import os
import time
import uuid
import weakref
from collections import UserDict
from functools import cached_property
from itertools import chain, product, repeat
from numbers import Number as numeric_type
import numpy as np
from more_itertools import always_iterable
from yt._typing import AxisOrder, FieldKey
from yt.data_objects.field_data import YTFieldData
from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
from yt.data_objects.index_subobjects.octree_subset import OctreeSubset
from yt.data_objects.index_subobjects.stretched_grid import StretchedGrid
from yt.data_objects.index_subobjects.unstructured_mesh import (
    SemiStructuredMesh,
    UnstructuredMesh,
)
from yt.data_objects.static_output import Dataset, ParticleFile
from yt.data_objects.unions import MeshUnion, ParticleUnion
from yt.frontends.sph.data_structures import SPHParticleIndex
from yt.funcs import setdefaultattr
from yt.geometry.api import Geometry
from yt.geometry.geometry_handler import Index, YTDataChunk
from yt.geometry.grid_geometry_handler import GridIndex
from yt.geometry.oct_container import OctreeContainer
from yt.geometry.oct_geometry_handler import OctreeIndex
from yt.geometry.unstructured_mesh_handler import UnstructuredIndex
from yt.units import YTQuantity
from yt.utilities.io_handler import io_registry
from yt.utilities.lib.cykdtree import PyKDTree
from yt.utilities.lib.misc_utilities import (
    _obtain_coords_and_widths,
    get_box_grids_level,
)
from yt.utilities.lib.particle_kdtree_tools import (
    estimate_density,
    generate_smoothing_length,
)
from yt.utilities.logger import ytLogger as mylog
from .definitions import process_data, set_particle_types
from .fields import StreamFieldInfo
[docs]
class StreamGrid(AMRGridPatch):
    """
    Class representing a single In-memory Grid instance.
    """
    __slots__ = ["proc_num"]
    _id_offset = 0
    def __init__(self, id, index):
        """
        Returns an instance of StreamGrid with *id*, associated with *filename*
        and *index*.
        """
        # All of the field parameters will be passed to us as needed.
        AMRGridPatch.__init__(self, id, filename=None, index=index)
        self._children_ids = []
        self._parent_id = -1
        self.Level = -1
[docs]
    def set_filename(self, filename):
        pass 
    @property
    def Parent(self):
        if self._parent_id == -1:
            return None
        return self.index.grids[self._parent_id - self._id_offset]
    @property
    def Children(self):
        return [self.index.grids[cid - self._id_offset] for cid in self._children_ids] 
[docs]
class StreamStretchedGrid(StretchedGrid):
    _id_offset = 0
    def __init__(self, id, index):
        cell_widths = index.grid_cell_widths[id - self._id_offset]
        super().__init__(id, cell_widths, index=index)
        self._children_ids = []
        self._parent_id = -1
        self.Level = -1
    @property
    def Parent(self):
        if self._parent_id == -1:
            return None
        return self.index.grids[self._parent_id - self._id_offset]
    @property
    def Children(self):
        return [self.index.grids[cid - self._id_offset] for cid in self._children_ids] 
[docs]
class StreamHandler:
    def __init__(
        self,
        left_edges,
        right_edges,
        dimensions,
        levels,
        parent_ids,
        particle_count,
        processor_ids,
        fields,
        field_units,
        code_units,
        io=None,
        particle_types=None,
        periodicity=(True, True, True),
        *,
        cell_widths=None,
        parameters=None,
    ):
        if particle_types is None:
            particle_types = {}
        self.left_edges = np.array(left_edges)
        self.right_edges = np.array(right_edges)
        self.dimensions = dimensions
        self.levels = levels
        self.parent_ids = parent_ids
        self.particle_count = particle_count
        self.processor_ids = processor_ids
        self.num_grids = self.levels.size
        self.fields = fields
        self.field_units = field_units
        self.code_units = code_units
        self.io = io
        self.particle_types = particle_types
        self.periodicity = periodicity
        self.cell_widths = cell_widths
        if parameters is None:
            self.parameters = {}
        else:
            self.parameters = parameters.copy()
[docs]
    def get_fields(self):
        return self.fields.all_fields 
[docs]
    def get_particle_type(self, field):
        if field in self.particle_types:
            return self.particle_types[field]
        else:
            return False 
 
[docs]
class StreamHierarchy(GridIndex):
    grid = StreamGrid
    def __init__(self, ds, dataset_type=None):
        self.dataset_type = dataset_type
        self.float_type = "float64"
        self.dataset = weakref.proxy(ds)  # for _obtain_enzo
        self.stream_handler = ds.stream_handler
        self.float_type = "float64"
        self.directory = os.getcwd()
        GridIndex.__init__(self, ds, dataset_type)
    def _count_grids(self):
        self.num_grids = self.stream_handler.num_grids
    def _icoords_to_fcoords(self, icoords, ires, axes=None):
        """
        We check here that we have cell_widths, and if we do, we will provide them.
        """
        if self.grid_cell_widths is None:
            return super()._icoords_to_fcoords(icoords, ires, axes)
        if axes is None:
            axes = [0, 1, 2]
        # Transpose these by reversing the shape
        coords = np.empty(icoords.shape, dtype="f8")
        cell_widths = np.empty(icoords.shape, dtype="f8")
        for i, ax in enumerate(axes):
            coords[:, i], cell_widths[:, i] = _obtain_coords_and_widths(
                icoords[:, i],
                ires,
                self.grid_cell_widths[0][ax],
                self.ds.domain_left_edge[ax].d,
            )
        return coords, cell_widths
    def _parse_index(self):
        self.grid_dimensions = self.stream_handler.dimensions
        self.grid_left_edge[:] = self.stream_handler.left_edges
        self.grid_right_edge[:] = self.stream_handler.right_edges
        self.grid_levels[:] = self.stream_handler.levels
        self.min_level = self.grid_levels.min()
        self.grid_procs = self.stream_handler.processor_ids
        self.grid_particle_count[:] = self.stream_handler.particle_count
        if self.stream_handler.cell_widths is not None:
            self.grid_cell_widths = self.stream_handler.cell_widths[:]
            self.grid = StreamStretchedGrid
        else:
            self.grid_cell_widths = None
        mylog.debug("Copying reverse tree")
        self.grids = []
        # We enumerate, so it's 0-indexed id and 1-indexed pid
        for id in range(self.num_grids):
            self.grids.append(self.grid(id, self))
            self.grids[id].Level = self.grid_levels[id, 0]
        parent_ids = self.stream_handler.parent_ids
        if parent_ids is not None:
            reverse_tree = self.stream_handler.parent_ids.tolist()
            # Initial setup:
            for gid, pid in enumerate(reverse_tree):
                if pid >= 0:
                    self.grids[gid]._parent_id = pid
                    self.grids[pid]._children_ids.append(self.grids[gid].id)
        else:
            mylog.debug("Reconstructing parent-child relationships")
            self._reconstruct_parent_child()
        self.max_level = self.grid_levels.max()
        mylog.debug("Preparing grids")
        temp_grids = np.empty(self.num_grids, dtype="object")
        for i, grid in enumerate(self.grids):
            if (i % 1e4) == 0:
                mylog.debug("Prepared % 7i / % 7i grids", i, self.num_grids)
            grid.filename = None
            grid._prepare_grid()
            grid._setup_dx()
            grid.proc_num = self.grid_procs[i]
            temp_grids[i] = grid
        self.grids = temp_grids
        mylog.debug("Prepared")
    def _reconstruct_parent_child(self):
        mask = np.empty(len(self.grids), dtype="int32")
        mylog.debug("First pass; identifying child grids")
        for i, grid in enumerate(self.grids):
            get_box_grids_level(
                self.grid_left_edge[i, :],
                self.grid_right_edge[i, :],
                self.grid_levels[i].item() + 1,
                self.grid_left_edge,
                self.grid_right_edge,
                self.grid_levels,
                mask,
            )
            ids = np.where(mask.astype("bool"))
            grid._children_ids = ids[0]  # where is a tuple
        mylog.debug("Second pass; identifying parents")
        self.stream_handler.parent_ids = (
            np.zeros(self.stream_handler.num_grids, "int64") - 1
        )
        for i, grid in enumerate(self.grids):  # Second pass
            for child in grid.Children:
                child._parent_id = i
                # _id_offset = 0
                self.stream_handler.parent_ids[child.id] = i
    def _initialize_grid_arrays(self):
        GridIndex._initialize_grid_arrays(self)
        self.grid_procs = np.zeros((self.num_grids, 1), "int32")
    def _detect_output_fields(self):
        # NOTE: Because particle unions add to the actual field list, without
        # having the keys in the field list itself, we need to double check
        # here.
        fl = set(self.stream_handler.get_fields())
        fl.update(set(getattr(self, "field_list", [])))
        self.field_list = list(fl)
    def _populate_grid_objects(self):
        for g in self.grids:
            g._setup_dx()
        self.max_level = self.grid_levels.max()
    def _setup_data_io(self):
        if self.stream_handler.io is not None:
            self.io = self.stream_handler.io
        else:
            self.io = io_registry[self.dataset_type](self.ds)
    def _reset_particle_count(self):
        self.grid_particle_count[:] = self.stream_handler.particle_count
        for i, grid in enumerate(self.grids):
            grid.NumberOfParticles = self.grid_particle_count[i, 0]
[docs]
    def update_data(self, data):
        """
        Update the stream data with a new data dict. If fields already exist,
        they will be replaced, but if they do not, they will be added. Fields
        already in the stream but not part of the data dict will be left
        alone.
        """
        particle_types = set_particle_types(data[0])
        self.stream_handler.particle_types.update(particle_types)
        self.ds._find_particle_types()
        for i, grid in enumerate(self.grids):
            field_units, gdata, number_of_particles = process_data(data[i])
            self.stream_handler.particle_count[i] = number_of_particles
            self.stream_handler.field_units.update(field_units)
            for field in gdata:
                if field in grid.field_data:
                    grid.field_data.pop(field, None)
                self.stream_handler.fields[grid.id][field] = gdata[field]
        self._reset_particle_count()
        # We only want to create a superset of fields here.
        for field in self.ds.field_list:
            if field[0] == "all":
                self.ds.field_list.remove(field)
        self._detect_output_fields()
        self.ds.create_field_info()
        mylog.debug("Creating Particle Union 'all'")
        pu = ParticleUnion("all", list(self.ds.particle_types_raw))
        self.ds.add_particle_union(pu)
        self.ds.particle_types = tuple(set(self.ds.particle_types)) 
 
[docs]
class StreamDataset(Dataset):
    _index_class: type[Index] = StreamHierarchy
    _field_info_class = StreamFieldInfo
    _dataset_type = "stream"
    def __init__(
        self,
        stream_handler,
        storage_filename=None,
        geometry="cartesian",
        unit_system="cgs",
        default_species_fields=None,
        *,
        axis_order: AxisOrder | None = None,
    ):
        self.fluid_types += ("stream",)
        self.geometry = Geometry(geometry)
        self.stream_handler = stream_handler
        self._find_particle_types()
        name = f"InMemoryParameterFile_{uuid.uuid4().hex}"
        from yt.data_objects.static_output import _cached_datasets
        if geometry == "spectral_cube":
            # mimic FITSDataset specific interface to allow testing with
            # fake, in memory data
            setdefaultattr(self, "lon_axis", 0)
            setdefaultattr(self, "lat_axis", 1)
            setdefaultattr(self, "spec_axis", 2)
            setdefaultattr(self, "lon_name", "X")
            setdefaultattr(self, "lat_name", "Y")
            setdefaultattr(self, "spec_name", "z")
            setdefaultattr(self, "spec_unit", "")
            setdefaultattr(
                self,
                "pixel2spec",
                lambda pixel_value: self.arr(pixel_value, self.spec_unit),  # type: ignore [attr-defined]
            )
            setdefaultattr(
                self,
                "spec2pixel",
                lambda spec_value: self.arr(spec_value, "code_length"),
            )
        _cached_datasets[name] = self
        Dataset.__init__(
            self,
            name,
            self._dataset_type,
            unit_system=unit_system,
            default_species_fields=default_species_fields,
            axis_order=axis_order,
        )
    @property
    def filename(self):
        return self.stream_handler.name
    @cached_property
    def unique_identifier(self) -> str:
        return str(self.parameters["CurrentTimeIdentifier"])
    def _parse_parameter_file(self):
        self.parameters["CurrentTimeIdentifier"] = time.time()
        self.domain_left_edge = self.stream_handler.domain_left_edge.copy()
        self.domain_right_edge = self.stream_handler.domain_right_edge.copy()
        self.refine_by = self.stream_handler.refine_by
        self.dimensionality = self.stream_handler.dimensionality
        self._periodicity = self.stream_handler.periodicity
        self.domain_dimensions = self.stream_handler.domain_dimensions
        self.current_time = self.stream_handler.simulation_time
        self.gamma = 5.0 / 3.0
        self.parameters["EOSType"] = -1
        self.parameters["CosmologyHubbleConstantNow"] = 1.0
        self.parameters["CosmologyCurrentRedshift"] = 1.0
        self.parameters["HydroMethod"] = -1
        self.parameters.update(self.stream_handler.parameters)
        if self.stream_handler.cosmology_simulation:
            self.cosmological_simulation = 1
            self.current_redshift = self.stream_handler.current_redshift
            self.omega_lambda = self.stream_handler.omega_lambda
            self.omega_matter = self.stream_handler.omega_matter
            self.hubble_constant = self.stream_handler.hubble_constant
        else:
            self.current_redshift = 0.0
            self.omega_lambda = 0.0
            self.omega_matter = 0.0
            self.hubble_constant = 0.0
            self.cosmological_simulation = 0
    def _set_units(self):
        self.field_units = self.stream_handler.field_units
    def _set_code_unit_attributes(self):
        base_units = self.stream_handler.code_units
        attrs = (
            "length_unit",
            "mass_unit",
            "time_unit",
            "velocity_unit",
            "magnetic_unit",
        )
        cgs_units = ("cm", "g", "s", "cm/s", "gauss")
        for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units, strict=True):
            if isinstance(unit, str):
                if unit == "code_magnetic":
                    # If no magnetic unit was explicitly specified
                    # we skip it now and take care of it at the bottom
                    continue
                else:
                    uq = self.quan(1.0, unit)
            elif isinstance(unit, numeric_type):
                uq = self.quan(unit, cgs_unit)
            elif isinstance(unit, YTQuantity):
                uq = unit
            elif isinstance(unit, tuple):
                uq = self.quan(unit[0], unit[1])
            else:
                raise RuntimeError(f"{attr} ({unit}) is invalid.")
            setattr(self, attr, uq)
        if not hasattr(self, "magnetic_unit"):
            self.magnetic_unit = np.sqrt(
                4 * np.pi * self.mass_unit / (self.time_unit**2 * self.length_unit)
            )
    @classmethod
    def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
        return False
    @property
    def _skip_cache(self):
        return True
    def _find_particle_types(self):
        particle_types = set()
        for k, v in self.stream_handler.particle_types.items():
            if v:
                particle_types.add(k[0])
        self.particle_types = tuple(particle_types)
        self.particle_types_raw = self.particle_types 
[docs]
class StreamDictFieldHandler(UserDict):
    _additional_fields: tuple[FieldKey, ...] = ()
    @property
    def all_fields(self):
        self_fields = chain.from_iterable(s.keys() for s in self.values())
        self_fields = list(set(self_fields))
        fields = list(self._additional_fields) + self_fields
        fields = list(set(fields))
        return fields 
[docs]
class StreamParticleIndex(SPHParticleIndex):
    def __init__(self, ds, dataset_type=None):
        self.stream_handler = ds.stream_handler
        super().__init__(ds, dataset_type)
    def _setup_data_io(self):
        if self.stream_handler.io is not None:
            self.io = self.stream_handler.io
        else:
            self.io = io_registry[self.dataset_type](self.ds)
[docs]
    def update_data(self, data):
        """
        Update the stream data with a new data dict. If fields already exist,
        they will be replaced, but if they do not, they will be added. Fields
        already in the stream but not part of the data dict will be left
        alone.
        """
        # Alias
        ds = self.ds
        handler = ds.stream_handler
        # Preprocess
        field_units, data, _ = process_data(data)
        pdata = {}
        for key in data.keys():
            if not isinstance(key, tuple):
                field = ("io", key)
                mylog.debug("Reassigning '%s' to '%s'", key, field)
            else:
                field = key
            pdata[field] = data[key]
        data = pdata  # Drop reference count
        particle_types = set_particle_types(data)
        # Update particle types
        handler.particle_types.update(particle_types)
        ds._find_particle_types()
        # Update fields
        handler.field_units.update(field_units)
        fields = handler.fields
        for field in data.keys():
            if field not in fields._additional_fields:
                fields._additional_fields += (field,)
        fields["stream_file"].update(data)
        # Update field list
        for field in self.ds.field_list:
            if field[0] in ["all", "nbody"]:
                self.ds.field_list.remove(field)
        self._detect_output_fields()
        self.ds.create_field_info() 
 
[docs]
class StreamParticleFile(ParticleFile):
    pass 
[docs]
class StreamParticlesDataset(StreamDataset):
    _index_class = StreamParticleIndex
    _file_class = StreamParticleFile
    _field_info_class = StreamFieldInfo
    _dataset_type = "stream_particles"
    file_count = 1
    filename_template = "stream_file"
    _proj_type = "particle_proj"
    def __init__(
        self,
        stream_handler,
        storage_filename=None,
        geometry="cartesian",
        unit_system="cgs",
        default_species_fields=None,
        *,
        axis_order: AxisOrder | None = None,
    ):
        super().__init__(
            stream_handler,
            storage_filename=storage_filename,
            geometry=geometry,
            unit_system=unit_system,
            default_species_fields=default_species_fields,
            axis_order=axis_order,
        )
        fields = list(stream_handler.fields["stream_file"].keys())
        sph_ptypes = []
        for ptype in self.particle_types:
            if (ptype, "density") in fields and (ptype, "smoothing_length") in fields:
                sph_ptypes.append(ptype)
        if len(sph_ptypes) == 1:
            self._sph_ptypes = tuple(sph_ptypes)
        elif len(sph_ptypes) > 1:
            raise ValueError("Multiple SPH particle types are currently not supported!")
[docs]
    def add_sph_fields(self, n_neighbors=32, kernel="cubic", sph_ptype="io"):
        """Add SPH fields for the specified particle type.
        For a particle type with "particle_position" and "particle_mass" already
        defined, this method adds the "smoothing_length" and "density" fields.
        "smoothing_length" is computed as the distance to the nth nearest
        neighbor. "density" is computed as the SPH (gather) smoothed mass. The
        SPH fields are added only if they don't already exist.
        Parameters
        ----------
        n_neighbors : int
            The number of neighbors to use in smoothing length computation.
        kernel : str
            The kernel function to use in density estimation.
        sph_ptype : str
            The SPH particle type. Each dataset has one sph_ptype only. This
            method will overwrite existing sph_ptype of the dataset.
        """
        mylog.info("Generating SPH fields")
        # Unify units
        l_unit = "code_length"
        m_unit = "code_mass"
        d_unit = "code_mass / code_length**3"
        # Read basic fields
        ad = self.all_data()
        pos = ad[sph_ptype, "particle_position"].to(l_unit).d
        mass = ad[sph_ptype, "particle_mass"].to(m_unit).d
        # Construct k-d tree
        kdtree = PyKDTree(
            pos.astype("float64"),
            left_edge=self.domain_left_edge.to_value(l_unit),
            right_edge=self.domain_right_edge.to_value(l_unit),
            periodic=self.periodicity,
            leafsize=2 * int(n_neighbors),
        )
        order = np.argsort(kdtree.idx)
        def exists(fname):
            if (sph_ptype, fname) in self.derived_field_list:
                mylog.info(
                    "Field ('%s','%s') already exists. Skipping", sph_ptype, fname
                )
                return True
            else:
                mylog.info("Generating field ('%s','%s')", sph_ptype, fname)
                return False
        data = {}
        # Add smoothing length field
        fname = "smoothing_length"
        if not exists(fname):
            hsml = generate_smoothing_length(pos[kdtree.idx], kdtree, n_neighbors)
            hsml = hsml[order]
            data[sph_ptype, "smoothing_length"] = (hsml, l_unit)
        else:
            hsml = ad[sph_ptype, fname].to(l_unit).d
        # Add density field
        fname = "density"
        if not exists(fname):
            dens = estimate_density(
                pos[kdtree.idx],
                mass[kdtree.idx],
                hsml[kdtree.idx],
                kdtree,
                kernel_name=kernel,
            )
            dens = dens[order]
            data[sph_ptype, "density"] = (dens, d_unit)
        # Add fields
        self._sph_ptypes = (sph_ptype,)
        self.index.update_data(data)
        self.num_neighbors = n_neighbors 
 
_cis = np.fromiter(
    chain.from_iterable(product([0, 1], [0, 1], [0, 1])), dtype=np.int64, count=8 * 3
)
_cis.shape = (8, 3)
[docs]
def hexahedral_connectivity(xgrid, ygrid, zgrid):
    r"""Define the cell coordinates and cell neighbors of a hexahedral mesh
    for a semistructured grid. Used to specify the connectivity and
    coordinates parameters used in
    :func:`~yt.frontends.stream.data_structures.load_hexahedral_mesh`.
    Parameters
    ----------
    xgrid : array_like
       x-coordinates of boundaries of the hexahedral cells. Should be a
       one-dimensional array.
    ygrid : array_like
       y-coordinates of boundaries of the hexahedral cells. Should be a
       one-dimensional array.
    zgrid : array_like
       z-coordinates of boundaries of the hexahedral cells. Should be a
       one-dimensional array.
    Returns
    -------
    coords : array_like
        The list of (x,y,z) coordinates of the vertices of the mesh.
        Is of size (M,3) where M is the number of vertices.
    connectivity : array_like
        For each hexahedron h in the mesh, gives the index of each of h's
        neighbors. Is of size (N,8), where N is the number of hexahedra.
    Examples
    --------
    >>> xgrid = np.array([-1, -0.25, 0, 0.25, 1])
    >>> coords, conn = hexahedral_connectivity(xgrid, xgrid, xgrid)
    >>> coords
    array([[-1.  , -1.  , -1.  ],
           [-1.  , -1.  , -0.25],
           [-1.  , -1.  ,  0.  ],
           ...,
           [ 1.  ,  1.  ,  0.  ],
           [ 1.  ,  1.  ,  0.25],
           [ 1.  ,  1.  ,  1.  ]])
    >>> conn
    array([[  0,   1,   5,   6,  25,  26,  30,  31],
           [  1,   2,   6,   7,  26,  27,  31,  32],
           [  2,   3,   7,   8,  27,  28,  32,  33],
           ...,
           [ 91,  92,  96,  97, 116, 117, 121, 122],
           [ 92,  93,  97,  98, 117, 118, 122, 123],
           [ 93,  94,  98,  99, 118, 119, 123, 124]])
    """
    nx = len(xgrid)
    ny = len(ygrid)
    nz = len(zgrid)
    coords = np.zeros((nx, ny, nz, 3), dtype="float64", order="C")
    coords[:, :, :, 0] = xgrid[:, None, None]
    coords[:, :, :, 1] = ygrid[None, :, None]
    coords[:, :, :, 2] = zgrid[None, None, :]
    coords.shape = (nx * ny * nz, 3)
    cycle = np.rollaxis(np.indices((nx - 1, ny - 1, nz - 1)), 0, 4)
    cycle.shape = ((nx - 1) * (ny - 1) * (nz - 1), 3)
    off = _cis + cycle[:, np.newaxis]
    connectivity = np.array(
        ((off[:, :, 0] * ny) + off[:, :, 1]) * nz + off[:, :, 2], order="C"
    )
    return coords, connectivity 
[docs]
class StreamHexahedralMesh(SemiStructuredMesh):
    _connectivity_length = 8
    _index_offset = 0 
[docs]
class StreamHexahedralHierarchy(UnstructuredIndex):
    def __init__(self, ds, dataset_type=None):
        self.stream_handler = ds.stream_handler
        super().__init__(ds, dataset_type)
    def _initialize_mesh(self):
        coords = self.stream_handler.fields.pop("coordinates")
        connect = self.stream_handler.fields.pop("connectivity")
        self.meshes = [
            StreamHexahedralMesh(0, self.index_filename, connect, coords, self)
        ]
    def _setup_data_io(self):
        if self.stream_handler.io is not None:
            self.io = self.stream_handler.io
        else:
            self.io = io_registry[self.dataset_type](self.ds)
    def _detect_output_fields(self):
        self.field_list = list(set(self.stream_handler.get_fields())) 
[docs]
class StreamHexahedralDataset(StreamDataset):
    _index_class = StreamHexahedralHierarchy
    _field_info_class = StreamFieldInfo
    _dataset_type = "stream_hexahedral" 
[docs]
class StreamOctreeSubset(OctreeSubset):
    domain_id = 1
    _domain_offset = 1
    def __init__(self, base_region, ds, oct_handler, num_zones=2, num_ghost_zones=0):
        self._num_zones = num_zones
        self.field_data = YTFieldData()
        self.field_parameters = {}
        self.ds = ds
        self._oct_handler = oct_handler
        self._last_mask = None
        self._last_selector_id = None
        self._current_particle_type = "io"
        self._current_fluid_type = self.ds.default_fluid_type
        self.base_region = base_region
        self.base_selector = base_region.selector
        self._num_ghost_zones = num_ghost_zones
        if num_ghost_zones > 0:
            if not all(ds.periodicity):
                mylog.warning(
                    "Ghost zones will wrongly assume the domain to be periodic."
                )
            base_grid = StreamOctreeSubset(base_region, ds, oct_handler, num_zones)
            self._base_grid = base_grid
    @property
    def oct_handler(self):
        return self._oct_handler
[docs]
    def retrieve_ghost_zones(self, ngz, fields, smoothed=False):
        try:
            new_subset = self._subset_with_gz
            mylog.debug("Reusing previous subset with ghost zone.")
        except AttributeError:
            new_subset = StreamOctreeSubset(
                self.base_region,
                self.ds,
                self.oct_handler,
                self._num_zones,
                num_ghost_zones=ngz,
            )
            self._subset_with_gz = new_subset
        return new_subset 
    def _fill_no_ghostzones(self, content, dest, selector, offset):
        # Here we get a copy of the file, which we skip through and read the
        # bits we want.
        oct_handler = self.oct_handler
        cell_count = selector.count_oct_cells(self.oct_handler, self.domain_id)
        levels, cell_inds, file_inds = self.oct_handler.file_index_octs(
            selector, self.domain_id, cell_count
        )
        levels[:] = 0
        dest.update((field, np.empty(cell_count, dtype="float64")) for field in content)
        # Make references ...
        count = oct_handler.fill_level(
            0, levels, cell_inds, file_inds, dest, content, offset
        )
        return count
    def _fill_with_ghostzones(self, content, dest, selector, offset):
        oct_handler = self.oct_handler
        ndim = self.ds.dimensionality
        cell_count = (
            selector.count_octs(self.oct_handler, self.domain_id) * self.nz**ndim
        )
        gz_cache = getattr(self, "_ghost_zone_cache", None)
        if gz_cache:
            levels, cell_inds, file_inds, domains = gz_cache
        else:
            gz_cache = (
                levels,
                cell_inds,
                file_inds,
                domains,
            ) = oct_handler.file_index_octs_with_ghost_zones(
                selector, self.domain_id, cell_count
            )
            self._ghost_zone_cache = gz_cache
        levels[:] = 0
        dest.update((field, np.empty(cell_count, dtype="float64")) for field in content)
        # Make references ...
        oct_handler.fill_level(0, levels, cell_inds, file_inds, dest, content, offset)
[docs]
    def fill(self, content, dest, selector, offset):
        if self._num_ghost_zones == 0:
            return self._fill_no_ghostzones(content, dest, selector, offset)
        else:
            return self._fill_with_ghostzones(content, dest, selector, offset) 
 
[docs]
class StreamOctreeHandler(OctreeIndex):
    def __init__(self, ds, dataset_type=None):
        self.stream_handler = ds.stream_handler
        self.dataset_type = dataset_type
        super().__init__(ds, dataset_type)
    def _setup_data_io(self):
        if self.stream_handler.io is not None:
            self.io = self.stream_handler.io
        else:
            self.io = io_registry[self.dataset_type](self.ds)
    def _initialize_oct_handler(self):
        header = {
            "dims": self.ds.domain_dimensions // self.ds.num_zones,
            "left_edge": self.ds.domain_left_edge,
            "right_edge": self.ds.domain_right_edge,
            "octree": self.ds.octree_mask,
            "num_zones": self.ds.num_zones,
            "partial_coverage": self.ds.partial_coverage,
        }
        self.oct_handler = OctreeContainer.load_octree(header)
        # We do now need to get the maximum level set, as well.
        self.ds.max_level = self.oct_handler.max_level
    def _identify_base_chunk(self, dobj):
        if getattr(dobj, "_chunk_info", None) is None:
            base_region = getattr(dobj, "base_region", dobj)
            subset = [
                StreamOctreeSubset(
                    base_region,
                    self.dataset,
                    self.oct_handler,
                    self.ds.num_zones,
                )
            ]
            dobj._chunk_info = subset
        dobj._current_chunk = list(self._chunk_all(dobj))[0]
    def _chunk_all(self, dobj):
        oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
        yield YTDataChunk(dobj, "all", oobjs, None)
    def _chunk_spatial(self, dobj, ngz, sort=None, preload_fields=None):
        sobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
        # This is where we will perform cutting of the Octree and
        # load-balancing.  That may require a specialized selector object to
        # cut based on some space-filling curve index.
        for og in sobjs:
            if ngz > 0:
                g = og.retrieve_ghost_zones(ngz, [], smoothed=True)
            else:
                g = og
            yield YTDataChunk(dobj, "spatial", [g])
    def _chunk_io(self, dobj, cache=True, local_only=False):
        oobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
        for subset in oobjs:
            yield YTDataChunk(dobj, "io", [subset], None, cache=cache)
    def _setup_classes(self):
        dd = self._get_data_reader_dict()
        super()._setup_classes(dd)
    def _detect_output_fields(self):
        # NOTE: Because particle unions add to the actual field list, without
        # having the keys in the field list itself, we need to double check
        # here.
        fl = set(self.stream_handler.get_fields())
        fl.update(set(getattr(self, "field_list", [])))
        self.field_list = list(fl) 
[docs]
class StreamOctreeDataset(StreamDataset):
    _index_class = StreamOctreeHandler
    _field_info_class = StreamFieldInfo
    _dataset_type = "stream_octree"
    levelmax = None
    def __init__(
        self,
        stream_handler,
        storage_filename=None,
        geometry="cartesian",
        unit_system="cgs",
        default_species_fields=None,
    ):
        super().__init__(
            stream_handler,
            storage_filename,
            geometry,
            unit_system,
            default_species_fields=default_species_fields,
        )
        # Set up levelmax
        self.max_level = stream_handler.levels.max()
        self.min_level = stream_handler.levels.min() 
[docs]
class StreamUnstructuredMesh(UnstructuredMesh):
    _index_offset = 0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._connectivity_length = self.connectivity_indices.shape[1] 
[docs]
class StreamUnstructuredIndex(UnstructuredIndex):
    def __init__(self, ds, dataset_type=None):
        self.stream_handler = ds.stream_handler
        super().__init__(ds, dataset_type)
    def _initialize_mesh(self):
        coords = self.stream_handler.fields.pop("coordinates")
        connect = always_iterable(self.stream_handler.fields.pop("connectivity"))
        self.meshes = [
            StreamUnstructuredMesh(i, self.index_filename, c1, c2, self)
            for i, (c1, c2) in enumerate(zip(connect, repeat(coords)))
        ]
        self.mesh_union = MeshUnion("mesh_union", self.meshes)
    def _setup_data_io(self):
        if self.stream_handler.io is not None:
            self.io = self.stream_handler.io
        else:
            self.io = io_registry[self.dataset_type](self.ds)
    def _detect_output_fields(self):
        self.field_list = list(set(self.stream_handler.get_fields()))
        fnames = list({fn for ft, fn in self.field_list})
        self.field_list += [("all", fname) for fname in fnames] 
[docs]
class StreamUnstructuredMeshDataset(StreamDataset):
    _index_class = StreamUnstructuredIndex
    _field_info_class = StreamFieldInfo
    _dataset_type = "stream_unstructured"
    def _find_particle_types(self):
        pass