import base64
import os
from functools import wraps
from typing import TYPE_CHECKING, Any
import matplotlib
import numpy as np
from more_itertools.more import always_iterable, unzip
from yt._maintenance.ipython_compat import IS_IPYTHON
from yt._typing import FieldKey
from yt.data_objects.profiles import create_profile, sanitize_field_tuple_keys
from yt.data_objects.static_output import Dataset
from yt.frontends.ytdata.data_structures import YTProfileDataset
from yt.funcs import iter_fields, matplotlib_style_context
from yt.utilities.exceptions import YTNotInsideNotebook
from yt.visualization._commons import _get_units_label
from yt.visualization._handlers import ColorbarHandler, NormHandler
from yt.visualization.base_plot_types import ImagePlotMPL, PlotMPL
from ..data_objects.selection_objects.data_selection_objects import YTSelectionContainer
from ._commons import validate_image_name
from .plot_container import (
BaseLinePlot,
ImagePlotContainer,
invalidate_plot,
validate_plot,
)
if TYPE_CHECKING:
from collections.abc import Iterable
from yt._typing import FieldKey
[docs]
def invalidate_profile(f):
@wraps(f)
def newfunc(*args, **kwargs):
rv = f(*args, **kwargs)
args[0]._profile_valid = False
return rv
return newfunc
[docs]
def sanitize_label(labels, nprofiles):
labels = list(always_iterable(labels)) or [None]
if len(labels) == 1:
labels = labels * nprofiles
if len(labels) != nprofiles:
raise ValueError(
f"Number of labels {len(labels)} must match number of profiles {nprofiles}"
)
invalid_data = [
(label, type(label))
for label in labels
if label is not None and not isinstance(label, str)
]
if invalid_data:
invalid_labels, types = unzip(invalid_data)
raise TypeError(
"All labels must be None or a string, "
f"received {invalid_labels} with type {types}"
)
return labels
[docs]
def data_object_or_all_data(data_source):
if isinstance(data_source, Dataset):
data_source = data_source.all_data()
if not isinstance(data_source, YTSelectionContainer):
raise RuntimeError("data_source must be a yt selection data object")
return data_source
[docs]
class ProfilePlot(BaseLinePlot):
r"""
Create a 1d profile plot from a data source or from a list
of profile objects.
Given a data object (all_data, region, sphere, etc.), an x field,
and a y field (or fields), this will create a one-dimensional profile
of the average (or total) value of the y field in bins of the x field.
This can be used to create profiles from given fields or to plot
multiple profiles created from
`yt.data_objects.profiles.create_profile`.
Parameters
----------
data_source : YTSelectionContainer Object
The data object to be profiled, such as all_data, region, or
sphere. If a dataset is passed in instead, an all_data data object
is generated internally from the dataset.
x_field : str
The binning field for the profile.
y_fields : str or list
The field or fields to be profiled.
weight_field : str
The weight field for calculating weighted averages. If None,
the profile values are the sum of the field values within the bin.
Otherwise, the values are a weighted average.
Default : ("gas", "mass")
n_bins : int
The number of bins in the profile.
Default: 64.
accumulation : bool
If True, the profile values for a bin N are the cumulative sum of
all the values from bin 0 to N.
Default: False.
fractional : If True the profile values are divided by the sum of all
the profile data such that the profile represents a probability
distribution function.
label : str or list of strings
If a string, the label to be put on the line plotted. If a list,
this should be a list of labels for each profile to be overplotted.
Default: None.
plot_spec : dict or list of dicts
A dictionary or list of dictionaries containing plot keyword
arguments. For example, dict(color="red", linestyle=":").
Default: None.
x_log : bool
Whether the x_axis should be plotted with a logarithmic
scaling (True), or linear scaling (False).
Default: True.
y_log : dict or bool
A dictionary containing field:boolean pairs, setting the logarithmic
property for that field. May be overridden after instantiation using
set_log
A single boolean can be passed to signify all fields should use
logarithmic (True) or linear scaling (False).
Default: True.
Examples
--------
This creates profiles of a single dataset.
>>> import yt
>>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046")
>>> ad = ds.all_data()
>>> plot = yt.ProfilePlot(
... ad,
... ("gas", "density"),
... [("gas", "temperature"), ("gas", "velocity_x")],
... weight_field=("gas", "mass"),
... plot_spec=dict(color="red", linestyle="--"),
... )
>>> plot.save()
This creates profiles from a time series object.
>>> es = yt.load_simulation("AMRCosmology.enzo", "Enzo")
>>> es.get_time_series()
>>> profiles = []
>>> labels = []
>>> plot_specs = []
>>> for ds in es[-4:]:
... ad = ds.all_data()
... profiles.append(
... create_profile(
... ad,
... [("gas", "density")],
... fields=[("gas", "temperature"), ("gas", "velocity_x")],
... )
... )
... labels.append(ds.current_redshift)
... plot_specs.append(dict(linestyle="--", alpha=0.7))
>>> plot = yt.ProfilePlot.from_profiles(
... profiles, labels=labels, plot_specs=plot_specs
... )
>>> plot.save()
Use set_line_property to change line properties of one or all profiles.
"""
_default_figure_size = (10.0, 8.0)
_default_font_size = 18.0
x_log = None
y_log = None
x_title = None
y_title = None
_plot_valid = False
def __init__(
self,
data_source,
x_field,
y_fields,
weight_field=("gas", "mass"),
n_bins=64,
accumulation=False,
fractional=False,
label=None,
plot_spec=None,
x_log=True,
y_log=True,
):
data_source = data_object_or_all_data(data_source)
y_fields = list(iter_fields(y_fields))
logs = {x_field: bool(x_log)}
if isinstance(y_log, bool):
y_log = {y_field: y_log for y_field in y_fields}
if isinstance(data_source.ds, YTProfileDataset):
profiles = [data_source.ds.profile]
else:
profiles = [
create_profile(
data_source,
[x_field],
n_bins=[n_bins],
fields=y_fields,
weight_field=weight_field,
accumulation=accumulation,
fractional=fractional,
logs=logs,
)
]
if plot_spec is None:
plot_spec = [{} for p in profiles]
if not isinstance(plot_spec, list):
plot_spec = [plot_spec.copy() for p in profiles]
ProfilePlot._initialize_instance(
self, data_source, profiles, label, plot_spec, y_log
)
@classmethod
def _initialize_instance(
cls,
obj,
data_source,
profiles,
labels,
plot_specs,
y_log,
):
obj._plot_title = {}
obj._plot_text = {}
obj._text_xpos = {}
obj._text_ypos = {}
obj._text_kwargs = {}
super(ProfilePlot, obj).__init__(data_source)
obj.profiles = list(always_iterable(profiles))
obj.x_log = None
obj.y_log = sanitize_field_tuple_keys(y_log, data_source) or {}
obj.y_title = {}
obj.x_title = None
obj.label = sanitize_label(labels, len(obj.profiles))
if plot_specs is None:
plot_specs = [{} for p in obj.profiles]
obj.plot_spec = plot_specs
obj._xlim = (None, None)
obj._setup_plots()
obj._plot_valid = False # see https://github.com/yt-project/yt/issues/4489
return obj
def _get_axrect(self):
return (0.1, 0.1, 0.8, 0.8)
[docs]
@validate_plot
def save(
self,
name: str | None = None,
suffix: str | None = None,
mpl_kwargs: dict[str, Any] | None = None,
):
r"""
Saves a 1d profile plot.
Parameters
----------
name : str, optional
The output file keyword.
suffix : string, optional
Specify the image type by its suffix. If not specified, the output
type will be inferred from the filename. Defaults to '.png'.
mpl_kwargs : dict, optional
A dict of keyword arguments to be passed to matplotlib.
"""
if not self._plot_valid:
self._setup_plots()
# Mypy is hardly convinced that we have a `profiles` attribute
# at this stage, so we're lasily going to deactivate it locally
unique = set(self.plots.values())
iters: Iterable[tuple[int | FieldKey, PlotMPL]]
if len(unique) < len(self.plots):
iters = enumerate(sorted(unique))
else:
iters = self.plots.items()
if name is None:
if len(self.profiles) == 1: # type: ignore
name = str(self.profiles[0].ds) # type: ignore
else:
name = "Multi-data"
name = validate_image_name(name, suffix)
prefix, suffix = os.path.splitext(name)
xfn = self.profiles[0].x_field # type: ignore
if isinstance(xfn, tuple):
xfn = xfn[1]
names = []
for uid, plot in iters:
if isinstance(uid, tuple):
uid = uid[1] # type: ignore
uid_name = f"{prefix}_1d-Profile_{xfn}_{uid}{suffix}"
names.append(uid_name)
with matplotlib_style_context():
plot.save(uid_name, mpl_kwargs=mpl_kwargs)
return names
[docs]
@validate_plot
def show(self):
r"""This will send any existing plots to the IPython notebook.
If yt is being run from within an IPython session, and it is able to
determine this, this function will send any existing plots to the
notebook for display.
If yt can't determine if it's inside an IPython session, it will raise
YTNotInsideNotebook.
Examples
--------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> pp = ProfilePlot(ds.all_data(), ("gas", "density"), ("gas", "temperature"))
>>> pp.show()
"""
if IS_IPYTHON:
from IPython.display import display
display(self)
else:
raise YTNotInsideNotebook
@validate_plot
def _repr_html_(self):
"""Return an html representation of the plot object. Will display as a
png for each WindowPlotMPL instance in self.plots"""
ret = ""
unique = set(self.plots.values())
if len(unique) < len(self.plots):
iters = sorted(unique)
else:
iters = self.plots.values()
for plot in iters:
with matplotlib_style_context():
img = plot._repr_png_()
img = base64.b64encode(img).decode()
ret += (
r'<img style="max-width:100%;max-height:100%;" '
rf'src="data:image/png;base64,{img}"><br>'
)
return ret
def _setup_plots(self):
if self._plot_valid:
return
for f, p in self.plots.items():
p.axes.cla()
if f in self._plot_text:
p.axes.text(
self._text_xpos[f],
self._text_ypos[f],
self._plot_text[f],
fontproperties=self._font_properties,
**self._text_kwargs[f],
)
self._set_font_properties()
for i, profile in enumerate(self.profiles):
for field, field_data in profile.items():
plot = self._get_plot_instance(field)
plot.axes.plot(
np.array(profile.x),
np.array(field_data),
label=self.label[i],
**self.plot_spec[i],
)
for profile in self.profiles:
for fname in profile.keys():
axes = self.plots[fname].axes
xscale, yscale = self._get_field_log(fname, profile)
xtitle, ytitle = self._get_field_title(fname, profile)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
axes.set_ylabel(ytitle)
axes.set_xlabel(xtitle)
pnh = self.plots[fname].norm_handler
axes.set_ylim(pnh.vmin, pnh.vmax)
axes.set_xlim(*self._xlim)
if fname in self._plot_title:
axes.set_title(self._plot_title[fname])
if any(self.label):
axes.legend(loc="best")
self._set_font_properties()
self._plot_valid = True
[docs]
@classmethod
def from_profiles(cls, profiles, labels=None, plot_specs=None, y_log=None):
r"""
Instantiate a ProfilePlot object from a list of profiles
created with :func:`~yt.data_objects.profiles.create_profile`.
Parameters
----------
profiles : a profile or list of profiles
A single profile or list of profile objects created with
:func:`~yt.data_objects.profiles.create_profile`.
labels : list of strings
A list of labels for each profile to be overplotted.
Default: None.
plot_specs : list of dicts
A list of dictionaries containing plot keyword
arguments. For example, [dict(color="red", linestyle=":")].
Default: None.
Examples
--------
>>> from yt import load_simulation
>>> es = load_simulation("AMRCosmology.enzo", "Enzo")
>>> es.get_time_series()
>>> profiles = []
>>> labels = []
>>> plot_specs = []
>>> for ds in es[-4:]:
... ad = ds.all_data()
... profiles.append(
... create_profile(
... ad,
... [("gas", "density")],
... fields=[("gas", "temperature"), ("gas", "velocity_x")],
... )
... )
... labels.append(ds.current_redshift)
... plot_specs.append(dict(linestyle="--", alpha=0.7))
>>> plot = ProfilePlot.from_profiles(
... profiles, labels=labels, plot_specs=plot_specs
... )
>>> plot.save()
"""
if labels is not None and len(profiles) != len(labels):
raise RuntimeError("Profiles list and labels list must be the same size.")
if plot_specs is not None and len(plot_specs) != len(profiles):
raise RuntimeError(
"Profiles list and plot_specs list must be the same size."
)
obj = cls.__new__(cls)
profiles = list(always_iterable(profiles))
return cls._initialize_instance(
obj, profiles[0].data_source, profiles, labels, plot_specs, y_log
)
[docs]
@invalidate_plot
def set_line_property(self, property, value, index=None):
r"""
Set properties for one or all lines to be plotted.
Parameters
----------
property : str
The line property to be set.
value : str, int, float
The value to set for the line property.
index : int
The index of the profile in the list of profiles to be
changed. If None, change all plotted lines.
Default : None.
Examples
--------
Change all the lines in a plot
plot.set_line_property("linestyle", "-")
Change a single line.
plot.set_line_property("linewidth", 4, index=0)
"""
if index is None:
specs = self.plot_spec
else:
specs = [self.plot_spec[index]]
for spec in specs:
spec[property] = value
return self
[docs]
@invalidate_plot
def set_log(self, field, log):
"""set a field to log or linear.
Parameters
----------
field : string
the field to set a transform
log : boolean
Log on/off.
"""
if field == "all":
self.x_log = log
for field in list(self.profiles[0].field_data.keys()):
self.y_log[field] = log
else:
(field,) = self.profiles[0].data_source._determine_fields([field])
if field == self.profiles[0].x_field:
self.x_log = log
elif field in self.profiles[0].field_data:
self.y_log[field] = log
else:
raise KeyError(f"Field {field} not in profile plot!")
return self
[docs]
@invalidate_plot
def set_ylabel(self, field, label):
"""Sets a new ylabel for the specified fields
Parameters
----------
field : string
The name of the field that is to be changed.
label : string
The label to be placed on the y-axis
"""
if field == "all":
for field in self.profiles[0].field_data:
self.y_title[field] = label
else:
(field,) = self.profiles[0].data_source._determine_fields([field])
if field in self.profiles[0].field_data:
self.y_title[field] = label
else:
raise KeyError(f"Field {field} not in profile plot!")
return self
[docs]
@invalidate_plot
def set_xlabel(self, label):
"""Sets a new xlabel for all profiles
Parameters
----------
label : string
The label to be placed on the x-axis
"""
self.x_title = label
return self
[docs]
@invalidate_plot
def set_unit(self, field, unit):
"""Sets a new unit for the requested field
Parameters
----------
field : string
The name of the field that is to be changed.
unit : string or Unit object
The name of the new unit.
"""
fd = self.profiles[0].data_source._determine_fields(field)[0]
for profile in self.profiles:
if fd == profile.x_field:
profile.set_x_unit(unit)
elif fd[1] in self.profiles[0].field_map:
profile.set_field_unit(field, unit)
else:
raise KeyError(f"Field {field} not in profile plot!")
return self
[docs]
@invalidate_plot
def set_xlim(self, xmin=None, xmax=None):
"""Sets the limits of the bin field
Parameters
----------
xmin : float or None
The new x minimum. Defaults to None, which leaves the xmin
unchanged.
xmax : float or None
The new x maximum. Defaults to None, which leaves the xmax
unchanged.
Examples
--------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> pp = yt.ProfilePlot(
... ds.all_data(), ("gas", "density"), ("gas", "temperature")
... )
>>> pp.set_xlim(1e-29, 1e-24)
>>> pp.save()
"""
self._xlim = (xmin, xmax)
for i, p in enumerate(self.profiles):
if xmin is None:
xmi = p.x_bins.min()
else:
xmi = xmin
if xmax is None:
xma = p.x_bins.max()
else:
xma = xmax
extrema = {p.x_field: ((xmi, str(p.x.units)), (xma, str(p.x.units)))}
units = {p.x_field: str(p.x.units)}
if self.x_log is None:
logs = None
else:
logs = {p.x_field: self.x_log}
for field in p.field_map.values():
units[field] = str(p.field_data[field].units)
self.profiles[i] = create_profile(
p.data_source,
p.x_field,
n_bins=len(p.x_bins) - 1,
fields=list(p.field_map.values()),
weight_field=p.weight_field,
accumulation=p.accumulation,
fractional=p.fractional,
logs=logs,
extrema=extrema,
units=units,
)
return self
[docs]
@invalidate_plot
def set_ylim(self, field, ymin=None, ymax=None):
"""Sets the plot limits for the specified field we are binning.
Parameters
----------
field : string or field tuple
The field that we want to adjust the plot limits for.
ymin : float or None
The new y minimum. Defaults to None, which leaves the ymin
unchanged.
ymax : float or None
The new y maximum. Defaults to None, which leaves the ymax
unchanged.
Examples
--------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> pp = yt.ProfilePlot(
... ds.all_data(),
... ("gas", "density"),
... [("gas", "temperature"), ("gas", "velocity_x")],
... )
>>> pp.set_ylim(("gas", "temperature"), 1e4, 1e6)
>>> pp.save()
"""
fields = list(self.plots.keys()) if field == "all" else field
for profile in self.profiles:
for field in profile.data_source._determine_fields(fields):
if field in profile.field_map:
field = profile.field_map[field]
pnh = self.plots[field].norm_handler
pnh.vmin = ymin
pnh.vmax = ymax
# Continue on to the next profile.
break
return self
def _set_font_properties(self):
for f in self.plots:
self.plots[f]._set_font_properties(self._font_properties, self._font_color)
def _get_field_log(self, field_y, profile):
yfi = profile.field_info[field_y]
if self.x_log is None:
x_log = profile.x_log
else:
x_log = self.x_log
y_log = self.y_log.get(field_y, yfi.take_log)
scales = {True: "log", False: "linear"}
return scales[x_log], scales[y_log]
def _get_field_label(self, field, field_info, field_unit, fractional=False):
field_unit = field_unit.latex_representation()
field_name = field_info.display_name
if isinstance(field, tuple):
field = field[1]
if field_name is None:
field_name = field_info.get_latex_display_name()
elif field_name.find("$") == -1:
field_name = field_name.replace(" ", r"\ ")
field_name = r"$\rm{" + field_name + r"}$"
if fractional:
label = field_name + r"$\rm{\ Probability\ Density}$"
elif field_unit is None or field_unit == "":
label = field_name
else:
label = field_name + _get_units_label(field_unit)
return label
def _get_field_title(self, field_y, profile):
field_x = profile.x_field
xfi = profile.field_info[field_x]
yfi = profile.field_info[field_y]
x_unit = profile.x.units
y_unit = profile.field_units[field_y]
fractional = profile.fractional
x_title = self.x_title or self._get_field_label(field_x, xfi, x_unit)
y_title = self.y_title.get(field_y, None) or self._get_field_label(
field_y, yfi, y_unit, fractional
)
return (x_title, y_title)
[docs]
@invalidate_plot
def annotate_title(self, title, field="all"):
r"""Set a title for the plot.
Parameters
----------
title : str
The title to add.
field : str or list of str
The field name for which title needs to be set.
Examples
--------
>>> # To set title for all the fields:
>>> plot.annotate_title("This is a Profile Plot")
>>> # To set title for specific fields:
>>> plot.annotate_title("Profile Plot for Temperature", ("gas", "temperature"))
>>> # Setting same plot title for both the given fields
>>> plot.annotate_title(
... "Profile Plot: Temperature-Dark Matter Density",
... [("gas", "temperature"), ("deposit", "dark_matter_density")],
... )
"""
fields = list(self.plots.keys()) if field == "all" else field
for profile in self.profiles:
for field in profile.data_source._determine_fields(fields):
if field in profile.field_map:
field = profile.field_map[field]
self._plot_title[field] = title
return self
[docs]
@invalidate_plot
def annotate_text(self, xpos=0.0, ypos=0.0, text=None, field="all", **text_kwargs):
r"""Allow the user to insert text onto the plot
The x-position and y-position must be given as well as the text string.
Add *text* to plot at location *xpos*, *ypos* in plot coordinates for
the given fields or by default for all fields.
(see example below).
Parameters
----------
xpos : float
Position on plot in x-coordinates.
ypos : float
Position on plot in y-coordinates.
text : str
The text to insert onto the plot.
field : str or tuple
The name of the field to add text to.
**text_kwargs : dict
Extra keyword arguments will be passed to matplotlib text instance
>>> import yt
>>> from yt.units import kpc
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> my_galaxy = ds.disk(ds.domain_center, [0.0, 0.0, 1.0], 10 * kpc, 3 * kpc)
>>> plot = yt.ProfilePlot(
... my_galaxy, ("gas", "density"), [("gas", "temperature")]
... )
>>> # Annotate text for all the fields
>>> plot.annotate_text(1e-26, 1e5, "This is annotated text in the plot area.")
>>> plot.save()
>>> # Annotate text for a given field
>>> plot.annotate_text(1e-26, 1e5, "Annotated text", ("gas", "temperature"))
>>> plot.save()
>>> # Annotate text for multiple fields
>>> fields = [("gas", "temperature"), ("gas", "density")]
>>> plot.annotate_text(1e-26, 1e5, "Annotated text", fields)
>>> plot.save()
"""
fields = list(self.plots.keys()) if field == "all" else field
for profile in self.profiles:
for field in profile.data_source._determine_fields(fields):
if field in profile.field_map:
field = profile.field_map[field]
self._plot_text[field] = text
self._text_xpos[field] = xpos
self._text_ypos[field] = ypos
self._text_kwargs[field] = text_kwargs
return self
[docs]
class PhasePlot(ImagePlotContainer):
r"""
Create a 2d profile (phase) plot from a data source or from
profile object created with
`yt.data_objects.profiles.create_profile`.
Given a data object (all_data, region, sphere, etc.), an x field,
y field, and z field (or fields), this will create a two-dimensional
profile of the average (or total) value of the z field in bins of the
x and y fields.
Parameters
----------
data_source : YTSelectionContainer Object
The data object to be profiled, such as all_data, region, or
sphere. If a dataset is passed in instead, an all_data data object
is generated internally from the dataset.
x_field : str
The x binning field for the profile.
y_field : str
The y binning field for the profile.
z_fields : str or list
The field or fields to be profiled.
weight_field : str
The weight field for calculating weighted averages. If None,
the profile values are the sum of the field values within the bin.
Otherwise, the values are a weighted average.
Default : ("gas", "mass")
x_bins : int
The number of bins in x field for the profile.
Default: 128.
y_bins : int
The number of bins in y field for the profile.
Default: 128.
accumulation : bool or list of bools
If True, the profile values for a bin n are the cumulative sum of
all the values from bin 0 to n. If -True, the sum is reversed so
that the value for bin n is the cumulative sum from bin N (total bins)
to n. A list of values can be given to control the summation in each
dimension independently.
Default: False.
fractional : If True the profile values are divided by the sum of all
the profile data such that the profile represents a probability
distribution function.
fontsize : int
Font size for all text in the plot.
Default: 18.
figure_size : int
Size in inches of the image.
Default: 8 (8x8)
shading : str
This argument is directly passed down to matplotlib.axes.Axes.pcolormesh
see
https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py # noqa
Default: 'nearest'
Examples
--------
>>> import yt
>>> ds = yt.load("enzo_tiny_cosmology/DD0046/DD0046")
>>> ad = ds.all_data()
>>> plot = yt.PhasePlot(
... ad,
... ("gas", "density"),
... ("gas", "temperature"),
... [("gas", "mass")],
... weight_field=None,
... )
>>> plot.save()
>>> # Change plot properties.
>>> plot.set_cmap(("gas", "mass"), "jet")
>>> plot.set_zlim(("gas", "mass"), 1e8, 1e13)
>>> plot.annotate_title("This is a phase plot")
"""
x_log = None
y_log = None
plot_title = None
_plot_valid = False
_profile_valid = False
_plot_type = "Phase"
_xlim = (None, None)
_ylim = (None, None)
def __init__(
self,
data_source,
x_field,
y_field,
z_fields,
weight_field=("gas", "mass"),
x_bins=128,
y_bins=128,
accumulation=False,
fractional=False,
fontsize=18,
figure_size=8.0,
shading="nearest",
):
data_source = data_object_or_all_data(data_source)
if isinstance(z_fields, tuple):
z_fields = [z_fields]
z_fields = list(always_iterable(z_fields))
if isinstance(data_source.ds, YTProfileDataset):
profile = data_source.ds.profile
else:
profile = create_profile(
data_source,
[x_field, y_field],
z_fields,
n_bins=[x_bins, y_bins],
weight_field=weight_field,
accumulation=accumulation,
fractional=fractional,
)
type(self)._initialize_instance(
self, data_source, profile, fontsize, figure_size, shading
)
@classmethod
def _initialize_instance(
cls, obj, data_source, profile, fontsize, figure_size, shading
):
obj.plot_title = {}
obj.z_log = {}
obj.z_title = {}
obj._initfinished = False
obj.x_log = None
obj.y_log = None
obj._plot_text = {}
obj._text_xpos = {}
obj._text_ypos = {}
obj._text_kwargs = {}
obj._profile = profile
obj._shading = shading
obj._profile_valid = True
obj._xlim = (None, None)
obj._ylim = (None, None)
super(PhasePlot, obj).__init__(data_source, figure_size, fontsize)
obj._setup_plots()
obj._plot_valid = False # see https://github.com/yt-project/yt/issues/4489
obj._initfinished = True
return obj
def _get_field_title(self, field_z, profile):
field_x = profile.x_field
field_y = profile.y_field
xfi = profile.field_info[field_x]
yfi = profile.field_info[field_y]
zfi = profile.field_info[field_z]
x_unit = profile.x.units
y_unit = profile.y.units
z_unit = profile.field_units[field_z]
fractional = profile.fractional
x_label, y_label, z_label = self._get_axes_labels(field_z)
x_title = x_label or self._get_field_label(field_x, xfi, x_unit)
y_title = y_label or self._get_field_label(field_y, yfi, y_unit)
z_title = z_label or self._get_field_label(field_z, zfi, z_unit, fractional)
return (x_title, y_title, z_title)
def _get_field_label(self, field, field_info, field_unit, fractional=False):
field_unit = field_unit.latex_representation()
field_name = field_info.display_name
if isinstance(field, tuple):
field = field[1]
if field_name is None:
field_name = field_info.get_latex_display_name()
elif field_name.find("$") == -1:
field_name = field_name.replace(" ", r"\ ")
field_name = r"$\rm{" + field_name + r"}$"
if fractional:
label = field_name + r"$\rm{\ Probability\ Density}$"
elif field_unit is None or field_unit == "":
label = field_name
else:
label = field_name + _get_units_label(field_unit)
return label
def _get_field_log(self, field_z, profile):
zfi = profile.field_info[field_z]
if self.x_log is None:
x_log = profile.x_log
else:
x_log = self.x_log
if self.y_log is None:
y_log = profile.y_log
else:
y_log = self.y_log
if field_z in self.z_log:
z_log = self.z_log[field_z]
else:
z_log = zfi.take_log
scales = {True: "log", False: "linear"}
return scales[x_log], scales[y_log], scales[z_log]
@property
def profile(self):
if not self._profile_valid:
self._recreate_profile()
return self._profile
@property
def fields(self):
return list(self.plots.keys())
def _setup_plots(self):
if self._plot_valid:
return
for f, data in self.profile.items():
if f in self.plots:
pnh = self.plots[f].norm_handler
cbh = self.plots[f].colorbar_handler
draw_axes = self.plots[f]._draw_axes
if self.plots[f].figure is not None:
fig = self.plots[f].figure
axes = self.plots[f].axes
cax = self.plots[f].cax
else:
fig = None
axes = None
cax = None
else:
pnh, cbh = self._get_default_handlers(
field=f, default_display_units=self.profile[f].units
)
fig = None
axes = None
cax = None
draw_axes = True
x_scale, y_scale, z_scale = self._get_field_log(f, self.profile)
x_title, y_title, z_title = self._get_field_title(f, self.profile)
font_size = self._font_properties.get_size()
f = self.profile.data_source._determine_fields(f)[0]
# if this is a Particle Phase Plot AND if we using a single color,
# override the colorbar here.
splat_color = getattr(self, "splat_color", None)
if splat_color is not None:
cbh.cmap = matplotlib.colors.ListedColormap(splat_color, "dummy")
masked_data = data.copy()
masked_data[~self.profile.used] = np.nan
self.plots[f] = PhasePlotMPL(
self.profile.x,
self.profile.y,
masked_data,
x_scale,
y_scale,
self.figure_size,
font_size,
fig,
axes,
cax,
shading=self._shading,
norm_handler=pnh,
colorbar_handler=cbh,
)
self.plots[f]._toggle_axes(draw_axes)
self.plots[f]._toggle_colorbar(cbh.draw_cbar)
self.plots[f].axes.xaxis.set_label_text(x_title)
self.plots[f].axes.yaxis.set_label_text(y_title)
self.plots[f].cax.yaxis.set_label_text(z_title)
self.plots[f].axes.set_xlim(self._xlim)
self.plots[f].axes.set_ylim(self._ylim)
if f in self._plot_text:
self.plots[f].axes.text(
self._text_xpos[f],
self._text_ypos[f],
self._plot_text[f],
fontproperties=self._font_properties,
**self._text_kwargs[f],
)
if f in self.plot_title:
self.plots[f].axes.set_title(self.plot_title[f])
# x-y axes minorticks
if f not in self._minorticks:
self._minorticks[f] = True
if self._minorticks[f]:
self.plots[f].axes.minorticks_on()
else:
self.plots[f].axes.minorticks_off()
self._set_font_properties()
# if this is a particle plot with one color only, hide the cbar here
if hasattr(self, "use_cbar") and not self.use_cbar:
self.plots[f].hide_colorbar()
self._plot_valid = True
[docs]
@classmethod
def from_profile(cls, profile, fontsize=18, figure_size=8.0, shading="nearest"):
r"""
Instantiate a PhasePlot object from a profile object created
with :func:`~yt.data_objects.profiles.create_profile`.
Parameters
----------
profile : An instance of :class:`~yt.data_objects.profiles.ProfileND`
A single profile object.
fontsize : float
The fontsize to use, in points.
figure_size : float
The figure size to use, in inches.
shading : str
This argument is directly passed down to matplotlib.axes.Axes.pcolormesh
see
https://matplotlib.org/3.3.1/gallery/images_contours_and_fields/pcolormesh_grids.html#sphx-glr-gallery-images-contours-and-fields-pcolormesh-grids-py # noqa
Default: 'nearest'
Examples
--------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> extrema = {
... ("gas", "density"): (1e-31, 1e-24),
... ("gas", "temperature"): (1e1, 1e8),
... ("gas", "mass"): (1e-6, 1e-1),
... }
>>> profile = yt.create_profile(
... ds.all_data(),
... [("gas", "density"), ("gas", "temperature")],
... fields=[("gas", "mass")],
... extrema=extrema,
... fractional=True,
... )
>>> ph = yt.PhasePlot.from_profile(profile)
>>> ph.save()
"""
obj = cls.__new__(cls)
data_source = profile.data_source
return cls._initialize_instance(
obj, data_source, profile, fontsize, figure_size, shading
)
[docs]
def annotate_text(self, xpos=0.0, ypos=0.0, text=None, **text_kwargs):
r"""
Allow the user to insert text onto the plot
The x-position and y-position must be given as well as the text string.
Add *text* tp plot at location *xpos*, *ypos* in plot coordinates
(see example below).
Parameters
----------
xpos : float
Position on plot in x-coordinates.
ypos : float
Position on plot in y-coordinates.
text : str
The text to insert onto the plot.
**text_kwargs : dict
Extra keyword arguments will be passed to matplotlib text instance
>>> plot.annotate_text(1e-15, 5e4, "Hello YT")
"""
for f in self.data_source._determine_fields(list(self.plots.keys())):
if self.plots[f].figure is not None and text is not None:
self.plots[f].axes.text(
xpos,
ypos,
text,
fontproperties=self._font_properties,
**text_kwargs,
)
self._plot_text[f] = text
self._text_xpos[f] = xpos
self._text_ypos[f] = ypos
self._text_kwargs[f] = text_kwargs
return self
[docs]
@validate_plot
def save(self, name: str | None = None, suffix: str | None = None, mpl_kwargs=None):
r"""
Saves a 2d profile plot.
Parameters
----------
name : str, optional
The output file keyword.
suffix : string, optional
Specify the image type by its suffix. If not specified, the output
type will be inferred from the filename. Defaults to '.png'.
mpl_kwargs : dict, optional
A dict of keyword arguments to be passed to matplotlib.
>>> plot.save(mpl_kwargs={"bbox_inches": "tight"})
"""
names = []
if not self._plot_valid:
self._setup_plots()
if mpl_kwargs is None:
mpl_kwargs = {}
if name is None:
name = str(self.profile.ds)
name = os.path.expanduser(name)
xfn = self.profile.x_field
yfn = self.profile.y_field
if isinstance(xfn, tuple):
xfn = xfn[1]
if isinstance(yfn, tuple):
yfn = yfn[1]
for f in self.profile.field_data:
_f = f
if isinstance(f, tuple):
_f = _f[1]
middle = f"2d-Profile_{xfn}_{yfn}_{_f}"
splitname = os.path.split(name)
if splitname[0] != "" and not os.path.isdir(splitname[0]):
os.makedirs(splitname[0])
if os.path.isdir(name) and name != str(self.profile.ds):
name = name + (os.sep if name[-1] != os.sep else "")
name += str(self.profile.ds)
new_name = validate_image_name(name, suffix)
if new_name == name:
for v in self.plots.values():
out_name = v.save(name, mpl_kwargs)
names.append(out_name)
return names
name = new_name
prefix, suffix = os.path.splitext(name)
name = f"{prefix}_{middle}{suffix}"
names.append(name)
self.plots[f].save(name, mpl_kwargs)
return names
[docs]
@invalidate_plot
def set_title(self, field, title):
"""Set a title for the plot.
Parameters
----------
field : str
The z field of the plot to add the title.
title : str
The title to add.
Examples
--------
>>> plot.set_title(("gas", "mass"), "This is a phase plot")
"""
self.plot_title[self.data_source._determine_fields(field)[0]] = title
return self
[docs]
@invalidate_plot
def annotate_title(self, title):
"""Set a title for the plot.
Parameters
----------
title : str
The title to add.
Examples
--------
>>> plot.annotate_title("This is a phase plot")
"""
for f in self._profile.field_data:
if isinstance(f, tuple):
f = f[1]
self.plot_title[self.data_source._determine_fields(f)[0]] = title
return self
[docs]
@invalidate_plot
def reset_plot(self):
self.plots = {}
return self
[docs]
@invalidate_plot
def set_log(self, field, log):
"""set a field to log or linear.
Parameters
----------
field : string
the field to set a transform
log : boolean
Log on/off.
"""
p = self._profile
if field == "all":
self.x_log = log
self.y_log = log
for field in p.field_data:
self.z_log[field] = log
self._profile_valid = False
else:
(field,) = self.profile.data_source._determine_fields([field])
if field == p.x_field:
self.x_log = log
self._profile_valid = False
elif field == p.y_field:
self.y_log = log
self._profile_valid = False
elif field in p.field_data:
super().set_log(field, log)
else:
raise KeyError(f"Field {field} not in phase plot!")
return self
[docs]
@invalidate_plot
def set_unit(self, field, unit):
"""Sets a new unit for the requested field
Parameters
----------
field : string
The name of the field that is to be changed.
unit : string or Unit object
The name of the new unit.
"""
fd = self.data_source._determine_fields(field)[0]
if fd == self.profile.x_field:
self.profile.set_x_unit(unit)
elif fd == self.profile.y_field:
self.profile.set_y_unit(unit)
elif fd in self.profile.field_data.keys():
self.profile.set_field_unit(field, unit)
self.plots[field].norm_handler.display_units = unit
else:
raise KeyError(f"Field {field} not in phase plot!")
return self
[docs]
@invalidate_plot
@invalidate_profile
def set_xlim(self, xmin=None, xmax=None):
"""Sets the limits of the x bin field
Parameters
----------
xmin : float or None
The new x minimum in the current x-axis units. Defaults to None,
which leaves the xmin unchanged.
xmax : float or None
The new x maximum in the current x-axis units. Defaults to None,
which leaves the xmax unchanged.
Examples
--------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> pp = yt.PhasePlot(ds.all_data(), "density", "temperature", ("gas", "mass"))
>>> pp.set_xlim(1e-29, 1e-24)
>>> pp.save()
"""
p = self._profile
if xmin is None:
xmin = p.x_bins.min()
elif not hasattr(xmin, "units"):
xmin = self.ds.quan(xmin, p.x_bins.units)
if xmax is None:
xmax = p.x_bins.max()
elif not hasattr(xmax, "units"):
xmax = self.ds.quan(xmax, p.x_bins.units)
self._xlim = (xmin, xmax)
return self
[docs]
@invalidate_plot
@invalidate_profile
def set_ylim(self, ymin=None, ymax=None):
"""Sets the plot limits for the y bin field.
Parameters
----------
ymin : float or None
The new y minimum in the current y-axis units. Defaults to None,
which leaves the ymin unchanged.
ymax : float or None
The new y maximum in the current y-axis units. Defaults to None,
which leaves the ymax unchanged.
Examples
--------
>>> import yt
>>> ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")
>>> pp = yt.PhasePlot(
... ds.all_data(),
... ("gas", "density"),
... ("gas", "temperature"),
... ("gas", "mass"),
... )
>>> pp.set_ylim(1e4, 1e6)
>>> pp.save()
"""
p = self._profile
if ymin is None:
ymin = p.y_bins.min()
elif not hasattr(ymin, "units"):
ymin = self.ds.quan(ymin, p.y_bins.units)
if ymax is None:
ymax = p.y_bins.max()
elif not hasattr(ymax, "units"):
ymax = self.ds.quan(ymax, p.y_bins.units)
self._ylim = (ymin, ymax)
return self
def _recreate_profile(self):
p = self._profile
units = {p.x_field: str(p.x.units), p.y_field: str(p.y.units)}
zunits = {field: str(p.field_units[field]) for field in p.field_units}
extrema = {p.x_field: self._xlim, p.y_field: self._ylim}
if self.x_log is not None or self.y_log is not None:
logs = {}
else:
logs = None
if self.x_log is not None:
logs[p.x_field] = self.x_log
if self.y_log is not None:
logs[p.y_field] = self.y_log
deposition = getattr(p, "deposition", None)
additional_kwargs = {
"accumulation": p.accumulation,
"fractional": p.fractional,
"deposition": deposition,
}
self._profile = create_profile(
p.data_source,
[p.x_field, p.y_field],
list(p.field_map.values()),
n_bins=[len(p.x_bins) - 1, len(p.y_bins) - 1],
weight_field=p.weight_field,
units=units,
extrema=extrema,
logs=logs,
**additional_kwargs,
)
for field in zunits:
self._profile.set_field_unit(field, zunits[field])
self._profile_valid = True
[docs]
class PhasePlotMPL(ImagePlotMPL):
"""A container for a single matplotlib figure and axes for a PhasePlot"""
def __init__(
self,
x_data,
y_data,
data,
x_scale,
y_scale,
figure_size,
fontsize,
figure,
axes,
cax,
shading="nearest",
*,
norm_handler: NormHandler,
colorbar_handler: ColorbarHandler,
):
self._initfinished = False
self._shading = shading
self._setup_layout_constraints(figure_size, fontsize)
# this line is added purely to prevent exact image comparison tests
# to fail, but eventually we should embrace the change and
# use similar values for PhasePlotMPL and WindowPlotMPL
self._ax_text_size[0] *= 1.1 / 1.2 # TODO: remove this
super().__init__(
figure=figure,
axes=axes,
cax=cax,
norm_handler=norm_handler,
colorbar_handler=colorbar_handler,
)
self._init_image(x_data, y_data, data, x_scale, y_scale)
self._initfinished = True
def _init_image(
self,
x_data,
y_data,
image_data,
x_scale,
y_scale,
):
"""Store output of imshow in image variable"""
norm = self.norm_handler.get_norm(image_data)
self.image = None
self.cb = None
self.image = self.axes.pcolormesh(
np.array(x_data),
np.array(y_data),
np.array(image_data.T),
norm=norm,
cmap=self.colorbar_handler.cmap,
shading=self._shading,
)
self._set_axes()
self.axes.set_xscale(x_scale)
self.axes.set_yscale(y_scale)