import base64
import contextlib
import copy
import errno
import glob
import inspect
import itertools
import os
import re
import struct
import subprocess
import sys
import time
import traceback
from collections import UserDict
from collections.abc import Callable
from copy import deepcopy
from functools import lru_cache, wraps
from numbers import Number as numeric_type
from typing import Any
import numpy as np
from more_itertools import always_iterable, collapse, first
from yt._maintenance.deprecation import issue_deprecation_warning
from yt._maintenance.ipython_compat import IS_IPYTHON
from yt.config import ytcfg
from yt.units import YTArray, YTQuantity
from yt.utilities.exceptions import YTFieldNotFound, YTInvalidWidthError
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _requests as requests
# Some functions for handling sequences and other types
[docs]
def is_sequence(obj):
"""
Grabbed from Python Cookbook / matplotlib.cbook. Returns true/false for
Parameters
----------
obj : iterable
"""
try:
len(obj)
return True
except TypeError:
return False
[docs]
def iter_fields(field_or_fields):
"""
Create an iterator for field names, specified as single strings or tuples(fname,
ftype) alike.
This can safely be used in places where we accept a single field or a list as input.
Parameters
----------
field_or_fields: str, tuple(str, str), or any iterable of the previous types.
Examples
--------
>>> fields = ("gas", "density")
>>> for field in iter_fields(fields):
... print(field)
density
>>> fields = ("gas", "density")
>>> for field in iter_fields(fields):
... print(field)
('gas', 'density')
>>> fields = [("gas", "density"), ("gas", "temperature"), ("index", "dx")]
>>> for field in iter_fields(fields):
... print(field)
density
temperature
('index', 'dx')
"""
return always_iterable(field_or_fields, base_type=(tuple, str, bytes))
[docs]
def ensure_numpy_array(obj):
"""
This function ensures that *obj* is a numpy array. Typically used to
convert scalar, list or tuple argument passed to functions using Cython.
"""
if isinstance(obj, np.ndarray):
if obj.shape == ():
return np.array([obj])
# We cast to ndarray to catch ndarray subclasses
return np.array(obj)
elif isinstance(obj, (list, tuple)):
return np.asarray(obj)
else:
return np.asarray([obj])
[docs]
def read_struct(f, fmt):
"""
This reads a struct, and only that struct, from an open file.
"""
s = f.read(struct.calcsize(fmt))
return struct.unpack(fmt, s)
[docs]
def just_one(obj):
# If we have an iterable, sometimes we only want one item
return first(collapse(obj))
[docs]
def compare_dicts(dict1, dict2):
if not set(dict1) <= set(dict2):
return False
for key in dict1.keys():
if dict1[key] is not None and dict2[key] is not None:
if isinstance(dict1[key], dict):
if compare_dicts(dict1[key], dict2[key]):
continue
else:
return False
try:
comparison = np.array_equal(dict1[key], dict2[key])
except TypeError:
comparison = dict1[key] == dict2[key]
if not comparison:
return False
return True
# Taken from
# http://www.goldb.org/goldblog/2008/02/06/PythonConvertSecsIntoHumanReadableTimeStringHHMMSS.aspx
[docs]
def humanize_time(secs):
"""
Takes *secs* and returns a nicely formatted string
"""
mins, secs = divmod(secs, 60)
hours, mins = divmod(mins, 60)
return "%02d:%02d:%02d" % (hours, mins, secs)
#
# Some function wrappers that come in handy once in a while
#
[docs]
def get_memory_usage(subtract_share=False):
"""
Returning resident size in megabytes
"""
pid = os.getpid()
# we use the resource module to get the memory page size
try:
import resource
except ImportError:
return -1024
else:
pagesize = resource.getpagesize()
status_file = f"/proc/{pid}/statm"
if not os.path.isfile(status_file):
return -1024
with open(status_file) as fh:
line = fh.read()
size, resident, share, text, library, data, dt = (int(i) for i in line.split())
if subtract_share:
resident -= share
return resident * pagesize / (1024 * 1024) # return in megs
[docs]
def time_execution(func):
r"""
Decorator for seeing how long a given function takes, depending on whether
or not the global 'yt.time_functions' config parameter is set.
"""
@wraps(func)
def wrapper(*arg, **kw):
t1 = time.time()
res = func(*arg, **kw)
t2 = time.time()
mylog.debug("%s took %0.3f s", func.__name__, (t2 - t1))
return res
if ytcfg.get("yt", "time_functions"):
return wrapper
else:
return func
[docs]
def print_tb(func):
"""
This function is used as a decorate on a function to have the calling stack
printed whenever that function is entered.
This can be used like so:
>>> @print_tb
... def some_deeply_nested_function(*args, **kwargs):
... ...
"""
@wraps(func)
def run_func(*args, **kwargs):
traceback.print_stack()
return func(*args, **kwargs)
return run_func
[docs]
def rootonly(func):
"""
This is a decorator that, when used, will only call the function on the
root processor.
This can be used like so:
.. code-block:: python
@rootonly
def some_root_only_function(*args, **kwargs): ...
"""
@wraps(func)
def check_parallel_rank(*args, **kwargs):
if ytcfg.get("yt", "internals", "topcomm_parallel_rank") > 0:
return
return func(*args, **kwargs)
return check_parallel_rank
[docs]
def pdb_run(func):
"""
This decorator inserts a pdb session on top of the call-stack into a
function.
This can be used like so:
>>> @pdb_run
... def some_function_to_debug(*args, **kwargs):
... ...
"""
import pdb
@wraps(func)
def wrapper(*args, **kw):
pdb.runcall(func, *args, **kw)
return wrapper
__header = """
== Welcome to the embedded IPython Shell ==
You are currently inside the function:
%(fname)s
Defined in:
%(filename)s:%(lineno)s
"""
[docs]
def insert_ipython(num_up=1):
"""
Placed inside a function, this will insert an IPython interpreter at that
current location. This will enabled detailed inspection of the current
execution environment, as well as (optional) modification of that environment.
*num_up* refers to how many frames of the stack get stripped off, and
defaults to 1 so that this function itself is stripped off.
"""
import IPython
from IPython.terminal.embed import InteractiveShellEmbed
try:
from traitlets.config.loader import Config
except ImportError:
from IPython.config.loader import Config
frame = inspect.stack()[num_up]
loc = frame[0].f_locals.copy()
glo = frame[0].f_globals
dd = {"fname": frame[3], "filename": frame[1], "lineno": frame[2]}
cfg = Config()
cfg.InteractiveShellEmbed.local_ns = loc
cfg.InteractiveShellEmbed.global_ns = glo
IPython.embed(config=cfg, banner2=__header % dd)
ipshell = InteractiveShellEmbed(config=cfg)
del ipshell
#
# Our progress bar types and how to get one
#
[docs]
class TqdmProgressBar:
# This is a drop in replacement for pbar
# called tqdm
def __init__(self, title, maxval):
from tqdm import tqdm
self._pbar = tqdm(leave=True, total=maxval, desc=title)
self.i = 0
[docs]
def update(self, i=None):
if i is None:
i = self.i + 1
n = i - self.i
self.i = i
self._pbar.update(n)
[docs]
def finish(self):
self._pbar.close()
[docs]
class DummyProgressBar:
# This progressbar gets handed if we don't
# want ANY output
def __init__(self, *args, **kwargs):
return
[docs]
def update(self, *args, **kwargs):
return
[docs]
def finish(self, *args, **kwargs):
return
[docs]
def get_pbar(title, maxval):
"""
This returns a progressbar of the most appropriate type, given a *title*
and a *maxval*.
"""
maxval = max(maxval, 1)
if (
ytcfg.get("yt", "suppress_stream_logging")
or ytcfg.get("yt", "internals", "within_testing")
or maxval == 1
or not is_root()
):
return DummyProgressBar()
return TqdmProgressBar(title, maxval)
[docs]
def only_on_root(func, *args, **kwargs):
"""
This function accepts a *func*, a set of *args* and *kwargs* and then only
on the root processor calls the function. All other processors get "None"
handed back.
"""
if kwargs.pop("global_rootonly", False):
cfg_option = "global_parallel_rank"
else:
cfg_option = "topcomm_parallel_rank"
if not ytcfg.get("yt", "internals", "parallel"):
return func(*args, **kwargs)
if ytcfg.get("yt", "internals", cfg_option) > 0:
return
return func(*args, **kwargs)
[docs]
def is_root():
"""
This function returns True if it is on the root processor of the
topcomm and False otherwise.
"""
if not ytcfg.get("yt", "internals", "parallel"):
return True
return ytcfg.get("yt", "internals", "topcomm_parallel_rank") == 0
#
# Our signal and traceback handling functions
#
[docs]
def signal_print_traceback(signo, frame):
print(traceback.print_stack(frame))
[docs]
def signal_problem(signo, frame):
raise RuntimeError()
[docs]
def signal_ipython(signo, frame):
insert_ipython(2)
[docs]
def paste_traceback(exc_type, exc, tb):
"""
This is a traceback handler that knows how to paste to the pastebin.
Should only be used in sys.excepthook.
"""
sys.__excepthook__(exc_type, exc, tb)
import xmlrpc.client
from io import StringIO
p = xmlrpc.client.ServerProxy(
"http://paste.yt-project.org/xmlrpc/", allow_none=True
)
s = StringIO()
traceback.print_exception(exc_type, exc, tb, file=s)
s = s.getvalue()
ret = p.pastes.newPaste("pytb", s, None, "", "", True)
print()
print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}")
print()
[docs]
def paste_traceback_detailed(exc_type, exc, tb):
"""
This is a traceback handler that knows how to paste to the pastebin.
Should only be used in sys.excepthook.
"""
import cgitb
import xmlrpc.client
from io import StringIO
s = StringIO()
handler = cgitb.Hook(format="text", file=s)
handler(exc_type, exc, tb)
s = s.getvalue()
print(s)
p = xmlrpc.client.ServerProxy(
"http://paste.yt-project.org/xmlrpc/", allow_none=True
)
ret = p.pastes.newPaste("text", s, None, "", "", True)
print()
print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}")
print()
_ss = "fURbBUUBE0cLXgETJnZgJRMXVhVGUQpQAUBuehQMUhJWRFFRAV1ERAtBXw1dAxMLXT4zXBFfABNN\nC0ZEXw1YUURHCxMXVlFERwxWCQw=\n"
def _rdbeta(key):
enc_s = base64.decodestring(_ss)
dec_s = "".join(chr(ord(a) ^ ord(b)) for a, b in zip(enc_s, itertools.cycle(key)))
print(dec_s)
#
# Some exceptions
#
[docs]
class NoCUDAException(Exception):
pass
[docs]
class YTEmptyClass:
pass
[docs]
def update_git(path):
try:
import git
except ImportError:
print("Updating and precise version information requires ")
print("gitpython to be installed.")
print("Try: python -m pip install gitpython")
return -1
with open(os.path.join(path, "yt_updater.log"), "a") as f:
repo = git.Repo(path)
if repo.is_dirty(untracked_files=True):
print("Changes have been made to the yt source code so I won't ")
print("update the code. You will have to do this yourself.")
print("Here's a set of sample commands:")
print("")
print(f" $ cd {path}")
print(" $ git stash")
print(" $ git checkout main")
print(" $ git pull")
print(" $ git stash pop")
print(f" $ {sys.executable} setup.py develop")
print("")
return 1
if repo.active_branch.name != "main":
print("yt repository is not tracking the main branch so I won't ")
print("update the code. You will have to do this yourself.")
print("Here's a set of sample commands:")
print("")
print(f" $ cd {path}")
print(" $ git checkout main")
print(" $ git pull")
print(f" $ {sys.executable} setup.py develop")
print("")
return 1
print("Updating the repository")
f.write("Updating the repository\n\n")
old_version = repo.git.rev_parse("HEAD", short=12)
try:
remote = repo.remotes.yt_upstream
except AttributeError:
remote = repo.create_remote(
"yt_upstream", url="https://github.com/yt-project/yt"
)
remote.fetch()
main = repo.heads.main
main.set_tracking_branch(remote.refs.main)
main.checkout()
remote.pull()
new_version = repo.git.rev_parse("HEAD", short=12)
f.write(f"Updated from {old_version} to {new_version}\n\n")
rebuild_modules(path, f)
print("Updated successfully")
[docs]
def rebuild_modules(path, f):
f.write("Rebuilding modules\n\n")
p = subprocess.Popen(
[sys.executable, "setup.py", "build_clib", "build_ext", "-i"],
cwd=path,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
stdout, stderr = p.communicate()
f.write(stdout.decode("utf-8"))
f.write("\n\n")
if p.returncode:
print(f"BROKEN: See {os.path.join(path, 'yt_updater.log')}")
sys.exit(1)
f.write("Successful!\n")
[docs]
def get_git_version(path):
try:
import git
except ImportError:
print("Updating and precise version information requires ")
print("gitpython to be installed.")
print("Try: python -m pip install gitpython")
return None
try:
repo = git.Repo(path)
return repo.git.rev_parse("HEAD", short=12)
except git.InvalidGitRepositoryError:
# path is not a git repository
return None
[docs]
def get_yt_version():
import importlib.resources as importlib_resources
version = get_git_version(os.path.dirname(importlib_resources.files("yt")))
if version is None:
return version
else:
v_str = version[:12].strip()
if hasattr(v_str, "decode"):
v_str = v_str.decode("utf-8")
return v_str
[docs]
def get_version_stack():
import matplotlib
from yt._version import __version__ as yt_version
version_info = {}
version_info["yt"] = yt_version
version_info["numpy"] = np.version.version
version_info["matplotlib"] = matplotlib.__version__
return version_info
[docs]
def get_script_contents():
top_frame = inspect.stack()[-1]
finfo = inspect.getframeinfo(top_frame[0])
if finfo[2] != "<module>":
return None
if not os.path.exists(finfo[0]):
return None
try:
contents = open(finfo[0]).read()
except Exception:
contents = None
return contents
[docs]
def download_file(url, filename):
try:
return fancy_download_file(url, filename, requests)
except ImportError:
# fancy_download_file requires requests
return simple_download_file(url, filename)
[docs]
def fancy_download_file(url, filename, requests=None):
response = requests.get(url, stream=True)
total_length = response.headers.get("content-length")
with open(filename, "wb") as fh:
if total_length is None:
fh.write(response.content)
else:
blocksize = 4 * 1024**2
iterations = int(float(total_length) / float(blocksize))
pbar = get_pbar(
"Downloading {} to {} ".format(*os.path.split(filename)[::-1]),
iterations,
)
iteration = 0
for chunk in response.iter_content(chunk_size=blocksize):
fh.write(chunk)
iteration += 1
pbar.update(iteration)
pbar.finish()
return filename
[docs]
def simple_download_file(url, filename):
import urllib.error
import urllib.request
try:
fn, h = urllib.request.urlretrieve(url, filename)
except urllib.error.HTTPError as err:
raise RuntimeError(
f"Attempt to download file from {url} failed with error {err.code}: {err.msg}."
) from None
return fn
# This code snippet is modified from Georg Brandl
[docs]
def bb_apicall(endpoint, data, use_pass=True):
import getpass
import urllib.parse
import urllib.request
uri = f"https://api.bitbucket.org/1.0/{endpoint}/"
# since bitbucket doesn't return the required WWW-Authenticate header when
# making a request without Authorization, we cannot use the standard urllib2
# auth handlers; we have to add the requisite header from the start
if data is not None:
data = urllib.parse.urlencode(data)
req = urllib.request.Request(uri, data)
if use_pass:
username = input("Bitbucket Username? ")
password = getpass.getpass()
upw = f"{username}:{password}"
req.add_header("Authorization", f"Basic {base64.b64encode(upw).strip()}")
return urllib.request.urlopen(req).read()
[docs]
def fix_length(length, ds):
registry = ds.unit_registry
if isinstance(length, YTArray):
if registry is not None:
length.units.registry = registry
return length.in_units("code_length")
if isinstance(length, numeric_type):
return YTArray(length, "code_length", registry=registry)
length_valid_tuple = isinstance(length, (list, tuple)) and len(length) == 2
unit_is_string = isinstance(length[1], str)
length_is_number = isinstance(length[0], numeric_type) and not isinstance(
length[0], YTArray
)
if length_valid_tuple and unit_is_string and length_is_number:
return YTArray(*length, registry=registry)
else:
raise RuntimeError(f"Length {str(length)} is invalid")
[docs]
@contextlib.contextmanager
def parallel_profile(prefix):
r"""A context manager for profiling parallel code execution using cProfile
This is a simple context manager that automatically profiles the execution
of a snippet of code.
Parameters
----------
prefix : string
A string name to prefix outputs with.
Examples
--------
>>> from yt import PhasePlot
>>> from yt.testing import fake_random_ds
>>> fields = ("density", "temperature", "cell_mass")
>>> units = ("g/cm**3", "K", "g")
>>> ds = fake_random_ds(16, fields=fields, units=units)
>>> with parallel_profile("my_profile"):
... plot = PhasePlot(ds.all_data(), *fields)
"""
import cProfile
fn = "%s_%04i_%04i.cprof" % (
prefix,
ytcfg.get("yt", "internals", "topcomm_parallel_size"),
ytcfg.get("yt", "internals", "topcomm_parallel_rank"),
)
p = cProfile.Profile()
p.enable()
yield fn
p.disable()
p.dump_stats(fn)
[docs]
def get_num_threads():
from .config import ytcfg
nt = ytcfg.get("yt", "num_threads")
if nt < 0:
return os.environ.get("OMP_NUM_THREADS", 0)
return nt
[docs]
def fix_axis(axis, ds):
return ds.coordinates.axis_id.get(axis, axis)
[docs]
def get_output_filename(name, keyword, suffix):
r"""Return an appropriate filename for output.
With a name provided by the user, this will decide how to appropriately name the
output file by the following rules:
1. if name is None, the filename will be the keyword plus the suffix.
2. if name ends with "/" (resp "\" on Windows), assume name is a directory and the
file will be named name/(keyword+suffix). If the directory does not exist, first
try to create it and raise an exception if an error occurs.
3. if name does not end in the suffix, add the suffix.
Parameters
----------
name : str
A filename given by the user.
keyword : str
A default filename prefix if name is None.
suffix : str
Suffix that must appear at end of the filename.
This will be added if not present.
Examples
--------
>>> get_output_filename(None, "Projection_x", ".png")
'Projection_x.png'
>>> get_output_filename("my_file", "Projection_x", ".png")
'my_file.png'
>>> get_output_filename("my_dir/", "Projection_x", ".png")
'my_dir/Projection_x.png'
"""
if name is None:
name = keyword
name = os.path.expanduser(name)
if name.endswith(os.sep) and not os.path.isdir(name):
ensure_dir(name)
if os.path.isdir(name):
name = os.path.join(name, keyword)
if not name.endswith(suffix):
name += suffix
return name
[docs]
def ensure_dir_exists(path):
r"""Create all directories in path recursively in a parallel safe manner"""
my_dir = os.path.dirname(path)
# If path is a file in the current directory, like "test.txt", then my_dir
# would be an empty string, resulting in FileNotFoundError when passed to
# ensure_dir. Let's avoid that.
if my_dir:
ensure_dir(my_dir)
[docs]
def ensure_dir(path):
r"""Parallel safe directory maker."""
if os.path.exists(path):
return path
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
return path
[docs]
def validate_width_tuple(width):
if not is_sequence(width) or len(width) != 2:
raise YTInvalidWidthError(f"width ({width}) is not a two element tuple")
is_numeric = isinstance(width[0], numeric_type)
length_has_units = isinstance(width[0], YTArray)
unit_is_string = isinstance(width[1], str)
if not is_numeric or length_has_units and unit_is_string:
msg = f"width ({str(width)}) is invalid. "
msg += "Valid widths look like this: (12, 'au')"
raise YTInvalidWidthError(msg)
_first_cap_re = re.compile("(.)([A-Z][a-z]+)")
_all_cap_re = re.compile("([a-z0-9])([A-Z])")
[docs]
@lru_cache(maxsize=128, typed=False)
def camelcase_to_underscore(name):
s1 = _first_cap_re.sub(r"\1_\2", name)
return _all_cap_re.sub(r"\1_\2", s1).lower()
[docs]
def set_intersection(some_list):
if len(some_list) == 0:
return set()
# This accepts a list of iterables, which we get the intersection of.
s = set(some_list[0])
for l in some_list[1:]:
s.intersection_update(l)
return s
[docs]
@contextlib.contextmanager
def memory_checker(interval=15, dest=None):
r"""This is a context manager that monitors memory usage.
Parameters
----------
interval : int
The number of seconds between printing the current memory usage in
gigabytes of the current Python interpreter.
Examples
--------
>>> with memory_checker(10):
... arr = np.zeros(1024 * 1024 * 1024, dtype="float64")
... time.sleep(15)
... del arr
MEMORY: -1.000e+00 gb
"""
import threading
if dest is None:
dest = sys.stdout
class MemoryChecker(threading.Thread):
def __init__(self, event, interval):
self.event = event
self.interval = interval
threading.Thread.__init__(self)
def run(self):
while not self.event.wait(self.interval):
print(f"MEMORY: {get_memory_usage() / 1024.0:0.3e} gb", file=dest)
e = threading.Event()
mem_check = MemoryChecker(e, interval)
mem_check.start()
try:
yield
finally:
e.set()
[docs]
def enable_plugins(plugin_filename=None):
"""Forces a plugin file to be parsed.
A plugin file is a means of creating custom fields, quantities,
data objects, colormaps, and other code classes and objects to be used
in yt scripts without modifying the yt source directly.
If ``plugin_filename`` is omitted, this function will look for a plugin file at
``$HOME/.config/yt/my_plugins.py``, which is the preferred behaviour for a
system-level configuration.
Warning: a script using this function will only be reproducible if your plugin
file is shared with it.
"""
import yt
from yt.config import config_dir, ytcfg
from yt.fields.my_plugin_fields import my_plugins_fields
if plugin_filename is not None:
_fn = plugin_filename
if not os.path.isfile(_fn):
raise FileNotFoundError(_fn)
else:
# Determine global plugin location. By decreasing priority order:
# - absolute path
# - CONFIG_DIR
# - obsolete config dir.
my_plugin_name = ytcfg.get("yt", "plugin_filename")
for base_prefix in ("", config_dir()):
if os.path.isfile(os.path.join(base_prefix, my_plugin_name)):
_fn = os.path.join(base_prefix, my_plugin_name)
break
else:
raise FileNotFoundError("Could not find a global system plugin file.")
mylog.info("Loading plugins from %s", _fn)
ytdict = yt.__dict__
execdict = ytdict.copy()
execdict["add_field"] = my_plugins_fields.add_field
with open(_fn) as f:
code = compile(f.read(), _fn, "exec")
exec(code, execdict, execdict)
ytnamespace = list(ytdict.keys())
for k in execdict.keys():
if k not in ytnamespace:
if callable(execdict[k]):
setattr(yt, k, execdict[k])
[docs]
def subchunk_count(n_total, chunk_size):
handled = 0
while handled < n_total:
tr = min(n_total - handled, chunk_size)
yield tr
handled += tr
[docs]
def fix_unitary(u):
if u == "1":
return "unitary"
else:
return u
[docs]
def get_hash(infile, algorithm="md5", BLOCKSIZE=65536):
"""Generate file hash without reading in the entire file at once.
Original code licensed under MIT. Source:
https://www.pythoncentral.io/hashing-files-with-python/
Parameters
----------
infile : str
File of interest (including the path).
algorithm : str (optional)
Hash algorithm of choice. Defaults to 'md5'.
BLOCKSIZE : int (optional)
How much data in bytes to read in at once.
Returns
-------
hash : str
The hash of the file.
Examples
--------
>>> from tempfile import NamedTemporaryFile
>>> with NamedTemporaryFile() as file:
... get_hash(file.name)
'd41d8cd98f00b204e9800998ecf8427e'
"""
import hashlib
try:
hasher = getattr(hashlib, algorithm)()
except AttributeError as e:
raise NotImplementedError(
f"'{algorithm}' not available! Available algorithms: {hashlib.algorithms}"
) from e
filesize = os.path.getsize(infile)
iterations = int(float(filesize) / float(BLOCKSIZE))
pbar = get_pbar(f"Generating {algorithm} hash", iterations)
iter = 0
with open(infile, "rb") as f:
buf = f.read(BLOCKSIZE)
while len(buf) > 0:
hasher.update(buf)
buf = f.read(BLOCKSIZE)
iter += 1
pbar.update(iter)
pbar.finish()
return hasher.hexdigest()
[docs]
def get_brewer_cmap(cmap):
"""Returns a colorbrewer colormap from palettable"""
try:
import palettable
except ImportError as exc:
raise RuntimeError(
"Please install palettable to use colorbrewer colormaps"
) from exc
bmap = palettable.colorbrewer.get_map(*cmap)
return bmap.get_mpl_colormap(N=cmap[2])
[docs]
def matplotlib_style_context(style="yt.default", after_reset=False):
"""Returns a context manager for controlling matplotlib style.
Arguments are passed to matplotlib.style.context() if specified. Defaults
to setting yt's "yt.default" style, after resetting to the default config parameters.
"""
# FUTURE: this function should be deprecated in favour of matplotlib.style.context
# after support for matplotlib 3.6 and older versions is dropped.
import importlib.resources as importlib_resources
import matplotlib as mpl
import matplotlib.style
if style == "yt.default" and mpl.__version_info__ < (3, 7):
style = importlib_resources.files("yt") / "default.mplstyle"
return matplotlib.style.context(style, after_reset=after_reset)
interactivity = False
"""Sets the condition that interactive backends can be used."""
[docs]
def toggle_interactivity():
global interactivity
interactivity = not interactivity
if interactivity:
if IS_IPYTHON:
import IPython
shell = IPython.get_ipython()
shell.magic("matplotlib")
else:
import matplotlib
matplotlib.interactive(True)
[docs]
def get_interactivity():
return interactivity
[docs]
def setdefaultattr(obj, name, value):
"""Set attribute with *name* on *obj* with *value* if it doesn't exist yet
Analogous to dict.setdefault
"""
if not hasattr(obj, name):
setattr(obj, name, value)
return getattr(obj, name)
[docs]
def parse_h5_attr(f, attr):
"""A Python3-safe function for getting hdf5 attributes.
If an attribute is supposed to be a string, this will return it as such.
"""
val = f.attrs.get(attr, None)
if isinstance(val, bytes):
return val.decode("utf8")
else:
return val
[docs]
def obj_length(v):
if is_sequence(v):
return len(v)
else:
# If something isn't iterable, we return 0
# to signify zero length (aka a scalar).
return 0
[docs]
def array_like_field(data, x, field):
field = data._determine_fields(field)[0]
finfo = data.ds._get_field_info(field)
if finfo.sampling_type == "particle":
units = finfo.output_units
else:
units = finfo.units
if isinstance(x, YTArray):
arr = copy.deepcopy(x)
arr.convert_to_units(units)
return arr
if isinstance(x, np.ndarray):
return data.ds.arr(x, units)
else:
return data.ds.quan(x, units)
def _full_type_name(obj: object = None, /, *, cls: type | None = None) -> str:
if cls is not None and obj is not None:
raise TypeError("_full_type_name takes an object or a class, but not both")
if cls is None:
cls = obj.__class__
prefix = f"{cls.__module__}." if cls.__module__ != "builtins" else ""
return f"{prefix}{cls.__name__}"
[docs]
def validate_3d_array(obj):
if not is_sequence(obj) or len(obj) != 3:
raise TypeError(
f"Expected an array of size (3,), "
f"received {_full_type_name(obj)!r} of length {len(obj)}"
)
[docs]
def validate_float(obj):
"""Validates if the passed argument is a float value.
Raises an exception if `obj` is not a single float value
or a YTQuantity of size 1.
Parameters
----------
obj : Any
Any argument which needs to be checked for a single float value.
Raises
------
TypeError
Raised if `obj` is not a single float value or YTQunatity
Examples
--------
>>> validate_float(1)
>>> validate_float(1.50)
>>> validate_float(YTQuantity(1, "cm"))
>>> validate_float((1, "cm"))
>>> validate_float([1, 1, 1])
Traceback (most recent call last):
...
TypeError: Expected a numeric value (or size-1 array), received 'list' of length 3
>>> validate_float([YTQuantity(1, "cm"), YTQuantity(2, "cm")])
Traceback (most recent call last):
...
TypeError: Expected a numeric value (or size-1 array), received 'list' of length 2
"""
if isinstance(obj, tuple):
if (
len(obj) != 2
or not isinstance(obj[0], numeric_type)
or not isinstance(obj[1], str)
):
raise TypeError(
"Expected a numeric value (or tuple of format "
f"(float, String)), received an inconsistent tuple {str(obj)!r}."
)
else:
return
if is_sequence(obj) and (len(obj) != 1 or not isinstance(obj[0], numeric_type)):
raise TypeError(
"Expected a numeric value (or size-1 array), "
f"received {_full_type_name(obj)!r} of length {len(obj)}"
)
[docs]
def validate_sequence(obj):
if obj is not None and not is_sequence(obj):
raise TypeError(
"Expected an iterable object, " f"received {_full_type_name(obj)!r}"
)
[docs]
def validate_field_key(key):
if (
isinstance(key, tuple)
and len(key) == 2
and all(isinstance(_, str) for _ in key)
):
return
raise TypeError(
"Expected a 2-tuple of strings formatted as\n"
"(field or particle type, field name)\n"
f"Received invalid field key: {key}, with type {type(key)}"
)
[docs]
def is_valid_field_key(key):
try:
validate_field_key(key)
except TypeError:
return False
else:
return True
[docs]
def validate_object(obj, data_type):
if obj is not None and not isinstance(obj, data_type):
raise TypeError(
f"Expected an object of {_full_type_name(cls=data_type)!r} type, "
f"received {_full_type_name(obj)!r}"
)
[docs]
def validate_axis(ds, axis):
if ds is not None:
valid_axis = sorted(
ds.coordinates.axis_name.keys(), key=lambda k: str(k).swapcase()
)
else:
valid_axis = [0, 1, 2, "x", "y", "z", "X", "Y", "Z"]
if axis not in valid_axis:
raise TypeError(f"Expected axis to be any of {valid_axis}, received {axis!r}")
[docs]
def validate_center(center):
if isinstance(center, str):
c = center.lower()
if (
c not in ["c", "center", "m", "max", "min"]
and not c.startswith("max_")
and not c.startswith("min_")
):
raise TypeError(
"Expected 'center' to be in ['c', 'center', "
"'m', 'max', 'min'] or the prefix to be "
f"'max_'/'min_', received {center!r}."
)
elif not isinstance(center, (numeric_type, YTQuantity)) and not is_sequence(center):
raise TypeError(
"Expected 'center' to be a numeric object of type "
"list/tuple/np.ndarray/YTArray/YTQuantity, "
f"received {_full_type_name(center)}."
)
[docs]
def parse_center_array(center, ds, axis: int | None = None):
known_shortnames = {"m": "max", "c": "center", "l": "left", "r": "right"}
valid_single_str_values = ("center", "left", "right")
valid_field_loc_str_values = ("min", "max")
valid_str_values = valid_single_str_values + valid_field_loc_str_values
default_error_message = (
"Expected any of the following\n"
"- 'c', 'center', 'l', 'left', 'r', 'right', 'm', 'max', or 'min'\n"
"- a 2 element tuple with 'min' or 'max' as the first element, followed by a field identifier\n"
"- a 3 element array-like: for a unyt_array, expects length dimensions, otherwise code_lenght is assumed"
)
# store an unmodified copy of user input to be inserted in error messages
center_input = deepcopy(center)
if isinstance(center, str):
centerl = center.lower()
if centerl in known_shortnames:
centerl = known_shortnames[centerl]
match = re.match(r"^(?P<extremum>(min|max))(_(?P<field>[\w_]+))?", centerl)
if match is not None:
if match["field"] is not None:
for ftype, fname in ds.derived_field_list: # noqa: B007
if fname == match["field"]:
break
else:
raise YTFieldNotFound(match["field"], ds)
else:
ftype, fname = ("gas", "density")
center = (match["extremum"], (ftype, fname))
elif centerl in ("center", "left", "right"):
# domain_left_edge and domain_right_edge might not be
# initialized until we create the index, so create it
ds.index
center = ds.domain_center.copy()
if centerl in ("left", "right") and axis is None:
raise ValueError(f"center={center!r} is not valid with axis=None")
if centerl == "left":
center = ds.domain_center.copy()
center[axis] = ds.domain_left_edge[axis]
elif centerl == "right":
# note that the right edge of a grid is excluded by slice selector
# which is why we offset the region center by the smallest distance possible
center = ds.domain_center.copy()
center[axis] = (
ds.domain_right_edge[axis] - center.uq * np.finfo(center.dtype).eps
)
elif centerl not in valid_str_values:
raise ValueError(
f"Received unknown center single string value {center!r}. "
+ default_error_message
)
if is_sequence(center):
if (
len(center) == 2
and isinstance(center[0], str)
and (is_valid_field_key(center[1]) or isinstance(center[1], str))
):
center0l = center[0].lower()
if center0l not in valid_str_values:
raise ValueError(
f"Received unknown string value {center[0]!r}. "
f"Expected one of {valid_field_loc_str_values} (case insensitive)"
)
field_key = center[1]
if center0l == "min":
v, center = ds.find_min(field_key)
else:
assert center0l == "max"
v, center = ds.find_max(field_key)
center = ds.arr(center, "code_length")
elif len(center) == 2 and is_sequence(center[0]) and isinstance(center[1], str):
center = ds.arr(center[0], center[1])
elif len(center) == 3 and all(isinstance(_, YTQuantity) for _ in center):
center = ds.arr([c.copy() for c in center], dtype="float64")
elif len(center) == 3:
center = ds.arr(center, "code_length")
if isinstance(center, np.ndarray) and center.ndim > 1:
mylog.debug("Removing singleton dimensions from 'center'.")
center = np.squeeze(center)
if not isinstance(center, YTArray):
raise TypeError(
f"Received {center_input!r}, but failed to transform to a unyt_array (obtained {center!r}).\n"
+ default_error_message
+ "\n"
"If you supplied an expected type, consider filing a bug report"
)
if center.shape != (3,):
raise TypeError(
f"Received {center_input!r} and obtained {center!r} after sanitizing.\n"
+ default_error_message
+ "\n"
"If you supplied an expected type, consider filing a bug report"
)
# make sure the return value shares all
# unit symbols with ds.unit_registry
# we rely on unyt to invalidate unit dimensionality here
center = ds.arr(center).in_units("code_length")
if not ds._is_within_domain(center):
mylog.warning(
"Requested center at %s is outside of data domain with "
"left edge = %s, "
"right edge = %s, "
"periodicity = %s",
center,
ds.domain_left_edge,
ds.domain_right_edge,
ds.periodicity,
)
return center.astype("float64")
[docs]
def sglob(pattern):
"""
Return the results of a glob through the sorted() function.
"""
return sorted(glob.glob(pattern))
[docs]
def dictWithFactory(factory: Callable[[Any], Any]) -> type:
"""
Create a dictionary class with a default factory function.
Contrary to `collections.defaultdict`, the factory takes
the missing key as input parameter.
Parameters
----------
factory : callable(key) -> value
The factory to call when hitting a missing key
Returns
-------
DictWithFactory class
A class to create new dictionaries handling missing keys.
"""
issue_deprecation_warning(
"yt.funcs.dictWithFactory will be removed in a future version of yt, please do not rely on it. "
"If you need it, copy paste this function from yt's source code",
stacklevel=3,
since="4.1",
)
class DictWithFactory(UserDict):
def __init__(self, *args, **kwargs):
self.factory = factory
super().__init__(*args, **kwargs)
def __missing__(self, key):
val = self.factory(key)
self[key] = val
return val
return DictWithFactory
[docs]
def levenshtein_distance(seq1, seq2, max_dist=None):
"""
Compute the levenshtein distance between seq1 and seq2.
From https://stackabuse.com/levenshtein-distance-and-text-similarity-in-python/
Parameters
----------
seq1 : str
seq2 : str
The strings to compute the distance between
max_dist : integer
If not None, maximum distance returned (see notes).
Returns
-------
The Levenshtein distance as an integer.
Notes
-----
This computes the Levenshtein distance, i.e. the number of edits to change
seq1 into seq2. If a maximum distance is passed, the algorithm will stop as soon
as the number of edits goes above the value. This allows for an earlier break
and speeds calculations up.
"""
size_x = len(seq1) + 1
size_y = len(seq2) + 1
if max_dist is None:
max_dist = max(size_x, size_y)
if abs(size_x - size_y) > max_dist:
return max_dist + 1
matrix = np.zeros((size_x, size_y), dtype=int)
for x in range(size_x):
matrix[x, 0] = x
for y in range(size_y):
matrix[0, y] = y
for x in range(1, size_x):
for y in range(1, size_y):
if seq1[x - 1] == seq2[y - 1]:
matrix[x, y] = min(
matrix[x - 1, y] + 1, matrix[x - 1, y - 1], matrix[x, y - 1] + 1
)
else:
matrix[x, y] = min(
matrix[x - 1, y] + 1, matrix[x - 1, y - 1] + 1, matrix[x, y - 1] + 1
)
# Early break: the minimum distance is already larger than
# maximum allow value, can return safely.
if matrix[x].min() > max_dist:
return max_dist + 1
return matrix[size_x - 1, size_y - 1]
[docs]
def validate_moment(moment, weight_field):
if moment == 2 and weight_field is None:
raise ValueError(
"Cannot compute the second moment of a projection if weight_field=None!"
)
if moment not in [1, 2]:
raise ValueError(
"Weighted projections can only be made of averages "
"(moment = 1) or standard deviations (moment = 2)!"
)