import abc
import weakref
from functools import cached_property
from numbers import Number
from typing import Any, Literal, overload
import numpy as np
from yt._typing import AxisOrder
from yt.funcs import fix_unitary, is_sequence, parse_center_array, validate_width_tuple
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.exceptions import YTCoordinateNotImplemented, YTInvalidWidthError
def _unknown_coord(field, data):
raise YTCoordinateNotImplemented
def _get_coord_fields(axi, units="code_length"):
def _dds(field, data):
rv = data.ds.arr(data.fwidth[..., axi].copy(), units)
return data._reshape_vals(rv)
def _coords(field, data):
rv = data.ds.arr(data.fcoords[..., axi].copy(), units)
return data._reshape_vals(rv)
return _dds, _coords
def _get_vert_fields(axi, units="code_length"):
def _vert(field, data):
rv = data.ds.arr(data.fcoords_vertex[..., axi].copy(), units)
return rv
return _vert
def _setup_dummy_cartesian_coords_and_widths(registry, axes: tuple[str]):
for ax in axes:
registry.add_field(
("index", f"d{ax}"), sampling_type="cell", function=_unknown_coord
)
registry.add_field(("index", ax), sampling_type="cell", function=_unknown_coord)
def _setup_polar_coordinates(registry, axis_id):
f1, f2 = _get_coord_fields(axis_id["r"])
registry.add_field(
("index", "dr"),
sampling_type="cell",
function=f1,
display_field=False,
units="code_length",
)
registry.add_field(
("index", "r"),
sampling_type="cell",
function=f2,
display_field=False,
units="code_length",
)
f1, f2 = _get_coord_fields(axis_id["theta"], "dimensionless")
registry.add_field(
("index", "dtheta"),
sampling_type="cell",
function=f1,
display_field=False,
units="dimensionless",
)
registry.add_field(
("index", "theta"),
sampling_type="cell",
function=f2,
display_field=False,
units="dimensionless",
)
def _path_r(field, data):
return data["index", "dr"]
registry.add_field(
("index", "path_element_r"),
sampling_type="cell",
function=_path_r,
units="code_length",
)
def _path_theta(field, data):
# Note: this already assumes cell-centered
return data["index", "r"] * data["index", "dtheta"]
registry.add_field(
("index", "path_element_theta"),
sampling_type="cell",
function=_path_theta,
units="code_length",
)
[docs]
def validate_sequence_width(width, ds, unit=None):
if isinstance(width[0], tuple) and isinstance(width[1], tuple):
validate_width_tuple(width[0])
validate_width_tuple(width[1])
return (
ds.quan(width[0][0], fix_unitary(width[0][1])),
ds.quan(width[1][0], fix_unitary(width[1][1])),
)
elif isinstance(width[0], Number) and isinstance(width[1], Number):
return (ds.quan(width[0], "code_length"), ds.quan(width[1], "code_length"))
elif isinstance(width[0], YTQuantity) and isinstance(width[1], YTQuantity):
return (ds.quan(width[0]), ds.quan(width[1]))
else:
validate_width_tuple(width)
# If width and unit are both valid width tuples, we
# assume width controls x and unit controls y
try:
validate_width_tuple(unit)
return (
ds.quan(width[0], fix_unitary(width[1])),
ds.quan(unit[0], fix_unitary(unit[1])),
)
except YTInvalidWidthError:
return (
ds.quan(width[0], fix_unitary(width[1])),
ds.quan(width[0], fix_unitary(width[1])),
)
[docs]
class CoordinateHandler(abc.ABC):
name: str
_default_axis_order: AxisOrder
def __init__(self, ds, ordering: AxisOrder | None = None):
self.ds = weakref.proxy(ds)
if ordering is not None:
self.axis_order = ordering
else:
self.axis_order = self._default_axis_order
[docs]
@abc.abstractmethod
def setup_fields(self):
# This should return field definitions for x, y, z, r, theta, phi
pass
@overload
def pixelize(
self,
dimension,
data_source,
field,
bounds,
size,
antialias=True,
periodic=True,
*,
return_mask: Literal[False],
) -> "np.ndarray[Any, np.dtype[np.float64]]": ...
@overload
def pixelize(
self,
dimension,
data_source,
field,
bounds,
size,
antialias=True,
periodic=True,
*,
return_mask: Literal[True],
) -> tuple[
"np.ndarray[Any, np.dtype[np.float64]]", "np.ndarray[Any, np.dtype[np.bool_]]"
]: ...
[docs]
@abc.abstractmethod
def pixelize(
self,
dimension,
data_source,
field,
bounds,
size,
antialias=True,
periodic=True,
*,
return_mask=False,
):
# This should *actually* be a pixelize call, not just returning the
# pixelizer
pass
[docs]
@abc.abstractmethod
def pixelize_line(self, field, start_point, end_point, npoints):
pass
[docs]
def distance(self, start, end):
p1 = self.convert_to_cartesian(start)
p2 = self.convert_to_cartesian(end)
return np.sqrt(((p1 - p2) ** 2.0).sum())
[docs]
@abc.abstractmethod
def convert_from_cartesian(self, coord):
pass
[docs]
@abc.abstractmethod
def convert_to_cartesian(self, coord):
pass
[docs]
@abc.abstractmethod
def convert_to_cylindrical(self, coord):
pass
[docs]
@abc.abstractmethod
def convert_from_cylindrical(self, coord):
pass
[docs]
@abc.abstractmethod
def convert_to_spherical(self, coord):
pass
[docs]
@abc.abstractmethod
def convert_from_spherical(self, coord):
pass
@cached_property
def data_projection(self):
return {ax: None for ax in self.axis_order}
@cached_property
def data_transform(self):
return {ax: None for ax in self.axis_order}
@cached_property
def axis_name(self):
an = {}
for axi, ax in enumerate(self.axis_order):
an[axi] = ax
an[ax] = ax
an[ax.capitalize()] = ax
return an
@cached_property
def axis_id(self):
ai = {}
for axi, ax in enumerate(self.axis_order):
ai[ax] = ai[axi] = axi
return ai
@property
def image_axis_name(self):
rv = {}
for i in range(3):
rv[i] = (self.axis_name[self.x_axis[i]], self.axis_name[self.y_axis[i]])
rv[self.axis_name[i]] = rv[i]
rv[self.axis_name[i].capitalize()] = rv[i]
return rv
@cached_property
def x_axis(self):
ai = self.axis_id
xa = {}
for a1, a2 in self._x_pairs:
xa[a1] = xa[ai[a1]] = ai[a2]
return xa
@cached_property
def y_axis(self):
ai = self.axis_id
ya = {}
for a1, a2 in self._y_pairs:
ya[a1] = ya[ai[a1]] = ai[a2]
return ya
@property
@abc.abstractmethod
def period(self):
pass
[docs]
def sanitize_depth(self, depth):
if is_sequence(depth):
validate_width_tuple(depth)
depth = (self.ds.quan(depth[0], fix_unitary(depth[1])),)
elif isinstance(depth, Number):
depth = (
self.ds.quan(depth, "code_length", registry=self.ds.unit_registry),
)
elif isinstance(depth, YTQuantity):
depth = (depth,)
else:
raise YTInvalidWidthError(depth)
return depth
[docs]
def sanitize_width(self, axis, width, depth):
if width is None:
# initialize the index if it is not already initialized
self.ds.index
# Default to code units
if not is_sequence(axis):
xax = self.x_axis[axis]
yax = self.y_axis[axis]
w = self.ds.domain_width[np.array([xax, yax])]
else:
# axis is actually the normal vector
# for an off-axis data object.
mi = np.argmin(self.ds.domain_width)
w = self.ds.domain_width[np.array((mi, mi))]
width = (w[0], w[1])
elif is_sequence(width):
width = validate_sequence_width(width, self.ds)
elif isinstance(width, YTQuantity):
width = (width, width)
elif isinstance(width, Number):
width = (
self.ds.quan(width, "code_length"),
self.ds.quan(width, "code_length"),
)
else:
raise YTInvalidWidthError(width)
if depth is not None:
depth = self.sanitize_depth(depth)
return width + depth
return width
[docs]
def sanitize_center(self, center, axis):
center = parse_center_array(center, ds=self.ds, axis=axis)
# This has to return both a center and a display_center
display_center = self.convert_to_cartesian(center)
return center, display_center
[docs]
def cartesian_to_cylindrical(coord, center=(0, 0, 0)):
c2 = np.zeros_like(coord)
if not isinstance(center, YTArray):
center = center * coord.uq
c2[..., 0] = (
(coord[..., 0] - center[0]) ** 2.0 + (coord[..., 1] - center[1]) ** 2.0
) ** 0.5
c2[..., 1] = coord[..., 2] # rzt
c2[..., 2] = np.arctan2(coord[..., 1] - center[1], coord[..., 0] - center[0])
return c2
[docs]
def cylindrical_to_cartesian(coord, center=(0, 0, 0)):
c2 = np.zeros_like(coord)
if not isinstance(center, YTArray):
center = center * coord.uq
c2[..., 0] = np.cos(coord[..., 0]) * coord[..., 1] + center[0]
c2[..., 1] = np.sin(coord[..., 0]) * coord[..., 1] + center[1]
c2[..., 2] = coord[..., 2]
return c2
def _get_polar_bounds(self: CoordinateHandler, axes: tuple[str, str]):
# a small helper function that is needed by two unrelated classes
ri = self.axis_id[axes[0]]
pi = self.axis_id[axes[1]]
rmin = self.ds.domain_left_edge[ri]
rmax = self.ds.domain_right_edge[ri]
phimin = self.ds.domain_left_edge[pi]
phimax = self.ds.domain_right_edge[pi]
corners = [
(rmin, phimin),
(rmin, phimax),
(rmax, phimin),
(rmax, phimax),
]
def to_polar_plane(r, phi):
x = r * np.cos(phi)
y = r * np.sin(phi)
return x, y
conic_corner_coords = [to_polar_plane(*corner) for corner in corners]
phimin = phimin.d
phimax = phimax.d
if phimin <= np.pi <= phimax:
xxmin = -rmax
else:
xxmin = min(xx for xx, yy in conic_corner_coords)
if phimin <= 0 <= phimax:
xxmax = rmax
else:
xxmax = max(xx for xx, yy in conic_corner_coords)
if phimin <= 3 * np.pi / 2 <= phimax:
yymin = -rmax
else:
yymin = min(yy for xx, yy in conic_corner_coords)
if phimin <= np.pi / 2 <= phimax:
yymax = rmax
else:
yymax = max(yy for xx, yy in conic_corner_coords)
return xxmin, xxmax, yymin, yymax