"""Join multiple data objects together."""
# --- import --------------------------------------------------------------------------------------
import collections
import warnings
import numpy as np
from .. import kit as wt_kit
from .. import exceptions as wt_exceptions
from ._data import Data
from ..collection import Collection
# --- define --------------------------------------------------------------------------------------
__all__ = ["join"]
# --- functions -----------------------------------------------------------------------------------
[docs]
def join(
datas, *, atol=None, rtol=None, name="join", parent=None, method="first", verbose=True
) -> Data:
"""Join a list of data objects into one data object.
The underlying dataset arrays are merged.
Joined datas must have the same axes and axes order.
The axes define the array structure for the joined dataset.
As such, each axis must
* map to a single Variable
* project along one or no dimension of the data shape (i.e. axis shapes should have no more than one dimension with shape greater than 1)
* be orthogonal to all other axes
Join does not perform any interpolation.
For that, look to ``Data.map_variable`` or ``Data.heal``
Parameters
----------
datas : list of data or WrightTools.Collection
The list or collection of data objects to join together. Data must have matching axes.
atol : numeric or list of numeric
The absolute tolerance to use (in ``np.isclose``) to consider points overlapped.
If given as a single number, applies to all axes.
If given as a list, must have same length as the data transformation.
``None`` in the list invokes default behavior.
Default is 10% of the minimum spacing between consecutive points in any
input data file.
rtol : numeric or list of numeric
The relative tolerance to use (in ``np.isclose``) to consider points overlapped.
If given as a single number, applies to all axes.
If given as a list, must have same length as the data transformation.
``None`` in the list invokes default behavior.
Default is ``4 * np.finfo(dtype).resolution`` for floating point types,
``0`` for integer types.
name : str (optional)
The name for the data object which is created. Default is 'join'.
parent : WrightTools.Collection (optional)
The location to place the joined data object. Default is new temp file at root.
method : {'first', 'last', 'min', 'max', 'mean'}
Mode to use for merged points in the joined space.
Default is 'first'.
verbose : bool (optional)
Toggle talkback. Default is True.
Returns
-------
WrightTools.data.Data
A new Data instance.
"""
warnings.warn("join", category=wt_exceptions.EntireDatasetInMemoryWarning)
if isinstance(datas, Collection):
datas = datas.values()
valid_methods = ["first", "last", "min", "max", "mean"]
if method not in valid_methods:
if method == "sum":
raise ValueError(f"method 'sum' is deprecated; consider 'mean' instead.")
raise ValueError(f"invalid method {method!r}: expected {valid_methods}")
datas = list(datas)
if not isinstance(atol, collections.abc.Iterable):
atol = [atol] * len(datas[0].axes)
if not isinstance(rtol, collections.abc.Iterable):
rtol = [rtol] * len(datas[0].axes)
# check if variables are valid
axis_expressions = datas[0].axis_expressions
variable_names = set(datas[0].variable_names)
channel_names = set(datas[0].channel_names)
for d in datas[1:]:
if d.axis_expressions != axis_expressions:
raise wt_exceptions.ValueError("Joined data must have same axis_expressions")
for a in d.axes:
if a.variables[0][:].squeeze().ndim > 1:
raise wt_exceptions.MultidimensionalAxisError(a.natural_name, "join")
variable_names &= set(d.variable_names)
channel_names &= set(d.channel_names)
variable_names = list(variable_names)
channel_names = list(channel_names)
variable_units = []
channel_units = []
for v in variable_names:
variable_units.append(datas[0][v].units)
for c in channel_names:
channel_units.append(datas[0][c].units)
axis_variable_names = []
axis_variable_units = []
for a in datas[0].axes:
if len(a.variables) > 1:
raise wt_exceptions.ValueError("Applied transform must have single variable axes")
if a.variables[0][:].squeeze().ndim > 1:
raise wt_exceptions.MultidimensionalAxisError(a.natural_name, "join")
for v in a.variables:
axis_variable_names.append(v.natural_name)
axis_variable_units.append(v.units)
vs = collections.OrderedDict()
for n, units, atol_, rtol_ in zip(axis_variable_names, axis_variable_units, atol, rtol):
dtype = np.result_type(*[d[n].dtype for d in datas])
if atol_ is None:
try:
# 10% of the minimum spacing between consecutive points in any singular input data
atol_ = min(np.min(np.abs(np.diff(d[n][:]))) for d in datas if d[n].size > 1) * 0.1
except ValueError:
atol_ = 0
if rtol_ is None:
# Ignore floating point precision rounding, if dtype is floting
rtol_ = 4 * np.finfo(dtype).resolution if dtype.kind in "fcmM" else 0
values = np.concatenate([d[n][:].flat for d in datas])
values = np.sort(values)
filtered = []
i = 0
# Filter out consecutive values that are "close"
while i < len(values):
sum_ = values[i]
count = 1
i += 1
if i < len(values):
while np.isclose(values[i - 1], values[i], atol=atol_, rtol=rtol_):
sum_ += values[i]
count += 1
i += 1
if i >= len(values):
break
filtered.append(sum_ / count)
vs[n] = {"values": np.array(filtered), "units": units}
# TODO: the following should become a new from method
def from_dict(d, parent=None):
ndim = len(d)
i = 0
out = Data(name=name, parent=parent)
for k, v in d.items():
values = v["values"]
shape = [1] * ndim
shape[i] = values.size
values.shape = tuple(shape)
# **attrs passes the name and units as well
out.create_variable(values=values, **datas[0][k].attrs)
i += 1
out.transform(*list(d.keys()))
return out
def get_shape(out, datas, item_name):
shape = [1] * out.ndim
for i, s in enumerate(out.shape):
idx = [np.argmax(d[out.axes[i].expression].shape) for d in datas]
if any(d[item_name].shape[j] != 1 for d, j in zip(datas, idx)) or all(
d[out.axes[i].expression].size == 1 for d in datas
):
shape[i] = s
return shape
out = from_dict(vs, parent=parent)
count = {}
for channel_name in channel_names:
# **attrs passes the name and units as well
out.create_channel(
shape=get_shape(out, datas, channel_name),
**datas[0][channel_name].attrs,
dtype=datas[0][channel_name].dtype,
)
count[channel_name] = np.zeros_like(out[channel_name], dtype=int)
for variable_name in variable_names:
if variable_name not in vs.keys():
# **attrs passes the name and units as well
out.create_variable(
shape=get_shape(out, datas, variable_name),
**datas[0][variable_name].attrs,
dtype=datas[0][variable_name].dtype,
)
count[variable_name] = np.zeros_like(out[variable_name], dtype=int)
def combine(data, out, item_name, new_idx, transpose, slice_):
old = data[item_name]
new = out[item_name]
vals = np.empty_like(new)
# Default fill value based on whether dtype is floating or not
if vals.dtype.kind == "f":
vals[:] = np.nan
elif vals.dtype.kind == "M":
vals[:] = np.datetime64("NaT")
elif vals.dtype.kind == "c":
vals[:] = complex(np.nan, np.nan)
else:
vals[:] = 0
# Use advanced indexing to populate vals, a temporary array with same shape as out
valid_index = tuple(wt_kit.valid_index(new_idx, new.shape))
vals[valid_index] = old[:].transpose(transpose)[slice_]
# Overlap methods are accomplished by adding the existing array with the one added
# for this particular data. Thus locations which should be set, but conflict by
# the method chosen are set to 0. Handling for floating point vs. integer types may vary.
# For floating types, nan indicates invalid, and must be explicitly allowed to add in.
if method == "first":
# Set any locations which have already been populated
vals[~np.isnan(new[:])] = 0
if not vals.dtype.kind in "fcmM":
vals[count[item_name] > 0] = 0
elif method == "last":
# Reset points which are to be overwritten
new[~np.isnan(vals)] = 0
if not vals.dtype.kind in "fcmM":
new[valid_index] = 0
elif method == "min":
rep_new = new > vals
rep_vals = vals > new
new[rep_new] = 0
vals[rep_vals] = 0
elif method == "max":
rep_new = new < vals
rep_vals = vals < new
new[rep_new] = 0
vals[rep_vals] = 0
# Ensure that previously NaN points which have values are written
new[np.isnan(new) & ~np.isnan(vals)] = 0
# Ensure that new data does not overwrite any previous data with nan
vals[np.isnan(vals)] = 0
# Track how many times each point is set (for mean)
count[item_name][valid_index] += 1
new[:] += vals
for data in datas:
new_idx = []
transpose = []
slice_ = []
for variable_name in vs.keys():
# p is at most 1-D by precondition to join
p = data[variable_name].points
# If p not scalar, append the proper transposition to interop with out
# And do not add new axis
if np.ndim(p) > 0:
transpose.append(np.argmax(data[variable_name].shape))
slice_.append(slice(None))
# If p is scalar, a new axis must be added, no transpose needed
else:
slice_.append(np.newaxis)
# Triple subscripting needed because newaxis only applys to numpy array
# New axis added so that subtracting p will broadcast
arr = out[variable_name][:][..., np.newaxis]
i = np.argmin(np.abs(arr - p), axis=np.argmax(arr.shape))
# Reshape i, to match with the output shape
sh = [1] * i.ndim
sh[np.argmax(arr.shape)] = i.size
i.shape = sh
new_idx.append(i)
slice_ = tuple(slice_)
for variable_name in out.variable_names:
if variable_name not in vs.keys():
combine(data, out, variable_name, new_idx, transpose, slice_)
for channel_name in channel_names:
combine(data, out, channel_name, new_idx, transpose, slice_)
if method == "mean":
for name, c in count.items():
if out[name].dtype.kind in "fcmM":
out[name][:] /= c
else:
out[name][:] //= c
out.transform(*axis_expressions)
if verbose:
print(len(datas), "datas joined to create new data:")
print(" axes:")
for axis in out.axes:
points = axis[:]
print(
" {0} : {1} points from {2} to {3} {4}".format(
axis.expression, points.size, np.min(points), np.max(points), axis.units
)
)
print(" channels:")
for channel in out.channels:
percent_nan = np.around(
100.0 * (np.isnan(channel[:]).sum() / float(channel.size)), decimals=2
)
print(
" {0} : {1} to {2} ({3}% NaN)".format(
channel.name, channel.min(), channel.max(), percent_nan
)
)
return out