import os
import sys
import warnings
from collections.abc import Callable
from pathlib import Path
from more_itertools import always_iterable
from yt.utilities.configuration_tree import ConfigLeaf, ConfigNode
configuration_callbacks: list[Callable[["YTConfig"], None]] = []
[docs]
def config_dir():
config_root = os.environ.get(
"XDG_CONFIG_HOME", os.path.join(os.path.expanduser("~"), ".config")
)
conf_dir = os.path.join(config_root, "yt")
return conf_dir
[docs]
class YTConfig:
def __init__(self, defaults=None):
if defaults is None:
defaults = {}
self.config_root = ConfigNode(None)
[docs]
def get(self, section, *keys, callback=None):
node_or_leaf = self.config_root.get(section, *keys)
if isinstance(node_or_leaf, ConfigLeaf):
if callback is not None:
return callback(node_or_leaf)
return node_or_leaf.value
return node_or_leaf
[docs]
def get_most_specific(self, section, *keys, **kwargs):
use_fallback = "fallback" in kwargs
fallback = kwargs.pop("fallback", None)
try:
return self.config_root.get_deepest_leaf(section, *keys)
except KeyError as err:
if use_fallback:
return fallback
else:
raise err
[docs]
def update(self, new_values, metadata=None):
if metadata is None:
metadata = {}
self.config_root.update(new_values, metadata)
[docs]
def has_section(self, section):
try:
self.config_root.get_child(section)
return True
except KeyError:
return False
[docs]
def add_section(self, section):
self.config_root.add_child(section)
[docs]
def remove_section(self, section):
if self.has_section(section):
self.config_root.remove_child(section)
return True
else:
return False
[docs]
def set(self, *args, metadata=None):
section, *keys, value = args
if metadata is None:
metadata = {"source": "runtime"}
self.config_root.upsert_from_list(
[section] + list(keys), value, extra_data=metadata
)
[docs]
def remove(self, *args):
self.config_root.pop_leaf(args)
[docs]
def read(self, file_names):
file_names_read = []
for fname in always_iterable(file_names):
if not os.path.exists(fname):
continue
metadata = {"source": f"file: {fname}"}
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
try:
with open(fname, "rb") as fh:
data = tomllib.load(fh)
except tomllib.TOMLDecodeError as exc:
warnings.warn(
f"Could not load configuration file {fname} (invalid TOML: {exc})",
stacklevel=2,
)
else:
self.update(data, metadata=metadata)
file_names_read.append(fname)
return file_names_read
[docs]
def write(self, file_handler):
import tomli_w
value = self.config_root.as_dict()
config_as_str = tomli_w.dumps(value)
try:
file_path = Path(file_handler)
except TypeError:
if not hasattr(file_handler, "write"):
raise TypeError(
f"Expected a path to a file, or a writable object, got {file_handler}"
) from None
file_handler.write(config_as_str)
else:
pdir = file_path.parent
if not pdir.exists():
warnings.warn(
f"{pdir!s} does not exist, creating it (recursively)", stacklevel=2
)
os.makedirs(pdir)
file_path.write_text(config_as_str)
[docs]
@staticmethod
def get_global_config_file():
return os.path.join(config_dir(), "yt.toml")
[docs]
@staticmethod
def get_local_config_file():
path = Path.cwd()
while path.parent is not path:
candidate = path.joinpath("yt.toml")
if candidate.is_file():
return os.path.abspath(candidate)
else:
path = path.parent
return os.path.join(os.path.abspath(os.curdir), "yt.toml")
def __setitem__(self, args, value):
section, *keys = always_iterable(args)
self.set(section, *keys, value, metadata=None)
def __getitem__(self, key):
section, *keys = always_iterable(key)
return self.get(section, *keys)
def __contains__(self, item):
return item in self.config_root
# Add support for IPython rich display
# see https://ipython.readthedocs.io/en/stable/config/integrating.html
def _repr_json_(self):
return self.config_root._repr_json_()
CONFIG = YTConfig()
def _cast_bool_helper(value):
if value in ("true", "True", True):
return True
elif value in ("false", "False", False):
return False
else:
raise ValueError("Cannot safely cast to bool")
def _expand_all(s):
return os.path.expandvars(os.path.expanduser(s))
def _cast_value_helper(value, types=(_cast_bool_helper, int, float, _expand_all)):
for t in types:
try:
retval = t(value)
return retval
except ValueError:
pass
[docs]
def get_config(section, option):
*option_path, option_name = option.split(".")
return CONFIG.get(section, *option_path, option_name)
[docs]
def set_config(section, option, value, config_file):
if not CONFIG.has_section(section):
CONFIG.add_section(section)
option_path = option.split(".")
CONFIG.set(section, *option_path, _cast_value_helper(value))
write_config(config_file)
[docs]
def write_config(config_file):
CONFIG.write(config_file)
[docs]
def rm_config(section, option, config_file):
option_path = option.split(".")
CONFIG.remove(section, *option_path)
write_config(config_file)