from collections import defaultdict
import numpy as np
from matplotlib.colors import LogNorm, Normalize, SymLogNorm
from yt.funcs import is_sequence, mylog
from yt.units.unit_object import Unit # type: ignore
from yt.units.yt_array import YTArray
from yt.visualization.plot_container import (
BaseLinePlot,
PlotDictionary,
invalidate_plot,
)
[docs]
class LineBuffer:
r"""
LineBuffer(ds, start_point, end_point, npoints, label = None)
This takes a data source and implements a protocol for generating a
'pixelized', fixed-resolution line buffer. In other words, LineBuffer
takes a starting point, ending point, and number of sampling points and
can subsequently generate YTArrays of field values along the sample points.
Parameters
----------
ds : :class:`yt.data_objects.static_output.Dataset`
This is the dataset object holding the data that can be sampled by the
LineBuffer
start_point : n-element list, tuple, ndarray, or YTArray
Contains the coordinates of the first point for constructing the LineBuffer.
Must contain n elements where n is the dimensionality of the dataset.
end_point : n-element list, tuple, ndarray, or YTArray
Contains the coordinates of the first point for constructing the LineBuffer.
Must contain n elements where n is the dimensionality of the dataset.
npoints : int
How many points to sample between start_point and end_point
Examples
--------
>>> lb = yt.LineBuffer(ds, (0.25, 0, 0), (0.25, 1, 0), 100)
>>> lb["all", "u"].max()
0.11562424257143075 dimensionless
"""
def __init__(self, ds, start_point, end_point, npoints, label=None):
self.ds = ds
self.start_point = _validate_point(start_point, ds, start=True)
self.end_point = _validate_point(end_point, ds)
self.npoints = npoints
self.label = label
self.data = {}
[docs]
def keys(self):
return self.data.keys()
def __setitem__(self, item, val):
self.data[item] = val
def __getitem__(self, item):
if item in self.data:
return self.data[item]
mylog.info("Making a line buffer with %d points of %s", self.npoints, item)
self.points, self.data[item] = self.ds.coordinates.pixelize_line(
item, self.start_point, self.end_point, self.npoints
)
return self.data[item]
def __delitem__(self, item):
del self.data[item]
[docs]
class LinePlotDictionary(PlotDictionary):
def __init__(self, data_source):
super().__init__(data_source)
self.known_dimensions = {}
def _sanitize_dimensions(self, item):
field = self.data_source._determine_fields(item)[0]
finfo = self.data_source.ds.field_info[field]
dimensions = Unit(
finfo.units, registry=self.data_source.ds.unit_registry
).dimensions
if dimensions not in self.known_dimensions:
self.known_dimensions[dimensions] = item
return self.known_dimensions[dimensions]
def __getitem__(self, item):
ret_item = self._sanitize_dimensions(item)
return super().__getitem__(ret_item)
def __setitem__(self, item, value):
ret_item = self._sanitize_dimensions(item)
super().__setitem__(ret_item, value)
def __contains__(self, item):
ret_item = self._sanitize_dimensions(item)
return super().__contains__(ret_item)
[docs]
class LinePlot(BaseLinePlot):
r"""
A class for constructing line plots
Parameters
----------
ds : :class:`yt.data_objects.static_output.Dataset`
This is the dataset object corresponding to the
simulation output to be plotted.
fields : string / tuple, or list of strings / tuples
The name(s) of the field(s) to be plotted.
start_point : n-element list, tuple, ndarray, or YTArray
Contains the coordinates of the first point for constructing the line.
Must contain n elements where n is the dimensionality of the dataset.
end_point : n-element list, tuple, ndarray, or YTArray
Contains the coordinates of the first point for constructing the line.
Must contain n elements where n is the dimensionality of the dataset.
npoints : int
How many points to sample between start_point and end_point for
constructing the line plot
figure_size : int or two-element iterable of ints
Size in inches of the image.
Default: 5 (5x5)
fontsize : int
Font size for all text in the plot.
Default: 14
field_labels : dictionary
Keys should be the field names. Values should be latex-formattable
strings used in the LinePlot legend
Default: None
Example
-------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> plot = yt.LinePlot(ds, "density", [0, 0, 0], [1, 1, 1], 512)
>>> plot.add_legend("density")
>>> plot.set_x_unit("cm")
>>> plot.set_unit("density", "kg/cm**3")
>>> plot.save()
"""
_plot_dict_type = LinePlotDictionary
_plot_type = "line_plot"
_default_figure_size = (5.0, 5.0)
_default_font_size = 14.0
def __init__(
self,
ds,
fields,
start_point,
end_point,
npoints,
figure_size=None,
fontsize: float | None = None,
field_labels=None,
):
"""
Sets up figure and axes
"""
line = LineBuffer(ds, start_point, end_point, npoints, label=None)
self.lines = [line]
self._initialize_instance(self, ds, fields, figure_size, fontsize, field_labels)
self._setup_plots()
@classmethod
def _initialize_instance(
cls, obj, ds, fields, figure_size, fontsize, field_labels=None
):
obj._x_unit = None
obj._titles = {}
data_source = ds.all_data()
obj.fields = data_source._determine_fields(fields)
obj.include_legend = defaultdict(bool)
super(LinePlot, obj).__init__(
data_source, figure_size=figure_size, fontsize=fontsize
)
if field_labels is None:
obj.field_labels = {}
else:
obj.field_labels = field_labels
for f in obj.fields:
if f not in obj.field_labels:
obj.field_labels[f] = f[1]
def _get_axrect(self):
fontscale = self._font_properties._size / self.__class__._default_font_size
top_buff_size = 0.35 * fontscale
x_axis_size = 1.35 * fontscale
y_axis_size = 0.7 * fontscale
right_buff_size = 0.2 * fontscale
if is_sequence(self.figure_size):
figure_size = self.figure_size
else:
figure_size = (self.figure_size, self.figure_size)
xbins = np.array([x_axis_size, figure_size[0], right_buff_size])
ybins = np.array([y_axis_size, figure_size[1], top_buff_size])
x_frac_widths = xbins / xbins.sum()
y_frac_widths = ybins / ybins.sum()
return (
x_frac_widths[0],
y_frac_widths[0],
x_frac_widths[1],
y_frac_widths[1],
)
[docs]
@classmethod
def from_lines(
cls, ds, fields, lines, figure_size=None, font_size=None, field_labels=None
):
"""
A class method for constructing a line plot from multiple sampling lines
Parameters
----------
ds : :class:`yt.data_objects.static_output.Dataset`
This is the dataset object corresponding to the
simulation output to be plotted.
fields : field name or list of field names
The name(s) of the field(s) to be plotted.
lines : list of :class:`yt.visualization.line_plot.LineBuffer` instances
The lines from which to sample data
figure_size : int or two-element iterable of ints
Size in inches of the image.
Default: 5 (5x5)
font_size : int
Font size for all text in the plot.
Default: 14
field_labels : dictionary
Keys should be the field names. Values should be latex-formattable
strings used in the LinePlot legend
Default: None
Example
--------
>>> ds = yt.load(
... "SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1
... )
>>> fields = [field for field in ds.field_list if field[0] == "all"]
>>> lines = [
... yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label="x = 0.25"),
... yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label="x = 0.5"),
... ]
>>> lines.append()
>>> plot = yt.LinePlot.from_lines(ds, fields, lines)
>>> plot.save()
"""
obj = cls.__new__(cls)
obj.lines = lines
cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels)
obj._setup_plots()
return obj
def _setup_plots(self):
if self._plot_valid:
return
for plot in self.plots.values():
plot.axes.cla()
for line in self.lines:
dimensions_counter = defaultdict(int)
for field in self.fields:
finfo = self.ds.field_info[field]
dimensions = Unit(
finfo.units, registry=self.ds.unit_registry
).dimensions
dimensions_counter[dimensions] += 1
for field in self.fields:
# get plot instance
plot = self._get_plot_instance(field)
# calculate x and y
x, y = self.ds.coordinates.pixelize_line(
field, line.start_point, line.end_point, line.npoints
)
# scale x and y to proper units
if self._x_unit is None:
unit_x = x.units
else:
unit_x = self._x_unit
unit_y = plot.norm_handler.display_units
x.convert_to_units(unit_x)
y.convert_to_units(unit_y)
# determine legend label
str_seq = []
str_seq.append(line.label)
str_seq.append(self.field_labels[field])
delim = "; "
legend_label = delim.join(filter(None, str_seq))
# apply plot to matplotlib axes
plot.axes.plot(x, y, label=legend_label)
# apply log transforms if requested
norm = plot.norm_handler.get_norm(data=y)
y_norm_type = type(norm)
if y_norm_type is Normalize:
plot.axes.set_yscale("linear")
elif y_norm_type is LogNorm:
plot.axes.set_yscale("log")
elif y_norm_type is SymLogNorm:
plot.axes.set_yscale("symlog")
else:
raise NotImplementedError(
f"LinePlot doesn't support y norm with type {type(norm)}"
)
# set font properties
plot._set_font_properties(self._font_properties, None)
# set x and y axis labels
axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y)
if self._xlabel is not None:
x_label = self._xlabel
else:
x_label = r"$\rm{Path\ Length" + axes_unit_labels[0] + "}$"
if self._ylabel is not None:
y_label = self._ylabel
else:
finfo = self.ds.field_info[field]
dimensions = Unit(
finfo.units, registry=self.ds.unit_registry
).dimensions
if dimensions_counter[dimensions] > 1:
y_label = (
r"$\rm{Multiple\ Fields}$"
+ r"$\rm{"
+ axes_unit_labels[1]
+ "}$"
)
else:
y_label = (
finfo.get_latex_display_name()
+ r"$\rm{"
+ axes_unit_labels[1]
+ "}$"
)
plot.axes.set_xlabel(x_label)
plot.axes.set_ylabel(y_label)
# apply title
if field in self._titles:
plot.axes.set_title(self._titles[field])
# apply legend
dim_field = self.plots._sanitize_dimensions(field)
if self.include_legend[dim_field]:
plot.axes.legend()
self._plot_valid = True
[docs]
@invalidate_plot
def annotate_legend(self, field):
"""
Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions`
call ensures that a legend label will be added for every field of
a multi-field plot
"""
dim_field = self.plots._sanitize_dimensions(field)
self.include_legend[dim_field] = True
[docs]
@invalidate_plot
def set_x_unit(self, unit_name):
"""Set the unit to use along the x-axis
Parameters
----------
unit_name: str
The name of the unit to use for the x-axis unit
"""
self._x_unit = unit_name
[docs]
@invalidate_plot
def set_unit(self, field, new_unit):
"""Set the unit used to plot the field
Parameters
----------
field: str or field tuple
The name of the field to set the units for
new_unit: string or Unit object
"""
field = self.data_source._determine_fields(field)[0]
pnh = self.plots[field].norm_handler
pnh.display_units = new_unit
[docs]
@invalidate_plot
def annotate_title(self, field, title):
"""Set the unit used to plot the field
Parameters
----------
field: str or field tuple
The name of the field to set the units for
title: str
The title to use for the plot
"""
self._titles[self.data_source._determine_fields(field)[0]] = title
def _validate_point(point, ds, start=False):
if not is_sequence(point):
raise RuntimeError("Input point must be array-like")
if not isinstance(point, YTArray):
point = ds.arr(point, "code_length", dtype=np.float64)
if len(point.shape) != 1:
raise RuntimeError("Input point must be a 1D array")
if point.shape[0] < ds.dimensionality:
raise RuntimeError("Input point must have an element for each dimension")
# need to pad to 3D elements to avoid issues later
if point.shape[0] < 3:
if start:
val = 0
else:
val = 1
point = np.append(point.d, [val] * (3 - ds.dimensionality)) * point.uq
return point