"""
Title: framework.py
Purpose: Contains answer tests that are used by yt's various frontends
"""
import contextlib
import hashlib
import logging
import os
import pickle
import shelve
import sys
import tempfile
import time
import warnings
import zlib
from collections import defaultdict
import numpy as np
from matplotlib import image as mpimg
from matplotlib.testing.compare import compare_images
from nose.plugins import Plugin
from numpy.testing import assert_almost_equal, assert_equal
from yt.config import ytcfg
from yt.data_objects.static_output import Dataset
from yt.funcs import get_pbar, get_yt_version
from yt.loaders import load, load_simulation
from yt.testing import (
assert_allclose_units,
assert_rel_equal,
skipif,
)
from yt.utilities.exceptions import (
YTAmbiguousDataType,
YTCloudError,
YTNoAnswerNameSpecified,
YTNoOldAnswer,
YTUnidentifiedDataType,
)
from yt.utilities.logger import disable_stream_logging
from yt.visualization import (
image_writer as image_writer,
particle_plots as particle_plots,
plot_window as pw,
profile_plotter as profile_plotter,
)
mylog = logging.getLogger("nose.plugins.answer-testing")
run_big_data = False
# Set the latest gold and local standard filenames
_latest = ytcfg.get("yt", "gold_standard_filename")
_latest_local = ytcfg.get("yt", "local_standard_filename")
_url_path = ytcfg.get("yt", "answer_tests_url")
[docs]
class AnswerTesting(Plugin):
name = "answer-testing"
_my_version = None
[docs]
def options(self, parser, env=os.environ):
super().options(parser, env=env)
parser.add_option(
"--answer-name",
dest="answer_name",
metavar="str",
default=None,
help="The name of the standard to store/compare against",
)
parser.add_option(
"--answer-store",
dest="store_results",
metavar="bool",
default=False,
action="store_true",
help="Should we store this result instead of comparing?",
)
parser.add_option(
"--local",
dest="local_results",
default=False,
action="store_true",
help="Store/load reference results locally?",
)
parser.add_option(
"--answer-big-data",
dest="big_data",
default=False,
help="Should we run against big data, too?",
action="store_true",
)
parser.add_option(
"--local-dir",
dest="output_dir",
metavar="str",
help="The name of the directory to store local results",
)
@property
def my_version(self, version=None):
if self._my_version is not None:
return self._my_version
if version is None:
try:
version = get_yt_version()
except Exception:
version = f"UNKNOWN{time.time()}"
self._my_version = version
return self._my_version
[docs]
def finalize(self, result=None):
if not self.store_results:
return
self.storage.dump(self.result_storage)
[docs]
def help(self):
return "yt answer testing support"
[docs]
class AnswerTestStorage:
def __init__(self, reference_name=None, answer_name=None):
self.reference_name = reference_name
self.answer_name = answer_name
self.cache = {}
[docs]
def dump(self, result_storage, result):
raise NotImplementedError
[docs]
def get(self, ds_name, default=None):
raise NotImplementedError
[docs]
class AnswerTestCloudStorage(AnswerTestStorage):
[docs]
def get(self, ds_name, default=None):
import urllib.error
import urllib.request
if self.reference_name is None:
return default
if ds_name in self.cache:
return self.cache[ds_name]
url = _url_path.format(self.reference_name, ds_name)
try:
resp = urllib.request.urlopen(url)
except urllib.error.HTTPError as exc:
raise YTNoOldAnswer(url) from exc
else:
for _ in range(3):
try:
data = resp.read()
except Exception:
time.sleep(0.01)
else:
# We were successful
break
else:
# Raise error if all tries were unsuccessful
raise YTCloudError(url)
# This is dangerous, but we have a controlled S3 environment
rv = pickle.loads(data)
self.cache[ds_name] = rv
return rv
[docs]
def progress_callback(self, current, total):
self.pbar.update(current)
[docs]
def dump(self, result_storage):
if self.answer_name is None:
return
# This is where we dump our result storage up to Amazon, if we are able
# to.
import pyrax
credentials = os.path.expanduser(os.path.join("~", ".yt", "rackspace"))
pyrax.set_credential_file(credentials)
cf = pyrax.cloudfiles
c = cf.get_container("yt-answer-tests")
pb = get_pbar("Storing results ", len(result_storage))
for i, ds_name in enumerate(result_storage):
pb.update(i + 1)
rs = pickle.dumps(result_storage[ds_name])
object_name = f"{self.answer_name}_{ds_name}"
if object_name in c.get_object_names():
obj = c.get_object(object_name)
c.delete_object(obj)
c.store_object(object_name, rs)
pb.finish()
[docs]
class AnswerTestLocalStorage(AnswerTestStorage):
[docs]
def dump(self, result_storage):
# The 'tainted' attribute is automatically set to 'True'
# if the dataset required for an answer test is missing
# (see can_run_ds().
# This logic check prevents creating a shelve with empty answers.
storage_is_tainted = result_storage.get("tainted", False)
if self.answer_name is None or storage_is_tainted:
return
# Store data using shelve
ds = shelve.open(self.answer_name, protocol=-1)
for ds_name in result_storage:
answer_name = f"{ds_name}"
if answer_name in ds:
mylog.info("Overwriting %s", answer_name)
ds[answer_name] = result_storage[ds_name]
ds.close()
[docs]
def get(self, ds_name, default=None):
if self.reference_name is None:
return default
# Read data using shelve
answer_name = f"{ds_name}"
os.makedirs(os.path.dirname(self.reference_name), exist_ok=True)
ds = shelve.open(self.reference_name, protocol=-1)
try:
result = ds[answer_name]
except KeyError:
result = default
ds.close()
return result
[docs]
@contextlib.contextmanager
def temp_cwd(cwd):
oldcwd = os.getcwd()
os.chdir(cwd)
yield
os.chdir(oldcwd)
[docs]
def can_run_ds(ds_fn, file_check=False):
result_storage = AnswerTestingTest.result_storage
if isinstance(ds_fn, Dataset):
return result_storage is not None
path = ytcfg.get("yt", "test_data_dir")
if not os.path.isdir(path):
return False
if file_check:
return os.path.isfile(os.path.join(path, ds_fn)) and result_storage is not None
try:
load(ds_fn)
except FileNotFoundError:
if ytcfg.get("yt", "internals", "strict_requires"):
if result_storage is not None:
result_storage["tainted"] = True
raise
return False
except (YTUnidentifiedDataType, YTAmbiguousDataType):
return False
return result_storage is not None
[docs]
def data_dir_load(ds_fn, cls=None, args=None, kwargs=None):
args = args or ()
kwargs = kwargs or {}
path = ytcfg.get("yt", "test_data_dir")
if isinstance(ds_fn, Dataset):
return ds_fn
if not os.path.isdir(path):
return False
if cls is None:
ds = load(ds_fn, *args, **kwargs)
else:
ds = cls(os.path.join(path, ds_fn), *args, **kwargs)
ds.index
return ds
[docs]
def data_dir_load_v2(fn, *args, **kwargs):
# a version of data_dir_load without type flexibility
# that is simpler to reason about
path = os.path.join(ytcfg.get("yt", "test_data_dir"), fn)
return load(path, *args, **kwargs)
[docs]
def sim_dir_load(sim_fn, path=None, sim_type="Enzo", find_outputs=False):
if path is None and not os.path.exists(sim_fn):
raise OSError
if os.path.exists(sim_fn) or not path:
path = "."
return load_simulation(
os.path.join(path, sim_fn), sim_type, find_outputs=find_outputs
)
[docs]
class AnswerTestingTest:
reference_storage = None
result_storage = None
prefix = ""
options = None
# This variable should be set if we are not providing `--answer-name` as
# command line parameter while running yt's answer testing using nosetests.
answer_name = None
def __init__(self, ds_fn):
if ds_fn is None:
self.ds = None
elif isinstance(ds_fn, Dataset):
self.ds = ds_fn
else:
self.ds = data_dir_load(ds_fn, kwargs={"unit_system": "code"})
def __call__(self):
if AnswerTestingTest.result_storage is None:
return
nv = self.run()
# Test answer name should be provided either as command line parameters
# or by setting AnswerTestingTest.answer_name
if self.options.answer_name is None and self.answer_name is None:
raise YTNoAnswerNameSpecified()
# This is for running answer test when `--answer-name` is not set in
# nosetests command line arguments. In this case, set the answer_name
# from the `answer_name` keyword in the test case
if self.options.answer_name is None:
pyver = f"py{sys.version_info.major}{sys.version_info.minor}"
self.answer_name = f"{pyver}_{self.answer_name}"
answer_store_dir = os.path.realpath(self.options.output_dir)
ref_name = os.path.join(
answer_store_dir, self.answer_name, self.answer_name
)
self.reference_storage.reference_name = ref_name
self.reference_storage.answer_name = ref_name
# If we are generating golden answers (passed --answer-store arg):
# - create the answer directory for this test
# - self.reference_storage.answer_name will be path to answer files
if self.options.store_results:
answer_test_dir = os.path.join(answer_store_dir, self.answer_name)
if not os.path.isdir(answer_test_dir):
os.makedirs(answer_test_dir)
self.reference_storage.reference_name = None
if self.reference_storage.reference_name is not None:
# Compare test generated values against the golden answer
dd = self.reference_storage.get(self.storage_name)
if dd is None or self.description not in dd:
raise YTNoOldAnswer(f"{self.storage_name} : {self.description}")
ov = dd[self.description]
self.compare(nv, ov)
else:
# Store results, hence do nothing (in case of --answer-store arg)
ov = None
self.result_storage[self.storage_name][self.description] = nv
@property
def storage_name(self):
if self.prefix != "":
return f"{self.prefix}_{self.ds}"
return str(self.ds)
[docs]
def compare(self, new_result, old_result):
raise RuntimeError
[docs]
def create_plot(self, ds, plot_type, plot_field, plot_axis, plot_kwargs=None):
# plot_type should be a string
# plot_kwargs should be a dict
if plot_type is None:
raise RuntimeError("Must explicitly request a plot type")
cls = getattr(pw, plot_type, None)
if cls is None:
cls = getattr(particle_plots, plot_type)
plot = cls(*(ds, plot_axis, plot_field), **plot_kwargs)
return plot
@property
def sim_center(self):
"""
This returns the center of the domain.
"""
return 0.5 * (self.ds.domain_right_edge + self.ds.domain_left_edge)
@property
def max_dens_location(self):
"""
This is a helper function to return the location of the most dense
point.
"""
return self.ds.find_max(("gas", "density"))[1]
@property
def entire_simulation(self):
"""
Return an unsorted array of values that cover the entire domain.
"""
return self.ds.all_data()
@property
def description(self):
obj_type = getattr(self, "obj_type", None)
if obj_type is None:
oname = "all"
else:
oname = "_".join(str(s) for s in obj_type)
args = [self._type_name, str(self.ds), oname]
args += [str(getattr(self, an)) for an in self._attrs]
suffix = getattr(self, "suffix", None)
if suffix:
args.append(suffix)
return "_".join(args).replace(".", "_")
[docs]
class FieldValuesTest(AnswerTestingTest):
_type_name = "FieldValues"
_attrs = ("field",)
def __init__(self, ds_fn, field, obj_type=None, particle_type=False, decimals=10):
super().__init__(ds_fn)
self.obj_type = obj_type
self.field = field
self.particle_type = particle_type
self.decimals = decimals
[docs]
def run(self):
obj = create_obj(self.ds, self.obj_type)
field = obj._determine_fields(self.field)[0]
fd = self.ds.field_info[field]
if self.particle_type:
weight_field = (field[0], "particle_ones")
elif fd.is_sph_field:
weight_field = (field[0], "ones")
else:
weight_field = ("index", "ones")
avg = obj.quantities.weighted_average_quantity(field, weight=weight_field)
mi, ma = obj.quantities.extrema(self.field)
return [avg, mi, ma]
[docs]
def compare(self, new_result, old_result):
err_msg = f"Field values for {self.field} not equal."
if hasattr(new_result, "d"):
new_result = new_result.d
if hasattr(old_result, "d"):
old_result = old_result.d
if self.decimals is None:
assert_equal(new_result, old_result, err_msg=err_msg, verbose=True)
else:
# What we do here is check if the old_result has units; if not, we
# assume they will be the same as the units of new_result.
if isinstance(old_result, np.ndarray) and not hasattr(
old_result, "in_units"
):
# coerce it here to the same units
old_result = old_result * new_result[0].uq
assert_allclose_units(
new_result,
old_result,
10.0 ** (-self.decimals),
err_msg=err_msg,
verbose=True,
)
[docs]
class AllFieldValuesTest(AnswerTestingTest):
_type_name = "AllFieldValues"
_attrs = ("field",)
def __init__(self, ds_fn, field, obj_type=None, decimals=None):
super().__init__(ds_fn)
self.obj_type = obj_type
self.field = field
self.decimals = decimals
[docs]
def run(self):
obj = create_obj(self.ds, self.obj_type)
return obj[self.field]
[docs]
def compare(self, new_result, old_result):
err_msg = f"All field values for {self.field} not equal."
if hasattr(new_result, "d"):
new_result = new_result.d
if hasattr(old_result, "d"):
old_result = old_result.d
if self.decimals is None:
assert_equal(new_result, old_result, err_msg=err_msg, verbose=True)
else:
assert_rel_equal(
new_result, old_result, self.decimals, err_msg=err_msg, verbose=True
)
[docs]
class ProjectionValuesTest(AnswerTestingTest):
_type_name = "ProjectionValues"
_attrs = ("field", "axis", "weight_field")
def __init__(
self, ds_fn, axis, field, weight_field=None, obj_type=None, decimals=10
):
super().__init__(ds_fn)
self.axis = axis
self.field = field
self.weight_field = weight_field
self.obj_type = obj_type
self.decimals = decimals
[docs]
def run(self):
if self.obj_type is not None:
obj = create_obj(self.ds, self.obj_type)
else:
obj = None
if self.ds.domain_dimensions[self.axis] == 1:
return None
proj = self.ds.proj(
self.field, self.axis, weight_field=self.weight_field, data_source=obj
)
return proj.field_data
[docs]
def compare(self, new_result, old_result):
if new_result is None:
return
assert len(new_result) == len(old_result)
nind, oind = None, None
for k in new_result:
assert k in old_result
if oind is None:
oind = np.array(np.isnan(old_result[k]))
np.logical_or(oind, np.isnan(old_result[k]), oind)
if nind is None:
nind = np.array(np.isnan(new_result[k]))
np.logical_or(nind, np.isnan(new_result[k]), nind)
oind = ~oind
nind = ~nind
for k in new_result:
err_msg = f"{k} values of {self.field} ({self.weight_field} weighted) projection (axis {self.axis}) not equal."
if k == "weight_field":
# Our weight_field can vary between unit systems, whereas we
# can do a unitful comparison for the other fields. So we do
# not do the test here.
continue
nres, ores = new_result[k][nind], old_result[k][oind]
if hasattr(nres, "d"):
nres = nres.d
if hasattr(ores, "d"):
ores = ores.d
if self.decimals is None:
assert_equal(nres, ores, err_msg=err_msg)
else:
assert_allclose_units(
nres, ores, 10.0**-(self.decimals), err_msg=err_msg
)
[docs]
class PixelizedProjectionValuesTest(AnswerTestingTest):
_type_name = "PixelizedProjectionValues"
_attrs = ("field", "axis", "weight_field")
def __init__(self, ds_fn, axis, field, weight_field=None, obj_type=None):
super().__init__(ds_fn)
self.axis = axis
self.field = field
self.weight_field = weight_field
self.obj_type = obj_type
def _get_frb(self, obj):
proj = self.ds.proj(
self.field, self.axis, weight_field=self.weight_field, data_source=obj
)
frb = proj.to_frb((1.0, "unitary"), 256)
return proj, frb
[docs]
def run(self):
if self.obj_type is not None:
obj = create_obj(self.ds, self.obj_type)
else:
obj = None
proj, frb = self._get_frb(obj)
frb.render(self.field)
if self.weight_field is not None:
frb.render(self.weight_field)
d = frb.data
for f in proj.field_data:
# Sometimes f will be a tuple.
d[f"{f}_sum"] = proj.field_data[f].sum(dtype="float64")
return d
[docs]
def compare(self, new_result, old_result):
assert len(new_result) == len(old_result)
for k in new_result:
assert k in old_result
for k in new_result:
# weight_field does not have units, so we do not directly compare them
if k == "weight_field_sum":
continue
try:
assert_allclose_units(new_result[k], old_result[k], 1e-10)
except AssertionError:
dump_images(new_result[k], old_result[k])
raise
[docs]
class PixelizedParticleProjectionValuesTest(PixelizedProjectionValuesTest):
def _get_frb(self, obj):
proj_plot = particle_plots.ParticleProjectionPlot(
self.ds, self.axis, [self.field], weight_field=self.weight_field
)
return proj_plot.data_source, proj_plot.frb
[docs]
class GridValuesTest(AnswerTestingTest):
_type_name = "GridValues"
_attrs = ("field",)
def __init__(self, ds_fn, field):
super().__init__(ds_fn)
self.field = field
[docs]
def run(self):
hashes = {}
for g in self.ds.index.grids:
hashes[g.id] = hashlib.md5(g[self.field].tobytes()).hexdigest()
g.clear_data()
return hashes
[docs]
def compare(self, new_result, old_result):
assert len(new_result) == len(old_result)
for k in new_result:
assert k in old_result
for k in new_result:
if hasattr(new_result[k], "d"):
new_result[k] = new_result[k].d
if hasattr(old_result[k], "d"):
old_result[k] = old_result[k].d
assert_equal(new_result[k], old_result[k])
[docs]
class VerifySimulationSameTest(AnswerTestingTest):
_type_name = "VerifySimulationSame"
_attrs = ()
def __init__(self, simulation_obj):
self.ds = simulation_obj
[docs]
def run(self):
result = [ds.current_time for ds in self.ds]
return result
[docs]
def compare(self, new_result, old_result):
assert_equal(
len(new_result),
len(old_result),
err_msg="Number of outputs not equal.",
verbose=True,
)
for i in range(len(new_result)):
assert_equal(
new_result[i],
old_result[i],
err_msg="Output times not equal.",
verbose=True,
)
[docs]
class GridHierarchyTest(AnswerTestingTest):
_type_name = "GridHierarchy"
_attrs = ()
[docs]
def run(self):
result = {}
result["grid_dimensions"] = self.ds.index.grid_dimensions
result["grid_left_edges"] = self.ds.index.grid_left_edge
result["grid_right_edges"] = self.ds.index.grid_right_edge
result["grid_levels"] = self.ds.index.grid_levels
result["grid_particle_count"] = self.ds.index.grid_particle_count
return result
[docs]
def compare(self, new_result, old_result):
for k in new_result:
if hasattr(new_result[k], "d"):
new_result[k] = new_result[k].d
if hasattr(old_result[k], "d"):
old_result[k] = old_result[k].d
assert_equal(new_result[k], old_result[k])
[docs]
class ParentageRelationshipsTest(AnswerTestingTest):
_type_name = "ParentageRelationships"
_attrs = ()
[docs]
def run(self):
result = {}
result["parents"] = []
result["children"] = []
for g in self.ds.index.grids:
p = g.Parent
if p is None:
result["parents"].append(None)
elif hasattr(p, "id"):
result["parents"].append(p.id)
else:
result["parents"].append([pg.id for pg in p])
result["children"].append([c.id for c in g.Children])
return result
[docs]
def compare(self, new_result, old_result):
for newp, oldp in zip(
new_result["parents"],
old_result["parents"],
strict=True,
):
assert newp == oldp
for newc, oldc in zip(
new_result["children"],
old_result["children"],
strict=True,
):
assert newc == oldc
[docs]
def dump_images(new_result, old_result, decimals=10):
tmpfd, old_image = tempfile.mkstemp(prefix="baseline_", suffix=".png")
os.close(tmpfd)
tmpfd, new_image = tempfile.mkstemp(prefix="thisPR_", suffix=".png")
os.close(tmpfd)
image_writer.write_projection(new_result, new_image)
image_writer.write_projection(old_result, old_image)
results = compare_images(old_image, new_image, 10 ** (-decimals))
if results is not None:
tempfiles = [
line.strip() for line in results.split("\n") if line.endswith(".png")
]
for fn in tempfiles:
sys.stderr.write(f"\n[[ATTACHMENT|{fn}]]")
sys.stderr.write("\n")
[docs]
def ensure_image_comparability(a, b):
# pad nans to the right and the bottom of two images to make them comparable
# via matplotlib if they do not have the same shape
if a.shape == b.shape:
return a, b
assert a.shape[2:] == b.shape[2:]
warnings.warn(
f"Images have different shapes {a.shape} and {b.shape}. "
"Padding nans to make them comparable.",
stacklevel=2,
)
smallest_containing_shape = (
max(a.shape[0], b.shape[0]),
max(a.shape[1], b.shape[1]),
*a.shape[2:],
)
pa = np.full(smallest_containing_shape, np.nan)
pa[: a.shape[0], : a.shape[1], ...] = a
pb = np.full(smallest_containing_shape, np.nan)
pb[: b.shape[0], : b.shape[1], ...] = b
return pa, pb
[docs]
def compare_image_lists(new_result, old_result, decimals):
fns = []
for _ in range(2):
tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
os.close(tmpfd)
fns.append(tmpname)
num_images = len(old_result)
assert num_images > 0
for i in range(num_images):
expected = pickle.loads(zlib.decompress(old_result[i]))
actual = pickle.loads(zlib.decompress(new_result[i]))
expected_p, actual_p = ensure_image_comparability(expected, actual)
mpimg.imsave(fns[0], expected_p)
mpimg.imsave(fns[1], actual_p)
results = compare_images(fns[0], fns[1], 10 ** (-decimals))
if results is not None:
tempfiles = [
line.strip() for line in results.split("\n") if line.endswith(".png")
]
for fn, img, padded in zip(
tempfiles,
(expected, actual),
(expected_p, actual_p),
strict=True,
):
# padded images are convenient for comparison
# but what we really want to store and upload
# are the actual results
if padded.shape != img.shape:
mpimg.imsave(fn, img)
if os.environ.get("JENKINS_HOME") is not None:
for fn in tempfiles:
sys.stderr.write(f"\n[[ATTACHMENT|{fn}]]")
sys.stderr.write("\n")
assert_equal(results, None, results)
for fn in fns:
os.remove(fn)
[docs]
class PlotWindowAttributeTest(AnswerTestingTest):
_type_name = "PlotWindowAttribute"
_attrs = (
"plot_type",
"plot_field",
"plot_axis",
"attr_name",
"attr_args",
"callback_id",
)
def __init__(
self,
ds_fn: str,
plot_field: str,
plot_axis: str,
attr_name: str | None = None,
attr_args: tuple | None = None,
decimals: int | None = 12,
plot_type: str | None = "SlicePlot",
callback_id: str | None = "",
callback_runners: tuple | None = None,
):
super().__init__(ds_fn)
self.plot_type = plot_type
self.plot_field = plot_field
self.plot_axis = plot_axis
self.plot_kwargs = {}
self.attr_name = attr_name
self.attr_args = attr_args
self.decimals = decimals
# callback_id is so that we don't have to hash the actual callbacks
# run, but instead we call them something
self.callback_id = callback_id
if callback_runners is None:
callback_runners = ()
self.callback_runners = callback_runners
[docs]
def run(self):
plot = self.create_plot(
self.ds, self.plot_type, self.plot_field, self.plot_axis, self.plot_kwargs
)
for r in self.callback_runners:
r(self, plot)
if self.attr_name and self.attr_args:
attr = getattr(plot, self.attr_name)
attr(*self.attr_args[0], **self.attr_args[1])
tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
os.close(tmpfd)
plot.save(name=tmpname)
image = mpimg.imread(tmpname)
os.remove(tmpname)
return [zlib.compress(image.dumps())]
[docs]
def compare(self, new_result, old_result):
compare_image_lists(new_result, old_result, self.decimals)
[docs]
class PhasePlotAttributeTest(AnswerTestingTest):
_type_name = "PhasePlotAttribute"
_attrs = ("plot_type", "x_field", "y_field", "z_field", "attr_name", "attr_args")
def __init__(
self,
ds_fn,
x_field,
y_field,
z_field,
attr_name,
attr_args,
decimals,
plot_type="PhasePlot",
):
super().__init__(ds_fn)
self.data_source = self.ds.all_data()
self.plot_type = plot_type
self.x_field = x_field
self.y_field = y_field
self.z_field = z_field
self.plot_kwargs = {}
self.attr_name = attr_name
self.attr_args = attr_args
self.decimals = decimals
[docs]
def create_plot(
self, data_source, x_field, y_field, z_field, plot_type, plot_kwargs=None
):
# plot_type should be a string
# plot_kwargs should be a dict
if plot_type is None:
raise RuntimeError("Must explicitly request a plot type")
cls = getattr(profile_plotter, plot_type, None)
if cls is None:
cls = getattr(particle_plots, plot_type)
plot = cls(*(data_source, x_field, y_field, z_field), **plot_kwargs)
return plot
[docs]
def run(self):
plot = self.create_plot(
self.data_source,
self.x_field,
self.y_field,
self.z_field,
self.plot_type,
self.plot_kwargs,
)
attr = getattr(plot, self.attr_name)
attr(*self.attr_args[0], **self.attr_args[1])
tmpfd, tmpname = tempfile.mkstemp(suffix=".png")
os.close(tmpfd)
plot.save(name=tmpname)
image = mpimg.imread(tmpname)
os.remove(tmpname)
return [zlib.compress(image.dumps())]
[docs]
def compare(self, new_result, old_result):
compare_image_lists(new_result, old_result, self.decimals)
[docs]
class GenericArrayTest(AnswerTestingTest):
_type_name = "GenericArray"
_attrs = ("array_func_name", "args", "kwargs")
def __init__(self, ds_fn, array_func, args=None, kwargs=None, decimals=None):
super().__init__(ds_fn)
self.array_func = array_func
self.array_func_name = array_func.__name__
self.args = args
self.kwargs = kwargs
self.decimals = decimals
[docs]
def run(self):
if self.args is None:
args = []
else:
args = self.args
if self.kwargs is None:
kwargs = {}
else:
kwargs = self.kwargs
return self.array_func(*args, **kwargs)
[docs]
def compare(self, new_result, old_result):
if not isinstance(new_result, dict):
new_result = {"answer": new_result}
old_result = {"answer": old_result}
assert_equal(
len(new_result),
len(old_result),
err_msg="Number of outputs not equal.",
verbose=True,
)
for k in new_result:
if hasattr(new_result[k], "d"):
new_result[k] = new_result[k].d
if hasattr(old_result[k], "d"):
old_result[k] = old_result[k].d
if self.decimals is None:
assert_almost_equal(new_result[k], old_result[k])
else:
assert_allclose_units(
new_result[k], old_result[k], 10 ** (-self.decimals)
)
[docs]
class AxialPixelizationTest(AnswerTestingTest):
# This test is typically used once per geometry or coordinates type.
# Feed it a dataset, and it checks that the results of basic pixelization
# don't change.
_type_name = "AxialPixelization"
_attrs = ("geometry",)
def __init__(self, ds_fn, decimals=None):
super().__init__(ds_fn)
self.decimals = decimals
self.geometry = self.ds.coordinates.name
[docs]
def run(self):
rv = {}
ds = self.ds
for i, axis in enumerate(ds.coordinates.axis_order):
(bounds, center, display_center) = pw.get_window_parameters(
axis, ds.domain_center, None, ds
)
slc = ds.slice(axis, center[i])
xax = ds.coordinates.axis_name[ds.coordinates.x_axis[axis]]
yax = ds.coordinates.axis_name[ds.coordinates.y_axis[axis]]
pix_x = ds.coordinates.pixelize(axis, slc, ("gas", xax), bounds, (512, 512))
pix_y = ds.coordinates.pixelize(axis, slc, ("gas", yax), bounds, (512, 512))
# Wipe out invalid values (fillers)
pix_x[~np.isfinite(pix_x)] = 0.0
pix_y[~np.isfinite(pix_y)] = 0.0
rv[f"{axis}_x"] = pix_x
rv[f"{axis}_y"] = pix_y
return rv
[docs]
def compare(self, new_result, old_result):
assert_equal(
len(new_result),
len(old_result),
err_msg="Number of outputs not equal.",
verbose=True,
)
for k in new_result:
if hasattr(new_result[k], "d"):
new_result[k] = new_result[k].d
if hasattr(old_result[k], "d"):
old_result[k] = old_result[k].d
if self.decimals is None:
assert_almost_equal(new_result[k], old_result[k])
else:
assert_allclose_units(
new_result[k], old_result[k], 10 ** (-self.decimals)
)
[docs]
def requires_answer_testing():
return skipif(
AnswerTestingTest.result_storage is None,
reason="answer testing storage is not properly setup",
)
[docs]
def requires_ds(ds_fn, big_data=False, file_check=False):
condition = (big_data and not run_big_data) or not can_run_ds(ds_fn, file_check)
return skipif(condition, reason=f"cannot load dataset {ds_fn}")
[docs]
def small_patch_amr(ds_fn, fields, input_center="max", input_weight=("gas", "density")):
if not can_run_ds(ds_fn):
return
dso = [None, ("sphere", (input_center, (0.1, "unitary")))]
yield GridHierarchyTest(ds_fn)
yield ParentageRelationshipsTest(ds_fn)
for field in fields:
yield GridValuesTest(ds_fn, field)
for dobj_name in dso:
for axis in [0, 1, 2]:
for weight_field in [None, input_weight]:
yield ProjectionValuesTest(
ds_fn, axis, field, weight_field, dobj_name
)
yield FieldValuesTest(ds_fn, field, dobj_name)
[docs]
def big_patch_amr(ds_fn, fields, input_center="max", input_weight=("gas", "density")):
if not can_run_ds(ds_fn):
return
dso = [None, ("sphere", (input_center, (0.1, "unitary")))]
yield GridHierarchyTest(ds_fn)
yield ParentageRelationshipsTest(ds_fn)
for field in fields:
yield GridValuesTest(ds_fn, field)
for axis in [0, 1, 2]:
for dobj_name in dso:
for weight_field in [None, input_weight]:
yield PixelizedProjectionValuesTest(
ds_fn, axis, field, weight_field, dobj_name
)
def _particle_answers(
ds, ds_str_repr, ds_nparticles, fields, proj_test_class, center="c"
):
if not can_run_ds(ds):
return
assert_equal(str(ds), ds_str_repr)
dso = [None, ("sphere", (center, (0.1, "unitary")))]
dd = ds.all_data()
# this needs to explicitly be "all"
assert_equal(dd["all", "particle_position"].shape, (ds_nparticles, 3))
tot = sum(
dd[ptype, "particle_position"].shape[0] for ptype in ds.particle_types_raw
)
assert_equal(tot, ds_nparticles)
for dobj_name in dso:
for field, weight_field in fields.items():
particle_type = field[0] in ds.particle_types
for axis in [0, 1, 2]:
if not particle_type:
yield proj_test_class(ds, axis, field, weight_field, dobj_name)
yield FieldValuesTest(ds, field, dobj_name, particle_type=particle_type)
[docs]
def nbody_answer(ds, ds_str_repr, ds_nparticles, fields, center="c"):
return _particle_answers(
ds,
ds_str_repr,
ds_nparticles,
fields,
PixelizedParticleProjectionValuesTest,
center=center,
)
[docs]
def sph_answer(ds, ds_str_repr, ds_nparticles, fields, center="c"):
return _particle_answers(
ds,
ds_str_repr,
ds_nparticles,
fields,
PixelizedProjectionValuesTest,
center=center,
)
[docs]
def create_obj(ds, obj_type):
# obj_type should be tuple of
# ( obj_name, ( args ) )
if obj_type is None:
return ds.all_data()
cls = getattr(ds, obj_type[0])
obj = cls(*obj_type[1])
return obj