from collections import defaultdict
import numpy as np
from yt.frontends.gadget_fof.io import IOHandlerGadgetFOFHaloHDF5
from yt.funcs import parse_h5_attr
from yt.units._numpy_wrapper_functions import uvstack
from yt.utilities.io_handler import BaseParticleIOHandler
from yt.utilities.on_demand_imports import _h5py as h5py
[docs]
class IOHandlerYTHaloCatalog(BaseParticleIOHandler):
_dataset_type = "ythalocatalog"
def _read_fluid_selection(self, chunks, selector, fields, size):
raise NotImplementedError
def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
pn = "particle_position_%s"
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
units = parse_h5_attr(f[pn % "x"], "units")
pos = data_file._get_particle_positions(ptype, f=f)
x, y, z = (self.ds.arr(pos[:, i], units) for i in range(3))
yield "halos", (x, y, z), 0.0
def _yield_coordinates(self, data_file):
pn = "particle_position_%s"
with h5py.File(data_file.filename, mode="r") as f:
units = parse_h5_attr(f[pn % "x"], "units")
x, y, z = (
self.ds.arr(f[pn % ax][()].astype("float64"), units) for ax in "xyz"
)
pos = uvstack([x, y, z]).T
pos.convert_to_units("code_length")
yield "halos", pos
def _read_particle_fields(self, chunks, ptf, selector):
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
pn = "particle_position_%s"
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
units = parse_h5_attr(f[pn % "x"], "units")
pos = data_file._get_particle_positions(ptype, f=f)
x, y, z = (self.ds.arr(pos[:, i], units) for i in range(3))
mask = selector.select_points(x, y, z, 0.0)
del x, y, z
if mask is None:
continue
for field in field_list:
data = f[field][si:ei][mask].astype("float64")
yield (ptype, field), data
def _count_particles(self, data_file):
si, ei = data_file.start, data_file.end
nhalos = data_file.header["num_halos"]
if None not in (si, ei):
nhalos = np.clip(nhalos - si, 0, ei - si)
return {"halos": nhalos}
def _identify_fields(self, data_file):
with h5py.File(data_file.filename, mode="r") as f:
fields = [
("halos", field) for field in f if not isinstance(f[field], h5py.Group)
]
units = {("halos", field): parse_h5_attr(f[field], "units") for field in f}
return fields, units
[docs]
class HaloDatasetIOHandler:
"""
Base class for io handlers to load halo member particles.
"""
def _read_particle_coords(self, chunks, ptf):
pass
def _read_particle_fields(self, dobj, ptf):
# separate member particle fields from scalar fields
scalar_fields = defaultdict(list)
member_fields = defaultdict(list)
for ptype, field_list in sorted(ptf.items()):
for field in field_list:
if (ptype, field) in self.ds.scalar_field_list:
scalar_fields[ptype].append(field)
else:
member_fields[ptype].append(field)
all_data = self._read_scalar_fields(dobj, scalar_fields)
all_data.update(self._read_member_fields(dobj, member_fields))
for field, field_data in all_data.items():
yield field, field_data
# This will be refactored.
_read_particle_selection = IOHandlerGadgetFOFHaloHDF5._read_particle_selection
# ignoring type in this mixing to circumvent this error from mypy
# Definition of "_read_particle_fields" in base class "HaloDatasetIOHandler"
# is incompatible with definition in base class "IOHandlerYTHaloCatalog"
#
# it may not be possible to refactor out of this situation without breaking downstream
[docs]
class IOHandlerYTHalo(HaloDatasetIOHandler, IOHandlerYTHaloCatalog): # type: ignore
_dataset_type = "ythalo"
def _identify_fields(self, data_file):
with h5py.File(data_file.filename, mode="r") as f:
scalar_fields = [
("halos", field) for field in f if not isinstance(f[field], h5py.Group)
]
units = {("halos", field): parse_h5_attr(f[field], "units") for field in f}
if "particles" in f:
id_fields = [("halos", field) for field in f["particles"]]
else:
id_fields = []
return scalar_fields + id_fields, scalar_fields, id_fields, units
def _read_member_fields(self, dobj, member_fields):
all_data = defaultdict(lambda: np.empty(dobj.particle_number, dtype=np.float64))
if not member_fields:
return all_data
field_start = 0
for i, data_file in enumerate(dobj.field_data_files):
start_index = dobj.field_data_start[i]
end_index = dobj.field_data_end[i]
pcount = end_index - start_index
if pcount == 0:
continue
field_end = field_start + end_index - start_index
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(member_fields.items()):
for field in field_list:
field_data = all_data[ptype, field]
my_data = f["particles"][field][start_index:end_index].astype(
"float64"
)
field_data[field_start:field_end] = my_data
field_start = field_end
return all_data
def _read_scalar_fields(self, dobj, scalar_fields):
all_data = {}
if not scalar_fields:
return all_data
with h5py.File(dobj.scalar_data_file.filename, mode="r") as f:
for ptype, field_list in sorted(scalar_fields.items()):
for field in field_list:
data = np.array([f[field][dobj.scalar_index]]).astype("float64")
all_data[ptype, field] = data
return all_data