Source code for yt.geometry.geometry_handler

import abc
import os
import weakref
from typing import Optional

import numpy as np

from yt._maintenance.deprecation import issue_deprecation_warning
from yt.config import ytcfg
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.yt_array import YTArray
from yt.utilities.exceptions import YTFieldNotFound
from yt.utilities.io_handler import io_registry
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.parallel_tools.parallel_analysis_interface import (
    ParallelAnalysisInterface,
    parallel_root_only,
)


[docs] class Index(ParallelAnalysisInterface, abc.ABC): """The base index class""" _unsupported_objects: tuple[str, ...] = () _index_properties: tuple[str, ...] = () def __init__(self, ds, dataset_type): ParallelAnalysisInterface.__init__(self) self.dataset = weakref.proxy(ds) self.ds = self.dataset self._initialize_state_variables() mylog.debug("Initializing data storage.") self._initialize_data_storage() mylog.debug("Setting up domain geometry.") self._setup_geometry() mylog.debug("Initializing data grid data IO") self._setup_data_io() # Note that this falls under the "geometry" object since it's # potentially quite expensive, and should be done with the indexing. mylog.debug("Detecting fields.") self._detect_output_fields() @abc.abstractmethod def _detect_output_fields(self): pass def _icoords_to_fcoords( self, icoords: np.ndarray, ires: np.ndarray, axes: Optional[tuple[int, ...]] = None, ) -> tuple[np.ndarray, np.ndarray]: # What's the use of raising NotImplementedError for this, when it's an # abstract base class? Well, only *some* of the subclasses have it -- # and for those that *don't*, we should not be calling it -- and since # it's a semi-private method, it shouldn't be called outside of yt # machinery. So we shouldn't ever get here! raise NotImplementedError def _initialize_state_variables(self): self._parallel_locking = False self._data_file = None self._data_mode = None self.num_grids = None def _initialize_data_storage(self): if not ytcfg.get("yt", "serialize"): return fn = self.ds.storage_filename if fn is None: if os.path.isfile( os.path.join(self.directory, f"{self.ds.unique_identifier}.yt") ): fn = os.path.join(self.directory, f"{self.ds.unique_identifier}.yt") else: fn = os.path.join(self.directory, f"{self.dataset.basename}.yt") dir_to_check = os.path.dirname(fn) if dir_to_check == "": dir_to_check = "." # We have four options: # Writeable, does not exist : create, open as append # Writeable, does exist : open as append # Not writeable, does not exist : do not attempt to open # Not writeable, does exist : open as read-only exists = os.path.isfile(fn) if not exists: writeable = os.access(dir_to_check, os.W_OK) else: writeable = os.access(fn, os.W_OK) writeable = writeable and not ytcfg.get("yt", "only_deserialize") # We now have our conditional stuff self.comm.barrier() if not writeable and not exists: return if writeable: try: if not exists: self.__create_data_file(fn) self._data_mode = "a" except OSError: self._data_mode = None return else: self._data_mode = "r" self.__data_filename = fn self._data_file = h5py.File(fn, mode=self._data_mode) def __create_data_file(self, fn): # Note that this used to be parallel_root_only; it no longer is, # because we have better logic to decide who owns the file. f = h5py.File(fn, mode="a") f.close() def _setup_data_io(self): if getattr(self, "io", None) is not None: return self.io = io_registry[self.dataset_type](self.dataset)
[docs] @parallel_root_only def save_data( self, array, node, name, set_attr=None, force=False, passthrough=False ): """ Arbitrary numpy data will be saved to the region in the datafile described by *node* and *name*. If data file does not exist, it throws no error and simply does not save. """ if self._data_mode != "a": return try: node_loc = self._data_file[node] if name in node_loc and force: mylog.info("Overwriting node %s/%s", node, name) del self._data_file[node][name] elif name in node_loc and passthrough: return except Exception: pass myGroup = self._data_file["/"] for q in node.split("/"): if q: myGroup = myGroup.require_group(q) arr = myGroup.create_dataset(name, data=array) if set_attr is not None: for i, j in set_attr.items(): arr.attrs[i] = j self._data_file.flush()
def _reload_data_file(self, *args, **kwargs): if self._data_file is None: return self._data_file.close() del self._data_file self._data_file = h5py.File(self.__data_filename, mode=self._data_mode)
[docs] def get_data(self, node, name): """ Return the dataset with a given *name* located at *node* in the datafile. """ if self._data_file is None: return None if node[0] != "/": node = f"/{node}" myGroup = self._data_file["/"] for group in node.split("/"): if group: if group not in myGroup: return None myGroup = myGroup[group] if name not in myGroup: return None full_name = f"{node}/{name}" try: return self._data_file[full_name][:] except TypeError: return self._data_file[full_name]
def _get_particle_type_counts(self): # this is implemented by subclasses raise NotImplementedError def _close_data_file(self): if self._data_file: self._data_file.close() del self._data_file self._data_file = None def _split_fields(self, fields): # This will split fields into either generated or read fields fields_to_read, fields_to_generate = [], [] for ftype, fname in fields: if fname in self.field_list or (ftype, fname) in self.field_list: fields_to_read.append((ftype, fname)) elif ( fname in self.ds.derived_field_list or (ftype, fname) in self.ds.derived_field_list ): fields_to_generate.append((ftype, fname)) else: raise YTFieldNotFound((ftype, fname), self.ds) return fields_to_read, fields_to_generate def _read_particle_fields(self, fields, dobj, chunk=None): if len(fields) == 0: return {}, [] fields_to_read, fields_to_generate = self._split_fields(fields) if len(fields_to_read) == 0: return {}, fields_to_generate selector = dobj.selector if chunk is None: self._identify_base_chunk(dobj) chunks = self._chunk_io(dobj, cache=False) fields_to_return = self.io._read_particle_selection( chunks, selector, fields_to_read ) return fields_to_return, fields_to_generate def _read_fluid_fields(self, fields, dobj, chunk=None): if len(fields) == 0: return {}, [] fields_to_read, fields_to_generate = self._split_fields(fields) if len(fields_to_read) == 0: return {}, fields_to_generate selector = dobj.selector if chunk is None: self._identify_base_chunk(dobj) chunk_size = dobj.size else: chunk_size = chunk.data_size fields_to_return = self.io._read_fluid_selection( self._chunk_io(dobj), selector, fields_to_read, chunk_size ) return fields_to_return, fields_to_generate def _chunk(self, dobj, chunking_style, ngz=0, **kwargs): # A chunk is either None or (grids, size) if dobj._current_chunk is None: self._identify_base_chunk(dobj) if ngz != 0 and chunking_style != "spatial": raise NotImplementedError if chunking_style == "all": return self._chunk_all(dobj, **kwargs) elif chunking_style == "spatial": return self._chunk_spatial(dobj, ngz, **kwargs) elif chunking_style == "io": return self._chunk_io(dobj, **kwargs) else: raise NotImplementedError
[docs] def cacheable_property(func): # not quite equivalent to functools.cached_property # this decorator allows cached to be disabled via a self._cache flag attribute n = f"_{func.__name__}" @property def cacheable_func(self): if self._cache and getattr(self, n, None) is not None: return getattr(self, n) if self.data_size is None: tr = self._accumulate_values(n[1:]) else: tr = func(self) if self._cache: setattr(self, n, tr) return tr return cacheable_func
[docs] class YTDataChunk: def __init__( self, dobj, chunk_type, objs, data_size=None, field_type=None, cache=False, fast_index=None, ): self.dobj = dobj self.chunk_type = chunk_type self.objs = objs self.data_size = data_size self._field_type = field_type self._cache = cache self._fast_index = fast_index def _accumulate_values(self, method): # We call this generically. It's somewhat slower, since we're doing # costly getattr functions, but this allows us to generalize. mname = f"select_{method}" arrs = [] for obj in self._fast_index or self.objs: f = getattr(obj, mname) arrs.append(f(self.dobj)) if method == "dtcoords": arrs = [arr[0] for arr in arrs] elif method == "tcoords": arrs = [arr[1] for arr in arrs] arrs = uconcatenate(arrs) self.data_size = arrs.shape[0] return arrs @cacheable_property def fcoords(self): if self._fast_index is not None: ci = self._fast_index.select_fcoords(self.dobj.selector, self.data_size) ci = YTArray(ci, units="code_length", registry=self.dobj.ds.unit_registry) return ci ci = np.empty((self.data_size, 3), dtype="float64") ci = YTArray(ci, units="code_length", registry=self.dobj.ds.unit_registry) if self.data_size == 0: return ci ind = 0 for obj in self._fast_index or self.objs: c = obj.select_fcoords(self.dobj) if c.shape[0] == 0: continue ci.d[ind : ind + c.shape[0], :] = c ind += c.shape[0] return ci @cacheable_property def icoords(self): if self._fast_index is not None: ci = self._fast_index.select_icoords(self.dobj.selector, self.data_size) return ci ci = np.empty((self.data_size, 3), dtype="int64") if self.data_size == 0: return ci ind = 0 for obj in self._fast_index or self.objs: c = obj.select_icoords(self.dobj) if c.shape[0] == 0: continue ci[ind : ind + c.shape[0], :] = c ind += c.shape[0] return ci @cacheable_property def fwidth(self): if self._fast_index is not None: ci = self._fast_index.select_fwidth(self.dobj.selector, self.data_size) ci = YTArray(ci, units="code_length", registry=self.dobj.ds.unit_registry) return ci ci = np.empty((self.data_size, 3), dtype="float64") ci = YTArray(ci, units="code_length", registry=self.dobj.ds.unit_registry) if self.data_size == 0: return ci ind = 0 for obj in self._fast_index or self.objs: c = obj.select_fwidth(self.dobj) if c.shape[0] == 0: continue ci.d[ind : ind + c.shape[0], :] = c ind += c.shape[0] return ci @cacheable_property def ires(self): if self._fast_index is not None: ci = self._fast_index.select_ires(self.dobj.selector, self.data_size) return ci ci = np.empty(self.data_size, dtype="int64") if self.data_size == 0: return ci ind = 0 for obj in self._fast_index or self.objs: c = obj.select_ires(self.dobj) if c.shape == 0: continue ci[ind : ind + c.size] = c ind += c.size return ci @cacheable_property def tcoords(self): self.dtcoords return self._tcoords @cacheable_property def dtcoords(self): ct = np.empty(self.data_size, dtype="float64") cdt = np.empty(self.data_size, dtype="float64") self._tcoords = ct # Se this for tcoords if self.data_size == 0: return cdt ind = 0 for obj in self._fast_index or self.objs: gdt, gt = obj.select_tcoords(self.dobj) if gt.size == 0: continue ct[ind : ind + gt.size] = gt cdt[ind : ind + gdt.size] = gdt ind += gt.size return cdt @cacheable_property def fcoords_vertex(self): nodes_per_elem = self.dobj.index.meshes[0].connectivity_indices.shape[1] dim = self.dobj.ds.dimensionality ci = np.empty((self.data_size, nodes_per_elem, dim), dtype="float64") ci = YTArray(ci, units="code_length", registry=self.dobj.ds.unit_registry) if self.data_size == 0: return ci ind = 0 for obj in self.objs: c = obj.select_fcoords_vertex(self.dobj) if c.shape[0] == 0: continue ci.d[ind : ind + c.shape[0], :, :] = c ind += c.shape[0] return ci
[docs] class ChunkDataCache: def __init__(self, base_iter, preload_fields, geometry_handler, max_length=256): # At some point, max_length should instead become a heuristic function, # potentially looking at estimated memory usage. Note that this never # initializes the iterator; it assumes the iterator is already created, # and it calls next() on it. self.base_iter = base_iter.__iter__() self.queue = [] self.max_length = max_length self.preload_fields = preload_fields self.geometry_handler = geometry_handler self.cache = {} def __iter__(self): return self def __next__(self): if len(self.queue) == 0: for _ in range(self.max_length): try: self.queue.append(next(self.base_iter)) except StopIteration: break # If it's still zero ... if len(self.queue) == 0: raise StopIteration chunk = YTDataChunk(None, "cache", self.queue, cache=False) self.cache = ( self.geometry_handler.io._read_chunk_data(chunk, self.preload_fields) or {} ) g = self.queue.pop(0) g._initialize_cache(self.cache.pop(g.id, {})) return g
[docs] def is_curvilinear(geo): # tell geometry is curvilinear or not issue_deprecation_warning( "the is_curvilear() function is deprecated. " "Instead, compare the geometry object directly with yt.geometry.geometry_enum.Geometry " "enum members, as for instance:\n" "if is_curvilinear(geometry):\n ...\n" "should be rewritten as:" "if geometry is Geometry.POLAR or geometry is Geometry.CYLINDRICAL or geometry is Geometry.SPHERICAL:\n ...", stacklevel=3, since="4.2", ) if geo in ["polar", "cylindrical", "spherical"]: return True else: return False