You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
7.9 KiB

import hashlib
import os
import site
import sys
import tarfile
import urllib.parse
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional
from wasabi import msg
from ..errors import Errors
from ..util import check_spacy_env_vars, download_file, ensure_pathy, get_checksum
from ..util import get_hash, make_tempdir, upload_file
if TYPE_CHECKING:
from cloudpathlib import CloudPath
class RemoteStorage:
"""Push and pull outputs to and from a remote file storage.
Remotes can be anything that `smart_open` can support: AWS, GCS, file system,
ssh, etc.
"""
def __init__(self, project_root: Path, url: str, *, compression="gz"):
self.root = project_root
self.url = ensure_pathy(url)
self.compression = compression
def push(self, path: Path, command_hash: str, content_hash: str) -> "CloudPath":
"""Compress a file or directory within a project and upload it to a remote
storage. If an object exists at the full URL, nothing is done.
Within the remote storage, files are addressed by their project path
(url encoded) and two user-supplied hashes, representing their creation
context and their file contents. If the URL already exists, the data is
not uploaded. Paths are archived and compressed prior to upload.
"""
loc = self.root / path
if not loc.exists():
raise IOError(f"Cannot push {loc}: does not exist.")
url = self.make_url(path, command_hash, content_hash)
if url.exists():
return url
tmp: Path
with make_tempdir() as tmp:
tar_loc = tmp / self.encode_name(str(path))
mode_string = f"w:{self.compression}" if self.compression else "w"
with tarfile.open(tar_loc, mode=mode_string) as tar_file:
tar_file.add(str(loc), arcname=str(path))
upload_file(tar_loc, url)
return url
def pull(
self,
path: Path,
*,
command_hash: Optional[str] = None,
content_hash: Optional[str] = None,
) -> Optional["CloudPath"]:
"""Retrieve a file from the remote cache. If the file already exists,
nothing is done.
If the command_hash and/or content_hash are specified, only matching
results are returned. If no results are available, an error is raised.
"""
dest = self.root / path
if dest.exists():
return None
url = self.find(path, command_hash=command_hash, content_hash=content_hash)
if url is None:
return url
else:
# Make sure the destination exists
if not dest.parent.exists():
dest.parent.mkdir(parents=True)
tmp: Path
with make_tempdir() as tmp:
tar_loc = tmp / url.parts[-1]
download_file(url, tar_loc)
mode_string = f"r:{self.compression}" if self.compression else "r"
with tarfile.open(tar_loc, mode=mode_string) as tar_file:
# This requires that the path is added correctly, relative
# to root. This is how we set things up in push()
# Disallow paths outside the current directory for the tar
# file (CVE-2007-4559, directory traversal vulnerability)
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise ValueError(Errors.E201)
if sys.version_info >= (3, 12):
tar.extractall(path, filter="data")
else:
tar.extractall(path)
safe_extract(tar_file, self.root)
return url
def find(
self,
path: Path,
*,
command_hash: Optional[str] = None,
content_hash: Optional[str] = None,
) -> Optional["CloudPath"]:
"""Find the best matching version of a file within the storage,
or `None` if no match can be found. If both the creation and content hash
are specified, only exact matches will be returned. Otherwise, the most
recent matching file is preferred.
"""
name = self.encode_name(str(path))
urls = []
if command_hash is not None and content_hash is not None:
url = self.url / name / command_hash / content_hash
urls = [url] if url.exists() else []
elif command_hash is not None:
if (self.url / name / command_hash).exists():
urls = list((self.url / name / command_hash).iterdir())
else:
if (self.url / name).exists():
for sub_dir in (self.url / name).iterdir():
urls.extend(sub_dir.iterdir())
if content_hash is not None:
urls = [url for url in urls if url.parts[-1] == content_hash]
if len(urls) >= 2:
try:
urls.sort(key=lambda x: x.stat().st_mtime)
except Exception:
msg.warn(
"Unable to sort remote files by last modified. The file(s) "
"pulled from the cache may not be the most recent."
)
return urls[-1] if urls else None
def make_url(self, path: Path, command_hash: str, content_hash: str) -> "CloudPath":
"""Construct a URL from a subpath, a creation hash and a content hash."""
return self.url / self.encode_name(str(path)) / command_hash / content_hash
def encode_name(self, name: str) -> str:
"""Encode a subpath into a URL-safe name."""
return urllib.parse.quote_plus(name)
def get_content_hash(loc: Path) -> str:
return get_checksum(loc)
def get_command_hash(
site_hash: str, env_hash: str, deps: List[Path], cmd: List[str]
) -> str:
"""Create a hash representing the execution of a command. This includes the
currently installed packages, whatever environment variables have been marked
as relevant, and the command.
"""
check_spacy_env_vars()
dep_checksums = [get_checksum(dep) for dep in sorted(deps)]
hashes = [site_hash, env_hash] + dep_checksums
hashes.extend(cmd)
creation_bytes = "".join(hashes).encode("utf8")
return hashlib.md5(creation_bytes).hexdigest()
def get_site_hash():
"""Hash the current Python environment's site-packages contents, including
the name and version of the libraries. The list we're hashing is what
`pip freeze` would output.
"""
site_dirs = site.getsitepackages()
if site.ENABLE_USER_SITE:
site_dirs.extend(site.getusersitepackages())
packages = set()
for site_dir in site_dirs:
site_dir = Path(site_dir)
for subpath in site_dir.iterdir():
if subpath.parts[-1].endswith("dist-info"):
packages.add(subpath.parts[-1].replace(".dist-info", ""))
package_bytes = "".join(sorted(packages)).encode("utf8")
return hashlib.md5sum(package_bytes).hexdigest()
def get_env_hash(env: Dict[str, str]) -> str:
"""Construct a hash of the environment variables that will be passed into
the commands.
Values in the env dict may be references to the current os.environ, using
the syntax $ENV_VAR to mean os.environ[ENV_VAR]
"""
env_vars = {}
for key, value in env.items():
if value.startswith("$"):
env_vars[key] = os.environ.get(value[1:], "")
else:
env_vars[key] = value
return get_hash(env_vars)