Source code for WrightTools.artists._base

"""Base tools for visualizing data."""

# --- import --------------------------------------------------------------------------------------

import numpy as np

import matplotlib
from matplotlib.projections import register_projection
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

from .. import exceptions as wt_exceptions
from .. import kit as wt_kit
from import Data
from ._colors import colormaps

# --- define -------------------------------------------------------------------------------------

__all__ = ["Axes", "Figure", "GridSpec", "apply_rcparams"]

# --- classes -------------------------------------------------------------------------------------

[docs] class Axes(matplotlib.axes.Axes): """Axes.""" name = "wright" transposed = False is_sideplot = False def _apply_labels( self, autolabel="none", xlabel=None, ylabel=None, data=None, channel_index=0 ): """Apply x and y labels to axes. Parameters ---------- autolabel : {'none', 'both', 'x', 'y'} (optional) Label(s) to apply from data. Default is none. xlabel : string (optional) x label. Default is None. ylabel : string (optional) y label. Default is None. data : object (optional) data to read labels from. Default is None. channel_index : integer (optional) Channel index. Default is 0. """ # read from data if autolabel in ["xy", "both", "x"] and not xlabel: xlabel = data.axes[0].label if autolabel in ["xy", "both", "y"] and not ylabel: if data.ndim == 1: ylabel = data.channels[channel_index].label elif data.ndim == 2: ylabel = data.axes[1].label # apply if xlabel: if isinstance(xlabel, bool): xlabel = data.axes[0].label self.set_xlabel(xlabel, fontsize=18) if ylabel: if isinstance(ylabel, bool): ylabel = data.axes[1].label self.set_ylabel(ylabel, fontsize=18) def _parse_limits(self, zi=None, data=None, channel_index=None, dynamic_range=False, **kwargs): if "norm" in kwargs: return kwargs if zi is not None: if "levels" in kwargs.keys(): levels = kwargs["levels"] vmin = levels.min() vmax = levels.max() else: vmin = np.nanmin(zi) vmax = np.nanmax(zi) elif data is not None: signed = data.channels[channel_index].signed null = data.channels[channel_index].null if signed and dynamic_range: vmin = -data.channels[channel_index].minor_extent + null vmax = +data.channels[channel_index].minor_extent + null elif signed and not dynamic_range: vmin = -data.channels[channel_index].major_extent + null vmax = +data.channels[channel_index].major_extent + null else: vmin = null vmax = data.channels[channel_index].max() # don't overwrite if "vmin" not in kwargs.keys(): kwargs["vmin"] = vmin if "vmax" not in kwargs.keys(): kwargs["vmax"] = vmax return kwargs def _parse_plot_args(self, *args, **kwargs): plot_type = kwargs.pop("plot_type") if plot_type not in ["pcolor", "pcolormesh", "contourf", "contour", "imshow"]: raise NotImplementedError args = list(args) # offer pop, append etc dynamic_range = kwargs.pop("dynamic_range", False) if isinstance(args[0], Data): data = args.pop(0) channel = kwargs.pop("channel", 0) channel_index = wt_kit.get_index(data.channel_names, channel) squeeze = np.array(data.channels[channel_index].shape) == 1 xa = data.axes[0] ya = data.axes[1] for sq, xs, ys in zip(squeeze, xa.shape, ya.shape): if sq and (xs != 1 or ys != 1): raise wt_exceptions.ValueError("Cannot squeeze axis to fit channel") zi = data.channels[channel_index].points if not zi.ndim == 2: raise wt_exceptions.DimensionalityError(2, zi.ndim) squeeze = tuple([0 if i else slice(None) for i in squeeze]) if plot_type == "imshow": if "aspect" not in kwargs.keys(): kwargs["aspect"] = "auto" if "origin" not in kwargs.keys(): kwargs["origin"] = "lower" if "interpolation" not in kwargs.keys(): if max(zi.shape) < 10**3: # TODO: better decision logic kwargs["interpolation"] = "nearest" else: kwargs["interpolation"] = "antialiased" xi = xa[:][squeeze] yi = ya[:][squeeze] zi = zi.transpose(_order_for_imshow(xi, yi)) # extract extent if "extent" not in kwargs.keys(): xlim = [xi[0, 0], xi[-1, -1]] ylim = [yi[0, 0], yi[-1, -1]] xstep = (xlim[1] - xlim[0]) / (2 * xi.size) ystep = (ylim[1] - ylim[0]) / (2 * yi.size) x_extent = [xlim[0] - xstep, xlim[1] + xstep] y_extent = [ylim[0] - ystep, ylim[1] + ystep] extent = [*x_extent, *y_extent] kwargs["extent"] = extent args = [zi] + args else: xi = xa.full[squeeze] yi = ya.full[squeeze] args = [xi, yi, zi] + args # limits kwargs = self._parse_limits( data=data, channel_index=channel_index, dynamic_range=dynamic_range, **kwargs ) if plot_type == "contourf": if "levels" not in kwargs.keys(): kwargs["levels"] = np.linspace(kwargs["vmin"], kwargs["vmax"], 256) elif plot_type == "contour": if "levels" not in kwargs.keys(): if data.channels[channel_index].signed: n = 11 else: n = 6 kwargs["levels"] = np.linspace(kwargs.pop("vmin"), kwargs.pop("vmax"), n)[1:-1] # colors if "colors" not in kwargs.keys() and "cmap" not in kwargs.keys(): kwargs["colors"] = "k" if "alpha" not in kwargs.keys(): kwargs["alpha"] = 0.5 if plot_type in ["pcolor", "pcolormesh", "contourf", "imshow"]: kwargs = _parse_cmap(data=data, channel_index=channel_index, **kwargs) else: if plot_type == "imshow": kwargs = self._parse_limits(zi=args[0], **kwargs) else: kwargs = self._parse_limits(zi=args[2], **kwargs) data = None channel_index = 0 if plot_type == "contourf": if "levels" not in kwargs.keys(): kwargs["levels"] = np.linspace(kwargs["vmin"], kwargs["vmax"], 256) if plot_type in ["pcolor", "pcolormesh", "contourf", "imshow"]: kwargs = _parse_cmap(**kwargs) # labels self._apply_labels( autolabel=kwargs.pop("autolabel", False), xlabel=kwargs.pop("xlabel", None), ylabel=kwargs.pop("ylabel", None), data=data, channel_index=channel_index, ) if plot_type != "contour": self.set_facecolor([0.75] * 3) if plot_type.startswith("pcolor"): kwargs["shading"] = kwargs.get("shading", "auto") return args, kwargs
[docs] def add_sideplot(self, along, pad=0, height=0.75, ymin=0, ymax=1.1): """Add a side axis. Parameters ---------- along : {'x', 'y'} Axis to add along. pad : float (optional) Side axis pad. Default is 0. height : float (optional) Side axis height. Default is 0. """ # divider should only be created once if hasattr(self, "divider"): divider = self.divider else: divider = make_axes_locatable(self) setattr(self, "divider", divider) # create if along == "x": ax = self.sidex = divider.append_axes("top", height, pad=pad, sharex=self) elif along == "y": ax = self.sidey = divider.append_axes("right", height, pad=pad, sharey=self) ax.transposed = True else: raise ValueError("unexpacted value for 'along': {along}, expected 'x' or 'y'") # beautify if along == "x": ax.set_ylim(ymin, ymax) elif along == "y": ax.set_xlim(ymin, ymax) ax.autoscale(enable=False) ax.set_adjustable("box") ax.is_sideplot = True plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) ax.tick_params(axis="both", which="both", length=0) return ax
[docs] def contour(self, *args, **kwargs): """Plot contours. If a 3D or higher Data object is passed, a lower dimensional channel can be plotted, provided the ``squeeze`` of the channel has ``ndim==2`` and the first two axes do not span dimensions other than those spanned by that channel. Parameters ---------- data : 2D object Data to plot. channel : int or string (optional) Channel index or name. Default is 0. dynamic_range : boolean (optional) Force plotting of all contours, overloading for major extent. Only applies to signed data. Default is False. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.contour__ optional keyword arguments. __ Returns ------- matplotlib.contour.QuadContourSet """ args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="contour") return super().contour(*args, **kwargs)
[docs] def contourf(self, *args, **kwargs): """Plot contours. If a 3D or higher Data object is passed, a lower dimensional channel can be plotted, provided the ``squeeze`` of the channel has ``ndim==2`` and the first two axes do not span dimensions other than those spanned by that channel. Parameters ---------- data : 2D object Data to plot. channel : int or string (optional) Channel index or name. Default is 0. dynamic_range : boolean (optional) Force plotting of all contours, overloading for major extent. Only applies to signed data. Default is False. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.contourf__ optional keyword arguments. __ Returns ------- matplotlib.contour.QuadContourSet """ args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="contourf") # Overloading contourf in an attempt to fix aliasing problems when saving vector graphics # see # also see # set_edgecolor('face') does indeed remove all of the aliasing problems # unfortunately, it also seems to distort the plot in a subtle but important way # it shifts the entire colorbar down w.r.t. the data (by one contour? not clear) # so for now, I am trying to fix the problem by adding contour just below contourf # this does not perfectly get rid of the aliasing, but it doesn't distort the data # which is more important # I anticipate that this method will be tinkered with in the future # so I've left the things I have tried and abandoned as comments---good luck! # ---Blaise 2017-07-30 kwargs["antialiased"] = False kwargs["extend"] = "both" contours = super().contourf(*args, **kwargs) # fill lines zorder = contours.collections[0].zorder - 0.1 levels = (contours.levels[1:] + contours.levels[:-1]) / 2 matplotlib.axes.Axes.contour( self, *args[:3], levels=levels, cmap=contours.cmap, zorder=zorder ) # decoration self.set_facecolor([0.75] * 3) # PathCollection modifications for c in contours.collections: pass # c.set_rasterized(True) # c.set_edgecolor('face') return contours
[docs] def legend(self, *args, **kwargs): """Add a legend. Parameters ---------- *args matplotlib legend args. *kwargs matplotlib legend kwargs. Returns ------- legend """ if "fancybox" not in kwargs.keys(): kwargs["fancybox"] = False if "framealpha" not in kwargs.keys(): kwargs["framealpha"] = 1.0 return super().legend(*args, **kwargs)
[docs] def pcolor(self, *args, **kwargs): """Create a pseudocolor plot of a 2-D array. If a 3D or higher Data object is passed, a lower dimensional channel can be plotted, provided the ``squeeze`` of the channel has ``ndim==2`` and the first two axes do not span dimensions other than those spanned by that channel. Defaults to ``shading="auto"`` to ensure that color boundaries are drawn bisecting point positions, when possible. Parameters ---------- data : 2D object Data to plot. channel : int or string (optional) Channel index or name. Default is 0. dynamic_range : boolean (optional) Force plotting of all contours, overloading for major extent. Only applies to signed data. Default is False. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.pcolor__ optional keyword arguments. __ Returns ------- matplotlib.collections.PolyCollection """ args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="pcolor") return super().pcolor(*args, **kwargs)
def imshow(self, *args, **kwargs): """Create a pseudocolor plot of a 2-D array. The array is plotted with uniform spacing. Quicker than pcolor, pcolormesh. **Requires that the plotted axes are grid aligned (i.e. the `squeeze` of each axis has ``ndim==1``).** If a 3D or higher Data object is passed, a lower dimensional channel can be plotted, provided the ``squeeze`` of the channel has ``ndim==2``. Defaults to ``aspect="auto"`` (pixels are stretched to fit the subplot axes) If `interpolation` method is not specified, defaults to either "antialiased" (for large images) or "nearest" (for small arrays). `extent` defaults to ensure that pixels are drawn bisecting point positions. Parameters ---------- data : 2D object Data to plot. channel : int or string (optional) Channel index or name. Default is 0. dynamic_range : boolean (optional) Force plotting of all contours, overloading for major extent. Only applies to signed data. Default is False. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.imshow__ optional keyword arguments. __ Returns ------- matplotlib.image.AxesImage """ xlim, ylim = super().get_xlim(), super().get_ylim() old_signs = list(map(lambda x: (x[1] - x[0]) > 0, [xlim, ylim])) args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="imshow") out = super().imshow(*args, **kwargs) # undo axis order if it was flipped xlim, ylim = super().get_xlim(), super().get_ylim() new_signs = list(map(lambda x: (x[1] - x[0]) > 0, [xlim, ylim])) if old_signs[0] != new_signs[0]: super().invert_xaxis() if old_signs[1] != new_signs[1]: super().invert_yaxis() return out def scatter(self, *args, **kwargs): """Scatter plot a channel against two _variables_. Scatter point color reflects channel values. Data need not be structured. If data object is not provided, scatter reverts to the `matplotlib parent method <>`_. args --------- data : 2D object Data to plot. kwargs ---------- x : int or string (optional) axis name or index for x (abscissa) axis. Default is 0. If x does not match an axis, searches variable names for match. y : int or string (optional) axis name or index for y (ordinate) axis. Default is 1. If y does not match an axis, searches variable names for match. channel : int or string (optional) Channel index or name. Default is 0. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.scatter__ optional keyword arguments. __ Returns ------- matplotlib.collections.PathCollection """ args = list(args) if isinstance(args[0], Data): data = args.pop(0) coords = [] for axis in [kwargs.pop("x", 0), kwargs.pop("y", 1)]: try: # check axes axis = wt_kit.get_index(data.axis_names, axis) axis = data.axes[axis][:] except (ValueError, IndexError): # check vars axis = wt_kit.get_index(data.variable_names, axis) axis = data.variables[axis][:] # broadcast up to channel shape coords.append(axis) if "c" in kwargs.keys(): raise KeyError( "'c' kwarg not allowed when data object provided. \ Use `cmap` instead to control colors." ) channel = kwargs.pop("channel", 0) channel_index = wt_kit.get_index(data.channel_names, channel) limits = {} limits = self._parse_limits(data=data, channel_index=channel_index, **limits) norm = Normalize(**limits) cmap = _parse_cmap(data, channel_index=channel_index, **kwargs)["cmap"] z = data.channels[channel_index][:] # fill x, y, z to joint shape shape = wt_kit.joint_shape(z, *coords) def full(arr, shape): for i in range(arr.ndim): if arr.shape[i] == 1: arr = np.repeat(arr, shape[i], axis=i) return arr args = [full(ax, shape).flatten() for ax in coords] + args z = full(z, shape).flatten() z = norm(z) z = cmap(z) kwargs["c"] = z self._apply_labels( autolabel=kwargs.pop("autolabel", False), xlabel=kwargs.pop("xlabel", None), ylabel=kwargs.pop("ylabel", None), data=data, channel_index=channel_index, ) return super().scatter(*args, **kwargs)
[docs] def pcolormesh(self, *args, **kwargs): """Create a pseudocolor plot of a 2-D array. If a 3D or higher Data object is passed, a lower dimensional channel can be plotted, provided the ``squeeze`` of the channel has ``ndim==2`` and the first two axes do not span dimensions other than those spanned by that channel. Defaults to ``shading="auto"`` to ensure that color boundaries are drawn bisecting point positions, when possible. Quicker than pcolor Parameters ---------- data : 2D object Data to plot. channel : int or string (optional) Channel index or name. Default is 0. dynamic_range : boolean (optional) Force plotting of all contours, overloading for major extent. Only applies to signed data. Default is False. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.pcolormesh__ optional keyword arguments. __ Returns ------- matplotlib.collections.QuadMesh """ args, kwargs = self._parse_plot_args(*args, **kwargs, plot_type="pcolormesh") return super().pcolormesh(*args, **kwargs)
[docs] def plot(self, *args, **kwargs): """Plot lines and/or markers. If a 2D or higher Data object is passed, a lower dimensional channel can be plotted, provided the ``squeeze`` of the channel has ``ndim==1`` and the first axis does not span dimensions other than that spanned by the channel. Parameters ---------- data : 1D object Data to plot. channel : int or string (optional) Channel index or name. Default is 0. dynamic_range : boolean (optional) Force plotting of all contours, overloading for major extent. Only applies to signed data. Default is False. autolabel : {'none', 'both', 'x', 'y'} (optional) Parameterize application of labels directly from data object. Default is none. xlabel : string (optional) xlabel. Default is None. ylabel : string (optional) ylabel. Default is None. **kwargs matplotlib.axes.Axes.plot__ optional keyword arguments. __ Returns ------- list list of matplotlib.lines.line2D objects """ args = list(args) # offer pop, append etc # unpack data object, if given if isinstance(args[0], Data): data = args.pop(0) channel = kwargs.pop("channel", 0) channel_index = wt_kit.get_index(data.channel_names, channel) squeeze = np.array(data.channels[channel_index].shape) == 1 xa = data.axes[0] for sq, xs in zip(squeeze, xa.shape): if sq and xs != 1: raise wt_exceptions.ValueError("Cannot squeeze axis to fit channel") squeeze = tuple([0 if i else slice(None) for i in squeeze]) zi = data.channels[channel_index].points xi = xa[squeeze] if not zi.ndim == 1: raise wt_exceptions.DimensionalityError(1, zi.ndim) args = [xi, zi] + args else: data = None channel_index = 0 # labels self._apply_labels( autolabel=kwargs.pop("autolabel", False), xlabel=kwargs.pop("xlabel", None), ylabel=kwargs.pop("ylabel", None), data=data, channel_index=channel_index, ) # call parent return super().plot(*args, **kwargs)
[docs] class Figure(matplotlib.figure.Figure): """Figure.""" def add_subplot(self, *args, **kwargs): """Add a subplot to the figure. Parameters ---------- *args **kwargs Returns ------- WrightTools.artists.Axes object """ kwargs.setdefault("projection", "wright") return super().add_subplot(*args, **kwargs)
[docs] class GridSpec(matplotlib.gridspec.GridSpec): """GridSpec.""" pass
# --- artist helpers ------------------------------------------------------------------------------
[docs] def apply_rcparams(kind="fast"): """Quickly apply rcparams for given purposes. Parameters ---------- kind: {'default', 'fast', 'publication'} (optional) Settings to use. Default is 'fast'. """ if kind == "default": matplotlib.rcdefaults() elif kind == "fast": matplotlib.rcParams["text.usetex"] = False matplotlib.rcParams["mathtext.fontset"] = "cm" matplotlib.rcParams[""] = "sans-serif" matplotlib.rcParams["font.size"] = 14 matplotlib.rcParams["legend.edgecolor"] = "grey" matplotlib.rcParams["contour.negative_linestyle"] = "solid" elif kind == "publication": matplotlib.rcParams["text.usetex"] = True preamble = "\\usepackage[cm]{sfmath}\\usepackage{amssymb}" matplotlib.rcParams["text.latex.preamble"] = preamble matplotlib.rcParams["mathtext.fontset"] = "cm" matplotlib.rcParams[""] = "sans-serif" matplotlib.rcParams["font.serif"] = "cm" matplotlib.rcParams["font.sans-serif"] = "cm" matplotlib.rcParams["font.size"] = 14 matplotlib.rcParams["legend.edgecolor"] = "grey" matplotlib.rcParams["contour.negative_linestyle"] = "solid"
def _order_for_imshow(xi, yi): """ looks at x and y axis shape to determine order of zi axes **requires orthogonal, 1D axes** returns 2-ple: the transpose order to apply to zi """ sx = np.array(xi.shape) sy = np.array(yi.shape) # check that each axis is 1D (i.e. for ndim, number of axes with size 1 is >= ndim - 1 ) if ( == xi.size) and ( == yi.size): # check that axes are orthogonal and orient z accordingly # determine index of x and y axes if (sx[0] == 1) and (sy[1] == 1): # zi[y,x] return (0, 1) elif (sx[1] == 1) and (sy[0] == 1): # zi[x,y]; imshow expects zi[rows, cols] return (1, 0) else: raise TypeError(f"x and y must be orthogonal; shapes are: {xi.shape}, {yi.shape}") else: raise TypeError(f"Axes are not 1D: {xi.shape}, {yi.shape}") def _parse_cmap(data=None, channel_index=None, **kwargs): if "cmap" in kwargs.keys(): if isinstance(kwargs["cmap"], str): kwargs["cmap"] = colormaps[kwargs["cmap"]] elif data: if data.channels[channel_index].signed: kwargs["cmap"] = colormaps["signed"] return kwargs kwargs["cmap"] = colormaps["default"] return kwargs