from __future__ import annotations
from bisect import bisect
from collections.abc import Iterable, Iterator
import datetime
from functools import lru_cache
from importlib.metadata import version as importlib_version
import inspect
import io
import itertools
import json
from mimetypes import guess_type
import os
import os.path as op
from pathlib import Path, PurePath, PurePosixPath
import pdb
import platform
import re
import shutil
import subprocess
import sys
from time import sleep
import traceback
import types
from typing import IO, Any, List, Optional, Protocol, TypeVar, Union
import dateutil.parser
from multidict import MultiDict # dependency of yarl
from pydantic import BaseModel, Field
import requests
import ruamel.yaml
from semantic_version import Version
from yarl import URL
from . import __version__, get_logger
from .consts import DandiInstance, known_instances, known_instances_rev
from .exceptions import BadCliVersionError, CliVersionTooOldError
AnyPath = Union[str, Path]
lgr = get_logger()
_sys_excepthook = sys.excepthook # Just in case we ever need original one
#
# Some useful variables
#
platform_system = platform.system().lower()
on_windows = platform_system == "windows"
on_osx = platform_system == "darwin"
on_linux = platform_system == "linux"
on_msys_tainted_paths = (
on_windows
and "MSYS_NO_PATHCONV" not in os.environ
and os.environ.get("MSYSTEM", "")[:4] in ("MSYS", "MING")
)
USER_AGENT = "dandi/{} requests/{} {}/{}".format(
__version__,
requests.__version__,
platform.python_implementation(),
platform.python_version(),
)
[docs]class Hasher(Protocol):
[docs] def update(self, data: bytes) -> None:
...
[docs] def hexdigest(self) -> str:
...
[docs]def is_interactive() -> bool:
"""Return True if all in/outs are tty"""
# TODO: check on windows if hasattr check would work correctly and add value:
#
return sys.stdin.isatty() and sys.stdout.isatty() and sys.stderr.isatty()
[docs]def setup_exceptionhook(ipython: bool = False) -> None:
"""Overloads default sys.excepthook with our exceptionhook handler.
If interactive, our exceptionhook handler will invoke
pdb.post_mortem; if not interactive, then invokes default handler.
"""
def _pdb_excepthook(
exc_type: type[BaseException],
value: BaseException,
tb: types.TracebackType | None,
) -> None:
traceback.print_exception(exc_type, value, tb)
print()
if is_interactive():
pdb.post_mortem(tb)
if ipython:
from IPython.core import ultratb # type: ignore[import]
sys.excepthook = ultratb.FormattedTB(
mode="Verbose",
# color_scheme='Linux',
call_pdb=is_interactive(),
)
else:
sys.excepthook = _pdb_excepthook
[docs]def get_utcnow_datetime(microseconds: bool = True) -> datetime.datetime:
"""Return current time as datetime with time zone information.
Microseconds are stripped away.
If string representation is desired, just "apply" .isoformat()
"""
ret = datetime.datetime.now(datetime.timezone.utc).astimezone()
if microseconds:
return ret
else:
return ret.replace(microsecond=0)
[docs]def is_same_time(
*times: datetime.datetime | int | float | str,
tolerance: float = 1e-6,
strip_tzinfo: bool = False,
) -> bool:
"""Helper to do comparison between time points
Time zone information gets stripped
Does it by first normalizing all times to datetime, and then
comparing to the first entry
Parameters
----------
tolerance: float, optional
Seconds of difference between times to tolerate. By default difference
up to a microsecond is ok
"""
assert len(times) >= 2
norm_times = [
ensure_datetime(t, strip_tzinfo=strip_tzinfo, tz=datetime.timezone.utc)
for t in times
]
# we need to have all pairs
tolerance_dt = datetime.timedelta(seconds=tolerance)
return all(
# if we subtract from smaller - we get negative days etc
(t1 - t2 if t1 > t2 else t2 - t1) <= tolerance_dt
for (t1, t2) in itertools.combinations(norm_times, 2)
)
[docs]def ensure_strtime(
t: str | int | float | datetime.datetime, isoformat: bool = True
) -> str:
"""Ensures that time is a string in iso format
Note: if `t` is already a string, no conversion of any kind is done.
epoch time assumed to be local (not utc)
Parameters
----------
isoformat: bool, optional
If True, use .isoformat() and otherwise str(). With .isoformat() there
is no space but T to separate date from time.
"""
t_orig = t
if isinstance(t, str):
return t
if isinstance(t, (int, float)):
t = ensure_datetime(t)
if isinstance(t, datetime.datetime):
return t.isoformat() if isoformat else str(t)
raise TypeError(f"Do not know how to convert {t_orig!r} to string datetime")
[docs]def ensure_datetime(
t: datetime.datetime | int | float | str,
strip_tzinfo: bool = False,
tz: datetime.tzinfo | None = None,
) -> datetime.datetime:
"""Ensures that time is a datetime
strip_tzinfo applies only to str records passed in
epoch time assumed to be local (not utc)
"""
if isinstance(t, datetime.datetime):
pass
elif isinstance(t, (int, float)):
t = datetime.datetime.fromtimestamp(t).astimezone()
elif isinstance(t, str):
# could be in different formats, for now parse as ISO
t = fromisoformat(t)
if strip_tzinfo and t.tzinfo:
# TODO: check a proper way to handle this so we could account
# for a possibly present tz
t = t.replace(tzinfo=None)
else:
raise TypeError(f"Do not know how to convert {t!r} to datetime")
if tz:
t = t.astimezone(tz=tz)
return t
#
# Generic
#
[docs]def flatten(it: Iterable) -> Iterator:
"""Yield items flattened if list, tuple or a generator"""
for i in it:
if isinstance(i, (list, tuple)) or inspect.isgenerator(i):
yield from flattened(i)
else:
yield i
[docs]def flattened(it: Iterable) -> list:
"""Return list with items flattened if list, tuple or a generator"""
return list(flatten(it))
#
# Paths and files
#
[docs]def load_jsonl(filename: AnyPath) -> list:
"""Load json lines formatted file"""
with open(filename) as f:
return list(map(json.loads, f))
_VCS_NAMES = {".git", ".gitattributes", ".svn", ".bzr", ".hg"}
[docs]def find_files(
regex: str,
paths: AnyPath | Iterable[AnyPath] = os.curdir,
exclude: str | None = None,
exclude_dotfiles: bool = True,
exclude_dotdirs: bool = True,
exclude_vcs: bool = True,
exclude_datalad: bool = False,
dirs: bool = False,
dirs_avoid: str | None = None,
) -> Iterator[str]:
"""Generator to find files matching regex
Parameters
----------
regex: string
Regex to search target files. Is not applied to filter out directories
paths: string or list, optional
Directories or files to search among (directories are searched recursively)
exclude: string, optional
Matches to exclude
exclude_vcs:
If True, excludes commonly known VCS subdirectories. If string, used
as regex to exclude those files (regex: `%r`)
exclude_dotdirs:
If True, does not descend into directories starting with ".".
exclude_datalad:
If True, excludes files known to be datalad meta-data files (e.g. under
.datalad/ subdirectory) (regex: `%r`)
dirs: bool, optional
Whether to match directories as well as files
dirs_avoid: string, optional
Regex for directories to not rercurse under (they might still be reported
if `dirs=True`)
"""
def exclude_path(path: str) -> bool:
path = path.rstrip(op.sep)
if exclude and re.search(exclude, path):
return True
parts = Path(path).parts
if exclude_vcs and any(p in _VCS_NAMES for p in parts):
return True
if exclude_datalad and any(p == ".datalad" for p in parts):
return True
return False
def good_file(path: str) -> bool:
return bool(re.search(regex, path)) and not exclude_path(path)
if not isinstance(paths, (str, Path)):
for path in paths:
if op.isdir(path):
yield from find_files(
regex,
paths=path,
exclude=exclude,
exclude_dotfiles=exclude_dotfiles,
exclude_dotdirs=exclude_dotdirs,
exclude_vcs=exclude_vcs,
exclude_datalad=exclude_datalad,
dirs=dirs,
dirs_avoid=dirs_avoid,
)
elif good_file(str(path)):
yield str(path)
else:
# Provided path didn't match regex, thus excluded
pass
return
elif op.isfile(paths):
if good_file(str(paths)):
yield str(paths)
return
for dirpath, dirnames, filenames in os.walk(paths):
names = (dirnames + filenames) if dirs else filenames
# TODO: might want to uniformize on windows to use '/'
if exclude_dotfiles:
names = [n for n in names if not n.startswith(".")]
if exclude_dotdirs or dirs_avoid:
# and we should filter out directories from dirnames
# Since we need to del which would change index, let's
# start from the end
for i in range(len(dirnames))[::-1]:
if (exclude_dotdirs and dirnames[i].startswith(".")) or (
dirs_avoid and re.search(dirs_avoid, dirnames[i])
):
del dirnames[i]
strpaths = [op.join(dirpath, name) for name in names]
for p in filter(re.compile(regex).search, strpaths):
if not exclude_path(p):
if op.islink(p) and op.isdir(p):
lgr.warning(
"%s: Ignoring unsupported symbolic link to directory", path
)
else:
yield p
[docs]def list_paths(
dirpath: AnyPath, dirs: bool = False, exclude_vcs: bool = True
) -> list[Path]:
return sorted(
map(
Path,
find_files(
r".*",
[dirpath],
dirs=dirs,
exclude_dotfiles=False,
exclude_dotdirs=False,
exclude_vcs=exclude_vcs,
),
)
)
_cp_supports_reflink: bool | None = False if on_windows else None
[docs]def copy_file(src: AnyPath, dst: AnyPath) -> None:
"""Copy file from src to dst"""
global _cp_supports_reflink
if _cp_supports_reflink is None:
r = subprocess.run(
["cp", "--help"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Ignore command failures (e.g., if cp doesn't support --help), as the
# command will still likely output its usage info.
_cp_supports_reflink = "--reflink" in r.stdout
if _cp_supports_reflink:
subprocess.run(
["cp", "-f", "--reflink=auto", "--", str(src), str(dst)], check=True
)
else:
shutil.copy2(src, dst)
[docs]def move_file(src: AnyPath, dst: AnyPath) -> Any:
"""Move file from src to dst"""
return shutil.move(str(src), str(dst))
[docs]def find_parent_directory_containing(
filename: AnyPath, path: AnyPath | None = None
) -> Path | None:
"""Find a directory, on the path to 'path' containing filename
if no 'path' - path from cwd. If 'path' is not absolute, absolute path
is taken assuming relative to cwd.
Returns None if no such found, pathlib's Path (absolute) to the directory
if found.
"""
if not path:
path = Path.cwd()
else: # assure pathlib object
path = Path(path)
if not path.is_absolute():
path = path.absolute()
while True:
if op.lexists(path / filename):
return path
if path.parent == path:
return None
path = path.parent # go up
[docs]def yaml_dump(rec: Any) -> str:
"""Consistent dump into yaml
Of primary importance is default_flow_style=False
to assure proper formatting on versions of pyyaml before
5.1: https://github.com/yaml/pyyaml/pull/256
"""
yaml = ruamel.yaml.YAML(typ="safe")
yaml.default_flow_style = False
out = io.StringIO()
yaml.dump(rec, out)
return out.getvalue()
[docs]def yaml_load(f: str | IO[str], typ: str | None = None) -> Any:
"""
Load YAML source from a file or string.
Parameters
----------
f: str or IO[str]
The YAML source to load
typ: str, optional
The value of `typ` to pass to `ruamel.yaml.YAML`. May be "rt" (default),
"safe", "unsafe", or "base"
Returns
-------
Any
The parsed YAML value
"""
return ruamel.yaml.YAML(typ=typ).load(f)
#
# Borrowed from DataLad (MIT license)
#
[docs]def with_pathsep(path: str) -> str:
"""Little helper to guarantee that path ends with /"""
return path + op.sep if not path.endswith(op.sep) else path
def _get_normalized_paths(path: str, prefix: str) -> tuple[str, str]:
if op.isabs(path) != op.isabs(prefix):
raise ValueError(
"Both paths must either be absolute or relative. "
"Got %r and %r" % (path, prefix)
)
path = with_pathsep(path)
prefix = with_pathsep(prefix)
return path, prefix
[docs]def path_is_subpath(path: str, prefix: str) -> bool:
"""Return True if path is a subpath of prefix
It will return False if path == prefix.
Parameters
----------
path: str
prefix: str
"""
path, prefix = _get_normalized_paths(path, prefix)
return (len(prefix) < len(path)) and path.startswith(prefix)
[docs]def shortened_repr(value: Any, length: int = 30) -> str:
try:
if hasattr(value, "__repr__") and (value.__repr__ is not object.__repr__):
value_repr = repr(value)
if not value_repr.startswith("<") and len(value_repr) > length:
value_repr = "<<%s++%d chars++%s>>" % (
value_repr[: length - 16],
len(value_repr) - (length - 16 + 4),
value_repr[-4:],
)
elif (
value_repr.startswith("<")
and value_repr.endswith(">")
and " object at 0x"
):
raise ValueError("I hate those useless long reprs")
else:
raise ValueError("gimme class")
except Exception:
value_repr = "<%s>" % value.__class__.__name__.split(".")[-1]
return value_repr
[docs]def Parallel(**kwargs: Any) -> Any: # TODO: disable lint complaint
"""Adapter for joblib.Parallel so we could if desired, centralize control"""
# ATM just a straight invocation
import joblib
return joblib.Parallel(**kwargs)
[docs]def delayed(*args, **kwargs):
"""Adapter for joblib.delayed so we could if desired, centralize control"""
# ATM just a straight invocation
import joblib
return joblib.delayed(*args, **kwargs)
[docs]class ServiceURL(BaseModel):
# Don't use pydantic.AnyHttpUrl, as that adds a trailing slash, and so URLs
# retrieved for known instances won't match the known values
url: str
[docs]class ServerServices(BaseModel):
api: ServiceURL
webui: Optional[ServiceURL] = None
jupyterhub: Optional[ServiceURL] = None
[docs]class ServerInfo(BaseModel):
# schema_version: str
# schema_url: str
version: str
services: ServerServices
cli_minimal_version: str = Field(alias="cli-minimal-version")
cli_bad_versions: List[str] = Field(alias="cli-bad-versions")
[docs]def get_instance(dandi_instance_id: str | DandiInstance) -> DandiInstance:
dandi_id = None
is_api = True
redirector_url = None
if isinstance(dandi_instance_id, DandiInstance):
instance = dandi_instance_id
dandi_id = instance.name
elif dandi_instance_id.lower().startswith(("http://", "https://")):
redirector_url = dandi_instance_id.rstrip("/")
dandi_id = known_instances_rev.get(redirector_url)
if dandi_id is not None:
instance = known_instances[dandi_id]
is_api = instance.api.rstrip("/") == redirector_url
else:
instance = None
is_api = False
redirector_url = str(
URL(redirector_url).with_path("").with_query(None).with_fragment(None)
)
else:
dandi_id = dandi_instance_id
instance = known_instances[dandi_id]
if redirector_url is None:
assert instance is not None
return _get_instance(instance.api.rstrip("/"), True, instance, dandi_id)
else:
return _get_instance(redirector_url, is_api, instance, dandi_id)
@lru_cache
def _get_instance(
url: str, is_api: bool, instance: DandiInstance | None, dandi_id: str | None
) -> DandiInstance:
try:
if is_api:
r = requests.get(joinurl(url, "/info/"))
else:
r = requests.get(joinurl(url, "/server-info"))
if r.status_code == 404:
r = requests.get(joinurl(url, "/api/info/"))
r.raise_for_status()
server_info = ServerInfo.model_validate(r.json())
except Exception as e:
lgr.warning("Request to %s failed (%s)", url, str(e))
if instance is not None:
lgr.warning("Using hard-coded URLs")
return instance
else:
raise RuntimeError(
f"Could not retrieve server info from {url},"
" and client does not recognize URL"
)
try:
minversion = Version(server_info.cli_minimal_version)
bad_versions = [Version(v) for v in server_info.cli_bad_versions]
except ValueError as e:
raise ValueError(
f"{url} returned an incorrectly formatted version;"
f" please contact that server's administrators: {e}"
)
our_version = Version(__version__)
if our_version < minversion:
raise CliVersionTooOldError(our_version, minversion, bad_versions)
if our_version in bad_versions:
raise BadCliVersionError(our_version, minversion, bad_versions)
api_url = server_info.services.api.url
if dandi_id is None:
# Don't use pydantic.AnyHttpUrl, as that sets the `port` attribute even
# if it's not present in the string.
u = URL(api_url)
if u.host is not None:
dandi_id = u.host
if (port := u.explicit_port) is not None:
if ":" in dandi_id:
dandi_id = f"[{dandi_id}]"
dandi_id += f":{port}"
else:
dandi_id = api_url
return DandiInstance(
name=dandi_id,
gui=(
server_info.services.webui.url
if server_info.services.webui is not None
else None
),
api=api_url,
)
[docs]def is_url(s: str) -> bool:
"""Very primitive url detection for now
TODO: redo
"""
return s.lower().startswith(("http://", "https://", "dandi:", "ftp://"))
# Slashes are not required after "dandi:" so as to support "DANDI:<id>"
[docs]def get_module_version(module: str | types.ModuleType) -> str | None:
"""Return version of the module
Return module's `__version__` if present, or use importlib
to get version.
Returns
-------
object
"""
modobj: types.ModuleType | None
if isinstance(module, str):
modobj = sys.modules.get(module)
mod_name = module
else:
modobj = module
mod_name = module.__name__.split(".", 1)[0]
if modobj is not None:
version = getattr(modobj, "__version__", None)
else:
version = None
if version is None:
# Let's use the standard Python mechanism if underlying module
# did not provide __version__
try:
version = importlib_version(mod_name)
except Exception as exc:
lgr.debug("Failed to determine version of the %s: %s", mod_name, exc)
return version
[docs]def pluralize(n: int, word: str, plural: str | None = None) -> str:
if n == 1:
return f"{n} {word}"
else:
if plural is None:
plural = word + "s"
return f"{n} {plural}"
[docs]def abbrev_prompt(msg: str, *options: str) -> str:
"""
Prompt the user to input one of several options, which can be entered as
either a whole word or the first letter of a word. All input is handled
case-insensitively. Returns the complete word corresponding to the input,
lowercased.
For example, ``abbrev_prompt("Delete assets?", "yes", "no", "list")``
prompts the user with the message ``Delete assets? ([y]es/[n]o/[l]ist): ``
and accepts as input ``y`, ``yes``, ``n``, ``no``, ``l``, and ``list``.
"""
options_map = {}
optstrs = []
for opt in options:
opt = opt.lower()
if opt in options_map:
raise ValueError(f"Repeated option: {opt}")
elif opt[0] in options_map:
raise ValueError(f"Repeated abbreviated option: {opt[0]}")
options_map[opt] = opt
options_map[opt[0]] = opt
optstrs.append(f"[{opt[0]}]{opt[1:]}")
msg += " (" + "/".join(optstrs) + "): "
while True:
answer = input(msg).lower()
if answer in options_map:
return options_map[answer]
[docs]def get_mime_type(filename: str, strict: bool = False) -> str:
"""
Like `mimetypes.guess_type()`, except that if the file is compressed, the
MIME type for the compression is returned. Also, the default return value
is now ``'application/octet-stream'`` instead of `None`.
"""
mtype, encoding = guess_type(filename, strict)
if encoding is None:
return mtype or "application/octet-stream"
elif encoding == "gzip":
# application/gzip is defined by RFC 6713
return "application/gzip"
# There is also a "+gzip" MIME structured syntax suffix defined by RFC
# 8460; exactly when can that be used?
# return mtype + '+gzip'
else:
return "application/x-" + encoding
[docs]def check_dandi_version() -> None:
if os.environ.get("DANDI_NO_ET"):
return
try:
import etelemetry
try:
etelemetry.check_available_version(
"dandi/dandi-cli", __version__, lgr=lgr, raise_exception=True
)
except etelemetry.client.BadVersionError:
# note: SystemExit is based of BaseException, so is not Exception
raise SystemExit(
"DANDI CLI has detected that you are using a version that is known to "
"contain bugs, is incompatible with our current data archive, or has "
"other significant performance limitations. "
"To continue using DANDI CLI, please upgrade your dandi client to a newer "
"version (e.g., using pip install --upgrade dandi if you installed using pip). "
"If you have any issues, please contact the DANDI "
"helpdesk at https://github.com/dandi/helpdesk/issues/new/choose ."
)
except Exception as exc:
lgr.warning(
"Failed to check for a more recent version available with etelemetry: %s",
exc,
)
os.environ["DANDI_NO_ET"] = "1"
T = TypeVar("T")
[docs]def chunked(iterable: Iterable[T], size: int) -> Iterator[list[T]]:
# cf. chunked() from more-itertools
i = iter(iterable)
while True:
xs = []
for _ in range(size):
try:
xs.append(next(i))
except StopIteration:
if xs:
break
else:
return
yield xs
[docs]def is_page2_url(page1: str, page2: str) -> bool:
"""
Tests whether the URL ``page2`` is the same as ``page1`` but with the
``page`` query parameter set to ``2``
"""
url1 = URL(page1)
params1 = MultiDict(url1.query)
params1["page"] = "2"
url1 = url1.with_query(None)
url2 = URL(page2)
params2 = url2.query
url2 = url2.with_query(None)
return (url1, sorted(params1.items())) == (url2, sorted(params2.items()))
[docs]def exclude_from_zarr(path: Path) -> bool:
"""
Returns `True` if the ``path`` is a file or directory that should be
excluded from consideration when located in a Zarr
"""
return path.name in (".dandi", ".datalad", ".git", ".gitattributes", ".gitmodules")
[docs]def under_paths(
paths: Iterable[str | PurePath], filter_paths: Iterable[str | PurePath]
) -> Iterator[PurePosixPath]:
"""
Return all elements of ``paths`` (converted to `PurePosixPath` instances)
that are equal to or under/start with one or more paths in
``filter_paths``. The elements of both iterables must be relative &
normalized.
Based on ``get_filtered_paths_`` from datalad's
:file:`datalad/support/path.py`
"""
path_parts = _prepare_path_parts(paths)
filter_path_parts = _prepare_path_parts(filter_paths)
for path in path_parts:
i = bisect(filter_path_parts, path)
if i > 0 and _starts_with(path, filter_path_parts[i - 1]):
yield PurePosixPath(*path)
elif i == len(filter_path_parts):
break
def _prepare_path_parts(paths: Iterable[str | PurePath]) -> list[tuple[str, ...]]:
path_parts: list[tuple[str, ...]] = []
for p in paths:
pp = PurePosixPath(p)
if pp.is_absolute():
raise ValueError(f"Absolute path: {p!r}")
parts = pp.parts
if ".." in parts or "." in parts:
raise ValueError(f"Non-normalized path: {p!r}")
path_parts.append(parts)
path_parts.sort()
return path_parts
def _starts_with(t: tuple[str, ...], prefix: tuple[str, ...]) -> bool:
return t[: len(prefix)] == prefix
[docs]def pre_upload_size_check(path: Path) -> int:
# If the filesystem reports a size of zero for a file we're about to
# upload, double-check the size in case we're on a flaky NFS system.
for naptime in [0] + [0.1] * 19:
sleep(naptime)
size = path.stat().st_size
if size != 0:
return size
return size
[docs]def post_upload_size_check(path: Path, pre_check_size: int, erroring: bool) -> None:
# More checks for NFS flakiness
size = path.stat().st_size
if size != pre_check_size:
msg = (
f"Size of {path} was {pre_check_size} at start of upload but is"
f" now {size} after upload"
)
if erroring:
lgr.error(msg)
else:
raise RuntimeError(msg)
[docs]def joinurl(base: str, path: str) -> str:
"""
Append a slash-separated ``path`` to a base HTTP(S) URL ``base``. The two
components are separated by a single slash, removing any excess slashes
that would be present after naïve concatenation.
If ``path`` is already an absolute HTTP(S) URL, it is returned unchanged.
Note that this function differs from `urllib.parse.urljoin()` when the path
portion of ``base`` is nonempty and does not end in a slash.
"""
if path.lower().startswith(("http://", "https://")):
return path
else:
return base.rstrip("/") + "/" + path.lstrip("/")