Source code for yt.utilities.parallel_tools.io_runner

import time
from contextlib import contextmanager

import numpy as np

from yt.utilities.io_handler import BaseIOHandler
from yt.utilities.logger import ytLogger as mylog

from .parallel_analysis_interface import ProcessorPool, parallel_objects

try:
    from .parallel_analysis_interface import MPI
except ImportError:
    pass

YT_TAG_MESSAGE = 317  # Cell 317 knows where to go


[docs] class IOCommunicator(BaseIOHandler): def __init__(self, ds, wg, pool): mylog.info("Initializing IOCommunicator") self.ds = ds self.wg = wg # We don't need to use this! self.pool = pool self.comm = pool.comm # We read our grids here self.grids = [] storage = {} grids = ds.index.grids.tolist() grids.sort(key=lambda a: a.filename) for sto, g in parallel_objects(grids, storage=storage): sto.result = self.comm.rank sto.result_id = g.id self.grids.append(g) self._id_offset = ds.index.grids[0]._id_offset mylog.info("Reading from disk ...") self.initialize_data() mylog.info("Broadcasting ...") self.comm.comm.bcast(storage, root=wg.ranks[0]) mylog.info("Done.") self.hooks = []
[docs] def initialize_data(self): ds = self.ds fields = [ f for f in ds.field_list if not ds.field_info[f].sampling_type == "particle" ] pfields = [ f for f in ds.field_list if ds.field_info[f].sampling_type == "particle" ] # Preload is only defined for Enzo ... if ds.index.io._dataset_type == "enzo_packed_3d": self.queue = ds.index.io.queue ds.index.io.preload(self.grids, fields) for g in self.grids: for f in fields: if f not in self.queue[g.id]: d = np.zeros(g.ActiveDimensions, dtype="float64") self.queue[g.id][f] = d for f in pfields: self.queue[g.id][f] = self._read(g, f) else: self.queue = {} for g in self.grids: for f in fields + pfields: self.queue[g.id][f] = ds.index.io._read(g, f)
def _read(self, g, f): fi = self.ds.field_info[f] if fi.sampling_type == "particle" and g.NumberOfParticles == 0: # because this gets upcast to float return np.array([], dtype="float64") try: temp = self.ds.index.io._read_data_set(g, f) except Exception: # self.ds.index.io._read_exception as exc: if fi.not_in_all: temp = np.zeros(g.ActiveDimensions, dtype="float64") else: raise return temp
[docs] def wait(self): status = MPI.Status() while True: if self.comm.comm.Iprobe(MPI.ANY_SOURCE, YT_TAG_MESSAGE, status=status): msg = self.comm.comm.recv(source=status.source, tag=YT_TAG_MESSAGE) if msg["op"] == "end": mylog.debug("Shutting down IO.") break self._send_data(msg, status.source) status = MPI.Status() else: time.sleep(1e-2)
def _send_data(self, msg, dest): grid_id = msg["grid_id"] field = msg["field"] ts = self.queue[grid_id][field].astype("float64") mylog.debug("Opening send to %s (%s)", dest, ts.shape) self.hooks.append(self.comm.comm.Isend([ts, MPI.DOUBLE], dest=dest))
[docs] class IOHandlerRemote(BaseIOHandler): _dataset_type = "remote" def __init__(self, ds, wg, pool): self.ds = ds self.wg = wg # probably won't need self.pool = pool self.comm = pool.comm self.proc_map = self.comm.comm.bcast(None, root=pool["io"].ranks[0]) super().__init__() def _read_data_set(self, grid, field): dest = self.proc_map[grid.id] msg = {"grid_id": grid.id, "field": field, "op": "read"} mylog.debug("Requesting %s for %s from %s", field, grid, dest) if self.ds.field_info[field].sampling_type == "particle": data = np.empty(grid.NumberOfParticles, "float64") else: data = np.empty(grid.ActiveDimensions, "float64") hook = self.comm.comm.Irecv([data, MPI.DOUBLE], source=dest) self.comm.comm.send(msg, dest=dest, tag=YT_TAG_MESSAGE) mylog.debug("Waiting for data.") MPI.Request.Wait(hook) return data def _read_data_slice(self, grid, field, axis, coord): sl = [slice(None), slice(None), slice(None)] sl[axis] = slice(coord, coord + 1) # sl = tuple(reversed(sl)) return self._read_data_set(grid, field)[tuple(sl)]
[docs] def terminate(self): msg = {"op": "end"} if self.wg.comm.rank == 0: for rank in self.pool["io"].ranks: mylog.debug("Sending termination message to %s", rank) self.comm.comm.send(msg, dest=rank, tag=YT_TAG_MESSAGE)
[docs] @contextmanager def remote_io(ds, wg, pool): original_io = ds.index.io ds.index.io = IOHandlerRemote(ds, wg, pool) yield ds.index.io.terminate() ds.index.io = original_io
[docs] def io_nodes(fn, n_io, n_work, func, *args, **kwargs): from yt.loaders import load pool, wg = ProcessorPool.from_sizes([(n_io, "io"), (n_work, "work")]) rv = None if wg.name == "work": ds = load(fn) with remote_io(ds, wg, pool): rv = func(ds, *args, **kwargs) elif wg.name == "io": ds = load(fn) io = IOCommunicator(ds, wg, pool) io.wait() # We should broadcast the result rv = pool.comm.mpi_bcast(rv, root=pool["work"].ranks[0]) pool.free_all() mylog.debug("Return value: %s", rv) return rv
# Here is an example of how to use this functionality. if __name__ == "__main__": def gq(ds): dd = ds.all_data() return dd.quantities["TotalQuantity"](("gas", "cell_mass")) q = io_nodes("DD0087/DD0087", 8, 24, gq) mylog.info(q)