import contextlib
import inspect
import re
from collections.abc import Iterable
from typing import Optional
from more_itertools import always_iterable
import yt.units.dimensions as ytdims
from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import FieldKey
from yt.funcs import iter_fields, validate_field_key
from yt.units.unit_object import Unit # type: ignore
from yt.utilities.exceptions import YTFieldNotFound
from yt.utilities.logger import ytLogger as mylog
from yt.visualization._commons import _get_units_label
from .field_detector import FieldDetector
from .field_exceptions import (
FieldUnitsError,
NeedsDataField,
NeedsGridType,
NeedsOriginalGrid,
NeedsParameter,
NeedsProperty,
)
[docs]
def TranslationFunc(field_name):
def _TranslationFunc(field, data):
# We do a bunch of in-place modifications, so we will copy this.
return data[field_name].copy()
_TranslationFunc.alias_name = field_name
return _TranslationFunc
[docs]
def NullFunc(field, data):
raise YTFieldNotFound(field.name)
[docs]
def DeprecatedFieldFunc(ret_field, func, since, removal):
def _DeprecatedFieldFunc(field, data):
# Only log a warning if we've already done
# field detection
if data.ds.fields_detected:
args = [field.name, since]
msg = "The Derived Field %s is deprecated as of yt v%s "
if removal is not None:
msg += "and will be removed in yt v%s "
args.append(removal)
if ret_field != field.name:
msg += ", use %s instead"
args.append(ret_field)
mylog.warning(msg, *args)
return func(field, data)
return _DeprecatedFieldFunc
[docs]
class DerivedField:
"""
This is the base class used to describe a cell-by-cell derived field.
Parameters
----------
name : str
is the name of the field.
function : callable
A function handle that defines the field. Should accept
arguments (field, data)
units : str
A plain text string encoding the unit, or a query to a unit system of
a dataset. Powers must be in Python syntax (** instead of ^). If set
to 'auto' or None (default), units will be inferred from the return value
of the field function.
take_log : bool
Describes whether the field should be logged
validators : list
A list of :class:`FieldValidator` objects
sampling_type : string, default = "cell"
How is the field sampled? This can be one of the following options at
present: "cell" (cell-centered), "discrete" (or "particle") for
discretely sampled data.
vector_field : bool
Describes the dimensionality of the field. Currently unused.
display_field : bool
Governs its appearance in the dropdowns in Reason
not_in_all : bool
Used for baryon fields from the data that are not in all the grids
display_name : str
A name used in the plots
output_units : str
For fields that exist on disk, which we may want to convert to other
fields or that get aliased to themselves, we can specify a different
desired output unit than the unit found on disk.
dimensions : str or object from yt.units.dimensions
The dimensions of the field, only used for error checking with units='auto'.
nodal_flag : array-like with three components
This describes how the field is centered within a cell. If nodal_flag
is [0, 0, 0], then the field is cell-centered. If any of the components
of nodal_flag are 1, then the field is nodal in that direction, meaning
it is defined at the lo and hi sides of the cell rather than at the center.
For example, a field with nodal_flag = [1, 0, 0] would be defined at the
middle of the 2 x-faces of each cell. nodal_flag = [0, 1, 1] would mean the
that the field defined at the centers of the 4 edges that are normal to the
x axis, while nodal_flag = [1, 1, 1] would be defined at the 8 cell corners.
"""
_inherited_particle_filter = False
def __init__(
self,
name: FieldKey,
sampling_type,
function,
units: str | bytes | Unit | None = None,
take_log=True,
validators=None,
vector_field=False,
display_field=True,
not_in_all=False,
display_name=None,
output_units=None,
dimensions=None,
ds=None,
nodal_flag=None,
*,
alias: Optional["DerivedField"] = None,
):
validate_field_key(name)
self.name = name
self.take_log = take_log
self.display_name = display_name
self.not_in_all = not_in_all
self.display_field = display_field
self.sampling_type = sampling_type
self.vector_field = vector_field
self.ds = ds
if self.ds is not None:
self._ionization_label_format = self.ds._ionization_label_format
else:
self._ionization_label_format = "roman_numeral"
if nodal_flag is None:
self.nodal_flag = [0, 0, 0]
else:
self.nodal_flag = nodal_flag
self._function = function
self.validators = list(always_iterable(validators))
# handle units
self.units: str | bytes | Unit | None
if units in (None, "auto"):
self.units = None
elif isinstance(units, str):
self.units = units
elif isinstance(units, Unit):
self.units = str(units)
elif isinstance(units, bytes):
self.units = units.decode("utf-8")
else:
raise FieldUnitsError(
f"Cannot handle units {units!r} (type {type(units)}). "
"Please provide a string or Unit object."
)
if output_units is None:
output_units = self.units
self.output_units = output_units
if isinstance(dimensions, str):
dimensions = getattr(ytdims, dimensions)
self.dimensions = dimensions
if alias is None:
self._shared_aliases_list = [self]
else:
self._shared_aliases_list = alias._shared_aliases_list
self._shared_aliases_list.append(self)
def _copy_def(self):
dd = {}
dd["name"] = self.name
dd["units"] = self.units
dd["take_log"] = self.take_log
dd["validators"] = list(self.validators)
dd["sampling_type"] = self.sampling_type
dd["vector_field"] = self.vector_field
dd["display_field"] = True
dd["not_in_all"] = self.not_in_all
dd["display_name"] = self.display_name
return dd
@property
def is_sph_field(self):
if self.sampling_type == "cell":
return False
is_sph_field = False
if self.is_alias:
name = self.alias_name
else:
name = self.name
if hasattr(self.ds, "_sph_ptypes"):
is_sph_field |= name[0] in (self.ds._sph_ptypes + ("gas",))
return is_sph_field
@property
def local_sampling(self):
return self.sampling_type in ("discrete", "particle", "local")
[docs]
def get_units(self):
if self.ds is not None:
u = Unit(self.units, registry=self.ds.unit_registry)
else:
u = Unit(self.units)
return u.latex_representation()
[docs]
def get_projected_units(self):
if self.ds is not None:
u = Unit(self.units, registry=self.ds.unit_registry)
else:
u = Unit(self.units)
return (u * Unit("cm")).latex_representation()
[docs]
def check_available(self, data):
"""
This raises an exception of the appropriate type if the set of
validation mechanisms are not met, and otherwise returns True.
"""
for validator in self.validators:
validator(data)
# If we don't get an exception, we're good to go
return True
[docs]
def get_dependencies(self, *args, **kwargs):
"""
This returns a list of names of fields that this field depends on.
"""
e = FieldDetector(*args, **kwargs)
e[self.name]
return e
def _get_needed_parameters(self, fd):
params = []
values = []
permute_params = {}
vals = [v for v in self.validators if isinstance(v, ValidateParameter)]
for val in vals:
if val.parameter_values is not None:
permute_params.update(val.parameter_values)
else:
params.extend(val.parameters)
values.extend([fd.get_field_parameter(fp) for fp in val.parameters])
return dict(zip(params, values, strict=True)), permute_params
_unit_registry = None
[docs]
@contextlib.contextmanager
def unit_registry(self, data):
old_registry = self._unit_registry
if hasattr(data, "unit_registry"):
ur = data.unit_registry
elif hasattr(data, "ds"):
ur = data.ds.unit_registry
else:
ur = None
self._unit_registry = ur
yield
self._unit_registry = old_registry
def __call__(self, data):
"""Return the value of the field in a given *data* object."""
self.check_available(data)
original_fields = data.keys() # Copy
if self._function is NullFunc:
raise RuntimeError(
"Something has gone terribly wrong, _function is NullFunc "
+ f"for {self.name}"
)
with self.unit_registry(data):
dd = self._function(self, data)
for field_name in data.keys():
if field_name not in original_fields:
del data[field_name]
return dd
[docs]
def get_source(self):
"""
Return a string containing the source of the function (if possible.)
"""
return inspect.getsource(self._function)
[docs]
def get_label(self, projected=False):
"""
Return a data label for the given field, including units.
"""
name = self.name[1]
if self.display_name is not None:
name = self.display_name
# Start with the field name
data_label = rf"$\rm{{{name}}}"
# Grab the correct units
if projected:
raise NotImplementedError
else:
if self.ds is not None:
units = Unit(self.units, registry=self.ds.unit_registry)
else:
units = Unit(self.units)
# Add unit label
if not units.is_dimensionless:
data_label += _get_units_label(units.latex_representation()).strip("$")
data_label += r"$"
return data_label
@property
def alias_field(self) -> bool:
issue_deprecation_warning(
"DerivedField.alias_field is a deprecated equivalent to DerivedField.is_alias ",
stacklevel=3,
since="4.1",
)
return self.is_alias
@property
def is_alias(self) -> bool:
return self._shared_aliases_list.index(self) > 0
[docs]
def is_alias_to(self, other: "DerivedField") -> bool:
return self._shared_aliases_list is other._shared_aliases_list
@property
def alias_name(self) -> FieldKey | None:
if self.is_alias:
return self._shared_aliases_list[0].name
return None
def __repr__(self):
if self._function is NullFunc:
s = "On-Disk Field "
elif self.is_alias:
s = f"Alias Field for {self.alias_name!r} "
else:
s = "Derived Field "
s += f"{self.name!r}: (units: {self.units!r}"
if self.display_name is not None:
s += f", display_name: {self.display_name!r}"
if self.sampling_type == "particle":
s += ", particle field"
s += ")"
return s
def _is_ion(self):
p = re.compile("_p[0-9]+_")
result = False
if p.search(self.name[1]) is not None:
result = True
return result
def _ion_to_label(self):
# check to see if the output format has changed
if self.ds is not None:
self._ionization_label_format = self.ds._ionization_label_format
pnum2rom = {
"0": "I",
"1": "II",
"2": "III",
"3": "IV",
"4": "V",
"5": "VI",
"6": "VII",
"7": "VIII",
"8": "IX",
"9": "X",
"10": "XI",
"11": "XII",
"12": "XIII",
"13": "XIV",
"14": "XV",
"15": "XVI",
"16": "XVII",
"17": "XVIII",
"18": "XIX",
"19": "XX",
"20": "XXI",
"21": "XXII",
"22": "XXIII",
"23": "XXIV",
"24": "XXV",
"25": "XXVI",
"26": "XXVII",
"27": "XXVIII",
"28": "XXIX",
"29": "XXX",
}
# first look for charge to decide if it is an ion
p = re.compile("_p[0-9]+_")
m = p.search(self.name[1])
if m is not None:
# Find the ionization state
pstr = m.string[m.start() + 1 : m.end() - 1]
segments = self.name[1].split("_")
# find the ionization index
for i, s in enumerate(segments):
if s == pstr:
ipstr = i
for i, s in enumerate(segments):
# If its the species we don't want to change the capitalization
if i == ipstr - 1:
continue
segments[i] = s.capitalize()
species = segments[ipstr - 1]
# If there is a number in the species part of the label
# that indicates part of a molecule
symbols = []
for symb in species:
# can't just use underscore b/c gets replaced later with space
if symb.isdigit():
symbols.append("latexsub{" + symb + "}")
else:
symbols.append(symb)
species_label = "".join(symbols)
# Use roman numerals for ionization
if self._ionization_label_format == "roman_numeral":
roman = pnum2rom[pstr[1:]]
label = (
species_label
+ r"\ "
+ roman
+ r"\ "
+ r"\ ".join(segments[ipstr + 1 :])
)
# use +/- for ionization
else:
sign = "+" * int(pstr[1:])
label = (
"{"
+ species_label
+ "}"
+ "^{"
+ sign
+ "}"
+ r"\ "
+ r"\ ".join(segments[ipstr + 1 :])
)
else:
label = self.name[1]
return label
[docs]
def get_latex_display_name(self):
label = self.display_name
if label is None:
if self._is_ion():
fname = self._ion_to_label()
label = r"$\rm{" + fname.replace("_", r"\ ") + r"}$"
label = label.replace("latexsub", "_")
else:
label = r"$\rm{" + self.name[1].replace("_", r"\ ").title() + r"}$"
elif label.find("$") == -1:
label = label.replace(" ", r"\ ")
label = r"$\rm{" + label + r"}$"
return label
def __copy__(self):
# a shallow copy doesn't copy the _shared_alias_list attr
# This method is implemented in support to ParticleFilter.wrap_func
return type(self)(
name=self.name,
sampling_type=self.sampling_type,
function=self._function,
units=self.units,
take_log=self.take_log,
validators=self.validators,
vector_field=self.vector_field,
display_field=self.display_field,
not_in_all=self.not_in_all,
display_name=self.display_name,
output_units=self.output_units,
dimensions=self.dimensions,
ds=self.ds,
nodal_flag=self.nodal_flag,
)
[docs]
class FieldValidator:
"""
Base class for FieldValidator objects. Available subclasses include:
"""
def __init_subclass__(cls, **kwargs):
# add the new subclass to the list of subclasses in the docstring
class_str = f":class:`{cls.__name__}`"
if ":class:" in FieldValidator.__doc__:
class_str = ", " + class_str
FieldValidator.__doc__ += class_str
[docs]
class ValidateParameter(FieldValidator):
"""
A :class:`FieldValidator` that ensures the dataset has a given parameter.
Parameters
----------
parameters: str, iterable[str]
a single parameter or list of parameters to require
parameter_values: dict
If *parameter_values* is supplied, this dict should map from field
parameter to a value or list of values. It will ensure that the field
is available for all permutations of the field parameter.
"""
def __init__(
self,
parameters: str | Iterable[str],
parameter_values: dict | None = None,
):
FieldValidator.__init__(self)
self.parameters = list(always_iterable(parameters))
self.parameter_values = parameter_values
def __call__(self, data):
doesnt_have = []
for p in self.parameters:
if not data.has_field_parameter(p):
doesnt_have.append(p)
if len(doesnt_have) > 0:
raise NeedsParameter(doesnt_have)
return True
[docs]
class ValidateDataField(FieldValidator):
"""
A :class:`FieldValidator` that ensures the output file has a given data field stored
in it.
Parameters
----------
field: str, tuple[str, str], or any iterable of the previous types.
the field or fields to require
"""
def __init__(self, field):
FieldValidator.__init__(self)
self.fields = list(iter_fields(field))
def __call__(self, data):
doesnt_have = []
if isinstance(data, FieldDetector):
return True
for f in self.fields:
if f not in data.index.field_list:
doesnt_have.append(f)
if len(doesnt_have) > 0:
raise NeedsDataField(doesnt_have)
return True
[docs]
class ValidateProperty(FieldValidator):
"""
A :class:`FieldValidator` that ensures the data object has a given python attribute.
Parameters
----------
prop: str, iterable[str]
the required property or properties to require
"""
def __init__(self, prop: str | Iterable[str]):
FieldValidator.__init__(self)
self.prop = list(always_iterable(prop))
def __call__(self, data):
doesnt_have = [p for p in self.prop if not hasattr(data, p)]
if len(doesnt_have) > 0:
raise NeedsProperty(doesnt_have)
return True
[docs]
class ValidateSpatial(FieldValidator):
"""
A :class:`FieldValidator` that ensures the data handed to the field is of spatial
nature -- that is to say, 3-D.
Parameters
----------
ghost_zones: int
If supplied, will validate that the number of ghost zones required
for the field is <= the available ghost zones. Default is 0.
fields: Optional str, tuple[str, str], or any iterable of the previous types.
The field or fields to validate.
"""
def __init__(self, ghost_zones: int | None = 0, fields=None):
FieldValidator.__init__(self)
self.ghost_zones = ghost_zones
self.fields = fields
def __call__(self, data):
# When we say spatial information, we really mean
# that it has a three-dimensional data structure
if not getattr(data, "_spatial", False):
raise NeedsGridType(self.ghost_zones, self.fields)
if self.ghost_zones <= data._num_ghost_zones:
return True
raise NeedsGridType(self.ghost_zones, self.fields)
[docs]
class ValidateGridType(FieldValidator):
"""
A :class:`FieldValidator` that ensures the data handed to the field is an actual
grid patch, not a covering grid of any kind. Does not accept parameters.
"""
def __init__(self):
FieldValidator.__init__(self)
def __call__(self, data):
# We need to make sure that it's an actual AMR grid
if isinstance(data, FieldDetector):
return True
if getattr(data, "_type_name", None) == "grid":
return True
raise NeedsOriginalGrid()