Source code for yt.visualization.line_plot

from collections import defaultdict

import numpy as np
from matplotlib.colors import LogNorm, Normalize, SymLogNorm

from yt.funcs import is_sequence, mylog
from yt.units.unit_object import Unit  # type: ignore
from yt.units.yt_array import YTArray
from yt.visualization.plot_container import (
    BaseLinePlot,
    PlotDictionary,
    invalidate_plot,
)


[docs] class LineBuffer: r""" LineBuffer(ds, start_point, end_point, npoints, label = None) This takes a data source and implements a protocol for generating a 'pixelized', fixed-resolution line buffer. In other words, LineBuffer takes a starting point, ending point, and number of sampling points and can subsequently generate YTArrays of field values along the sample points. Parameters ---------- ds : :class:`yt.data_objects.static_output.Dataset` This is the dataset object holding the data that can be sampled by the LineBuffer start_point : n-element list, tuple, ndarray, or YTArray Contains the coordinates of the first point for constructing the LineBuffer. Must contain n elements where n is the dimensionality of the dataset. end_point : n-element list, tuple, ndarray, or YTArray Contains the coordinates of the first point for constructing the LineBuffer. Must contain n elements where n is the dimensionality of the dataset. npoints : int How many points to sample between start_point and end_point Examples -------- >>> lb = yt.LineBuffer(ds, (0.25, 0, 0), (0.25, 1, 0), 100) >>> lb["all", "u"].max() 0.11562424257143075 dimensionless """ def __init__(self, ds, start_point, end_point, npoints, label=None): self.ds = ds self.start_point = _validate_point(start_point, ds, start=True) self.end_point = _validate_point(end_point, ds) self.npoints = npoints self.label = label self.data = {}
[docs] def keys(self): return self.data.keys()
def __setitem__(self, item, val): self.data[item] = val def __getitem__(self, item): if item in self.data: return self.data[item] mylog.info("Making a line buffer with %d points of %s", self.npoints, item) self.points, self.data[item] = self.ds.coordinates.pixelize_line( item, self.start_point, self.end_point, self.npoints ) return self.data[item] def __delitem__(self, item): del self.data[item]
[docs] class LinePlotDictionary(PlotDictionary): def __init__(self, data_source): super().__init__(data_source) self.known_dimensions = {} def _sanitize_dimensions(self, item): field = self.data_source._determine_fields(item)[0] finfo = self.data_source.ds.field_info[field] dimensions = Unit( finfo.units, registry=self.data_source.ds.unit_registry ).dimensions if dimensions not in self.known_dimensions: self.known_dimensions[dimensions] = item return self.known_dimensions[dimensions] def __getitem__(self, item): ret_item = self._sanitize_dimensions(item) return super().__getitem__(ret_item) def __setitem__(self, item, value): ret_item = self._sanitize_dimensions(item) super().__setitem__(ret_item, value) def __contains__(self, item): ret_item = self._sanitize_dimensions(item) return super().__contains__(ret_item)
[docs] class LinePlot(BaseLinePlot): r""" A class for constructing line plots Parameters ---------- ds : :class:`yt.data_objects.static_output.Dataset` This is the dataset object corresponding to the simulation output to be plotted. fields : string / tuple, or list of strings / tuples The name(s) of the field(s) to be plotted. start_point : n-element list, tuple, ndarray, or YTArray Contains the coordinates of the first point for constructing the line. Must contain n elements where n is the dimensionality of the dataset. end_point : n-element list, tuple, ndarray, or YTArray Contains the coordinates of the first point for constructing the line. Must contain n elements where n is the dimensionality of the dataset. npoints : int How many points to sample between start_point and end_point for constructing the line plot figure_size : int or two-element iterable of ints Size in inches of the image. Default: 5 (5x5) fontsize : int Font size for all text in the plot. Default: 14 field_labels : dictionary Keys should be the field names. Values should be latex-formattable strings used in the LinePlot legend Default: None Example ------- >>> import yt >>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030") >>> plot = yt.LinePlot(ds, "density", [0, 0, 0], [1, 1, 1], 512) >>> plot.add_legend("density") >>> plot.set_x_unit("cm") >>> plot.set_unit("density", "kg/cm**3") >>> plot.save() """ _plot_dict_type = LinePlotDictionary _plot_type = "line_plot" _default_figure_size = (5.0, 5.0) _default_font_size = 14.0 def __init__( self, ds, fields, start_point, end_point, npoints, figure_size=None, fontsize: float | None = None, field_labels=None, ): """ Sets up figure and axes """ line = LineBuffer(ds, start_point, end_point, npoints, label=None) self.lines = [line] self._initialize_instance(self, ds, fields, figure_size, fontsize, field_labels) self._setup_plots() @classmethod def _initialize_instance( cls, obj, ds, fields, figure_size, fontsize, field_labels=None ): obj._x_unit = None obj._titles = {} data_source = ds.all_data() obj.fields = data_source._determine_fields(fields) obj.include_legend = defaultdict(bool) super(LinePlot, obj).__init__( data_source, figure_size=figure_size, fontsize=fontsize ) if field_labels is None: obj.field_labels = {} else: obj.field_labels = field_labels for f in obj.fields: if f not in obj.field_labels: obj.field_labels[f] = f[1] def _get_axrect(self): fontscale = self._font_properties._size / self.__class__._default_font_size top_buff_size = 0.35 * fontscale x_axis_size = 1.35 * fontscale y_axis_size = 0.7 * fontscale right_buff_size = 0.2 * fontscale if is_sequence(self.figure_size): figure_size = self.figure_size else: figure_size = (self.figure_size, self.figure_size) xbins = np.array([x_axis_size, figure_size[0], right_buff_size]) ybins = np.array([y_axis_size, figure_size[1], top_buff_size]) x_frac_widths = xbins / xbins.sum() y_frac_widths = ybins / ybins.sum() return ( x_frac_widths[0], y_frac_widths[0], x_frac_widths[1], y_frac_widths[1], )
[docs] @classmethod def from_lines( cls, ds, fields, lines, figure_size=None, font_size=None, field_labels=None ): """ A class method for constructing a line plot from multiple sampling lines Parameters ---------- ds : :class:`yt.data_objects.static_output.Dataset` This is the dataset object corresponding to the simulation output to be plotted. fields : field name or list of field names The name(s) of the field(s) to be plotted. lines : list of :class:`yt.visualization.line_plot.LineBuffer` instances The lines from which to sample data figure_size : int or two-element iterable of ints Size in inches of the image. Default: 5 (5x5) font_size : int Font size for all text in the plot. Default: 14 field_labels : dictionary Keys should be the field names. Values should be latex-formattable strings used in the LinePlot legend Default: None Example -------- >>> ds = yt.load( ... "SecondOrderTris/RZ_p_no_parts_do_nothing_bcs_cone_out.e", step=-1 ... ) >>> fields = [field for field in ds.field_list if field[0] == "all"] >>> lines = [ ... yt.LineBuffer(ds, [0.25, 0, 0], [0.25, 1, 0], 100, label="x = 0.25"), ... yt.LineBuffer(ds, [0.5, 0, 0], [0.5, 1, 0], 100, label="x = 0.5"), ... ] >>> lines.append() >>> plot = yt.LinePlot.from_lines(ds, fields, lines) >>> plot.save() """ obj = cls.__new__(cls) obj.lines = lines cls._initialize_instance(obj, ds, fields, figure_size, font_size, field_labels) obj._setup_plots() return obj
def _setup_plots(self): if self._plot_valid: return for plot in self.plots.values(): plot.axes.cla() for line in self.lines: dimensions_counter = defaultdict(int) for field in self.fields: finfo = self.ds.field_info[field] dimensions = Unit( finfo.units, registry=self.ds.unit_registry ).dimensions dimensions_counter[dimensions] += 1 for field in self.fields: # get plot instance plot = self._get_plot_instance(field) # calculate x and y x, y = self.ds.coordinates.pixelize_line( field, line.start_point, line.end_point, line.npoints ) # scale x and y to proper units if self._x_unit is None: unit_x = x.units else: unit_x = self._x_unit unit_y = plot.norm_handler.display_units x.convert_to_units(unit_x) y.convert_to_units(unit_y) # determine legend label str_seq = [] str_seq.append(line.label) str_seq.append(self.field_labels[field]) delim = "; " legend_label = delim.join(filter(None, str_seq)) # apply plot to matplotlib axes plot.axes.plot(x, y, label=legend_label) # apply log transforms if requested norm = plot.norm_handler.get_norm(data=y) y_norm_type = type(norm) if y_norm_type is Normalize: plot.axes.set_yscale("linear") elif y_norm_type is LogNorm: plot.axes.set_yscale("log") elif y_norm_type is SymLogNorm: plot.axes.set_yscale("symlog") else: raise NotImplementedError( f"LinePlot doesn't support y norm with type {type(norm)}" ) # set font properties plot._set_font_properties(self._font_properties, None) # set x and y axis labels axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) if self._xlabel is not None: x_label = self._xlabel else: x_label = r"$\rm{Path\ Length" + axes_unit_labels[0] + "}$" if self._ylabel is not None: y_label = self._ylabel else: finfo = self.ds.field_info[field] dimensions = Unit( finfo.units, registry=self.ds.unit_registry ).dimensions if dimensions_counter[dimensions] > 1: y_label = ( r"$\rm{Multiple\ Fields}$" + r"$\rm{" + axes_unit_labels[1] + "}$" ) else: y_label = ( finfo.get_latex_display_name() + r"$\rm{" + axes_unit_labels[1] + "}$" ) plot.axes.set_xlabel(x_label) plot.axes.set_ylabel(y_label) # apply title if field in self._titles: plot.axes.set_title(self._titles[field]) # apply legend dim_field = self.plots._sanitize_dimensions(field) if self.include_legend[dim_field]: plot.axes.legend() self._plot_valid = True
[docs] @invalidate_plot def annotate_legend(self, field): """ Adds a legend to the `LinePlot` instance. The `_sanitize_dimensions` call ensures that a legend label will be added for every field of a multi-field plot """ dim_field = self.plots._sanitize_dimensions(field) self.include_legend[dim_field] = True
[docs] @invalidate_plot def set_x_unit(self, unit_name): """Set the unit to use along the x-axis Parameters ---------- unit_name: str The name of the unit to use for the x-axis unit """ self._x_unit = unit_name
[docs] @invalidate_plot def set_unit(self, field, new_unit): """Set the unit used to plot the field Parameters ---------- field: str or field tuple The name of the field to set the units for new_unit: string or Unit object """ field = self.data_source._determine_fields(field)[0] pnh = self.plots[field].norm_handler pnh.display_units = new_unit
[docs] @invalidate_plot def annotate_title(self, field, title): """Set the unit used to plot the field Parameters ---------- field: str or field tuple The name of the field to set the units for title: str The title to use for the plot """ self._titles[self.data_source._determine_fields(field)[0]] = title
def _validate_point(point, ds, start=False): if not is_sequence(point): raise RuntimeError("Input point must be array-like") if not isinstance(point, YTArray): point = ds.arr(point, "code_length", dtype=np.float64) if len(point.shape) != 1: raise RuntimeError("Input point must be a 1D array") if point.shape[0] < ds.dimensionality: raise RuntimeError("Input point must have an element for each dimension") # need to pad to 3D elements to avoid issues later if point.shape[0] < 3: if start: val = 0 else: val = 1 point = np.append(point.d, [val] * (3 - ds.dimensionality)) * point.uq return point