import base64
import os
from io import BytesIO
import numpy as np
from yt.fields.derived_field import ValidateSpatial
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _astropy
def _make_counts(emin, emax):
def _counts(field, data):
e = data["all", "event_energy"].in_units("keV")
mask = np.logical_and(e >= emin, e < emax)
x = data["all", "event_x"][mask]
y = data["all", "event_y"][mask]
z = np.ones(x.shape)
pos = np.array([x, y, z]).transpose()
img = data.deposit(pos, method="count")
if data.has_field_parameter("sigma"):
sigma = data.get_field_parameter("sigma")
else:
sigma = None
if sigma is not None and sigma > 0.0:
kern = _astropy.conv.Gaussian2DKernel(x_stddev=sigma)
img[:, :, 0] = _astropy.conv.convolve(img[:, :, 0], kern)
return data.ds.arr(img, "counts/pixel")
return _counts
[docs]
def setup_counts_fields(ds, ebounds, ftype="gas"):
r"""
Create deposited image fields from X-ray count data in energy bands.
Parameters
----------
ds : ~yt.data_objects.static_output.Dataset
The FITS events file dataset to add the counts fields to.
ebounds : list of tuples
A list of tuples, one for each field, with (emin, emax) as the
energy bounds for the image.
ftype : string, optional
The field type of the resulting field. Defaults to "gas".
Examples
--------
>>> ds = yt.load("evt.fits")
>>> ebounds = [(0.1, 2.0), (2.0, 3.0)]
>>> setup_counts_fields(ds, ebounds)
"""
for emin, emax in ebounds:
cfunc = _make_counts(emin, emax)
fname = f"counts_{emin}-{emax}"
mylog.info("Creating counts field %s.", fname)
ds.add_field(
(ftype, fname),
sampling_type="cell",
function=cfunc,
units="counts/pixel",
validators=[ValidateSpatial()],
display_name=f"Counts ({emin}-{emax} keV)",
)
[docs]
def create_spectral_slabs(filename, slab_centers, slab_width, **kwargs):
r"""
Given a dictionary of spectral slab centers and a width in
spectral units, extract data from a spectral cube at these slab
centers and return a `FITSDataset` instance containing the different
slabs as separate yt fields. Useful for extracting individual
lines from a spectral cube and separating them out as different fields.
Requires the SpectralCube (https://spectral-cube.readthedocs.io/en/latest/)
library.
All keyword arguments will be passed on to the `FITSDataset` constructor.
Parameters
----------
filename : string
The spectral cube FITS file to extract the data from.
slab_centers : dict of (float, string) tuples or YTQuantities
The centers of the slabs, where the keys are the names
of the new fields and the values are (float, string) tuples or
YTQuantities, specifying a value for each center and its unit.
slab_width : YTQuantity or (float, string) tuple
The width of the slab along the spectral axis.
Examples
--------
>>> slab_centers = {
... "13CN": (218.03117, "GHz"),
... "CH3CH2CHO": (218.284256, "GHz"),
... "CH3NH2": (218.40956, "GHz"),
... }
>>> slab_width = (0.05, "GHz")
>>> ds = create_spectral_slabs(
... "intensity_cube.fits", slab_centers, slab_width, nan_mask=0.0
... )
"""
from spectral_cube import SpectralCube
from yt.frontends.fits.api import FITSDataset
from yt.visualization.fits_image import FITSImageData
cube = SpectralCube.read(filename)
if not isinstance(slab_width, YTQuantity):
slab_width = YTQuantity(slab_width[0], slab_width[1])
slab_data = {}
field_units = cube.header.get("bunit", "dimensionless")
for k, v in slab_centers.items():
if not isinstance(v, YTQuantity):
slab_center = YTQuantity(v[0], v[1])
else:
slab_center = v
mylog.info("Adding slab field %s at %g %s", k, slab_center.v, slab_center.units)
slab_lo = (slab_center - 0.5 * slab_width).to_astropy()
slab_hi = (slab_center + 0.5 * slab_width).to_astropy()
subcube = cube.spectral_slab(slab_lo, slab_hi)
slab_data[k] = YTArray(subcube.filled_data[:, :, :], field_units)
width = subcube.header["naxis3"] * cube.header["cdelt3"]
w = subcube.wcs.copy()
w.wcs.crpix[-1] = 0.5
w.wcs.crval[-1] = -0.5 * width
fid = FITSImageData(slab_data, wcs=w)
for hdu in fid:
hdu.header.pop("RESTFREQ", None)
hdu.header.pop("RESTFRQ", None)
ds = FITSDataset(fid, **kwargs)
return ds
[docs]
def ds9_region(ds, reg, obj=None, field_parameters=None):
r"""
Create a data container from a ds9 region file. Requires the regions
package (https://astropy-regions.readthedocs.io/) to be installed.
Parameters
----------
ds : FITSDataset
The Dataset to create the region from.
reg : string
The filename of the ds9 region, or a region string to be parsed.
obj : data container, optional
The data container that will be used to create the new region.
Defaults to ds.all_data.
field_parameters : dictionary, optional
A set of field parameters to apply to the region.
Examples
--------
>>> ds = yt.load("m33_hi.fits")
>>> circle_region = ds9_region(ds, "circle.reg")
>>> print(circle_region.quantities.extrema("flux"))
"""
from yt.utilities.on_demand_imports import _astropy, _regions
Regions = _regions.Regions
WCS = _astropy.WCS
from yt.frontends.fits.api import EventsFITSDataset
if os.path.exists(reg):
method = Regions.read
else:
method = Regions.parse
r = method(reg, format="ds9").regions[0]
reg_name = reg
header = ds.wcs_2d.to_header()
# The FITS header only contains WCS-related keywords
header["NAXIS1"] = ds.domain_dimensions[ds.lon_axis]
header["NAXIS2"] = ds.domain_dimensions[ds.lat_axis]
pixreg = r.to_pixel(WCS(header))
mask = pixreg.to_mask().to_image((header["NAXIS1"], header["NAXIS2"])).astype(bool)
if isinstance(ds, EventsFITSDataset):
prefix = "event_"
else:
prefix = ""
def _reg_field(field, data):
i = data[prefix + "xyz"[ds.lon_axis]].d.astype("int64") - 1
j = data[prefix + "xyz"[ds.lat_axis]].d.astype("int64") - 1
new_mask = mask[i, j]
ret = np.zeros(data[prefix + "x"].shape)
ret[new_mask] = 1.0
return ret
ds.add_field(("gas", reg_name), sampling_type="cell", function=_reg_field)
if obj is None:
obj = ds.all_data()
if field_parameters is not None:
for k, v in field_parameters.items():
obj.set_field_parameter(k, v)
return obj.cut_region([f"obj['{reg_name}'] > 0"])
[docs]
class PlotWindowWCS:
r"""
Use AstroPy's WCSAxes class to plot celestial coordinates on the axes of a
on-axis PlotWindow plot. See
http://docs.astropy.org/en/stable/visualization/wcsaxes/ for more details
on how it works under the hood. This functionality requires a version of
AstroPy >= 1.3.
Parameters
----------
pw : on-axis PlotWindow instance
The PlotWindow instance to add celestial axes to.
"""
def __init__(self, pw):
WCSAxes = _astropy.wcsaxes.WCSAxes
if pw.oblique:
raise NotImplementedError("WCS axes are not implemented for oblique plots.")
if not hasattr(pw.ds, "wcs_2d"):
raise NotImplementedError("WCS axes are not implemented for this dataset.")
if pw.data_source.axis != pw.ds.spec_axis:
raise NotImplementedError("WCS axes are not implemented for this axis.")
self.plots = {}
self.pw = pw
for f in pw.plots:
rect = pw.plots[f]._get_best_layout()[1]
fig = pw.plots[f].figure
ax = fig.axes[0]
wcs_ax = WCSAxes(fig, rect, wcs=pw.ds.wcs_2d, frameon=False)
fig.add_axes(wcs_ax)
wcs = pw.ds.wcs_2d.wcs
xax = pw.ds.coordinates.x_axis[pw.data_source.axis]
yax = pw.ds.coordinates.y_axis[pw.data_source.axis]
xlabel = f"{wcs.ctype[xax].split('-')[0]} ({wcs.cunit[xax]})"
ylabel = f"{wcs.ctype[yax].split('-')[0]} ({wcs.cunit[yax]})"
fp = pw._font_properties
wcs_ax.coords[0].set_axislabel(xlabel, fontproperties=fp, minpad=0.5)
wcs_ax.coords[1].set_axislabel(ylabel, fontproperties=fp, minpad=0.4)
wcs_ax.coords[0].ticklabels.set_fontproperties(fp)
wcs_ax.coords[1].ticklabels.set_fontproperties(fp)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
wcs_ax.set_xlim(pw.xlim[0].value, pw.xlim[1].value)
wcs_ax.set_ylim(pw.ylim[0].value, pw.ylim[1].value)
wcs_ax.coords.frame._update_cache = []
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
self.plots[f] = fig
[docs]
def keys(self):
return self.plots.keys()
[docs]
def values(self):
return self.plots.values()
[docs]
def items(self):
return self.plots.items()
def __getitem__(self, key):
for k in self.keys():
if k[1] == key:
return self.plots[k]
[docs]
def show(self):
return self
[docs]
def save(self, name=None, mpl_kwargs=None):
if mpl_kwargs is None:
mpl_kwargs = {}
mpl_kwargs["bbox_inches"] = "tight"
self.pw.save(name=name, mpl_kwargs=mpl_kwargs)
def _repr_html_(self):
from matplotlib.backends.backend_agg import FigureCanvasAgg
ret = ""
for v in self.plots.values():
canvas = FigureCanvasAgg(v)
f = BytesIO()
canvas.print_figure(f)
f.seek(0)
img = base64.b64encode(f.read()).decode()
ret += (
r'<img style="max-width:100%%;max-height:100%%;" '
rf'src="data:image/png;base64,{img}"><br>'
)
return ret