import os
import re
import weakref
import numpy as np
from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
from yt.data_objects.static_output import Dataset
from yt.fields.magnetic_field import get_magnetic_normalization
from yt.funcs import mylog, sglob
from yt.geometry.api import Geometry
from yt.geometry.geometry_handler import YTDataChunk
from yt.geometry.grid_geometry_handler import GridIndex
from yt.utilities.chemical_formulas import compute_mu
from yt.utilities.decompose import decompose_array, get_psize
from yt.utilities.lib.misc_utilities import get_box_grids_level
from .fields import AthenaFieldInfo
[docs]
def chk23(strin):
return strin.encode("utf-8")
[docs]
def str23(strin):
if isinstance(strin, list):
return [s.decode("utf-8") for s in strin]
else:
return strin.decode("utf-8")
[docs]
def check_readline(fl):
line = fl.readline()
chk = chk23("SCALARS")
if chk in line and not line.startswith(chk):
line = line[line.find(chk) :]
chk = chk23("VECTORS")
if chk in line and not line.startswith(chk):
line = line[line.find(chk) :]
return line
[docs]
def check_break(line):
splitup = line.strip().split()
do_break = chk23("SCALAR") in splitup
do_break = (chk23("VECTOR") in splitup) & do_break
do_break = (chk23("TABLE") in splitup) & do_break
do_break = (len(line) == 0) & do_break
return do_break
def _get_convert(fname):
def _conv(data):
return data.convert(fname)
return _conv
[docs]
class AthenaGrid(AMRGridPatch):
_id_offset = 0
def __init__(self, id, index, level, start, dimensions, file_offset, read_dims):
gname = index.grid_filenames[id]
AMRGridPatch.__init__(self, id, filename=gname, index=index)
self.filename = gname
self.Parent = []
self.Children = []
self.Level = level
self.start_index = start.copy()
self.stop_index = self.start_index + dimensions
self.ActiveDimensions = dimensions.copy()
self.file_offset = file_offset
self.read_dims = read_dims
def _setup_dx(self):
# So first we figure out what the index is. We don't assume
# that dx=dy=dz , at least here. We probably do elsewhere.
id = self.id - self._id_offset
if len(self.Parent) > 0:
self.dds = self.Parent[0].dds / self.ds.refine_by
else:
LE, RE = self.index.grid_left_edge[id, :], self.index.grid_right_edge[id, :]
self.dds = self.ds.arr((RE - LE) / self.ActiveDimensions, "code_length")
if self.ds.dimensionality < 2:
self.dds[1] = 1.0
if self.ds.dimensionality < 3:
self.dds[2] = 1.0
self.field_data["dx"], self.field_data["dy"], self.field_data["dz"] = self.dds
[docs]
def parse_line(line, grid):
# grid is a dictionary
splitup = line.strip().split()
if chk23("vtk") in splitup:
grid["vtk_version"] = str23(splitup[-1])
elif chk23("time=") in splitup:
time_index = splitup.index(chk23("time="))
grid["time"] = float(str23(splitup[time_index + 1]).rstrip(","))
grid["level"] = int(str23(splitup[time_index + 3]).rstrip(","))
grid["domain"] = int(str23(splitup[time_index + 5]).rstrip(","))
elif chk23("DIMENSIONS") in splitup:
grid["dimensions"] = np.array(str23(splitup[-3:]), dtype="int")
elif chk23("ORIGIN") in splitup:
grid["left_edge"] = np.array(str23(splitup[-3:]), dtype="float64")
elif chk23("SPACING") in splitup:
grid["dds"] = np.array(str23(splitup[-3:]), dtype="float64")
elif chk23("CELL_DATA") in splitup or chk23("POINT_DATA") in splitup:
grid["ncells"] = int(str23(splitup[-1]))
elif chk23("SCALARS") in splitup:
field = str23(splitup[1])
grid["read_field"] = field
grid["read_type"] = "scalar"
elif chk23("VECTORS") in splitup:
field = str23(splitup[1])
grid["read_field"] = field
grid["read_type"] = "vector"
elif chk23("time") in splitup:
time_index = splitup.index(chk23("time"))
grid["time"] = float(str23(splitup[time_index + 1]))
[docs]
class AthenaHierarchy(GridIndex):
grid = AthenaGrid
_dataset_type = "athena"
_data_file = None
def __init__(self, ds, dataset_type="athena"):
self.dataset = weakref.proxy(ds)
self.directory = os.path.dirname(self.dataset.filename)
self.dataset_type = dataset_type
# for now, the index file is the dataset!
self.index_filename = os.path.join(os.getcwd(), self.dataset.filename)
self._fhandle = open(self.index_filename, "rb")
GridIndex.__init__(self, ds, dataset_type)
self._fhandle.close()
def _detect_output_fields(self):
field_map = {}
f = open(self.index_filename, "rb")
line = check_readline(f)
chkwhile = chk23("")
while line != chkwhile:
splitup = line.strip().split()
chkd = chk23("DIMENSIONS")
chkc = chk23("CELL_DATA")
chkp = chk23("POINT_DATA")
if chkd in splitup:
field = str23(splitup[-3:])
grid_dims = np.array(field, dtype="int64")
line = check_readline(f)
elif chkc in splitup or chkp in splitup:
grid_ncells = int(str23(splitup[-1]))
line = check_readline(f)
if np.prod(grid_dims) != grid_ncells:
grid_dims -= 1
grid_dims[grid_dims == 0] = 1
if np.prod(grid_dims) != grid_ncells:
mylog.error(
"product of dimensions %i not equal to number of cells %i",
np.prod(grid_dims),
grid_ncells,
)
raise TypeError
break
else:
line = check_readline(f)
read_table = False
read_table_offset = f.tell()
while line != chkwhile:
splitup = line.strip().split()
chks = chk23("SCALARS")
chkv = chk23("VECTORS")
if chks in line and chks not in splitup:
splitup = str23(line[line.find(chks) :].strip().split())
if chkv in line and chkv not in splitup:
splitup = str23(line[line.find(chkv) :].strip().split())
if chks in splitup:
field = ("athena", str23(splitup[1]))
dtype = str23(splitup[-1]).lower()
if not read_table:
line = check_readline(f) # Read the lookup table line
read_table = True
field_map[field] = ("scalar", f.tell() - read_table_offset, dtype)
read_table = False
elif chkv in splitup:
field = str23(splitup[1])
dtype = str23(splitup[-1]).lower()
for ax in "xyz":
field_map["athena", f"{field}_{ax}"] = (
"vector",
f.tell() - read_table_offset,
dtype,
)
line = check_readline(f)
f.close()
self.field_list = list(field_map.keys())
self._field_map = field_map
def _count_grids(self):
self.num_grids = self.dataset.nvtk * self.dataset.nprocs
def _parse_index(self):
f = open(self.index_filename, "rb")
grid = {}
grid["read_field"] = None
grid["read_type"] = None
line = f.readline()
while grid["read_field"] is None:
parse_line(line, grid)
if check_break(line):
break
line = f.readline()
f.close()
# It seems some datasets have a mismatch between ncells and
# the actual grid dimensions.
if np.prod(grid["dimensions"]) != grid["ncells"]:
grid["dimensions"] -= 1
grid["dimensions"][grid["dimensions"] == 0] = 1
if np.prod(grid["dimensions"]) != grid["ncells"]:
mylog.error(
"product of dimensions %i not equal to number of cells %i",
np.prod(grid["dimensions"]),
grid["ncells"],
)
raise TypeError
# Need to determine how many grids: self.num_grids
dataset_dir = os.path.dirname(self.index_filename)
dname = os.path.split(self.index_filename)[-1]
if dataset_dir.endswith("id0"):
dname = "id0/" + dname
dataset_dir = dataset_dir[:-3]
gridlistread = sglob(
os.path.join(dataset_dir, f"id*/{dname[4:-9]}-id*{dname[-9:]}")
)
gridlistread.insert(0, self.index_filename)
if "id0" in dname:
gridlistread += sglob(
os.path.join(dataset_dir, f"id*/lev*/{dname[4:-9]}*-lev*{dname[-9:]}")
)
else:
gridlistread += sglob(
os.path.join(dataset_dir, f"lev*/{dname[:-9]}*-lev*{dname[-9:]}")
)
ndots = dname.count(".")
gridlistread = [
fn for fn in gridlistread if os.path.basename(fn).count(".") == ndots
]
self.num_grids = len(gridlistread)
dxs = []
levels = np.zeros(self.num_grids, dtype="int32")
glis = np.empty((self.num_grids, 3), dtype="float64")
gdds = np.empty((self.num_grids, 3), dtype="float64")
gdims = np.ones_like(glis)
j = 0
self.grid_filenames = gridlistread
while j < (self.num_grids):
f = open(gridlistread[j], "rb")
gridread = {}
gridread["read_field"] = None
gridread["read_type"] = None
line = f.readline()
while gridread["read_field"] is None:
parse_line(line, gridread)
splitup = line.strip().split()
if chk23("X_COORDINATES") in splitup:
gridread["left_edge"] = np.zeros(3)
gridread["dds"] = np.zeros(3)
v = np.fromfile(f, dtype=">f8", count=2)
gridread["left_edge"][0] = v[0] - 0.5 * (v[1] - v[0])
gridread["dds"][0] = v[1] - v[0]
if chk23("Y_COORDINATES") in splitup:
v = np.fromfile(f, dtype=">f8", count=2)
gridread["left_edge"][1] = v[0] - 0.5 * (v[1] - v[0])
gridread["dds"][1] = v[1] - v[0]
if chk23("Z_COORDINATES") in splitup:
v = np.fromfile(f, dtype=">f8", count=2)
gridread["left_edge"][2] = v[0] - 0.5 * (v[1] - v[0])
gridread["dds"][2] = v[1] - v[0]
if check_break(line):
break
line = f.readline()
f.close()
levels[j] = gridread.get("level", 0)
glis[j, 0] = gridread["left_edge"][0]
glis[j, 1] = gridread["left_edge"][1]
glis[j, 2] = gridread["left_edge"][2]
# It seems some datasets have a mismatch between ncells and
# the actual grid dimensions.
if np.prod(gridread["dimensions"]) != gridread["ncells"]:
gridread["dimensions"] -= 1
gridread["dimensions"][gridread["dimensions"] == 0] = 1
if np.prod(gridread["dimensions"]) != gridread["ncells"]:
mylog.error(
"product of dimensions %i not equal to number of cells %i",
np.prod(gridread["dimensions"]),
gridread["ncells"],
)
raise TypeError
gdims[j, 0] = gridread["dimensions"][0]
gdims[j, 1] = gridread["dimensions"][1]
gdims[j, 2] = gridread["dimensions"][2]
# Setting dds=1 for non-active dimensions in 1D/2D datasets
gridread["dds"][gridread["dimensions"] == 1] = 1.0
gdds[j, :] = gridread["dds"]
j = j + 1
gres = glis + gdims * gdds
# Now we convert the glis, which were left edges (floats), to indices
# from the domain left edge. Then we do a bunch of fixing now that we
# know the extent of all the grids.
glis = np.round(
(glis - self.dataset.domain_left_edge.ndarray_view()) / gdds
).astype("int64")
new_dre = np.max(gres, axis=0)
dre_units = self.dataset.domain_right_edge.uq
self.dataset.domain_right_edge = np.round(new_dre, decimals=12) * dre_units
self.dataset.domain_width = (
self.dataset.domain_right_edge - self.dataset.domain_left_edge
)
self.dataset.domain_center = 0.5 * (
self.dataset.domain_left_edge + self.dataset.domain_right_edge
)
domain_dimensions = np.round(self.dataset.domain_width / gdds[0]).astype(
"int64"
)
if self.dataset.dimensionality <= 2:
domain_dimensions[2] = 1
if self.dataset.dimensionality == 1:
domain_dimensions[1] = 1
self.dataset.domain_dimensions = domain_dimensions
dle = self.dataset.domain_left_edge
dre = self.dataset.domain_right_edge
dx_root = (
self.dataset.domain_right_edge - self.dataset.domain_left_edge
) / self.dataset.domain_dimensions
if self.dataset.nprocs > 1:
gle_all = []
gre_all = []
shapes_all = []
levels_all = []
new_gridfilenames = []
file_offsets = []
read_dims = []
for i in range(levels.shape[0]):
dx = dx_root / self.dataset.refine_by ** (levels[i])
gle_orig = self.ds.arr(
np.round(dle + dx * glis[i], decimals=12), "code_length"
)
gre_orig = self.ds.arr(
np.round(gle_orig + dx * gdims[i], decimals=12), "code_length"
)
bbox = np.array(
[[le, re] for le, re in zip(gle_orig, gre_orig, strict=True)]
)
psize = get_psize(self.ds.domain_dimensions, self.ds.nprocs)
gle, gre, shapes, slices, _ = decompose_array(gdims[i], psize, bbox)
gle_all += gle
gre_all += gre
shapes_all += shapes
levels_all += [levels[i]] * self.dataset.nprocs
new_gridfilenames += [self.grid_filenames[i]] * self.dataset.nprocs
file_offsets += [
[slc[0].start, slc[1].start, slc[2].start] for slc in slices
]
read_dims += [
np.array([gdims[i][0], gdims[i][1], shape[2]], dtype="int64")
for shape in shapes
]
self.num_grids *= self.dataset.nprocs
self.grids = np.empty(self.num_grids, dtype="object")
self.grid_filenames = new_gridfilenames
self.grid_left_edge = self.ds.arr(gle_all, "code_length")
self.grid_right_edge = self.ds.arr(gre_all, "code_length")
self.grid_dimensions = np.array(list(shapes_all), dtype="int32")
gdds = (self.grid_right_edge - self.grid_left_edge) / self.grid_dimensions
glis = np.round(
(self.grid_left_edge - self.ds.domain_left_edge) / gdds
).astype("int64")
for i in range(self.num_grids):
self.grids[i] = self.grid(
i,
self,
levels_all[i],
glis[i],
shapes_all[i],
file_offsets[i],
read_dims[i],
)
else:
self.grids = np.empty(self.num_grids, dtype="object")
for i in range(levels.shape[0]):
self.grids[i] = self.grid(
i, self, levels[i], glis[i], gdims[i], [0] * 3, gdims[i]
)
dx = dx_root / self.dataset.refine_by ** (levels[i])
dxs.append(dx)
dx = self.ds.arr(dxs, "code_length")
self.grid_left_edge = self.ds.arr(
np.round(dle + dx * glis, decimals=12), "code_length"
)
self.grid_dimensions = gdims.astype("int32")
self.grid_right_edge = self.ds.arr(
np.round(self.grid_left_edge + dx * self.grid_dimensions, decimals=12),
"code_length",
)
if self.dataset.dimensionality <= 2:
self.grid_right_edge[:, 2] = dre[2]
if self.dataset.dimensionality == 1:
self.grid_right_edge[:, 1:] = dre[1:]
self.grid_particle_count = np.zeros([self.num_grids, 1], dtype="int64")
def _populate_grid_objects(self):
for g in self.grids:
g._prepare_grid()
g._setup_dx()
self._reconstruct_parent_child()
self.max_level = self.grid_levels.max()
def _reconstruct_parent_child(self):
mask = np.empty(len(self.grids), dtype="int32")
mylog.debug("First pass; identifying child grids")
for i, grid in enumerate(self.grids):
get_box_grids_level(
self.grid_left_edge[i, :],
self.grid_right_edge[i, :],
self.grid_levels[i].item() + 1,
self.grid_left_edge,
self.grid_right_edge,
self.grid_levels,
mask,
)
grid.Children = [
g for g in self.grids[mask.astype("bool")] if g.Level == grid.Level + 1
]
mylog.debug("Second pass; identifying parents")
for grid in self.grids: # Second pass
for child in grid.Children:
child.Parent.append(grid)
def _get_grid_children(self, grid):
mask = np.zeros(self.num_grids, dtype="bool")
grids, grid_ind = self.get_box_grids(grid.LeftEdge, grid.RightEdge)
mask[grid_ind] = True
return [g for g in self.grids[mask] if g.Level == grid.Level + 1]
def _chunk_io(self, dobj, cache=True, local_only=False):
gobjs = getattr(dobj._current_chunk, "objs", dobj._chunk_info)
for subset in gobjs:
yield YTDataChunk(
dobj, "io", [subset], self._count_selection(dobj, [subset]), cache=cache
)
[docs]
class AthenaDataset(Dataset):
_index_class = AthenaHierarchy
_field_info_class = AthenaFieldInfo
_dataset_type = "athena"
def __init__(
self,
filename,
dataset_type="athena",
storage_filename=None,
parameters=None,
units_override=None,
nprocs=1,
unit_system="cgs",
default_species_fields=None,
magnetic_normalization="gaussian",
):
self.fluid_types += ("athena",)
self.nprocs = nprocs
if parameters is None:
parameters = {}
self.specified_parameters = parameters.copy()
if units_override is None:
units_override = {}
self._magnetic_factor = get_magnetic_normalization(magnetic_normalization)
Dataset.__init__(
self,
filename,
dataset_type,
units_override=units_override,
unit_system=unit_system,
default_species_fields=default_species_fields,
)
if storage_filename is None:
storage_filename = self.basename + ".yt"
self.storage_filename = storage_filename
# Unfortunately we now have to mandate that the index gets
# instantiated so that we can make sure we have the correct left
# and right domain edges.
self.index
def _set_code_unit_attributes(self):
"""
Generates the conversion to various physical _units based on the
parameter file
"""
if "length_unit" not in self.units_override:
self.no_cgs_equiv_length = True
for unit, cgs in [("length", "cm"), ("time", "s"), ("mass", "g")]:
# We set these to cgs for now, but they may have been overridden
if getattr(self, unit + "_unit", None) is not None:
continue
mylog.warning("Assuming 1.0 = 1.0 %s", cgs)
setattr(self, f"{unit}_unit", self.quan(1.0, cgs))
self.magnetic_unit = np.sqrt(
self._magnetic_factor
* self.mass_unit
/ (self.time_unit**2 * self.length_unit)
)
self.magnetic_unit.convert_to_units("gauss")
self.velocity_unit = self.length_unit / self.time_unit
def _parse_parameter_file(self):
self._handle = open(self.parameter_filename, "rb")
# Read the start of a grid to get simulation parameters.
grid = {}
grid["read_field"] = None
line = self._handle.readline()
while grid["read_field"] is None:
parse_line(line, grid)
splitup = line.strip().split()
if chk23("X_COORDINATES") in splitup:
grid["left_edge"] = np.zeros(3)
grid["dds"] = np.zeros(3)
v = np.fromfile(self._handle, dtype=">f8", count=2)
grid["left_edge"][0] = v[0] - 0.5 * (v[1] - v[0])
grid["dds"][0] = v[1] - v[0]
if chk23("Y_COORDINATES") in splitup:
v = np.fromfile(self._handle, dtype=">f8", count=2)
grid["left_edge"][1] = v[0] - 0.5 * (v[1] - v[0])
grid["dds"][1] = v[1] - v[0]
if chk23("Z_COORDINATES") in splitup:
v = np.fromfile(self._handle, dtype=">f8", count=2)
grid["left_edge"][2] = v[0] - 0.5 * (v[1] - v[0])
grid["dds"][2] = v[1] - v[0]
if check_break(line):
break
line = self._handle.readline()
self.domain_left_edge = grid["left_edge"]
mylog.info(
"Temporarily setting domain_right_edge = -domain_left_edge. "
"This will be corrected automatically if it is not the case."
)
self.domain_right_edge = -self.domain_left_edge
self.domain_width = self.domain_right_edge - self.domain_left_edge
domain_dimensions = np.round(self.domain_width / grid["dds"]).astype("int32")
refine_by = None
if refine_by is None:
refine_by = 2
self.refine_by = refine_by
dimensionality = 3
if grid["dimensions"][2] == 1:
dimensionality = 2
if grid["dimensions"][1] == 1:
dimensionality = 1
if dimensionality <= 2:
domain_dimensions[2] = np.int32(1)
if dimensionality == 1:
domain_dimensions[1] = np.int32(1)
if dimensionality != 3 and self.nprocs > 1:
raise RuntimeError("Virtual grids are only supported for 3D outputs!")
self.domain_dimensions = domain_dimensions
self.dimensionality = dimensionality
self.current_time = grid["time"]
self.cosmological_simulation = False
self.num_ghost_zones = 0
self.field_ordering = "fortran"
self.boundary_conditions = [1] * 6
self._periodicity = tuple(
self.specified_parameters.get("periodicity", (True, True, True))
)
if "gamma" in self.specified_parameters:
self.gamma = float(self.specified_parameters["gamma"])
else:
self.gamma = 5.0 / 3.0
dataset_dir = os.path.dirname(self.parameter_filename)
dname = os.path.split(self.parameter_filename)[-1]
if dataset_dir.endswith("id0"):
dname = "id0/" + dname
dataset_dir = dataset_dir[:-3]
gridlistread = sglob(
os.path.join(dataset_dir, f"id*/{dname[4:-9]}-id*{dname[-9:]}")
)
if "id0" in dname:
gridlistread += sglob(
os.path.join(dataset_dir, f"id*/lev*/{dname[4:-9]}*-lev*{dname[-9:]}")
)
else:
gridlistread += sglob(
os.path.join(dataset_dir, f"lev*/{dname[:-9]}*-lev*{dname[-9:]}")
)
ndots = dname.count(".")
gridlistread = [
fn for fn in gridlistread if os.path.basename(fn).count(".") == ndots
]
self.nvtk = len(gridlistread) + 1
self.current_redshift = 0.0
self.omega_lambda = 0.0
self.omega_matter = 0.0
self.hubble_constant = 0.0
self.cosmological_simulation = 0
# Hardcode time conversion for now.
self.parameters["Time"] = self.current_time
# Hardcode for now until field staggering is supported.
self.parameters["HydroMethod"] = 0
if "gamma" in self.specified_parameters:
self.parameters["Gamma"] = self.specified_parameters["gamma"]
else:
self.parameters["Gamma"] = 5.0 / 3.0
self.geometry = Geometry(self.specified_parameters.get("geometry", "cartesian"))
self._handle.close()
self.mu = self.specified_parameters.get(
"mu", compute_mu(self.default_species_fields)
)
@classmethod
def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
if not filename.endswith(".vtk"):
return False
with open(filename, "rb") as fh:
if not re.match(b"# vtk DataFile Version \\d\\.\\d\n", fh.readline(256)):
return False
if (
re.search(
b"at time= .*, level= \\d, domain= \\d\n",
fh.readline(256),
)
is None
):
# vtk Dumps headers start with either "CONSERVED vars" or "PRIMITIVE vars",
# while vtk output headers start with "Really cool Athena data", but
# we will consider this an implementation detail and not attempt to
# match it exactly here.
# See Athena's user guide for reference
# https://princetonuniversity.github.io/Athena-Cversion/AthenaDocsUGbtk
return False
return True
@property
def _skip_cache(self):
return True
def __str__(self):
return self.basename.rsplit(".", 1)[0]