import glob
import weakref
from collections import defaultdict
from functools import cached_property, partial
import numpy as np
from yt.data_objects.selection_objects.data_selection_objects import (
YTSelectionContainer,
)
from yt.data_objects.static_output import (
ParticleDataset,
ParticleFile,
)
from yt.frontends.ytdata.data_structures import SavedDataset
from yt.funcs import parse_h5_attr
from yt.geometry.particle_geometry_handler import ParticleIndex
from yt.utilities.on_demand_imports import _h5py as h5py
from .fields import YTHaloCatalogFieldInfo, YTHaloCatalogHaloFieldInfo
[docs]
class HaloCatalogFile(ParticleFile):
"""
Base class for data files of halo catalog datasets.
This is mainly here to correct for periodicity when
reading particle positions.
"""
def __init__(self, ds, io, filename, file_id, frange):
super().__init__(ds, io, filename, file_id, frange)
def _read_particle_positions(self, ptype, f=None):
raise NotImplementedError
def _get_particle_positions(self, ptype, f=None):
pcount = self.total_particles[ptype]
if pcount == 0:
return None
# Correct for periodicity.
dle = self.ds.domain_left_edge.to("code_length").v
dw = self.ds.domain_width.to("code_length").v
pos = self._read_particle_positions(ptype, f=f)
si, ei = self.start, self.end
if None not in (si, ei):
pos = pos[si:ei]
np.subtract(pos, dle, out=pos)
np.mod(pos, dw, out=pos)
np.add(pos, dle, out=pos)
return pos
[docs]
class YTHaloCatalogFile(HaloCatalogFile):
"""
Data file class for the YTHaloCatalogDataset.
"""
def __init__(self, ds, io, filename, file_id, frange):
with h5py.File(filename, mode="r") as f:
self.header = {field: parse_h5_attr(f, field) for field in f.attrs.keys()}
pids = f.get("particles/ids")
self.total_ids = 0 if pids is None else pids.size
self.group_length_sum = self.total_ids
super().__init__(ds, io, filename, file_id, frange)
def _read_particle_positions(self, ptype, f=None):
"""
Read all particle positions in this file.
"""
if f is None:
close = True
f = h5py.File(self.filename, mode="r")
else:
close = False
pcount = self.header["num_halos"]
pos = np.empty((pcount, 3), dtype="float64")
for i, ax in enumerate("xyz"):
pos[:, i] = f[f"particle_position_{ax}"][()]
if close:
f.close()
return pos
[docs]
class YTHaloCatalogDataset(SavedDataset):
"""
Dataset class for halo catalogs made with yt.
This covers yt FoF/HoP halo finders and the halo analysis
in yt_astro_analysis.
"""
_load_requirements = ["h5py"]
_index_class = ParticleIndex
_file_class = YTHaloCatalogFile
_field_info_class = YTHaloCatalogFieldInfo
_suffix = ".h5"
_con_attrs = (
"cosmological_simulation",
"current_time",
"current_redshift",
"hubble_constant",
"omega_matter",
"omega_lambda",
"domain_left_edge",
"domain_right_edge",
)
def __init__(
self,
filename,
dataset_type="ythalocatalog",
index_order=None,
units_override=None,
unit_system="cgs",
):
self.index_order = index_order
super().__init__(
filename,
dataset_type,
units_override=units_override,
unit_system=unit_system,
)
[docs]
def add_field(self, *args, **kwargs):
super().add_field(*args, **kwargs)
self._halos_ds.add_field(*args, **kwargs)
@property
def halos_field_list(self):
return self._halos_ds.field_list
@property
def halos_derived_field_list(self):
return self._halos_ds.derived_field_list
@cached_property
def _halos_ds(self):
return YTHaloDataset(self)
def _setup_classes(self):
super()._setup_classes()
self.halo = partial(YTHaloCatalogHaloContainer, ds=self._halos_ds)
self.halo.__doc__ = YTHaloCatalogHaloContainer.__doc__
def _parse_parameter_file(self):
self.refine_by = 2
self.dimensionality = 3
self.domain_dimensions = np.ones(self.dimensionality, "int32")
self._periodicity = (True, True, True)
prefix = ".".join(self.parameter_filename.rsplit(".", 2)[:-2])
self.filename_template = f"{prefix}.%(num)s{self._suffix}"
self.file_count = len(glob.glob(prefix + "*" + self._suffix))
self.particle_types = ("halos",)
self.particle_types_raw = ("halos",)
super()._parse_parameter_file()
@classmethod
def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
if not filename.endswith(".h5"):
return False
if cls._missing_load_requirements():
return False
with h5py.File(filename, mode="r") as f:
if (
"data_type" in f.attrs
and parse_h5_attr(f, "data_type") == "halo_catalog"
):
return True
return False
[docs]
class YTHaloParticleIndex(ParticleIndex):
"""
Particle index for getting halo particles from YTHaloCatalogDatasets.
"""
def __init__(self, ds, dataset_type):
self.real_ds = weakref.proxy(ds.real_ds)
super().__init__(ds, dataset_type)
def _calculate_particle_index_starts(self):
"""
Create a dict of halo id offsets for each file.
"""
particle_count = defaultdict(int)
offset_count = 0
for data_file in self.data_files:
data_file.index_start = {
ptype: particle_count[ptype] for ptype in data_file.total_particles
}
data_file.offset_start = offset_count
for ptype in data_file.total_particles:
particle_count[ptype] += data_file.total_particles[ptype]
offset_count += getattr(data_file, "total_offset", 0)
self._halo_index_start = {}
for ptype in self.ds.particle_types_raw:
d = [data_file.index_start[ptype] for data_file in self.data_files]
self._halo_index_start.update({ptype: np.array(d)})
def _detect_output_fields(self):
field_list = []
scalar_field_list = []
units = {}
pc = {}
for ptype in self.ds.particle_types_raw:
d = [df.total_particles[ptype] for df in self.data_files]
pc.update({ptype: sum(d)})
found_fields = {ptype: False for ptype, pnum in pc.items() if pnum > 0}
has_ids = False
for data_file in self.data_files:
fl, sl, idl, _units = self.io._identify_fields(data_file)
units.update(_units)
field_list.extend([f for f in fl if f not in field_list])
scalar_field_list.extend([f for f in sl if f not in scalar_field_list])
for ptype in found_fields:
found_fields[ptype] |= data_file.total_particles[ptype]
has_ids |= len(idl) > 0
if all(found_fields.values()) and has_ids:
break
self.field_list = field_list
self.scalar_field_list = scalar_field_list
ds = self.dataset
ds.scalar_field_list = scalar_field_list
ds.particle_types = tuple({pt for pt, ds in field_list})
ds.field_units.update(units)
ds.particle_types_raw = ds.particle_types
def _get_halo_file_indices(self, ptype, identifiers):
"""
Get the index of the data file list where this halo lives.
Digitize returns i such that bins[i-1] <= x < bins[i], so we subtract
one because we will open data file i.
"""
return np.digitize(identifiers, self._halo_index_start[ptype], right=False) - 1
def _get_halo_scalar_index(self, ptype, identifier):
i_scalar = self._get_halo_file_indices(ptype, [identifier])[0]
scalar_index = identifier - self._halo_index_start[ptype][i_scalar]
return scalar_index
def _get_halo_values(self, ptype, identifiers, fields, f=None):
"""
Get field values for halo data containers.
"""
# if a file is already open, don't open it again
filename = None if f is None else f.filename
data = defaultdict(lambda: np.empty(identifiers.size))
i_scalars = self._get_halo_file_indices(ptype, identifiers)
for i_scalar in np.unique(i_scalars):
# mask array to get field data for this halo
target = i_scalars == i_scalar
scalar_indices = identifiers - self._halo_index_start[ptype][i_scalar]
# only open file if it's not already open
my_f = (
f
if self.data_files[i_scalar].filename == filename
else h5py.File(self.data_files[i_scalar].filename, mode="r")
)
for field in fields:
data[field][target] = self._read_halo_particle_field(
my_f, ptype, field, scalar_indices[target]
)
if self.data_files[i_scalar].filename != filename:
my_f.close()
return data
def _identify_base_chunk(self, dobj):
pass
def _read_halo_particle_field(self, fh, ptype, field, indices):
return fh[field][indices]
def _read_particle_fields(self, fields, dobj, chunk=None):
if not fields:
return {}, []
fields_to_read, fields_to_generate = self._split_fields(fields)
if not fields_to_read:
return {}, fields_to_generate
fields_to_return = self.io._read_particle_selection(dobj, fields_to_read)
return fields_to_return, fields_to_generate
def _setup_data_io(self):
super()._setup_data_io()
if self.real_ds._instantiated_index is None:
self.real_ds.index
self.real_ds.index
# inherit some things from parent index
self._data_files = self.real_ds.index.data_files
self._total_particles = self.real_ds.index.total_particles
self._calculate_particle_index_starts()
[docs]
class HaloDataset(ParticleDataset):
"""
Base class for dataset accessing particles from halo catalogs.
"""
def __init__(self, ds, dataset_type):
self.real_ds = ds
for attr in [
"filename_template",
"file_count",
"particle_types_raw",
"particle_types",
"_periodicity",
]:
setattr(self, attr, getattr(self.real_ds, attr))
super().__init__(self.real_ds.parameter_filename, dataset_type)
[docs]
def print_key_parameters(self):
pass
def _set_derived_attrs(self):
pass
def _parse_parameter_file(self):
for attr in [
"cosmological_simulation",
"cosmology",
"current_redshift",
"current_time",
"dimensionality",
"domain_dimensions",
"domain_left_edge",
"domain_right_edge",
"domain_width",
"hubble_constant",
"omega_lambda",
"omega_matter",
"unique_identifier",
]:
setattr(self, attr, getattr(self.real_ds, attr))
[docs]
def set_code_units(self):
self._set_code_unit_attributes()
self.unit_registry = self.real_ds.unit_registry
def _set_code_unit_attributes(self):
for unit in ["length", "time", "mass", "velocity", "magnetic", "temperature"]:
my_unit = f"{unit}_unit"
setattr(self, my_unit, getattr(self.real_ds, my_unit, None))
def __str__(self):
return f"{self.real_ds}"
def _setup_classes(self):
self.objects = []
[docs]
class YTHaloDataset(HaloDataset):
"""
Dataset used for accessing member particles from YTHaloCatalogDatasets.
"""
_index_class = YTHaloParticleIndex
_file_class = YTHaloCatalogFile
_field_info_class = YTHaloCatalogHaloFieldInfo
def __init__(self, ds, dataset_type="ythalo"):
super().__init__(ds, dataset_type)
def _set_code_unit_attributes(self):
pass
@classmethod
def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
# We don't ever want this to be loaded by yt.load.
return False
[docs]
class HaloContainer(YTSelectionContainer):
"""
Base class for data containers providing halo particles.
"""
_type_name = "halo"
_con_args = ("ptype", "particle_identifier")
_skip_add = True
_spatial = False
def __init__(self, ptype, particle_identifier, ds=None):
if ptype not in ds.particle_types_raw:
raise RuntimeError(
f'Possible halo types are {ds.particle_types_raw}, supplied "{ptype}".'
)
self.ptype = ptype
self._current_particle_type = ptype
super().__init__(ds, {})
self._set_identifiers(particle_identifier)
# Find the file that has the scalar values for this halo.
i_scalar = self.index._get_halo_file_indices(ptype, [self.particle_identifier])[
0
]
self.i_scalar = i_scalar
self.scalar_data_file = self.index.data_files[i_scalar]
# Data files containing particles belonging to this halo.
self.field_data_files = [self.index.data_files[i_scalar]]
# index within halo arrays that corresponds to this halo
self.scalar_index = self.index._get_halo_scalar_index(
ptype, self.particle_identifier
)
self._set_io_data()
self.particle_number = self._get_particle_number()
# starting and ending indices for each file containing particles
self._set_field_indices()
@cached_property
def mass(self):
return self[self.ptype, "particle_mass"][0]
@cached_property
def radius(self):
return self[self.ptype, "virial_radius"][0]
@cached_property
def position(self):
return self[self.ptype, "particle_position"][0]
@cached_property
def velocity(self):
return self[self.ptype, "particle_velocity"][0]
def _set_io_data(self):
halo_fields = self._get_member_fieldnames()
my_data = self.index._get_halo_values(
self.ptype, np.array([self.particle_identifier]), halo_fields
)
self._io_data = {field: np.int64(val[0]) for field, val in my_data.items()}
def __repr__(self):
return f"{self.ds}_{self.ptype}_{self.particle_identifier:09d}"
[docs]
class YTHaloCatalogHaloContainer(HaloContainer):
"""
Data container for accessing particles from a halo.
Create a data container to get member particles and individual
values from halos and subhalos. Halo mass, radius, position, and
velocity are set as attributes. Halo IDs are accessible
through the field, "member_ids". Other fields that are one
value per halo are accessible as normal. The field list for
halo objects can be seen in `ds.halos_field_list`.
Parameters
----------
ptype : string
The type of halo. Possible options can be found by
inspecting the value of ds.particle_types_raw.
particle_identifier : int
The halo id.
Examples
--------
>>> import yt
>>> ds = yt.load("tiny_fof_halos/DD0046/DD0046.0.h5")
>>> halo = ds.halo("halos", 0)
>>> print(halo.particle_identifier)
0
>>> print(halo.mass)
8724990744704.453 Msun
>>> print(halo.radius)
658.8140635766607 kpc
>>> print(halo.position)
[0.05496909 0.19451951 0.04056824] code_length
>>> print(halo.velocity)
[7034181.07118151 5323471.09102874 3234522.50495914] cm/s
>>> # particle ids for this halo
>>> print(halo["member_ids"])
[ 1248. 129. 128. 31999. 31969. 31933. 31934. 159. 31903. 31841. ...
2241. 2240. 2239. 2177. 2209. 2207. 2208.] dimensionless
"""
def _get_member_fieldnames(self):
return ["particle_number", "particle_index_start"]
def _get_particle_number(self):
return self._io_data["particle_number"]
def _set_field_indices(self):
self.field_data_start = [self._io_data["particle_index_start"]]
self.field_data_end = [self.field_data_start[0] + self.particle_number]
def _set_identifiers(self, particle_identifier):
self.particle_identifier = particle_identifier
self.group_identifier = self.particle_identifier