Source code for yt.frontends.stream.data_structures

import os
import time
import uuid
import weakref
from collections import UserDict
from functools import cached_property
from itertools import chain, product, repeat
from numbers import Number as numeric_type

import numpy as np
from more_itertools import always_iterable

from yt._typing import AxisOrder, FieldKey
from yt.data_objects.field_data import YTFieldData
from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
from yt.data_objects.index_subobjects.octree_subset import OctreeSubset
from yt.data_objects.index_subobjects.stretched_grid import StretchedGrid
from yt.data_objects.index_subobjects.unstructured_mesh import (
    SemiStructuredMesh,
    UnstructuredMesh,
)
from yt.data_objects.static_output import Dataset, ParticleFile
from yt.data_objects.unions import MeshUnion, ParticleUnion
from yt.frontends.sph.data_structures import SPHParticleIndex
from yt.funcs import setdefaultattr
from yt.geometry.api import Geometry
from yt.geometry.geometry_handler import Index, YTDataChunk
from yt.geometry.grid_geometry_handler import GridIndex
from yt.geometry.oct_container import OctreeContainer
from yt.geometry.oct_geometry_handler import OctreeIndex
from yt.geometry.unstructured_mesh_handler import UnstructuredIndex
from yt.units import YTQuantity
from yt.utilities.io_handler import io_registry
from yt.utilities.lib.cykdtree import PyKDTree
from yt.utilities.lib.misc_utilities import (
    _obtain_coords_and_widths,
    get_box_grids_level,
)
from yt.utilities.lib.particle_kdtree_tools import (
    estimate_density,
    generate_smoothing_length,
)
from yt.utilities.logger import ytLogger as mylog

from .definitions import process_data, set_particle_types
from .fields import StreamFieldInfo


[docs] class StreamGrid(AMRGridPatch): """ Class representing a single In-memory Grid instance. """ __slots__ = ["proc_num"] _id_offset = 0 def __init__(self, id, index): """ Returns an instance of StreamGrid with *id*, associated with *filename* and *index*. """ # All of the field parameters will be passed to us as needed. AMRGridPatch.__init__(self, id, filename=None, index=index) self._children_ids = [] self._parent_id = -1 self.Level = -1
[docs] def set_filename(self, filename): pass
@property def Parent(self): if self._parent_id == -1: return None return self.index.grids[self._parent_id - self._id_offset] @property def Children(self): return [self.index.grids[cid - self._id_offset] for cid in self._children_ids]
[docs] class StreamStretchedGrid(StretchedGrid): _id_offset = 0 def __init__(self, id, index): cell_widths = index.grid_cell_widths[id - self._id_offset] super().__init__(id, cell_widths, index=index) self._children_ids = [] self._parent_id = -1 self.Level = -1 @property def Parent(self): if self._parent_id == -1: return None return self.index.grids[self._parent_id - self._id_offset] @property def Children(self): return [self.index.grids[cid - self._id_offset] for cid in self._children_ids]
[docs] class StreamHandler: def __init__( self, left_edges, right_edges, dimensions, levels, parent_ids, particle_count, processor_ids, fields, field_units, code_units, io=None, particle_types=None, periodicity=(True, True, True), *, cell_widths=None, parameters=None, ): if particle_types is None: particle_types = {} self.left_edges = np.array(left_edges) self.right_edges = np.array(right_edges) self.dimensions = dimensions self.levels = levels self.parent_ids = parent_ids self.particle_count = particle_count self.processor_ids = processor_ids self.num_grids = self.levels.size self.fields = fields self.field_units = field_units self.code_units = code_units self.io = io self.particle_types = particle_types self.periodicity = periodicity self.cell_widths = cell_widths if parameters is None: self.parameters = {} else: self.parameters = parameters.copy()
[docs] def get_fields(self): return self.fields.all_fields
[docs] def get_particle_type(self, field): if field in self.particle_types: return self.particle_types[field] else: return False
[docs] class StreamHierarchy(GridIndex): grid = StreamGrid def __init__(self, ds, dataset_type=None): self.dataset_type = dataset_type self.float_type = "float64" self.dataset = weakref.proxy(ds) # for _obtain_enzo self.stream_handler = ds.stream_handler self.float_type = "float64" self.directory = os.getcwd() GridIndex.__init__(self, ds, dataset_type) def _count_grids(self): self.num_grids = self.stream_handler.num_grids def _icoords_to_fcoords(self, icoords, ires, axes=None): """ We check here that we have cell_widths, and if we do, we will provide them. """ if self.grid_cell_widths is None: return super()._icoords_to_fcoords(icoords, ires, axes) if axes is None: axes = [0, 1, 2] # Transpose these by reversing the shape coords = np.empty(icoords.shape, dtype="f8") cell_widths = np.empty(icoords.shape, dtype="f8") for i, ax in enumerate(axes): coords[:, i], cell_widths[:, i] = _obtain_coords_and_widths( icoords[:, i], ires, self.grid_cell_widths[0][ax], self.ds.domain_left_edge[ax].d, ) return coords, cell_widths def _parse_index(self): self.grid_dimensions = self.stream_handler.dimensions self.grid_left_edge[:] = self.stream_handler.left_edges self.grid_right_edge[:] = self.stream_handler.right_edges self.grid_levels[:] = self.stream_handler.levels self.min_level = self.grid_levels.min() self.grid_procs = self.stream_handler.processor_ids self.grid_particle_count[:] = self.stream_handler.particle_count if self.stream_handler.cell_widths is not None: self.grid_cell_widths = self.stream_handler.cell_widths[:] self.grid = StreamStretchedGrid else: self.grid_cell_widths = None mylog.debug("Copying reverse tree") self.grids = [] # We enumerate, so it's 0-indexed id and 1-indexed pid for id in range(self.num_grids): self.grids.append(self.grid(id, self)) self.grids[id].Level = self.grid_levels[id, 0] parent_ids = self.stream_handler.parent_ids if parent_ids is not None: reverse_tree = self.stream_handler.parent_ids.tolist() # Initial setup: for gid, pid in enumerate(reverse_tree): if pid >= 0: self.grids[gid]._parent_id = pid self.grids[pid]._children_ids.append(self.grids[gid].id) else: mylog.debug("Reconstructing parent-child relationships") self._reconstruct_parent_child() self.max_level = self.grid_levels.max() mylog.debug("Preparing grids") temp_grids = np.empty(self.num_grids, dtype="object") for i, grid in enumerate(self.grids): if (i % 1e4) == 0: mylog.debug("Prepared % 7i / % 7i grids", i, self.num_grids) grid.filename = None grid._prepare_grid() grid._setup_dx() grid.proc_num = self.grid_procs[i] temp_grids[i] = grid self.grids = temp_grids mylog.debug("Prepared") def _reconstruct_parent_child(self): mask = np.empty(len(self.grids), dtype="int32") mylog.debug("First pass; identifying child grids") for i, grid in enumerate(self.grids): get_box_grids_level( self.grid_left_edge[i, :], self.grid_right_edge[i, :], self.grid_levels[i].item() + 1, self.grid_left_edge, self.grid_right_edge, self.grid_levels, mask, ) ids = np.where(mask.astype("bool")) grid._children_ids = ids[0] # where is a tuple mylog.debug("Second pass; identifying parents") self.stream_handler.parent_ids = ( np.zeros(self.stream_handler.num_grids, "int64") - 1 ) for i, grid in enumerate(self.grids): # Second pass for child in grid.Children: child._parent_id = i # _id_offset = 0 self.stream_handler.parent_ids[child.id] = i def _initialize_grid_arrays(self): GridIndex._initialize_grid_arrays(self) self.grid_procs = np.zeros((self.num_grids, 1), "int32") def _detect_output_fields(self): # NOTE: Because particle unions add to the actual field list, without # having the keys in the field list itself, we need to double check # here. fl = set(self.stream_handler.get_fields()) fl.update(set(getattr(self, "field_list", []))) self.field_list = list(fl) def _populate_grid_objects(self): for g in self.grids: g._setup_dx() self.max_level = self.grid_levels.max() def _setup_data_io(self): if self.stream_handler.io is not None: self.io = self.stream_handler.io else: self.io = io_registry[self.dataset_type](self.ds) def _reset_particle_count(self): self.grid_particle_count[:] = self.stream_handler.particle_count for i, grid in enumerate(self.grids): grid.NumberOfParticles = self.grid_particle_count[i, 0]
[docs] def update_data(self, data): """ Update the stream data with a new data dict. If fields already exist, they will be replaced, but if they do not, they will be added. Fields already in the stream but not part of the data dict will be left alone. """ particle_types = set_particle_types(data[0]) self.stream_handler.particle_types.update(particle_types) self.ds._find_particle_types() for i, grid in enumerate(self.grids): field_units, gdata, number_of_particles = process_data(data[i]) self.stream_handler.particle_count[i] = number_of_particles self.stream_handler.field_units.update(field_units) for field in gdata: if field in grid.field_data: grid.field_data.pop(field, None) self.stream_handler.fields[grid.id][field] = gdata[field] self._reset_particle_count() # We only want to create a superset of fields here. for field in self.ds.field_list: if field[0] == "all": self.ds.field_list.remove(field) self._detect_output_fields() self.ds.create_field_info() mylog.debug("Creating Particle Union 'all'") pu = ParticleUnion("all", list(self.ds.particle_types_raw)) self.ds.add_particle_union(pu) self.ds.particle_types = tuple(set(self.ds.particle_types))
[docs] class StreamDataset(Dataset): _index_class: type[Index] = StreamHierarchy _field_info_class = StreamFieldInfo _dataset_type = "stream" def __init__( self, stream_handler, storage_filename=None, geometry="cartesian", unit_system="cgs", default_species_fields=None, *, axis_order: AxisOrder | None = None, ): self.fluid_types += ("stream",) self.geometry = Geometry(geometry) self.stream_handler = stream_handler self._find_particle_types() name = f"InMemoryParameterFile_{uuid.uuid4().hex}" from yt.data_objects.static_output import _cached_datasets if geometry == "spectral_cube": # mimic FITSDataset specific interface to allow testing with # fake, in memory data setdefaultattr(self, "lon_axis", 0) setdefaultattr(self, "lat_axis", 1) setdefaultattr(self, "spec_axis", 2) setdefaultattr(self, "lon_name", "X") setdefaultattr(self, "lat_name", "Y") setdefaultattr(self, "spec_name", "z") setdefaultattr(self, "spec_unit", "") setdefaultattr( self, "pixel2spec", lambda pixel_value: self.arr(pixel_value, self.spec_unit), # type: ignore [attr-defined] ) setdefaultattr( self, "spec2pixel", lambda spec_value: self.arr(spec_value, "code_length"), ) _cached_datasets[name] = self Dataset.__init__( self, name, self._dataset_type, unit_system=unit_system, default_species_fields=default_species_fields, axis_order=axis_order, ) @property def filename(self): return self.stream_handler.name @cached_property def unique_identifier(self) -> str: return str(self.parameters["CurrentTimeIdentifier"]) def _parse_parameter_file(self): self.parameters["CurrentTimeIdentifier"] = time.time() self.domain_left_edge = self.stream_handler.domain_left_edge.copy() self.domain_right_edge = self.stream_handler.domain_right_edge.copy() self.refine_by = self.stream_handler.refine_by self.dimensionality = self.stream_handler.dimensionality self._periodicity = self.stream_handler.periodicity self.domain_dimensions = self.stream_handler.domain_dimensions self.current_time = self.stream_handler.simulation_time self.gamma = 5.0 / 3.0 self.parameters["EOSType"] = -1 self.parameters["CosmologyHubbleConstantNow"] = 1.0 self.parameters["CosmologyCurrentRedshift"] = 1.0 self.parameters["HydroMethod"] = -1 self.parameters.update(self.stream_handler.parameters) if self.stream_handler.cosmology_simulation: self.cosmological_simulation = 1 self.current_redshift = self.stream_handler.current_redshift self.omega_lambda = self.stream_handler.omega_lambda self.omega_matter = self.stream_handler.omega_matter self.hubble_constant = self.stream_handler.hubble_constant else: self.current_redshift = 0.0 self.omega_lambda = 0.0 self.omega_matter = 0.0 self.hubble_constant = 0.0 self.cosmological_simulation = 0 def _set_units(self): self.field_units = self.stream_handler.field_units def _set_code_unit_attributes(self): base_units = self.stream_handler.code_units attrs = ( "length_unit", "mass_unit", "time_unit", "velocity_unit", "magnetic_unit", ) cgs_units = ("cm", "g", "s", "cm/s", "gauss") for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units, strict=True): if isinstance(unit, str): if unit == "code_magnetic": # If no magnetic unit was explicitly specified # we skip it now and take care of it at the bottom continue else: uq = self.quan(1.0, unit) elif isinstance(unit, numeric_type): uq = self.quan(unit, cgs_unit) elif isinstance(unit, YTQuantity): uq = unit elif isinstance(unit, tuple): uq = self.quan(unit[0], unit[1]) else: raise RuntimeError(f"{attr} ({unit}) is invalid.") setattr(self, attr, uq) if not hasattr(self, "magnetic_unit"): self.magnetic_unit = np.sqrt( 4 * np.pi * self.mass_unit / (self.time_unit**2 * self.length_unit) ) @classmethod def _is_valid(cls, filename: str, *args, **kwargs) -> bool: return False @property def _skip_cache(self): return True def _find_particle_types(self): particle_types = set() for k, v in self.stream_handler.particle_types.items(): if v: particle_types.add(k[0]) self.particle_types = tuple(particle_types) self.particle_types_raw = self.particle_types
[docs] class StreamDictFieldHandler(UserDict): _additional_fields: tuple[FieldKey, ...] = () @property def all_fields(self): self_fields = chain.from_iterable(s.keys() for s in self.values()) self_fields = list(set(self_fields)) fields = list(self._additional_fields) + self_fields fields = list(set(fields)) return fields
[docs] class StreamParticleIndex(SPHParticleIndex): def __init__(self, ds, dataset_type=None): self.stream_handler = ds.stream_handler super().__init__(ds, dataset_type) def _setup_data_io(self): if self.stream_handler.io is not None: self.io = self.stream_handler.io else: self.io = io_registry[self.dataset_type](self.ds)
[docs] def update_data(self, data): """ Update the stream data with a new data dict. If fields already exist, they will be replaced, but if they do not, they will be added. Fields already in the stream but not part of the data dict will be left alone. """ # Alias ds = self.ds handler = ds.stream_handler # Preprocess field_units, data, _ = process_data(data) pdata = {} for key in data.keys(): if not isinstance(key, tuple): field = ("io", key) mylog.debug("Reassigning '%s' to '%s'", key, field) else: field = key pdata[field] = data[key] data = pdata # Drop reference count particle_types = set_particle_types(data) # Update particle types handler.particle_types.update(particle_types) ds._find_particle_types() # Update fields handler.field_units.update(field_units) fields = handler.fields for field in data.keys(): if field not in fields._additional_fields: fields._additional_fields += (field,) fields["stream_file"].update(data) # Update field list for field in self.ds.field_list: if field[0] in ["all", "nbody"]: self.ds.field_list.remove(field) self._detect_output_fields() self.ds.create_field_info()
[docs] class StreamParticleFile(ParticleFile): pass
[docs] class StreamParticlesDataset(StreamDataset): _index_class = StreamParticleIndex _file_class = StreamParticleFile _field_info_class = StreamFieldInfo _dataset_type = "stream_particles" file_count = 1 filename_template = "stream_file" _proj_type = "particle_proj" def __init__( self, stream_handler, storage_filename=None, geometry="cartesian", unit_system="cgs", default_species_fields=None, *, axis_order: AxisOrder | None = None, ): super().__init__( stream_handler, storage_filename=storage_filename, geometry=geometry, unit_system=unit_system, default_species_fields=default_species_fields, axis_order=axis_order, ) fields = list(stream_handler.fields["stream_file"].keys()) sph_ptypes = [] for ptype in self.particle_types: if (ptype, "density") in fields and (ptype, "smoothing_length") in fields: sph_ptypes.append(ptype) if len(sph_ptypes) == 1: self._sph_ptypes = tuple(sph_ptypes) elif len(sph_ptypes) > 1: raise ValueError("Multiple SPH particle types are currently not supported!")
[docs] def add_sph_fields(self, n_neighbors=32, kernel="cubic", sph_ptype="io"): """Add SPH fields for the specified particle type. For a particle type with "particle_position" and "particle_mass" already defined, this method adds the "smoothing_length" and "density" fields. "smoothing_length" is computed as the distance to the nth nearest neighbor. "density" is computed as the SPH (gather) smoothed mass. The SPH fields are added only if they don't already exist. Parameters ---------- n_neighbors : int The number of neighbors to use in smoothing length computation. kernel : str The kernel function to use in density estimation. sph_ptype : str The SPH particle type. Each dataset has one sph_ptype only. This method will overwrite existing sph_ptype of the dataset. """ mylog.info("Generating SPH fields") # Unify units l_unit = "code_length" m_unit = "code_mass" d_unit = "code_mass / code_length**3" # Read basic fields ad = self.all_data() pos = ad[sph_ptype, "particle_position"].to(l_unit).d mass = ad[sph_ptype, "particle_mass"].to(m_unit).d # Construct k-d tree kdtree = PyKDTree( pos.astype("float64"), left_edge=self.domain_left_edge.to_value(l_unit), right_edge=self.domain_right_edge.to_value(l_unit), periodic=self.periodicity, leafsize=2 * int(n_neighbors), ) order = np.argsort(kdtree.idx) def exists(fname): if (sph_ptype, fname) in self.derived_field_list: mylog.info( "Field ('%s','%s') already exists. Skipping", sph_ptype, fname ) return True else: mylog.info("Generating field ('%s','%s')", sph_ptype, fname) return False data = {} # Add smoothing length field fname = "smoothing_length" if not exists(fname): hsml = generate_smoothing_length(pos[kdtree.idx], kdtree, n_neighbors) hsml = hsml[order] data[sph_ptype, "smoothing_length"] = (hsml, l_unit) else: hsml = ad[sph_ptype, fname].to(l_unit).d # Add density field fname = "density" if not exists(fname): dens = estimate_density( pos[kdtree.idx], mass[kdtree.idx], hsml[kdtree.idx], kdtree, kernel_name=kernel, ) dens = dens[order] data[sph_ptype, "density"] = (dens, d_unit) # Add fields self._sph_ptypes = (sph_ptype,) self.index.update_data(data) self.num_neighbors = n_neighbors
_cis = np.fromiter( chain.from_iterable(product([0, 1], [0, 1], [0, 1])), dtype=np.int64, count=8 * 3 ) _cis.shape = (8, 3)
[docs] def hexahedral_connectivity(xgrid, ygrid, zgrid): r"""Define the cell coordinates and cell neighbors of a hexahedral mesh for a semistructured grid. Used to specify the connectivity and coordinates parameters used in :func:`~yt.frontends.stream.data_structures.load_hexahedral_mesh`. Parameters ---------- xgrid : array_like x-coordinates of boundaries of the hexahedral cells. Should be a one-dimensional array. ygrid : array_like y-coordinates of boundaries of the hexahedral cells. Should be a one-dimensional array. zgrid : array_like z-coordinates of boundaries of the hexahedral cells. Should be a one-dimensional array. Returns ------- coords : array_like The list of (x,y,z) coordinates of the vertices of the mesh. Is of size (M,3) where M is the number of vertices. connectivity : array_like For each hexahedron h in the mesh, gives the index of each of h's neighbors. Is of size (N,8), where N is the number of hexahedra. Examples -------- >>> xgrid = np.array([-1, -0.25, 0, 0.25, 1]) >>> coords, conn = hexahedral_connectivity(xgrid, xgrid, xgrid) >>> coords array([[-1. , -1. , -1. ], [-1. , -1. , -0.25], [-1. , -1. , 0. ], ..., [ 1. , 1. , 0. ], [ 1. , 1. , 0.25], [ 1. , 1. , 1. ]]) >>> conn array([[ 0, 1, 5, 6, 25, 26, 30, 31], [ 1, 2, 6, 7, 26, 27, 31, 32], [ 2, 3, 7, 8, 27, 28, 32, 33], ..., [ 91, 92, 96, 97, 116, 117, 121, 122], [ 92, 93, 97, 98, 117, 118, 122, 123], [ 93, 94, 98, 99, 118, 119, 123, 124]]) """ nx = len(xgrid) ny = len(ygrid) nz = len(zgrid) coords = np.zeros((nx, ny, nz, 3), dtype="float64", order="C") coords[:, :, :, 0] = xgrid[:, None, None] coords[:, :, :, 1] = ygrid[None, :, None] coords[:, :, :, 2] = zgrid[None, None, :] coords.shape = (nx * ny * nz, 3) cycle = np.rollaxis(np.indices((nx - 1, ny - 1, nz - 1)), 0, 4) cycle.shape = ((nx - 1) * (ny - 1) * (nz - 1), 3) off = _cis + cycle[:, np.newaxis] connectivity = np.array( ((off[:, :, 0] * ny) + off[:, :, 1]) * nz + off[:, :, 2], order="C" ) return coords, connectivity
[docs] class StreamHexahedralMesh(SemiStructuredMesh): _connectivity_length = 8 _index_offset = 0
[docs] class StreamHexahedralHierarchy(UnstructuredIndex): def __init__(self, ds, dataset_type=None): self.stream_handler = ds.stream_handler super().__init__(ds, dataset_type) def _initialize_mesh(self): coords = self.stream_handler.fields.pop("coordinates") connect = self.stream_handler.fields.pop("connectivity") self.meshes = [ StreamHexahedralMesh(0, self.index_filename, connect, coords, self) ] def _setup_data_io(self): if self.stream_handler.io is not None: self.io = self.stream_handler.io else: self.io = io_registry[self.dataset_type](self.ds) def _detect_output_fields(self): self.field_list = list(set(self.stream_handler.get_fields()))
[docs] class StreamHexahedralDataset(StreamDataset): _index_class = StreamHexahedralHierarchy _field_info_class = StreamFieldInfo _dataset_type = "stream_hexahedral"
[docs] class StreamOctreeSubset(OctreeSubset): domain_id = 1 _domain_offset = 1 def __init__(self, base_region, ds, oct_handler, num_zones=2, num_ghost_zones=0): self._num_zones = num_zones self.field_data = YTFieldData() self.field_parameters = {} self.ds = ds self._oct_handler = oct_handler self._last_mask = None self._last_selector_id = None self._current_particle_type = "io" self._current_fluid_type = self.ds.default_fluid_type self.base_region = base_region self.base_selector = base_region.selector self._num_ghost_zones = num_ghost_zones if num_ghost_zones > 0: if not all(ds.periodicity): mylog.warning( "Ghost zones will wrongly assume the domain to be periodic." ) base_grid = StreamOctreeSubset(base_region, ds, oct_handler, num_zones) self._base_grid = base_grid @property def oct_handler(self): return self._oct_handler
[docs] def retrieve_ghost_zones(self, ngz, fields, smoothed=False): try: new_subset = self._subset_with_gz mylog.debug("Reusing previous subset with ghost zone.") except AttributeError: new_subset = StreamOctreeSubset( self.base_region, self.ds, self.oct_handler, self._num_zones, num_ghost_zones=ngz, ) self._subset_with_gz = new_subset return new_subset
def _fill_no_ghostzones(self, content, dest, selector, offset): # Here we get a copy of the file, which we skip through and read the # bits we want. oct_handler = self.oct_handler cell_count = selector.count_oct_cells(self.oct_handler, self.domain_id) levels, cell_inds, file_inds = self.oct_handler.file_index_octs( selector, self.domain_id, cell_count ) levels[:] = 0 dest.update((field, np.empty(cell_count, dtype="float64")) for field in content) # Make references ... count = oct_handler.fill_level( 0, levels, cell_inds, file_inds, dest, content, offset ) return count def _fill_with_ghostzones(self, content, dest, selector, offset): oct_handler = self.oct_handler ndim = self.ds.dimensionality cell_count = ( selector.count_octs(self.oct_handler, self.domain_id) * self.nz**ndim ) gz_cache = getattr(self, "_ghost_zone_cache", None) if gz_cache: levels, cell_inds, file_inds, domains = gz_cache else: gz_cache = ( levels, cell_inds, file_inds, domains, ) = oct_handler.file_index_octs_with_ghost_zones( selector, self.domain_id, cell_count ) self._ghost_zone_cache = gz_cache levels[:] = 0 dest.update((field, np.empty(cell_count, dtype="float64")) for field in content) # Make references ... oct_handler.fill_level(0, levels, cell_inds, file_inds, dest, content, offset)
[docs] def fill(self, content, dest, selector, offset): if self._num_ghost_zones == 0: return self._fill_no_ghostzones(content, dest, selector, offset) else: return self._fill_with_ghostzones(content, dest, selector, offset)
[docs] class StreamOctreeHandler(OctreeIndex): def __init__(self, ds, dataset_type=None): self.stream_handler = ds.stream_handler self.dataset_type = dataset_type super().__init__(ds, dataset_type) def _setup_data_io(self): if self.stream_handler.io is not None: self.io = self.stream_handler.io else: self.io = io_registry[self.dataset_type](self.ds) def _initialize_oct_handler(self): header = { "dims": self.ds.domain_dimensions // self.ds.num_zones, "left_edge": self.ds.domain_left_edge, "right_edge": self.ds.domain_right_edge, "octree": self.ds.octree_mask, "num_zones": self.ds.num_zones, "partial_coverage": self.ds.partial_coverage, } self.oct_handler = OctreeContainer.load_octree(header) # We do now need to get the maximum level set, as well. self.ds.max_level = self.oct_handler.max_level def _identify_base_chunk(self, dobj): if getattr(dobj, "_chunk_info", None) is None: base_region = getattr(dobj, "base_region", dobj) subset = [ StreamOctreeSubset( base_region, self.dataset, self.oct_handler, self.ds.num_zones, ) ] dobj._chunk_info = subset dobj._current_chunk = list(self._chunk_all(dobj))[0] def _chunk_all(self, dobj): oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info) yield YTDataChunk(dobj, "all", oobjs, None) def _chunk_spatial(self, dobj, ngz, sort=None, preload_fields=None): sobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info) # This is where we will perform cutting of the Octree and # load-balancing. That may require a specialized selector object to # cut based on some space-filling curve index. for og in sobjs: if ngz > 0: g = og.retrieve_ghost_zones(ngz, [], smoothed=True) else: g = og yield YTDataChunk(dobj, "spatial", [g]) def _chunk_io(self, dobj, cache=True, local_only=False): oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info) for subset in oobjs: yield YTDataChunk(dobj, "io", [subset], None, cache=cache) def _setup_classes(self): dd = self._get_data_reader_dict() super()._setup_classes(dd) def _detect_output_fields(self): # NOTE: Because particle unions add to the actual field list, without # having the keys in the field list itself, we need to double check # here. fl = set(self.stream_handler.get_fields()) fl.update(set(getattr(self, "field_list", []))) self.field_list = list(fl)
[docs] class StreamOctreeDataset(StreamDataset): _index_class = StreamOctreeHandler _field_info_class = StreamFieldInfo _dataset_type = "stream_octree" levelmax = None def __init__( self, stream_handler, storage_filename=None, geometry="cartesian", unit_system="cgs", default_species_fields=None, ): super().__init__( stream_handler, storage_filename, geometry, unit_system, default_species_fields=default_species_fields, ) # Set up levelmax self.max_level = stream_handler.levels.max() self.min_level = stream_handler.levels.min()
[docs] class StreamUnstructuredMesh(UnstructuredMesh): _index_offset = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._connectivity_length = self.connectivity_indices.shape[1]
[docs] class StreamUnstructuredIndex(UnstructuredIndex): def __init__(self, ds, dataset_type=None): self.stream_handler = ds.stream_handler super().__init__(ds, dataset_type) def _initialize_mesh(self): coords = self.stream_handler.fields.pop("coordinates") connect = always_iterable(self.stream_handler.fields.pop("connectivity")) self.meshes = [ StreamUnstructuredMesh(i, self.index_filename, c1, c2, self) for i, (c1, c2) in enumerate(zip(connect, repeat(coords))) ] self.mesh_union = MeshUnion("mesh_union", self.meshes) def _setup_data_io(self): if self.stream_handler.io is not None: self.io = self.stream_handler.io else: self.io = io_registry[self.dataset_type](self.ds) def _detect_output_fields(self): self.field_list = list(set(self.stream_handler.get_fields())) fnames = list({fn for ft, fn in self.field_list}) self.field_list += [("all", fname) for fname in fnames]
[docs] class StreamUnstructuredMeshDataset(StreamDataset): _index_class = StreamUnstructuredIndex _field_info_class = StreamFieldInfo _dataset_type = "stream_unstructured" def _find_particle_types(self): pass