"""Axis class and associated."""
# --- import --------------------------------------------------------------------------------------
import re
import numexpr
import operator
import functools
import numpy as np
from .. import exceptions as wt_exceptions
from .. import kit as wt_kit
from .. import units as wt_units
# --- define --------------------------------------------------------------------------------------
__all__ = ["Axis"]
operator_to_identifier = {}
operator_to_identifier["/"] = "__d__"
operator_to_identifier["="] = "__e__"
operator_to_identifier["-"] = "__m__"
operator_to_identifier["+"] = "__p__"
operator_to_identifier["*"] = "__t__"
identifier_to_operator = {value: key for key, value in operator_to_identifier.items()}
operators = "".join(operator_to_identifier.keys())
# --- class ---------------------------------------------------------------------------------------
[docs]
class Axis(object):
"""Axis class."""
[docs]
def __init__(self, parent, expression, units=None):
"""Data axis.
Parameters
----------
parent : WrightTools.Data
Parent data object.
expression : string
Axis expression. Space characters are ignored.
units : string (optional)
Axis units. Default is None.
"""
self.parent = parent
self.expression = expression.replace(" ", "") # ignore spaces
if units is None:
self.units = self.variables[0].units
else:
self.units = units
def __getitem__(self, index):
vs = {}
for variable in self.variables:
arr = variable[index]
vs[variable.natural_name] = wt_units.converter(arr, variable.units, self.units)
return numexpr.evaluate(self.expression.split("=")[0], local_dict=vs)
def __repr__(self) -> str:
return "<WrightTools.Axis {0} ({1}) at {2}>".format(
self.expression, str(self.units), id(self)
)
@property
def _leaf(self):
out = self.expression
if self.units is not None:
out += " ({0})".format(self.units)
out += " {0}".format(self.shape)
return out
@property
def full(self) -> np.ndarray:
"""Axis expression evaluated and repeated to match the shape of the parent data object."""
arr = self[:]
for i in range(arr.ndim):
if arr.shape[i] == 1:
arr = np.repeat(arr, self.parent.shape[i], axis=i)
return arr
@property
def identity(self) -> str:
"""Complete identifier written to disk in data.attrs['axes']."""
return self.expression + " {%s}" % self.units
@property
def label(self) -> str:
"""A latex formatted label representing axis expression."""
label = self.expression.replace("_", "\\;")
if self.units_kind:
symbol = wt_units.get_symbol(self.units)
if symbol is not None:
for v in self.variables:
vl = "%s_{%s}" % (symbol, v.label)
vl = vl.replace("_{}", "") # label can be empty, no empty subscripts
label = label.replace(v.natural_name, vl)
label += rf"\,\left({wt_units.ureg.Unit(self.units):~}\right)"
label = r"$\mathsf{%s}$" % label
return label
@property
def natural_name(self) -> str:
"""Valid python identifier representation of the expession."""
name = self.expression.strip()
return wt_kit.string2identifier(name, replace=operator_to_identifier)
@property
def ndim(self) -> int:
"""Get number of dimensions."""
try:
assert self._ndim is not None
except (AssertionError, AttributeError):
self._ndim = self.variables[0].ndim
finally:
return self._ndim
@property
def points(self) -> np.ndarray:
"""Squeezed array."""
return np.squeeze(self[:])
@property
def shape(self) -> tuple:
"""Shape."""
return wt_kit.joint_shape(*self.variables)
@property
def size(self) -> int:
"""Size."""
return functools.reduce(operator.mul, self.shape)
@property
def units(self):
return self._units
@units.setter
def units(self, value):
if value == "None":
value = None
if value is not None and value not in wt_units.ureg:
raise ValueError(f"'{value}' is not in the unit registry")
self._units = value
@property
def units_kind(self) -> str:
"""Units kind."""
return wt_units.kind(self.units)
@property
def variables(self) -> list:
"""Variables."""
try:
assert self._variables is not None
except (AssertionError, AttributeError):
pattern = "|".join(map(re.escape, operators))
keys = re.split(pattern, self.expression)
indices = []
for key in keys:
if key in self.parent.variable_names:
indices.append(self.parent.variable_names.index(key))
self._variables = [self.parent.variables[i] for i in indices]
finally:
return self._variables
@property
def masked(self) -> np.ndarray:
"""Axis expression evaluated, and masked with NaN shared from data channels."""
arr = self[:]
arr.shape = self.shape
arr = wt_kit.share_nans(arr, *self.parent.channels)[0]
return np.nanmean(
arr, keepdims=True, axis=tuple(i for i in range(self.ndim) if self.shape[i] == 1)
)
[docs]
def convert(self, destination_units, *, convert_variables=False):
"""Convert axis to destination_units.
Parameters
----------
destination_units : string
Destination units.
convert_variables : boolean (optional)
Toggle conversion of stored arrays. Default is False.
"""
if self.units is None and (destination_units is None or destination_units == "None"):
return
if not wt_units.is_valid_conversion(self.units, destination_units):
valid = wt_units.get_valid_conversions(self.units)
raise wt_exceptions.UnitsError(valid, destination_units)
if convert_variables:
for v in self.variables:
v.convert(destination_units)
self.units = destination_units
self.parent._on_axes_updated()
[docs]
def max(self):
"""Axis max."""
return np.nanmax(self[:])
[docs]
def min(self):
"""Axis min."""
return np.nanmin(self[:])