Source code for WrightTools.diagrams.WMEL

"""WMEL diagrams."""


# --- import --------------------------------------------------------------------------------------


import numpy as np
import matplotlib.pyplot as plt


# --- define --------------------------------------------------------------------------------------


# --- subplot -------------------------------------------------------------------------------------


[docs]class Subplot: """Subplot containing WMEL."""
[docs] def __init__( self, ax, energies, number_of_interactions=4, title="", title_font_size=16, state_names=None, virtual=[None], state_font_size=14, state_text_buffer=0.5, label_side="left", ): """Subplot. Parameters ---------- ax : matplotlib axis The axis. energies : 1D array-like Energies (scaled between 0 and 1) number_of_interactions : integer Number of interactions in diagram. title : string (optional) Title of subplot. Default is empty string. state_names: list of str (optional) list of the names of the states virtual: list of ints (optional) list of indexes of any vitual energy states state_font_size: numtype (optional) font size for the state lables state_text_buffer: numtype (optional) space between the energy level bars and the state labels """ self.ax = ax self.energies = energies self.interactions = number_of_interactions self.state_names = state_names # Plot Energy Levels for i in range(len(self.energies)): if i in virtual: linestyle = "--" else: linestyle = "-" self.ax.axhline(self.energies[i], color="k", linewidth=2, ls=linestyle, zorder=5) # add state names if isinstance(state_names, list): for i in range(len(self.energies)): if label_side == "left": ax.text( -state_text_buffer, energies[i], state_names[i], fontsize=state_font_size, verticalalignment="center", horizontalalignment="center", ) elif label_side == "right": ax.text( 1 + state_text_buffer, energies[i], state_names[i], fontsize=state_font_size, verticalalignment="center", horizontalalignment="center", ) # calculate interaction_positons self.x_pos = np.linspace(0, 1, number_of_interactions) # set limits self.ax.set_xlim(-0.1, 1.1) self.ax.set_ylim(-0.01, 1.01) # remove guff self.ax.axis("off") # title self.ax.set_title(title, fontsize=title_font_size)
[docs] def add_arrow( self, index, between, kind, label="", head_length=10, head_aspect=1, font_size=14, color="k", ): """Add an arrow to the WMEL diagram. Parameters ---------- index : integer The interaction, or start and stop interaction for the arrow. between : 2-element iterable of integers The inital and final state of the arrow kind : {'ket', 'bra', 'outbra', 'outket'} The kind of interaction. label : string (optional) Interaction label. Default is empty string. head_length: number (optional) size of arrow head font_size : number (optional) Label font size. Default is 14. color : matplotlib color (optional) Arrow color. Default is black. Returns ------- [line,arrow_head,text] """ if hasattr(index, "index"): x_pos = list(index) else: x_pos = [index] * 2 x_pos = [np.linspace(0, 1, self.interactions)[i] for i in x_pos] y_pos = [self.energies[between[0]], self.energies[between[1]]] # calculate arrow length arrow_length = self.energies[between[1]] - self.energies[between[0]] arrow_end = self.energies[between[1]] if arrow_length > 0: direction = 1 elif arrow_length < 0: direction = -1 else: raise ValueError("between invalid!") length = abs(y_pos[0] - y_pos[1]) if kind == "ket": line = self.ax.plot(x_pos, y_pos, linestyle="-", color=color, linewidth=2, zorder=9) elif kind == "bra": line = self.ax.plot(x_pos, y_pos, linestyle="--", color=color, linewidth=2, zorder=9) elif kind == "out": yi = np.linspace(y_pos[0], y_pos[1], 100) xi = ( np.sin((yi - y_pos[0]) * int((1 / length) * 20) * 2 * np.pi * length) / 40 + x_pos[0] ) line = self.ax.plot( xi[:-5], yi[:-5], linestyle="-", color=color, linewidth=2, solid_capstyle="butt", zorder=9, ) # the yi[:-5] simply cuts off 5 line points to not show past the arrowhead elif kind == "outbra": yi = np.linspace(y_pos[0], y_pos[1], 100) xi = ( np.sin((yi - y_pos[0]) * int((1 / length) * 20) * 2 * np.pi * length) / 40 + x_pos[0] ) counter = 0 while counter - 13 <= len(yi): subyi = yi[counter : counter + 15] subxi = xi[counter : counter + 15] line = self.ax.plot( subxi[:-5], subyi[:-5], linestyle="-", color=color, linewidth=2, solid_capstyle="butt", zorder=9, ) counter += 13 # increment must be equal to the while condition else: raise ValueError("kind is not 'ket', 'bra', 'out' or 'outbra'.") # add arrow head dx = x_pos[1] - x_pos[0] dy = y_pos[1] - y_pos[0] xytext = (x_pos[1] - dx * 1e-2, y_pos[1] - dy * 1e-2) annotation = self.ax.annotate( "", xy=(x_pos[1], y_pos[1]), xytext=xytext, arrowprops=dict( fc=color, ec=color, shrink=0, headwidth=head_length * head_aspect, headlength=head_length, linewidth=0, zorder=10, ), size=25, ) # add text text = self.ax.text( np.mean(x_pos), -0.15, label, fontsize=font_size, horizontalalignment="center" ) return line, annotation.arrow_patch, text
# --- artist --------------------------------------------------------------------------------------
[docs]class Artist: """Dedicated WMEL figure artist."""
[docs] def __init__( self, size, energies, state_names=None, number_of_interactions=4, virtual=[None], state_font_size=8, state_text_buffer=0.5, ): """Initialize. Parameters ---------- size : [columns, rows] Layout. energies : list of numbers State energies. state_names : list of strings (optional) State names. Default is None. number_of_interactions : integer (optional) Number of interactions. Default is 4. virtual : list of integers (optional) Indices of states which are virtual. Default is [None]. state_font_size : number (optional) State font size. Default is 8. state_text_buffer : number (optional) Size of buffer around state text. Default is 0.5. """ self.cascades = {} # stores grid address of cascades (if any) for referencing # create figure figsize = [int(size[0] * ((number_of_interactions + 1.0) / 6)), size[1] * 2.5] fig, (subplots) = plt.subplots(size[1], size[0], figsize=figsize) self.fig = fig # wrap subplots if need be if size == [1, 1]: self.subplots = np.array([[subplots]]) plt.subplots_adjust(left=0.3) elif size[1] == 1: self.subplots = np.array([subplots]) else: self.subplots = subplots # add energy levels self.energies = energies for plot in self.subplots.flatten(): for i in range(len(self.energies)): if i in virtual: linestyle = "--" else: linestyle = "-" plot.axhline(energies[i], color="k", linewidth=2, linestyle=linestyle) # add state names to leftmost plots if state_names: for i in range(size[1]): plot = self.subplots[i][0] for j in range(len(self.energies)): plot.text( -state_text_buffer, energies[j], state_names[j], fontsize=state_font_size, verticalalignment="center", horizontalalignment="center", ) # calculate interaction_positons self.x_pos = np.linspace(0, 1, number_of_interactions) # plot cleans up a bunch - call it now as well as later self.plot()
[docs] def label_rows(self, labels, font_size=15, text_buffer=1.5): """Label rows. Parameters ---------- labels : list of strings Labels. font_size : number (optional) Font size. Default is 15. text_buffer : number Buffer around text. Default is 1.5. """ for i in range(len(self.subplots)): plot = self.subplots[i][-1] plot.text( text_buffer, 0.5, labels[i], fontsize=font_size, verticalalignment="center", horizontalalignment="center", )
[docs] def label_columns(self, labels, font_size=15, text_buffer=1.15): """Label columns. Parameters ---------- labels : list of strings Labels. font_size : number (optional) Font size. Default is 15. text_buffer : number Buffer around text. Default is 1.5. """ for i in range(len(labels)): plot = self.subplots[0][i] plot.text( 0.5, text_buffer, labels[i], fontsize=font_size, verticalalignment="center", horizontalalignment="center", )
[docs] def clear_diagram(self, diagram): """Clear diagram. Parameters ---------- diagram : [column, row] Diagram to clear. """ plot = self.subplots[diagram[1]][diagram[0]] plot.cla()
[docs] def add_arrow( self, diagram, number, between, kind, label="", head_length=0.075, font_size=7, color="k" ): """Add arrow. Parameters ---------- diagram : [column, row] Diagram position. number : integer Arrow position. between : [start, stop] Arrow span. kind : {'ket', 'bra', 'out'} Arrow style. label : string (optional) Arrow label. Default is ''. head_length : number (optional) Arrow head length. Default 0.075. font_size : number (optional) Font size. Default is 7. color : matplotlib color Arrow color. Default is 'k'. Returns ------- list [line, arrow_head, text] """ column, row = diagram x_pos = self.x_pos[number] # calculate arrow length arrow_length = self.energies[between[1]] - self.energies[between[0]] arrow_end = self.energies[between[1]] if arrow_length > 0: direction = 1 y_poss = [self.energies[between[0]], self.energies[between[1]] - head_length] elif arrow_length < 0: direction = -1 y_poss = [self.energies[between[0]], self.energies[between[1]] + head_length] else: raise ValueError("Variable between invalid") subplot = self.subplots[row][column] # add line length = abs(y_poss[0] - y_poss[1]) if kind == "ket": line = subplot.plot([x_pos, x_pos], y_poss, linestyle="-", color=color, linewidth=2) elif kind == "bra": line = subplot.plot([x_pos, x_pos], y_poss, linestyle="--", color=color, linewidth=2) elif kind == "out": yi = np.linspace(y_poss[0], y_poss[1], 100) xi = ( np.sin((yi - y_poss[0]) * int((1 / length) * 20) * 2 * np.pi * length) / 40 + x_pos ) line = subplot.plot( xi, yi, linestyle="-", color=color, linewidth=2, solid_capstyle="butt" ) else: raise ValueError("unexpected value for 'kind': {kind}, expected ('ket', 'bra', 'out')") # add arrow head arrow_head = subplot.arrow( self.x_pos[number], arrow_end - head_length * direction, 0, 0.0001 * direction, head_width=head_length * 2, head_length=head_length, fc=color, ec=color, linestyle="solid", linewidth=0, ) # add text text = subplot.text( self.x_pos[number], -0.1, label, fontsize=font_size, horizontalalignment="center" ) return line, arrow_head, text
[docs] def plot(self, save_path=None, close=False, bbox_inches="tight", pad_inches=1): """Plot figure. Parameters ---------- save_path : string (optional) Save path. Default is None. close : boolean (optional) Toggle automatic figure closure after plotting. Default is False. bbox_inches : number (optional) Bounding box size, in inches. Default is 'tight'. pad_inches : number (optional) Pad inches. Default is 1. """ # final manipulations for plot in self.subplots.flatten(): # set limits plot.set_xlim(-0.1, 1.1) plot.set_ylim(-0.1, 1.1) # remove guff plot.axis("off") # save if save_path: plt.savefig( save_path, transparent=True, dpi=300, bbox_inches=bbox_inches, pad_inches=pad_inches, ) # close if close: plt.close()
[docs] def add_cascade( self, diagram, number, number_of_interactions=4, titles=[None], title_font_size=16, state_names=None, virtual=[None], state_font_size=14, state_text_buffer=0.5, label_side="left", bbox_adjust=[0.2, 0.9, 0.2, 0.78], transfer_arrow_linewidth=1.5, transfer_arrow_offset=[0.5, 5], ): """Add cascading process as plot. Parameters --------- diagram : [row,column] row and column positions to place the cascade in the artist grid. number : int number of cascading processes to add in series number_of_interactions : int (optional) number of interactions for each sub process. Equal for all subprocesses. Default is 4. titles : list of str (optional) list of title for each process. Default is [None] title_font_size : int (optional) set the title font size. Default is 16 state_names : list of str (optional) list of state names. Default is None virtal : list of int (optional) list of indices of virtual states. Default is [None] state_font_size : int (optional) set the font size for the labels. Default is 14 state_text_buffer : float (optional) set the state text buffer. Default is 0.5 label_side : str (optional) set the side on which the state label will appear. Either 'left' or 'right'. Default is 'left' bbox_adjust : list of floats (optional) adjusts the left, right, bottom, and top bounds of the cascade diagram. Default is bbox_adjust=[0.2, 0.9, 0.2, 0.78] transfer_arrow_linewidth = float (ooptional) adjusts the thickness of the arrow between cascading WMEL diagrams transfer_arrow_offset = [x_offset, y_offset] (optional) adjusts the label of the energy transfer arrow along the x and y axis """ x, y = diagram[0], diagram[1] self.clear_diagram([x, y]) # clear prev. added single process in grid self.cascades[f"[{x},{y}]"] = [] # store list of cascades in grid for iteration gridspec = ( self.subplots[x, y].get_subplotspec().get_gridspec() ) # get grid spec to confine cascade sfig = self.fig.add_subfigure(gridspec[x, y]) # create subfigure bbox = self.subplots[x - 1][y].get_position() # get grid box position for cascade spec = { "width_ratios": [1 + 0 ** (i % 2) for i in range(3 * number - 3)] } # define width ratios for subplots casc_subplot = sfig.subplots( 1, 3 * number - 3, gridspec_kw=spec, subplot_kw={"position": bbox} ) # make subplots # empirical adjustment to subfigure bbox to align levels with other row plots sfig.subplots_adjust( left=bbox_adjust[0], bottom=bbox_adjust[2], right=bbox_adjust[1], top=bbox_adjust[3] ) # plot the arrows between cascading processes as a sine arrow for arrowindex, arrowplot in enumerate(casc_subplot.flatten()[:-1]): if arrowindex % 2 == 1: xout = np.linspace(0, 4, 100) yout = [np.sin(x * 2 * np.pi) for x in xout] arrowplot.plot(xout, yout, color="k", linewidth=transfer_arrow_linewidth) arrowplot.set_ylim(-20, 20) # sets the visible amplitude of the sine wave arr_lbl_x, arr_lbl_y = transfer_arrow_offset[0], transfer_arrow_offset[1] arrowplot.annotate( r"$\mathrm{\hbar}$$\mathrm{\omega}$", [xout[len(xout) // 4] - arr_lbl_x, yout[len(yout) // 2] + arr_lbl_y], ) # adds hv label arrowplot.arrow( xout[-2], yout[-2], 0.0001, 0, head_width=2, head_length=1, fc="k", ec="k" ) # add arrowhead # plot the cascade's processes' diagrams for ind, plot in enumerate(casc_subplot.flatten()): plot.axis("off") if ind % 2 == 0: if titles != [None]: subtitles = [] for t in titles: subtitles.append(t) subtitles.append("") elif titles == [None]: subtitles = [""] * (2 * number_of_interactions) if ind == 0 and label_side == "left": casc_wmel = Subplot( plot, energies=self.energies, number_of_interactions=number_of_interactions, title=subtitles[ind], title_font_size=title_font_size, state_names=state_names, virtual=virtual, state_font_size=state_font_size, state_text_buffer=state_text_buffer, label_side=label_side, ) # Uses the Subplot class to make the each process diagram elif ind == len(subtitles) - 2 and label_side == "right": casc_wmel = Subplot( plot, energies=self.energies, number_of_interactions=number_of_interactions, title=subtitles[ind], title_font_size=title_font_size, state_names=state_names, virtual=virtual, state_font_size=state_font_size, state_text_buffer=state_text_buffer, label_side=label_side, ) # Uses the Subplot class to make the each process diagram else: casc_wmel = Subplot( plot, energies=self.energies, number_of_interactions=number_of_interactions, title=subtitles[ind], title_font_size=title_font_size, virtual=virtual, ) # Uses the Subplot class to make the each process diagram self.cascades[f"[{x},{y}]"].append(casc_wmel)
[docs] def add_cascade_arrow( self, diagram, cascade_number, index, between, kind, label="", head_length=10, head_aspect=1, font_size=14, color="k", ): """Add arrow to cascade subprocess. Parameters ---------- diagram : [row,column] row and column indices location of the cascade in the figure. cascade_number : int index of subprocess to add the arrow index : int interaction index in the subprocess to which the arrow will be added between : [start,stop] start and stop energy level indeces from which the arrow will point to kind : {'bra', 'ket', 'out'} the kind of interaction to add label : str (optional) interaction label. Default is "" head_length : int (optional) interaction arrow head length. Default is 10 head_aspect : int (optional) set arrow head aspect ratio. Default is 1 font_size : int (optional) set label font size. Default is 14 color : str (optional) set arrow color. Default is black """ x, y = diagram[0], diagram[1] self.cascades[f"[{x},{y}]"][cascade_number].add_arrow( index, between, kind, label=label, head_length=head_length, head_aspect=head_aspect, font_size=font_size, color=color, )