import numpy as np
from yt.utilities.io_handler import BaseIOHandler
from yt.utilities.logger import ytLogger as mylog
[docs]
class IOHandlerFITS(BaseIOHandler):
_particle_reader = False
_dataset_type = "fits"
def __init__(self, ds):
super().__init__(ds)
self.ds = ds
self._handle = ds._handle
def _read_particles(
self, fields_to_read, type, args, grid_list, count_list, conv_factors
):
pass
def _read_particle_coords(self, chunks, ptf):
pdata = self.ds._handle[self.ds.first_image].data
assert len(ptf) == 1
ptype = list(ptf.keys())[0]
x = np.asarray(pdata.field("X"), dtype="=f8")
y = np.asarray(pdata.field("Y"), dtype="=f8")
z = np.ones(x.shape)
x = (x - 0.5) / self.ds.reblock + 0.5
y = (y - 0.5) / self.ds.reblock + 0.5
yield ptype, (x, y, z), 0.0
def _read_particle_fields(self, chunks, ptf, selector):
pdata = self.ds._handle[self.ds.first_image].data
assert len(ptf) == 1
ptype = list(ptf.keys())[0]
field_list = ptf[ptype]
x = np.asarray(pdata.field("X"), dtype="=f8")
y = np.asarray(pdata.field("Y"), dtype="=f8")
z = np.ones(x.shape)
x = (x - 0.5) / self.ds.reblock + 0.5
y = (y - 0.5) / self.ds.reblock + 0.5
mask = selector.select_points(x, y, z, 0.0)
if mask is None:
return
for field in field_list:
fd = field.split("_")[-1]
data = pdata.field(fd.upper())
if fd in ["x", "y"]:
data = (data.copy() - 0.5) / self.ds.reblock + 0.5
yield (ptype, field), data[mask]
def _read_fluid_selection(self, chunks, selector, fields, size):
chunks = list(chunks)
if any((ftype != "fits" for ftype, fname in fields)):
raise NotImplementedError
rv = {}
dt = "float64"
for field in fields:
rv[field] = np.empty(size, dtype=dt)
ng = sum(len(c.objs) for c in chunks)
mylog.debug(
"Reading %s cells of %s fields in %s grids",
size,
[f2 for f1, f2 in fields],
ng,
)
dx = self.ds.domain_width / self.ds.domain_dimensions
for field in fields:
ftype, fname = field
f = self.ds.index._file_map[fname]
ds = f[self.ds.index._ext_map[fname]]
bzero, bscale = self.ds.index._scale_map[fname]
ind = 0
for chunk in chunks:
for g in chunk.objs:
start = ((g.LeftEdge - self.ds.domain_left_edge) / dx).d.astype(
"int"
)
end = start + g.ActiveDimensions
slices = [slice(start[i], end[i]) for i in range(3)]
if self.ds.dimensionality == 2:
nx, ny = g.ActiveDimensions[:2]
nz = 1
data = np.zeros((nx, ny, nz))
data[:, :, 0] = ds.data[slices[1], slices[0]].T
elif self.ds.naxis == 4:
idx = self.ds.index._axis_map[fname]
data = ds.data[idx, slices[2], slices[1], slices[0]].T
else:
data = ds.data[slices[2], slices[1], slices[0]].T
if fname in self.ds.nan_mask:
data[np.isnan(data)] = self.ds.nan_mask[fname]
elif "all" in self.ds.nan_mask:
data[np.isnan(data)] = self.ds.nan_mask["all"]
data = bzero + bscale * data
ind += g.select(selector, data.astype("float64"), rv[field], ind)
return rv