Source code for omfit_classes.utils_plot

    # framework is running
    from .startup_choice import *
except ImportError as _excp:
    # class is imported by itself
    if (
        'attempted relative import with no known parent package' in str(_excp)
        or 'No module named \'omfit_classes\'' in str(_excp)
        or "No module named '__main__.startup_choice'" in str(_excp)
        from startup_choice import *

if framework:
    print('Loading plot utility functions...')

import collections
from signal import pause
from omfit_classes.utils_base import _available_to_user_plot, compare_version
from omfit_classes.utils_math import is_uncertain, interp1e, unsorted_unique, RectBivariateSplineNaN, point_to_line
from omfit_classes.sortedDict import SortedDict

# explicit imports
import numpy as np
from xarray import Dataset, DataArray
import uncertainties
import inspect
from scipy import interpolate, integrate
import matplotlib
from matplotlib import pyplot, gridspec, cm
from matplotlib.widgets import RectangleSelector, RadioButtons, AxesWidget
from mpl_toolkits.mplot3d import Axes3D
import platform
import re

default_colorblind_line_cycle = [
default_matplotlib_line_cycle = eval(re.sub(r'.*(\[.*\]).*', r'\1', str(matplotlib.rcParams['axes.prop_cycle'])))
default_colorblind_line_cycle = matplotlib.cycler('color', default_colorblind_line_cycle)
default_matplotlib_cmap_cycle = 'viridis' + '/extras/styles')
# Colormap definition for "Standard Gamma-II" colormap from IDL
# fmt: off
_r = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               4, 9, 14, 19, 23, 28, 33, 38, 42, 47, 52, 57, 61, 66, 71, 76,
               81, 81, 81, 81, 81, 81, 81, 81, 80, 80, 80, 80, 80, 80, 80, 79,
               84, 89, 94, 99, 104, 109, 114, 119, 124, 129, 134, 139, 144, 149, 154, 159,
               164, 169, 174, 180, 185, 190, 196, 201, 206, 212, 217, 222, 228, 233, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 248, 240, 232, 225, 217, 209, 202, 194, 186, 179, 171, 163, 168,
               173, 178, 183, 188, 193, 198, 203, 209, 214, 219, 224, 229, 234, 239, 244, 249,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255])
_g = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 5, 10, 16, 21, 27, 32, 37, 43, 48, 54, 59, 64, 70, 75,
               81, 85, 90, 95, 100, 105, 109, 114, 119, 124, 129, 134, 138, 143, 148, 153,
               158, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163,
               163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163,
               163, 169, 175, 181, 187, 193, 199, 205, 212, 218, 224, 230, 236, 242, 248, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255])
_b = np.array([0, 5, 10, 15, 20, 26, 31, 36, 41, 46, 52, 57, 62, 67, 72, 78,
               83, 88, 93, 98, 104, 109, 114, 119, 124, 130, 135, 140, 145, 150, 156, 161,
               166, 171, 176, 182, 187, 192, 197, 202, 208, 213, 218, 223, 228, 234, 239, 244,
               249, 255, 250, 245, 239, 234, 228, 223, 218, 212, 207, 201, 196, 190, 185, 180,
               174, 169, 163, 158, 152, 147, 142, 136, 131, 125, 120, 114, 109, 104, 98, 93,
               87, 82, 76, 71, 66, 60, 55, 49, 44, 38, 33, 28, 22, 17, 11, 6,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 4, 9, 14, 19, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72,
               77, 82, 77, 71, 65, 59, 53, 47, 41, 36, 30, 24, 18, 12, 6, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 3, 6, 9, 12, 16, 19, 22, 25, 29, 32, 35, 38, 41, 45, 48,
               51, 54, 58, 61, 64, 67, 71, 74, 77, 80, 83, 87, 90, 93, 96, 100,
               103, 106, 109, 112, 116, 119, 122, 125, 129, 132, 135, 138, 142, 145, 148, 151,
               154, 158, 161, 164, 167, 171, 174, 177, 180, 183, 187, 190, 193, 196, 200, 203,
               206, 209, 213, 216, 219, 222, 225, 229, 232, 235, 238, 242, 245, 248, 251, 255])
# fmt: on
_rgb = np.column_stack((_r, _g, _b)) / 255.0
cm.register_cmap(cmap=matplotlib.colors.ListedColormap(_rgb, name='Standard Gamma-II'))
matplotlib.rcParams['image.cmap'] = default_matplotlib_cmap_cycle

if platform.system() == 'Darwin':
    rightClickMPLindex = 2
    middleClickMPLindex = 3
    rightClickMPLindex = 3
    middleClickMPLindex = 2

[docs]@_available_to_user_plot def autofmt_sharexy(trim_xlabel=True, trim_ylabel=True, fig=None): # todo: docstring if fig is None: fig = pyplot.gcf() autofmt_sharex(trim_ylabel=trim_ylabel, fig=fig) autofmt_sharey(trim_xlabel=trim_xlabel, fig=fig)
[docs]@_available_to_user_plot def autofmt_sharey(trim_xlabel=True, fig=None, wspace=0): """ Prunes y-tick labels and y-axis labels from all but the first cols axes and moves cols (optionally) closer together. :param trim_xlabel: bool. prune right ytick label to prevent overlap. :param fig: Figure. Defaults to current figure. :param wspace: Horizontal spacing between axes. """ if fig is None: fig = pyplot.gcf() for a in fig.axes: if hasattr(a, 'is_first_col') and not a.is_first_col(): for yt in a.get_yticklabels(): yt.set_visible(False) a.set_ylabel('') if len(a.get_xticklabels()) and trim_xlabel: try: a.xaxis.major.locator.set_params(prune='upper') except Exception: pass fig.subplots_adjust(wspace=wspace)
[docs]@_available_to_user_plot def autofmt_sharex(trim_ylabel=True, fig=None, hspace=0): """ Prunes x-tick labels and x-axis labels from all but the last row axes and moves rows (optionally) closer together. :param trim_ylabel: bool. prune top ytick label to prevent overlap. :param fig: Figure. Defaults to current figure. :param hspace: Vertical spacing between axes. """ if fig is None: fig = pyplot.gcf() for a in fig.axes: if hasattr(a, 'is_last_row') and not a.is_last_row(): for xt in a.get_xticklabels(): xt.set_visible(False) a.set_xlabel('') a.xaxis.get_offset_text().set_visible(False) if len(a.get_yticklabels()) and trim_ylabel: try: a.yaxis.major.locator.set_params(prune='upper') except Exception: pass fig.subplots_adjust(hspace=hspace)
[docs]@_available_to_user_plot def uerrorbar(x, y, ax=None, **kwargs): r""" Given arguments y or x,y where x and/or y have uncertainties, feed the appropriate terms to matplotlib's errorbar function. If y or x is more than 1D, it is flattened along every dimension but the last. :param x: array of independent axis values :param y: array of values with uncertainties, for which shaded error band is plotted :param ax: The axes instance into which to plot (default: pyplot.gca()) :param \**kwargs: Passed to ax.errorbar :return: list. A list of ErrorbarContainer objects containing the line, bars, and caps of each (x,y) along the last dimension. """ result = [] # set default key word arguments if ax is None: ax = pyplot.gca() kwargs.setdefault('marker', 'o') kwargs.pop('axes', None) # don't let them use this because we used ax if 'linestyle' not in kwargs and 'ls' not in kwargs: kwargs['linestyle'] = '' if np.all(std_devs(y) == 0) and np.all(std_devs(x) == 0): kwargs.setdefault('capsize', 0) # enable combinations of 1D and 2D x's and y's y = np.array(y) y = y.reshape(-1, y.shape[-1]) x = np.array(x) x = x.reshape(-1, x.shape[-1]) if x.shape[0] == 1 and y.shape[0] > 1: # one x for all y's x = np.tile(x[0, :], y.shape[0]).reshape(-1, x.shape[-1]) # plot each (x,y) and collect container objects for xi, yi in zip(x, y): tmp = ax.errorbar(nominal_values(xi), nominal_values(yi), xerr=std_devs(xi), yerr=std_devs(yi), **kwargs) result.append(tmp) return result
[docs]class Uband(object): """ This class wraps the line and PollyCollection(s) associated with a banded errorbar plot for use in the uband function. """ def __init__(self, line, bands): """ :param line: Line2D A line of the x,y nominal values :param bands: list of PolyCollections The fill_between and/or fill_betweenx PollyCollections spanning the std_devs of the x,y data """ from matplotlib.cbook import flatten self.line = line # matplotlib.lines.Line2D self.bands = list(flatten([bands])) # matplotlib.collections.PolyCollection(s) def __getattr__(self, attr): if attr in ['set_color', 'set_lw', 'set_linewidth', 'set_dashes', 'set_linestyle']: def _band_line_method(method, *args, **kw): """ Call the same method for line and band. Returns Line2D method call result. """ for band in self.bands: getattr(band, method)(*args, **kw) return getattr(self.line, method)(*args, **kw) return lambda *args, **kw: _band_line_method(attr, *args, **kw) else: return getattr(self.line, attr)
[docs]@_available_to_user_plot def uband(x, y, ax=None, fill_kwargs=None, **kwargs): r""" Given arguments x,y where either or both have uncertainties, plot x,y using pyplot.plot of the nominal values and surround it with with a shaded error band using matplotlib's fill_between and/or fill_betweenx. If y or x is more than 1D, it is flattened along every dimension but the last. :param x: array of independent axis values :param y: array of values with uncertainties, for which shaded error band is plotted :param ax: The axes instance into which to plot (default: pyplot.gca()) :param fill_kwargs: dict. Passed to pyplot.fill_between :param \**kwargs: Passed to pyplot.plot :return: list. A list of Uband objects containing the line and bands of each (x,y) along the last dimension. """ result = [] if ax is None: ax = pyplot.gca() if fill_kwargs is None: fill_kwargs = {} fill_kwargs.setdefault('alpha', 0.25) # enable combinations of 1D and 2D x's and y's y = np.array(y) y = y.reshape(-1, y.shape[-1]) x = np.array(x) x = x.reshape(-1, x.shape[-1]) if x.shape[0] == 1 and y.shape[0] > 1: # one x for all y's x = np.tile(x[0, :], y.shape[0]).reshape(-1, x.shape[-1]) # plot each (x,y) and collect the lines/bands into a single object for xi, yi in zip(x, y): xnom = np.atleast_1d(np.squeeze(uncertainties.unumpy.nominal_values(xi))) xerr = np.atleast_1d(np.squeeze(uncertainties.unumpy.std_devs(xi))) ynom = np.atleast_1d(np.squeeze(uncertainties.unumpy.nominal_values(yi))) yerr = np.atleast_1d(np.squeeze(uncertainties.unumpy.std_devs(yi))) (l,) = ax.plot(xnom, ynom, **kwargs) fkwargs = copy.copy(fill_kwargs) # changes to fill_kwargs propagate to the next call of uband! fkwargs.setdefault('color', l.get_color()) bands = [] if np.any(yerr != 0): bandy = ax.fill_between(xnom, ynom - yerr, ynom + yerr, **fkwargs) bands.append(bandy) if np.any(xerr != 0): bandx = ax.fill_betweenx(ynom, xnom - xerr, xnom + xerr, **fkwargs) bands.append(bandx) tmp = Uband(l, bands) result.append(tmp) return result
[docs]@_available_to_user_plot def hardcopy(fn, bbox_inches='tight', fig=None, **keyw): # todo: docstring if fig is None: fig = pyplot.gcf() tmp = {} tmp['ps.usedistiller'] = matplotlib.rcParams['ps.usedistiller'] tmp['text.usetex'] = matplotlib.rcParams['text.usetex'] matplotlib.rcParams['ps.usedistiller'] = 'ghostscript' matplotlib.rcParams['text.usetex'] = True os.system('rm %s/.config/matplotlib/tex.cache/*' % os.environ['HOME']) try: fig.savefig(fn, bbox_inches=bbox_inches, **keyw) printi('Saved to %s' % (fn)) except Exception: printe('Error in saving file') matplotlib.rcParams['ps.usedistiller'] = tmp['ps.usedistiller'] matplotlib.rcParams['text.usetex'] = tmp['text.usetex'] return fn
[docs]@_available_to_user_plot def set_fontsize(fig=None, fontsize='+0'): """ For each text object of a figure fig, set the font size to fontsize :param fig: matplotlib.figure object :param fontsize: can be an absolute number (e.g 10) or a relative number (-2 or +2) :return: None """ def match(artist): return artist.__module__ == "matplotlib.text" if is_int(fig) or isinstance(fig, str): fontsize = fig fig = pyplot.gcf() if fig is None: fig = pyplot.gcf() fontS = fontsize for textobj in fig.findobj(match=match): if '+' in fontsize or '-' in fontsize: fontS = textobj.get_fontsize() + eval(fontsize) textobj.set_fontsize(fontS)
[docs]def user_lines_cmap_cycle(): """ return colormap chosen by the user for representation of lines """ mapper = {'blind': 'viridis'} try: tmp = OMFIT['MainSettings']['SETUP']['PlotAppearance']['lines.cmap'][0] return mapper.get(tmp, tmp) except KeyError: pass except Exception as _excp: printe(_excp) return default_matplotlib_cmap_cycle
[docs]def user_image_cmap_cycle(): """ return colormap chosen by the user for representation of images """ try: return OMFIT['MainSettings']['SETUP']['PlotAppearance']['image.cmap'] except KeyError: pass except Exception as _excp: printe(_excp) return default_matplotlib_cmap_cycle
[docs]@_available_to_user_plot def color_cycle(n=10, k=None, cmap_name=None): """ Utility function to conveniently return the color of an index in a colormap cycle :param n: number of uniformly spaced colors, or array defining the colors' spacings :param k: index of the color (if None an array of colors of length n will be returned) :param cmap_name: name of the colormap :return: color of index k from colormap cmap_name made of n colors, or array of colors of length n if k is None Note: if n is an array, then the associated ScalarMappable object is also returned (e.g. for use in a colorbar) """ if cmap_name is None: cmap_name = user_lines_cmap_cycle() if np.iterable(n): nm = matplotlib.colors.Normalize(np.nanmin(n), np.nanmax(n)) sm =, norm=nm) sm.set_array(n) colors = sm.cmap(sm.norm(n)) if k is not None: return colors[k], sm else: return colors, sm else: cmap = getattr(cm, cmap_name) if k is not None: return cmap(k / max([1.0, float(n - 1)])) else: tmp = [] for k in range(n): tmp.append(cmap(k / max([1.0, float(n - 1)]))) return tmp
[docs]@_available_to_user_plot def cycle_cmap(length=50, cmap=None, start=None, stop=None, ax=None): """ Set default color cycle of matplotlib based on colormap Note that the default color cycle is not changed if ax parameter is set; only the axes's color cycle will be changed :param length: The number of colors in the cycle :param cmap: Name of a matplotlib colormap :param start: Limit colormap to this range (0 < start < stop 1) :param stop: Limit colormap to this range (0 < start < stop 1) :param ax: If ax is not None, then change the axes's color cycle instead of the default color cycle :return: color_cycle """ if cmap is None: cmap = user_lines_cmap_cycle() cmap = getattr(, cmap) crange = [0, 1] if start is not None: crange[0] = start if stop is not None: crange[1] = stop assert 0 <= crange[0] <= 1 assert 0 <= crange[1] <= 1 color_cycle = [RGB_to_HEX(*rgb[:3]) for rgb in cmap(np.linspace(crange[0], crange[1], num=length))] if ax is None: from cycler import cycler pyplot.rc('axes', prop_cycle=cycler('color', color_cycle)) else: ax.set_prop_cycle(color_cycle) return color_cycle
[docs]@_available_to_user_plot def contrasting_color(line_or_color): """ Given a matplotlib color specification or a line2D instance or a list with a line2D instance as the first element, pick and return a color that will contrast well. More complicated than just inversion as inverting blue gives yellow, which doesn't display well on a white background. :param line_or_color: matplotlib color spec, line2D instance, or list w/ line2D instance as the first element :return: 4 element array RGBA color specification for a contrasting color """ # Get the RGBA values for the color to transform if isinstance(line_or_color, matplotlib.lines.Line2D): color = line_or_color.get_color() elif np.iterable(line_or_color) and isinstance(line_or_color[0], matplotlib.lines.Line2D): color = line_or_color[0].get_color() else: color = line_or_color color = np.array(matplotlib.colors.to_rgba(color)) # Split RGB from A and transform to HSV alpha = color[3] color0 = color[0:3] color = matplotlib.colors.rgb_to_hsv(color0) # Make sure black and white change and that colors are fairly bright value_lims = [0.5, 0.9] sat_lims = [0.75, 0.95] color[1] = min([max([color[1], value_lims[0]]), value_lims[1]]) color[2] = min([max([color[2], sat_lims[0]]), sat_lims[1]]) # Rotate the hue. 1/3 seems to look better than 1/2. color[0] = (color[0] + 0.333) % 1 # Transform back to RGB, put the A back in, and return return np.append(matplotlib.colors.hsv_to_rgb(color), alpha)
[docs]@_available_to_user_plot def associated_color(line_or_color): """ Given a matplotlib color specification or a line2D instance or a list with a line2D instance as the first element, pick and return a color that will look thematically linked to the first color, but still distinguishable. :param line_or_color: matplotlib color spec, line2D instance, or list w/ line2D instance as the first element :return: 4 element array RGBA color specification for a related, similar (but distinguishable) color """ # Get the RGBA values for the color to transform if isinstance(line_or_color, matplotlib.lines.Line2D): color = line_or_color.get_color() elif np.iterable(line_or_color) and isinstance(line_or_color[0], matplotlib.lines.Line2D): color = line_or_color[0].get_color() else: color = line_or_color color = np.array(matplotlib.colors.to_rgba(color)) # Split RGB from A and transform to HSV alpha = color[3] color0 = color[0:3] color = matplotlib.colors.rgb_to_hsv(color0) # Change saturation if color[1] <= 0.4: color[1] *= 2 else: color[1] *= 0.75 # Change brightness if color[2] <= 0.25: color[2] = 0.5 + color[2] elif color[2] <= 0.4: color[2] *= 1.8 else: color[2] *= 0.45 # Shift the hue slightly. color[0] = (color[0] - 0.1) % 1 # Transform back to RGB, put the A back in, and return return np.append(matplotlib.colors.hsv_to_rgb(color), alpha)
[docs]@_available_to_user_plot def blur_image(im, n, ny=None): """blurs the image by convolving with a gaussian kernel of typical size n. The optional keyword argument ny allows for a different size in the y direction. """ def gauss_kern(size, sizey=None): """Returns a normalized 2D gauss kernel array for convolutions""" size = int(size) if not sizey: sizey = size else: sizey = int(sizey) x, y = np.mgrid[-size : size + 1, -sizey : sizey + 1] g = np.exp(-(x**2 / float(size) + y**2 / float(sizey))) return g / g.sum() from scipy import signal g = gauss_kern(n, sizey=ny) improc = signal.convolve(im, g, mode='same') return improc
[docs]@_available_to_user_plot def pcolor2(*args, fast=False, **kwargs): r""" Plots 2D data as a patch collection. Differently from matplotlib.pyplot.pcolor the mesh is extended by one element so that the number of tiles equals the number of data points in the Z matrix. The X,Y grid does not have to be rectangular. :param \*args: Z or X,Y,Z data to be plotted :param fast: bool Use pcolorfast instead of pcolor. Speed improvements may be dramatic. However, pcolorfast is marked as experimental and may produce unexpected behavior. :param \**kwargs: these arguments are passed to matplotlib.pyplot.pclor :return: None """ ax = kwargs.pop('ax', None) if ax is None: ax = pylab.gca() cmap = kwargs.pop('cmap', None) if cmap is None: cmap = cm.get_cmap() cmap.set_bad(color='w', alpha=1.0) kwargs['cmap'] = cmap naughty_list = ['interpolation'] kw2 = {k: v for k, v in kwargs.items() if k not in naughty_list} if len(args) == 1: if fast: obj = ax.pcolorfast([0]), **kw2) else: obj = ax.pcolor([0]), **kw2) else: if len(args[0].shape) == 1: xy1d = True xdata, ydata = np.meshgrid(args[0], args[1]) else: xy1d = False xdata = args[0] ydata = args[1] masked =[2]) if fast: if xy1d: # pcolorfast doesn't print the z coordinate if x,y are 2D, so try to reduce to 1d if possible xdata = xdata[0, :] ydata = ydata[:, 0] inner_edges = (xdata[:-1] + xdata[1:]) / 2.0 xdata = np.append(np.append(inner_edges[0] - xdata[1] + xdata[0], inner_edges), inner_edges[-1] + xdata[-1] - xdata[-2]) inner_edges = (ydata[:-1] + ydata[1:]) / 2.0 ydata = np.append(np.append(inner_edges[0] - ydata[1] + ydata[0], inner_edges), inner_edges[-1] + ydata[-1] - ydata[-2]) # pcolorfast says it's even faster if given a tuple of (min, max) for x and y, for even spacing if (np.max(abs(np.diff(np.diff(xdata)))) == 0) and (np.max(abs(np.diff(np.diff(ydata)))) == 0): xdata = (np.min(xdata), np.max(xdata)) ydata = (np.min(ydata), np.max(ydata)) obj = ax.pcolorfast(xdata, ydata, masked, **kw2) else: xdata, ydata = meshgrid_expand(xdata, ydata) obj = ax.pcolorfast(xdata, ydata, masked, **kw2) else: xdata, ydata = meshgrid_expand(xdata, ydata) obj = ax.pcolor(xdata, ydata, masked, **kw2) return obj
[docs]@_available_to_user_plot def image(*args, **kwargs): r""" Plots 2D data as an image. Much faster than pcolor/pcolor2(fast=False), but the data have to be on a rectangular X,Y grid :param \*args: Z or X,Y,Z data to be plotted :param \**kwargs: these arguments are passed to pcolorfast """ from omfit_classes.omfit_dmp import mpl_dump, mpl_dump_enable mpl_dump_enable[0] = False try: # todo: add clipping ax = kwargs.pop('ax', None) if ax is None: ax = pylab.gca() cmap = kwargs.pop('cmap', None) if cmap is None: cmap = cm.get_cmap() cmap.set_bad(color='w', alpha=1.0) kwargs['cmap'] = cmap zlim = kwargs.pop('zlim', None) if len(args) == 1: zdata = args[0] xdata = np.linspace(0, zdata.shape[1] - 1, zdata.shape[1]) ydata = np.linspace(0, zdata.shape[0] - 1, zdata.shape[0]) elif len(args) >= 3: xdata = args[0] ydata = args[1] zdata = args[2] if xdata is None: xdata = np.linspace(0, zdata.shape[1] - 1, zdata.shape[1]) if ydata is None: ydata = np.linspace(0, zdata.shape[0] - 1, zdata.shape[0]) zdata = # accounts for data where x,ymax is at min if xdata[0] > xdata[-1]: xdata = xdata[::-1] zdata = zdata[:, ::-1] if ydata[0] > ydata[-1]: ydata = ydata[::-1] zdata = zdata[::-1, :] extent = [np.nanmin(xdata), np.nanmax(xdata), np.nanmin(ydata), np.nanmax(ydata)] ax.plot([extent[0], extent[1]], [extent[2], extent[3]]) del ax.lines[-1] kwargs.setdefault('fast', True) obj = pcolor2(xdata, ydata, zdata, ax=ax, **kwargs) if zlim is not None: norm = matplotlib.colors.Normalize(vmin=zlim[0], vmax=zlim[1]) obj.set_norm(norm) ax.set_xlim(extent[0], extent[1]) ax.set_ylim(extent[2], extent[3]) finally: mpl_dump_enable[0] = True dumpDict = {'args': [ax], 'func': 'image'} args = list(args) # for DMP purposes reduce color scale to 16bits if len(args) == 1 and args[0]'float'): args[0] = args[0].astype(np.float16) if len(args) == 3 and args[2]'float'): args[2] = args[2].astype(np.float16) dumpDict['args'].extend(args) kwargs.pop('ax', None) dumpDict['kwargs'] = kwargs mpl_dump(dumpDict) return obj
[docs]@_available_to_user_plot def meshgrid_expand(xdata, ydata): """ returns the veritices of the mesh, if the xdata and ydata were the centers of the mesh xdata and ydata are 2D matrices, which could for example be generated by np.meshgrid :param xdata: center of the mesh :param ydata: center of the mesh :return: """ dxdx, dxdy = np.gradient(xdata) dydx, dydy = np.gradient(ydata) xdata2 = np.vstack((xdata, xdata[-1, :] + dxdx[-1, :])) dxdy2 = np.vstack((dxdy, dxdy[-1, :])) xdata3 = np.hstack((xdata2, xdata2[:, -1:] + dxdy2[:, -1:])) ydata2 = np.hstack((ydata, ydata[:, -1:] + dydy[:, -1:])) dydx2 = np.hstack((dydx, dydx[:, -1:])) ydata3 = np.vstack((ydata2, ydata2[-1, :] + dydx2[-1, :])) dxdx3, dxdy3 = np.gradient(xdata3) dydx3, dydy3 = np.gradient(ydata3) dx3 = dxdx3 + dxdy3 dy3 = dydx3 + dydy3 xdata4 = xdata3 - dx3 / 2.0 ydata4 = ydata3 - dy3 / 2.0 return xdata4, ydata4
[docs]@_available_to_user_plot def map_HBS_to_RGB(H, B, S=1.0, cmap=None): """ map to a RGB colormap separate HUE, BRIGHTNESS and SATURATIONS arrays :param H: HUE data (any shape array) :param B: BRIGHTNESS data (any shape array) :param S: SATURATION data (any shape array) :param cmap: matplotlib.colormap to be used :return: RGB array (shape of input array with one more dimension of size 3 (RGB) ) """ if cmap is None: cmap = user_image_cmap_cycle() if isinstance(cmap, str): cmap = cm.get_cmap(cmap, 256) x = np.linspace(0, 1, cmap.N) shapes = list(H.shape) w = np.zeros((H.size, 3)) for n in range(H.size): v = H.flatten()[n] t = cmap(np.nanargmin(abs(x - v))) v = colorsys.rgb_to_hls(t[0], t[1], t[2]) if is_float(S * 1.0): w[n] = colorsys.hls_to_rgb(v[0], B.flatten()[n], S) else: w[n] = colorsys.hls_to_rgb(v[0], B.flatten()[n], S.flatten()[n]) shapes.append(3) return w.reshape(shapes)
[docs]@_available_to_user_plot def RGB_to_HEX(R, G, B): """ Convert color from numerical RGB to hexadecimal representation :param R: integer 0<x<255 or float 0.0<x<1.0 :param G: integer 0<x<255 or float 0.0<x<1.0 :param B: integer 0<x<255 or float 0.0<x<1.0 :return: hexadecimal representation of the color """ if not isinstance(R, int): R = int(round(255 * R)) if not isinstance(G, int): G = int(round(255 * G)) if not isinstance(B, int): B = int(round(255 * B)) return '#%02x%02x%02x' % (R, G, B)
[docs]@_available_to_user_plot def plotc(*args, **kwargs): r""" Plot the various curves defined by the arguments [X],Y,[Z] where X is the x value, Y is the y value, and Z is the color. If one argument is given it is interpreted as Y; if two, then X, Y; if three then X, Y, Z. If all three are given, then it is passed to plotc and the labels are discarded. If Z is omitted, a rainbow of colors is used, with blue for the first curve and red for the last curve. A different color map can be given with the cmap keyword (see for other options). If X is omitted, then the (ith) index of Y is used as the x value. :param \*args: :param \**kwargs: :return: """ ax = kwargs.pop('ax', None) if ax is None: ax = pylab.gca() args = list(args) X = args.pop(0) Y = None Z = None if len(args) >= 1 and not isinstance(args[0], str): Y = args.pop(0) if len(args) >= 1 and not isinstance(args[0], str): Z = args.pop(0) # from list to matrix, this is done to work nicely with across function of SortedDict def list2mat(inv): outv = inv if isinstance(inv, list): outv = np.array(inv).T return outv X = list2mat(X) Y = list2mat(Y) Z = list2mat(Z) if Y is None and Z is None: Y = X X = np.linspace(0, Y.shape[0] - 1, Y.shape[0]) if len(X.shape) == len(Y.shape) == 1 and Z is None: return ax.plot(X, Y, *args, **kwargs) Y = np.reshape(Y, (Y.shape[0], -1)) if len(X.shape) == 1: X = np.tile(X.flatten(), (Y.shape[1], 1)).T try: tmp_prop_cycle = matplotlib.rcParams['axes.prop_cycle'] nlines = Y.shape[1] pyplot.ioff() LC = [] if Z is not None: vmax = kwargs.pop('vmax', None) vmin = kwargs.pop('vmin', None) kwargs.setdefault('norm', [vmin, vmax]) if len(Z.shape) == 1: Z = np.tile(Z.flatten(), (Y.shape[1], 1)).T # lets plot it normally, since this sets the axes and other stuff that I do not know about ax.plot(X.flatten(), Y.flatten(), '.') del ax.lines[-1] def _plotc(x, y, z, *args, **kwargs): points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) if kwargs['norm'][0] is None: kwargs['norm'] = [np.nanmin(Z), kwargs['norm'][1]] if kwargs['norm'][1] is None: kwargs['norm'] = [kwargs['norm'][0], np.nanmax(Z)] if not isinstance(kwargs['norm'], pyplot.Normalize): kwargs['norm'] = pyplot.Normalize(kwargs['norm'][0], kwargs['norm'][1]) lc = matplotlib.collections.LineCollection(segments, *args, **kwargs) lc.set_array(z) ax.add_collection(lc) return lc for k in range(nlines): x = X[:, k] y = Y[:, k] z = Z[:, k] LC.append(_plotc(x, y, z, *args, **kwargs)) else: labels = kwargs.pop('labels', ['_line' + str(k) for k in range(1, nlines + 1)]) cmap = getattr(cm, kwargs.pop('cmap', default_matplotlib_cmap_cycle)) for k in range(nlines): x = X[:, k] y = Y[~np.isnan(x), k] x = x[~np.isnan(x)] kwargs['label'] = labels[k] kwargs['color'] = cmap(k / max([1, float(nlines - 1)])) ax.plot(x, y, *args, **kwargs) LC.append(ax.lines[-1]) finally: matplotlib.rcParams['axes.prop_cycle'] = tmp_prop_cycle pyplot.draw() pyplot.ion() return LC
[docs]@_available_to_user_plot def title_inside(string, x=0.5, y=0.9, ax=None, **kwargs): r""" Write the title of a figure inside the figure axis rather than outside :param string: title string :param x: x location of the title string (default 0.5, that is centered) :param y: y location of the title string (default 0.875) :param ax: axes to operate on :param \**kwargs: additional keywords passed to pyplot.title :return: pyplot.title text object """ kwargs.setdefault('va', 'top') if ax is None: return pyplot.title(string, x=x, y=y, **kwargs) else: return ax.set_title(string, x=x, y=y, **kwargs)
[docs]@_available_to_user_plot def increase_resolution(*args, **kwargs): """ This function takes 1 (Z) or 3 (X,Y,Z) 2D tables and interpolates them to higher resolution by bivariate spline interpolation. If 1 (3) table(s) is(are) provided, then the second(fourth) argument is the resolution increase, which can be a positive or negative integer integer: res=res0*2^n or a float which sets the grid size in the units provided by the X and Y tables """ oneInput = False args = list(args) if len(args) in [1, 2]: args.insert(1, np.arange(args[0].shape[1])) args.insert(2, np.arange(args[0].shape[0])) oneInput = True Qin = args[0] Rin = args[1] Zin = args[2] if len(Rin.shape) == 2: Rin = Rin[0, :] if len(Zin.shape) == 2: Zin = Zin[:, 0] if len(args) == 4: kwargs['resolution'] = args[3] resolution = kwargs['resolution'] quiet = kwargs.get('quiet', True) if resolution == 0: Q = Qin R = Rin Z = Zin elif is_int(resolution): if resolution > 0: if not quiet: printi('Increasing tables resolution by factor of ' + str(abs(resolution) + 1) + ' ...') nr = len(Rin) nz = len(Zin) for k in range(resolution): nr = nr + nr - 1 nz = nz + nz - 1 R = np.linspace(min(Rin), max(Rin), nr) Z = np.linspace(min(Zin), max(Zin), nz) Q = RectBivariateSplineNaN(Zin, Rin, Qin)(Z, R) elif resolution < 0: if not quiet: printi('Decreasing tables resolution by factor of ' + str(abs(resolution) + 1) + ' ...') R = Rin[::-resolution] Z = Zin[::-resolution] Q = Qin[::-resolution, ::-resolution] elif is_float(resolution): if not quiet: printi('Interpolating tables to ' + str(resolution) + ' m resolution ...') R = np.linspace(min(Rin), max(Rin), int(np.ceil((max(Rin) - min(Rin)) / resolution))) Z = np.linspace(min(Zin), max(Zin), int(np.ceil((max(Zin) - min(Zin)) / resolution))) Q = RectBivariateSplineNaN(Zin, Rin, Qin)(Z, R) if oneInput: return Q else: return Q, R, Z
[docs]@_available_to_user_plot class infoScatter(object): r""" improved version of: Callback for matplotlib to display an annotation when points are clicked on :param x: x of the annotations :param y: y of the annotations :param annotes: list of string annotations :param axis: axis on which to operate on (default to current axis) :param tol: vicinity in pixels where to look for annotations :param func: function to call with signature: func(x,y,annote,visible,axis) :param all_on: Make all of the text visible to begin with :param suppress_canvas_draw: Do not actively draw the canvas if all_on is True, makes plotting faster there are many subplots :param \**kw: extra keywords passed to matplotlib text class """ def __init__(self, x, y, annotes, axis=None, tol=5, func=None, all_on=False, suppress_canvas_draw=False, **kw): = list(zip(x, y, annotes)) self.tol = tol if axis is None: self.axis = pylab.gca() else: self.axis = axis self.drawnAnnotations = {} self.func = func = kw'clip_on', True) if all_on: for cur_x, cur_y, cur_annote in zip(x, y, annotes): self.drawAnnote(self.axis, cur_x, cur_y, cur_annote, redraw_canvas=False) if not suppress_canvas_draw: self.axis.figure.canvas.draw() self.axis.figure.canvas.mpl_connect('button_press_event', self) def __call__(self, event): if event.inaxes: clickX = event.x clickY = event.y if self.axis is None or self.axis == event.inaxes: winner = [None, None] for k, (xd, yd, a) in enumerate( x, y = self.axis.transData.transform((xd, yd)) d = np.sqrt((x - clickX) ** 2 + (y - clickY) ** 2) if d < self.tol: if winner[0] is None or d < winner[0]: winner = [d, k] if winner[0] is not None: self.drawAnnote(event.inaxes,[winner[1]][0],[winner[1]][1],[winner[1]][2])
[docs] def drawAnnote(self, axis, x, y, annote, redraw_canvas=True): """ Draw the annotation on the plot """ if (x, y) in self.drawnAnnotations: markers = self.drawnAnnotations[(x, y)] for m in markers: m.set_visible(not m.get_visible()) if redraw_canvas: self.axis.figure.canvas.draw() else: printi(annote) t = axis.text(x, y, annote, ** m = axis.scatter([x], [y], marker='+', c='r', zorder=100) self.drawnAnnotations[(x, y)] = (t, m) if redraw_canvas: self.axis.figure.canvas.draw() if self.func is not None: self.func(x, y, annote, m.get_visible(), axis)
[docs] def drawSpecificAnnote(self, annote): annotesToDraw = [(x, y, a) for x, y, a in if a == annote] for x, y, a in annotesToDraw: self.drawAnnote(self.axis, x, y, a)
[docs]@_available_to_user_plot def infoPoint(fig=None): """ print x,y coordinates where the user clicks :param fig: matplotlib figure """ from matplotlib import pyplot if fig is None: fig = pyplot.gcf() def onclick(event): print(event.xdata, event.ydata) fig.canvas.mpl_connect('button_press_event', onclick)
[docs]@_available_to_user_plot def XKCDify( ax, mag=1.0, f1=50, f2=0.01, f3=15, bgcolor='w', xaxis_loc=None, yaxis_loc=None, xaxis_arrow='+', yaxis_arrow='+', ax_extend=0.1, expand_axes=False, ylabel_rot=78, ): """ XKCD plot generator by, Jake Vanderplas; Modified by Sterling Smith This is a script that will take any matplotlib line diagram, and convert it to an XKCD-style plot. It will work for plots with line & text elements, including axes labels and titles (but not axes tick labels). The idea for this comes from work by Damon McDougall This adjusts all lines, text, legends, and axes in the figure to look like xkcd plots. Other plot elements are not modified. :param ax: Axes instance the axes to be modified. :param mag: float the magnitude of the distortion :param f1, f2, f3: int, float, int filtering parameters. f1 gives the size of the window, f2 gives the high-frequency cutoff, f3 gives the size of the filter :param xaxis_loc, yaxis_log: float The locations to draw the x and y axes. If not specified, they will be drawn from the bottom left of the plot :param xaxis_arrow: str where to draw arrows on the x axes. Options are '+', '-', '+-', or '' :param yaxis_arrow: str where to draw arrows on the y axes. Options are '+', '-', '+-', or '' :param ax_extend: float How far (fractionally) to extend the drawn axes beyond the original axes limits :param expand_axes: bool if True, then expand axes to fill the figure (useful if there is only a single axes in the figure) :param ylabel_rot: float number of degrees to rotate the y axis label """ import pylab as pl import matplotlib.font_manager as fm fontfn = os.sep.join([os.path.dirname(__file__), '..', 'extras', 'graphics', 'fonts', 'Humor-Sans.ttf']) def xkcd_line(x, y, xlim=None, ylim=None, mag=1.0, f1=30, f2=0.05, f3=15): """ Mimic a hand-drawn line from (x, y) data Parameters ---------- x, y : array_like arrays to be modified xlim, ylim : data range the assumed plot range for the modification. If not specified, they will be guessed from the data mag : float magnitude of distortions f1, f2, f3 : int, float, int filtering parameters. f1 gives the size of the window, f2 gives the high-frequency cutoff, f3 gives the size of the filter Returns ------- x, y : np.ndarrays The modified lines """ from scipy import interpolate, signal x = np.asarray(x) y = np.asarray(y) # get limits for rescaling if xlim is None: xlim = (x.min(), x.max()) if ylim is None: ylim = (y.min(), y.max()) if xlim[1] == xlim[0]: xlim = ylim if ylim[1] == ylim[0]: ylim = xlim # scale the data x_scaled = (x - xlim[0]) * 1.0 / (xlim[1] - xlim[0]) y_scaled = (y - ylim[0]) * 1.0 / (ylim[1] - ylim[0]) # compute the total distance along the path dx = x_scaled[1:] - x_scaled[:-1] dy = y_scaled[1:] - y_scaled[:-1] dist_tot = np.sum(np.sqrt(dx * dx + dy * dy)) # number of interpolated points is proportional to the distance Nu = int(200 * dist_tot) u = np.arange(-1, Nu + 1) * 1.0 / (Nu - 1) # interpolate curve at sampled points k = min(3, len(x) - 1) res = interpolate.splprep([x_scaled, y_scaled], s=0, k=k) x_int, y_int = interpolate.splev(u, res[0]) # we'll perturb perpendicular to the drawn line dx = x_int[2:] - x_int[:-2] dy = y_int[2:] - y_int[:-2] dist = np.sqrt(dx * dx + dy * dy) # create a filtered perturbation coeffs = mag * np.random.normal(0, 0.01, len(x_int) - 2) b = signal.firwin(f1, f2 * dist_tot, window=('kaiser', f3)) response = signal.lfilter(b, 1, coeffs) x_int[1:-1] += response * dy / dist y_int[1:-1] += response * dx / dist # un-scale data x_int = x_int[1:-1] * (xlim[1] - xlim[0]) + xlim[0] y_int = y_int[1:-1] * (ylim[1] - ylim[0]) + ylim[0] return x_int, y_int # Get axes aspect ext = ax.get_window_extent().extents aspect = (ext[3] - ext[1]) / (ext[2] - ext[0]) xlim = ax.get_xlim() ylim = ax.get_ylim() xspan = xlim[1] - xlim[0] yspan = ylim[1] - xlim[0] xax_lim = (xlim[0] - ax_extend * xspan, xlim[1] + ax_extend * xspan) yax_lim = (ylim[0] - ax_extend * yspan, ylim[1] + ax_extend * yspan) if xaxis_loc is None: xaxis_loc = ylim[0] if yaxis_loc is None: yaxis_loc = xlim[0] # Draw axes xaxis = pl.Line2D([xax_lim[0], xax_lim[1]], [xaxis_loc, xaxis_loc], linestyle='-', color='k') yaxis = pl.Line2D([yaxis_loc, yaxis_loc], [yax_lim[0], yax_lim[1]], linestyle='-', color='k') # ttk.Label axes3, 0.5, 'hello', fontsize=14) ax.text(xax_lim[1], xaxis_loc - 0.02 * yspan, ax.get_xlabel(), fontsize=14, ha='right', va='top', rotation=12) ax.text(yaxis_loc - 0.02 * xspan, yax_lim[1], ax.get_ylabel(), fontsize=14, ha='right', va='top', rotation=ylabel_rot) ax.set_xlabel('') ax.set_ylabel('') # Add title ax.text(0.5 * (xax_lim[1] + xax_lim[0]), yax_lim[1], ax.get_title(), ha='center', va='bottom', fontsize=16) ax.set_title('') Nlines = len(ax.lines) lines = [xaxis, yaxis] + [ax.lines.pop(0) for i in range(Nlines)] for line in lines: x, y = line.get_data() x_int, y_int = xkcd_line(x, y, xlim, ylim, mag, f1, f2, f3) # create foreground and background line lw = line.get_linewidth() line.set_linewidth(2 * lw) line.set_data(x_int, y_int) # don't add background line for axes if (line is not xaxis) and (line is not yaxis): line_bg = pl.Line2D(x_int, y_int, color=bgcolor, linewidth=8 * lw) ax.add_line(line_bg) ax.add_line(line) # Draw arrow-heads at the end of axes lines arr1 = 0.03 * np.array([-1, 0, -1]) arr2 = 0.02 * np.array([-1, 0, 1]) arr1[::2] += np.random.normal(0, 0.005, 2) arr2[::2] += np.random.normal(0, 0.005, 2) x, y = xaxis.get_data() if '+' in str(xaxis_arrow): ax.plot(x[-1] + arr1 * xspan * aspect, y[-1] + arr2 * yspan, color='k', lw=2) if '-' in str(xaxis_arrow): ax.plot(x[0] - arr1 * xspan * aspect, y[0] - arr2 * yspan, color='k', lw=2) x, y = yaxis.get_data() if '+' in str(yaxis_arrow): ax.plot(x[-1] + arr2 * xspan * aspect, y[-1] + arr1 * yspan, color='k', lw=2) if '-' in str(yaxis_arrow): ax.plot(x[0] - arr2 * xspan * aspect, y[0] - arr1 * yspan, color='k', lw=2) # Change all the fonts to humor-sans. prop = fm.FontProperties(fname=fontfn, size=16) for text in ax.texts: text.set_fontproperties(prop) # modify legend leg = ax.get_legend() if leg is not None: leg.set_frame_on(False) for child in leg.get_children(): if isinstance(child, pl.Line2D): x, y = child.get_data() child.set_data(xkcd_line(x, y, mag=10, f1=100, f2=0.001)) child.set_linewidth(2 * child.get_linewidth()) if isinstance(child, pl.Text): child.set_fontproperties(prop) # Set the axis limits ax.set_xlim(xax_lim[0] - 0.1 * xspan, xax_lim[1] + 0.1 * xspan) ax.set_ylim(yax_lim[0] - 0.1 * yspan, yax_lim[1] + 0.1 * yspan) # adjust the axes ax.set_xticks([]) ax.set_yticks([]) if expand_axes: ax.figure.set_facecolor(bgcolor) ax.set_axis_off() ax.set_position([0, 0, 1, 1]) return ax
[docs]@_available_to_user_plot def autoscale_y(ax, margin=0.1): """ Rescales the y-axis based on the data that is visible given the current xlim of the axis. Created by eldond at 2017 Mar 23 20:26 This function was taken from an answer by DanKickstein on I don't think this function considers shaded bands such as would be used to display error bars. Increasing the margin may be a good idea when dealing with such plots. :param ax: a matplotlib axes object :param margin: The fraction of the total height of the y-data to pad the upper and lower ylims """ printd('autoscale_y()...') def get_bottom_top(line2): xd = line2.get_xdata() yd = line2.get_ydata() lo, hi = ax.get_xlim() with warnings.catch_warnings(record=False) as w: # Ignore RuntimeWarnings for these calculations (could be caused by NaNs in xd or yd) warnings.filterwarnings("ignore", category=RuntimeWarning) if (len(xd) == 2) and (xd[0] == 0.0) and (xd[1] == 1.0): # Special case to handle axhline - added by D. Eldon y_displayed = yd else: y_displayed = np.array(yd)[((xd >= lo) & (xd <= hi))] if len(y_displayed) > 1: h = np.nanmax(y_displayed) - np.nanmin(y_displayed) bot2 = np.nanmin(y_displayed) - margin * h top2 = np.nanmax(y_displayed) + margin * h return bot2, top2 else: printd('Could not find any elements of y_displayed in range') return np.inf, -np.inf lines = ax.get_lines() bot, top = np.inf, -np.inf for j, line in enumerate(lines): printd(' line', j) new_bot, new_top = get_bottom_top(line) printd(' new_bot = {:}, new_top = {:}'.format(new_bot, new_top)) if new_bot < bot: bot = new_bot if new_top > top: top = new_top if (bot < top) and (bot > (-np.inf)) and (top < np.inf): printd('Set ylim to ', bot, top) ax.set_ylim(bot, top) else: printd('Did not set ylim because one of these was False:') printd(' bot > (-np.inf) = {:}, bot = {:}'.format(bot > (-np.inf), bot)) printd(' top < np.inf = {:}, top = {:}'.format(top < np.inf, top)) printd(' bot < top = {:}'.format(bot < top)) printd('Done with autoscale_y().') return
[docs]@_available_to_user_plot def set_linearray(lines, values=None, cmap=default_matplotlib_cmap_cycle, vmin=None, vmax=None): """ Set colors of lines to colormapping of values. Other good sequential colormaps are YlOrBr and autumn. A good diverging colormap is bwr. :param lines: Lines to set colors. :type lines: list :param values: Values corresponding to each line. Default is indexing. :type values: array like :param cmap: Valid matplotlib colormap name. :type cmap: str :param vmax: Upper bound of colormapping. :type vmax: float :param vmin: Lower bound of colormapping. :type vmin: float :return: ScalarMappable. A mapping object used for colorbars. """ from matplotlib.cbook import flatten if values is None: values = list(range(len(lines))) nm = matplotlib.colors.Normalize(vmin, vmax) sm =, norm=nm) sm.set_array(values) colors = sm.cmap(sm.norm(values)) for l, c in zip(lines, colors): if isinstance(l, matplotlib.container.ErrorbarContainer): for li in flatten(l.lines): li.set_color(c) else: l.set_color(c) return sm
[docs]@_available_to_user_plot def pi_multiple(x, pos=None): """ Provides a string representation of x that is a multiple of the fraction pi/'denominator'. See multiple_formatter documentation for more info. """ func = multiple_formatter() return func(x, pos)
[docs]@_available_to_user_plot def multiple_formatter(denominator=24, number=np.pi, latex=r'\pi'): """ Returns a tick formatting function that creates tick labels in multiples of 'number'/'denominator'. Code from :param denominator: The denominator of the fraction that tick labels are created in multiples of. :param number: The numerator of the fraction that tick labels are created in multiples of :param latex: The latex string used to represent 'number' """ def gcd(a, b): while b: a, b = b, a % b return a def _multiple_formatter(x, pos): den = denominator num = * x / number)) com = gcd(num, den) (num, den) = (int(num / com), int(den / com)) if den == 1: if num == 0: return r'$0$' if num == 1: return r'$%s$' % latex elif num == -1: return r'$-%s$' % latex else: return r'$%s%s$' % (num, latex) else: if num == 1: return r'$\frac{%s}{%s}$' % (latex, den) elif num == -1: return r'$\frac{-%s}{%s}$' % (latex, den) else: return r'$\frac{%s%s}{%s}$' % (num, latex, den) return _multiple_formatter
[docs]@_available_to_user_plot def convert_ticks_to_pi_multiple(axis=None, major=2, minor=4): """ Given an axis object, force its ticks to be at multiples of pi, with the labels formatted nicely [...,-2pi,-pi,0,pi,2pi,...] :param axis: An axis object, such as pyplot.gca().xaxis :param major: int Denominator of pi for major tick marks. 2: major ticks at 0, pi/2., pi, ... Can't be greater than 24. :param minor: int Denominator of pi for minor tick marks. 4: minor ticks at 0, pi/4., pi/2., ... :return: None """ axis = pyplot.gca().xaxis if axis is None else axis major = min([major, 24]) axis.set_major_locator(MultipleLocator(np.pi / major)) axis.set_minor_locator(MultipleLocator(np.pi / minor)) axis.set_major_formatter(FuncFormatter(pi_multiple))
[docs]@_available_to_user_plot def is_colorbar(ax): """ Guesses whether a set of Axes is home to a colorbar :param ax: Axes instance :return: bool True if the x xor y axis satisfies all of the following and thus looks like it's probably a colorbar: No ticks, no tick labels, no axis label, and range is (0, 1) """ xcb = (len(ax.get_xticks()) == 0) and (len(ax.get_xticklabels()) == 0) and (len(ax.get_xlabel()) == 0) and (ax.get_xlim() == (0, 1)) ycb = (len(ax.get_yticks()) == 0) and (len(ax.get_yticklabels()) == 0) and (len(ax.get_ylabel()) == 0) and (ax.get_ylim() == (0, 1)) return xcb != ycb # != is effectively xor in this case, since xcb and ycb are both bool
# xor trick from
[docs]@_available_to_user_plot def tag_plots_abc( fig=None, axes=None, corner=[1, 1], font_size=matplotlib.rcParams['xtick.labelsize'], skip_suspected_colorbars=True, start_at=0, **annotate_kw, ): """ Tag plots with (a), (b), (c), ... :param fig: Specify a figure instance instead of letting the function pick the most recent one :param axes: Specify a plot axes instance or list/array of plot axes instances instead of letting the function use fig.get_axes() :param corner: Which corner does the tag go in? [0, 0] for bottom left, [1, 0] for bottom right, etc. :param font_size: Font size of the annotation. :param skip_suspected_colorbars: bool Try to detect axes which are home to colorbars and skip tagging them. An Axes instance is suspected of having a colorbar if either the xaxis or yaxis satisfies all of these conditions: - Length of tick list is 0 - Length of tick label list is 0 - Length of axis label is 0 - Axis range is (0,1) :param start_at: int Offset value for skipping some numbers. Useful if you aren't doing real subfigs, but two separate plots and placing them next to each other in a publication. Set to 1 to start at (b) instead of (a), for example. :param annotate_kw: dict Additional keywords passed to annotate(). Keywords used by settings such as corner, etc. will be overriden. """ from matplotlib.cbook import flatten if fig is None: fig = pyplot.gcf() if axes is None: axes = fig.get_axes() letters = string.ascii_lowercase + string.ascii_uppercase tag_x = abs(corner[0] - 0.005) tag_y = abs(corner[1] - 0.025) tag_ha = ['left', 'right'][corner[0]] tag_va = ['bottom', 'top'][corner[1]] printd('Plot tag information: ', tag_x, tag_y, tag_ha, tag_va) valid_axes = [ax for ax in flatten(np.atleast_1d(axes)) if not (skip_suspected_colorbars and is_colorbar(ax))] # Assign all keywords into annotate_kw to prevent duplicates annotate_kw['xycoords'] = 'axes fraction' annotate_kw['ha'] = tag_ha annotate_kw['va'] = tag_va annotate_kw['fontsize'] = font_size annotate_kw['xy'] = (tag_x, tag_y) j = 0 + start_at for i, ax in enumerate(valid_axes): if ax.axison: # Thanks to Suever for pointing out ax.axison: ax.annotate(' ({:}) '.format(letters[j]), **annotate_kw).draggable(True) j += 1 return
[docs]@_available_to_user_plot def mark_as_interactive(ax, interactive=True): """ Mark an axis as interactive or not :param ax: axis :param interactive: boolean :return: axis """ ax.is_interactive = interactive return ax
[docs]@_available_to_user_plot class View1d(object): """ Plot 2D or 3D data as line-plots with interactive navigation through the alternate dimensions. Navigation uses the 4 arrow keys to traverse up to 2 alternate dimensions. The data must be on a regular grid, and is formed into a xarray DataArray if not already. Uses matplotlib line plot for float/int data, OMFIT uerrrorbar for uncertainty variables. Examples: The view1d can be used to interactively explore data. For usual arrays it draws line slices. >> t = np.arange(20) >> s = np.linspace(0,2*np.pi,60) >> y = np.sin(np.atleast_2d(s).T+np.atleast_2d(t)) >> da = xarray.DataArray(y,coords=SortedDict([('space',s),('time',t)]),name='sine') >> v = View1d(da.transpose('time','space'),dim='space',time=10) For uncertainties arrays, it draws errorbars using the uerrorbar function. Multiple views with the same dimensions can be linked for increased speed (eliminate redundant calls to redraw). >> y_u = unumpy.uarray(y+(random(y.shape)-0.5),random(y.shape)) >> da_u = xarray.DataArray(y_u,coords=SortedDict([('space',s),('time',t)]),name='measured') >> v_u = View1d(da_u,dim='space',time=10,axes=pyplot.gca()) >> # v will remain connected to keypress events and drive vu Variable dependent axis data can be viewed if x and y share a regular grid in some coordinates, >> x = np.array([s+(random(s.shape)-0.5)*0.2 for i in t]).T >> da_x = xarray.DataArray(x,coords=SortedDict([('space',s),('time',t)]),name='varspace') >> ds = da_u.to_dataset().merge(da_x.to_dataset()) >> v_x = View1d(ds,name='measured',dim='varspace',time=10,axes=pyplot.gca()) >> """ def __init__( self, data, coords=None, dims=None, name=None, dim=None, axes=None, dynamic_ylim=False, use_uband=False, cornernote_options=None, plot_options=None, **indexers, ): r""" :param data: DataArray or array-like 2D or 3D data values to be viewed. :param coords: dict-like Dictionary of Coordinate objects that label values along each dimension. :param dims: tuple Dimension names associated with this array. :param name: string Label used in legend. Empty or beginning with '_' produces no legend label. If the data is a DataArray it will be renamed before plotting. If the data is a Dataset, the name specifies which of its existing data_vars to plot. :param dim: string, DataArray Dimension plotted on x-axis. If DataArray, must have same dims as data. :param axes: Axes instance The axes plotting is done in. :param dynamic_ylim: bool Re-scale y limits of axes when new slices are plotted. :param use_uband: bool Use uband instead of uerrorbar to plot uncertainties variables. :param cornernote_options: dict Key word arguments passed to cornernote (such as root, shot, device). If this is present, then cornernote will be updated with the new time if there is only one time to show, or the time will be erased from the cornernote if more than one time is shown by this View1d instance (such as by freezing one slice). :param plot_options: dict Key word arguments passed to plot/uerrorbar/uband. :param \**indexers: dict Dictionary with keys given by dimension names and values given by arrays of coordinate index values. Must include all dimensions other, than the fundamental. """ # check args and make Dataset if needed if isinstance(data, DataArray): if coords or dims: raise ValueError("View cannot re-assign coords/dims to data that is already a DataArray") if name: = name elif not # require name name = 'viewdata' = name else: name = data = data.to_dataset() elif isinstance(data, Dataset): if coords or dims: raise ValueError("View cannot re-assign coords/dims to data that is already a Dataset") else: if name is None: name = 'viewdata' data = DataArray(data, coords=coords, dims=dims, name=name).to_dataset() # check key words self.dynamic_ylim = dynamic_ylim self.cornernote_options = cornernote_options if plot_options: self.plot_options = plot_options else: self.plot_options = {} # Make figure & axes if needed if not axes: axes = pyplot.gca() self.figure, self.axes = axes.get_figure(), axes # make view aware of full data, the slice data, and slice indexes self.links = [] self.is_uncertain = is_uncertain(data[name].values) = data = name if dim: if isinstance(dim, DataArray): if not = 'dim' data = data.merge(dim.to_dataset()) dim = self.dim = dim # reduce any large datasets just to the things we need = Dataset(dict(((,[]), (dim,[dim])))) # put dimension at front of indexing if dim in dim0 = dim elif dim in # it will return the first dimension which is not in indexers dim0 = [d for d in if d not in indexers][0] else: raise ValueError("Argument `dim` must be in DataArray dims or in Dataset") dims = list( dims.insert(0, dims.pop(dims.index(dim0))) =*dims) else: self.dim = list([0] self.values =**indexers) self.indexes = SortedDict(indexers) for key, val in list(self.indexes.items()): self.indexes[key] = np.atleast_1d(val) # use metadata for labeling self.axes.set_xlabel(dim) self.label_prefix = self.plot_options.pop('label', '') self._label_indexers_visible = True lbl = self._get_label() # actual plotting self.use_uband(use_uband) self._inactive = [] if self.is_uncertain: self._active = self._uplot(self.values[self.dim], self.values[], ax=self.axes, label=lbl, **self.plot_options) else: self._active = self.axes.plot(self.values[self.dim], self.values[], label=lbl, **self.plot_options) # Update annotations if self.cornernote_options is not None: self._update_cornernote() legend = self.axes.legend(loc=0) legend.draggable(True) self.figure.canvas.draw_idle() # store all views within the axes for safe keeping (persistence) if not hasattr(self.axes, 'views'): self.axes.views = [] self.axes.views.append(self) self.axes.is_interactive = True # connect up the key/link driven navigation self._accelerate = False self.cid = self.figure.canvas.mpl_connect('key_press_event', self.key_command) if not hasattr(self.figure, 'shortcuts'): self.figure.shortcuts = SortedDict() self.figure.shortcuts['left'] = 'Reduce index of first non-axial dimension' self.figure.shortcuts['right'] = 'Increase index of first non-axial dimension' self.figure.shortcuts['up'] = 'Reduce index of second non-axial dimension' self.figure.shortcuts['down'] = 'Increase index of second non-axial dimension' self.figure.shortcuts['shift+up/down/left/right'] = 'Accelerate: Change index by one fifth the range' self.figure.shortcuts['a'] = 'Toggle accelerate all on/off' self.figure.shortcuts['w'] = 'Persist current slice of viewed data' self.figure.shortcuts['e'] = 'Erase all persisting slices of viewed data except the active one'
[docs] def use_uband(self, use=True): """Toggle use of uband instead of uerrorbar for plotting function""" self._use_uband = use if use: self._uplot = uband else: self._uplot = uerrorbar
[docs] def set_label_indexers_visible(self, visible=True): """Include the current indexers in line labels.""" self._label_indexers_visible = visible # update plot self.isel(draw=False, **self.indexes)
def _get_label(self): """(Re-)Set the view Line's label indicating the current indexes slicing""" if self.label_prefix: lbl = self.label_prefix + ' ' + else: lbl = if not return lbl for k, i in list(self.indexes.items()): i = np.unique(np.clip(i, 0,[k].size - 1)) if self._label_indexers_visible and len(i) <[k].shape[0]: # don't label if not sliced at all if len(i) == 1: lbl += ', {key} = {vals} {units}'.format(, key=k,[k][i].values[0],[k][i].attrs.get('units', '') ) else: lbl += ', {key} = {vals}'.format(, key=k,[k][i].values) return lbl.rstrip(', ').lstrip(', ')
[docs] def key_command(self, event, draw=True, **plot_options): """ Use arrows to navigate up to 2 extra dimensions by incrementing the slice indexes. Use w/e to write/erase slices to persist outside on navigation. """ # prevent infinite loops if not hasattr(self, '_lastevent'): self._lastevent = None if self._lastevent == event: return self._lastevent = event # drivers can override plot_options self.plot_options.update(plot_options) # get standard key press name if not event.key: return key_press = event.key.lower() # had to add this because mac Tk doesn't register shift+key if key_press == 'a': self._accelerate = not self._accelerate # slice navigation if key_press in ['left', 'right', 'down', 'up', 'shift+left', 'shift+right', 'shift+up', 'shift+down']: # select dimension being changed if key_press in ['left', 'right', 'shift+left', 'shift+right']: dim = list(self.indexes.keys())[0] if key_press in ['down', 'up', 'shift+up', 'shift+down']: i = min(len(self.indexes) - 1, 1) dim = list(self.indexes.keys())[i] # apply change in right direction if key_press in ['left', 'down']: self.indexes[dim] -= 1 + self._accelerate * (int([dim].size // 5) - 1) if key_press in ['shift+left', 'shift+down']: self.indexes[dim] -= int([dim].size // 5) if key_press in ['right', 'up']: self.indexes[dim] += 1 + self._accelerate * (int([dim].size // 5) - 1) if key_press in ['shift+right', 'shift+up']: self.indexes[dim] += int([dim].size // 5) # prevent stepping out of the dimension range self.indexes[dim] = np.clip(self.indexes[dim], 0,[dim] - 1) # update plot self.isel(draw=False, **self.indexes) # snapshot write / erase if key_press == 'w': self._add_lines() if isinstance(self._inactive[-1][0], matplotlib.container.ErrorbarContainer): plot_options['color'] = self._inactive[-1][0].lines[0].get_color() else: plot_options['color'] = self._inactive[-1][0].get_color() if key_press == 'e': self._clear_lines() # drive linked views, then draw once for link in self.links: link.key_command(event, draw=False, **plot_options) if draw: self.figure.canvas.draw_idle() # if linked to views in other figures draw those too otherfigs = unsorted_unique([self.figure] + [v.figure for v in self.links])[1:] for f in otherfigs: f.canvas.draw_idle() return
def _add_lines(self): """Plot a persistent snapshot of the current slice""" lbl = self._get_label() if self.is_uncertain: ys = self.values[].values.reshape(self.values[].shape[0], -1) xs = self.values[self.dim].values.reshape(self.values[self.dim].shape[0], -1) if xs.shape[1] == 1: xs = np.tile(xs, ys.shape[1]) for x, y in zip(xs.T, ys.T): if self._use_uband: self._inactive.append(uband(x, y, ax=self.axes, label=lbl, **self.plot_options)) else: self._inactive.append(uerrorbar(x, y, ax=self.axes, label=lbl, **self.plot_options)) else: self._inactive.append(self.axes.plot(self.values[self.dim], self.values[], label=lbl, **self.plot_options)) # Update annotations if self.cornernote_options is not None: self._update_cornernote() def _clear_lines(self): """Remove otherwise persistent snapshots""" from matplotlib.cbook import flatten for lines in self._inactive: if lines in self.axes.containers: # bug in matplotlib ErrorbarContainer.remove for l in flatten(lines.lines): l.remove() self.axes.containers.remove(lines) else: for l in lines: l.remove() self._inactive = [] self.isel(**self.indexes) # redraw current slice
[docs] def isel(self, draw=True, **indexers): """ Re-slice the data along its extra dimensions using indexes. """ from matplotlib.cbook import flatten self.indexes = SortedDict(indexers) # only plot valid indexers (but allow keeping track of others) for dim in indexers: indexers[dim] = np.unique(np.clip(self.indexes[dim], 0,[dim].size - 1)) self.values =**indexers).transpose(*[].dims) # make sure dim order stays consistent lbl = self._get_label() # update errorbar container lines if self.is_uncertain: # can't animate PollyCollection - have to delete and draw new if self._use_uband: colors = [] for obj in self._active: colors.append(obj.get_color()) if isinstance(obj, Uband): if obj.line in self.axes.lines: self.axes.lines.remove(obj.line) for band in obj.bands: if band in self.axes.collections: self.axes.collections.remove(band) else: # assume errorbar container for l in flatten(obj.lines): if l in self.axes.lines: self.axes.lines.remove(l) if l in self.axes.collections: self.axes.collections.remove(l) self._active = [] ys = self.values[].values.reshape(self.values[].shape[0], -1) xs = self.values[self.dim].values.reshape(self.values[self.dim].shape[0], -1) if xs.shape[0] == 1 and ys.shape[0] == 1: xs = xs.T ys = ys.T if xs.shape[1] == 1: xs = np.tile(xs, ys.shape[1]) lim = (np.nanmin(nominal_values(ys)), np.nanmax(nominal_values(ys))) match_colors = not 'color' in self.plot_options # keep same color even if not explicitly set the first time for i, (xi, yi) in enumerate(zip(xs.T, ys.T)): if match_colors and i < len(colors): self.plot_options['color'] = colors[i] self._active.append(self._uplot(xi, yi, ax=self.axes, label=lbl, **self.plot_options)[0]) if match_colors: self.plot_options.pop('color', None) # painfully update only the individual pieces of a ErrorbarContainer for faster redraw else: ys = self.values[].values.reshape(self.values[].shape[0], -1) xs = self.values[self.dim].values.reshape(self.values[self.dim].shape[0], -1) if xs.shape[1] == 1: xs = np.tile(xs, ys.shape[1]) lim = (np.nanmin(nominal_values(ys)), np.nanmax(nominal_values(ys))) for i, (xi, yi) in enumerate(zip(xs.T, ys.T)): container = self._active[i] y, ye = nominal_values(yi), std_devs(yi) x, xe = nominal_values(xi), std_devs(xi) # matplotlib error bar does not do multiple lines at once, so safe to simply flatten x, xe = x.flatten(), xe.flatten() line, capinfo, barinfo = container.lines if len(barinfo) == 2: # X and Y errors (xbars, ybars) = barinfo if len(capinfo): # With end-caps xtop, xbot, ytop, ybot = capinfo else: # Without end-caps xtop, xbot, ytop, ybot = None, None, None, None else: # Y errors only (ybars,) = barinfo[0] if len(capinfo): # With end-caps ytop, ybot = capinfo else: # Without end-caps ytop, ybot = None, None xtop, xbot = None, None yt, yb = y + ye, y - ye xt, xb = x + xe, x - xe line.set_data([x, y]) if ybot: ytop.set_data([x, yt]) ybot.set_data([x, yb]) new_segments_y = [np.array([[xn, yts], [xn, ybs]]) for xn, yts, ybs in zip(x, yt, yb)] ybars.set_segments(new_segments_y) if xbot: xtop.set_data([xt, y]) xbot.set_data([xb, y]) new_segments_x = [np.array([[xts, yn], [xbs, yn]]) for xts, xbs, yn in zip(xt, xb, y)] xbars.set_segments(new_segments_x) container.set_label(lbl) # update simple lines else: lim = (np.nanmin(self.values[]), np.nanmax(self.values[])) xlen = len(self._active[0].get_xdata()) ys = self.values[].values.reshape(xlen, -1).T if self.dim in # usual case where x is a 1D dimension xs = [self.values[self.dim]] * ys.shape[0] else: # special case where x is a dataset variable xs = self.values[self.dim].values.reshape(xlen, -1).T # update each line for line, x, y in zip(self._active, xs, ys): line.set_data([x, y]) line.set_label(lbl) # remove lines that got crunched into upper/lower bounds for line in self._active[len(ys) :]: self.axes.lines.remove(line) self._active.remove(line) # make sure the new labels get into the legend if self.axes.get_legend(): visible = self.axes.get_legend().get_visible() leg = self.axes.legend(loc=0) leg.draggable(True) leg.set_visible(visible) # expand axis limits if necessary if self.dynamic_ylim: ylim = self.axes.get_ylim() if ylim[0] > lim[0]: self.axes.set_ylim(ymin=lim[0]) if ylim[1] < lim[1]: self.axes.set_ylim(ymax=lim[1]) # Update annotations if self.cornernote_options is not None: self._update_cornernote() if draw: self.figure.canvas.draw_idle()
def _update_cornernote(self): """ If indexing by time and if there is only one index, use it to look up the time for the current index and then update the cornernote. """ if len(self.indexes) == 1 and 'time' in list(self.indexes.keys()): time_index = np.unique(np.clip(self.indexes['time'], 0,['time'].size - 1)) times =['time'].values[time_index] co = self.cornernote_options root = co.get('root', None) device = co.get('device', None) shot = co.get('shot', None) if len(time_index) == 1 and not self._inactive: cornernote(root=root, device=device, shot=shot, time=times[0], ax=self.axes) elif len(time_index) == 2 and not self._inactive: cornernote(root=root, device=device, shot=shot, time='{}, {}'.format(times[0], times[1]), ax=self.axes) else: cornernote(root=root, device=device, shot=shot, time='', ax=self.axes)
[docs]@_available_to_user_plot class View2d(object): """ Plot 2D data with interactive slice viewers attached to the 2D Axes. Left clicking on the 2D plot refreshes the line plot slices, right clicking overplots new slices. The original design of this viewer was for data on a rectangular grid, for which x and y are 1D arrays defining the axes but may be irregularly spaced. In this case, the line plot points correspond to the given data. If x or y is a 2D array, the data is assumed irregular and interpolated to a regular grid using scipy.interpolate.griddata. **Example:** Explore a basic 2D np array without labels, >> x = np.linspace(-1, 1, 200) >> y = np.linspace(-2, 2, 200) >> xx, yy = meshgrid(x, y) >> z = np.exp(-xx**2 - yy**2) >> v = View2d(z) To add more meaningful labels to the axes and data do, >> v = View2d(z, coords={'x':x, 'y':y}, dims=('x', 'y'), name='wow') or use a DataArray, >> d = DataArray(z, coords={'x':x, 'y':y}, dims=('x', 'y'), name='wow') >> v = View2d(d) Note that the coordinates should be 1D. Initializing a view with regular grid, 2D coordinates will result in an attempt to slice them appropriately. This is done for consistency with some matplotlib 2D plotting routines, but is not recommended. >> v = View2d(z, coords=dict(x=x, y=y), dims=('x', 'y')) If you have irregularly distributed 2D data, it is recomended that you first interpolate it to a 2D grid in whatever way is most applicable. If you do not, initializing a view will result in an attempt to linearly interpolate to a automatically chosen grid. >> x = np.random.rand(1000) >> y = np.random.rand(1000) * 2 >> z = np.exp(-x**2 - y**2) >> v = View2d(z, coords=dict(x=x, y=y), dims=('x', 'y')) The same applies for 2D collections of irregular points and values. >> x = x.reshape((50, 20)) >> y = y.reshape((50, 20)) >> z = z.reshape((50, 20)) >> v = View2d(z, coords=[('x', x), ('y', y)], dims=('x', 'y')) """ def __init__( self, data, coords=None, dims=None, name=None, axes=None, quiet=False, use_uband=False, contour_levels=0, imag_options={}, plot_options={'marker': '', 'ls': '-'}, **indexers, ): r""" :param data: DataArray or array-like 2D or 3D data values to be viewed. :param coords: dict-like Dictionary of Coordinate objects that label values along each dimension. :param dims: tuple Dimension names associated with this array. :param name: string Label used in legend. Empty or begining with '_' produces no legend label. :param dim: string, DataArray Dimension plotted on x-axis. If DataArray, must have same dims as data. :param axes: Axes instance The axes plotting is done in. :param quiet: bool Suppress printed messages. :param use_uband: bool Use uband for 1D slice plots instead of uerrorbar. :param contour_levels: int or np.ndarray Number of or specific levels used to draw black contour lines over the 2D image. :param imag_options: dict Key word arguments passed to the DataArray plot method (modified pcolormesh). :param plot_options: dict Key word arguments passed to plot or uerrorbar. Color will be determined by cmap variable. :param \**indexers: dict Dictionary with keys given by dimension names and values given by arrays of coordinate index values. """ # check args and make Dataset if needed if isinstance(data, DataArray): if coords or dims: raise ValueError("View cannot re-assign coords/dims to data that is already a DataArray") if name: = name elif not # require name name = 'viewdata' = name else: name = data = data elif isinstance(data, Dataset): if coords or dims: raise ValueError("View cannot re-assign coords/dims to data that is already a Dataset") data = data[list(data.data_vars.keys())] else: if name is None: name = 'viewdata' try: data = DataArray(data, coords=coords, dims=dims, name=name) except Exception as error: # turn the flexible input styles into the standard coords dict, dims list if coords is not None: coords = OrderedDict(coords) if isinstance(dims, dict): x0, x1 = list(dims.values())[:2] dims = list(dims.keys())[:2] elif dims is not None: x0, x1 = [coords[k] for k in dims[:2]] else: x0, x1 = list(coords.values())[:2] dims = list(coords.keys())[:2] # check if called with 2D coordinates of a regular grid if ( (x0.shape == data.shape and x1.shape == data.shape) and np.all(x0 == np.roll(x0, -1, axis=0)) and np.all(x1 == np.roll(x1, -1, axis=1)) ): coords_1d = {dims[0]: x0[0], dims[1]: x1[:, 0]} data = DataArray(data, coords=coords_1d, dims=dims, name=name) # interpolate weird inputs to a nice 2D grid else: if not quiet: printe("WARNING: Could not directly form a DataArray from inputs.") printe(" {:}".format(error)) printe(" > Interpolating to a regular grid.") # make a regular grid x0, x1 = np.ravel(x0), np.ravel(x1) dx0, dx1 = np.diff(sorted(x0)), np.diff(sorted(x1)) n0 = int((np.nanmax(x0) - np.nanmin(x0)) / np.median(dx0[np.where(dx0 != 0)])) n1 = int((np.nanmax(x1) - np.nanmin(x1)) / np.median(dx1[np.where(dx1 != 0)])) i0 = np.linspace(np.nanmin(x0), np.nanmax(x0), min(500, n0)) i1 = np.linspace(np.nanmin(x1), np.nanmax(x1), min(500, n1)) data = interpolate.griddata((x0, x1), np.ravel(data), (i0[None, :], i1[:, None]), method='linear') data = data.reshape((i1.shape[0], i0.shape[0])).T data = DataArray(data, coords=[(k, x) for k, x in zip(dims, (i0, i1))], dims=dims, name=name) # pass on key words if not np.shape(contour_levels) and contour_levels == 0: contour_levels = None # captures 0.0, 0, None, False, etc. self._contour_levels = contour_levels self._use_uband = use_uband self.plot_options = plot_options self.plot_cmap = plot_options.get('cmap', None) self.imag_options = imag_options # store basic inputs = data # make alarm for timing self.alarm = None # empty list of links self.links = [] # Make figure & main axes if needed if axes == None: fig = pyplot.figure(figsize=(10, 8)) gs = matplotlib.gridspec.GridSpec(4, 4) gs.update(wspace=0.05, hspace=0.05) axm = fig.add_subplot(gs[1:, :-1]) # main 2D plot else: # make a grid within the original axes location fig = axes.get_figure() axm = axes rect = axm.get_position().bounds gs = matplotlib.gridspec.GridSpec(4, 4, left=rect[0], bottom=rect[1], right=rect[0] + rect[2], top=rect[1] + rect[3]) axm.set_position(gs[1:, :-1].get_position(fig)) # main 2D plot # add surrounding axes axx = fig.add_subplot(gs[0, :-1], sharex=axm) axy = fig.add_subplot(gs[1:, -1], sharey=axm) axr = fig.add_subplot(gs[0, -1]) # add to class self.fig = fig self.axm = axm self.axx = axx self.axy = axy self.axr = axr self.old_xmin = 0.0 self.old_xmax = 0.0 self.old_ymin = 0.0 self.old_ymax = 1.0 # autoscaling behaves poorly for ax in [axm, axy, axx, axr]: ax.autoscale(False) # store within the axes for safe keeping (persistence) if not hasattr(self.axm, 'views'): self.axm.views = [] self.axm.views.append(self) self.axm.is_interactive = True # 2D plot with colorbar axm.set_xlabel(data.dims[0]) axm.set_ylabel(data.dims[1]) self._set_values(, **imag_options) axm.axis('tight') self.colorbar = self.fig.colorbar(self.image, ax=[axm, axy, axx, axr], use_gridspec=True, shrink=0.75, anchor=(0, 0.0)) self.colorbar.set_label( if not self.plot_cmap: self.plot_cmap = self.axm.collections[-1].get_cmap() # 1D plots aesthetics axx.set_ylabel( # axx.set_ylim(np.nanmin(data),np.nanmax(data)) loc = pyplot.MaxNLocator(5) axx.yaxis.set_major_locator(loc) axy.set_xlabel( # axy.set_xlim(np.nanmin(data),np.nanmax(data)) loc = pyplot.MaxNLocator(5) axy.xaxis.set_major_locator(loc) pyplot.setp(axy.xaxis.get_majorticklabels(), rotation=90) for tl in axx.xaxis.get_ticklabels() + axy.yaxis.get_ticklabels(): tl.set_visible(False) # connect axes to events self.RectangleSelector = RectangleSelector( self.axm, self.line_select_callback, drawtype='box', useblit=True, # minspanx=0, minspany=0, button=[1, 3], # don't use middle button spancoords='data', ) self.fig.canvas.mpl_connect('button_press_event', self.line_select_callback) self.fig.canvas.mpl_connect('key_press_event', self.toggle_selector) # connect radio buttons = RadioButtons( axr, ('data', 'd/d-' +[0], 'd/d-' +[1], 'int-' +[0], 'int-' +[1]), activecolor='black', ) = 'data' self._log_toggled = False if hasattr(self.fig, 'shortcuts'): self.fig.shortcuts['j'] = 'Toggle log scaling of 2D image data.' else: printe('No shortcuts attribute so no log scaling') self.fig.canvas.draw_idle()
[docs] def use_uband(self, use=True): """Toggle use of uband instead of uerrorbar in 1D slice plots""" self._use_uband = use
[docs] def key_navigate_cuts(self, key_press): if key_press in ['left', 'right']: dat =[[0]] xmin, xmax, std = self.get_vslice_args() else: dat =[[1]] xmin, xmax, std = self.get_hslice_args() imin = np.abs(dat.values - xmin).argmin() imax = np.abs(dat.values - xmax).argmin() # imax = max(imax,min(imin,len(dat)-2)) if key_press in ['right', 'up']: imax_new = imax + 1 imin_new = imin + 1 else: imax_new = imax - 1 imin_new = imin - 1 # Bound the results max_index = min(max(imax_new, 0), len(dat) - 1) min_index = min(max(imin_new, 0), len(dat) - 1) if key_press in ['left', 'right']: self.vslice(dat[min_index].values, dat[max_index].values, std) else: self.hslice(dat[min_index].values, dat[max_index].values, std)
# if self.alarm: # OMFITaux['rootGUI'].after_cancel(self.alarm) # self.alarm=None # self.alarm=OMFITaux['rootGUI'].after(100,None) def _set_values(self, values, draw=True, log_toggled=False, **kw): r""" Plots the main 2D image in axm. Optionally re-sets displayed values and axes, but does not change original data. :param values: 2D DataArray of values for display/interaction. :type values: DataArray. :param draw: Re-draw the figure when done. :type draw; bool :param log_toggled: Default value of False resets 'j' hotkey binding. :type log_toggled; bool :param \**kw: Additional key word arguments passed to the DataArray plot method (modified pcolormesh). :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh """ self._log_toggled = log_toggled self.values = values # mask values (allows for nan's) zim = nominal_values(values) zim =, zim) if isinstance(values[values.dims[0]].values[0], str): x0 = np.arange(len(values[values.dims[0]])) else: x0 = nominal_values(values[values.dims[0]]) if isinstance(values[values.dims[1]].values[0], str): x1 = np.arange(len(values[values.dims[1]])) else: x1 = nominal_values(values[values.dims[1]]) # limits xlim = (np.nanmin(x0), np.nanmax(x0)) ylim = (np.nanmin(x1), np.nanmax(x1)) vlim = (np.nanmin(zim), np.nanmax(zim)) # default key words dkw = {} dkw.update(kw) # first call creates the base image from the original data if not hasattr(self, 'image'): if 1 in raise ValueError("Must have 2D data to use View2D. This data is {:}x{:}".format(* self.image =, add_colorbar=False, **dkw) # updated image self.image.set_array(zim.T.flatten()) self.image.autoscale() for k, v in list(kw.items()): try: getattr(self.image, 'set_' + k)(v) except Exception: printe('Cannot reset {:}, change manually.'.format(k)) # contour if hasattr(self, 'contour'): for c in self.contour.collections: if c in self.axm.collections: self.axm.collections.remove(c) if not np.all(zim.T == zim.flat[0]) and self._contour_levels is not None: self.contour = self.axm.contour(x0, x1, zim.T, self._contour_levels, colors='k') # aesthetics self.axm.set_xlim(*xlim) self.axm.set_ylim(*ylim) self.axx.set_ylim(*vlim) self.axy.set_xlim(*vlim) # re-slice if hasattr(self, 'get_vslice_args') and hasattr(self, 'get_hslice_args'): self.vslice(*self.get_vslice_args(), draw=draw, force=True) self.hslice(*self.get_hslice_args(), draw=draw, force=True) else: self.vslice(*[np.nanmean(x0)] * 2, draw=draw, force=True) self.hslice(*[np.nanmean(x1)] * 2, draw=draw, force=True) # drawing is slow if draw: self.fig.canvas.draw() # Why does this change the axm limits???? return
[docs] def set_data(self, data=None, **kw): """ Set base data and axes, as well as displayed values. Call with no arguments re-sets to interactively displayed values (use after :func:`der`, or :func:`int`). :param data: 2D array of data. :type data: DataArray :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh """ if not data is None: = data # keep the radio selection self._radio_select( return
[docs] def der(self, axis=0): """ Set 2D values to derivative of data along specified axis. :param axis: Axis along which derivative is taken (0=horizontal, 1=vertical). :type axis: int :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh """ if not isinstance(axis, int): axis = dx = np.gradient(nominal_values([[axis]])) + 0 *[[axis]] dy = np.gradient(nominal_values(, edge_order=2)[axis] + 0 * data = dy / dx if np.allclose(nominal_values(self.values), nominal_values(data)): return self._set_values(data, **self.imag_options) # force action in linked views for v in self.links: v.der(axis=axis) return
[docs] def int(self, axis=0): """ Set 2D values to derivative of data along specified axis. :param axis: Axis along which integration is taken (0=horizontal, 1=vertical). :type axis: int :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh """ if isinstance(axis, int): x =[[axis]] else: x =[axis] data = integrate.cumtrapz(np.nan_to_num(nominal_values(, x=x, axis=axis, initial=0) + 0 * if np.allclose(nominal_values(self.values), nominal_values(data)): return self._set_values(data, **self.imag_options) # force action in linked views for v in self.links: return
[docs] def toggle_log(self): """ Toggle log/linear scaling of data in 2D plot. """ if self._log_toggled: self._set_values(10**self.values, **self.imag_options) lbl ='log ') self.colorbar.set_label(lbl) self.axx.set_ylabel(lbl) self.axy.set_xlabel(lbl) self.axx.set_yscale('linear') self.axy.set_xscale('linear') self._log_toggled = False else: self._set_values(np.log10(self.values), **self.imag_options) # relabel all axes lbl = 'log ' + self.colorbar.set_label(lbl) self.axx.set_ylabel(lbl) self.axy.set_xlabel(lbl) # log scale line plots self.axx.set_yscale('log') self.axy.set_xscale('log') # robust log scale axes should not go to 0 zim = nominal_values(self.values) zim =, zim) zim = <= 0, zim) vlim = (np.nanmin(zim), np.nanmax(zim)) self.axx.set_ylim(*vlim) self.axy.set_xlim(*vlim) self._log_toggled = True return
[docs] def vslice(self, xmin, xmax, std=False, draw=True, force=False, **kw): r""" Plot line collection of x slices. :param xmin: Lower bound of slices dispayed in line plot. :type xmin: float :param xmax: Upper bound of slices dispayed in line plot. :type xmax: float :param std: Display mean and standard deviation instead of individual slices. :type std: bool :param draw: Redraw the figure canvas :type draw: bool :param Force: Re-slice even if arguments are identical to last slice. :type draw: bool :param \**kw: Extra key words are passed to the 1D :func:`plot` function :return: Possibly modified (xmin,xmax,std) :rtype: tuple """ # plot options self.plot_options.update(kw) # slices if isinstance([[0]].values[0], str): x0 = np.arange(len([[0]])) else: x0 = nominal_values([[0]]) if isinstance([[1]].values[0], str): x1 = np.arange(len([[1]])) else: x1 = nominal_values([[1]]) if isinstance(xmin, DataArray): xmin = xmin.values if isinstance(xmax, DataArray): xmax = xmax.values imin = np.abs(x0 - xmin).argmin() imax = np.abs(x0 - xmax).argmin() imax = max(imax, min(imin, len(x0) - 2)) x = self.values[imin : imax + 1, :] y =[[1]] limits = self.axm.axis() # Stop infinite loops if hasattr(self, 'get_vslice_args'): if not force and (x0[imin], x0[imax], std) == self.get_vslice_args(): return self.get_vslice_args() self.get_vslice_args = lambda: (x0[imin], x0[imax], std) self.axy.containers = [] if compare_version(matplotlib.__version__, '3.3.4') <= 0: self.axy.collections = [] self.axy.lines = [] else: for collection in self.axy.collections: collection.remove() for line in self.axy.lines: line.remove() xm, sd = np.mean(nominal_values(x), axis=0), np.std(nominal_values(x), axis=0) self._vslice_mean = xm self._vslice_std = sd self._vslice_xdata = y if std: (l,) = self.axy.plot(nominal_values(xm), nominal_values(y), color='r') self.axy.fill_betweenx(nominal_values(y), nominal_values(xm - sd), nominal_values(xm + sd), alpha=0.3, color='r') else: y = x1.repeat(x.shape[0]).reshape((len(x1), -1)).T lines = [] for i, (xi, yi) in enumerate(zip(x, y)): if self._use_uband: (l,) = uband(xi, yi, ax=self.axy, **self.plot_options) else: (l,) = uerrorbar(xi, yi, ax=self.axy, **self.plot_options) if i == 0 or i == imax - imin: l.set_label('{:}={:6.2g}'.format(self.axm.get_xlabel(), x0[imin + i])) lines.append(l) sm = set_linearray(lines[:], x0[imin : imax + 1], cmap=self.plot_cmap) # self.fig.colorbar(sm,cax=self.axy,orientation='horizontal',use_gridspec=True) # lc = plotslices(x,y,colorbar='None',axes=self.axy,cmap='autumn_r') self.axy.legend(numpoints=1, prop={'size': 'xx-small'}).draggable(True) # self.axy.texts = [] # self.axy.text(0.02,0.95,'{:}=[{:.2},{:.2}]'.format(self.axm.get_xlabel(),*self.get_vslice_args()), # verticalalignment='top',transform=self.axy.transAxes) if hasattr(self.axm, '_vspan'): if self.axm._vspan in self.axm.patches: self.axm.patches.pop(self.axm.patches.index(self.axm._vspan)) self.axm._vspan = self.axm.axvspan(x0[imin], x0[imax], facecolor='w', edgecolor='w', alpha=0.3) if draw: self.fig.canvas.draw() # force action in linked views for v in self.links: v.vslice(*self.get_vslice_args()) return self.get_vslice_args()
[docs] def hslice(self, ymin, ymax, std=False, draw=True, force=False, **kw): r""" Plot line collection of y slices. :param ymin: Lower bound of slices dispayed in line plot. :type ymin: float :param ymax: Upper bound of slices dispayed in line plot. :type ymax: float :param std: Display mean and standard deviation instead of individual slices. :type std: bool :param draw: Redraw the figure canvas :type draw: bool :param Force: Re-slice even if arguments are identical to last slice. :type draw: bool :param \**kw: Extra key words are passed to the 1D :func:`plot` function :return: Possibly modified (ymin,ymax,std) :rtype: tuple """ # plot options self.plot_options.update(kw) # slices if isinstance([[0]].values[0], str): x0 = np.arange(len([[0]])) else: x0 = nominal_values([[0]]) if isinstance([[1]].values[0], str): x1 = np.arange(len([[1]])) else: x1 = nominal_values([[1]]) if isinstance(ymin, DataArray): ymin = ymin.values if isinstance(ymax, DataArray): ymax = ymax.values imin = np.abs(x1 - ymin).argmin() imax = np.abs(x1 - ymax).argmin() imax = max(imax, min(imin, len(x1) - 2)) y = self.values[:, imin : imax + 1] x =[[0]].values if isinstance(x[0], str): x = np.arange(len(x)) # Prevent infinite loops if hasattr(self, 'get_hslice_args'): if not force and (x1[imin], x1[imax], std) == self.get_hslice_args(): return self.get_hslice_args() self.get_hslice_args = lambda: (x1[imin], x1[imax], std) self.axx.containers = [] if compare_version(matplotlib.__version__, '3.3.4') <= 0: self.axx.collections = [] self.axx.lines = [] else: for collection in self.axx.collections: collection.remove() for line in self.axx.lines: line.remove() ym, sd = np.mean(nominal_values(y), axis=1), np.std(nominal_values(y), axis=1) self._hslice_mean = ym self._hslice_std = sd self._hslice_xdata = x0 if std: (l,) = self.axx.plot(x0, nominal_values(ym), color='b') self.axx.fill_between(nominal_values(x), nominal_values(ym - sd), nominal_values(ym + sd), alpha=0.3, color='b') else: lines = [] for i, yi in enumerate(y.T): if self._use_uband: (l,) = uband(x, yi.values, ax=self.axx, **self.plot_options) else: (l,) = uerrorbar(x, yi.values, ax=self.axx, **self.plot_options) if i == 0 or i == imax - imin: l.set_label('{:}={:6.2g}'.format(self.axm.get_ylabel(), x1[imin + i])) lines.append(l) sm = set_linearray(lines[:], x1[imin : imax + 1], cmap=self.plot_cmap) self.axx.legend(numpoints=1, prop={'size': 'xx-small'}).draggable(True) if hasattr(self.axm, '_hspan'): if self.axm._hspan in self.axm.patches: self.axm.patches.pop(self.axm.patches.index(self.axm._hspan)) self.axm._hspan = self.axm.axhspan(x1[imin], x1[imax], facecolor='w', edgecolor='w', alpha=0.3) if draw: self.fig.canvas.draw() # force action in linked views for v in self.links: v.hslice(*self.get_hslice_args()) return self.get_hslice_args()
[docs] def line_select_callback(self, eclick, erelease=None): """ Call vslice and hslice for the range of x and y spanned by the rectangle between mouse press and release. Called by RectangleSelector. :param eclick: Matplotlib mouse click. :type eclick: matplotlib Event :param erelease: Matplotlib mouse release. :type erelease: matplotlib Event :return: None """ # do nothing if click not in main axes if eclick.inaxes != self.axm: return # special treatment for single click (no dragged box) eclick.button == rightClickMPLindex # right click if erelease is None: x, y = self.axm.transData.inverted().transform([eclick.x, eclick.y]) self.hslice(y, y, std=rightClickMPLindex) self.vslice(x, x, std=rightClickMPLindex) self.fig.canvas.draw_idle() return # check if box is new (prevent infinite loops) x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata xmin, xmax = sorted([x1, x2]) ymin, ymax = sorted([y1, y2]) if xmin == self.old_xmin and xmax == self.old_xmax and ymin == self.old_ymin and ymax == self.old_ymax: return self.old_xmin, self.old_xmax, self.old_ymin, self.old_ymax = xmin, xmax, ymin, ymax # print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2)) # print(" The button you used were: %s %s" % (eclick.button, erelease.button)) self.hslice(ymin, ymax, std=rightClickMPLindex) self.vslice(xmin, xmax, std=rightClickMPLindex) self.fig.canvas.draw_idle()
[docs] def toggle_selector(self, event): """ Connected to key press events to turn on (a) or off (q) selector. :param event: key press event. :type event: matplotlib event :return: None """ if event is None or event.key is None: return if event.key.lower() in ['up', 'down', 'left', 'right']: self.key_navigate_cuts(event.key.lower()) # if event.key.lower()=='f': # self.scroll_event('','forward') if event.key in ['Q', 'q'] and print(' RectangleSelector deactivated.') self.RectangleSelector.set_active(False) if event.key in ['A', 'a'] and not print(' RectangleSelector activated.') self.RectangleSelector.set_active(True) if event.key == 'j': self.toggle_log()
def _radio_select(self, label): """ Radio button effect. """ = label if label == 'data': self._set_values(, **self.imag_options) elif label == 'd/d-' +[0]: self.der(0) elif label == 'd/d-' +[1]: self.der(1) elif label == 'int-' +[0]: elif label == 'int-' +[1]: for view in self.links: # have to manually change the fill for l, c in zip(, if l.get_text() == label: c.set_facecolor( else: c.set_facecolor( view.fig.canvas.blit(
[docs]@_available_to_user_plot class View3d(View2d): """ View 3D data by scrolling through 3rd dimension in a View2d plot. """ def __init__( self, data, coords=None, dims=None, name=None, axes=None, use_uband=False, quiet=False, contour_levels=0, imag_options={}, plot_options={'marker': '', 'ls': '-'}, **indexers, ): r""" :param data: DataArray or array-like 2D or 3D data values to be viewed. :param coords: dict-like Dictionary of Coordinate objects that label values along each dimension. :param dims: tuple Dimension names associated with this array. :param name: string Label used in legend. Empty or begining with '_' produces no legend label. :param dim: string, DataArray Dimension plotted on x-axis. If DataArray, must have same dims as data. :param axes: Axes instance The axes plotting is done in. :param quiet: bool Suppress printed messages. :param use_uband: bool Use uband for 1D slice plots instead of uerrorbar. :param contour_levels: int or np.ndarray Number of or specific levels used to draw black contour lines over the 2D image. :param imag_options: dict Key word arguments passed to the DataArray plot method (modified pcolormesh). :param plot_options: dict Key word arguments passed to plot or uerrorbar/uband. Color will be determined by cmap variable. :param \**indexers: dict Dictionary with keys given by dimension names and values given by arrays of coordinate index values. """ # check args and make Dataset if needed if isinstance(data, DataArray): if coords or dims: raise ValueError("View cannot re-assign coords/dims to data that is already a DataArray") if name: = name elif not # require name name = 'viewdata' = name else: name = data = data elif isinstance(data, Dataset): if coords or dims: raise ValueError("View cannot re-assign coords/dims to data that is already a Dataset") data = data[list(data.data_vars.keys())] else: if name is None: name = 'viewdata' # must be regular gridded data data = DataArray(data, coords=coords, dims=dims, name=name) # pass on key words self._use_uband = use_uband self.plot_options = plot_options self.plot_cmap = plot_options.get('cmap', None) self.imag_options = imag_options # store basic inputs self.data3d = data # adjust inputs self.step = max(1, int(len(data[data.dims[2]]) // 100)) self._index = int(self.data3d.shape[-1] // 2) View2d.__init__( self, data[:, :, self._index], axes=axes, use_uband=use_uband, contour_levels=contour_levels, imag_options=imag_options, plot_options=plot_options, **indexers, ) self.axm.set_title('') # remove xarray label of 3rd dimension's slice (we have our own dynamically updated one) self._set_3d() self.alarm = None self.widget = AxesWidget(self.axm) self.widget.connect_event('scroll_event', self._scrollx) = True printi("Scroll over 2D plot to change " + data.dims[2]) self.fig.canvas.draw() def _scrollx(self, event): """ Update viewer2d axes to show new 3d slice. """ key = self.data3d.dims[2] ilim = self.data3d[key].shape[0] # change index if event.button == 'up': # ignore if already at limit if self._index == ilim - 1: return self._index += self.step if self._index > ilim - 1: self._index = ilim - 1 elif event.button == 'down': if self._index == 0: return self._index -= self.step if self._index < 0: self._index = 0 else: return if self.alarm: OMFITaux['rootGUI'].after_cancel(self.alarm) self.alarm = None self.alarm = OMFITaux['rootGUI'].after(10, self._set_3d) def _set_3d(self): self.set_3d(self._index, index=True) for v in self.links: try: v._scrollx(event) except Exception: pass
[docs] def set_3d(self, x2, index=False, draw=True): """ Set third dimension of view to value nearest to slice. :param x2: Slice in third dimension. Type depends on value of `index`. :type x2: float or int :param index: Set True if x2 is the integer index :type index: bool :param draw: If True, redraw the canvas :type draw: bool :return: None """ # nearest index key = self.data3d.dims[2] if index: indx = x2 else: indx = np.abs([key] - x2).argmin() # update the central 2D plot self.set_data(self.data3d[:, :, indx], draw=False) ax = self.axm if compare_version(matplotlib.__version__, '3.3.4') <= 0: ax.texts = [] else: for aText in ax.texts: aText.remove() ax.text(0.02, 0.95, '{:}={:6.2g}'.format(key, float(self.data3d[key].values[self._index])), transform=ax.transAxes) if draw: self.fig.canvas.draw()
[docs]@_available_to_user_plot class DragPoints(object): r""" This class is used to define matplotlib draggable arrays :param yArray: location in the OMFIT tree for the y array :param xArray: location in the OMFIT tree for the x array :param eyArray: location in the OMFIT tree for the x error array :param exArray: location in the OMFIT tree for the y error array :param editY: allow dragging of points along Y axis :param editX: allow dragging of points along X axis :param exDelta: increments of exArray :param eyDelta: increments of eyArray :param editX: allow dragging of points along X axis :param sorted: keep points sorted in x :param continuous: draw continuously even while user is dragging (may be good to disable if takes long time) :param showOriginal: show original points :param func: a function with signature like `func(x,y,motion,fargs=[])`. where: * x : x coordinate of control points * y : y coordinate of control points * motion : True if the user has dragged a point This function must return x\_, y\_, x, y where: * x\_ : interpolating points between x coordinate of control points * y\_ : interpolating points between y coordinate of control points * x : x coordinate of control points * y : y coordinate of control points :param fargs: arguments to the function :param cyArray: location in the OMFIT tree for the y interpolation array :param cxArray: location in the OMFIT tree for the x interpolation array :param onMove: function handle to call whenever control point is moved :param resize: boolean to whether update axes size when drag point gets to an edge for the figure :param show_hints: bool. Show a cornernote with tips for interaction. :param ax: Axes. Axes in which to plot. All other key word arguments passed to the matplotlib plot function. :return: """ epsilon = 5 def __init__( self, yArray, xArray=None, eyArray=None, exArray=None, editY=True, editX=True, exDelta=1, eyDelta=1, sorted=False, continuous=True, showOriginal=False, func=None, fargs=[], cyArray=None, cxArray=None, onMove=None, resize=False, show_hints=True, ax=None, **kwargs, ): self.xArray = xArray self.yArray = yArray self.exArray = exArray self.eyArray = eyArray self.cxArray = cxArray self.cyArray = cyArray if ax is not None: # enable any key words for lines kw = copy.deepcopy(kwargs) for k in ['animated', 'label']: # these are set by this class kw.pop(k, None) ckw = copy.deepcopy(kw) kw['linestyle'] = '' # the interpolation line gets the linestyle kw.setdefault('marker', kw.pop('m', 'o')) if func == 'nearest': def func(x, y, motion, *fargs, **fkw): if self.cxArray is not None: x_ = eval(self.cxArray) else: x_ = np.linspace(min(x), max(x), 1001) try: return x_, interp1d(x, y, kind='nearest')(x_), x, y except ValueError as _excp: printe(repr(_excp)) elif func == 'linear': def func(x, y, motion, *fargs, **fkw): if self.cxArray is not None: x_ = eval(self.cxArray) else: x_ = np.linspace(min(x), max(x), 1001) try: return x_, interpolate.InterpolatedUnivariateSpline(x, y, k=1, *fargs, **fkw)(x_), x, y except ValueError as _excp: printe(repr(_excp)) elif func == 'spline': def func(x, y, motion, *fargs, **fkw): if self.cxArray is not None: x_ = eval(self.cxArray) else: x_ = np.linspace(min(x), max(x), 1001) try: return x_, interpolate.InterpolatedUnivariateSpline(x, y, *fargs, **fkw)(x_), x, y except ValueError as _excp: printe(repr(_excp)) elif func == 'pchip': def func(x, y, motion, *fargs, **fkw): if self.cxArray is not None: x_ = eval(self.cxArray) else: x_ = np.linspace(min(x), max(x), 1001) try: return x_, interpolate.PchipInterpolator(x, y, *fargs, **fkw)(x_), x, y except ValueError as _excp: printe(repr(_excp)) elif func == 'circular': sorted = False def func(x, y, motion, *fargs, **fkw): x[-1] = x[0] y[-1] = y[0] x0 = np.mean(x) y0 = np.mean(y) xm = max(x) - min(x) ym = max(y) - min(y) t = np.unwrap(np.arctan2((y - y0) / ym, (x - x0) / xm)) if t[0] > t[1]: t = -t if self.cxArray is not None: x_ = eval(self.cxArray) else: x_ = np.linspace(min(x), max(x), 1001) if self.cyArray is not None: y_ = eval(self.cyArray) else: y_ = np.linspace(min(y), max(y), 1001) t_ = np.linspace(min(t), max(t), len(y_)) tckp = interpolate.splrep(t, x, k=3, per=True) xi = interpolate.splev(t_, tckp, ext=0) tckp = interpolate.splrep(t, y, k=3, per=True) yi = interpolate.splev(t_, tckp, ext=0) return xi, yi, x, y self.func = func self.onMove = onMove self.fargs = fargs self.sorted = sorted self.continuous = continuous self.motion = False self.resize = resize self.editX = editX self.editY = editY self.exDelta = exDelta self.eyDelta = eyDelta = pyplot.gca() = self self._ind = None self.canvas = pyplot.gcf().canvas y = eval(self.yArray) if self.xArray is None: x = np.array(list(range(y.size))) self.editX = False else: x = eval(self.xArray) if self.exArray is None: ex = y * 0 else: ex = eval(self.exArray) if self.eyArray is None: ey = y * 0 else: ey = eval(self.eyArray) if self.cxArray is None: cx = y * 0 else: cx = eval(self.cxArray) if self.cyArray is None: cy = y * 0 else: cy = eval(self.cyArray) if self.func is not None: argspec = inspect.getfullargspec(self.func) if len(argspec.args) > 2 and argspec.args[2] == 'motion': x_, y_ = self.func(x, y, False, *self.fargs)[:2] else: # backward compatibility x_, y_ = self.func(x, y, *self.fargs)[:2] self.cline, *_ =, y_, animated=True, **ckw) if self.cxArray is not None: self.cline.set_marker('.') self.cline.set_markerfacecolor('k') if showOriginal: okw = copy.deepcopy(kw) okw.pop('color', None) okw.pop('c', None) self.line_orig, *_ =, y, animated=True, label='Original', color='grey', **okw) self.line, *_ =, y, animated=True, label='Control point', **kw) if self.onMove is not None: self.onMove(self.cline._xy[:, 0], self.cline._xy[:, 1], self.line._xy[:, 0], self.line._xy[:, 1]) self.orig = {} self.orig['xArray'] = copy.deepcopy(x) self.orig['yArray'] = copy.deepcopy(y) self.orig['cxArray'] = copy.deepcopy(cx) self.orig['cyArray'] = copy.deepcopy(cy) self.orig['exArray'] = copy.deepcopy(ex) self.orig['eyArray'] = copy.deepcopy(ey) self.canvas.mpl_connect('draw_event', self.draw_callback) self.canvas.mpl_connect('button_press_event', self.button_press_callback) self.canvas.mpl_connect('key_press_event', self.key_press_callback) self.canvas.mpl_connect('button_release_event', self.button_release_callback) self.canvas.mpl_connect('motion_notify_event', self.motion_notify_callback) txt = '<u>: undo\n<i>: insert node\n<d>: delete node\n' if self.eyArray is not None: txt += '<=>: increment ey\n' txt += '<->: decrement ey\n' if self.exArray is not None: txt += '<+>: increment ex\n' txt += '<_>: decrement ex\n' if show_hints: cornernote('', txt.strip()) self.inhibitFunc = True
[docs] def draw_callback(self, event=None): self.background = self.canvas.copy_from_bbox( if hasattr(self, 'line_orig'): if self.sorted: index = np.argsort(self.line._xy[:, 0]) self.line._xy[:, 0] = self.line._xy[index, 0] self.line._xy[:, 1] = self.line._xy[index, 1] if self.func is not None: if not self.inhibitFunc or (event is not None and hasattr(event, 'key') and event.key == 'u' and self.cyArray is not None): x_ = self.cline._xy[:, 0] if self.cxArray is not None: x_ = np.squeeze(eval(self.cxArray)) y_ = self.cline._xy[:, 1] if self.cyArray is not None: y_ = np.squeeze(eval(self.cyArray)) x = self.line._xy[:, 0] y = self.line._xy[:, 1] else: argspec = inspect.getfullargspec(self.func) if len(argspec.args) > 2 and argspec.args[2] == 'motion': try: x_, y_, x, y = self.func(self.line._xy[:, 0], self.line._xy[:, 1], self.motion, *self.fargs) except Exception as _excp: return else: # backward compatibility try: x_, y_, x, y = self.func(self.line._xy[:, 0], self.line._xy[:, 1], *self.fargs) except Exception as _excp: return self.motion = False self.cline.set_data(x_, y_) if self.cxArray is not None: exec('%s=np.reshape(x_,%s.shape)' % (self.cxArray, self.cxArray)) if self.cyArray is not None: exec('%s=np.reshape(y_,%s.shape)' % (self.cyArray, self.cyArray)) self.line._xy[:, 0] = np.squeeze(x) self.line._xy[:, 1] = np.squeeze(y) self.canvas.blit( if self.onMove is not None: self.onMove(self.cline._xy[:, 0], self.cline._xy[:, 1], self.line._xy[:, 0], self.line._xy[:, 1])
[docs] def get_ind_under_point(self, event): xy = self.line._xy xyt = self.line.get_transform().transform(xy) xt, yt = xyt[:, 0], xyt[:, 1] d = np.sqrt((xt - event.x) ** 2 + (yt - event.y) ** 2) indseq = np.nonzero(np.equal(d, np.amin(d)))[0] ind = indseq[0] if d[ind] >= self.epsilon: ind = None return ind
[docs] def button_press_callback(self, event): self.inhibitFunc = True if event.inaxes is None: return if event.button != 1: return self._ind = self.get_ind_under_point(event)
[docs] def key_press_callback(self, event=None): if event is not None and not event.inaxes: return if event is None or event.key == 'u': self.line.set_data(copy.deepcopy(self.orig['xArray']), copy.deepcopy(self.orig['yArray'])) if self.xArray is not None: exec(self.xArray + "=copy.deepcopy(self.orig['xArray'])") if self.yArray is not None: exec(self.yArray + "=copy.deepcopy(self.orig['yArray'])") if self.cxArray is not None: exec(self.cxArray + "=copy.deepcopy(self.orig['cxArray'])") if self.cyArray is not None: exec(self.cyArray + "=copy.deepcopy(self.orig['cyArray'])") if self.exArray is not None: exec(self.exArray + "=copy.deepcopy(self.orig['exArray'])") if self.eyArray is not None: exec(self.eyArray + "=copy.deepcopy(self.orig['eyArray'])") elif event.key == 'd': ind = self.get_ind_under_point(event) if ind is not None: printd('Deleted point #' + str(ind), topic='figure') tmp = np.array([tup for i, tup in enumerate(self.line._xy) if i != ind]) if not self.editX: tmp.T[0] = np.array(list(range(tmp.T[0].size))) self.line.set_data(tmp.T[0], tmp.T[1]) if self.xArray is not None: exec(self.xArray + '=tmp.T[0]') if self.yArray is not None: exec(self.yArray + '=tmp.T[1]') if self.exArray is not None: tmp = np.squeeze(eval(self.exArray)).tolist() del tmp[ind] tmp = np.array(tmp) exec(self.exArray + '=tmp') if self.eyArray is not None: tmp = np.squeeze(eval(self.eyArray)).tolist() del tmp[ind] tmp = np.array(tmp) exec(self.eyArray + '=tmp') elif event.key == 'i': xysl = self.line.get_transform().transform(self.line._xy) try: xys = self.cline.get_transform().transform(self.cline._xy) except Exception: xys = xysl p = event.x, event.y for i in range(len(xys) - 1): s0 = xys[i] s1 = xys[i + 1] d = point_to_line(p[0], p[1], s0[0], s0[1], s1[0], s1[1]) new_point = (event.xdata, event.ydata) condition = True if d <= self.epsilon and condition: tmp = np.array(list(self.line._xy[: i + 1]) + [new_point] + list(self.line._xy[i + 1 :])) if not self.editX: tmp.T[0] = np.array(list(range(tmp.T[0].size))) self.line.set_data(tmp.T[0], tmp.T[1]) if self.xArray is not None: exec(self.xArray + '=tmp.T[0]') if self.yArray is not None: exec(self.yArray + '=tmp.T[1]') if self.exArray is not None: ex = np.hstack( ( np.squeeze(eval(self.exArray))[: i + 1], (np.squeeze(eval(self.exArray))[i] + np.squeeze(eval(self.exArray))[i + 1]) * 0.5, np.squeeze(eval(self.exArray))[i + 1 :], ) ) exec(self.exArray + '=ex') if self.eyArray is not None: ey = np.hstack( ( np.squeeze(eval(self.eyArray))[: i + 1], (np.squeeze(eval(self.eyArray))[i] + np.squeeze(eval(self.eyArray))[i + 1]) * 0.5, np.squeeze(eval(self.eyArray))[i + 1 :], ) ) exec(self.eyArray + '=ey') break elif event.key == '=' and self.eyArray is not None: ind = self.get_ind_under_point(event) if ind is not None: exec(self.eyArray + '[ind]+=self.eyDelta') print('figure: Increased ey point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind])) elif event.key == '-' and self.eyArray is not None: ind = self.get_ind_under_point(event) if ind is not None: exec(self.eyArray + '[ind]-=self.eyDelta') print('figure: Decreased ey point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind])) elif event.key == '+' and self.exArray is not None: ind = self.get_ind_under_point(event) if ind is not None: exec(self.exArray + '[ind]+=self.exDelta') print('figure: Increased ex point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind])) elif event.key == '_' and self.exArray is not None: ind = self.get_ind_under_point(event) if ind is not None: exec(self.exArray + '[ind]-=self.exDelta') print('figure: Decreased ex point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind])) elif event.key == 'x': self.editX = not self.editX printi('drag X:' + str(self.editX)) elif event.key == 'y': self.editY = not self.editY printi('drag Y:' + str(self.editY)) OMFITaux['rootGUI'].event_generate("<<update_treeGUI>>") if self.onMove is not None: self.onMove(self.cline._xy[:, 0], self.cline._xy[:, 1], self.line._xy[:, 0], self.line._xy[:, 1]) self.canvas.draw_idle() self.canvas.restore_region(self.background) self.draw_callback(event)
[docs] def button_release_callback(self, event): if event.button != 1: return self._ind = None self.inhibitFunc = False OMFITaux['rootGUI'].event_generate("<<update_treeGUI>>") pyplot.draw()
[docs] def motion_notify_callback(self, event): if self._ind is None: return if event.inaxes is None: return if event.button != 1: return x, y = event.xdata, event.ydata tmp = self.line._xy if self.editX: tmp[self._ind][0] = x if self.editY: tmp[self._ind][1] = y self.line.set_data(tmp.T[0], tmp.T[1]) if self.editY: dy = np.diff(pyplot.ylim())[0] / 20.0 if self.resize and y > pyplot.ylim()[1] - dy: pyplot.ylim(np.array(pyplot.ylim()) + np.array([0, dy])) if self.resize and y < pyplot.ylim()[0] + dy: pyplot.ylim(np.array(pyplot.ylim()) + np.array([-dy, 0])) if self.editX: dx = np.diff(pyplot.xlim())[0] / 20.0 if self.resize and x > pyplot.xlim()[1] - dx: pyplot.xlim(np.array(pyplot.xlim()) + np.array([0, dx])) if self.resize and x < pyplot.xlim()[0] + dx: pyplot.xlim(np.array(pyplot.xlim()) + np.array([-dx, 0])) if self.xArray is not None: exec(self.xArray + '=tmp.T[0]') if self.yArray is not None: exec(self.yArray + '=tmp.T[1]') self.motion = True self.canvas.restore_region(self.background) if self.continuous: self.draw_callback() else: self.canvas.blit(
[docs]@_available_to_user_plot def editProfile(yi, xi=None, n=None, showOriginal=True, func='spline', onMove=None): """ This function opens an interactive figure for convenient editing of profiles via spline :param yi: string whose `eval` yields the y data :param xi: string whose `eval` yields the x data :param n: number of control points :param showOriginal: plot original data :param func: interpolation function used to interpolate between control points 'linear', 'spline', 'pchip', 'circular' :param onMove: function to call when moving of control points occurs :param resize: boolean to whether update axes size when drag point gets to an edge for the figure :return: DragPoints object """ from omfit_classes.utils_fit import autoknot if n is None: n = int(len(np.squeeze(eval(yi))) // 5) n = max([4, n]) n = min([n, int(len(np.squeeze(eval(yi))) // 2)]) if xi is None: OMFIT['scratch']['x_' + str(id(yi))] = np.linspace(0, len(np.squeeze(eval(yi))) - 1, len(np.squeeze(eval(yi)))) xi = "OMFIT['scratch']['x_" + str(id(yi)) + "']" if func == 'linear': x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, s=1, allKnots=True) elif func == 'pchip': tmp = interpolate.PchipInterpolator(np.squeeze(eval(xi)), np.squeeze(eval(yi))) class pchip(object): def __init__(self, x, y): pass def __call__(self, x0): return tmp(x0) x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, mindist=1e-6, allKnots=True, userFunc=pchip) elif func == 'nearest': tmp = interpolate.interp1d(np.squeeze(eval(xi)), np.squeeze(eval(yi)), kind='nearest', bounds_error=False) class nearest(object): def __init__(self, x, y): pass def __call__(self, x0): return tmp(x0) x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, mindist=1e-6, allKnots=True, userFunc=nearest) elif func == 'spline': x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, s=3, allKnots=True) elif func == 'circular': t = np.cumsum(np.sqrt(gradient(np.squeeze(eval(xi))) ** 2 + gradient(np.squeeze(eval(yi))) ** 2)) te = np.linspace(min(t), max(t), 1001) xe = interp1e(t, np.squeeze(eval(xi)), kind=3)(te) ye = interp1e(t, np.squeeze(eval(yi)), kind=3)(te) xe0 = np.mean(xe) ye0 = np.mean(ye) xem = max(xe) - min(xe) yem = max(ye) - min(ye) tec = np.cumsum(np.sqrt(gradient((xe - xe0) / xem) ** 2 + gradient((ye - ye0) / yem) ** 2)) t0, x0 = autoknot(tec, (xe - xe0) / xem, n + 1, mindist=1e-6, s=5, allKnots=True) t1, y0 = autoknot(tec, (ye - ye0) / yem, n + 1, mindist=1e-6, s=5, allKnots=True) t_ = (t0 + t1) / 2.0 x0 = interp1e(tec, xe)(t_) y0 = interp1e(tec, ye)(t_) OMFIT['scratch']['x0_' + str(id(x0))] = x0 OMFIT['scratch']['y0_' + str(id(y0))] = y0 if showOriginal: pyplot.plot(np.squeeze(eval(xi)), np.squeeze(eval(yi))) tmp = DragPoints( "OMFIT['scratch']['y0_" + str(id(y0)) + "']", "OMFIT['scratch']['x0_" + str(id(x0)) + "']", func=func, sorted=True, cxArray=xi, cyArray=yi, onMove=onMove, resize=False, ) if func == 'circular':'equal') return tmp
[docs]@_available_to_user_plot def cornernote( text='', root=None, device=None, shot=None, time=None, ax=None, fontsize='small', clean=True, remove=False, remove_specific=False ): """ Write text at the bottom right corner of a figure :param text: text to appear in the bottom left corner :param root: * if '' append nothing * if None append shot/time as from `OMFIT['MainSettings']['EXPERIMENT']` * if OMFITmodule append shot/time as from `root['SETTINGS']['EXPERIMENT']` :param device: override device string (does not print device at all if empty string) :param shot: override shot string (does not print shot at all if empty string) :param time: override time string (does not print time at all if empty string) :param ax: axis to plot on :param fontsize: str or float. Sets font size of the Axes annotate method. :param clean: delete existing cornernote(s) from current axes before drawing a new cornernote :param remove: delete existing cornernote(s) and return before drawing any new ones :param remove_specific: delete existing cornernote(s) from current axes only if text matches the text that would be printed by the current call to cornernote() (such as identical shot, time, etc.) :return: Matplotlib annotate object """ if ax is None: ax = pyplot.gca() if root is None: root = eval('OMFIT') if root and not isinstance(root, str): for location in ['MainSettings', 'SETTINGS']: if device is None: try: device = root[location]['EXPERIMENT']['device'] except Exception: pass if shot is None: try: shot = root[location]['EXPERIMENT']['shot'] except Exception: pass if time is None: try: time = root[location]['EXPERIMENT']['time'] except Exception: pass text_ = [] if device is not None and len(str(device)): text_.append(str(device)) if shot is not None and len(str(shot)): text_.append('#' + str(shot)) if time is not None and len(str(time)): text_.append(str(time) + ' ms') if len(text): text = ' '.join(text_) + ' : ' + text else: text = ' '.join(text_) elif isinstance(root, str) and not len(text): text = root # cornernote deletion options if remove_specific: # delete existing cornernote only if the text that would be written by this call matches exactly with existing text for textObj in ax.texts: if textObj.get_text() == text and textObj.get_label() == 'cornernote' and np.all(textObj._get_xy_display() == [0.98, 0.02]): textObj.remove() return # for the specific version only, return without writing a new cornernote (assume you want to delete just one specific note and not recreate it exactly) else: # if not trying to remove a specific note, then scrub all notes but don't return if clean or remove: for textObj in ax.texts: if textObj.get_label() == 'cornernote': textObj.remove() if remove: return # the remove keyword should leave the corner blank (no cornernote visible) else: pass # no return after clean because we want the default behavior to be: clear all notes and then write a fresh one return ax.annotate(text, xy=(0.98, 0.02), xycoords='figure fraction', fontsize=fontsize, ha='right', va='bottom', label='cornernote')
# Diagonal line to go with axhline and axvline
[docs]class axdline(matplotlib.pyplot.Line2D): """ Draw a line based on its slope and y-intercept. Additional arguments are passed to the <matplotlib.lines.Line2D> constructor. From stackoverflow anser by ali_m: Originally named ABLine2D """ def __init__(self, slope=1, intercept=0, *args, **kwargs): # get current axes if user has not specified them if not 'axes' in kwargs: if 'ax' in kwargs: # translate ax into axes kwargs.update({'axes': kwargs['ax']}) del kwargs['ax'] else: kwargs.update({'axes': pyplot.gca()}) ax = kwargs['axes'] # init the line, add it to the axes super().__init__([], [], *args, **kwargs) self._slope = slope self._intercept = intercept ax.add_line(self) # cache the renderer, draw the line for the first time ax.figure.canvas.draw() self._update_lim(None) # connect to axis callbacks self.axes.callbacks.connect('xlim_changed', self._update_lim) self.axes.callbacks.connect('ylim_changed', self._update_lim) def _update_lim(self, event): """called whenever axis x/y limits change""" x = np.array(self.axes.get_xbound()) y = (self._slope * x) + self._intercept self.set_data(x, y) try: self.axes.draw_artist(self) except AttributeError: printw('WARNING: could not update axdline artist! Diagonal lines might not display correctly!')
[docs]def square_subplots(nplots, ncol_max=np.inf, flip=False, sparse_column=True, just_numbers=False, identify=False, fig=None, **kw): r""" Creates a set of subplots in an approximate square, with a few empty subplots if needed :param nplots: int Number of subplots desired :param ncol_max: int Maximum number of columns to allow :param flip: bool True: Puts row 0 at the bottom, so every plot on the bottom row can accept an X axis label False: Normal plot numbering with row 0 at the top. The bottom row may be sparsely populated. :param sparse_column: bool Controls the arrangement of empty subplots. True: the last column is sparse. That is, all the empty plots will be in the last column. There will be at most one plot missing from the last row, and potentially several from the last column. The advantage is this provides plenty of X axes on the bottom row to accept labels. To get natural numbering of flattened subplots, transpose before flattening: axs.T.flatten(), or just use the 1D axsf array that's returned. False: the last row is sparse. All the empty plots will be in the last row. The last column will be missing at most one plot, but the last row may be missing several. This arrangement goes more smoothly with the numbering of axes after flattening. :param just_numbers: bool Don't create any axes, but instead just return the number of rows, columns, and empty subplots in the array. :param identify: bool For debugging: write the number (as flattened) and [row, col] coordinates of each subplot on the plot itself. These go in the center, in black. In the top left corner in red is the naive flattened count, which will appear on empty plots as well to show how wrong it is. In the bottom right corner in blue is the proper flattened count based on axsf. :param fig: Figure instance [optional] :param \**kw: keywords passed to pyplot.subplots when creating axes (like sharex, etc.) :return: (axs, axsf) or (nr, nc, on, empty) axs: 2d array of Axes instances. It is flipped vertically relative to normal axes output by pyplot.subplots, so the 0th row is the bottom. This is so the bottom row will be populated and can receive x axis labels. axsf: 1d array of Axes instances, leaving out the empty ones (they might not be in order nicely) empty: int: number of empty cells in axs. The first empty, if there is one, is [-1, -1] (top right), then [-1, -2] (top row, 2nd from the right), etc. nr: int: number of rows nc: int: number of columns on: 2d bool array: flags indicating which axes should be on (True) and which should be hidden/off (False) """ kw.setdefault('sharex', True) kw['squeeze'] = False # Find number of columns, rows, and empty plots nc = int(nplots**0.5) nc = int(np.ceil(np.min([nc, ncol_max]))) nr = int(np.ceil(nplots / float(nc))) empty = nr * nc - nplots # Make a rule for hiding the unused subplots on = np.ones((nr, nc), bool) for j in range(empty): if sparse_column: on[-(1 + j), -1] = False else: on[-1, -(1 + j)] = False # Map 1D indices to 2D indices (the hidden axes make this harder than it should be) offset = 0 coords = [None] * nplots for j in range(nplots): jeff = j + offset ir = jeff // nc ic = np.mod(jeff, nc) if sparse_column: # Skip empties by incrementing offset while not on[ir, ic]: offset += 1 jeff = j + offset ir = jeff // nc # Find which row we're on. ic = np.mod(jeff, nc) # Find which column we're on coords[j] = (ir, ic) if just_numbers: return nr, nc, on, empty # Make the plot grid if fig is None: fig, axs = pyplot.subplots(nr, nc, **kw) # Default to new figure instead of pyplot.gcf() if flip: axs = axs[::-1, :] for axx in axs.flatten()[~on.flatten()]: axx.axis('off') axsf = axs.flatten()[on.flatten()] else: pyplot.figure(num=fig.number) axs = np.empty((nr, nc), object) axsf = np.empty(nplots, object) sharex = None sharey = None for j in range(nplots): ir, ic = coords[j] if flip: slot = (nr - 1 - ir) * nc + ic + 1 else: slot = ir * nc + ic + 1 axx = pyplot.subplot(nr, nc, slot, sharex=sharex, sharey=sharey) axs[ir, ic] = axx axsf[j] = axx if kw.get('sharex', False): if ((ir < (nr - 1)) and (not flip)) or ((ir > 0) and flip): # This is not the last row, so tradition demands we hide the axis labels for tick in axx.xaxis.get_ticklabels(): tick.set_visible(False) if sharex is None: sharex = axx # If sharex=True, catch the first axes as the ones to share with if kw.get('sharey', False): if ic > 0: # This is not the first column, so tradition demands we hide the axis labels for tick in axx.yaxis.get_ticklabels(): tick.set_visible(False) if sharey is None: sharey = axx # Write numbers on the subplots to identify them if identify: for j in range(nplots): ir, ic = coords[j] axx = axs[ir, ic] # Get a reference to the subplot we're working with label = '{} [{}, {}]'.format(j, ir, ic) axx.text(0.5, 0.5, label, transform=axx.transAxes, ha='center', va='center') axsf[j].text(0.95, 0.05, str(j), transform=axsf[j].transAxes, ha='right', va='bottom', color='blue') if on[ir, ic]: axx.text(0.75, 0.75, 'ON', transform=axx.transAxes, ha='center', va='center', color='blue') else: if axx is not None: axx.text(0.75, 0.75, 'OFF', transform=axx.transAxes, ha='center', va='center', color='red') for j in range(nc * nr): ax1 = axs.flatten()[j] axt = axs.T.flatten()[j] if ax1 is not None: ax1.text(0.05, 0.95, str(j), transform=ax1.transAxes, ha='left', va='top', color='red') if axt is not None: axt.text(0.95, 0.95, 'T' + str(j), transform=axt.transAxes, ha='right', va='top', color='green') return axs, axsf
try: from matplotlib.widgets import TextBox as _TextBox except ImportError: class _TextBox(AxesWidget): """ A GUI neutral text input box. For the text box to remain responsive you must keep a reference to it. The following attributes are accessible: *ax* The :class:`matplotlib.axes.Axes` the button renders into. *label* A :class:`matplotlib.text.Text` instance. *color* The color of the text box when not hovering. *hovercolor* The color of the text box when hovering. Call :meth:`on_text_change` to be updated whenever the text changes. Call :meth:`on_submit` to be updated whenever the user hits enter or leaves the text entry field. Taken from which is in a pull request to matplotlib smithsp added the set_val method to be merged into that pull request with """ def __init__(self, ax, label, initial='', color='.95', hovercolor='1', label_pad=0.01): """ Parameters: ax : matplotlib.axes.Axes The :class:`matplotlib.axes.Axes` instance the button will be placed into. label : str Label for this text box. Accepts string. initial : str Initial value in the text box color : color The color of the box hovercolor : color The color of the box when the mouse is over it label_pad : float the distance between the label and the right side of the textbox """ from matplotlib.widgets import AxesWidget AxesWidget.__init__(self, ax) self.DIST_FROM_LEFT = 0.05 self.params_to_disable = [] for key in list(rcParams.keys()): if u'keymap' in key: self.params_to_disable += [key] self.text = initial self.label = ax.text(-label_pad, 0.5, label, verticalalignment='center', horizontalalignment='right', transform=ax.transAxes) self.text_disp = self._make_text_disp(self.text) self.cnt = 0 self.change_observers = {} self.submit_observers = {} # If these lines are removed, the cursor won't appear the first # time the box is clicked:, 1), 1) self.cursor_index = 0 # Because this is initialized, _render_cursor # can assume that cursor exists. self.cursor =, 0, 0) self.cursor.set_visible(False) self.connect_event('button_press_event', self._click) self.connect_event('button_release_event', self._release) self.connect_event('motion_notify_event', self._motion) self.connect_event('key_press_event', self._keypress) self.connect_event('resize_event', self._resize) ax.set_navigate(False) ax.set_axis_bgcolor(color) ax.set_xticks([]) ax.set_yticks([]) self.color = color self.hovercolor = hovercolor self._lastcolor = color self.capturekeystrokes = False def _make_text_disp(self, string): return self.DIST_FROM_LEFT, 0.5, string, verticalalignment='center', horizontalalignment='left', ) def _rendercursor(self): # this is a hack to figure out where the cursor should go. # we draw the text up to where the cursor should go, measure # and save its dimensions, draw the real text, then put the cursor # at the saved dimensions widthtext = self.text[: self.cursor_index] no_text = False if widthtext == "" or widthtext == " " or widthtext == " ": no_text = widthtext == "" widthtext = "," wt_disp = self._make_text_disp(widthtext) bb = wt_disp.get_window_extent() inv = bb = inv.transform(bb) wt_disp.set_visible(False) if no_text: bb[1, 0] = bb[0, 0] # hack done self.cursor.set_visible(False) self.cursor =[1, 0], bb[0, 1], bb[1, 1]) def _notify_submit_observers(self): for cid, func in self.submit_observers.items(): func(self.text) def _release(self, event): if self.ignore(event): return if event.canvas.mouse_grabber != return event.canvas.release_mouse( def _keypress(self, event): if self.ignore(event): return if self.capturekeystrokes: key = event.key if len(key) == 1: self.text = self.text[: self.cursor_index] + key + self.text[self.cursor_index :] self.cursor_index += 1 elif key == "right": if self.cursor_index != len(self.text): self.cursor_index += 1 elif key == "left": if self.cursor_index != 0: self.cursor_index -= 1 elif key == "home": self.cursor_index = 0 elif key == "end": self.cursor_index = len(self.text) elif key == "backspace": if self.cursor_index != 0: self.text = self.text[: self.cursor_index - 1] + self.text[self.cursor_index :] self.cursor_index -= 1 elif key == "delete": if self.cursor_index != len(self.text): self.text = self.text[: self.cursor_index] + self.text[self.cursor_index + 1 :] self.text_disp.set_text(self.text) self._rendercursor() self._notify_change_observers() if key == "enter": self._notify_submit_observers() def set_val(self, val): newval = str(val) if self.text == newval: return self.text = newval self.text_disp.set_text(self.text) self._rendercursor() self._notify_change_observers() self._notify_submit_observers() def _notify_change_observers(self): for cid, func in self.change_observers.items(): func(self.text) def begin_typing(self, x): self.capturekeystrokes = True # disable command keys so that the user can type without # command keys causing figure to be saved, etc self.reset_params = {} for key in self.params_to_disable: self.reset_params[key] = rcParams[key] rcParams[key] = [] def stop_typing(self): notifysubmit = False # because _notify_submit_users might throw an error in the # user's code, we only want to call it once we've already done # our cleanup. if self.capturekeystrokes: # since the user is no longer typing, # reactivate the standard command keys for key in self.params_to_disable: rcParams[key] = self.reset_params[key] notifysubmit = True self.capturekeystrokes = False self.cursor.set_visible(False) if notifysubmit: self._notify_submit_observers() def position_cursor(self, x): # now, we have to figure out where the cursor goes. # approximate it based on assuming all characters the same length if len(self.text) == 0: self.cursor_index = 0 else: bb = self.text_disp.get_window_extent() trans = inv = bb = trans.transform(inv.transform(bb)) text_start = bb[0, 0] text_end = bb[1, 0] ratio = (x - text_start) / (text_end - text_start) if ratio < 0: ratio = 0 if ratio > 1: ratio = 1 self.cursor_index = int(len(self.text) * ratio) self._rendercursor() def _click(self, event): if self.ignore(event): return if event.inaxes != self.stop_typing() return if not self.eventson: return if event.canvas.mouse_grabber != event.canvas.grab_mouse( if not (self.capturekeystrokes): self.begin_typing(event.x) self.position_cursor(event.x) def _resize(self, event): self.stop_typing() def _motion(self, event): if self.ignore(event): return if event.inaxes == c = self.hovercolor else: c = self.color if c != self._lastcolor: self._lastcolor = c if self.drawon: def on_text_change(self, func): """ When the text changes, call this *func* with event. A connection id is returned which can be used to disconnect. """ cid = self.cnt self.change_observers[cid] = func self.cnt += 1 return cid def on_submit(self, func): """ When the user hits enter or leaves the submision box, call this *func* with event. A connection id is returned which can be used to disconnect. """ cid = self.cnt self.submit_observers[cid] = func self.cnt += 1 return cid def disconnect(self, cid): """remove the observer with connection id *cid*""" try: del self.observers[cid] except KeyError: pass matplotlib.widgets.TextBox = _TextBox if __name__ == '__main__': from pylab import figure, imshow import uncertainties x = np.linspace(0, 2 * np.pi, 20) y = np.linspace(0, 3 * np.pi, 20) s = np.sin(x) z = [] for i in range(10): z.append(np.sin(x) + i) z = np.array(z) xx, yy = np.meshgrid(x, y) zz = np.sin(xx) * np.cos(yy) xerr = abs(0.1 * np.random.random(x.shape) * x) serr = abs(0.1 * np.random.random(x.shape) * s) X = uncertainties.unumpy.uarray(x, xerr) S = uncertainties.unumpy.uarray(s, serr) figure('sin') pyplot.plot(x, s) figure('multi-sin') plotc(x, z.T) ax = pyplot.gca() XKCDify(ax, xaxis_loc=0) draw() map_HBS_to_RGB(np.array([1]), np.array([2]), np.array([0])) paths = contourPaths(x, y, zz, [-0.75, 0, 0.75]) fig = figure('pcolormesh') pyplot.figure(num=fig.number) ax = pyplot.subplot(121) pcolor2(xx, yy, zz) ax2 = pyplot.subplot(122, sharey=ax, sharex=ax) im = blur_image(zz, 3) ax2.imshow(im, extent=(min(x), max(x), min(y), max(y)), origin='lower') set_fontsize(fontsize='+10') plot_equality_lims(ax=ax2) autofmt_sharey() hardcopy('/tmp/%s/test_hardcopy.pdf' % os.environ['USER']) fig = figure('uerrorbar') pyplot.figure(num=fig.number) ax = pyplot.subplot(211) uerrorbar(X, S, ax=ax) ax2 = pyplot.subplot(212, sharex=ax) uerrorbar(X, S, ax=ax2) plot_equality(X, S, ax=ax2) autofmt_sharex() fig = figure('uband') uband(X, S)