Source code for yt.funcs

import base64
import contextlib
import copy
import errno
import glob
import inspect
import itertools
import os
import re
import struct
import subprocess
import sys
import time
import traceback
from collections import UserDict
from collections.abc import Callable
from copy import deepcopy
from functools import lru_cache, wraps
from numbers import Number as numeric_type
from typing import Any

import numpy as np
from more_itertools import always_iterable, collapse, first

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._maintenance.ipython_compat import IS_IPYTHON
from yt.config import ytcfg
from yt.units import YTArray, YTQuantity
from yt.utilities.exceptions import YTFieldNotFound, YTInvalidWidthError
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _requests as requests

# Some functions for handling sequences and other types


[docs] def is_sequence(obj): """ Grabbed from Python Cookbook / matplotlib.cbook. Returns true/false for Parameters ---------- obj : iterable """ try: len(obj) return True except TypeError: return False
[docs] def iter_fields(field_or_fields): """ Create an iterator for field names, specified as single strings or tuples(fname, ftype) alike. This can safely be used in places where we accept a single field or a list as input. Parameters ---------- field_or_fields: str, tuple(str, str), or any iterable of the previous types. Examples -------- >>> fields = ("gas", "density") >>> for field in iter_fields(fields): ... print(field) density >>> fields = ("gas", "density") >>> for field in iter_fields(fields): ... print(field) ('gas', 'density') >>> fields = [("gas", "density"), ("gas", "temperature"), ("index", "dx")] >>> for field in iter_fields(fields): ... print(field) density temperature ('index', 'dx') """ return always_iterable(field_or_fields, base_type=(tuple, str, bytes))
[docs] def ensure_numpy_array(obj): """ This function ensures that *obj* is a numpy array. Typically used to convert scalar, list or tuple argument passed to functions using Cython. """ if isinstance(obj, np.ndarray): if obj.shape == (): return np.array([obj]) # We cast to ndarray to catch ndarray subclasses return np.array(obj) elif isinstance(obj, (list, tuple)): return np.asarray(obj) else: return np.asarray([obj])
[docs] def read_struct(f, fmt): """ This reads a struct, and only that struct, from an open file. """ s = f.read(struct.calcsize(fmt)) return struct.unpack(fmt, s)
[docs] def just_one(obj): # If we have an iterable, sometimes we only want one item return first(collapse(obj))
[docs] def compare_dicts(dict1, dict2): if not set(dict1) <= set(dict2): return False for key in dict1.keys(): if dict1[key] is not None and dict2[key] is not None: if isinstance(dict1[key], dict): if compare_dicts(dict1[key], dict2[key]): continue else: return False try: comparison = np.array_equal(dict1[key], dict2[key]) except TypeError: comparison = dict1[key] == dict2[key] if not comparison: return False return True
# Taken from # http://www.goldb.org/goldblog/2008/02/06/PythonConvertSecsIntoHumanReadableTimeStringHHMMSS.aspx
[docs] def humanize_time(secs): """ Takes *secs* and returns a nicely formatted string """ mins, secs = divmod(secs, 60) hours, mins = divmod(mins, 60) return "%02d:%02d:%02d" % (hours, mins, secs)
# # Some function wrappers that come in handy once in a while #
[docs] def get_memory_usage(subtract_share=False): """ Returning resident size in megabytes """ pid = os.getpid() # we use the resource module to get the memory page size try: import resource except ImportError: return -1024 else: pagesize = resource.getpagesize() status_file = f"/proc/{pid}/statm" if not os.path.isfile(status_file): return -1024 with open(status_file) as fh: line = fh.read() size, resident, share, text, library, data, dt = (int(i) for i in line.split()) if subtract_share: resident -= share return resident * pagesize / (1024 * 1024) # return in megs
[docs] def time_execution(func): r""" Decorator for seeing how long a given function takes, depending on whether or not the global 'yt.time_functions' config parameter is set. """ @wraps(func) def wrapper(*arg, **kw): t1 = time.time() res = func(*arg, **kw) t2 = time.time() mylog.debug("%s took %0.3f s", func.__name__, (t2 - t1)) return res if ytcfg.get("yt", "time_functions"): return wrapper else: return func
[docs] def rootonly(func): """ This is a decorator that, when used, will only call the function on the root processor. This can be used like so: .. code-block:: python @rootonly def some_root_only_function(*args, **kwargs): ... """ @wraps(func) def check_parallel_rank(*args, **kwargs): if ytcfg.get("yt", "internals", "topcomm_parallel_rank") > 0: return return func(*args, **kwargs) return check_parallel_rank
[docs] def pdb_run(func): """ This decorator inserts a pdb session on top of the call-stack into a function. This can be used like so: >>> @pdb_run ... def some_function_to_debug(*args, **kwargs): ... ... """ import pdb @wraps(func) def wrapper(*args, **kw): pdb.runcall(func, *args, **kw) return wrapper
__header = """ == Welcome to the embedded IPython Shell == You are currently inside the function: %(fname)s Defined in: %(filename)s:%(lineno)s """
[docs] def insert_ipython(num_up=1): """ Placed inside a function, this will insert an IPython interpreter at that current location. This will enabled detailed inspection of the current execution environment, as well as (optional) modification of that environment. *num_up* refers to how many frames of the stack get stripped off, and defaults to 1 so that this function itself is stripped off. """ import IPython from IPython.terminal.embed import InteractiveShellEmbed try: from traitlets.config.loader import Config except ImportError: from IPython.config.loader import Config frame = inspect.stack()[num_up] loc = frame[0].f_locals.copy() glo = frame[0].f_globals dd = {"fname": frame[3], "filename": frame[1], "lineno": frame[2]} cfg = Config() cfg.InteractiveShellEmbed.local_ns = loc cfg.InteractiveShellEmbed.global_ns = glo IPython.embed(config=cfg, banner2=__header % dd) ipshell = InteractiveShellEmbed(config=cfg) del ipshell
# # Our progress bar types and how to get one #
[docs] class TqdmProgressBar: # This is a drop in replacement for pbar # called tqdm def __init__(self, title, maxval): from tqdm import tqdm self._pbar = tqdm(leave=True, total=maxval, desc=title) self.i = 0
[docs] def update(self, i=None): if i is None: i = self.i + 1 n = i - self.i self.i = i self._pbar.update(n)
[docs] def finish(self): self._pbar.close()
[docs] class DummyProgressBar: # This progressbar gets handed if we don't # want ANY output def __init__(self, *args, **kwargs): return
[docs] def update(self, *args, **kwargs): return
[docs] def finish(self, *args, **kwargs): return
[docs] def get_pbar(title, maxval): """ This returns a progressbar of the most appropriate type, given a *title* and a *maxval*. """ maxval = max(maxval, 1) if ( ytcfg.get("yt", "suppress_stream_logging") or ytcfg.get("yt", "internals", "within_testing") or maxval == 1 or not is_root() ): return DummyProgressBar() return TqdmProgressBar(title, maxval)
[docs] def only_on_root(func, *args, **kwargs): """ This function accepts a *func*, a set of *args* and *kwargs* and then only on the root processor calls the function. All other processors get "None" handed back. """ if kwargs.pop("global_rootonly", False): cfg_option = "global_parallel_rank" else: cfg_option = "topcomm_parallel_rank" if not ytcfg.get("yt", "internals", "parallel"): return func(*args, **kwargs) if ytcfg.get("yt", "internals", cfg_option) > 0: return return func(*args, **kwargs)
[docs] def is_root(): """ This function returns True if it is on the root processor of the topcomm and False otherwise. """ if not ytcfg.get("yt", "internals", "parallel"): return True return ytcfg.get("yt", "internals", "topcomm_parallel_rank") == 0
# # Our signal and traceback handling functions #
[docs] def signal_print_traceback(signo, frame): print(traceback.print_stack(frame))
[docs] def signal_problem(signo, frame): raise RuntimeError()
[docs] def signal_ipython(signo, frame): insert_ipython(2)
[docs] def paste_traceback(exc_type, exc, tb): """ This is a traceback handler that knows how to paste to the pastebin. Should only be used in sys.excepthook. """ sys.__excepthook__(exc_type, exc, tb) import xmlrpc.client from io import StringIO p = xmlrpc.client.ServerProxy( "http://paste.yt-project.org/xmlrpc/", allow_none=True ) s = StringIO() traceback.print_exception(exc_type, exc, tb, file=s) s = s.getvalue() ret = p.pastes.newPaste("pytb", s, None, "", "", True) print() print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}") print()
[docs] def paste_traceback_detailed(exc_type, exc, tb): """ This is a traceback handler that knows how to paste to the pastebin. Should only be used in sys.excepthook. """ import cgitb import xmlrpc.client from io import StringIO s = StringIO() handler = cgitb.Hook(format="text", file=s) handler(exc_type, exc, tb) s = s.getvalue() print(s) p = xmlrpc.client.ServerProxy( "http://paste.yt-project.org/xmlrpc/", allow_none=True ) ret = p.pastes.newPaste("text", s, None, "", "", True) print() print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}") print()
_ss = "fURbBUUBE0cLXgETJnZgJRMXVhVGUQpQAUBuehQMUhJWRFFRAV1ERAtBXw1dAxMLXT4zXBFfABNN\nC0ZEXw1YUURHCxMXVlFERwxWCQw=\n" def _rdbeta(key): enc_s = base64.decodestring(_ss) dec_s = "".join(chr(ord(a) ^ ord(b)) for a, b in zip(enc_s, itertools.cycle(key))) print(dec_s) # # Some exceptions #
[docs] class NoCUDAException(Exception): pass
[docs] class YTEmptyClass: pass
[docs] def update_git(path): try: import git except ImportError: print("Updating and precise version information requires ") print("gitpython to be installed.") print("Try: python -m pip install gitpython") return -1 with open(os.path.join(path, "yt_updater.log"), "a") as f: repo = git.Repo(path) if repo.is_dirty(untracked_files=True): print("Changes have been made to the yt source code so I won't ") print("update the code. You will have to do this yourself.") print("Here's a set of sample commands:") print("") print(f" $ cd {path}") print(" $ git stash") print(" $ git checkout main") print(" $ git pull") print(" $ git stash pop") print(f" $ {sys.executable} setup.py develop") print("") return 1 if repo.active_branch.name != "main": print("yt repository is not tracking the main branch so I won't ") print("update the code. You will have to do this yourself.") print("Here's a set of sample commands:") print("") print(f" $ cd {path}") print(" $ git checkout main") print(" $ git pull") print(f" $ {sys.executable} setup.py develop") print("") return 1 print("Updating the repository") f.write("Updating the repository\n\n") old_version = repo.git.rev_parse("HEAD", short=12) try: remote = repo.remotes.yt_upstream except AttributeError: remote = repo.create_remote( "yt_upstream", url="https://github.com/yt-project/yt" ) remote.fetch() main = repo.heads.main main.set_tracking_branch(remote.refs.main) main.checkout() remote.pull() new_version = repo.git.rev_parse("HEAD", short=12) f.write(f"Updated from {old_version} to {new_version}\n\n") rebuild_modules(path, f) print("Updated successfully")
[docs] def rebuild_modules(path, f): f.write("Rebuilding modules\n\n") p = subprocess.Popen( [sys.executable, "setup.py", "build_clib", "build_ext", "-i"], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) stdout, stderr = p.communicate() f.write(stdout.decode("utf-8")) f.write("\n\n") if p.returncode: print(f"BROKEN: See {os.path.join(path, 'yt_updater.log')}") sys.exit(1) f.write("Successful!\n")
[docs] def get_git_version(path): try: import git except ImportError: print("Updating and precise version information requires ") print("gitpython to be installed.") print("Try: python -m pip install gitpython") return None try: repo = git.Repo(path) return repo.git.rev_parse("HEAD", short=12) except git.InvalidGitRepositoryError: # path is not a git repository return None
[docs] def get_yt_version(): import importlib.resources as importlib_resources version = get_git_version(os.path.dirname(importlib_resources.files("yt"))) if version is None: return version else: v_str = version[:12].strip() if hasattr(v_str, "decode"): v_str = v_str.decode("utf-8") return v_str
[docs] def get_version_stack(): import matplotlib from yt._version import __version__ as yt_version version_info = {} version_info["yt"] = yt_version version_info["numpy"] = np.version.version version_info["matplotlib"] = matplotlib.__version__ return version_info
[docs] def get_script_contents(): top_frame = inspect.stack()[-1] finfo = inspect.getframeinfo(top_frame[0]) if finfo[2] != "<module>": return None if not os.path.exists(finfo[0]): return None try: contents = open(finfo[0]).read() except Exception: contents = None return contents
[docs] def download_file(url, filename): try: return fancy_download_file(url, filename, requests) except ImportError: # fancy_download_file requires requests return simple_download_file(url, filename)
[docs] def fancy_download_file(url, filename, requests=None): response = requests.get(url, stream=True) total_length = response.headers.get("content-length") with open(filename, "wb") as fh: if total_length is None: fh.write(response.content) else: blocksize = 4 * 1024**2 iterations = int(float(total_length) / float(blocksize)) pbar = get_pbar( "Downloading {} to {} ".format(*os.path.split(filename)[::-1]), iterations, ) iteration = 0 for chunk in response.iter_content(chunk_size=blocksize): fh.write(chunk) iteration += 1 pbar.update(iteration) pbar.finish() return filename
[docs] def simple_download_file(url, filename): import urllib.error import urllib.request try: fn, h = urllib.request.urlretrieve(url, filename) except urllib.error.HTTPError as err: raise RuntimeError( f"Attempt to download file from {url} failed with error {err.code}: {err.msg}." ) from None return fn
# This code snippet is modified from Georg Brandl
[docs] def bb_apicall(endpoint, data, use_pass=True): import getpass import urllib.parse import urllib.request uri = f"https://api.bitbucket.org/1.0/{endpoint}/" # since bitbucket doesn't return the required WWW-Authenticate header when # making a request without Authorization, we cannot use the standard urllib2 # auth handlers; we have to add the requisite header from the start if data is not None: data = urllib.parse.urlencode(data) req = urllib.request.Request(uri, data) if use_pass: username = input("Bitbucket Username? ") password = getpass.getpass() upw = f"{username}:{password}" req.add_header("Authorization", f"Basic {base64.b64encode(upw).strip()}") return urllib.request.urlopen(req).read()
[docs] def fix_length(length, ds): registry = ds.unit_registry if isinstance(length, YTArray): if registry is not None: length.units.registry = registry return length.in_units("code_length") if isinstance(length, numeric_type): return YTArray(length, "code_length", registry=registry) length_valid_tuple = isinstance(length, (list, tuple)) and len(length) == 2 unit_is_string = isinstance(length[1], str) length_is_number = isinstance(length[0], numeric_type) and not isinstance( length[0], YTArray ) if length_valid_tuple and unit_is_string and length_is_number: return YTArray(*length, registry=registry) else: raise RuntimeError(f"Length {str(length)} is invalid")
[docs] @contextlib.contextmanager def parallel_profile(prefix): r"""A context manager for profiling parallel code execution using cProfile This is a simple context manager that automatically profiles the execution of a snippet of code. Parameters ---------- prefix : string A string name to prefix outputs with. Examples -------- >>> from yt import PhasePlot >>> from yt.testing import fake_random_ds >>> fields = ("density", "temperature", "cell_mass") >>> units = ("g/cm**3", "K", "g") >>> ds = fake_random_ds(16, fields=fields, units=units) >>> with parallel_profile("my_profile"): ... plot = PhasePlot(ds.all_data(), *fields) """ import cProfile fn = "%s_%04i_%04i.cprof" % ( prefix, ytcfg.get("yt", "internals", "topcomm_parallel_size"), ytcfg.get("yt", "internals", "topcomm_parallel_rank"), ) p = cProfile.Profile() p.enable() yield fn p.disable() p.dump_stats(fn)
[docs] def get_num_threads(): from .config import ytcfg nt = ytcfg.get("yt", "num_threads") if nt < 0: return os.environ.get("OMP_NUM_THREADS", 0) return nt
[docs] def fix_axis(axis, ds): return ds.coordinates.axis_id.get(axis, axis)
[docs] def get_output_filename(name, keyword, suffix): r"""Return an appropriate filename for output. With a name provided by the user, this will decide how to appropriately name the output file by the following rules: 1. if name is None, the filename will be the keyword plus the suffix. 2. if name ends with "/" (resp "\" on Windows), assume name is a directory and the file will be named name/(keyword+suffix). If the directory does not exist, first try to create it and raise an exception if an error occurs. 3. if name does not end in the suffix, add the suffix. Parameters ---------- name : str A filename given by the user. keyword : str A default filename prefix if name is None. suffix : str Suffix that must appear at end of the filename. This will be added if not present. Examples -------- >>> get_output_filename(None, "Projection_x", ".png") 'Projection_x.png' >>> get_output_filename("my_file", "Projection_x", ".png") 'my_file.png' >>> get_output_filename("my_dir/", "Projection_x", ".png") 'my_dir/Projection_x.png' """ if name is None: name = keyword name = os.path.expanduser(name) if name.endswith(os.sep) and not os.path.isdir(name): ensure_dir(name) if os.path.isdir(name): name = os.path.join(name, keyword) if not name.endswith(suffix): name += suffix return name
[docs] def ensure_dir_exists(path): r"""Create all directories in path recursively in a parallel safe manner""" my_dir = os.path.dirname(path) # If path is a file in the current directory, like "test.txt", then my_dir # would be an empty string, resulting in FileNotFoundError when passed to # ensure_dir. Let's avoid that. if my_dir: ensure_dir(my_dir)
[docs] def ensure_dir(path): r"""Parallel safe directory maker.""" if os.path.exists(path): return path try: os.makedirs(path) except OSError as e: if e.errno == errno.EEXIST: pass else: raise return path
[docs] def validate_width_tuple(width): if not is_sequence(width) or len(width) != 2: raise YTInvalidWidthError(f"width ({width}) is not a two element tuple") is_numeric = isinstance(width[0], numeric_type) length_has_units = isinstance(width[0], YTArray) unit_is_string = isinstance(width[1], str) if not is_numeric or length_has_units and unit_is_string: msg = f"width ({str(width)}) is invalid. " msg += "Valid widths look like this: (12, 'au')" raise YTInvalidWidthError(msg)
_first_cap_re = re.compile("(.)([A-Z][a-z]+)") _all_cap_re = re.compile("([a-z0-9])([A-Z])")
[docs] @lru_cache(maxsize=128, typed=False) def camelcase_to_underscore(name): s1 = _first_cap_re.sub(r"\1_\2", name) return _all_cap_re.sub(r"\1_\2", s1).lower()
[docs] def set_intersection(some_list): if len(some_list) == 0: return set() # This accepts a list of iterables, which we get the intersection of. s = set(some_list[0]) for l in some_list[1:]: s.intersection_update(l) return s
[docs] @contextlib.contextmanager def memory_checker(interval=15, dest=None): r"""This is a context manager that monitors memory usage. Parameters ---------- interval : int The number of seconds between printing the current memory usage in gigabytes of the current Python interpreter. Examples -------- >>> with memory_checker(10): ... arr = np.zeros(1024 * 1024 * 1024, dtype="float64") ... time.sleep(15) ... del arr MEMORY: -1.000e+00 gb """ import threading if dest is None: dest = sys.stdout class MemoryChecker(threading.Thread): def __init__(self, event, interval): self.event = event self.interval = interval threading.Thread.__init__(self) def run(self): while not self.event.wait(self.interval): print(f"MEMORY: {get_memory_usage() / 1024.0:0.3e} gb", file=dest) e = threading.Event() mem_check = MemoryChecker(e, interval) mem_check.start() try: yield finally: e.set()
[docs] def enable_plugins(plugin_filename=None): """Forces a plugin file to be parsed. A plugin file is a means of creating custom fields, quantities, data objects, colormaps, and other code classes and objects to be used in yt scripts without modifying the yt source directly. If ``plugin_filename`` is omitted, this function will look for a plugin file at ``$HOME/.config/yt/my_plugins.py``, which is the preferred behaviour for a system-level configuration. Warning: a script using this function will only be reproducible if your plugin file is shared with it. """ import yt from yt.config import config_dir, ytcfg from yt.fields.my_plugin_fields import my_plugins_fields if plugin_filename is not None: _fn = plugin_filename if not os.path.isfile(_fn): raise FileNotFoundError(_fn) else: # Determine global plugin location. By decreasing priority order: # - absolute path # - CONFIG_DIR # - obsolete config dir. my_plugin_name = ytcfg.get("yt", "plugin_filename") for base_prefix in ("", config_dir()): if os.path.isfile(os.path.join(base_prefix, my_plugin_name)): _fn = os.path.join(base_prefix, my_plugin_name) break else: raise FileNotFoundError("Could not find a global system plugin file.") mylog.info("Loading plugins from %s", _fn) ytdict = yt.__dict__ execdict = ytdict.copy() execdict["add_field"] = my_plugins_fields.add_field with open(_fn) as f: code = compile(f.read(), _fn, "exec") exec(code, execdict, execdict) ytnamespace = list(ytdict.keys()) for k in execdict.keys(): if k not in ytnamespace: if callable(execdict[k]): setattr(yt, k, execdict[k])
[docs] def subchunk_count(n_total, chunk_size): handled = 0 while handled < n_total: tr = min(n_total - handled, chunk_size) yield tr handled += tr
[docs] def fix_unitary(u): if u == "1": return "unitary" else: return u
[docs] def get_hash(infile, algorithm="md5", BLOCKSIZE=65536): """Generate file hash without reading in the entire file at once. Original code licensed under MIT. Source: https://www.pythoncentral.io/hashing-files-with-python/ Parameters ---------- infile : str File of interest (including the path). algorithm : str (optional) Hash algorithm of choice. Defaults to 'md5'. BLOCKSIZE : int (optional) How much data in bytes to read in at once. Returns ------- hash : str The hash of the file. Examples -------- >>> from tempfile import NamedTemporaryFile >>> with NamedTemporaryFile() as file: ... get_hash(file.name) 'd41d8cd98f00b204e9800998ecf8427e' """ import hashlib try: hasher = getattr(hashlib, algorithm)() except AttributeError as e: raise NotImplementedError( f"'{algorithm}' not available! Available algorithms: {hashlib.algorithms}" ) from e filesize = os.path.getsize(infile) iterations = int(float(filesize) / float(BLOCKSIZE)) pbar = get_pbar(f"Generating {algorithm} hash", iterations) iter = 0 with open(infile, "rb") as f: buf = f.read(BLOCKSIZE) while len(buf) > 0: hasher.update(buf) buf = f.read(BLOCKSIZE) iter += 1 pbar.update(iter) pbar.finish() return hasher.hexdigest()
[docs] def get_brewer_cmap(cmap): """Returns a colorbrewer colormap from palettable""" try: import palettable except ImportError as exc: raise RuntimeError( "Please install palettable to use colorbrewer colormaps" ) from exc bmap = palettable.colorbrewer.get_map(*cmap) return bmap.get_mpl_colormap(N=cmap[2])
[docs] def matplotlib_style_context(style="yt.default", after_reset=False): """Returns a context manager for controlling matplotlib style. Arguments are passed to matplotlib.style.context() if specified. Defaults to setting yt's "yt.default" style, after resetting to the default config parameters. """ # FUTURE: this function should be deprecated in favour of matplotlib.style.context # after support for matplotlib 3.6 and older versions is dropped. import importlib.resources as importlib_resources import matplotlib as mpl import matplotlib.style if style == "yt.default" and mpl.__version_info__ < (3, 7): style = importlib_resources.files("yt") / "default.mplstyle" return matplotlib.style.context(style, after_reset=after_reset)
interactivity = False """Sets the condition that interactive backends can be used."""
[docs] def toggle_interactivity(): global interactivity interactivity = not interactivity if interactivity: if IS_IPYTHON: import IPython shell = IPython.get_ipython() shell.magic("matplotlib") else: import matplotlib matplotlib.interactive(True)
[docs] def get_interactivity(): return interactivity
[docs] def setdefaultattr(obj, name, value): """Set attribute with *name* on *obj* with *value* if it doesn't exist yet Analogous to dict.setdefault """ if not hasattr(obj, name): setattr(obj, name, value) return getattr(obj, name)
[docs] def parse_h5_attr(f, attr): """A Python3-safe function for getting hdf5 attributes. If an attribute is supposed to be a string, this will return it as such. """ val = f.attrs.get(attr, None) if isinstance(val, bytes): return val.decode("utf8") else: return val
[docs] def obj_length(v): if is_sequence(v): return len(v) else: # If something isn't iterable, we return 0 # to signify zero length (aka a scalar). return 0
[docs] def array_like_field(data, x, field): field = data._determine_fields(field)[0] finfo = data.ds._get_field_info(field) if finfo.sampling_type == "particle": units = finfo.output_units else: units = finfo.units if isinstance(x, YTArray): arr = copy.deepcopy(x) arr.convert_to_units(units) return arr if isinstance(x, np.ndarray): return data.ds.arr(x, units) else: return data.ds.quan(x, units)
def _full_type_name(obj: object = None, /, *, cls: type | None = None) -> str: if cls is not None and obj is not None: raise TypeError("_full_type_name takes an object or a class, but not both") if cls is None: cls = obj.__class__ prefix = f"{cls.__module__}." if cls.__module__ != "builtins" else "" return f"{prefix}{cls.__name__}"
[docs] def validate_3d_array(obj): if not is_sequence(obj) or len(obj) != 3: raise TypeError( f"Expected an array of size (3,), " f"received {_full_type_name(obj)!r} of length {len(obj)}" )
[docs] def validate_float(obj): """Validates if the passed argument is a float value. Raises an exception if `obj` is not a single float value or a YTQuantity of size 1. Parameters ---------- obj : Any Any argument which needs to be checked for a single float value. Raises ------ TypeError Raised if `obj` is not a single float value or YTQunatity Examples -------- >>> validate_float(1) >>> validate_float(1.50) >>> validate_float(YTQuantity(1, "cm")) >>> validate_float((1, "cm")) >>> validate_float([1, 1, 1]) Traceback (most recent call last): ... TypeError: Expected a numeric value (or size-1 array), received 'list' of length 3 >>> validate_float([YTQuantity(1, "cm"), YTQuantity(2, "cm")]) Traceback (most recent call last): ... TypeError: Expected a numeric value (or size-1 array), received 'list' of length 2 """ if isinstance(obj, tuple): if ( len(obj) != 2 or not isinstance(obj[0], numeric_type) or not isinstance(obj[1], str) ): raise TypeError( "Expected a numeric value (or tuple of format " f"(float, String)), received an inconsistent tuple {str(obj)!r}." ) else: return if is_sequence(obj) and (len(obj) != 1 or not isinstance(obj[0], numeric_type)): raise TypeError( "Expected a numeric value (or size-1 array), " f"received {_full_type_name(obj)!r} of length {len(obj)}" )
[docs] def validate_sequence(obj): if obj is not None and not is_sequence(obj): raise TypeError( "Expected an iterable object, " f"received {_full_type_name(obj)!r}" )
[docs] def validate_field_key(key): if ( isinstance(key, tuple) and len(key) == 2 and all(isinstance(_, str) for _ in key) ): return raise TypeError( "Expected a 2-tuple of strings formatted as\n" "(field or particle type, field name)\n" f"Received invalid field key: {key}, with type {type(key)}" )
[docs] def is_valid_field_key(key): try: validate_field_key(key) except TypeError: return False else: return True
[docs] def validate_object(obj, data_type): if obj is not None and not isinstance(obj, data_type): raise TypeError( f"Expected an object of {_full_type_name(cls=data_type)!r} type, " f"received {_full_type_name(obj)!r}" )
[docs] def validate_axis(ds, axis): if ds is not None: valid_axis = sorted( ds.coordinates.axis_name.keys(), key=lambda k: str(k).swapcase() ) else: valid_axis = [0, 1, 2, "x", "y", "z", "X", "Y", "Z"] if axis not in valid_axis: raise TypeError(f"Expected axis to be any of {valid_axis}, received {axis!r}")
[docs] def validate_center(center): if isinstance(center, str): c = center.lower() if ( c not in ["c", "center", "m", "max", "min"] and not c.startswith("max_") and not c.startswith("min_") ): raise TypeError( "Expected 'center' to be in ['c', 'center', " "'m', 'max', 'min'] or the prefix to be " f"'max_'/'min_', received {center!r}." ) elif not isinstance(center, (numeric_type, YTQuantity)) and not is_sequence(center): raise TypeError( "Expected 'center' to be a numeric object of type " "list/tuple/np.ndarray/YTArray/YTQuantity, " f"received {_full_type_name(center)}." )
[docs] def parse_center_array(center, ds, axis: int | None = None): known_shortnames = {"m": "max", "c": "center", "l": "left", "r": "right"} valid_single_str_values = ("center", "left", "right") valid_field_loc_str_values = ("min", "max") valid_str_values = valid_single_str_values + valid_field_loc_str_values default_error_message = ( "Expected any of the following\n" "- 'c', 'center', 'l', 'left', 'r', 'right', 'm', 'max', or 'min'\n" "- a 2 element tuple with 'min' or 'max' as the first element, followed by a field identifier\n" "- a 3 element array-like: for a unyt_array, expects length dimensions, otherwise code_lenght is assumed" ) # store an unmodified copy of user input to be inserted in error messages center_input = deepcopy(center) if isinstance(center, str): centerl = center.lower() if centerl in known_shortnames: centerl = known_shortnames[centerl] match = re.match(r"^(?P<extremum>(min|max))(_(?P<field>[\w_]+))?", centerl) if match is not None: if match["field"] is not None: for ftype, fname in ds.derived_field_list: # noqa: B007 if fname == match["field"]: break else: raise YTFieldNotFound(match["field"], ds) else: ftype, fname = ("gas", "density") center = (match["extremum"], (ftype, fname)) elif centerl in ("center", "left", "right"): # domain_left_edge and domain_right_edge might not be # initialized until we create the index, so create it ds.index center = ds.domain_center.copy() if centerl in ("left", "right") and axis is None: raise ValueError(f"center={center!r} is not valid with axis=None") if centerl == "left": center = ds.domain_center.copy() center[axis] = ds.domain_left_edge[axis] elif centerl == "right": # note that the right edge of a grid is excluded by slice selector # which is why we offset the region center by the smallest distance possible center = ds.domain_center.copy() center[axis] = ( ds.domain_right_edge[axis] - center.uq * np.finfo(center.dtype).eps ) elif centerl not in valid_str_values: raise ValueError( f"Received unknown center single string value {center!r}. " + default_error_message ) if is_sequence(center): if ( len(center) == 2 and isinstance(center[0], str) and (is_valid_field_key(center[1]) or isinstance(center[1], str)) ): center0l = center[0].lower() if center0l not in valid_str_values: raise ValueError( f"Received unknown string value {center[0]!r}. " f"Expected one of {valid_field_loc_str_values} (case insensitive)" ) field_key = center[1] if center0l == "min": v, center = ds.find_min(field_key) else: assert center0l == "max" v, center = ds.find_max(field_key) center = ds.arr(center, "code_length") elif len(center) == 2 and is_sequence(center[0]) and isinstance(center[1], str): center = ds.arr(center[0], center[1]) elif len(center) == 3 and all(isinstance(_, YTQuantity) for _ in center): center = ds.arr([c.copy() for c in center], dtype="float64") elif len(center) == 3: center = ds.arr(center, "code_length") if isinstance(center, np.ndarray) and center.ndim > 1: mylog.debug("Removing singleton dimensions from 'center'.") center = np.squeeze(center) if not isinstance(center, YTArray): raise TypeError( f"Received {center_input!r}, but failed to transform to a unyt_array (obtained {center!r}).\n" + default_error_message + "\n" "If you supplied an expected type, consider filing a bug report" ) if center.shape != (3,): raise TypeError( f"Received {center_input!r} and obtained {center!r} after sanitizing.\n" + default_error_message + "\n" "If you supplied an expected type, consider filing a bug report" ) # make sure the return value shares all # unit symbols with ds.unit_registry # we rely on unyt to invalidate unit dimensionality here center = ds.arr(center).in_units("code_length") if not ds._is_within_domain(center): mylog.warning( "Requested center at %s is outside of data domain with " "left edge = %s, " "right edge = %s, " "periodicity = %s", center, ds.domain_left_edge, ds.domain_right_edge, ds.periodicity, ) return center.astype("float64")
[docs] def sglob(pattern): """ Return the results of a glob through the sorted() function. """ return sorted(glob.glob(pattern))
[docs] def dictWithFactory(factory: Callable[[Any], Any]) -> type: """ Create a dictionary class with a default factory function. Contrary to `collections.defaultdict`, the factory takes the missing key as input parameter. Parameters ---------- factory : callable(key) -> value The factory to call when hitting a missing key Returns ------- DictWithFactory class A class to create new dictionaries handling missing keys. """ issue_deprecation_warning( "yt.funcs.dictWithFactory will be removed in a future version of yt, please do not rely on it. " "If you need it, copy paste this function from yt's source code", stacklevel=3, since="4.1", ) class DictWithFactory(UserDict): def __init__(self, *args, **kwargs): self.factory = factory super().__init__(*args, **kwargs) def __missing__(self, key): val = self.factory(key) self[key] = val return val return DictWithFactory
[docs] def levenshtein_distance(seq1, seq2, max_dist=None): """ Compute the levenshtein distance between seq1 and seq2. From https://stackabuse.com/levenshtein-distance-and-text-similarity-in-python/ Parameters ---------- seq1 : str seq2 : str The strings to compute the distance between max_dist : integer If not None, maximum distance returned (see notes). Returns ------- The Levenshtein distance as an integer. Notes ----- This computes the Levenshtein distance, i.e. the number of edits to change seq1 into seq2. If a maximum distance is passed, the algorithm will stop as soon as the number of edits goes above the value. This allows for an earlier break and speeds calculations up. """ size_x = len(seq1) + 1 size_y = len(seq2) + 1 if max_dist is None: max_dist = max(size_x, size_y) if abs(size_x - size_y) > max_dist: return max_dist + 1 matrix = np.zeros((size_x, size_y), dtype=int) for x in range(size_x): matrix[x, 0] = x for y in range(size_y): matrix[0, y] = y for x in range(1, size_x): for y in range(1, size_y): if seq1[x - 1] == seq2[y - 1]: matrix[x, y] = min( matrix[x - 1, y] + 1, matrix[x - 1, y - 1], matrix[x, y - 1] + 1 ) else: matrix[x, y] = min( matrix[x - 1, y] + 1, matrix[x - 1, y - 1] + 1, matrix[x, y - 1] + 1 ) # Early break: the minimum distance is already larger than # maximum allow value, can return safely. if matrix[x].min() > max_dist: return max_dist + 1 return matrix[size_x - 1, size_y - 1]
[docs] def validate_moment(moment, weight_field): if moment == 2 and weight_field is None: raise ValueError( "Cannot compute the second moment of a projection if weight_field=None!" ) if moment not in [1, 2]: raise ValueError( "Weighted projections can only be made of averages " "(moment = 1) or standard deviations (moment = 2)!" )
[docs] def setdefault_mpl_metadata(save_kwargs: dict[str, Any], name: str) -> None: """ Set a default Software metadata entry for use with Matplotlib outputs. """ _, ext = os.path.splitext(name.lower()) if ext in (".eps", ".ps", ".svg", ".pdf"): key = "Creator" elif ext == ".png": key = "Software" else: return default_software = ( "Matplotlib version{matplotlib}, https://matplotlib.org|NumPy-{numpy}|yt-{yt}" ).format(**get_version_stack()) if "metadata" in save_kwargs: save_kwargs["metadata"].setdefault(key, default_software) else: save_kwargs["metadata"] = {key: default_software}