"""Quick plotting."""
# --- import --------------------------------------------------------------------------------------
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from contextlib import closing
from functools import reduce
from ._helpers import _title, create_figure, plot_colorbar, savefig
from .. import kit as wt_kit
# --- define --------------------------------------------------------------------------------------
__all__ = ["quick1D", "quick2D"]
# --- general purpose plotting functions ----------------------------------------------------------
[docs]
def quick1D(
data,
axis=0,
at={},
channel=0,
*,
local=False,
autosave=False,
save_directory=None,
fname=None,
verbose=True,
):
"""Quickly plot 1D slice(s) of data.
Parameters
----------
data : WrightTools.Data object
Data to plot.
axis : string or integer (optional)
Expression or index of axis. Default is 0.
at : dictionary (optional)
Dictionary of parameters in non-plotted dimension(s). If not
provided, plots will be made at each coordinate.
channel : string or integer (optional)
Name or index of channel to plot. Default is 0.
local : boolean (optional)
Toggle plotting locally. Default is False.
autosave : boolean (optional)
Toggle autosave. Default is False.
save_directory : string (optional)
Location to save image(s). Default is None (auto-generated).
fname : string (optional)
File name. If None, data name is used. Default is None.
verbose : boolean (optional)
Toggle talkback. Default is True.
Returns
-------
list of strings
List of saved image files (if any).
"""
channel_index = wt_kit.get_index(data.channel_names, channel)
shape = data.channels[channel_index].shape
# remove dimensions that do not involve the channel
channel_slice = [0 if size == 1 else slice(None) for size in shape]
sliced_constants = [
data.axis_expressions[i] for i in range(len(shape)) if not channel_slice[i]
]
removed_shape = data._chop_prep(axis, at=at)[0]
len_chopped = reduce(int.__mul__, removed_shape) // reduce(int.__mul__, shape)
if len_chopped > 10 and not autosave:
print(f"expecting {len_chopped} figures. Forcing autosave.")
autosave = True
if autosave:
save_directory, filepath_seed = _filepath_seed(
save_directory, fname if fname else data.natural_name, len_chopped, "quick1D"
)
pathlib.Path.mkdir(save_directory)
# prepare data
with closing(data._from_slice(channel_slice)) as sliced:
out = []
for constant in sliced_constants:
sliced.remove_constant(constant)
# chew through image generation
for i, d in enumerate(sliced.ichop(axis, at=at)):
# determine ymin and ymax for global axis scale
data_channel = data.channels[channel_index]
ymin, ymax = data_channel.min(), data_channel.max()
dynamic_range = ymax - ymin
ymin -= dynamic_range * 0.05
ymax += dynamic_range * 0.05
if np.sign(ymin) != np.sign(data_channel.min()):
ymin = 0
if np.sign(ymax) != np.sign(data_channel.max()):
ymax = 0
# unpack data -------------------------------------------------------------------------
axis = d.axes[0]
xi = axis.full
channel = d.channels[channel_index]
zi = channel[:]
# create figure ------------------------------------------------------------------------
aspects = [[[0, 0], 0.5]]
fig, gs = create_figure(width="single", nrows=1, cols=[1], aspects=aspects)
ax = plt.subplot(gs[0, 0])
# plot --------------------------------------------------------------------------------
plt.plot(xi, zi, lw=2)
plt.scatter(xi, zi, color="grey", alpha=0.5, edgecolor="none")
# decoration --------------------------------------------------------------------------
plt.grid()
# limits
if local:
pass
else:
plt.ylim(ymin, ymax)
# label axes
ax.set_xlabel(axis.label, fontsize=18)
ax.set_ylabel(channel.natural_name, fontsize=18)
plt.xticks(rotation=45)
plt.axvline(0, lw=2, c="k")
plt.xlim(xi.min(), xi.max())
# constants: variable marker lines, title
ls = []
for constant in d.constants:
if constant.expression in sliced_constants:
# ignore these constants; no relation to the data
continue
ls.append(constant.label)
if constant.units and (axis.units_kind == constant.units_kind):
constant.convert(axis.units)
plt.axvline(constant.value, color="k", linewidth=4, alpha=0.25)
_title(fig, data.natural_name, subtitle=", ".join(ls))
# save --------------------------------------------------------------------------------
if autosave:
filepath = filepath_seed.format(i)
savefig(filepath, fig=fig, facecolor="white")
plt.close()
if verbose:
print("image saved at", str(filepath))
out.append(str(filepath))
return out
[docs]
def quick2D(
data,
xaxis=0,
yaxis=1,
at={},
channel=0,
*,
cmap=None,
contours=0,
pixelated=True,
dynamic_range=False,
local=False,
contours_local=True,
autosave=False,
save_directory=None,
fname=None,
verbose=True,
):
"""Quickly plot 2D slice(s) of data.
Parameters
----------
data : WrightTools.Data object.
Data to plot.
xaxis : string or integer (optional)
Expression or index of horizontal axis. Default is 0.
yaxis : string or integer (optional)
Expression or index of vertical axis. Default is 1.
at : dictionary (optional)
Dictionary of parameters in non-plotted dimension(s). If not
provided, plots will be made at each coordinate.
cmap : Colormap
Colormap to use. If None, will use "default" or "signed" depending on channel values.
channel : string or integer (optional)
Name or index of channel to plot. Default is 0.
contours : integer (optional)
The number of black contour lines to add to the plot. Default is 0.
pixelated : boolean (optional)
Toggle between pcolor and contourf (deulaney) plotting backends.
Default is True (pcolor).
dynamic_range : boolean (optional)
Force the colorbar to use all of its colors. Only changes behavior
for signed channels. Default is False.
local : boolean (optional)
Toggle plotting locally. Default is False.
contours_local : boolean (optional)
Toggle plotting black contour lines locally. Default is True.
autosave : boolean (optional)
Toggle autosave. Default is False when the number of plots is 10 or less.
When the number of plots is greater than 10, saving is forced.
save_directory : string (optional)
Location to save image(s). Default is None (auto-generated).
fname : string (optional)
File name. If None, data name is used. Default is None.
verbose : boolean (optional)
Toggle talkback. Default is True.
Returns
-------
list of strings
List of saved image files (if any).
"""
# channel index
channel_index = wt_kit.get_index(data.channel_names, channel)
shape = data.channels[channel_index].shape
# remove axes that are independent of channel
channel_slice = [0 if size == 1 else slice(None) for size in shape]
sliced_constants = [
data.axis_expressions[i] for i in range(len(shape)) if not channel_slice[i]
]
# determine saving, prepare for saving
removed_shape = data._chop_prep(xaxis, yaxis, at=at)[0]
len_chopped = reduce(int.__mul__, removed_shape) // reduce(int.__mul__, shape)
if len_chopped > 10 and not autosave:
print(f"expecting {len_chopped} figures. Forcing autosave.")
autosave = True
if autosave:
save_directory, filepath_seed = _filepath_seed(
save_directory, fname if fname else data.natural_name, len_chopped, "quick2D"
)
pathlib.Path.mkdir(save_directory)
kwargs = {}
if cmap:
kwargs["cmap"] = cmap
with closing(data._from_slice(channel_slice)) as sliced:
out = []
for constant in sliced_constants:
sliced.remove_constant(constant)
for i, d in enumerate(sliced.ichop(xaxis, yaxis, at=at)):
# unpack data -------------------------------------------------------------------------
xaxis = d.axes[0]
xlim = xaxis.min(), xaxis.max()
yaxis = d.axes[1]
ylim = yaxis.min(), yaxis.max()
channel = d.channels[channel_index]
zi = channel[:]
zi = np.ma.masked_invalid(zi)
# create figure -----------------------------------------------------------------------
if xaxis.units == yaxis.units:
xr = xlim[1] - xlim[0]
yr = ylim[1] - ylim[0]
aspect = np.abs(yr / xr)
if 3 < aspect or aspect < 1 / 3.0:
# TODO: raise warning here
aspect = np.clip(aspect, 1 / 3.0, 3.0)
else:
aspect = 1
fig, gs = create_figure(
width="single", nrows=1, cols=[1, "cbar"], aspects=[[[0, 0], aspect]]
)
ax = plt.subplot(gs[0])
# ax.patch.set_facecolor("w")
# colors ------------------------------------------------------------------------------
levels = determine_levels(channel, data.channels[channel_index], dynamic_range, local)
if pixelated:
ax.pcolor(d, channel=channel_index, vmin=levels.min(), vmax=levels.max(), **kwargs)
else:
ax.contourf(d, channel=channel_index, levels=levels, **kwargs)
# contour lines -----------------------------------------------------------------------
if contours:
contour_levels = determine_contour_levels(
channel, data.channels[channel_index], contours_local
)
ax.contour(d, channel=channel_index, levels=contour_levels)
# decoration --------------------------------------------------------------------------
plt.xticks(rotation=45, fontsize=14)
plt.yticks(fontsize=14)
ax.set_xlabel(xaxis.label, fontsize=18)
ax.set_ylabel(yaxis.label, fontsize=18)
ax.grid()
# lims
ax.set_xlim(xlim)
ax.set_ylim(ylim)
# add zero lines
plt.axvline(0, lw=2, c="k")
plt.axhline(0, lw=2, c="k")
# constants: variable marker lines, title
ls = []
for constant in d.constants:
if constant.expression in sliced_constants:
# ignore these constants; no relation to the data
continue
ls.append(constant.label)
if constant.units:
# x axis
if xaxis.units_kind == constant.units_kind:
constant.convert(xaxis.units)
plt.axvline(constant.value, color="k", linewidth=4, alpha=0.25)
# y axis
if yaxis.units_kind == constant.units_kind:
constant.convert(yaxis.units)
plt.axhline(constant.value, color="k", linewidth=4, alpha=0.25)
_title(fig, data.natural_name, subtitle=", ".join(ls))
# colorbar
cax = plt.subplot(gs[1])
cbar_ticks = np.linspace(levels.min(), levels.max(), 11)
plot_colorbar(cax=cax, ticks=cbar_ticks, label=channel.natural_name, **kwargs)
plt.sca(ax)
# save figure -------------------------------------------------------------------------
if autosave:
filepath = filepath_seed.format(i)
savefig(filepath, fig=fig, facecolor="white")
plt.close()
if verbose:
print("image saved at", str(filepath))
out.append(str(filepath))
# else:
# out.append(fig)
return out
def determine_levels(local_channel, global_channel, dynamic_range, local):
if local_channel.signed:
if local:
limit = local_channel.mag()
else:
if dynamic_range:
limit = min(
abs(global_channel.null - global_channel.min()),
abs(global_channel.null - global_channel.max()),
)
else:
limit = global_channel.mag()
levels = np.linspace(-limit + local_channel.null, limit + local_channel.null, 200)
else:
if local:
levels = np.linspace(local_channel.null, local_channel.max(), 200)
else:
if global_channel.max() < global_channel.null:
levels = np.linspace(global_channel.min(), global_channel.null, 200)
else:
levels = np.linspace(global_channel.null, global_channel.max(), 200)
return levels
def determine_contour_levels(local_channel, global_channel, contours, local):
# force top and bottom contour to be data range then clip them out
if local_channel.signed:
if local:
limit = local_channel.mag()
else:
limit = global_channel.mag()
levels = np.linspace(
-limit + local_channel.null, limit + local_channel.null, contours + 2
)[1:-1]
else:
if local:
limit = local_channel.max()
else:
limit = global_channel.max()
levels = np.linspace(local_channel.null, limit, contours + 2)[1:-1]
return levels
def _filepath_seed(save_directory, fname, nchops, artist):
"""the big ugly logic block to determine the autosave filepaths"""
if isinstance(save_directory, str):
save_directory = pathlib.Path(save_directory)
elif save_directory is None:
save_directory = pathlib.Path.cwd()
# create a folder if multiple images
if nchops > 1:
save_directory = save_directory / f"{artist} {wt_kit.TimeStamp().path}"
pathlib.Path.mkdir(save_directory)
return save_directory, fname + " {0:0>3}.png"