Source code for WrightTools.diagrams.WMEL

"""WMEL diagrams."""


# --- import --------------------------------------------------------------------------------------


import numpy as np
import matplotlib.pyplot as plt


# --- define --------------------------------------------------------------------------------------


# --- subplot -------------------------------------------------------------------------------------


[docs]class Subplot: """Subplot containing WMEL."""
[docs] def __init__( self, ax, energies, number_of_interactions=4, title="", title_font_size=16, state_names=None, virtual=[None], state_font_size=14, state_text_buffer=0.5, label_side="left", ): """Subplot. Parameters ---------- ax : matplotlib axis The axis. energies : 1D array-like Energies (scaled between 0 and 1) number_of_interactions : integer Number of interactions in diagram. title : string (optional) Title of subplot. Default is empty string. state_names: list of str (optional) list of the names of the states virtual: list of ints (optional) list of indexes of any vitual energy states state_font_size: numtype (optional) font size for the state lables state_text_buffer: numtype (optional) space between the energy level bars and the state labels """ self.ax = ax self.energies = energies self.interactions = number_of_interactions self.state_names = state_names # Plot Energy Levels for i in range(len(self.energies)): if i in virtual: linestyle = "--" else: linestyle = "-" self.ax.axhline(self.energies[i], color="k", linewidth=2, ls=linestyle, zorder=5) # add state names if isinstance(state_names, list): for i in range(len(self.energies)): if label_side == "left": ax.text( -state_text_buffer, energies[i], state_names[i], fontsize=state_font_size, verticalalignment="center", horizontalalignment="center", ) elif label_side == "right": ax.text( 1 + state_text_buffer, energies[i], state_names[i], fontsize=state_font_size, verticalalignment="center", horizontalalignment="center", ) # calculate interaction_positons self.x_pos = np.linspace(0, 1, number_of_interactions) # set limits self.ax.set_xlim(-0.1, 1.1) self.ax.set_ylim(-0.01, 1.01) # remove guff self.ax.axis("off") # title self.ax.set_title(title, fontsize=title_font_size)
[docs] def add_arrow( self, index, between, kind, label="", head_length=10, head_aspect=1, font_size=14, color="k", ): """Add an arrow to the WMEL diagram. Parameters ---------- index : integer The interaction, or start and stop interaction for the arrow. between : 2-element iterable of integers The inital and final state of the arrow kind : {'ket', 'bra'} The kind of interaction. label : string (optional) Interaction label. Default is empty string. head_length: number (optional) size of arrow head font_size : number (optional) Label font size. Default is 14. color : matplotlib color (optional) Arrow color. Default is black. Returns ------- [line,arrow_head,text] """ if hasattr(index, "index"): x_pos = list(index) else: x_pos = [index] * 2 x_pos = [np.linspace(0, 1, self.interactions)[i] for i in x_pos] y_pos = [self.energies[between[0]], self.energies[between[1]]] # calculate arrow length arrow_length = self.energies[between[1]] - self.energies[between[0]] arrow_end = self.energies[between[1]] if arrow_length > 0: direction = 1 elif arrow_length < 0: direction = -1 else: raise ValueError("between invalid!") length = abs(y_pos[0] - y_pos[1]) if kind == "ket": line = self.ax.plot(x_pos, y_pos, linestyle="-", color=color, linewidth=2, zorder=9) elif kind == "bra": line = self.ax.plot(x_pos, y_pos, linestyle="--", color=color, linewidth=2, zorder=9) elif kind == "out": yi = np.linspace(y_pos[0], y_pos[1], 100) xi = ( np.sin((yi - y_pos[0]) * int((1 / length) * 20) * 2 * np.pi * length) / 40 + x_pos[0] ) line = self.ax.plot( xi, yi, linestyle="-", color=color, linewidth=2, solid_capstyle="butt", zorder=9 ) else: raise ValueError("kind is not 'ket', 'out', or 'bra'.") # add arrow head dx = x_pos[1] - x_pos[0] dy = y_pos[1] - y_pos[0] xytext = (x_pos[1] - dx * 1e-2, y_pos[1] - dy * 1e-2) annotation = self.ax.annotate( "", xy=(x_pos[1], y_pos[1]), xytext=xytext, arrowprops=dict( fc=color, ec=color, shrink=0, headwidth=head_length * head_aspect, headlength=head_length, linewidth=0, zorder=10, ), size=25, ) # add text text = self.ax.text( np.mean(x_pos), -0.15, label, fontsize=font_size, horizontalalignment="center" ) return line, annotation.arrow_patch, text
# --- artist --------------------------------------------------------------------------------------
[docs]class Artist: """Dedicated WMEL figure artist."""
[docs] def __init__( self, size, energies, state_names=None, number_of_interactions=4, virtual=[None], state_font_size=8, state_text_buffer=0.5, ): """Initialize. Parameters ---------- size : [rows, collumns] Layout. energies : list of numbers State energies. state_names : list of strings (optional) State names. Default is None. number_of_interactions : integer (optional) Number of interactions. Default is 4. virtual : list of integers (optional) Indices of states which are virtual. Default is [None]. state_font_size : number (optional) State font size. Default is 8. state_text_buffer : number (optional) Size of buffer around state text. Default is 0.5. """ # create figure figsize = [int(size[0] * ((number_of_interactions + 1.0) / 6.0)), size[1] * 2.5] fig, (subplots) = plt.subplots(size[1], size[0], figsize=figsize) self.fig = fig # wrap subplots if need be if size == [1, 1]: self.subplots = np.array([[subplots]]) plt.subplots_adjust(left=0.3) elif size[1] == 1: self.subplots = np.array([subplots]) else: self.subplots = subplots # add energy levels self.energies = energies for plot in self.subplots.flatten(): for i in range(len(self.energies)): if i in virtual: linestyle = "--" else: linestyle = "-" plot.axhline(energies[i], color="k", linewidth=2, linestyle=linestyle) # add state names to leftmost plots if state_names: for i in range(size[1]): plot = self.subplots[i][0] for j in range(len(self.energies)): plot.text( -state_text_buffer, energies[j], state_names[j], fontsize=state_font_size, verticalalignment="center", horizontalalignment="center", ) # calculate interaction_positons self.x_pos = np.linspace(0, 1, number_of_interactions) # plot cleans up a bunch - call it now as well as later self.plot()
[docs] def label_rows(self, labels, font_size=15, text_buffer=1.5): """Label rows. Parameters ---------- labels : list of strings Labels. font_size : number (optional) Font size. Default is 15. text_buffer : number Buffer around text. Default is 1.5. """ for i in range(len(self.subplots)): plot = self.subplots[i][-1] plot.text( text_buffer, 0.5, labels[i], fontsize=font_size, verticalalignment="center", horizontalalignment="center", )
[docs] def label_columns(self, labels, font_size=15, text_buffer=1.15): """Label columns. Parameters ---------- labels : list of strings Labels. font_size : number (optional) Font size. Default is 15. text_buffer : number Buffer around text. Default is 1.5. """ for i in range(len(labels)): plot = self.subplots[0][i] plot.text( 0.5, text_buffer, labels[i], fontsize=font_size, verticalalignment="center", horizontalalignment="center", )
[docs] def clear_diagram(self, diagram): """Clear diagram. Parameters ---------- diagram : [column, row] Diagram to clear. """ plot = self.subplots[diagram[1]][diagram[0]] plot.cla()
[docs] def add_arrow( self, diagram, number, between, kind, label="", head_length=0.075, font_size=7, color="k" ): """Add arrow. Parameters ---------- diagram : [column, row] Diagram position. number : integer Arrow position. between : [start, stop] Arrow span. kind : {'ket', 'bra', 'out'} Arrow style. label : string (optional) Arrow label. Default is ''. head_length : number (optional) Arrow head length. Default 0.075. font_size : number (optional) Font size. Default is 7. color : matplotlib color Arrow color. Default is 'k'. Returns ------- list [line, arrow_head, text] """ column, row = diagram x_pos = self.x_pos[number] # calculate arrow length arrow_length = self.energies[between[1]] - self.energies[between[0]] arrow_end = self.energies[between[1]] if arrow_length > 0: direction = 1 y_poss = [self.energies[between[0]], self.energies[between[1]] - head_length] elif arrow_length < 0: direction = -1 y_poss = [self.energies[between[0]], self.energies[between[1]] + head_length] else: raise ValueError("Variable between invalid") subplot = self.subplots[row][column] # add line length = abs(y_poss[0] - y_poss[1]) if kind == "ket": line = subplot.plot([x_pos, x_pos], y_poss, linestyle="-", color=color, linewidth=2) elif kind == "bra": line = subplot.plot([x_pos, x_pos], y_poss, linestyle="--", color=color, linewidth=2) elif kind == "out": yi = np.linspace(y_poss[0], y_poss[1], 100) xi = ( np.sin((yi - y_poss[0]) * int((1 / length) * 20) * 2 * np.pi * length) / 40 + x_pos ) line = subplot.plot( xi, yi, linestyle="-", color=color, linewidth=2, solid_capstyle="butt" ) # add arrow head arrow_head = subplot.arrow( self.x_pos[number], arrow_end - head_length * direction, 0, 0.0001 * direction, head_width=head_length * 2, head_length=head_length, fc=color, ec=color, linestyle="solid", linewidth=0, ) # add text text = subplot.text( self.x_pos[number], -0.1, label, fontsize=font_size, horizontalalignment="center" ) return line, arrow_head, text
[docs] def plot(self, save_path=None, close=False, bbox_inches="tight", pad_inches=1): """Plot figure. Parameters ---------- save_path : string (optional) Save path. Default is None. close : boolean (optional) Toggle automatic figure closure after plotting. Default is False. bbox_inches : number (optional) Bounding box size, in inches. Default is 'tight'. pad_inches : number (optional) Pad inches. Default is 1. """ # final manipulations for plot in self.subplots.flatten(): # set limits plot.set_xlim(-0.1, 1.1) plot.set_ylim(-0.1, 1.1) # remove guff plot.axis("off") # save if save_path: plt.savefig( save_path, transparent=True, dpi=300, bbox_inches=bbox_inches, pad_inches=pad_inches, ) # close if close: plt.close()