Source code for lazyfmri.plotting

import string
import numpy as np
from . import utils
import pandas as pd
import seaborn as sns
import matplotlib as mpl
from typing import Union
import matplotlib.pyplot as plt
import matplotlib.patches as patches


[docs] class Defaults(): """Defaults Default settings for plotting. Parameters ---------- pad_title: int Set the distance between the title and the plot. Default = 20 title_size: int Set the font size of titles. Default = 22; you can set it equal to font_size (or any other size) by specifying "font_size". font_size: int Set the font size of axis labels/titles. Default = 18 label_size: int Set the font size of tick labels. Default = 14 tick_width: float Set the line-width of the ticks. Default = 0.5 tick_length: float Set the length of the ticks. Default = 0 (no ticks) axis_width: float Set the line-width of axes. Default = 0.5 line_width: int Line widths for either all graphs (then *int*) or a *list* with the number of elements as requested graphs, default=1. line_style: str Set the style of data in line-plots. Default = "solid" sns_offset: int Set the distance between y-axis and start of plot. Default = None sns_trim: bool Trim the axes following seaborn's convention. Default = False sns_bottom: bool Trim specifically the x-axis of plots. Default = False sns_ori: str, optional Default orientation for bar-plots. Default is up-right (vertical). Allowed options are "v" or "h". :class:`lazyfmri.plotting.LazyBar()` sns_rot: int, float, optional Rotation of labels in bar plot. Specific to :class:`lazyfmri.plotting.LazyBar()` xkcd: bool Plot with cartoon style. Default = False ylim_bottom: float Set the y-limiter at the bottom of the plot. Default = None ylim_top: float Set the y-limiter at the top of the plot. Default = None xlim_left: float Set the x-limiter at the left of the plot. Default = None xlim_right: float Set the x-limiter at the right of the plot. Default = None set_xlim_zero: bool Set the distance between plot and y-axis to 0. Default = False legend_handletext: float Set the distance between the handle and text in legends. Default = 0.05 (bit closer than default) x_label: str, optional Label of x-axis, by default None y_label: str, optional Label of y-axis, by default None labels: str, list, optional String (if 1 timeseries) or list (with the length of `ts`) of colors, by default None. Labels for the timeseries to be used in the legend title: str, dict, optional String of dictionary collecting the following keys representing information about the title: .. code-block:: python title = { 'title' "some title", # title text 'color': 'k', # color (default = 'k') 'fontweight': "bold" # fontweight (default = 'normal'), can be any of the matplotib fontweight options (e.g., 'italic', 'bold', 'normal' etc.) } color: str, list, optional String (if 1 timeseries) or list (with the length of `ts`) of colors, by default None. If nothing is specified, we'll use `cmap` to create a color palette save_as: str, list, optional Save the plot, by default None. If you want to use figures in Inkscape, save them as PDFs to retain high resolution; specify a list of strings to save the plot with multiple extensions y_lim: list, optional List for `self._set_ylim` x_lim: list, optional List for `self_.set_xlim` x_dec: int, optional Enforce `x_ticks` to have `x_dec` decimal accuracy. Default is whatever the data dictates y_dec: int, optional Enforce `y_ticks` to have `y_dec` decimal accuracy. Default is whatever the data dictates add_hline: dict, optional Dictionary for a horizontal line through the plot, by default None. Collects the following items: .. code-block:: python add_hline = { 'pos' 0, # position 'color': 'k', # color 'lw': 1, # linewidth 'ls': '--' # linestyle } You can get the settings above by specifying *add_hline='default'*. Now also accepts *add_hline='mean'* for single inputs add_vline: dict, optional Dictionary for a vertical line through the plot, by default None. Same keys as `add_hline` dpi: int, optional Save figures with DPI-value. Default is 300 figure_background_color: str, optional Background of images. Default is "white" bbox_inches: str, optional Bounding box settings. Default is "tight" """ def __init__(self, **kwargs): self.ls_kwargs = [ "pad_title", "font_size", "title_size", "label_size", "tick_width", "tick_length", "axis_width", "line_width", "line_style", "sns_offset", "sns_trim", "sns_bottom", "sns_ori", "sns_rot", "sns_despine", "xkcd", "ylim_bottom", "ylim_top", "xlim_left", "xlim_right", "set_xlim_zero", "legend_handletext", "x_label", "y_label", "title", "save_as", "y_lim", "x_lim", "x_ticks", "y_ticks", "z_ticks", "x_ticklabels", "y_ticklabels", "z_ticklabels", "axs", "color", "y_dec", "x_dec", "add_vline", "add_hline", "add_vspan", "add_hspan", "dpi", "figure_background_color", "bbox_inches", "fontname", "legend_kwargs" ] self.pad_title = 20 self.font_size = 14 self.title_size = 16 self.label_size = 10 self.tick_width = 0.5 self.tick_length = 0 self.axis_width = 0.5 self.line_width = 1 self.line_style = "solid" self.sns_offset = 0 self.sns_trim = False self.sns_bottom = False self.sns_ori = "v" self.sns_rot = None self.xkcd = False self.ylim_bottom = None self.ylim_top = None self.xlim_left = None self.xlim_right = None self.set_xlim_zero = False self.legend_handletext = 0.25 self.title = None self.save_as = None self.axs = None self.color = None self.add_vline = None self.add_hline = None self.add_vspan = None self.add_hspan = None self.dpi = 300 self.figure_background_color = "white" self.bbox_inches = "tight" self.fontname = None self.sns_despine = True self.legend_kwargs = {} for i in ["label", "lim", "ticks", "ticklabels", "dec"]: for ii in ["x", "y", "z"]: setattr(self, f"{ii}_{i}", None) # update kwargs self.__dict__.update(kwargs) # set default font if self.xkcd: self.fontname = "Humor Sans" else: if not isinstance(self.fontname, str): self.fontname = "Montserrat" # update font widely self.update_rc(self.fontname)
[docs] def update_rc(self, font): """update font""" plt.rcParams.update({'font.family': font})
def _set_figure_axs(self, figsize=None): if not isinstance(self.axs, mpl.axes._axes.Axes): if not isinstance(figsize, tuple): figsize = self.figsize _, self.axs = plt.subplots(figsize=figsize) def _set_spine_width(self, ax): """update spine width""" for axis in ['top', 'bottom', 'left', 'right']: ax.spines[axis].set_linewidth(self.axis_width) def _set_axlabel(self, ax, lbl, axis="x", **kwargs): """_set_axlabel Internal function to set the x/y/z-labels of a plot. Parameters ---------- ax: <AxesSubplot:>, optional Subplot axis to put the plot on, by default None lbl: str, dict if `lbl` is a string, we'll just simply use this as the label annotation. By default, `fontsize` and `fontname` are set by :class:`lazyfmri.plotting.Defaults()`. .. code-block:: python # update with defaults defaults_kws = { "fontsize": self.font_size, "fontname": self.fontname, } `lbl` can also be a dictionary collecting keys relevant for '<ax>.set_{x|y|z}label': .. code-block:: python title_dict = { "label": "this is the actual label", "color: "b" "fontweight": "bold" } axis: str which axis to label. Must be one of 'x', 'y', or 'z' """ kwargs = {} if isinstance(lbl, (dict, str)): if isinstance(lbl, dict): assert "label" in lbl, f"dictionary must contain a key 'label' with a string representing the actual label, dictionary now has: {lbl.keys()}" kwargs = lbl lbl = kwargs.pop("label") allowed_axes = ["x", "y", 'z'] assert axis in allowed_axes, f"axis must be one of {allowed_axes}, not {axis}" assert hasattr(ax, f"set_{axis}label"), f"{ax}-object does not have 'set_{axis}label' attribute. Make sure to specify a valid axis object" # update with defaults defaults_kws = { "fontsize": self.font_size, "fontname": self.fontname, } for key, val in defaults_kws.items(): kwargs = utils.update_kwargs( kwargs, key, val ) if axis == "x": ffunc = ax.set_xlabel elif axis == "y": ffunc = ax.set_ylabel else: ffunc = ax.set_zlabel if isinstance(lbl, (str, list)): ffunc( lbl, **kwargs ) def _set_tick_params(self, ax, **kwargs): """set width/length/labelsize of ticks""" ax.tick_params( width=self.tick_width, length=self.tick_length, labelsize=self.label_size, **kwargs ) def _set_title(self, ax, title): """_set_title Internal function to set the main title of the axis object. Parameters ---------- ax: <AxesSubplot:>, optional Subplot axis to put the plot on, by default None title: str, dict if `title` is a string, we'll just simply use this as the label annotation. By default, `fontsize` and `fontname` are set by :class:`lazyfmri.plotting.Defaults()`. .. code-block:: python # update with defaults defaults_kws = { "fontsize": self.font_size, "fontname": self.fontname, } `title` can also be a dictionary collecting keys relevant for '<ax>.set_title': .. code-block:: python title_dict = { "title": "this is the actual title", "color: "b" "fontweight": "bold" } """ if isinstance(title, (str, dict)): if isinstance(self.title_size, str): if hasattr(self, self.title_size): self.title_size = getattr(self, self.title_size) assert isinstance(title, (dict, str)), f"title must be a string or a dictionary with a 'title' key representing the title, as well as additional arguments passed to ax.set_title(**kwargs)" default_dict = { 'color': 'k', 'fontweight': 'normal', "fontsize": self.title_size, "fontname": self.fontname, "pad": self.pad_title } # pop 'title' key kwargs = {} if isinstance(title, dict): assert "title" in title, f"dictionary must contain a key 'title' with a string representing the actual title, dictionary now has: {title.keys()}" kwargs = title.copy() title = kwargs.pop("title") # add default keys if they're missing in dictionary for key, val in default_dict.items(): kwargs = utils.update_kwargs( kwargs, key, val ) ax.set_title( title, **kwargs ) def _set_bar_lim(self, ax, lim): if isinstance(lim, list): if self.sns_ori == 'h': fc = self._set_xlim elif self.sns_ori == "v": fc = self._set_ylim else: raise ValueError( f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'") # set fc(ax, lim) def _set_bar_ticks(self, ax, ticks): if isinstance(ticks, list): if self.sns_ori == 'h': fc = "x" elif self.sns_ori == "v": fc = "y" else: raise ValueError( f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'") # set self._set_ticks(ax, ticks, axis=fc) @staticmethod def _set_ticks(ax, ticks, axis="x", **kwargs): """set ticks""" if hasattr(ax, f"set_{axis}ticks"): if axis == "x": ffunc = ax.set_xticks elif axis == "y": ffunc = ax.set_yticks else: ffunc = ax.set_zticks if isinstance(ticks, (pd.Series, pd.DataFrame)): ticks = ticks.values if isinstance(ticks, np.ndarray): ticks = [float(i) for i in ticks] if isinstance(ticks, (list)): # check if elements are output of get_?ticklabels() if all([isinstance(i, mpl.text.Text) for i in ticks]): ticks = [float(i._text) for i in ticks] ffunc(ticks) @staticmethod def _set_ticklabels(ax, ticks, axis="x", **kwargs): if hasattr(ax, f"set_{axis}ticklabels"): if axis == "x": ffunc = ax.set_xticklabels elif axis == "y": ffunc = ax.set_yticklabels else: ffunc = ax.set_zticklabels if isinstance(ticks, (pd.Series, pd.DataFrame)): ticks = ticks.values if isinstance(ticks, np.ndarray): ticks = [float(i) for i in ticks] if isinstance(ticks, list): ffunc(ticks, **kwargs) @staticmethod def _set_ylim(ax, lim): """set y-limit""" if isinstance(lim, list): ax.set_ylim(lim) elif isinstance(lim, (int, float)): ax.set_ylim(top=lim) @staticmethod def _set_xlim(ax, lim): """set x-limit""" if isinstance(lim, list): ax.set_xlim(lim) elif isinstance(lim, (int, float)): ax.set_xlim(left=lim) def _despine(self, ax, **kwargs): """despine plot""" default_dict = { "offset": self.sns_offset, "trim": self.sns_trim } for key, val in default_dict.items(): kwargs = utils.update_kwargs( kwargs, key, val ) if self.sns_despine: sns.despine( ax=ax, **kwargs ) @staticmethod def _set_ticker(ax, dec, axis="x", **kwargs): """set all y-ticks to decimal""" if hasattr(ax, f"{axis}axis"): if axis == "x": ffunc = ax.xaxis elif axis == "y": ffunc = ax.yaxis else: ffunc = ax.zaxis if isinstance(dec, int): from matplotlib.ticker import FormatStrFormatter ffunc.set_major_formatter(FormatStrFormatter(f"%.{dec}f")) def _set_shaded_error( self, x: np.ndarray = None, tc: np.ndarray = None, ax: mpl.axes._axes.Axes = None, yerr: np.ndarray = None, **kwargs ): if yerr is None: return if not isinstance(yerr, (int, float, list, tuple, np.ndarray)): return # scalar symmetric error or vector symmetric error if np.isscalar(yerr): ymin = tc - yerr ymax = tc + yerr else: # explicit lower/upper bounds if isinstance(yerr, (list, tuple)) and len(yerr) == 2: ymin = np.asarray(yerr[0]) ymax = np.asarray(yerr[1]) else: yerr = np.asarray(yerr) if len(yerr) != len(tc): raise ValueError( f"Error vector length ({len(yerr)}) does not match data length ({len(tc)})" ) if np.all(np.isnan(yerr)): raise TypeError("Error contains only NaNs") ymin = tc - yerr ymax = tc + yerr if len(ymin) != len(tc) or len(ymax) != len(tc): raise ValueError( f"Shaded bounds must match data length ({len(tc)}); got {len(ymin)} and {len(ymax)}" ) ax.fill_between( x, ymin, ymax, **kwargs ) def _draw_errorbar( self, x: np.ndarray = None, y: np.ndarray = None, ax: mpl.axes._axes.Axes = None, yerr: np.ndarray = None, xerr: np.ndarray = None, **kwargs ): # set bunch of defaults if "linewidth" not in list( kwargs.keys()) and "lw" not in list(kwargs.keys()): kwargs["lw"] = self.line_width if "color" not in list(kwargs.keys()): kwargs["ecolor"] = "k" if "zorder" not in list(kwargs.keys()): kwargs["zorder"] = 0 if "linestyle" not in list( kwargs.keys()) and "ls" not in list(kwargs.keys()): kwargs["ls"] = "None" if isinstance(x, (pd.Series, pd.DataFrame)): x = x.values if isinstance(y, (pd.Series, pd.DataFrame)): y = y.values if isinstance(yerr, (pd.Series, pd.DataFrame)): yerr = yerr.values if isinstance(xerr, (pd.Series, pd.DataFrame)): xerr = xerr.values ax.errorbar( x, y, yerr=yerr, xerr=xerr, **kwargs ) def _set_legend_labels(self, ax, labels=None): if isinstance(labels, (list, np.ndarray)): if "font_size" not in list(self.legend_kwargs.keys( )) and "fontsize" not in list(self.legend_kwargs.keys()): self.legend_kwargs["fontsize"] = self.label_size ax.legend( frameon=False, handletextpad=self.legend_handletext, **self.legend_kwargs, ) def _save_as(self, save_as, **kwargs): """simple save function""" if isinstance(save_as, str): plt.savefig( save_as, bbox_inches=self.bbox_inches, dpi=self.dpi, facecolor=self.figure_background_color, **kwargs ) def _save_figure(self, save_as): """save same figure with multiple extensions""" if isinstance(save_as, (list, str)): if isinstance(save_as, str): save_as = [save_as] for ii in save_as: self._save_as(ii) def _return_element(self, ddict, el, ix=0): if isinstance(ddict[el], list): return ddict[el][ix] else: return ddict[el] def _add_span( self, ax=None ): """Add shaded horizontal and/or vertical spans to an Axes. Reads configuration from :pyattr:`self.add_hspan` and :pyattr:`self.add_vspan` and draws rectangular highlight regions using ``ax.axhspan`` (horizontal band) and/or ``ax.axvspan`` (vertical band). Each attribute can be a concise list ``[start, stop]`` or a dictionary for full control. Parameters ---------- ax : matplotlib.axes.Axes, optional Target axes to draw on. Required; drawn in place. Default = None. Configuration (via ``self.add_hspan`` / ``self.add_vspan``) ----------------------------------------------------------- Each attribute may be: - ``list`` : Two-element list ``[start, stop]`` in **data coordinates** along the oriented axis (y for horizontal spans, x for vertical spans). - ``dict`` : Must contain a ``'loc'`` key with ``[start, stop]`` as above. May also include any valid matplotlib span kwargs (e.g., ``color``, ``alpha``) and the special keys ``'min'`` / ``'max'`` to control the extent along the **orthogonal axis** in **axes coordinates** [0, 1]. Dictionary keys ---------------- loc : list of 2 floats, **required** Span limits in data coords along the oriented axis: - For ``add_hspan`` → ``ymin=loc[0]``, ``ymax=loc[1]``. - For ``add_vspan`` → ``xmin=loc[0]``, ``xmax=loc[1]``. min : float, optional Lower bound along the orthogonal axis, in axes fraction [0, 1]. - Maps to ``xmin`` for ``axhspan``; to ``ymin`` for ``axvspan``. Default = 0. max : float, optional Upper bound along the orthogonal axis, in axes fraction [0, 1]. - Maps to ``xmax`` for ``axhspan``; to ``ymax`` for ``axvspan``. Default = 1. color : str, optional Face color of the span. Default = '#cccccc'. alpha : float, optional Opacity of the span. Default = 0.3. **other matplotlib kwargs** Any additional keyword args accepted by :meth:`matplotlib.axes.Axes.axhspan` / :meth:`~matplotlib.axes.Axes.axvspan` (e.g., ``zorder``, ``linewidth``, ``edgecolor``). These are merged with defaults; explicit values override defaults. Defaults -------- - color = '#cccccc' - alpha = 0.3 - min = 0, max = 1 (full extent across the orthogonal axis) Behavior -------- - Horizontal spans use ``ax.axhspan(ymin, ymax, xmin=min, xmax=max, **kwargs)``. - Vertical spans use ``ax.axvspan(xmin, xmax, ymin=min, ymax=max, **kwargs)``. - When a dict is provided, ``'loc'`` is extracted; optional ``'min'``/``'max'`` are popped and mapped to the correct kwargs depending on orientation. - User kwargs are merged with defaults via ``utils.update_kwargs`` so you can override only the settings you need. Returns ------- None Draws on ``ax`` in place. Examples -------- Full-width horizontal band between y=0.5 and y=1.0: .. code-block:: python self.add_hspan = [0.5, 1.0] self._add_span(ax) Vertical band from x=2 to x=3 covering the central 30–90% of the axis height: .. code-block:: python self.add_vspan = {'loc': [2, 3], 'min': 0.30, 'max': 0.90, 'color': 'tab:blue', 'alpha': 0.15} self._add_span(ax) """ default_dict = { "color": "#cccccc", "alpha": 0.3, } min, max = [0, 1] for ii, ffunc in zip(["hspan", "vspan"], [ax.axhspan, ax.axvspan]): kwargs = {} test_attr = getattr(self, f"add_{ii}") if isinstance(test_attr, (list, dict)): if isinstance(test_attr, dict): assert "loc" in test_attr, f"Input must contain a 'loc'-key, representing a [start, stop] list. Dictionary now contains: {test_attr.keys()}" # extract loc kwargs = test_attr loc = kwargs.pop("loc") if "min" in kwargs: min = kwargs.pop("min") if "max" in kwargs: max = kwargs.pop("max") else: loc = test_attr if ii == "hspan": default_dict["xmin"] = min default_dict["xmax"] = max default_dict["ymin"] = loc[0] default_dict["ymax"] = loc[1] else: default_dict["xmin"] = loc[0] default_dict["xmax"] = loc[1] default_dict["ymin"] = min default_dict["ymax"] = max # update with defaults for key, val in default_dict.items(): kwargs = utils.update_kwargs( kwargs, key, val ) # plop into function ffunc(**kwargs) def _add_line( self, ax=None, **kwargs ): """_add_line Add horizontal and/or vertical reference lines to an Axes. Reads configuration from the instance attributes :pyattr:`self.add_hline` and :pyattr:`self.add_vline` and draws one or more ``axhline``/``axvline`` objects with sensible defaults. Each attribute can be provided in several shorthand forms (see below). Optionally, labels can be attached to each line. Parameters ---------- ax : matplotlib.axes.Axes, optional Target axes to draw on. Must be provided; no internal fallback. Default = None. **kwargs Currently unused; accepted for API compatibility. Configuration (via ``self.add_hline`` / ``self.add_vline``) ----------------------------------------------------------- Each of these attributes may be: - ``"default"`` : Draw a single line at position ``0`` using defaults. - ``float`` or ``int`` : Draw a single line at the given data position. - ``list`` or ``numpy.ndarray`` of numbers: Draw multiple lines at the given positions. - ``dict`` : Full control via keys described below. Dictionary keys ---------------- The following keys are recognized when the attribute is a ``dict``. Scalars apply to all lines; lists/arrays are indexed per position. pos : float, int, list, or numpy.ndarray, **required** Position(s) of the line(s) in data coordinates (x for vertical, y for horizontal). If omitted, a ``ValueError`` is raised. color : str or list, optional Line color(s). Default = 'k'. ls : str or list, optional Line style(s), e.g. 'dashed'. Default = 'dashed'. lw : float or list, optional Line width(s). Default = 0.5. min : float or list, optional For horizontal lines this maps to ``xmin`` (axes fraction in [0, 1]); for vertical lines to ``ymin``. Default = 0. max : float or list, optional For horizontal lines this maps to ``xmax``; for vertical lines to ``ymax``. Default = 1. add_label : list, dict or any, optional If this key is present, a label is added at each line via :meth:`add_label_to_line`. When ``add_label`` is a ``dict``, its contents are forwarded as keyword arguments to ``add_label_to_line`` (e.g., ``{'lbl': 'threshold', 'l_max': 0.9}``). If ``add_label`` is present but not a dict, a label is still added with default settings and the text defaults to the line position. If `lbl`-key contains `=pos`, this will be substituted for the actual value of ``pos`` For further customization, generate the list of dictionaries before passing them to the function Defaults -------- - color = 'k' - ls = 'dashed' - lw = 0.5 - min = 0, max = 1 (interpreted as axes fractions; mapped to ``xmin/xmax`` or ``ymin/ymax`` depending on orientation) Behavior -------- - Lines are drawn with :meth:`matplotlib.axes.Axes.axhline` (for ``add_hline``) or :meth:`matplotlib.axes.Axes.axvline` (for ``add_vline``). - Positions provided as scalars are internally normalized to a list so that per-line styling can be broadcast via lists/arrays of the same length. - If a provided style key is a list/array, the *i*-th element is used for the *i*-th line; otherwise the scalar value is reused for all lines. Returns ------- None This method draws on ``ax`` in-place and does not return a value. Examples -------- How to call with :class:`lazyfmri.plotting.LazyLine()` .. code-block:: python # simple line at 0 pl = plotting.LazyLine( <some data>, xx=<some axis>, figsize=(5,5), add_vline=5 # relative to <some axis> ) .. code-block:: python # list of positions pl = plotting.LazyLine( <some data>, xx=<some axis>, figsize=(5,5), add_vline=[0, 2, 10] # relative to <some axis> ) # list of annotations ('pos' in lbl will be subsituted for the actual 'pos' value) pl = plotting.LazyLine( <some data>, xx=<some axis>, figsize=(5,5), add_vline={ "pos": [0, 1, 4], "add_label": [ { "lbl": "t=pos", # t=0 "l_max": 0.15, }, { "lbl": "t=pos", # t=1 "l_max": 0.15, }, { "lbl": "t=pos", # t=4 "l_max": 0.15, } ] } ) .. code-block:: python # full control + annotation pl = plotting.LazyLine( <some data>, xx=<some axis), figsize=(5,5), add_vline={ "pos": 0, "add_label": { "lbl": "onset", "l_max": 0.15, "fontweight": "bold", "color": "r", "fontsize": 20 } } ) """ for ii in ["hline", "vline"]: line_kw = {} test_attr = getattr(self, f"add_{ii}") if isinstance(test_attr, (float, int, dict, str, list)): add_line = True # define default dictionary default_dict = { 'color': 'k', 'ls': 'dashed', 'lw': 0.5 } # set fixer depending on line being drawn if ii == "hline": default_dict["min"] = 0 default_dict["max"] = 1 else: default_dict["min"] = 0 default_dict["max"] = 1 # add lines if test_attr == "default": test_attr = {'pos': 0} elif isinstance(test_attr, (float, int, list, np.ndarray)): test_attr = {"pos": test_attr} elif isinstance(test_attr, dict): add_line = True else: add_line = False if add_line: for key in list(default_dict.keys()): if key not in list(test_attr.keys()): test_attr[key] = default_dict[key] # enforce list so we only need to call functions once if "pos" not in list(test_attr.keys()): raise ValueError("Need the 'pos' key to denote position..") else: if isinstance(test_attr['pos'], (int, float)): test_attr['pos'] = [test_attr['pos']] # loop through elements if isinstance(test_attr['pos'], (list, np.ndarray)): for ix, pos in enumerate(test_attr['pos']): # define function kwargs_list = ["color", "lw", "ls"] if ii == "hline": ffunc = ax.axhline kwargs_list += ["xmin", "xmax"] else: ffunc = ax.axvline kwargs_list += ["ymin", "ymax"] # update kwargs for key, val in zip( ["color", "lw", "ls", "min", "max"], kwargs_list ): line_kw[val] = self._return_element(test_attr, key, ix=ix) # run func kwargs = {} ffunc( pos, **line_kw, **kwargs ) # check if we should add label if "add_label" in test_attr: lbl = test_attr["add_label"] if not isinstance(lbl, list): lbl = [lbl] # ensure lists are the same length if len(test_attr['pos']) != len(lbl): lbl = [lbl[0] for _ in range(len(test_attr['pos']))] lbl_kws = {} if isinstance(lbl[ix], dict): lbl_kws = lbl[ix] # set 'pos' if not specified in add_labels; allows custom position if not "pos" in lbl_kws: lbl_kws["pos"] = pos self.add_label_to_line( ax, ori=ii, **lbl_kws )
[docs] def add_label_to_line( self, ax, pos=None, lbl=None, ori="vline", l_max=0.85, boxstyle="round", pad=0.2, fc="white", ec="white", alpha=1, **kwargs ): """Add a label to a vertical or horizontal reference line. Places a text label at a fixed position along an axis-aligned line using mixed data/axes transforms. For vertical lines, *x* is in data units and *y* in axes coordinates; for horizontal lines, *y* is in data units and *x* in axes coordinates. A rounded text box is added by default. Parameters ---------- ax : matplotlib.axes.Axes Target axes to draw on. pos : float Line position in data coordinates. Interpreted as *x* when `ori='vline'` and as *y* when `ori!='vline'` (i.e., horizontal line). lbl : str or float, optional Label text. If `None`, the value of `pos` is used. Default = None. ori : {'vline', 'hline'}, optional Orientation of the line to label. `'vline'` places a label on a vertical line with text rotated 90°, `'hline'` on a horizontal line with no rotation. Default = 'vline'. l_max : float, optional Position along the orthogonal axis in axes coordinates [0, 1] where the label is placed (e.g., height within the axes for a vertical line). Default = 0.85. boxstyle : str, optional Matplotlib fancy box style string (e.g., 'round', 'round4', 'square', 'larrow', 'rarrow'). Combined internally as ``f"{boxstyle},pad={pad}"``. Default = 'round'. pad : float, optional Padding for the text box (in fraction of font size) passed to the boxstyle. Default = 0.2. fc : str, optional Face color of the label's bounding box. Default = 'white'. ec : str, optional Edge color of the label's bounding box. Default = 'white'. alpha : float, optional Transparency of the label's bounding box. Default = 1. **kwargs Additional keyword arguments forwarded to :meth:`matplotlib.axes.Axes.text`. If `bbox` is provided here, it is merged with the defaults; any provided keys (e.g., `fontsize`, `fontweight`, `bbox`, `ha`, `va`, `rotation`) override the internal defaults. Notes ----- - Transform logic: - `ori='vline'`: coordinates are ``(x=data, y=axes)`` via ``ax.get_xaxis_transform()``, with default rotation 90°. - `ori!='vline'` (horizontal): coordinates are ``(x=axes, y=data)`` via ``ax.get_yaxis_transform()``, with default rotation 0°. - Default text alignment is centered both horizontally and vertically. - The text is drawn with ``clip_on=False`` and ``zorder=3``. - The function merges user-supplied `bbox`/style kwargs with sensible defaults so you can override only what you need. Returns ------- matplotlib.axes.Axes The same `ax` object, to allow for method chaining. Examples -------- Add a label to a vertical line at x=0 near the bottom of the axes with 'onset' text: .. code-block:: python # example argument passed on :class:`lazyfmri.plotting.LazyLine()` add_vline={ "pos": 0, "add_label": { "lbl": "onset", "l_max": 0.15 } } .. code-block:: python # multi-position lines add_vline={ "pos": [0, 1, 4], "add_label": [ { "lbl": "t=pos", # t=0 "l_max": 0.15, }, { "lbl": "t=pos", # t=1 "l_max": 0.15, }, { "lbl": "t=pos", # t=4 "l_max": 0.15, } ] } """ if ori == "vline": rot = 90 coord = (pos, l_max) # x in data units, y in axes [0..1] tfm = ax.get_xaxis_transform() else: rot = 0 coord= (l_max, pos) # y in data units, x in axes [0..1] tfm = ax.get_yaxis_transform() # text box kwargs bbox_defaults = { "boxstyle": f"{boxstyle},pad={pad}", "fc": fc, "ec": ec, "alpha": alpha, } # position kwargs defaults = { "ha": "center", "va": "center", "rotation": rot } # initialize bbox kwargs if absent if not "bbox" in kwargs: kwargs["bbox"] = {} # append defaults if not specified for key, val in bbox_defaults.items(): kwargs["bbox"] = utils.update_kwargs( kwargs["bbox"], key,val ) # add regular defaults for key, val in defaults.items(): kwargs = utils.update_kwargs( kwargs, key, val ) # set label to pos if not specified if lbl is None: lbl = pos else: lbl = lbl.replace("=pos", f"={str(pos)}") ax.text( coord[0], coord[1], lbl, transform=tfm, clip_on=False, zorder=3, **kwargs ) return ax
[docs] class LazyPRF(Defaults): """LazyPRF Plot the geometric location of the Gaussian pRF. Parameters ---------- prf: numpy.ndarray instantiation of `gauss2D_iso_cart`; will be np.squeeze'ed over the first axis if `ndim >= 3`. vf_extent: list the space the pRF lives in cmap: str, optional Colormap for imshow; accepts output from :func:`lazyfmri.utils.make_binary_cm`. Defaults to 'magma' cross_color: str, optional Color for the fixation cross; defaults to 'white'. You can set it to 'k' if you have a binary colormap as input alpha: float, optional Opacity for imshow shrink_factor: float, optional When the background of the image is white, we create a black border around the Circle patch. If this is equal to `vf_extent`, the border is cut off at some points. This factor shrinks the radius of the Circle, so that we can have a nice border. When set to 0.9, it becomes sort of like a target. This is relevant for **all** non-`magma` color maps that you insert, specifically a :func:`lazyfmri.utils.make_binary_cm` object full_axis: bool, optional If `True`, the entire axis of `vf_extent` will be used for the ticks (more detailed). If `False`, a truncated/trimmed version will be returned (looks cleaner). Default = False axis_off: bool, optional If `True` the x/y axis will be maintained, and the `vf_extent` will be given as ticks. If `False`, axis will be turned off. If `axis_off=True`, then `full_axis` and other label/axis parameters are ignored. Default = True vf_only: bool, optional Only show the outline of the the visual field, without pRF. You still need to specify the pRF as we'll `imshow` an empty array with the same shape rather than the pRF. Default = False line_width: float, optional Width of the outer border of the visual field if `cmap` is not *viridis* or *magma* (these color maps are quite default, and do not require an extra border like :func:`lazyfmri.utils.make_binary_cm`-objects do). Default is 0.5. cross_width: float, optional Width of the cross denoting the x/y axis. Default is 0.5, but can be increased if `cmap` is not *viridis* or *magma* to enhance visibility z_lines: int, optional Set the order of the vertical/horizontal lines. Default is **on top** of the pRF (1) z_prf: int, optional Set the order of the pRF imshow. Default is below the axis lines, but can be changed to be on top of them. Default = 0 imshow_kw: dict, optional Additional kwargs passed on to `imshow` Returns ---------- matplotlib.pyplot plot """ def __init__( self, prf, vf_extent, cmap='RdBu_r', cross_color="white", alpha=None, shrink_factor=1, axis_off=True, figsize=(3.54, 3.54), full_axis=False, vf_only=False, cross_width=0.5, concentric=None, z_lines=1, z_prf=0, edge_color=None, imshow_kw={}, **kwargs): self.prf = prf self.vf_extent = vf_extent self.cmap = cmap self.cross_color = cross_color self.alpha = alpha self.shrink_factor = shrink_factor self.axis_off = axis_off self.figsize = figsize self.full_axis = full_axis self.vf_only = vf_only self.cross_width = cross_width self.concentric = concentric self.z_lines = z_lines self.z_prf = z_prf self.edge_color = edge_color self.imshow_kw = imshow_kw super().__init__() self.__dict__.update(kwargs) self.update_rc(self.fontname) if self.xkcd: with plt.xkcd(): self.plot() else: self.plot() # save self._save_figure(self.save_as)
[docs] def plot(self): # set figure axis self._set_figure_axs() if self.prf.ndim >= 3: self.prf = np.squeeze(self.prf, axis=0) if self.alpha is None: self.alpha = 1 # add cross-hair for ii in ["hline", "vline"]: self.line_kw = { "pos": 0, "color": self.cross_color, "lw": self.cross_width } setattr(self, f"add_{ii}", self.line_kw) self._add_line( self.axs, zorder=self.z_lines ) if not self.vf_only: plot_obj = self.prf else: plot_obj = np.zeros_like(self.prf) # check if pRF has negatives if plot_obj.min() < 0: vmin = plot_obj.min() vmax = -plot_obj.min() else: vmin = -plot_obj.max() vmax = plot_obj.max() if len(self.vf_extent) < 4: self.use_extent = self.vf_extent + self.vf_extent else: self.use_extent = self.vf_extent im = self.axs.imshow( plot_obj, extent=self.use_extent, cmap=self.cmap, alpha=self.alpha, zorder=self.z_prf, vmin=vmin, vmax=vmax, **self.imshow_kw) # In case of a white background, the circle for the visual field is cut # off, so we need to make an adjustment: if self.cmap != 'magma' and self.cmap != 'viridis': radius = self.use_extent[-1] * self.shrink_factor else: radius = self.use_extent[-1] # set title self._set_title(self.axs, self.title) # set patch self.patch = patches.Circle( (0, 0), radius=radius, transform=self.axs.transData, edgecolor=self.edge_color, facecolor="None", linewidth=self.line_width) self.axs.add_patch(self.patch) im.set_clip_path(self.patch) if self.axis_off: self.axs.axis('off') else: # set tick params self._set_tick_params(self.axs) # set spine widths self._set_spine_width(self.axs) if self.full_axis: self.use_ticks = np.arange( self.vf_extent[0], self.vf_extent[1] + 1, 1) else: self.use_ticks = self.vf_extent # set ticks loop_funcs = [ "_set_ticks", "_set_ticklabels", "_set_ticker" ] for x in ["x", "y"]: for ff, el in zip( loop_funcs, ["ticks", "ticklabels", "dec"]): add_to_ax = getattr(self, f"{x}_{el}") getattr(self, ff)(self.axs, add_to_ax, axis=x) self._despine(self.axs)
[docs] class LazyLine(Defaults): """LazyLine Class for plotting because I'm lazy and I don't want to go through the ``matplotlib`` motion everything I quickly want to visualize something. This class makes that a lot easier. It allows single inputs, lists with multiple timecourses, labels, error shadings, and much more. Parameters ---------- ts: list, numpy.ndarray Input data. Can either be a single list, or a list of multiple numpy arrays. If you want labels, custom colors, or error bars, these inputs must come in lists of similar length as ``ts``! xx: list, numpy.ndarray, optional X-axis array error: list, numpy.ndarray, optional Error data with the same length/shape as the input timeseries, by default None. Can be either a numpy.ndarray for 1 timeseries, or a list of numpy.ndarrays for multiple timeseries error_alpha: float, optional Opacity level for error shadings, by default 0.3 cmap: str, optional Color palette to use for colors if no individual colors are specified, by default 'viridis' figsize: tuple, optional Figure dimensions as per usual matplotlib conventions, by default (25,5) markers: str, list, optional Use markers during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted. If one array in ``ts`` should not have markers, use `None`. E.g., if `len(ts) == 3`, and we want only the first timecourse to have markers use: ``markers=['.',None,None]`` markersize: str, list, optional Specify marker sizes during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted. If one array in ``ts`` should not have markers, use `None`. E.g., if ``len(ts) == 3``, and we want only the first timecourse to have markers use: ``markers=['.',None,None]`` markerfc: str, list, optional Specify marker facecolor during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted markerec: str, list, optional Specify marker edgecolor during plotting. A single option will be applied to all elements in ``ts``. A separate list for each element of ``ts`` is also accepted plot_kw: dict, optional Further plotting options passed to ``matplotlib.pyplot.plot``, such as ``markerfacecolor`` (removes fill from markers, leaving only the edges), and ``markeredgewidth`` (sets the width of the marker outline). These options are passed to **every** elements in ``ts``. x_ticks: list, optional Locations where to put the ticks on the x-axis y_ticks: list, optional Locations where to put the ticks on the y-axis Example ---------- .. code-block:: python # create a bunch of timeseries from lazyfmri import utils ts = utils.random_timeseries(1.2, 0.0, 100) ts1 = utils.random_timeseries(1.2, 0.3, 100) ts2 = utils.random_timeseries(1.2, 0.5, 100) ts3 = utils.random_timeseries(1.2, 0.8, 100) ts4 = utils.random_timeseries(1.2, 1, 100) # plot 1 timecourse plotting.LazyLine(ts2, figsize=(20, 5)) <lazyfmri.plotting.LazyLine at 0x7f839b0289d0> .. code-block:: python # plot multiple timecourses, add labels, and save file plotting.LazyLine([ts, ts1, ts2, ts3, ts4], figsize=(20, 5), save_as="test_LazyLine.pdf", labels=['vol=0', 'vol=0.3', 'vol=0.5', 'vol=0.8', 'vol=1.0']) <lazyfmri.plotting.LazyLine at 0x7f839b2177c0> .. code-block:: python # add horizontal line at y=0 hline = {'pos': 0, 'color': 'k', 'lw': 0.5, 'ls': '--'} plotting.LazyLine(ts2, figsize=(20, 5), add_hline=hline) <lazyfmri.plotting.LazyLine at 0x7f839b053580> .. code-block:: python # add shaded error bars from scipy.stats import sem # make some stack stack = np.hstack((ts1[...,np.newaxis],ts2[...,np.newaxis],ts4[...,np.newaxis])) avg = stack.mean(axis=-1) # calculate mean err = sem(stack, axis=-1) # calculate error plotting.LazyLine(avg, figsize=(20, 5), error=err) <lazyfmri.plotting.LazyLine at 0x7f839b0d5220> Notes ---------- See https://lazyfmri.readthedocs.io/en/latest/examples/example.html for more examples """ def __init__( self, ts, xx=None, error=None, error_alpha=0.3, figsize=(7.1, 3.54), cmap='viridis', labels=None, markers=None, markersize=None, markerfc=None, markerec=None, plot_alpha=None, plot_kw={}, **kwargs): self.array = ts self.xx = xx self.error = error self.error_alpha = error_alpha self.plot_alpha = plot_alpha self.figsize = figsize self.cmap = cmap self.labels = labels self.markers = markers self.markersize = markersize self.markerfc = markerfc self.markerec = markerec self.plot_kw = plot_kw super().__init__() self.__dict__.update(kwargs) self.update_rc(self.fontname) # plot if self.xkcd: with plt.xkcd(): self.plot() else: self.plot() # save self._save_figure(self.save_as)
[docs] def plot(self): """main plotting function""" # set figure axis self._set_figure_axs() # sort out color if isinstance(self.array, np.ndarray): self.array = [self.array] if not self.color: self.color = sns.color_palette(self.cmap, 1)[0] else: self.color = [self.color] # check if alpha's match nr of elements in array if isinstance(self.array, list): if not isinstance(self.plot_alpha, list): if self.plot_alpha is None: self.plot_alpha = [1 for _ in range(len(self.array))] elif isinstance(self.plot_alpha, (int, float)): self.plot_alpha = [ self.plot_alpha for _ in range(len(self.array))] else: self.plot_alpha = [self.plot_alpha] if len(self.plot_alpha) != len(self.array): raise ValueError( f"Alpha list ({len(self.plot_alpha)}) does not match length of data list ({len(self.array)})") if isinstance(self.color, str): self.color = [self.color for _ in range(len(self.array))] if not isinstance(self.markers, list): if self.markers is None: self.markers = [None for _ in range(len(self.array))] else: self.markers = [ self.markers for _ in range(len(self.array))] if len(self.markers) != len(self.array): raise ValueError( f"Marker list ({len(self.markers)}) does not match length of data list ({len(self.array)})") self.fix_error = True if not isinstance(self.error, (int, float, np.ndarray, list)): self.fix_error = False if not isinstance(self.error, list): if not isinstance(self.error, (int, float, np.ndarray, list)): self.error = [None for ii in range(len(self.array))] elif isinstance(self.error, (int, float)): self.error = [self.error for ii in range(len(self.array))] else: self.error = [self.error] # filter out all-NaN errors if self.fix_error: self.tmp_error = [None for _ in range(len(self.array))] for ix, err in enumerate(self.error): if err is None: continue if isinstance(err, (int, float)): self.tmp_error[ix] = err continue # explicit bounds: (ymin, ymax) or [ymin, ymax] if isinstance(err, (list, tuple)) and len(err) == 2: ymin = np.asarray(err[0]) ymax = np.asarray(err[1]) if ymin.shape != ymax.shape: raise ValueError( f"Lower/upper error bounds must have same shape; got {ymin.shape} and {ymax.shape}" ) if not (np.all(np.isnan(ymin)) and np.all(np.isnan(ymax))): self.tmp_error[ix] = (ymin, ymax) continue # regular 1D error vector err = np.asarray(err) if not np.all(np.isnan(err)): self.tmp_error[ix] = err self.error = self.tmp_error.copy() if len(self.error) != len(self.array): raise ValueError( f"Error list ({len(self.error)}) does not match length of data list ({len(self.array)})" ) if not isinstance(self.markersize, list): if self.markersize is None: self.markersize = [None for ii in range(len(self.array))] else: self.markersize = [self.markersize] if len(self.markersize) != len(self.array): raise ValueError( f"Markersize list ({len(self.markersize)}) does not match length of data list ({len(self.array)})") # decide on color scheme if not isinstance(self.color, list): self.color_list = sns.color_palette(self.cmap, len(self.array)) else: self.color_list = self.color if len(self.color_list) != len(self.array): raise ValueError( f"Length color list ({len(self.color_list)}) does not match length of data list ({len(self.array)})") for idx, el in enumerate(self.array): # squeeze dimensions if el.ndim > 1: el = el.squeeze() # decide on line-width if isinstance(self.line_width, list): if len(self.line_width) != len(self.array): raise ValueError( f"""Length of line width lenghts {len(self.line_width)} does not match length of data list ({len(self.array)}""") use_width = self.line_width[idx] elif isinstance(self.line_width, (int, float)): use_width = self.line_width else: use_width = "" # decide on line-style if isinstance(self.line_style, list): if len(self.line_style) != len(self.array): raise ValueError( f"""Length of line width lenghts {len(self.line_style)} does not match length of data list ({len(self.array)}""") use_style = self.line_style[idx] elif isinstance(self.line_style, str): use_style = self.line_style else: use_style = "solid" # decide on x-axis if not isinstance(self.xx, (np.ndarray, list, range, pd.DataFrame, pd.Series)): self.t_ = np.arange(0, len(el)) else: # range has no copy attribute if isinstance(self.xx, range): self.t_ = self.xx elif isinstance(self.xx, (pd.DataFrame, pd.Series)): self.t_ = self.xx.values elif isinstance(self.xx, list): # received array-specific list if len(self.xx) == len(self.array): self.t_ = self.xx[idx] else: self.t_ = self.xx elif isinstance(self.xx, np.ndarray): if self.xx.ndim > 1: self.xx = self.xx.squeeze() self.t_ = self.xx.copy() else: raise TypeError( f"Could not derive time axis from input type {type(self.xx)}") if isinstance(self.labels, (list, np.ndarray)): lbl = self.labels[idx] else: lbl = None if isinstance(self.markerec, (list, str)): if isinstance(self.markerec, list): use_ec = self.markerec[idx] else: use_ec = self.markerec self.plot_kw["markeredgecolor"] = use_ec if isinstance(self.markerfc, (list, str)): if isinstance(self.markerfc, list): use_fc = self.markerfc[idx] else: use_fc = self.markerfc self.plot_kw["markerfacecolor"] = use_fc # plot self.axs.plot( self.t_, el, color=self.color_list[idx], label=lbl, lw=use_width, ls=use_style, marker=self.markers[idx], markersize=self.markersize[idx], alpha=self.plot_alpha[idx], **self.plot_kw) # plot shaded error bars self._set_shaded_error( x=self.t_, ax=self.axs, tc=el, yerr=self.error[idx], color=self.color_list[idx], alpha=self.error_alpha ) # axis labels and titles self._set_legend_labels(self.axs, labels=self.labels) # set title self._set_title(self.axs, self.title) # set tick params self._set_tick_params(self.axs) # set spine widths self._set_spine_width(self.axs) # give priority to specify x-lims rather than seaborn's xlim if not self.x_lim: if isinstance(self.xlim_left, (float, int)): self.axs.set_xlim(left=self.xlim_left) else: self.axs.set_xlim(left=self.t_[0]) if self.xlim_right: self.axs.set_xlim(right=self.xlim_right) else: self.axs.set_xlim(right=self.t_[-1]) else: self.axs.set_xlim(self.x_lim) if not self.y_lim: if isinstance(self.ylim_bottom, (float, int)): self.axs.set_ylim(bottom=self.ylim_bottom) if self.ylim_top: self.axs.set_ylim(top=self.ylim_top) else: self.axs.set_ylim(self.y_lim) # set ticks loop_funcs = [ "_set_ticks", "_set_ticklabels", "_set_ticker", "_set_axlabel" ] for x in ["x", "y"]: for ff, el in zip( loop_funcs, ["ticks", "ticklabels", "dec", "label"] ): add_to_ax = getattr(self, f"{x}_{el}") getattr(self, ff)(self.axs, add_to_ax, axis=x) # draw horizontal/vertical lines with ax?line self._add_line(ax=self.axs) self._add_span(ax=self.axs) # despine self._despine(self.axs)
[docs] class LazyCorr(Defaults): """LazyCorr Wrapper around seaborn's regplot. Plot data and a linear regression model fit. In addition to creating the plot, you can also run a regression or correlation using pingouin by setting the corresponding argument to ``True``. Parameters ---------- data: pd.DataFrame, optional Input DataFrame. In this case, use strings representing column names for ``x``, ``y``, and ``color_by``. Internally, the dataframe is parsed into arrays so that it's compatible with matplotlib's scatter-/ color-by functions x: str, np.ndarray, list First variable to include in plot/regression. Can be a list/array representing data, or a column name from ``data`` y: str, np.ndarray, list Second variable to include in plot/regression. Can be a list/array representing data, or a column name from ``data`` color_by: str, np.ndarray, list Color the points according to a separate array. Can be a list/array representing data, or a column name from ``data``. Default color map for this is 'viridis', and can be changed by passing arguments to `scatter_kwargs` color: str, list, optional String representing a color, by default "#ccccccc" to color the regression fit figsize: tuple, optional Figure dimensions as per usual matplotlib conventions, by default (8,8) axs: <AxesSubplot:>, optional Matplotlib axis to store the figure on correlation: bool, optional Run a correlation between ``x`` and ``y``. The result is stored in ``self.correlation_result`` regression: bool, optional Run a regression between ``x`` and ``y``. The result is stored in ``self.regression_result`` scatter_kwargs: dict, optional Additional options passed on to the ``scatter`` function from matplotlib. Set colorbar to nothing by passing: .. code-block:: python scatter_kwargs={"cbar": False} result_to_plot: bool, optional Add the correlation/regression result to the plot stat_kwargs: dict, optional Options passed on to pingouin's stats functions reg_kwargs: dict, optional Options passed on the seaborn's regplot Example ---------- .. code-block:: python from lazyfmri import plotting import matplotlib.pyplot as plt # vanilla version; here, the regression fit has the same color as the dots. fig,axs = plt.subplots(figsize=(7,7)) plotting.LazyCorr( x_data, y_data, axs=axs, x_label="add xlabel", y_label="add ylabel" ) .. code-block:: python # more exotic version: color each dot differently from lazyfmri import utils # create color map between red and blue; return as list fig,axs = plt.subplots(figsize=(7,7)) colors = utils.make_between_cm(["r","b], as_list=True, N=len(y_data)) for ix,val in enumerate(y_data): axs.plot(x_data[ix], val, 'o', color=colors[ix], alpha=0.6) # add the regression fit plotting.LazyCorr( x_data, y_data, axs=axs, add_points=False, # turn off points; we've already plotted them x_label="add xlabel", y_label="add ylabel" ) Notes ---------- see documentation of :class:`lazyfmri.plotting.Defaults()` for formatting options """ def __init__( self, data: pd.DataFrame = None, x: Union[str, np.ndarray] = None, y: Union[str, np.ndarray] = None, color: str = "#cccccc", figsize: tuple = (3.54, 3.54), points: bool = True, label: str = None, scatter_kwargs: dict = {}, stat_kwargs: dict = {}, color_by: Union[list, np.ndarray] = None, regression: bool = False, correlation: bool = False, reg_kwargs: dict = {}, error_kwargs: dict = {}, result_to_plot: bool = False, result_loc: tuple = (0.55, 0.1), result_ec: tuple = (1., 0.5, 0.5), result_fc: tuple = (1., 0.8, 0.8), result_dec: int = 2, verbose=False, *args, **kwargs): # init default plotter class super().__init__(**kwargs) self.data = data self.x = x self.y = y self.color = color self.figsize = figsize self.points = points self.label = label self.scatter_kwargs = scatter_kwargs self.stat_kwargs = stat_kwargs self.color_by = color_by self.regression = regression self.correlation = correlation self.reg_kwargs = reg_kwargs self.error_kwargs = error_kwargs self.result_to_plot = result_to_plot self.result_loc = result_loc self.result_fc = result_fc self.result_ec = result_ec self.result_dec = result_dec self.verbose = verbose if self.xkcd: with plt.xkcd(): self.plot() else: self.plot() # run quick regression with pingouin if self.regression: self._run_regression() # run quick correlation with pingouin if self.correlation: self._run_correlation() # print results? if self.verbose: self.print_results() # check if we should add text box if self.result_to_plot: self.add_result_to_plot(*args) # save self._save_figure(self.save_as) def _get_result(self): if hasattr(self, "correlation_result"): met_name = "r" p_name = "p-unc" test = "correlation" result_obj = self.correlation_result meth = result_obj["method"].values[0] elif hasattr(self, "regression_result"): met_name = "r2" p_name = "pval" test = "regression" result_obj = self.regression_result meth = "" p_val = result_obj[p_name][0] met_val = result_obj[met_name][0] return { "p": p_val, "test": test, "metric": met_name, "value": met_val, "method": meth }
[docs] def print_results(self, return_result=False): res = self._get_result() if isinstance(self.x_col, str) and isinstance(self.y_col, str): col_txt = f"\tA={self.x_col}\tB={self.y_col}" else: col_txt = "" if len(res['method']) > 0: meth_txt = f" ({res['method']})" else: meth_txt = "" r, m, v, p = res["test"], res["metric"], res["value"], res["p"] utils.verbose( f"Test={r}{meth_txt}{col_txt} | {m}={round(v,self.result_dec)},\tp={round(p,self.result_dec)}", True) if return_result: return res
[docs] def add_result_to_plot(self): res = self.print_results(return_result=True) self.axs.text( *self.result_loc, f"{res['metric']}={round(res['value'],self.result_dec)}, p={round(res['p'],self.result_dec)}", size=self.font_size * 0.8, bbox=dict(boxstyle="round", ec=self.result_ec, fc=self.result_fc), transform=self.axs.transAxes )
def _run_regression(self): try: import pingouin as pg except BaseException: raise ImportError( "Could not import pingouin, so this functionality is not available") self.regression_result = pg.linear_regression( self.x, self.y, remove_na=True, **self.stat_kwargs ) def _run_correlation(self): try: import pingouin as pg except BaseException: raise ImportError( "Could not import pingouin, so this functionality is not available") # convert to dataframe self.data = pd.DataFrame({"x": self.x, "y": self.y}) self.x = "x" self.y = "y" self.correlation_result = pg.pairwise_corr( self.data, columns=["x", "y"], **self.stat_kwargs )
[docs] def plot(self): # set figure axis self._set_figure_axs() # sort out stuff if input is data if isinstance(self.data, pd.DataFrame): # can reset index on dataframe without index, to try-except try: self.data = self.data.reset_index() except BaseException: pass self.x_col = self.x self.y_col = self.y self.x = self.data[self.x].values self.y = self.data[self.y].values if isinstance(self.color_by, str): self.color_by = self.data[self.color_by].values.astype(float) self.data = None else: self.x_col = self.y_col = None for e in [self.x, self.y]: # convert list to array if isinstance(e, list): e = np.array(e) # enforce float e = e.astype(float) # c-arguments clashes with "color" argument if you pass it to # sns.regplot in "scatter_kws"; hence this solution if isinstance(self.color_by, (list, np.ndarray, pd.DataFrame, pd.Series)): # get array if isinstance(self.color_by, (pd.DataFrame, pd.Series)): self.color_by = self.color_by.values # remove cbar from kwargs if "cbar" in list(self.scatter_kwargs.keys()): add_cbar = self.scatter_kwargs["cbar"] _ = self.scatter_kwargs.pop("cbar") else: add_cbar = True points = self.axs.scatter( self.x, self.y, c=self.color_by, **self.scatter_kwargs ) # set colorbar if add_cbar: self.cbar = plt.colorbar(points) if "label" in list(self.scatter_kwargs.keys()): self.cbar.set_label( self.scatter_kwargs["label"], fontsize=self.font_size, fontname=self.fontname ) # sort out ticks self._set_tick_params(self.cbar.ax) self._set_spine_width(self.cbar.ax) self._set_ticker( self.cbar.ax, self.y_dec, axis="y" ) # remove outside edge from colorbar self.cbar.ax.set_frame_on(False) # set stuff to false/empty for sns.regplot self.points = False self.scatter_kwargs = {} # draw scatter error if desired self._draw_errorbar( x=self.x, y=self.y, ax=self.axs, **self.error_kwargs ) self.kde_color = utils.make_between_cm(self.color, self.color, as_list=True) self.reg_ = sns.regplot( x=self.x, y=self.y, color=self.color, ax=self.axs, scatter=self.points, label=self.label, scatter_kws=self.scatter_kwargs, **self.reg_kwargs ) # sort out ticks self._set_spine_width(self.axs) self._set_tick_params(self.axs) self._set_title(self.axs, self.title) # set ticks loop_funcs = [ "_set_ticks", "_set_ticklabels", "_set_ticker", "_set_axlabel" ] for x in ["x", "y"]: for ff, el in zip( loop_funcs, ["ticks", "ticklabels", "dec", "label"]): add_to_ax = getattr(self, f"{x}_{el}") getattr(self, ff)(self.axs, add_to_ax, axis=x) # draw horizontal/vertical lines with ax?line self._add_line(ax=self.axs) self._add_span(ax=self.axs) # set tickers & despine self._despine(self.axs)
[docs] class LazyBar(): """LazyBar Wrapper around :func:`seaborn.barplot` to follow the same aesthetics of the other Lazy* functions. It is strongly recommended to use a dataframe for this function to make the formatting somewhat easier, but you can input arrays for `x` and `y`. You can round the edges of the bar using `fancy=True`. Parameters ---------- data: pd.DataFrame, optional Input dataframe, by default None x: str, list, np.ndarray, optional Variable for the x-axis, by default None. Can be a column name from `data`, or a list/np.ndarray with labels for input `y`. y: str, list, np.ndarray, optional Variable for the y-axis, by default None. Can be a column name from `data`, or a list/np.ndarray. If `x` is not specified, indices from 0 to `y.shape` will be used to construct the input dataframe. labels: list, np.ndarray, optional custom labels that can be used when `x` denotes a column name in dataframe `data`. The replacing labels should have the same length as the labels that are being overwritten. axs: <AxesSubplot:>, optional Subplot axis to put the plot on, by default None add_points: bool, optional Add the actual datapoints rather than just the bars, by default False. Though default is `False` points_color: str, tuple, optional Color of the points if you do not have nested categories, by default None points_palette: list, sns.palettes._ColorPalette, optional Color palette for the points if you have nested categories (e.g., multiple variables per subject so you can color the individual subjects' data points), by default None points_cmap: str, optional Color map for the points if you did not specify `points_palette`, by default "viridis" points_legend: bool, optional Add legend of the data points (if you have nested categories), by default False. The functionality of these interchangeable legends (`bar_legend` and `points_legend`) is quite tricky, so user discretion is advised. points_alpha: float, optional Alpha of the points, by default 1. Sometimes useful to adjust if you have LOADS of data points error: str, optional Type of error bar to use for the bar, by default "sem". Can be {'sem'|'se'} or {'std'|'sd'. Internally, we'll check if there's enough samples to calculate errors from, otherwise `error` will be set to `None` fancy: bool, optional Flag to round the edges of the bars, by default False. By default, the rounding is scaled by the min/max of the plot, regardless whether `lim` was specified. This ensures equal rounding across inputs. The other `fancy`-arguments below are a bit vague, so leaving them default will ensure nice rounding of the bars fancy_rounding: float, optional Amount of rounding, by default 0.15 fancy_pad: float, optional Vague variable, by default -0.004 fancy_aspect: float, optional Vague variable, by default None. If None, the rounding is scaled by the min/max of the plot, regardless whether `lim` was specified. fancy_denom: int, optional Scaling factor for `fancy_aspect`, by default 4 (which works well for data where the max value is ~50). Use higher values (e.g., 6) if your data range is large bar_legend: bool, optional Legend for the bars, rather than points, by default False. The functionality of these interchangeable legends (`bar_legend` and `points_legend`) is quite tricky, so user discretion is advised. strip_kw, dict, optional Additional kwargs passed on to seaborn's stripplot. Several factors are being set via regular arguments in the function, such as `dodge`, `palette`, `color`, and `hue`. Example ---------- .. code-block:: python # this figure size works well for plots with 2 bars fig,axs = plt.subplots(figsize=(2,8)) plotting.LazyBar( data=df_wm, x="group", y="t1", sns_ori="v", axs=axs, add_labels=True, palette=[con_color,mdd_color], add_points=True, points_color="k", trim_bottom=True, sns_offset=4, y_label2="white matter T1 (ms)", lim=[800,1600], fancy=True, fancy_denom=6 ) Notes ---------- see documentation of :class:`lazyfmri.plotting.Defaults()` for formatting options """ def __init__( self, data: pd.DataFrame = None, x: Union[str, np.ndarray] = None, y: Union[str, np.ndarray] = None, labels: list = None, palette: Union[list, sns.palettes._ColorPalette] = None, cmap: str = "inferno", hue: str = None, figsize=(1.77, 3.54), add_labels: bool = False, lim: list = None, ticks: list = None, add_points: bool = False, points_color: Union[str, tuple] = None, points_palette: Union[list, sns.palettes._ColorPalette] = None, points_cmap: str = "viridis", points_legend: bool = False, points_alpha: float = 1, error: str = "se", ci_mode: str = "auto", ci_low: str = None, ci_high: str = None, ci_kws: dict={}, fancy: bool = False, fancy_rounding: float = 0.15, fancy_pad: float = -0.004, fancy_aspect: float = None, fancy_denom: int = 4, bar_legend: bool = False, lbl_legend: list = None, strip_kw: dict = {}, connect: bool = False, connect_kw: dict = {}, **kwargs ): self.data = data self.x = x self.y = y self.hue = hue self.labels = labels self.palette = palette self.cmap = cmap self.add_labels = add_labels self.lim = lim self.ticks = ticks self.bar_legend = bar_legend self.lbl_legend = lbl_legend self.add_points = add_points self.points_color = points_color self.points_palette = points_palette self.points_cmap = points_cmap self.points_legend = points_legend self.points_alpha = points_alpha self.error = error self.fancy = fancy self.fancy_rounding = fancy_rounding self.fancy_pad = fancy_pad self.fancy_aspect = fancy_aspect self.fancy_denom = fancy_denom self.figsize = figsize self.strip_kw = strip_kw self.connect = connect self.connect_kw = connect_kw self.ci_mode = ci_mode self.ci_low = ci_low self.ci_high = ci_high self.ci_kws = ci_kws self.kw_defaults = Defaults() # avoid that these kwargs are passed down to matplotlib.bar.. Throws # errors ignore_kwargs = [ "trim_left", "trim_bottom", "points_hue", "points_alpha", "bbox_to_anchor", "fancy", "fancy_rounding", "fancy_pad", "fancy_aspect", "fancy_denom", "font_name", "bar_legend", "labels" "strip_kw", "fontname", "add_legend", "connect", "connect_kw" ] kw_sns = {} for ii in kwargs: # filter out non-ls kwargs if ii not in self.kw_defaults.ls_kwargs + ignore_kwargs: kw_sns[ii] = kwargs[ii] else: # overwrite ls-kwargs if ii in self.kw_defaults.ls_kwargs: if not getattr(self.kw_defaults, ii) == kwargs[ii]: setattr(self.kw_defaults, ii, kwargs[ii]) self.__dict__.update(**self.kw_defaults.__dict__) self.__dict__.update(**kwargs) self.kw_defaults.update_rc(self.fontname) if not hasattr(self, "bbox_to_anchor"): self.bbox_to_anchor = None if self.xkcd: with plt.xkcd(): self.plot(**kw_sns) else: self.plot(**kw_sns) # save self.kw_defaults._save_figure(self.save_as) def _resolve_ci_mode(self): if self.ci_mode is None: return None if self.ci_mode == "auto": if isinstance(self.ci_low, str) and isinstance(self.ci_high, str): return "columns" return "raw" if self.ci_mode not in ["raw", "columns"]: raise ValueError("ci_mode must be one of {'auto', 'raw', 'columns', None}") return self.ci_mode def _add_dataframe_cis(self): if not isinstance(self.ci_low, str) or not isinstance(self.ci_high, str): raise ValueError("When ci_mode='columns', ci_low and ci_high must be column names") if self.ci_low not in self.data.columns: raise ValueError(f"Column '{self.ci_low}' not found in dataframe") if self.ci_high not in self.data.columns: raise ValueError(f"Column '{self.ci_high}' not found in dataframe") patches_ = [ p for p in self.ff.patches if p.get_height() != 0 or p.get_width() != 0 ] n_rows = self.data.shape[0] if len(patches_) < n_rows: raise ValueError( f"Found fewer bar patches ({len(patches_)}) than dataframe rows ({n_rows}). " "This can happen when using incompatible input for manual CI drawing." ) value_col = self.y if self.sns_ori == "v" else self.x ci_defs = { "fmt": "none", "ecolor": "black", "capsize": 0, "lw": 1 } for key, val in ci_defs.items(): if key not in self.ci_kws: self.ci_kws = utils.update_kwargs(self.ci_kws, key, val) # Reorder dataframe to match seaborn patch order data_ci = self.data.copy() x_col = self.x if self.sns_ori == "v" else self.y hue_col = self.hue if hue_col is not None: # seaborn draws grouped bar patches hue-first, then x if hasattr(self, "hue_order") and self.hue_order is not None: hue_order = list(self.hue_order) else: hue_order = list(data_ci[hue_col].dropna().unique()) if hasattr(self, "order") and self.order is not None: x_order = list(self.order) else: x_order = list(data_ci[x_col].dropna().unique()) data_ci[x_col] = pd.Categorical( data_ci[x_col], categories=x_order, ordered=True ) data_ci[hue_col] = pd.Categorical( data_ci[hue_col], categories=hue_order, ordered=True ) data_ci = ( data_ci .sort_values([hue_col, x_col]) .reset_index(drop=True) ) else: if hasattr(self, "order") and self.order is not None: x_order = list(self.order) data_ci[x_col] = pd.Categorical( data_ci[x_col], categories=x_order, ordered=True ) data_ci = ( data_ci .sort_values(x_col) .reset_index(drop=True) ) else: data_ci = data_ci.reset_index(drop=True) # Draw CIs for patch, (_, row) in zip(patches_, data_ci.iterrows()): value = row[value_col] low = row[self.ci_low] high = row[self.ci_high] if pd.isna(value) or pd.isna(low) or pd.isna(high): continue err_low = value - low err_high = high - value if err_low < 0 or err_high < 0: raise ValueError( f"CI bounds inconsistent for row: value={value}, " f"low={low}, high={high}. Expected low <= value <= high." ) if self.sns_ori == "v": x_center = patch.get_x() + patch.get_width() / 2 self.ff.errorbar( x=x_center, y=value, yerr=np.array([[err_low], [err_high]]), **self.ci_kws ) else: y_center = patch.get_y() + patch.get_height() / 2 self.ff.errorbar( x=value, y=y_center, xerr=np.array([[err_low], [err_high]]), **self.ci_kws )
[docs] def connect_hue_pairs(self, **kwargs): n_xs = utils.get_unique_ids(self.data, id=self.xx, sort=False) for i, x in enumerate(n_xs): hue_data = self.data.loc[(self.data[self.xx] == x)] self.connect_pairs( hue_data, self.hue, hue_id=i, **kwargs )
[docs] def connect_pairs( self, data, x, hue_id=None, **kwargs): # first assess which elements on the x-axis we have x_inputs = utils.get_unique_ids( data, id=x, sort=False ) # find subsequent pairs pairs = utils.pairwise(x_inputs) connect_kws = { "color": "black", "alpha": 0.1 } for key, val in connect_kws.items(): kwargs = utils.update_kwargs( kwargs, key, val ) # get children of the axis children = self.ff.get_children() # make all strings so we can regex the "collections" child_str_list = [str(i) for i in children] str_collection = utils.get_file_from_substring( "collections", child_str_list) # then index in actual children list real_collection = [ children[child_str_list.index(i)] for i in str_collection] # get correct collections based on hue_id if isinstance(hue_id, int): start_idx = hue_id * len(x_inputs) real_collection = real_collection[start_idx:start_idx + len(x_inputs)] # find collections collections = utils.pairwise(real_collection) for pair, coll in zip(pairs, collections): # find starting index of pair list locs1 = coll[0].get_offsets() locs2 = coll[1].get_offsets() # find subdatasets sets = [data.loc[data[x] == i][self.yy].values for i in pair] # before plotting, we need to sort so that the data points # correspond to each other sort_idxs1 = np.argsort(sets[0]) sort_idxs2 = np.argsort(sets[1]) # revert "ascending sort" through sort_idxs2.argsort(), # and then sort into order corresponding with set1 locs2_sorted = locs2[sort_idxs2.argsort()][sort_idxs1] for i in range(locs1.shape[0]): if self.sns_ori == "v": x_idx, y_idx = 0, 1 else: x_idx, y_idx = 1, 0 x1 = [locs1[i, x_idx], locs2_sorted[i, x_idx]] y1 = [locs1[i, y_idx], locs2_sorted[i, y_idx]] self.ff.plot(x1, y1, **kwargs)
[docs] def plot(self, **kw_sns): # set figure axis self.kw_defaults._set_figure_axs(figsize=self.figsize) # construct dataframe from loose inputs if isinstance(self.y, (np.ndarray, list)): if isinstance(self.y, list): self.y = np.array(self.y) if not isinstance(self.x, (np.ndarray, list)): self.x = np.arange(0, self.y.shape[0]) self.data = pd.DataFrame({"x": self.x, "y": self.y}) self.x = "x" self.y = "y" # check if we should reset the index of dataframe try: self.data = self.data.reset_index() except BaseException: pass # check if we got custom labels if isinstance(self.labels, (np.ndarray, list)): # check if x is multi-level if len(self.data[self.x]) > len(self.labels): # get unique values unique_x = utils.get_unique_ids( self.data, id=self.x, sort=False ) # new xw new_x = [] for x in self.data[self.x].values: # find index of old x in unique x curr_x = unique_x.index(x) # substitute for new x new_x.append(self.labels[curr_x]) self.data[self.x] = new_x else: self.data[self.x] = self.labels if self.sns_ori == "h": self.xx = self.x self.yy = self.y self.trim_bottom = False self.trim_left = True elif self.sns_ori == "v": self.xx = self.x self.yy = self.y self.trim_bottom = True self.trim_left = False else: raise ValueError( f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'") if isinstance(self.color, (str, tuple, list)): if isinstance(self.color, (str, tuple)): self.palette = None self.cmap = None elif isinstance(self.color, list): self.palette = sns.color_palette(palette=self.color) self.color = None else: self.color = None if isinstance(self.palette, list): self.palette = sns.color_palette(palette=self.palette) if not isinstance(self.palette, sns.palettes._ColorPalette): # self.palette = sns.color_palette(self.cmap, self.data.shape[0]) self.palette = self.cmap # allow more input types if isinstance(self.error, str): if self.error.lower() in "sem": self.error = "se" elif self.error.lower() == "std": self.error = "sd" ci_mode = self._resolve_ci_mode() sns_errorbar = self.error if ci_mode == "raw" else None self.ff = sns.barplot( data=self.data, x=self.xx, y=self.yy, ax=self.axs, orient=self.sns_ori, errorbar=sns_errorbar, hue=self.hue, **dict( kw_sns, color=self.color, palette=self.palette )) if ci_mode == "columns": self._add_dataframe_cis() multi_strip = False # need points in order to connect them.. if self.connect: self.add_points = True if self.add_points: if not hasattr(self, "points_hue"): self.points_hue = None if not self.points_palette: self.points_palette = self.points_cmap # give priority to given points_color if isinstance(self.points_color, (str, tuple)): self.points_palette = None self.points_hue = None if isinstance(self.hue, str): if isinstance(self.points_hue, str): if self.points_hue != self.hue: multi_strip = True self.hue_items = list( np.unique(self.data[self.points_hue].values)) if isinstance(self.points_color, (str, tuple)): self.hue_colors = [ self.points_color for ii in range(len(self.hue_items))] else: self.hue_colors = sns.color_palette( self.points_palette, len(self.hue_items)) for it, color in zip(self.hue_items, self.hue_colors): df_per_it = self.data[self.data[self.points_hue] == it] sns.stripplot( data=df_per_it, x=self.xx, y=self.yy, hue=self.hue, dodge=False, palette=[color] * 2, ax=self.ff, **self.strip_kw ) else: multi_strip = True sns.stripplot( data=self.data, x=self.xx, y=self.yy, hue=self.hue, dodge=True, ax=self.ff, color=self.points_color, palette=self.points_palette, alpha=self.points_alpha, **self.strip_kw ) else: sns.stripplot( data=self.data, x=self.xx, y=self.yy, hue=self.points_hue, dodge=False, ax=self.ff, color=self.points_color, palette=self.points_palette, alpha=self.points_alpha, **self.strip_kw ) # connect dots if self.connect: if isinstance(self.hue, str): self.connect_hue_pairs(**self.connect_kw) else: self.connect_pairs( self.data, self.xx, **self.connect_kw ) # set tick params self.kw_defaults._set_tick_params(self.ff) # set spine widths self.kw_defaults._set_spine_width(self.ff) if not self.add_labels: if self.sns_ori == 'h': self.kw_defaults._set_ticks(self.ff, [], axis="y") elif self.sns_ori == "v": self.kw_defaults._set_ticks(self.ff, [], axis="x") else: raise ValueError( f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'") elif isinstance(self.add_labels, list): self.kw_defaults._set_xlabel(self.ff, self.add_labels) if isinstance(self.sns_rot, (int, float)): if self.sns_ori == 'h': self.kw_defaults._set_ticklabels( self.ff, self.ff.get_yticklabels(), rotation=self.sns_rot, fontname=self.fontname, axis="y" ) elif self.sns_ori == "v": self.kw_defaults._set_ticklabels( self.ff, self.ff.get_xticklabels(), axis="x", rotation=self.sns_rot, fontname=self.fontname ) else: raise ValueError( f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'") # set limits depending on orientation self.kw_defaults._set_bar_lim(self.ff, self.lim) # set ticks depending on orientation self.kw_defaults._set_bar_ticks(self.ff, self.ticks) # from: https://stackoverflow.com/a/61569240 if self.fancy: new_patches = [] for patch in self.ff.patches: bb = patch.get_bbox() color = patch.get_facecolor() # max of axis divided by 4 gives nice rounding if not isinstance(self.fancy_aspect, (int, float)): if self.sns_ori == "v": y_limiter = patch._axes.get_ylim()[-1] if isinstance(self.lim, list): y_limiter -= self.lim[0] self.fancy_aspect = y_limiter / self.fancy_denom else: x_limiter = patch._axes.get_xlim()[-1] if isinstance(self.lim, list): x_limiter -= self.lim[0] self.fancy_aspect = x_limiter / self.fancy_denom # make rounding at limit if isinstance(self.lim, list): if self.sns_ori == "v": ymin = self.lim[0] xmin = bb.xmin height = bb.height - ymin width = bb.width else: xmin = self.lim[0] ymin = bb.ymin width = bb.width - xmin height = bb.height else: xmin = bb.xmin ymin = bb.ymin height = bb.height width = bb.width p_bbox = patches.FancyBboxPatch( (xmin, ymin), abs(width), abs(height), boxstyle=f"round,pad={self.fancy_pad},rounding_size={self.fancy_rounding}", ec="none", fc=color, mutation_aspect=self.fancy_aspect ) patch.remove() new_patches.append(p_bbox) for patch in new_patches: self.ff.add_patch(patch) # set xlabel to none of nothing is specified if isinstance(self.x, str) and not isinstance(self.x_label, str): self.ff.set(xlabel=None) if isinstance(self.y, str) and not isinstance(self.y_label, str): self.ff.set(ylabel=None) # set these explicitly; remove left axis is orientation = horizontal | # remove bottom axis if orientation is vertical if hasattr(self, "trim_left"): trim_left = self.trim_left else: trim_left = False if hasattr(self, "trim_bottom"): trim_bottom = self.trim_bottom else: trim_bottom = False # draw horizontal/vertical lines with ax?line self.kw_defaults._add_line(ax=self.ff) self.kw_defaults._add_span(ax=self.ff) # set tickers & despine # set ticks loop_funcs = [ "_set_ticker", "_set_axlabel" ] for x in ["x", "y"]: for ff, el in zip( loop_funcs, ["dec", "label"]): add_to_ax = getattr(self, f"{x}_{el}") getattr(self.kw_defaults, ff)(self.ff, add_to_ax, axis=x) # sort out legend if self.bar_legend or self.points_legend: self.add_legend = True # filter out handles that correspond to labels self.legend_kw = {} for key, val in zip( ["fontsize", "handletextpad", "frameon"], [self.label_size, self.legend_handletext, False]): self.legend_kw[key] = val if isinstance(self.bbox_to_anchor, tuple): self.legend_kw["bbox_to_anchor"] = self.bbox_to_anchor # get handles handles, labels = self.ff.get_legend_handles_labels() # bar legend if self.bar_legend: # do some more exotic stuff to disentangle coloring from bars # and hue if isinstance(self.hue, str): # find categorical handles handles, labels = self.ff.get_legend_handles_labels() # find indices of categorical handles in list cc = self.data[self.hue].values indexes = np.unique(cc, return_index=True)[1] cond = [cc[index] for index in sorted(indexes)] if multi_strip: handles = handles[-len(cond):] labels = labels[-len(cond):] else: if not self.add_points: self.add_legend = False else: self.add_legend = False # fill in legend if self.add_legend: if isinstance(self.lbl_legend, list): labels = self.lbl_legend # brute force if len(handles) == 0: handles = list(self.ff.patches) self.legend = self.ff.legend( handles, labels, **self.legend_kw ) else: self.legend = self.ff.legend([], [], frameon=False) self.kw_defaults._despine( self.ff, left=trim_left, bottom=trim_bottom ) # set title self.kw_defaults._set_title(self.ff, self.title)
[docs] class LazyHist(Defaults): """LazyHist Wrapper around seaborn's histogram plotter Parameters ---------- data: numpy.ndarray Input data for histogram kde: bool, optional Add kernel density plot to histogram with seaborn (https://seaborn.pydata.org/generated/seaborn.kdeplot.html). Default is False hist: bool, optional Add histogram to plot. Default is True fill: bool, optional Fill the area below the kde plot. Default is False bins: str, optional Set bins for histogram; default = "auto" kde_kwargs: dict, optional Additional arguments passed on the seaborn's `kde_plot` hist_kwargs: dict, optional Additional arguments passed on to matplotlib's `hist` fuction Returns ---------- matplotlib.pyplot plot Example ---------- .. code-block:: python from lazyfmri import plotting import matplotlib.pyplot as plt fig,axs = plt.subplots(figsize=(7,7)) plotting.LazyHist( y_data, axs=axs, kde=True, hist=True, fill=False, y_label2="add y_label", x_label2="add x_label", hist_kwargs={"alpha": 0.4}, kde_kwargs={"linewidth": 4} ) Notes ---------- see documentation of :class:`lazyfmri.plotting.Defaults()` for formatting options """ def __init__( self, data, x=None, y=None, figsize=(3.54, 3.54), kde=False, hist=True, bins="auto", fill=False, kde_kwargs={}, hist_kwargs={}, color="#cccccc", cmap=None, fancy: bool = False, fancy_rounding: float = 0.15, fancy_pad: float = -0.004, fancy_aspect: float = None, **kwargs): super().__init__() self.__dict__.update(kwargs) self.update_rc(self.fontname) # read regular arguments self.data = data self.x = x self.y = y self.figsize = figsize self.kde = kde self.kde_kwargs = kde_kwargs self.hist_kwargs = hist_kwargs self.hist = hist self.bins = bins self.fill = fill self.color = color self.kwargs = kwargs self.fancy = fancy self.fancy_rounding = fancy_rounding self.fancy_pad = fancy_pad self.fancy_aspect = fancy_aspect self.cmap = cmap # self.__dict__.update(self.kde_kwargs) if self.xkcd: with plt.xkcd(): self.plot() else: self.plot() if self.kde: try: self.kde_ = self.return_kde() except BaseException: pass # save self._save_figure(self.save_as)
[docs] def return_kde(self): return self.ff.get_lines()[0].get_data()
def _set_color(self): if isinstance(self.cmap, str): self.color = sns.color_palette(self.cmap, 1)[0]
[docs] def plot(self): # overwrite color argument if cmap is specified self._set_color() # set figure axis self._set_figure_axs() if self.hist: if self.kde: # ensure density for KDE self.hist_kwargs = utils.update_kwargs( self.hist_kwargs, "density", True ) self.vals, self.bins, self.patches = self.axs.hist( self.data, bins=self.bins, color=self.color, **self.hist_kwargs ) # from: https://stackoverflow.com/a/61569240 if self.fancy: new_patches = [] for patch in reversed(self.patches): # max of axis divided by 4 gives nice rounding if not isinstance(self.fancy_aspect, (int, float)): self.fancy_aspect = patch._axes.get_ylim()[-1] / 4 bb = patch.get_bbox() color = patch.get_facecolor() p_bbox = patches.FancyBboxPatch( (bb.xmin, bb.ymin), abs(bb.width), abs(bb.height), boxstyle=f"round,pad={self.fancy_pad},rounding_size={self.fancy_rounding}", ec="none", fc=color, mutation_aspect=self.fancy_aspect ) patch.remove() new_patches.append(p_bbox) for patch in new_patches: self.axs.add_patch(patch) if self.kde: # turn off legend by default if "legend" not in list(self.kde_kwargs): self.kde_kwargs["legend"] = False # if not self.hist: self.kde_kwargs = utils.update_kwargs( self.kde_kwargs, "color", self.color, force=True ) self.ff = sns.kdeplot( data=self.data, x=self.x, y=self.y, ax=self.axs, fill=self.fill, **self.kde_kwargs ) # there's no self.ff if kde=False if hasattr(self, "ff"): self.active_axs = self.ff else: self.active_axs = self.axs # set titles self._set_title(self.active_axs, self.title) # set tick params/axis width self._set_tick_params(self.active_axs) self._set_spine_width(self.active_axs) # set limits self._set_xlim(self.active_axs, self.x_lim) self._set_ylim(self.active_axs, self.y_lim) # set tickers & despine loop_funcs = [ "_set_ticks", "_set_ticklabels", "_set_ticker", "_set_axlabel" ] for x in ["x", "y"]: for ff, el in zip( loop_funcs, ["ticks", "ticklabels", "dec", "label"]): add_to_ax = getattr(self, f"{x}_{el}") getattr(self, ff)( self.active_axs, add_to_ax, axis=x, fontname=self.fontname ) # set axis labels if not isinstance(self.x_label, str): self.active_axs.set(xlabel=None) if not isinstance(self.y_label, str): self.active_axs.set(ylabel=None) if hasattr(self, "trim_left"): trim_left = self.trim_left else: trim_left = False if "trim_bottom" in list(self.kwargs.keys()): trim_bottom = self.kwargs["trim_bottom"] else: trim_bottom = False self._despine( self.active_axs, left=trim_left, bottom=trim_bottom ) # draw horizontal/vertical lines with ax?line self._add_line(ax=self.active_axs) self._add_span(ax=self.active_axs) # set title self._set_title(self.active_axs, self.title)
[docs] def conform_ax_to_obj( ax, obj=None, **kwargs ): """conform_ax_to_obj Function to conform any plot to the aesthetics of this plotting module. Can be used when a plot is created with functions other than :class:`lazyfmri.plotting.LazyLine`, :class:`lazyfmri.plotting.LazyCorr`, :class:`lazyfmri.plotting.LazyHist`, or any other function specified in this file. Assumes `ax` is a `matplotlib.axes._subplots.AxesSubplot` object, and `obj` a `lazyfmri.plotting.Lazy*`-object. Parameters ---------- ax: matplotlib.axes._subplots.AxesSubplot input axis that needs to be modified obj: lazyfmri.plotting.Lazy* linecanning-specified plotting object containing the information with which `ax` will be conformed kwargs: dict, optional other elements defined in :class:`lazyfmri.plotting.Defaults`, such as `font_size`, `label_size`, or `axis_width`. Overwrites elements in `obj`, if passed Returns ---------- matplotlib.axes._subplots.AxesSubplot input axis with updated aesthetics Example ---------- .. code-block:: python # convert statsmodels's QQ-plot from lazyfmri import plotting import matplotlib as plt from scipy import stats fig,ax = plt.subplots(figsize=(8,8)) sm.qqplot(mdf.resid, dist=stats.norm, line='s', ax=ax) ax = plotting.conform_ax_to_obj(ax,pl) """ # check if Lazy* object was passed on if obj is None: obj = Defaults() # update with kwargs | do separate if len(kwargs) > 0: for key, val in kwargs.items(): setattr(obj, key, val) obj.update_rc(obj.fontname) # try to read some stuff from the passed axis if not isinstance(ax, mpl.colorbar.Colorbar): if not isinstance(obj.title, (str, dict)): obj.title = ax.get_title() if not isinstance(obj.y_label, str): obj.y_label = ax.get_ylabel() if not isinstance(obj.x_label, str): obj.x_label = ax.get_xlabel() # format ticks try: obj._set_spine_width(ax) except BaseException: pass obj._set_tick_params(ax) obj._set_title(ax, obj.title) loop_funcs = [ "_set_ticks", "_set_ticklabels", "_set_ticker", "_set_axlabel" ] for x in ["x", "y", "z"]: for ff, el in zip( loop_funcs, ["ticks", "ticklabels", "dec", "label"] ): add_to_ax = getattr(obj, f"{x}_{el}") getattr(obj, ff)(ax, add_to_ax, axis=x) # format limits for lim, func in zip( [obj.x_lim, obj.y_lim], [obj._set_xlim, obj._set_ylim]): if isinstance(lim, (float, int, list)): func(ax, lim) # draw horizontal/vertical lines with ax?line obj._add_line(ax=ax) obj._add_span(ax=ax) # despine try: obj._despine(ax) except BaseException: pass return ax, obj
[docs] class LazyColorbar(Defaults): def __init__( self, cmap="magma_r", txt=None, vmin=0, vmax=10, ori="vertical", ticks=None, labels=None, flip_ticks=False, flip_label=False, figsize=(3.54, 0.5), save_as=None, cm_nr=5, cm_decimal=3, cb_kws={}, font_kws={}, **kwargs): self.cmap = cmap self.txt = txt self.vmin = vmin self.vmax = vmax self.ori = ori self.ticks = ticks self.flip_ticks = flip_ticks self.flip_label = flip_label self.figsize = figsize self.save_as = save_as self.cm_nr = cm_nr self.cm_decimal = cm_decimal self.labels = labels self.font_kws = font_kws if self.ori == "vertical": self.figsize = (self.figsize[1], self.figsize[0]) super().__init__(**kwargs) # set figure axis self._set_figure_axs() self.__dict__.update(kwargs) self.update_rc(self.fontname) if self.axs is None: if isinstance(self.save_as, str): self.fig, self.axs = plt.subplots(figsize=self.figsize) else: self.fig, self.axs = plt.subplots(figsize=self.figsize) # make colorbase instance if isinstance(self.cmap, str): self.cmap = mpl.cm.get_cmap(self.cmap, 256) # decide ticks if not isinstance(self.ticks, (np.ndarray, list)): self.ticks = self.colormap_ticks( vmin=self.vmin, vmax=self.vmax, key=self.txt, dec=self.cm_decimal, nr=self.cm_nr ) # plop everything in class mpl.colorbar.Colorbar( self.axs, orientation=self.ori, cmap=self.cmap, norm=mpl.colors.Normalize(vmin, vmax), label=self.txt, ticks=self.ticks, **cb_kws ) if self.ori == "vertical": tick_ax = "y" # set font stuff if self.flip_ticks: self.axs.yaxis.set_ticks_position("left") if self.flip_label: self.axs.yaxis.set_label_position("left") text = self.axs.yaxis.label else: tick_ax = "x" if self.flip_ticks: self.axs.xaxis.set_ticks_position("top") if self.flip_label: self.axs.xaxis.set_label_position("top") text = self.axs.xaxis.label if "color" in list(self.font_kws.keys()): ft_color = self.font_kws["color"] self.font_kws.pop('color') else: ft_color = "black" font = mpl.font_manager.FontProperties( size=self.font_size, **self.font_kws ) text.set_font_properties(font) text.set_color(ft_color) # fix ticks self._set_tick_params( self.axs, labelcolor=ft_color, ) # set tick labels? if isinstance(self.labels, (np.ndarray, list)): self._set_ticklabels( self.axs, self.labels, axis=tick_ax ) # turn off frame self.axs.set_frame_on(False) # save self._save_figure(self.save_as)
[docs] @staticmethod def colormap_ticks( vmin=None, vmax=None, key=None, dec=3, nr=5 ): # store colormaps if isinstance(key, str): if key == "polar" or key == "polar angle" or "polar" in key: ticks = [-np.pi, 0, np.pi] else: ticks = list(np.linspace(vmin, vmax, endpoint=True, num=nr)) else: ticks = list(np.linspace(vmin, vmax, endpoint=True, num=nr)) # round ticks ticks = [round(ii, dec) for ii in ticks] # check if minimum of ticks > minimum of data if ticks[0] < vmin: ticks[0] = utils.round_decimals_up(vmin, dec) # check if maximum of ticks < maximum of data if ticks[-1] > vmax: ticks[-1] = utils.round_decimals_down(vmax, dec) return ticks
[docs] def show(self): fig = plt.figure() new_manager = fig.canvas.manager new_manager.canvas.figure = self.fig self.fig.set_canvas(new_manager.canvas)
[docs] def fig_annot( fig, axs=None, y=1.01, x0_corr=0, x_corr=-0.09, fontsize=28, lower=False, brackets=False, square=False, **kwargs ): # get figure letters if lower: alphabet = list(string.ascii_lowercase) else: alphabet = list(string.ascii_uppercase) if isinstance(axs, list): ax_list = axs else: ax_list = fig.axes # make annotations for ix, ax in enumerate(ax_list): bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) if isinstance(x_corr, (float, int)): use_x_pos = x_corr else: if len(x_corr) != len(ax_list) - 1: raise ValueError( f"""List with 'x_corr'-values must match list with axes-1. x_corr contains {len(x_corr)} elements, while {len(ax_list)-1} axes are specified""") use_x_pos = x_corr[ix - 1] if ix == 0: move_frac = x0_corr / bbox.width else: move_frac = use_x_pos / bbox.width pos = move_frac if isinstance(y, list): if len(y) != len(ax_list): raise ValueError( f"""List with y-values must match list with axes. y contains {len(y)} elements, while {len(ax_list)} axes are specified""") y_pos = y[ix] else: y_pos = y txt = alphabet[ix] if brackets: if square: txt = f"[{txt}]" else: txt = f"({txt})" ax.annotate( txt, (pos, y_pos), fontsize=fontsize, xycoords="axes fraction", **kwargs )
turbo_colormap_data = np.array( [[0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149], [0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844], [0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314], [0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558], [0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578], [0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373], [0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942], [0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286], [0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406], [0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300], [0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968], [0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412], [0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631], [0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624], [0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393], [0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936], [0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254], [0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347], [0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214], [0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857], [0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275], [0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461], [0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303], [0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773], [0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896], [0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697], [0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202], [0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436], [0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423], [0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190], [0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761], [0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161], [0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416], [0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550], [0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590], [0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559], [0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484], [0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389], [0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299], [0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240], [0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237], [0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316], [0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500], [0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651], [0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627], [0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448], [0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137], [0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713], [0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199], [0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614], [0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981], [0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321], [0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654], [0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002], [0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386], [0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826], [0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345], [0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963], [0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701], [0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581], [0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623], [0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849], [0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280], [0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937], [0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835], [0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960], [0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294], [0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815], [0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504], [0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343], [0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310], [0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386], [0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552], [0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788], [0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074], [0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391], [0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719], [0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038], [0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328], [0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570], [0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744], [0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831], [0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811], [0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663], [0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369], [0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918], [0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358], [0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706], [0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971], [0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165], [0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297], [0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376], [0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412], [0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417], [0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398], [0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367], [0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332], [0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305], [0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294], [0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310], [0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362], [0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461], [0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616], [0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837], [0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134], [0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516], [0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993], [0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521], [0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082], [0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677], [0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305], [0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966], [0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660], [0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387], [0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148], [0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942], [0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769], [0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629], [0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522], [0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449], [0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408], [0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401], [0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427], [0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486], [0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579], [0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705], [0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863], [0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055]] )
[docs] def RGBToPyCmap(rgbdata): nsteps = rgbdata.shape[0] stepaxis = np.linspace(0, 1, nsteps) rdata = [] gdata = [] bdata = [] for istep in range(nsteps): r = rgbdata[istep, 0] g = rgbdata[istep, 1] b = rgbdata[istep, 2] rdata.append((stepaxis[istep], r, r)) gdata.append((stepaxis[istep], g, g)) bdata.append((stepaxis[istep], b, b)) mpl_data = { 'red': rdata, 'green': gdata, 'blue': bdata } return mpl_data
mpl_data = RGBToPyCmap(turbo_colormap_data) try: plt.register_cmap( name='turbo', data=mpl_data, lut=turbo_colormap_data.shape[0] ) except BaseException: pass mpl_data_r = RGBToPyCmap(turbo_colormap_data[::-1, :]) try: plt.register_cmap( name='turbo_r', data=mpl_data_r, lut=turbo_colormap_data.shape[0] ) except BaseException: pass
[docs] def make_wm_pial_ticks( data, start=0, end=100, step=25, force_int=True ): x_ticks = [0, data.shape[0]//4, data.shape[0]//2, (data.shape[0]//2+data.shape[0]//4), data.shape[0]] x_labels = list(np.arange(start, end*1.1, step=step)) if len(x_ticks) != len(x_labels): raise ValueError( f"Length of ticks ({len(x_ticks)}) {x_ticks} != length of labels ({len(x_labels)}) {x_labels}") if force_int: x_labels = [int(round(i, 0)) for i in x_labels] return { "ticks": x_ticks, "labels": x_labels }
[docs] def annotate_cortical_ribbon( axs, pial_pos=(0.02, 0.92), wm_pos=(0.02, 0.02), lbls=["pial", "wm"], **kwargs ): if "xycoords" not in list(kwargs.keys()): kwargs["xycoords"] = "axes fraction" for pos, tag in zip([pial_pos, wm_pos], lbls): axs.annotate( tag, pos, **kwargs )
[docs] def add_axvspan( axs, loc=(0, 2), color="#cccccc", alpha=0.3, ymin=0, ymax=1, **kwargs ): axs.axvspan( *loc, ymin=ymin, ymax=ymax, alpha=alpha, color=color, **kwargs ) return axs