import string
import numpy as np
from . import utils
import pandas as pd
import seaborn as sns
import matplotlib as mpl
from typing import Union
import matplotlib.pyplot as plt
import matplotlib.patches as patches
[docs]
class Defaults():
"""Defaults
Default settings for plotting.
Parameters
----------
pad_title: int
Set the distance between the title and the plot. Default = 20
title_size: int
Set the font size of titles. Default = 22; you can set it equal to font_size (or any other size) by specifying
"font_size".
font_size: int
Set the font size of axis labels/titles. Default = 18
label_size: int
Set the font size of tick labels. Default = 14
tick_width: float
Set the line-width of the ticks. Default = 0.5
tick_length: float
Set the length of the ticks. Default = 0 (no ticks)
axis_width: float
Set the line-width of axes. Default = 0.5
line_width: int
Line widths for either all graphs (then *int*) or a *list* with the number of elements as requested graphs, default=1.
line_style: str
Set the style of data in line-plots. Default = "solid"
sns_offset: int
Set the distance between y-axis and start of plot. Default = None
sns_trim: bool
Trim the axes following seaborn's convention. Default = False
sns_bottom: bool
Trim specifically the x-axis of plots. Default = False
sns_ori: str, optional
Default orientation for bar-plots. Default is up-right (vertical). Allowed options are "v" or "h".
:class:`lazyfmri.plotting.LazyBar()`
sns_rot: int, float, optional
Rotation of labels in bar plot. Specific to :class:`lazyfmri.plotting.LazyBar()`
xkcd: bool
Plot with cartoon style. Default = False
ylim_bottom: float
Set the y-limiter at the bottom of the plot. Default = None
ylim_top: float
Set the y-limiter at the top of the plot. Default = None
xlim_left: float
Set the x-limiter at the left of the plot. Default = None
xlim_right: float
Set the x-limiter at the right of the plot. Default = None
set_xlim_zero: bool
Set the distance between plot and y-axis to 0. Default = False
legend_handletext: float
Set the distance between the handle and text in legends. Default = 0.05 (bit closer than default)
x_label: str, optional
Label of x-axis, by default None
y_label: str, optional
Label of y-axis, by default None
labels: str, list, optional
String (if 1 timeseries) or list (with the length of `ts`) of colors, by default None. Labels for the timeseries to be
used in the legend
title: str, dict, optional
String of dictionary collecting the following keys representing information about the title:
.. code-block:: python
title = {
'title' "some title", # title text
'color': 'k', # color (default = 'k')
'fontweight': "bold" # fontweight (default = 'normal'), can be any of the matplotib fontweight options
(e.g., 'italic', 'bold', 'normal' etc.)
}
color: str, list, optional
String (if 1 timeseries) or list (with the length of `ts`) of colors, by default None. If nothing is specified, we'll
use `cmap` to create a color palette
save_as: str, list, optional
Save the plot, by default None. If you want to use figures in Inkscape, save them as PDFs to retain high resolution;
specify a list of strings to save the plot with multiple extensions
y_lim: list, optional
List for `self._set_ylim`
x_lim: list, optional
List for `self_.set_xlim`
x_dec: int, optional
Enforce `x_ticks` to have `x_dec` decimal accuracy. Default is whatever the data dictates
y_dec: int, optional
Enforce `y_ticks` to have `y_dec` decimal accuracy. Default is whatever the data dictates
add_hline: dict, optional
Dictionary for a horizontal line through the plot, by default None. Collects the following items:
.. code-block:: python
add_hline = {
'pos' 0, # position
'color': 'k', # color
'lw': 1, # linewidth
'ls': '--' # linestyle
}
You can get the settings above by specifying *add_hline='default'*. Now also accepts *add_hline='mean'* for single
inputs
add_vline: dict, optional
Dictionary for a vertical line through the plot, by default None. Same keys as `add_hline`
dpi: int, optional
Save figures with DPI-value. Default is 300
figure_background_color: str, optional
Background of images. Default is "white"
bbox_inches: str, optional
Bounding box settings. Default is "tight"
"""
def __init__(self, **kwargs):
self.ls_kwargs = [
"pad_title",
"font_size",
"title_size",
"label_size",
"tick_width",
"tick_length",
"axis_width",
"line_width",
"line_style",
"sns_offset",
"sns_trim",
"sns_bottom",
"sns_ori",
"sns_rot",
"sns_despine",
"xkcd",
"ylim_bottom",
"ylim_top",
"xlim_left",
"xlim_right",
"set_xlim_zero",
"legend_handletext",
"x_label",
"y_label",
"title",
"save_as",
"y_lim",
"x_lim",
"x_ticks",
"y_ticks",
"z_ticks",
"x_ticklabels",
"y_ticklabels",
"z_ticklabels",
"axs",
"color",
"y_dec",
"x_dec",
"add_vline",
"add_hline",
"add_vspan",
"add_hspan",
"dpi",
"figure_background_color",
"bbox_inches",
"fontname",
"legend_kwargs"
]
self.pad_title = 20
self.font_size = 14
self.title_size = 16
self.label_size = 10
self.tick_width = 0.5
self.tick_length = 0
self.axis_width = 0.5
self.line_width = 1
self.line_style = "solid"
self.sns_offset = 0
self.sns_trim = False
self.sns_bottom = False
self.sns_ori = "v"
self.sns_rot = None
self.xkcd = False
self.ylim_bottom = None
self.ylim_top = None
self.xlim_left = None
self.xlim_right = None
self.set_xlim_zero = False
self.legend_handletext = 0.25
self.title = None
self.save_as = None
self.axs = None
self.color = None
self.add_vline = None
self.add_hline = None
self.add_vspan = None
self.add_hspan = None
self.dpi = 300
self.figure_background_color = "white"
self.bbox_inches = "tight"
self.fontname = None
self.sns_despine = True
self.legend_kwargs = {}
for i in ["label", "lim", "ticks", "ticklabels", "dec"]:
for ii in ["x", "y", "z"]:
setattr(self, f"{ii}_{i}", None)
# update kwargs
self.__dict__.update(kwargs)
# set default font
if self.xkcd:
self.fontname = "Humor Sans"
else:
if not isinstance(self.fontname, str):
self.fontname = "Montserrat"
# update font widely
self.update_rc(self.fontname)
[docs]
def update_rc(self, font):
"""update font"""
plt.rcParams.update({'font.family': font})
def _set_figure_axs(self, figsize=None):
if not isinstance(self.axs, mpl.axes._axes.Axes):
if not isinstance(figsize, tuple):
figsize = self.figsize
_, self.axs = plt.subplots(figsize=figsize)
def _set_spine_width(self, ax):
"""update spine width"""
for axis in ['top', 'bottom', 'left', 'right']:
ax.spines[axis].set_linewidth(self.axis_width)
def _set_axlabel(self, ax, lbl, axis="x", **kwargs):
"""_set_axlabel
Internal function to set the x/y/z-labels of a plot.
Parameters
----------
ax: <AxesSubplot:>, optional
Subplot axis to put the plot on, by default None
lbl: str, dict
if `lbl` is a string, we'll just simply use this as the label annotation. By default, `fontsize` and `fontname` are set by :class:`lazyfmri.plotting.Defaults()`.
.. code-block:: python
# update with defaults
defaults_kws = {
"fontsize": self.font_size,
"fontname": self.fontname,
}
`lbl` can also be a dictionary collecting keys relevant for '<ax>.set_{x|y|z}label':
.. code-block:: python
title_dict = {
"label": "this is the actual label",
"color: "b"
"fontweight": "bold"
}
axis: str
which axis to label. Must be one of 'x', 'y', or 'z'
"""
kwargs = {}
if isinstance(lbl, (dict, str)):
if isinstance(lbl, dict):
assert "label" in lbl, f"dictionary must contain a key 'label' with a string representing the actual label, dictionary now has: {lbl.keys()}"
kwargs = lbl
lbl = kwargs.pop("label")
allowed_axes = ["x", "y", 'z']
assert axis in allowed_axes, f"axis must be one of {allowed_axes}, not {axis}"
assert hasattr(ax, f"set_{axis}label"), f"{ax}-object does not have 'set_{axis}label' attribute. Make sure to specify a valid axis object"
# update with defaults
defaults_kws = {
"fontsize": self.font_size,
"fontname": self.fontname,
}
for key, val in defaults_kws.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
if axis == "x":
ffunc = ax.set_xlabel
elif axis == "y":
ffunc = ax.set_ylabel
else:
ffunc = ax.set_zlabel
if isinstance(lbl, (str, list)):
ffunc(
lbl,
**kwargs
)
def _set_tick_params(self, ax, **kwargs):
"""set width/length/labelsize of ticks"""
ax.tick_params(
width=self.tick_width,
length=self.tick_length,
labelsize=self.label_size,
**kwargs
)
def _set_title(self, ax, title):
"""_set_title
Internal function to set the main title of the axis object.
Parameters
----------
ax: <AxesSubplot:>, optional
Subplot axis to put the plot on, by default None
title: str, dict
if `title` is a string, we'll just simply use this as the label annotation. By default, `fontsize` and `fontname` are set by :class:`lazyfmri.plotting.Defaults()`.
.. code-block:: python
# update with defaults
defaults_kws = {
"fontsize": self.font_size,
"fontname": self.fontname,
}
`title` can also be a dictionary collecting keys relevant for '<ax>.set_title':
.. code-block:: python
title_dict = {
"title": "this is the actual title",
"color: "b"
"fontweight": "bold"
}
"""
if isinstance(title, (str, dict)):
if isinstance(self.title_size, str):
if hasattr(self, self.title_size):
self.title_size = getattr(self, self.title_size)
assert isinstance(title, (dict, str)), f"title must be a string or a dictionary with a 'title' key representing the title, as well as additional arguments passed to ax.set_title(**kwargs)"
default_dict = {
'color': 'k',
'fontweight': 'normal',
"fontsize": self.title_size,
"fontname": self.fontname,
"pad": self.pad_title
}
# pop 'title' key
kwargs = {}
if isinstance(title, dict):
assert "title" in title, f"dictionary must contain a key 'title' with a string representing the actual title, dictionary now has: {title.keys()}"
kwargs = title.copy()
title = kwargs.pop("title")
# add default keys if they're missing in dictionary
for key, val in default_dict.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
ax.set_title(
title,
**kwargs
)
def _set_bar_lim(self, ax, lim):
if isinstance(lim, list):
if self.sns_ori == 'h':
fc = self._set_xlim
elif self.sns_ori == "v":
fc = self._set_ylim
else:
raise ValueError(
f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
# set
fc(ax, lim)
def _set_bar_ticks(self, ax, ticks):
if isinstance(ticks, list):
if self.sns_ori == 'h':
fc = "x"
elif self.sns_ori == "v":
fc = "y"
else:
raise ValueError(
f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
# set
self._set_ticks(ax, ticks, axis=fc)
@staticmethod
def _set_ticks(ax, ticks, axis="x", **kwargs):
"""set ticks"""
if hasattr(ax, f"set_{axis}ticks"):
if axis == "x":
ffunc = ax.set_xticks
elif axis == "y":
ffunc = ax.set_yticks
else:
ffunc = ax.set_zticks
if isinstance(ticks, (pd.Series, pd.DataFrame)):
ticks = ticks.values
if isinstance(ticks, np.ndarray):
ticks = [float(i) for i in ticks]
if isinstance(ticks, (list)):
# check if elements are output of get_?ticklabels()
if all([isinstance(i, mpl.text.Text) for i in ticks]):
ticks = [float(i._text) for i in ticks]
ffunc(ticks)
@staticmethod
def _set_ticklabels(ax, ticks, axis="x", **kwargs):
if hasattr(ax, f"set_{axis}ticklabels"):
if axis == "x":
ffunc = ax.set_xticklabels
elif axis == "y":
ffunc = ax.set_yticklabels
else:
ffunc = ax.set_zticklabels
if isinstance(ticks, (pd.Series, pd.DataFrame)):
ticks = ticks.values
if isinstance(ticks, np.ndarray):
ticks = [float(i) for i in ticks]
if isinstance(ticks, list):
ffunc(ticks, **kwargs)
@staticmethod
def _set_ylim(ax, lim):
"""set y-limit"""
if isinstance(lim, list):
ax.set_ylim(lim)
elif isinstance(lim, (int, float)):
ax.set_ylim(top=lim)
@staticmethod
def _set_xlim(ax, lim):
"""set x-limit"""
if isinstance(lim, list):
ax.set_xlim(lim)
elif isinstance(lim, (int, float)):
ax.set_xlim(left=lim)
def _despine(self, ax, **kwargs):
"""despine plot"""
default_dict = {
"offset": self.sns_offset,
"trim": self.sns_trim
}
for key, val in default_dict.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
if self.sns_despine:
sns.despine(
ax=ax,
**kwargs
)
@staticmethod
def _set_ticker(ax, dec, axis="x", **kwargs):
"""set all y-ticks to decimal"""
if hasattr(ax, f"{axis}axis"):
if axis == "x":
ffunc = ax.xaxis
elif axis == "y":
ffunc = ax.yaxis
else:
ffunc = ax.zaxis
if isinstance(dec, int):
from matplotlib.ticker import FormatStrFormatter
ffunc.set_major_formatter(FormatStrFormatter(f"%.{dec}f"))
def _set_shaded_error(
self,
x: np.ndarray = None,
tc: np.ndarray = None,
ax: mpl.axes._axes.Axes = None,
yerr: np.ndarray = None,
**kwargs
):
if yerr is None:
return
if not isinstance(yerr, (int, float, list, tuple, np.ndarray)):
return
# scalar symmetric error or vector symmetric error
if np.isscalar(yerr):
ymin = tc - yerr
ymax = tc + yerr
else:
# explicit lower/upper bounds
if isinstance(yerr, (list, tuple)) and len(yerr) == 2:
ymin = np.asarray(yerr[0])
ymax = np.asarray(yerr[1])
else:
yerr = np.asarray(yerr)
if len(yerr) != len(tc):
raise ValueError(
f"Error vector length ({len(yerr)}) does not match data length ({len(tc)})"
)
if np.all(np.isnan(yerr)):
raise TypeError("Error contains only NaNs")
ymin = tc - yerr
ymax = tc + yerr
if len(ymin) != len(tc) or len(ymax) != len(tc):
raise ValueError(
f"Shaded bounds must match data length ({len(tc)}); got {len(ymin)} and {len(ymax)}"
)
ax.fill_between(
x,
ymin,
ymax,
**kwargs
)
def _draw_errorbar(
self,
x: np.ndarray = None,
y: np.ndarray = None,
ax: mpl.axes._axes.Axes = None,
yerr: np.ndarray = None,
xerr: np.ndarray = None,
**kwargs
):
# set bunch of defaults
if "linewidth" not in list(
kwargs.keys()) and "lw" not in list(kwargs.keys()):
kwargs["lw"] = self.line_width
if "color" not in list(kwargs.keys()):
kwargs["ecolor"] = "k"
if "zorder" not in list(kwargs.keys()):
kwargs["zorder"] = 0
if "linestyle" not in list(
kwargs.keys()) and "ls" not in list(kwargs.keys()):
kwargs["ls"] = "None"
if isinstance(x, (pd.Series, pd.DataFrame)):
x = x.values
if isinstance(y, (pd.Series, pd.DataFrame)):
y = y.values
if isinstance(yerr, (pd.Series, pd.DataFrame)):
yerr = yerr.values
if isinstance(xerr, (pd.Series, pd.DataFrame)):
xerr = xerr.values
ax.errorbar(
x,
y,
yerr=yerr,
xerr=xerr,
**kwargs
)
def _set_legend_labels(self, ax, labels=None):
if isinstance(labels, (list, np.ndarray)):
if "font_size" not in list(self.legend_kwargs.keys(
)) and "fontsize" not in list(self.legend_kwargs.keys()):
self.legend_kwargs["fontsize"] = self.label_size
ax.legend(
frameon=False,
handletextpad=self.legend_handletext,
**self.legend_kwargs,
)
def _save_as(self, save_as, **kwargs):
"""simple save function"""
if isinstance(save_as, str):
plt.savefig(
save_as,
bbox_inches=self.bbox_inches,
dpi=self.dpi,
facecolor=self.figure_background_color,
**kwargs
)
def _save_figure(self, save_as):
"""save same figure with multiple extensions"""
if isinstance(save_as, (list, str)):
if isinstance(save_as, str):
save_as = [save_as]
for ii in save_as:
self._save_as(ii)
def _return_element(self, ddict, el, ix=0):
if isinstance(ddict[el], list):
return ddict[el][ix]
else:
return ddict[el]
def _add_span(
self,
ax=None
):
"""Add shaded horizontal and/or vertical spans to an Axes.
Reads configuration from :pyattr:`self.add_hspan` and :pyattr:`self.add_vspan`
and draws rectangular highlight regions using ``ax.axhspan`` (horizontal band)
and/or ``ax.axvspan`` (vertical band). Each attribute can be a concise list
``[start, stop]`` or a dictionary for full control.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Target axes to draw on. Required; drawn in place. Default = None.
Configuration (via ``self.add_hspan`` / ``self.add_vspan``)
-----------------------------------------------------------
Each attribute may be:
- ``list`` :
Two-element list ``[start, stop]`` in **data coordinates** along the
oriented axis (y for horizontal spans, x for vertical spans).
- ``dict`` :
Must contain a ``'loc'`` key with ``[start, stop]`` as above. May also
include any valid matplotlib span kwargs (e.g., ``color``, ``alpha``) and
the special keys ``'min'`` / ``'max'`` to control the extent along the
**orthogonal axis** in **axes coordinates** [0, 1].
Dictionary keys
----------------
loc : list of 2 floats, **required**
Span limits in data coords along the oriented axis:
- For ``add_hspan`` → ``ymin=loc[0]``, ``ymax=loc[1]``.
- For ``add_vspan`` → ``xmin=loc[0]``, ``xmax=loc[1]``.
min : float, optional
Lower bound along the orthogonal axis, in axes fraction [0, 1].
- Maps to ``xmin`` for ``axhspan``; to ``ymin`` for ``axvspan``.
Default = 0.
max : float, optional
Upper bound along the orthogonal axis, in axes fraction [0, 1].
- Maps to ``xmax`` for ``axhspan``; to ``ymax`` for ``axvspan``.
Default = 1.
color : str, optional
Face color of the span. Default = '#cccccc'.
alpha : float, optional
Opacity of the span. Default = 0.3.
**other matplotlib kwargs**
Any additional keyword args accepted by
:meth:`matplotlib.axes.Axes.axhspan` / :meth:`~matplotlib.axes.Axes.axvspan`
(e.g., ``zorder``, ``linewidth``, ``edgecolor``). These are merged with
defaults; explicit values override defaults.
Defaults
--------
- color = '#cccccc'
- alpha = 0.3
- min = 0, max = 1 (full extent across the orthogonal axis)
Behavior
--------
- Horizontal spans use ``ax.axhspan(ymin, ymax, xmin=min, xmax=max, **kwargs)``.
- Vertical spans use ``ax.axvspan(xmin, xmax, ymin=min, ymax=max, **kwargs)``.
- When a dict is provided, ``'loc'`` is extracted; optional ``'min'``/``'max'``
are popped and mapped to the correct kwargs depending on orientation.
- User kwargs are merged with defaults via ``utils.update_kwargs`` so you can
override only the settings you need.
Returns
-------
None
Draws on ``ax`` in place.
Examples
--------
Full-width horizontal band between y=0.5 and y=1.0:
.. code-block:: python
self.add_hspan = [0.5, 1.0]
self._add_span(ax)
Vertical band from x=2 to x=3 covering the central 30–90% of the axis height:
.. code-block:: python
self.add_vspan = {'loc': [2, 3], 'min': 0.30, 'max': 0.90, 'color': 'tab:blue', 'alpha': 0.15}
self._add_span(ax)
"""
default_dict = {
"color": "#cccccc",
"alpha": 0.3,
}
min, max = [0, 1]
for ii, ffunc in zip(["hspan", "vspan"], [ax.axhspan, ax.axvspan]):
kwargs = {}
test_attr = getattr(self, f"add_{ii}")
if isinstance(test_attr, (list, dict)):
if isinstance(test_attr, dict):
assert "loc" in test_attr, f"Input must contain a 'loc'-key, representing a [start, stop] list. Dictionary now contains: {test_attr.keys()}"
# extract loc
kwargs = test_attr
loc = kwargs.pop("loc")
if "min" in kwargs:
min = kwargs.pop("min")
if "max" in kwargs:
max = kwargs.pop("max")
else:
loc = test_attr
if ii == "hspan":
default_dict["xmin"] = min
default_dict["xmax"] = max
default_dict["ymin"] = loc[0]
default_dict["ymax"] = loc[1]
else:
default_dict["xmin"] = loc[0]
default_dict["xmax"] = loc[1]
default_dict["ymin"] = min
default_dict["ymax"] = max
# update with defaults
for key, val in default_dict.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
# plop into function
ffunc(**kwargs)
def _add_line(
self,
ax=None,
**kwargs
):
"""_add_line
Add horizontal and/or vertical reference lines to an Axes.
Reads configuration from the instance attributes :pyattr:`self.add_hline` and
:pyattr:`self.add_vline` and draws one or more ``axhline``/``axvline`` objects
with sensible defaults. Each attribute can be provided in several shorthand
forms (see below). Optionally, labels can be attached to each line.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Target axes to draw on. Must be provided; no internal fallback. Default = None.
**kwargs
Currently unused; accepted for API compatibility.
Configuration (via ``self.add_hline`` / ``self.add_vline``)
-----------------------------------------------------------
Each of these attributes may be:
- ``"default"`` :
Draw a single line at position ``0`` using defaults.
- ``float`` or ``int`` :
Draw a single line at the given data position.
- ``list`` or ``numpy.ndarray`` of numbers:
Draw multiple lines at the given positions.
- ``dict`` :
Full control via keys described below.
Dictionary keys
----------------
The following keys are recognized when the attribute is a ``dict``. Scalars
apply to all lines; lists/arrays are indexed per position.
pos : float, int, list, or numpy.ndarray, **required**
Position(s) of the line(s) in data coordinates (x for vertical, y for horizontal).
If omitted, a ``ValueError`` is raised.
color : str or list, optional
Line color(s). Default = 'k'.
ls : str or list, optional
Line style(s), e.g. 'dashed'. Default = 'dashed'.
lw : float or list, optional
Line width(s). Default = 0.5.
min : float or list, optional
For horizontal lines this maps to ``xmin`` (axes fraction in [0, 1]);
for vertical lines to ``ymin``. Default = 0.
max : float or list, optional
For horizontal lines this maps to ``xmax``; for vertical lines to ``ymax``.
Default = 1.
add_label : list, dict or any, optional
If this key is present, a label is added at each line via
:meth:`add_label_to_line`. When ``add_label`` is a ``dict``, its contents
are forwarded as keyword arguments to ``add_label_to_line`` (e.g.,
``{'lbl': 'threshold', 'l_max': 0.9}``). If ``add_label`` is present but
not a dict, a label is still added with default settings and the text
defaults to the line position.
If `lbl`-key contains `=pos`, this will be substituted for the actual value of ``pos``
For further customization, generate the list of dictionaries before passing them to
the function
Defaults
--------
- color = 'k'
- ls = 'dashed'
- lw = 0.5
- min = 0, max = 1 (interpreted as axes fractions; mapped to ``xmin/xmax`` or
``ymin/ymax`` depending on orientation)
Behavior
--------
- Lines are drawn with :meth:`matplotlib.axes.Axes.axhline` (for ``add_hline``)
or :meth:`matplotlib.axes.Axes.axvline` (for ``add_vline``).
- Positions provided as scalars are internally normalized to a list so that
per-line styling can be broadcast via lists/arrays of the same length.
- If a provided style key is a list/array, the *i*-th element is used for the
*i*-th line; otherwise the scalar value is reused for all lines.
Returns
-------
None
This method draws on ``ax`` in-place and does not return a value.
Examples
--------
How to call with :class:`lazyfmri.plotting.LazyLine()`
.. code-block:: python
# simple line at 0
pl = plotting.LazyLine(
<some data>,
xx=<some axis>,
figsize=(5,5),
add_vline=5 # relative to <some axis>
)
.. code-block:: python
# list of positions
pl = plotting.LazyLine(
<some data>,
xx=<some axis>,
figsize=(5,5),
add_vline=[0, 2, 10] # relative to <some axis>
)
# list of annotations ('pos' in lbl will be subsituted for the actual 'pos' value)
pl = plotting.LazyLine(
<some data>,
xx=<some axis>,
figsize=(5,5),
add_vline={
"pos": [0, 1, 4],
"add_label": [
{
"lbl": "t=pos", # t=0
"l_max": 0.15,
},
{
"lbl": "t=pos", # t=1
"l_max": 0.15,
},
{
"lbl": "t=pos", # t=4
"l_max": 0.15,
}
]
}
)
.. code-block:: python
# full control + annotation
pl = plotting.LazyLine(
<some data>,
xx=<some axis),
figsize=(5,5),
add_vline={
"pos": 0,
"add_label": {
"lbl": "onset",
"l_max": 0.15,
"fontweight": "bold",
"color": "r",
"fontsize": 20
}
}
)
"""
for ii in ["hline", "vline"]:
line_kw = {}
test_attr = getattr(self, f"add_{ii}")
if isinstance(test_attr, (float, int, dict, str, list)):
add_line = True
# define default dictionary
default_dict = {
'color': 'k',
'ls': 'dashed',
'lw': 0.5
}
# set fixer depending on line being drawn
if ii == "hline":
default_dict["min"] = 0
default_dict["max"] = 1
else:
default_dict["min"] = 0
default_dict["max"] = 1
# add lines
if test_attr == "default":
test_attr = {'pos': 0}
elif isinstance(test_attr, (float, int, list, np.ndarray)):
test_attr = {"pos": test_attr}
elif isinstance(test_attr, dict):
add_line = True
else:
add_line = False
if add_line:
for key in list(default_dict.keys()):
if key not in list(test_attr.keys()):
test_attr[key] = default_dict[key]
# enforce list so we only need to call functions once
if "pos" not in list(test_attr.keys()):
raise ValueError("Need the 'pos' key to denote position..")
else:
if isinstance(test_attr['pos'], (int, float)):
test_attr['pos'] = [test_attr['pos']]
# loop through elements
if isinstance(test_attr['pos'], (list, np.ndarray)):
for ix, pos in enumerate(test_attr['pos']):
# define function
kwargs_list = ["color", "lw", "ls"]
if ii == "hline":
ffunc = ax.axhline
kwargs_list += ["xmin", "xmax"]
else:
ffunc = ax.axvline
kwargs_list += ["ymin", "ymax"]
# update kwargs
for key, val in zip(
["color", "lw", "ls", "min", "max"],
kwargs_list
):
line_kw[val] = self._return_element(test_attr, key, ix=ix)
# run func
kwargs = {}
ffunc(
pos,
**line_kw,
**kwargs
)
# check if we should add label
if "add_label" in test_attr:
lbl = test_attr["add_label"]
if not isinstance(lbl, list):
lbl = [lbl]
# ensure lists are the same length
if len(test_attr['pos']) != len(lbl):
lbl = [lbl[0] for _ in range(len(test_attr['pos']))]
lbl_kws = {}
if isinstance(lbl[ix], dict):
lbl_kws = lbl[ix]
# set 'pos' if not specified in add_labels; allows custom position
if not "pos" in lbl_kws:
lbl_kws["pos"] = pos
self.add_label_to_line(
ax,
ori=ii,
**lbl_kws
)
[docs]
def add_label_to_line(
self,
ax,
pos=None,
lbl=None,
ori="vline",
l_max=0.85,
boxstyle="round",
pad=0.2,
fc="white",
ec="white",
alpha=1,
**kwargs
):
"""Add a label to a vertical or horizontal reference line.
Places a text label at a fixed position along an axis-aligned line using
mixed data/axes transforms. For vertical lines, *x* is in data units and
*y* in axes coordinates; for horizontal lines, *y* is in data units and *x*
in axes coordinates. A rounded text box is added by default.
Parameters
----------
ax : matplotlib.axes.Axes
Target axes to draw on.
pos : float
Line position in data coordinates. Interpreted as *x* when `ori='vline'`
and as *y* when `ori!='vline'` (i.e., horizontal line).
lbl : str or float, optional
Label text. If `None`, the value of `pos` is used. Default = None.
ori : {'vline', 'hline'}, optional
Orientation of the line to label. `'vline'` places a label on a vertical
line with text rotated 90°, `'hline'` on a horizontal line with no rotation.
Default = 'vline'.
l_max : float, optional
Position along the orthogonal axis in axes coordinates [0, 1] where the
label is placed (e.g., height within the axes for a vertical line).
Default = 0.85.
boxstyle : str, optional
Matplotlib fancy box style string (e.g., 'round', 'round4', 'square',
'larrow', 'rarrow'). Combined internally as
``f"{boxstyle},pad={pad}"``. Default = 'round'.
pad : float, optional
Padding for the text box (in fraction of font size) passed to the
boxstyle. Default = 0.2.
fc : str, optional
Face color of the label's bounding box. Default = 'white'.
ec : str, optional
Edge color of the label's bounding box. Default = 'white'.
alpha : float, optional
Transparency of the label's bounding box. Default = 1.
**kwargs
Additional keyword arguments forwarded to :meth:`matplotlib.axes.Axes.text`.
If `bbox` is provided here, it is merged with the defaults; any provided
keys (e.g., `fontsize`, `fontweight`, `bbox`, `ha`, `va`, `rotation`)
override the internal defaults.
Notes
-----
- Transform logic:
- `ori='vline'`: coordinates are ``(x=data, y=axes)`` via
``ax.get_xaxis_transform()``, with default rotation 90°.
- `ori!='vline'` (horizontal): coordinates are ``(x=axes, y=data)`` via
``ax.get_yaxis_transform()``, with default rotation 0°.
- Default text alignment is centered both horizontally and vertically.
- The text is drawn with ``clip_on=False`` and ``zorder=3``.
- The function merges user-supplied `bbox`/style kwargs with sensible
defaults so you can override only what you need.
Returns
-------
matplotlib.axes.Axes
The same `ax` object, to allow for method chaining.
Examples
--------
Add a label to a vertical line at x=0 near the bottom of the axes with 'onset' text:
.. code-block:: python
# example argument passed on :class:`lazyfmri.plotting.LazyLine()`
add_vline={
"pos": 0,
"add_label": {
"lbl": "onset",
"l_max": 0.15
}
}
.. code-block:: python
# multi-position lines
add_vline={
"pos": [0, 1, 4],
"add_label": [
{
"lbl": "t=pos", # t=0
"l_max": 0.15,
},
{
"lbl": "t=pos", # t=1
"l_max": 0.15,
},
{
"lbl": "t=pos", # t=4
"l_max": 0.15,
}
]
}
"""
if ori == "vline":
rot = 90
coord = (pos, l_max)
# x in data units, y in axes [0..1]
tfm = ax.get_xaxis_transform()
else:
rot = 0
coord= (l_max, pos)
# y in data units, x in axes [0..1]
tfm = ax.get_yaxis_transform()
# text box kwargs
bbox_defaults = {
"boxstyle": f"{boxstyle},pad={pad}",
"fc": fc,
"ec": ec,
"alpha": alpha,
}
# position kwargs
defaults = {
"ha": "center",
"va": "center",
"rotation": rot
}
# initialize bbox kwargs if absent
if not "bbox" in kwargs:
kwargs["bbox"] = {}
# append defaults if not specified
for key, val in bbox_defaults.items():
kwargs["bbox"] = utils.update_kwargs(
kwargs["bbox"],
key,val
)
# add regular defaults
for key, val in defaults.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
# set label to pos if not specified
if lbl is None:
lbl = pos
else:
lbl = lbl.replace("=pos", f"={str(pos)}")
ax.text(
coord[0], coord[1],
lbl,
transform=tfm,
clip_on=False,
zorder=3,
**kwargs
)
return ax
[docs]
class LazyPRF(Defaults):
"""LazyPRF
Plot the geometric location of the Gaussian pRF.
Parameters
----------
prf: numpy.ndarray
instantiation of `gauss2D_iso_cart`; will be np.squeeze'ed over the first axis if `ndim >= 3`.
vf_extent: list
the space the pRF lives in
cmap: str, optional
Colormap for imshow; accepts output from :func:`lazyfmri.utils.make_binary_cm`. Defaults to 'magma'
cross_color: str, optional
Color for the fixation cross; defaults to 'white'. You can set it to 'k' if you have a binary colormap as input
alpha: float, optional
Opacity for imshow
shrink_factor: float, optional
When the background of the image is white, we create a black border around the Circle patch. If this is equal to
`vf_extent`, the border is cut off at some points. This factor shrinks the radius of the Circle, so that we can have a
nice border. When set to 0.9, it becomes sort of like a target. This is relevant for **all** non-`magma` color maps
that you insert, specifically a :func:`lazyfmri.utils.make_binary_cm` object
full_axis: bool, optional
If `True`, the entire axis of `vf_extent` will be used for the ticks (more detailed). If `False`, a truncated/trimmed
version will be returned (looks cleaner). Default = False
axis_off: bool, optional
If `True` the x/y axis will be maintained, and the `vf_extent` will be given as ticks. If `False`, axis will be turned
off. If `axis_off=True`, then `full_axis` and other label/axis parameters are ignored. Default = True
vf_only: bool, optional
Only show the outline of the the visual field, without pRF. You still need to specify the pRF as we'll `imshow` an
empty array with the same shape rather than the pRF. Default = False
line_width: float, optional
Width of the outer border of the visual field if `cmap` is not *viridis* or *magma* (these color maps are quite
default, and do not require an extra border like :func:`lazyfmri.utils.make_binary_cm`-objects do). Default is 0.5.
cross_width: float, optional
Width of the cross denoting the x/y axis. Default is 0.5, but can be increased if `cmap` is not *viridis* or *magma*
to enhance visibility
z_lines: int, optional
Set the order of the vertical/horizontal lines. Default is **on top** of the pRF (1)
z_prf: int, optional
Set the order of the pRF imshow. Default is below the axis lines, but can be changed to be on top of them. Default = 0
imshow_kw: dict, optional
Additional kwargs passed on to `imshow`
Returns
----------
matplotlib.pyplot plot
"""
def __init__(
self,
prf,
vf_extent,
cmap='RdBu_r',
cross_color="white",
alpha=None,
shrink_factor=1,
axis_off=True,
figsize=(3.54, 3.54),
full_axis=False,
vf_only=False,
cross_width=0.5,
concentric=None,
z_lines=1,
z_prf=0,
edge_color=None,
imshow_kw={},
**kwargs):
self.prf = prf
self.vf_extent = vf_extent
self.cmap = cmap
self.cross_color = cross_color
self.alpha = alpha
self.shrink_factor = shrink_factor
self.axis_off = axis_off
self.figsize = figsize
self.full_axis = full_axis
self.vf_only = vf_only
self.cross_width = cross_width
self.concentric = concentric
self.z_lines = z_lines
self.z_prf = z_prf
self.edge_color = edge_color
self.imshow_kw = imshow_kw
super().__init__()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
# save
self._save_figure(self.save_as)
[docs]
def plot(self):
# set figure axis
self._set_figure_axs()
if self.prf.ndim >= 3:
self.prf = np.squeeze(self.prf, axis=0)
if self.alpha is None:
self.alpha = 1
# add cross-hair
for ii in ["hline", "vline"]:
self.line_kw = {
"pos": 0,
"color": self.cross_color,
"lw": self.cross_width
}
setattr(self, f"add_{ii}", self.line_kw)
self._add_line(
self.axs,
zorder=self.z_lines
)
if not self.vf_only:
plot_obj = self.prf
else:
plot_obj = np.zeros_like(self.prf)
# check if pRF has negatives
if plot_obj.min() < 0:
vmin = plot_obj.min()
vmax = -plot_obj.min()
else:
vmin = -plot_obj.max()
vmax = plot_obj.max()
if len(self.vf_extent) < 4:
self.use_extent = self.vf_extent + self.vf_extent
else:
self.use_extent = self.vf_extent
im = self.axs.imshow(
plot_obj,
extent=self.use_extent,
cmap=self.cmap,
alpha=self.alpha,
zorder=self.z_prf,
vmin=vmin,
vmax=vmax,
**self.imshow_kw)
# In case of a white background, the circle for the visual field is cut
# off, so we need to make an adjustment:
if self.cmap != 'magma' and self.cmap != 'viridis':
radius = self.use_extent[-1] * self.shrink_factor
else:
radius = self.use_extent[-1]
# set title
self._set_title(self.axs, self.title)
# set patch
self.patch = patches.Circle(
(0, 0),
radius=radius,
transform=self.axs.transData,
edgecolor=self.edge_color,
facecolor="None",
linewidth=self.line_width)
self.axs.add_patch(self.patch)
im.set_clip_path(self.patch)
if self.axis_off:
self.axs.axis('off')
else:
# set tick params
self._set_tick_params(self.axs)
# set spine widths
self._set_spine_width(self.axs)
if self.full_axis:
self.use_ticks = np.arange(
self.vf_extent[0], self.vf_extent[1] + 1, 1)
else:
self.use_ticks = self.vf_extent
# set ticks
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker"
]
for x in ["x", "y"]:
for ff, el in zip(
loop_funcs,
["ticks", "ticklabels", "dec"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.axs, add_to_ax, axis=x)
self._despine(self.axs)
[docs]
class LazyLine(Defaults):
"""LazyLine
Class for plotting because I'm lazy and I don't want to go through the ``matplotlib`` motion everything I quickly want to
visualize something. This class makes that a lot easier. It allows single inputs, lists with multiple timecourses, labels,
error shadings, and much more.
Parameters
----------
ts: list, numpy.ndarray
Input data. Can either be a single list, or a list of multiple numpy arrays. If you want labels, custom colors, or
error bars, these inputs must come in lists of similar length as ``ts``!
xx: list, numpy.ndarray, optional
X-axis array
error: list, numpy.ndarray, optional
Error data with the same length/shape as the input timeseries, by default None. Can be either a numpy.ndarray for 1
timeseries, or a list of numpy.ndarrays for multiple timeseries
error_alpha: float, optional
Opacity level for error shadings, by default 0.3
cmap: str, optional
Color palette to use for colors if no individual colors are specified, by default 'viridis'
figsize: tuple, optional
Figure dimensions as per usual matplotlib conventions, by default (25,5)
markers: str, list, optional
Use markers during plotting. A single option will be applied to all elements in ``ts``. A separate list for each
element of ``ts`` is also accepted. If one array in ``ts`` should not have markers, use `None`.
E.g., if `len(ts) == 3`, and we want only the first timecourse to have markers use: ``markers=['.',None,None]``
markersize: str, list, optional
Specify marker sizes during plotting. A single option will be applied to all elements in ``ts``. A separate list for
each element of ``ts`` is also accepted. If one array in ``ts`` should not have markers, use `None`.
E.g., if ``len(ts) == 3``, and we want only the first timecourse to have markers use: ``markers=['.',None,None]``
markerfc: str, list, optional
Specify marker facecolor during plotting. A single option will be applied to all elements in ``ts``. A separate list
for each element of ``ts`` is also accepted
markerec: str, list, optional
Specify marker edgecolor during plotting. A single option will be applied to all elements in ``ts``. A separate list
for each element of ``ts`` is also accepted
plot_kw: dict, optional
Further plotting options passed to ``matplotlib.pyplot.plot``, such as ``markerfacecolor`` (removes fill from markers,
leaving only the edges), and ``markeredgewidth`` (sets the width of the marker outline). These options are passed to
**every** elements in ``ts``.
x_ticks: list, optional
Locations where to put the ticks on the x-axis
y_ticks: list, optional
Locations where to put the ticks on the y-axis
Example
----------
.. code-block:: python
# create a bunch of timeseries
from lazyfmri import utils
ts = utils.random_timeseries(1.2, 0.0, 100)
ts1 = utils.random_timeseries(1.2, 0.3, 100)
ts2 = utils.random_timeseries(1.2, 0.5, 100)
ts3 = utils.random_timeseries(1.2, 0.8, 100)
ts4 = utils.random_timeseries(1.2, 1, 100)
# plot 1 timecourse
plotting.LazyLine(ts2, figsize=(20, 5))
<lazyfmri.plotting.LazyLine at 0x7f839b0289d0>
.. code-block:: python
# plot multiple timecourses, add labels, and save file
plotting.LazyLine([ts, ts1, ts2, ts3, ts4], figsize=(20, 5), save_as="test_LazyLine.pdf", labels=['vol=0', 'vol=0.3',
'vol=0.5', 'vol=0.8', 'vol=1.0'])
<lazyfmri.plotting.LazyLine at 0x7f839b2177c0>
.. code-block:: python
# add horizontal line at y=0
hline = {'pos': 0, 'color': 'k', 'lw': 0.5, 'ls': '--'}
plotting.LazyLine(ts2, figsize=(20, 5), add_hline=hline)
<lazyfmri.plotting.LazyLine at 0x7f839b053580>
.. code-block:: python
# add shaded error bars
from scipy.stats import sem
# make some stack
stack = np.hstack((ts1[...,np.newaxis],ts2[...,np.newaxis],ts4[...,np.newaxis]))
avg = stack.mean(axis=-1) # calculate mean
err = sem(stack, axis=-1) # calculate error
plotting.LazyLine(avg, figsize=(20, 5), error=err)
<lazyfmri.plotting.LazyLine at 0x7f839b0d5220>
Notes
----------
See https://lazyfmri.readthedocs.io/en/latest/examples/example.html for more examples
"""
def __init__(
self,
ts,
xx=None,
error=None,
error_alpha=0.3,
figsize=(7.1, 3.54),
cmap='viridis',
labels=None,
markers=None,
markersize=None,
markerfc=None,
markerec=None,
plot_alpha=None,
plot_kw={},
**kwargs):
self.array = ts
self.xx = xx
self.error = error
self.error_alpha = error_alpha
self.plot_alpha = plot_alpha
self.figsize = figsize
self.cmap = cmap
self.labels = labels
self.markers = markers
self.markersize = markersize
self.markerfc = markerfc
self.markerec = markerec
self.plot_kw = plot_kw
super().__init__()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
# plot
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
# save
self._save_figure(self.save_as)
[docs]
def plot(self):
"""main plotting function"""
# set figure axis
self._set_figure_axs()
# sort out color
if isinstance(self.array, np.ndarray):
self.array = [self.array]
if not self.color:
self.color = sns.color_palette(self.cmap, 1)[0]
else:
self.color = [self.color]
# check if alpha's match nr of elements in array
if isinstance(self.array, list):
if not isinstance(self.plot_alpha, list):
if self.plot_alpha is None:
self.plot_alpha = [1 for _ in range(len(self.array))]
elif isinstance(self.plot_alpha, (int, float)):
self.plot_alpha = [
self.plot_alpha for _ in range(len(self.array))]
else:
self.plot_alpha = [self.plot_alpha]
if len(self.plot_alpha) != len(self.array):
raise ValueError(
f"Alpha list ({len(self.plot_alpha)}) does not match length of data list ({len(self.array)})")
if isinstance(self.color, str):
self.color = [self.color for _ in range(len(self.array))]
if not isinstance(self.markers, list):
if self.markers is None:
self.markers = [None for _ in range(len(self.array))]
else:
self.markers = [
self.markers for _ in range(len(self.array))]
if len(self.markers) != len(self.array):
raise ValueError(
f"Marker list ({len(self.markers)}) does not match length of data list ({len(self.array)})")
self.fix_error = True
if not isinstance(self.error, (int, float, np.ndarray, list)):
self.fix_error = False
if not isinstance(self.error, list):
if not isinstance(self.error, (int, float, np.ndarray, list)):
self.error = [None for ii in range(len(self.array))]
elif isinstance(self.error, (int, float)):
self.error = [self.error for ii in range(len(self.array))]
else:
self.error = [self.error]
# filter out all-NaN errors
if self.fix_error:
self.tmp_error = [None for _ in range(len(self.array))]
for ix, err in enumerate(self.error):
if err is None:
continue
if isinstance(err, (int, float)):
self.tmp_error[ix] = err
continue
# explicit bounds: (ymin, ymax) or [ymin, ymax]
if isinstance(err, (list, tuple)) and len(err) == 2:
ymin = np.asarray(err[0])
ymax = np.asarray(err[1])
if ymin.shape != ymax.shape:
raise ValueError(
f"Lower/upper error bounds must have same shape; got {ymin.shape} and {ymax.shape}"
)
if not (np.all(np.isnan(ymin)) and np.all(np.isnan(ymax))):
self.tmp_error[ix] = (ymin, ymax)
continue
# regular 1D error vector
err = np.asarray(err)
if not np.all(np.isnan(err)):
self.tmp_error[ix] = err
self.error = self.tmp_error.copy()
if len(self.error) != len(self.array):
raise ValueError(
f"Error list ({len(self.error)}) does not match length of data list ({len(self.array)})"
)
if not isinstance(self.markersize, list):
if self.markersize is None:
self.markersize = [None for ii in range(len(self.array))]
else:
self.markersize = [self.markersize]
if len(self.markersize) != len(self.array):
raise ValueError(
f"Markersize list ({len(self.markersize)}) does not match length of data list ({len(self.array)})")
# decide on color scheme
if not isinstance(self.color, list):
self.color_list = sns.color_palette(self.cmap, len(self.array))
else:
self.color_list = self.color
if len(self.color_list) != len(self.array):
raise ValueError(
f"Length color list ({len(self.color_list)}) does not match length of data list ({len(self.array)})")
for idx, el in enumerate(self.array):
# squeeze dimensions
if el.ndim > 1:
el = el.squeeze()
# decide on line-width
if isinstance(self.line_width, list):
if len(self.line_width) != len(self.array):
raise ValueError(
f"""Length of line width lenghts {len(self.line_width)} does not match length of data list
({len(self.array)}""")
use_width = self.line_width[idx]
elif isinstance(self.line_width, (int, float)):
use_width = self.line_width
else:
use_width = ""
# decide on line-style
if isinstance(self.line_style, list):
if len(self.line_style) != len(self.array):
raise ValueError(
f"""Length of line width lenghts {len(self.line_style)} does not match length of data list
({len(self.array)}""")
use_style = self.line_style[idx]
elif isinstance(self.line_style, str):
use_style = self.line_style
else:
use_style = "solid"
# decide on x-axis
if not isinstance(self.xx, (np.ndarray, list,
range, pd.DataFrame, pd.Series)):
self.t_ = np.arange(0, len(el))
else:
# range has no copy attribute
if isinstance(self.xx, range):
self.t_ = self.xx
elif isinstance(self.xx, (pd.DataFrame, pd.Series)):
self.t_ = self.xx.values
elif isinstance(self.xx, list):
# received array-specific list
if len(self.xx) == len(self.array):
self.t_ = self.xx[idx]
else:
self.t_ = self.xx
elif isinstance(self.xx, np.ndarray):
if self.xx.ndim > 1:
self.xx = self.xx.squeeze()
self.t_ = self.xx.copy()
else:
raise TypeError(
f"Could not derive time axis from input type {type(self.xx)}")
if isinstance(self.labels, (list, np.ndarray)):
lbl = self.labels[idx]
else:
lbl = None
if isinstance(self.markerec, (list, str)):
if isinstance(self.markerec, list):
use_ec = self.markerec[idx]
else:
use_ec = self.markerec
self.plot_kw["markeredgecolor"] = use_ec
if isinstance(self.markerfc, (list, str)):
if isinstance(self.markerfc, list):
use_fc = self.markerfc[idx]
else:
use_fc = self.markerfc
self.plot_kw["markerfacecolor"] = use_fc
# plot
self.axs.plot(
self.t_,
el,
color=self.color_list[idx],
label=lbl,
lw=use_width,
ls=use_style,
marker=self.markers[idx],
markersize=self.markersize[idx],
alpha=self.plot_alpha[idx],
**self.plot_kw)
# plot shaded error bars
self._set_shaded_error(
x=self.t_,
ax=self.axs,
tc=el,
yerr=self.error[idx],
color=self.color_list[idx],
alpha=self.error_alpha
)
# axis labels and titles
self._set_legend_labels(self.axs, labels=self.labels)
# set title
self._set_title(self.axs, self.title)
# set tick params
self._set_tick_params(self.axs)
# set spine widths
self._set_spine_width(self.axs)
# give priority to specify x-lims rather than seaborn's xlim
if not self.x_lim:
if isinstance(self.xlim_left, (float, int)):
self.axs.set_xlim(left=self.xlim_left)
else:
self.axs.set_xlim(left=self.t_[0])
if self.xlim_right:
self.axs.set_xlim(right=self.xlim_right)
else:
self.axs.set_xlim(right=self.t_[-1])
else:
self.axs.set_xlim(self.x_lim)
if not self.y_lim:
if isinstance(self.ylim_bottom, (float, int)):
self.axs.set_ylim(bottom=self.ylim_bottom)
if self.ylim_top:
self.axs.set_ylim(top=self.ylim_top)
else:
self.axs.set_ylim(self.y_lim)
# set ticks
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker",
"_set_axlabel"
]
for x in ["x", "y"]:
for ff, el in zip(
loop_funcs,
["ticks", "ticklabels", "dec", "label"]
):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.axs, add_to_ax, axis=x)
# draw horizontal/vertical lines with ax?line
self._add_line(ax=self.axs)
self._add_span(ax=self.axs)
# despine
self._despine(self.axs)
[docs]
class LazyCorr(Defaults):
"""LazyCorr
Wrapper around seaborn's regplot. Plot data and a linear regression model fit. In addition to creating the plot, you can
also run a regression or correlation using pingouin by setting the corresponding argument to ``True``.
Parameters
----------
data: pd.DataFrame, optional
Input DataFrame. In this case, use strings representing column names for ``x``, ``y``, and ``color_by``. Internally,
the dataframe is parsed into arrays so that it's compatible with matplotlib's scatter-/ color-by functions
x: str, np.ndarray, list
First variable to include in plot/regression. Can be a list/array representing data, or a column name from ``data``
y: str, np.ndarray, list
Second variable to include in plot/regression. Can be a list/array representing data, or a column name from ``data``
color_by: str, np.ndarray, list
Color the points according to a separate array. Can be a list/array representing data, or a column name from ``data``.
Default color map for this is 'viridis', and can be changed by passing arguments to `scatter_kwargs`
color: str, list, optional
String representing a color, by default "#ccccccc" to color the regression fit
figsize: tuple, optional
Figure dimensions as per usual matplotlib conventions, by default (8,8)
axs: <AxesSubplot:>, optional
Matplotlib axis to store the figure on
correlation: bool, optional
Run a correlation between ``x`` and ``y``. The result is stored in ``self.correlation_result``
regression: bool, optional
Run a regression between ``x`` and ``y``. The result is stored in ``self.regression_result``
scatter_kwargs: dict, optional
Additional options passed on to the ``scatter`` function from matplotlib. Set colorbar to nothing by passing:
.. code-block:: python
scatter_kwargs={"cbar": False}
result_to_plot: bool, optional
Add the correlation/regression result to the plot
stat_kwargs: dict, optional
Options passed on to pingouin's stats functions
reg_kwargs: dict, optional
Options passed on the seaborn's regplot
Example
----------
.. code-block:: python
from lazyfmri import plotting
import matplotlib.pyplot as plt
# vanilla version; here, the regression fit has the same color as the dots.
fig,axs = plt.subplots(figsize=(7,7))
plotting.LazyCorr(
x_data,
y_data,
axs=axs,
x_label="add xlabel",
y_label="add ylabel"
)
.. code-block:: python
# more exotic version: color each dot differently
from lazyfmri import utils
# create color map between red and blue; return as list
fig,axs = plt.subplots(figsize=(7,7))
colors = utils.make_between_cm(["r","b], as_list=True, N=len(y_data))
for ix,val in enumerate(y_data):
axs.plot(x_data[ix], val, 'o', color=colors[ix], alpha=0.6)
# add the regression fit
plotting.LazyCorr(
x_data,
y_data,
axs=axs,
add_points=False, # turn off points; we've already plotted them
x_label="add xlabel",
y_label="add ylabel"
)
Notes
----------
see documentation of :class:`lazyfmri.plotting.Defaults()` for formatting options
"""
def __init__(
self,
data: pd.DataFrame = None,
x: Union[str, np.ndarray] = None,
y: Union[str, np.ndarray] = None,
color: str = "#cccccc",
figsize: tuple = (3.54, 3.54),
points: bool = True,
label: str = None,
scatter_kwargs: dict = {},
stat_kwargs: dict = {},
color_by: Union[list, np.ndarray] = None,
regression: bool = False,
correlation: bool = False,
reg_kwargs: dict = {},
error_kwargs: dict = {},
result_to_plot: bool = False,
result_loc: tuple = (0.55, 0.1),
result_ec: tuple = (1., 0.5, 0.5),
result_fc: tuple = (1., 0.8, 0.8),
result_dec: int = 2,
verbose=False,
*args,
**kwargs):
# init default plotter class
super().__init__(**kwargs)
self.data = data
self.x = x
self.y = y
self.color = color
self.figsize = figsize
self.points = points
self.label = label
self.scatter_kwargs = scatter_kwargs
self.stat_kwargs = stat_kwargs
self.color_by = color_by
self.regression = regression
self.correlation = correlation
self.reg_kwargs = reg_kwargs
self.error_kwargs = error_kwargs
self.result_to_plot = result_to_plot
self.result_loc = result_loc
self.result_fc = result_fc
self.result_ec = result_ec
self.result_dec = result_dec
self.verbose = verbose
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
# run quick regression with pingouin
if self.regression:
self._run_regression()
# run quick correlation with pingouin
if self.correlation:
self._run_correlation()
# print results?
if self.verbose:
self.print_results()
# check if we should add text box
if self.result_to_plot:
self.add_result_to_plot(*args)
# save
self._save_figure(self.save_as)
def _get_result(self):
if hasattr(self, "correlation_result"):
met_name = "r"
p_name = "p-unc"
test = "correlation"
result_obj = self.correlation_result
meth = result_obj["method"].values[0]
elif hasattr(self, "regression_result"):
met_name = "r2"
p_name = "pval"
test = "regression"
result_obj = self.regression_result
meth = ""
p_val = result_obj[p_name][0]
met_val = result_obj[met_name][0]
return {
"p": p_val,
"test": test,
"metric": met_name,
"value": met_val,
"method": meth
}
[docs]
def print_results(self, return_result=False):
res = self._get_result()
if isinstance(self.x_col, str) and isinstance(self.y_col, str):
col_txt = f"\tA={self.x_col}\tB={self.y_col}"
else:
col_txt = ""
if len(res['method']) > 0:
meth_txt = f" ({res['method']})"
else:
meth_txt = ""
r, m, v, p = res["test"], res["metric"], res["value"], res["p"]
utils.verbose(
f"Test={r}{meth_txt}{col_txt} | {m}={round(v,self.result_dec)},\tp={round(p,self.result_dec)}",
True)
if return_result:
return res
[docs]
def add_result_to_plot(self):
res = self.print_results(return_result=True)
self.axs.text(
*self.result_loc,
f"{res['metric']}={round(res['value'],self.result_dec)}, p={round(res['p'],self.result_dec)}",
size=self.font_size * 0.8,
bbox=dict(boxstyle="round", ec=self.result_ec, fc=self.result_fc),
transform=self.axs.transAxes
)
def _run_regression(self):
try:
import pingouin as pg
except BaseException:
raise ImportError(
"Could not import pingouin, so this functionality is not available")
self.regression_result = pg.linear_regression(
self.x,
self.y,
remove_na=True,
**self.stat_kwargs
)
def _run_correlation(self):
try:
import pingouin as pg
except BaseException:
raise ImportError(
"Could not import pingouin, so this functionality is not available")
# convert to dataframe
self.data = pd.DataFrame({"x": self.x, "y": self.y})
self.x = "x"
self.y = "y"
self.correlation_result = pg.pairwise_corr(
self.data,
columns=["x", "y"],
**self.stat_kwargs
)
[docs]
def plot(self):
# set figure axis
self._set_figure_axs()
# sort out stuff if input is data
if isinstance(self.data, pd.DataFrame):
# can reset index on dataframe without index, to try-except
try:
self.data = self.data.reset_index()
except BaseException:
pass
self.x_col = self.x
self.y_col = self.y
self.x = self.data[self.x].values
self.y = self.data[self.y].values
if isinstance(self.color_by, str):
self.color_by = self.data[self.color_by].values.astype(float)
self.data = None
else:
self.x_col = self.y_col = None
for e in [self.x, self.y]:
# convert list to array
if isinstance(e, list):
e = np.array(e)
# enforce float
e = e.astype(float)
# c-arguments clashes with "color" argument if you pass it to
# sns.regplot in "scatter_kws"; hence this solution
if isinstance(self.color_by, (list, np.ndarray, pd.DataFrame, pd.Series)):
# get array
if isinstance(self.color_by, (pd.DataFrame, pd.Series)):
self.color_by = self.color_by.values
# remove cbar from kwargs
if "cbar" in list(self.scatter_kwargs.keys()):
add_cbar = self.scatter_kwargs["cbar"]
_ = self.scatter_kwargs.pop("cbar")
else:
add_cbar = True
points = self.axs.scatter(
self.x,
self.y,
c=self.color_by,
**self.scatter_kwargs
)
# set colorbar
if add_cbar:
self.cbar = plt.colorbar(points)
if "label" in list(self.scatter_kwargs.keys()):
self.cbar.set_label(
self.scatter_kwargs["label"],
fontsize=self.font_size,
fontname=self.fontname
)
# sort out ticks
self._set_tick_params(self.cbar.ax)
self._set_spine_width(self.cbar.ax)
self._set_ticker(
self.cbar.ax,
self.y_dec,
axis="y"
)
# remove outside edge from colorbar
self.cbar.ax.set_frame_on(False)
# set stuff to false/empty for sns.regplot
self.points = False
self.scatter_kwargs = {}
# draw scatter error if desired
self._draw_errorbar(
x=self.x,
y=self.y,
ax=self.axs,
**self.error_kwargs
)
self.kde_color = utils.make_between_cm(self.color, self.color, as_list=True)
self.reg_ = sns.regplot(
x=self.x,
y=self.y,
color=self.color,
ax=self.axs,
scatter=self.points,
label=self.label,
scatter_kws=self.scatter_kwargs,
**self.reg_kwargs
)
# sort out ticks
self._set_spine_width(self.axs)
self._set_tick_params(self.axs)
self._set_title(self.axs, self.title)
# set ticks
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker",
"_set_axlabel"
]
for x in ["x", "y"]:
for ff, el in zip(
loop_funcs,
["ticks", "ticklabels", "dec", "label"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(self.axs, add_to_ax, axis=x)
# draw horizontal/vertical lines with ax?line
self._add_line(ax=self.axs)
self._add_span(ax=self.axs)
# set tickers & despine
self._despine(self.axs)
[docs]
class LazyBar():
"""LazyBar
Wrapper around :func:`seaborn.barplot` to follow the same aesthetics of the other Lazy* functions. It is strongly
recommended to use a dataframe for this function to make the formatting somewhat easier, but you can input arrays for `x`
and `y`. You can round the edges of the bar using `fancy=True`.
Parameters
----------
data: pd.DataFrame, optional
Input dataframe, by default None
x: str, list, np.ndarray, optional
Variable for the x-axis, by default None. Can be a column name from `data`, or a list/np.ndarray with labels for input
`y`.
y: str, list, np.ndarray, optional
Variable for the y-axis, by default None. Can be a column name from `data`, or a list/np.ndarray. If `x` is not
specified, indices from 0 to `y.shape` will be used to construct the input dataframe.
labels: list, np.ndarray, optional
custom labels that can be used when `x` denotes a column name in dataframe `data`. The replacing labels should have
the same length as the labels that are being overwritten.
axs: <AxesSubplot:>, optional
Subplot axis to put the plot on, by default None
add_points: bool, optional
Add the actual datapoints rather than just the bars, by default False. Though default is `False`
points_color: str, tuple, optional
Color of the points if you do not have nested categories, by default None
points_palette: list, sns.palettes._ColorPalette, optional
Color palette for the points if you have nested categories (e.g., multiple variables per subject so you can color the
individual subjects' data points), by default None
points_cmap: str, optional
Color map for the points if you did not specify `points_palette`, by default "viridis"
points_legend: bool, optional
Add legend of the data points (if you have nested categories), by default False. The functionality of these
interchangeable legends (`bar_legend` and `points_legend`) is quite tricky, so user discretion is advised.
points_alpha: float, optional
Alpha of the points, by default 1. Sometimes useful to adjust if you have LOADS of data points
error: str, optional
Type of error bar to use for the bar, by default "sem". Can be {'sem'|'se'} or {'std'|'sd'. Internally, we'll check if
there's enough samples to calculate errors from, otherwise `error` will be set to `None`
fancy: bool, optional
Flag to round the edges of the bars, by default False. By default, the rounding is scaled by the min/max of the plot,
regardless whether `lim` was specified. This ensures equal rounding across inputs. The other `fancy`-arguments below
are a bit vague, so leaving them default will ensure nice rounding of the bars
fancy_rounding: float, optional
Amount of rounding, by default 0.15
fancy_pad: float, optional
Vague variable, by default -0.004
fancy_aspect: float, optional
Vague variable, by default None. If None, the rounding is scaled by the min/max of the plot, regardless whether `lim`
was specified.
fancy_denom: int, optional
Scaling factor for `fancy_aspect`, by default 4 (which works well for data where the max value is ~50). Use higher
values (e.g., 6) if your data range is large
bar_legend: bool, optional
Legend for the bars, rather than points, by default False. The functionality of these interchangeable legends
(`bar_legend` and `points_legend`) is quite tricky, so user discretion is advised.
strip_kw, dict, optional
Additional kwargs passed on to seaborn's stripplot. Several factors are being set via regular arguments in the
function, such as `dodge`, `palette`, `color`, and `hue`.
Example
----------
.. code-block:: python
# this figure size works well for plots with 2 bars
fig,axs = plt.subplots(figsize=(2,8))
plotting.LazyBar(
data=df_wm,
x="group",
y="t1",
sns_ori="v",
axs=axs,
add_labels=True,
palette=[con_color,mdd_color],
add_points=True,
points_color="k",
trim_bottom=True,
sns_offset=4,
y_label2="white matter T1 (ms)",
lim=[800,1600],
fancy=True,
fancy_denom=6
)
Notes
----------
see documentation of :class:`lazyfmri.plotting.Defaults()` for formatting options
"""
def __init__(
self,
data: pd.DataFrame = None,
x: Union[str, np.ndarray] = None,
y: Union[str, np.ndarray] = None,
labels: list = None,
palette: Union[list, sns.palettes._ColorPalette] = None,
cmap: str = "inferno",
hue: str = None,
figsize=(1.77, 3.54),
add_labels: bool = False,
lim: list = None,
ticks: list = None,
add_points: bool = False,
points_color: Union[str, tuple] = None,
points_palette: Union[list, sns.palettes._ColorPalette] = None,
points_cmap: str = "viridis",
points_legend: bool = False,
points_alpha: float = 1,
error: str = "se",
ci_mode: str = "auto",
ci_low: str = None,
ci_high: str = None,
ci_kws: dict={},
fancy: bool = False,
fancy_rounding: float = 0.15,
fancy_pad: float = -0.004,
fancy_aspect: float = None,
fancy_denom: int = 4,
bar_legend: bool = False,
lbl_legend: list = None,
strip_kw: dict = {},
connect: bool = False,
connect_kw: dict = {},
**kwargs
):
self.data = data
self.x = x
self.y = y
self.hue = hue
self.labels = labels
self.palette = palette
self.cmap = cmap
self.add_labels = add_labels
self.lim = lim
self.ticks = ticks
self.bar_legend = bar_legend
self.lbl_legend = lbl_legend
self.add_points = add_points
self.points_color = points_color
self.points_palette = points_palette
self.points_cmap = points_cmap
self.points_legend = points_legend
self.points_alpha = points_alpha
self.error = error
self.fancy = fancy
self.fancy_rounding = fancy_rounding
self.fancy_pad = fancy_pad
self.fancy_aspect = fancy_aspect
self.fancy_denom = fancy_denom
self.figsize = figsize
self.strip_kw = strip_kw
self.connect = connect
self.connect_kw = connect_kw
self.ci_mode = ci_mode
self.ci_low = ci_low
self.ci_high = ci_high
self.ci_kws = ci_kws
self.kw_defaults = Defaults()
# avoid that these kwargs are passed down to matplotlib.bar.. Throws
# errors
ignore_kwargs = [
"trim_left",
"trim_bottom",
"points_hue",
"points_alpha",
"bbox_to_anchor",
"fancy",
"fancy_rounding",
"fancy_pad",
"fancy_aspect",
"fancy_denom",
"font_name",
"bar_legend",
"labels"
"strip_kw",
"fontname",
"add_legend",
"connect",
"connect_kw"
]
kw_sns = {}
for ii in kwargs:
# filter out non-ls kwargs
if ii not in self.kw_defaults.ls_kwargs + ignore_kwargs:
kw_sns[ii] = kwargs[ii]
else:
# overwrite ls-kwargs
if ii in self.kw_defaults.ls_kwargs:
if not getattr(self.kw_defaults, ii) == kwargs[ii]:
setattr(self.kw_defaults, ii, kwargs[ii])
self.__dict__.update(**self.kw_defaults.__dict__)
self.__dict__.update(**kwargs)
self.kw_defaults.update_rc(self.fontname)
if not hasattr(self, "bbox_to_anchor"):
self.bbox_to_anchor = None
if self.xkcd:
with plt.xkcd():
self.plot(**kw_sns)
else:
self.plot(**kw_sns)
# save
self.kw_defaults._save_figure(self.save_as)
def _resolve_ci_mode(self):
if self.ci_mode is None:
return None
if self.ci_mode == "auto":
if isinstance(self.ci_low, str) and isinstance(self.ci_high, str):
return "columns"
return "raw"
if self.ci_mode not in ["raw", "columns"]:
raise ValueError("ci_mode must be one of {'auto', 'raw', 'columns', None}")
return self.ci_mode
def _add_dataframe_cis(self):
if not isinstance(self.ci_low, str) or not isinstance(self.ci_high, str):
raise ValueError("When ci_mode='columns', ci_low and ci_high must be column names")
if self.ci_low not in self.data.columns:
raise ValueError(f"Column '{self.ci_low}' not found in dataframe")
if self.ci_high not in self.data.columns:
raise ValueError(f"Column '{self.ci_high}' not found in dataframe")
patches_ = [
p for p in self.ff.patches
if p.get_height() != 0 or p.get_width() != 0
]
n_rows = self.data.shape[0]
if len(patches_) < n_rows:
raise ValueError(
f"Found fewer bar patches ({len(patches_)}) than dataframe rows ({n_rows}). "
"This can happen when using incompatible input for manual CI drawing."
)
value_col = self.y if self.sns_ori == "v" else self.x
ci_defs = {
"fmt": "none",
"ecolor": "black",
"capsize": 0,
"lw": 1
}
for key, val in ci_defs.items():
if key not in self.ci_kws:
self.ci_kws = utils.update_kwargs(self.ci_kws, key, val)
# Reorder dataframe to match seaborn patch order
data_ci = self.data.copy()
x_col = self.x if self.sns_ori == "v" else self.y
hue_col = self.hue
if hue_col is not None:
# seaborn draws grouped bar patches hue-first, then x
if hasattr(self, "hue_order") and self.hue_order is not None:
hue_order = list(self.hue_order)
else:
hue_order = list(data_ci[hue_col].dropna().unique())
if hasattr(self, "order") and self.order is not None:
x_order = list(self.order)
else:
x_order = list(data_ci[x_col].dropna().unique())
data_ci[x_col] = pd.Categorical(
data_ci[x_col],
categories=x_order,
ordered=True
)
data_ci[hue_col] = pd.Categorical(
data_ci[hue_col],
categories=hue_order,
ordered=True
)
data_ci = (
data_ci
.sort_values([hue_col, x_col])
.reset_index(drop=True)
)
else:
if hasattr(self, "order") and self.order is not None:
x_order = list(self.order)
data_ci[x_col] = pd.Categorical(
data_ci[x_col],
categories=x_order,
ordered=True
)
data_ci = (
data_ci
.sort_values(x_col)
.reset_index(drop=True)
)
else:
data_ci = data_ci.reset_index(drop=True)
# Draw CIs
for patch, (_, row) in zip(patches_, data_ci.iterrows()):
value = row[value_col]
low = row[self.ci_low]
high = row[self.ci_high]
if pd.isna(value) or pd.isna(low) or pd.isna(high):
continue
err_low = value - low
err_high = high - value
if err_low < 0 or err_high < 0:
raise ValueError(
f"CI bounds inconsistent for row: value={value}, "
f"low={low}, high={high}. Expected low <= value <= high."
)
if self.sns_ori == "v":
x_center = patch.get_x() + patch.get_width() / 2
self.ff.errorbar(
x=x_center,
y=value,
yerr=np.array([[err_low], [err_high]]),
**self.ci_kws
)
else:
y_center = patch.get_y() + patch.get_height() / 2
self.ff.errorbar(
x=value,
y=y_center,
xerr=np.array([[err_low], [err_high]]),
**self.ci_kws
)
[docs]
def connect_hue_pairs(self, **kwargs):
n_xs = utils.get_unique_ids(self.data, id=self.xx, sort=False)
for i, x in enumerate(n_xs):
hue_data = self.data.loc[(self.data[self.xx] == x)]
self.connect_pairs(
hue_data,
self.hue,
hue_id=i,
**kwargs
)
[docs]
def connect_pairs(
self,
data,
x,
hue_id=None,
**kwargs):
# first assess which elements on the x-axis we have
x_inputs = utils.get_unique_ids(
data,
id=x,
sort=False
)
# find subsequent pairs
pairs = utils.pairwise(x_inputs)
connect_kws = {
"color": "black",
"alpha": 0.1
}
for key, val in connect_kws.items():
kwargs = utils.update_kwargs(
kwargs,
key,
val
)
# get children of the axis
children = self.ff.get_children()
# make all strings so we can regex the "collections"
child_str_list = [str(i) for i in children]
str_collection = utils.get_file_from_substring(
"collections", child_str_list)
# then index in actual children list
real_collection = [
children[child_str_list.index(i)] for i in str_collection]
# get correct collections based on hue_id
if isinstance(hue_id, int):
start_idx = hue_id * len(x_inputs)
real_collection = real_collection[start_idx:start_idx +
len(x_inputs)]
# find collections
collections = utils.pairwise(real_collection)
for pair, coll in zip(pairs, collections):
# find starting index of pair list
locs1 = coll[0].get_offsets()
locs2 = coll[1].get_offsets()
# find subdatasets
sets = [data.loc[data[x] == i][self.yy].values for i in pair]
# before plotting, we need to sort so that the data points
# correspond to each other
sort_idxs1 = np.argsort(sets[0])
sort_idxs2 = np.argsort(sets[1])
# revert "ascending sort" through sort_idxs2.argsort(),
# and then sort into order corresponding with set1
locs2_sorted = locs2[sort_idxs2.argsort()][sort_idxs1]
for i in range(locs1.shape[0]):
if self.sns_ori == "v":
x_idx, y_idx = 0, 1
else:
x_idx, y_idx = 1, 0
x1 = [locs1[i, x_idx], locs2_sorted[i, x_idx]]
y1 = [locs1[i, y_idx], locs2_sorted[i, y_idx]]
self.ff.plot(x1, y1, **kwargs)
[docs]
def plot(self, **kw_sns):
# set figure axis
self.kw_defaults._set_figure_axs(figsize=self.figsize)
# construct dataframe from loose inputs
if isinstance(self.y, (np.ndarray, list)):
if isinstance(self.y, list):
self.y = np.array(self.y)
if not isinstance(self.x, (np.ndarray, list)):
self.x = np.arange(0, self.y.shape[0])
self.data = pd.DataFrame({"x": self.x, "y": self.y})
self.x = "x"
self.y = "y"
# check if we should reset the index of dataframe
try:
self.data = self.data.reset_index()
except BaseException:
pass
# check if we got custom labels
if isinstance(self.labels, (np.ndarray, list)):
# check if x is multi-level
if len(self.data[self.x]) > len(self.labels):
# get unique values
unique_x = utils.get_unique_ids(
self.data,
id=self.x,
sort=False
)
# new xw
new_x = []
for x in self.data[self.x].values:
# find index of old x in unique x
curr_x = unique_x.index(x)
# substitute for new x
new_x.append(self.labels[curr_x])
self.data[self.x] = new_x
else:
self.data[self.x] = self.labels
if self.sns_ori == "h":
self.xx = self.x
self.yy = self.y
self.trim_bottom = False
self.trim_left = True
elif self.sns_ori == "v":
self.xx = self.x
self.yy = self.y
self.trim_bottom = True
self.trim_left = False
else:
raise ValueError(
f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
if isinstance(self.color, (str, tuple, list)):
if isinstance(self.color, (str, tuple)):
self.palette = None
self.cmap = None
elif isinstance(self.color, list):
self.palette = sns.color_palette(palette=self.color)
self.color = None
else:
self.color = None
if isinstance(self.palette, list):
self.palette = sns.color_palette(palette=self.palette)
if not isinstance(self.palette, sns.palettes._ColorPalette):
# self.palette = sns.color_palette(self.cmap, self.data.shape[0])
self.palette = self.cmap
# allow more input types
if isinstance(self.error, str):
if self.error.lower() in "sem":
self.error = "se"
elif self.error.lower() == "std":
self.error = "sd"
ci_mode = self._resolve_ci_mode()
sns_errorbar = self.error if ci_mode == "raw" else None
self.ff = sns.barplot(
data=self.data,
x=self.xx,
y=self.yy,
ax=self.axs,
orient=self.sns_ori,
errorbar=sns_errorbar,
hue=self.hue,
**dict(
kw_sns,
color=self.color,
palette=self.palette
))
if ci_mode == "columns":
self._add_dataframe_cis()
multi_strip = False
# need points in order to connect them..
if self.connect:
self.add_points = True
if self.add_points:
if not hasattr(self, "points_hue"):
self.points_hue = None
if not self.points_palette:
self.points_palette = self.points_cmap
# give priority to given points_color
if isinstance(self.points_color, (str, tuple)):
self.points_palette = None
self.points_hue = None
if isinstance(self.hue, str):
if isinstance(self.points_hue, str):
if self.points_hue != self.hue:
multi_strip = True
self.hue_items = list(
np.unique(self.data[self.points_hue].values))
if isinstance(self.points_color, (str, tuple)):
self.hue_colors = [
self.points_color for ii in range(len(self.hue_items))]
else:
self.hue_colors = sns.color_palette(
self.points_palette, len(self.hue_items))
for it, color in zip(self.hue_items, self.hue_colors):
df_per_it = self.data[self.data[self.points_hue] == it]
sns.stripplot(
data=df_per_it,
x=self.xx,
y=self.yy,
hue=self.hue,
dodge=False,
palette=[color] * 2,
ax=self.ff,
**self.strip_kw
)
else:
multi_strip = True
sns.stripplot(
data=self.data,
x=self.xx,
y=self.yy,
hue=self.hue,
dodge=True,
ax=self.ff,
color=self.points_color,
palette=self.points_palette,
alpha=self.points_alpha,
**self.strip_kw
)
else:
sns.stripplot(
data=self.data,
x=self.xx,
y=self.yy,
hue=self.points_hue,
dodge=False,
ax=self.ff,
color=self.points_color,
palette=self.points_palette,
alpha=self.points_alpha,
**self.strip_kw
)
# connect dots
if self.connect:
if isinstance(self.hue, str):
self.connect_hue_pairs(**self.connect_kw)
else:
self.connect_pairs(
self.data,
self.xx,
**self.connect_kw
)
# set tick params
self.kw_defaults._set_tick_params(self.ff)
# set spine widths
self.kw_defaults._set_spine_width(self.ff)
if not self.add_labels:
if self.sns_ori == 'h':
self.kw_defaults._set_ticks(self.ff, [], axis="y")
elif self.sns_ori == "v":
self.kw_defaults._set_ticks(self.ff, [], axis="x")
else:
raise ValueError(
f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
elif isinstance(self.add_labels, list):
self.kw_defaults._set_xlabel(self.ff, self.add_labels)
if isinstance(self.sns_rot, (int, float)):
if self.sns_ori == 'h':
self.kw_defaults._set_ticklabels(
self.ff,
self.ff.get_yticklabels(),
rotation=self.sns_rot,
fontname=self.fontname,
axis="y"
)
elif self.sns_ori == "v":
self.kw_defaults._set_ticklabels(
self.ff,
self.ff.get_xticklabels(),
axis="x",
rotation=self.sns_rot,
fontname=self.fontname
)
else:
raise ValueError(
f"sns_ori must be 'v' or 'h', not '{self.sns_ori}'")
# set limits depending on orientation
self.kw_defaults._set_bar_lim(self.ff, self.lim)
# set ticks depending on orientation
self.kw_defaults._set_bar_ticks(self.ff, self.ticks)
# from: https://stackoverflow.com/a/61569240
if self.fancy:
new_patches = []
for patch in self.ff.patches:
bb = patch.get_bbox()
color = patch.get_facecolor()
# max of axis divided by 4 gives nice rounding
if not isinstance(self.fancy_aspect, (int, float)):
if self.sns_ori == "v":
y_limiter = patch._axes.get_ylim()[-1]
if isinstance(self.lim, list):
y_limiter -= self.lim[0]
self.fancy_aspect = y_limiter / self.fancy_denom
else:
x_limiter = patch._axes.get_xlim()[-1]
if isinstance(self.lim, list):
x_limiter -= self.lim[0]
self.fancy_aspect = x_limiter / self.fancy_denom
# make rounding at limit
if isinstance(self.lim, list):
if self.sns_ori == "v":
ymin = self.lim[0]
xmin = bb.xmin
height = bb.height - ymin
width = bb.width
else:
xmin = self.lim[0]
ymin = bb.ymin
width = bb.width - xmin
height = bb.height
else:
xmin = bb.xmin
ymin = bb.ymin
height = bb.height
width = bb.width
p_bbox = patches.FancyBboxPatch(
(xmin, ymin),
abs(width), abs(height),
boxstyle=f"round,pad={self.fancy_pad},rounding_size={self.fancy_rounding}",
ec="none",
fc=color,
mutation_aspect=self.fancy_aspect
)
patch.remove()
new_patches.append(p_bbox)
for patch in new_patches:
self.ff.add_patch(patch)
# set xlabel to none of nothing is specified
if isinstance(self.x, str) and not isinstance(self.x_label, str):
self.ff.set(xlabel=None)
if isinstance(self.y, str) and not isinstance(self.y_label, str):
self.ff.set(ylabel=None)
# set these explicitly; remove left axis is orientation = horizontal |
# remove bottom axis if orientation is vertical
if hasattr(self, "trim_left"):
trim_left = self.trim_left
else:
trim_left = False
if hasattr(self, "trim_bottom"):
trim_bottom = self.trim_bottom
else:
trim_bottom = False
# draw horizontal/vertical lines with ax?line
self.kw_defaults._add_line(ax=self.ff)
self.kw_defaults._add_span(ax=self.ff)
# set tickers & despine
# set ticks
loop_funcs = [
"_set_ticker",
"_set_axlabel"
]
for x in ["x", "y"]:
for ff, el in zip(
loop_funcs,
["dec", "label"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self.kw_defaults, ff)(self.ff, add_to_ax, axis=x)
# sort out legend
if self.bar_legend or self.points_legend:
self.add_legend = True
# filter out handles that correspond to labels
self.legend_kw = {}
for key, val in zip(
["fontsize", "handletextpad", "frameon"],
[self.label_size, self.legend_handletext, False]):
self.legend_kw[key] = val
if isinstance(self.bbox_to_anchor, tuple):
self.legend_kw["bbox_to_anchor"] = self.bbox_to_anchor
# get handles
handles, labels = self.ff.get_legend_handles_labels()
# bar legend
if self.bar_legend:
# do some more exotic stuff to disentangle coloring from bars
# and hue
if isinstance(self.hue, str):
# find categorical handles
handles, labels = self.ff.get_legend_handles_labels()
# find indices of categorical handles in list
cc = self.data[self.hue].values
indexes = np.unique(cc, return_index=True)[1]
cond = [cc[index] for index in sorted(indexes)]
if multi_strip:
handles = handles[-len(cond):]
labels = labels[-len(cond):]
else:
if not self.add_points:
self.add_legend = False
else:
self.add_legend = False
# fill in legend
if self.add_legend:
if isinstance(self.lbl_legend, list):
labels = self.lbl_legend
# brute force
if len(handles) == 0:
handles = list(self.ff.patches)
self.legend = self.ff.legend(
handles,
labels,
**self.legend_kw
)
else:
self.legend = self.ff.legend([], [], frameon=False)
self.kw_defaults._despine(
self.ff,
left=trim_left,
bottom=trim_bottom
)
# set title
self.kw_defaults._set_title(self.ff, self.title)
[docs]
class LazyHist(Defaults):
"""LazyHist
Wrapper around seaborn's histogram plotter
Parameters
----------
data: numpy.ndarray
Input data for histogram
kde: bool, optional
Add kernel density plot to histogram with seaborn (https://seaborn.pydata.org/generated/seaborn.kdeplot.html). Default
is False
hist: bool, optional
Add histogram to plot. Default is True
fill: bool, optional
Fill the area below the kde plot. Default is False
bins: str, optional
Set bins for histogram; default = "auto"
kde_kwargs: dict, optional
Additional arguments passed on the seaborn's `kde_plot`
hist_kwargs: dict, optional
Additional arguments passed on to matplotlib's `hist` fuction
Returns
----------
matplotlib.pyplot plot
Example
----------
.. code-block:: python
from lazyfmri import plotting
import matplotlib.pyplot as plt
fig,axs = plt.subplots(figsize=(7,7))
plotting.LazyHist(
y_data,
axs=axs,
kde=True,
hist=True,
fill=False,
y_label2="add y_label",
x_label2="add x_label",
hist_kwargs={"alpha": 0.4},
kde_kwargs={"linewidth": 4}
)
Notes
----------
see documentation of :class:`lazyfmri.plotting.Defaults()` for formatting options
"""
def __init__(
self,
data,
x=None,
y=None,
figsize=(3.54, 3.54),
kde=False,
hist=True,
bins="auto",
fill=False,
kde_kwargs={},
hist_kwargs={},
color="#cccccc",
cmap=None,
fancy: bool = False,
fancy_rounding: float = 0.15,
fancy_pad: float = -0.004,
fancy_aspect: float = None,
**kwargs):
super().__init__()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
# read regular arguments
self.data = data
self.x = x
self.y = y
self.figsize = figsize
self.kde = kde
self.kde_kwargs = kde_kwargs
self.hist_kwargs = hist_kwargs
self.hist = hist
self.bins = bins
self.fill = fill
self.color = color
self.kwargs = kwargs
self.fancy = fancy
self.fancy_rounding = fancy_rounding
self.fancy_pad = fancy_pad
self.fancy_aspect = fancy_aspect
self.cmap = cmap
# self.__dict__.update(self.kde_kwargs)
if self.xkcd:
with plt.xkcd():
self.plot()
else:
self.plot()
if self.kde:
try:
self.kde_ = self.return_kde()
except BaseException:
pass
# save
self._save_figure(self.save_as)
[docs]
def return_kde(self):
return self.ff.get_lines()[0].get_data()
def _set_color(self):
if isinstance(self.cmap, str):
self.color = sns.color_palette(self.cmap, 1)[0]
[docs]
def plot(self):
# overwrite color argument if cmap is specified
self._set_color()
# set figure axis
self._set_figure_axs()
if self.hist:
if self.kde:
# ensure density for KDE
self.hist_kwargs = utils.update_kwargs(
self.hist_kwargs,
"density",
True
)
self.vals, self.bins, self.patches = self.axs.hist(
self.data,
bins=self.bins,
color=self.color,
**self.hist_kwargs
)
# from: https://stackoverflow.com/a/61569240
if self.fancy:
new_patches = []
for patch in reversed(self.patches):
# max of axis divided by 4 gives nice rounding
if not isinstance(self.fancy_aspect, (int, float)):
self.fancy_aspect = patch._axes.get_ylim()[-1] / 4
bb = patch.get_bbox()
color = patch.get_facecolor()
p_bbox = patches.FancyBboxPatch(
(bb.xmin, bb.ymin),
abs(bb.width), abs(bb.height),
boxstyle=f"round,pad={self.fancy_pad},rounding_size={self.fancy_rounding}",
ec="none",
fc=color,
mutation_aspect=self.fancy_aspect
)
patch.remove()
new_patches.append(p_bbox)
for patch in new_patches:
self.axs.add_patch(patch)
if self.kde:
# turn off legend by default
if "legend" not in list(self.kde_kwargs):
self.kde_kwargs["legend"] = False
# if not self.hist:
self.kde_kwargs = utils.update_kwargs(
self.kde_kwargs,
"color",
self.color,
force=True
)
self.ff = sns.kdeplot(
data=self.data,
x=self.x,
y=self.y,
ax=self.axs,
fill=self.fill,
**self.kde_kwargs
)
# there's no self.ff if kde=False
if hasattr(self, "ff"):
self.active_axs = self.ff
else:
self.active_axs = self.axs
# set titles
self._set_title(self.active_axs, self.title)
# set tick params/axis width
self._set_tick_params(self.active_axs)
self._set_spine_width(self.active_axs)
# set limits
self._set_xlim(self.active_axs, self.x_lim)
self._set_ylim(self.active_axs, self.y_lim)
# set tickers & despine
loop_funcs = [
"_set_ticks",
"_set_ticklabels",
"_set_ticker",
"_set_axlabel"
]
for x in ["x", "y"]:
for ff, el in zip(
loop_funcs,
["ticks", "ticklabels", "dec", "label"]):
add_to_ax = getattr(self, f"{x}_{el}")
getattr(self, ff)(
self.active_axs,
add_to_ax,
axis=x,
fontname=self.fontname
)
# set axis labels
if not isinstance(self.x_label, str):
self.active_axs.set(xlabel=None)
if not isinstance(self.y_label, str):
self.active_axs.set(ylabel=None)
if hasattr(self, "trim_left"):
trim_left = self.trim_left
else:
trim_left = False
if "trim_bottom" in list(self.kwargs.keys()):
trim_bottom = self.kwargs["trim_bottom"]
else:
trim_bottom = False
self._despine(
self.active_axs,
left=trim_left,
bottom=trim_bottom
)
# draw horizontal/vertical lines with ax?line
self._add_line(ax=self.active_axs)
self._add_span(ax=self.active_axs)
# set title
self._set_title(self.active_axs, self.title)
[docs]
class LazyColorbar(Defaults):
def __init__(
self,
cmap="magma_r",
txt=None,
vmin=0,
vmax=10,
ori="vertical",
ticks=None,
labels=None,
flip_ticks=False,
flip_label=False,
figsize=(3.54, 0.5),
save_as=None,
cm_nr=5,
cm_decimal=3,
cb_kws={},
font_kws={},
**kwargs):
self.cmap = cmap
self.txt = txt
self.vmin = vmin
self.vmax = vmax
self.ori = ori
self.ticks = ticks
self.flip_ticks = flip_ticks
self.flip_label = flip_label
self.figsize = figsize
self.save_as = save_as
self.cm_nr = cm_nr
self.cm_decimal = cm_decimal
self.labels = labels
self.font_kws = font_kws
if self.ori == "vertical":
self.figsize = (self.figsize[1], self.figsize[0])
super().__init__(**kwargs)
# set figure axis
self._set_figure_axs()
self.__dict__.update(kwargs)
self.update_rc(self.fontname)
if self.axs is None:
if isinstance(self.save_as, str):
self.fig, self.axs = plt.subplots(figsize=self.figsize)
else:
self.fig, self.axs = plt.subplots(figsize=self.figsize)
# make colorbase instance
if isinstance(self.cmap, str):
self.cmap = mpl.cm.get_cmap(self.cmap, 256)
# decide ticks
if not isinstance(self.ticks, (np.ndarray, list)):
self.ticks = self.colormap_ticks(
vmin=self.vmin,
vmax=self.vmax,
key=self.txt,
dec=self.cm_decimal,
nr=self.cm_nr
)
# plop everything in class
mpl.colorbar.Colorbar(
self.axs,
orientation=self.ori,
cmap=self.cmap,
norm=mpl.colors.Normalize(vmin, vmax),
label=self.txt,
ticks=self.ticks,
**cb_kws
)
if self.ori == "vertical":
tick_ax = "y"
# set font stuff
if self.flip_ticks:
self.axs.yaxis.set_ticks_position("left")
if self.flip_label:
self.axs.yaxis.set_label_position("left")
text = self.axs.yaxis.label
else:
tick_ax = "x"
if self.flip_ticks:
self.axs.xaxis.set_ticks_position("top")
if self.flip_label:
self.axs.xaxis.set_label_position("top")
text = self.axs.xaxis.label
if "color" in list(self.font_kws.keys()):
ft_color = self.font_kws["color"]
self.font_kws.pop('color')
else:
ft_color = "black"
font = mpl.font_manager.FontProperties(
size=self.font_size,
**self.font_kws
)
text.set_font_properties(font)
text.set_color(ft_color)
# fix ticks
self._set_tick_params(
self.axs,
labelcolor=ft_color,
)
# set tick labels?
if isinstance(self.labels, (np.ndarray, list)):
self._set_ticklabels(
self.axs,
self.labels,
axis=tick_ax
)
# turn off frame
self.axs.set_frame_on(False)
# save
self._save_figure(self.save_as)
[docs]
@staticmethod
def colormap_ticks(
vmin=None,
vmax=None,
key=None,
dec=3,
nr=5
):
# store colormaps
if isinstance(key, str):
if key == "polar" or key == "polar angle" or "polar" in key:
ticks = [-np.pi, 0, np.pi]
else:
ticks = list(np.linspace(vmin, vmax, endpoint=True, num=nr))
else:
ticks = list(np.linspace(vmin, vmax, endpoint=True, num=nr))
# round ticks
ticks = [round(ii, dec) for ii in ticks]
# check if minimum of ticks > minimum of data
if ticks[0] < vmin:
ticks[0] = utils.round_decimals_up(vmin, dec)
# check if maximum of ticks < maximum of data
if ticks[-1] > vmax:
ticks[-1] = utils.round_decimals_down(vmax, dec)
return ticks
[docs]
def show(self):
fig = plt.figure()
new_manager = fig.canvas.manager
new_manager.canvas.figure = self.fig
self.fig.set_canvas(new_manager.canvas)
[docs]
def fig_annot(
fig,
axs=None,
y=1.01,
x0_corr=0,
x_corr=-0.09,
fontsize=28,
lower=False,
brackets=False,
square=False,
**kwargs
):
# get figure letters
if lower:
alphabet = list(string.ascii_lowercase)
else:
alphabet = list(string.ascii_uppercase)
if isinstance(axs, list):
ax_list = axs
else:
ax_list = fig.axes
# make annotations
for ix, ax in enumerate(ax_list):
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
if isinstance(x_corr, (float, int)):
use_x_pos = x_corr
else:
if len(x_corr) != len(ax_list) - 1:
raise ValueError(
f"""List with 'x_corr'-values must match list with axes-1. x_corr contains {len(x_corr)} elements,
while {len(ax_list)-1} axes are specified""")
use_x_pos = x_corr[ix - 1]
if ix == 0:
move_frac = x0_corr / bbox.width
else:
move_frac = use_x_pos / bbox.width
pos = move_frac
if isinstance(y, list):
if len(y) != len(ax_list):
raise ValueError(
f"""List with y-values must match list with axes. y contains {len(y)} elements,
while {len(ax_list)} axes are specified""")
y_pos = y[ix]
else:
y_pos = y
txt = alphabet[ix]
if brackets:
if square:
txt = f"[{txt}]"
else:
txt = f"({txt})"
ax.annotate(
txt,
(pos, y_pos),
fontsize=fontsize,
xycoords="axes fraction",
**kwargs
)
turbo_colormap_data = np.array(
[[0.18995, 0.07176, 0.23217],
[0.19483, 0.08339, 0.26149],
[0.19956, 0.09498, 0.29024],
[0.20415, 0.10652, 0.31844],
[0.20860, 0.11802, 0.34607],
[0.21291, 0.12947, 0.37314],
[0.21708, 0.14087, 0.39964],
[0.22111, 0.15223, 0.42558],
[0.22500, 0.16354, 0.45096],
[0.22875, 0.17481, 0.47578],
[0.23236, 0.18603, 0.50004],
[0.23582, 0.19720, 0.52373],
[0.23915, 0.20833, 0.54686],
[0.24234, 0.21941, 0.56942],
[0.24539, 0.23044, 0.59142],
[0.24830, 0.24143, 0.61286],
[0.25107, 0.25237, 0.63374],
[0.25369, 0.26327, 0.65406],
[0.25618, 0.27412, 0.67381],
[0.25853, 0.28492, 0.69300],
[0.26074, 0.29568, 0.71162],
[0.26280, 0.30639, 0.72968],
[0.26473, 0.31706, 0.74718],
[0.26652, 0.32768, 0.76412],
[0.26816, 0.33825, 0.78050],
[0.26967, 0.34878, 0.79631],
[0.27103, 0.35926, 0.81156],
[0.27226, 0.36970, 0.82624],
[0.27334, 0.38008, 0.84037],
[0.27429, 0.39043, 0.85393],
[0.27509, 0.40072, 0.86692],
[0.27576, 0.41097, 0.87936],
[0.27628, 0.42118, 0.89123],
[0.27667, 0.43134, 0.90254],
[0.27691, 0.44145, 0.91328],
[0.27701, 0.45152, 0.92347],
[0.27698, 0.46153, 0.93309],
[0.27680, 0.47151, 0.94214],
[0.27648, 0.48144, 0.95064],
[0.27603, 0.49132, 0.95857],
[0.27543, 0.50115, 0.96594],
[0.27469, 0.51094, 0.97275],
[0.27381, 0.52069, 0.97899],
[0.27273, 0.53040, 0.98461],
[0.27106, 0.54015, 0.98930],
[0.26878, 0.54995, 0.99303],
[0.26592, 0.55979, 0.99583],
[0.26252, 0.56967, 0.99773],
[0.25862, 0.57958, 0.99876],
[0.25425, 0.58950, 0.99896],
[0.24946, 0.59943, 0.99835],
[0.24427, 0.60937, 0.99697],
[0.23874, 0.61931, 0.99485],
[0.23288, 0.62923, 0.99202],
[0.22676, 0.63913, 0.98851],
[0.22039, 0.64901, 0.98436],
[0.21382, 0.65886, 0.97959],
[0.20708, 0.66866, 0.97423],
[0.20021, 0.67842, 0.96833],
[0.19326, 0.68812, 0.96190],
[0.18625, 0.69775, 0.95498],
[0.17923, 0.70732, 0.94761],
[0.17223, 0.71680, 0.93981],
[0.16529, 0.72620, 0.93161],
[0.15844, 0.73551, 0.92305],
[0.15173, 0.74472, 0.91416],
[0.14519, 0.75381, 0.90496],
[0.13886, 0.76279, 0.89550],
[0.13278, 0.77165, 0.88580],
[0.12698, 0.78037, 0.87590],
[0.12151, 0.78896, 0.86581],
[0.11639, 0.79740, 0.85559],
[0.11167, 0.80569, 0.84525],
[0.10738, 0.81381, 0.83484],
[0.10357, 0.82177, 0.82437],
[0.10026, 0.82955, 0.81389],
[0.09750, 0.83714, 0.80342],
[0.09532, 0.84455, 0.79299],
[0.09377, 0.85175, 0.78264],
[0.09287, 0.85875, 0.77240],
[0.09267, 0.86554, 0.76230],
[0.09320, 0.87211, 0.75237],
[0.09451, 0.87844, 0.74265],
[0.09662, 0.88454, 0.73316],
[0.09958, 0.89040, 0.72393],
[0.10342, 0.89600, 0.71500],
[0.10815, 0.90142, 0.70599],
[0.11374, 0.90673, 0.69651],
[0.12014, 0.91193, 0.68660],
[0.12733, 0.91701, 0.67627],
[0.13526, 0.92197, 0.66556],
[0.14391, 0.92680, 0.65448],
[0.15323, 0.93151, 0.64308],
[0.16319, 0.93609, 0.63137],
[0.17377, 0.94053, 0.61938],
[0.18491, 0.94484, 0.60713],
[0.19659, 0.94901, 0.59466],
[0.20877, 0.95304, 0.58199],
[0.22142, 0.95692, 0.56914],
[0.23449, 0.96065, 0.55614],
[0.24797, 0.96423, 0.54303],
[0.26180, 0.96765, 0.52981],
[0.27597, 0.97092, 0.51653],
[0.29042, 0.97403, 0.50321],
[0.30513, 0.97697, 0.48987],
[0.32006, 0.97974, 0.47654],
[0.33517, 0.98234, 0.46325],
[0.35043, 0.98477, 0.45002],
[0.36581, 0.98702, 0.43688],
[0.38127, 0.98909, 0.42386],
[0.39678, 0.99098, 0.41098],
[0.41229, 0.99268, 0.39826],
[0.42778, 0.99419, 0.38575],
[0.44321, 0.99551, 0.37345],
[0.45854, 0.99663, 0.36140],
[0.47375, 0.99755, 0.34963],
[0.48879, 0.99828, 0.33816],
[0.50362, 0.99879, 0.32701],
[0.51822, 0.99910, 0.31622],
[0.53255, 0.99919, 0.30581],
[0.54658, 0.99907, 0.29581],
[0.56026, 0.99873, 0.28623],
[0.57357, 0.99817, 0.27712],
[0.58646, 0.99739, 0.26849],
[0.59891, 0.99638, 0.26038],
[0.61088, 0.99514, 0.25280],
[0.62233, 0.99366, 0.24579],
[0.63323, 0.99195, 0.23937],
[0.64362, 0.98999, 0.23356],
[0.65394, 0.98775, 0.22835],
[0.66428, 0.98524, 0.22370],
[0.67462, 0.98246, 0.21960],
[0.68494, 0.97941, 0.21602],
[0.69525, 0.97610, 0.21294],
[0.70553, 0.97255, 0.21032],
[0.71577, 0.96875, 0.20815],
[0.72596, 0.96470, 0.20640],
[0.73610, 0.96043, 0.20504],
[0.74617, 0.95593, 0.20406],
[0.75617, 0.95121, 0.20343],
[0.76608, 0.94627, 0.20311],
[0.77591, 0.94113, 0.20310],
[0.78563, 0.93579, 0.20336],
[0.79524, 0.93025, 0.20386],
[0.80473, 0.92452, 0.20459],
[0.81410, 0.91861, 0.20552],
[0.82333, 0.91253, 0.20663],
[0.83241, 0.90627, 0.20788],
[0.84133, 0.89986, 0.20926],
[0.85010, 0.89328, 0.21074],
[0.85868, 0.88655, 0.21230],
[0.86709, 0.87968, 0.21391],
[0.87530, 0.87267, 0.21555],
[0.88331, 0.86553, 0.21719],
[0.89112, 0.85826, 0.21880],
[0.89870, 0.85087, 0.22038],
[0.90605, 0.84337, 0.22188],
[0.91317, 0.83576, 0.22328],
[0.92004, 0.82806, 0.22456],
[0.92666, 0.82025, 0.22570],
[0.93301, 0.81236, 0.22667],
[0.93909, 0.80439, 0.22744],
[0.94489, 0.79634, 0.22800],
[0.95039, 0.78823, 0.22831],
[0.95560, 0.78005, 0.22836],
[0.96049, 0.77181, 0.22811],
[0.96507, 0.76352, 0.22754],
[0.96931, 0.75519, 0.22663],
[0.97323, 0.74682, 0.22536],
[0.97679, 0.73842, 0.22369],
[0.98000, 0.73000, 0.22161],
[0.98289, 0.72140, 0.21918],
[0.98549, 0.71250, 0.21650],
[0.98781, 0.70330, 0.21358],
[0.98986, 0.69382, 0.21043],
[0.99163, 0.68408, 0.20706],
[0.99314, 0.67408, 0.20348],
[0.99438, 0.66386, 0.19971],
[0.99535, 0.65341, 0.19577],
[0.99607, 0.64277, 0.19165],
[0.99654, 0.63193, 0.18738],
[0.99675, 0.62093, 0.18297],
[0.99672, 0.60977, 0.17842],
[0.99644, 0.59846, 0.17376],
[0.99593, 0.58703, 0.16899],
[0.99517, 0.57549, 0.16412],
[0.99419, 0.56386, 0.15918],
[0.99297, 0.55214, 0.15417],
[0.99153, 0.54036, 0.14910],
[0.98987, 0.52854, 0.14398],
[0.98799, 0.51667, 0.13883],
[0.98590, 0.50479, 0.13367],
[0.98360, 0.49291, 0.12849],
[0.98108, 0.48104, 0.12332],
[0.97837, 0.46920, 0.11817],
[0.97545, 0.45740, 0.11305],
[0.97234, 0.44565, 0.10797],
[0.96904, 0.43399, 0.10294],
[0.96555, 0.42241, 0.09798],
[0.96187, 0.41093, 0.09310],
[0.95801, 0.39958, 0.08831],
[0.95398, 0.38836, 0.08362],
[0.94977, 0.37729, 0.07905],
[0.94538, 0.36638, 0.07461],
[0.94084, 0.35566, 0.07031],
[0.93612, 0.34513, 0.06616],
[0.93125, 0.33482, 0.06218],
[0.92623, 0.32473, 0.05837],
[0.92105, 0.31489, 0.05475],
[0.91572, 0.30530, 0.05134],
[0.91024, 0.29599, 0.04814],
[0.90463, 0.28696, 0.04516],
[0.89888, 0.27824, 0.04243],
[0.89298, 0.26981, 0.03993],
[0.88691, 0.26152, 0.03753],
[0.88066, 0.25334, 0.03521],
[0.87422, 0.24526, 0.03297],
[0.86760, 0.23730, 0.03082],
[0.86079, 0.22945, 0.02875],
[0.85380, 0.22170, 0.02677],
[0.84662, 0.21407, 0.02487],
[0.83926, 0.20654, 0.02305],
[0.83172, 0.19912, 0.02131],
[0.82399, 0.19182, 0.01966],
[0.81608, 0.18462, 0.01809],
[0.80799, 0.17753, 0.01660],
[0.79971, 0.17055, 0.01520],
[0.79125, 0.16368, 0.01387],
[0.78260, 0.15693, 0.01264],
[0.77377, 0.15028, 0.01148],
[0.76476, 0.14374, 0.01041],
[0.75556, 0.13731, 0.00942],
[0.74617, 0.13098, 0.00851],
[0.73661, 0.12477, 0.00769],
[0.72686, 0.11867, 0.00695],
[0.71692, 0.11268, 0.00629],
[0.70680, 0.10680, 0.00571],
[0.69650, 0.10102, 0.00522],
[0.68602, 0.09536, 0.00481],
[0.67535, 0.08980, 0.00449],
[0.66449, 0.08436, 0.00424],
[0.65345, 0.07902, 0.00408],
[0.64223, 0.07380, 0.00401],
[0.63082, 0.06868, 0.00401],
[0.61923, 0.06367, 0.00410],
[0.60746, 0.05878, 0.00427],
[0.59550, 0.05399, 0.00453],
[0.58336, 0.04931, 0.00486],
[0.57103, 0.04474, 0.00529],
[0.55852, 0.04028, 0.00579],
[0.54583, 0.03593, 0.00638],
[0.53295, 0.03169, 0.00705],
[0.51989, 0.02756, 0.00780],
[0.50664, 0.02354, 0.00863],
[0.49321, 0.01963, 0.00955],
[0.47960, 0.01583, 0.01055]]
)
[docs]
def RGBToPyCmap(rgbdata):
nsteps = rgbdata.shape[0]
stepaxis = np.linspace(0, 1, nsteps)
rdata = []
gdata = []
bdata = []
for istep in range(nsteps):
r = rgbdata[istep, 0]
g = rgbdata[istep, 1]
b = rgbdata[istep, 2]
rdata.append((stepaxis[istep], r, r))
gdata.append((stepaxis[istep], g, g))
bdata.append((stepaxis[istep], b, b))
mpl_data = {
'red': rdata,
'green': gdata,
'blue': bdata
}
return mpl_data
mpl_data = RGBToPyCmap(turbo_colormap_data)
try:
plt.register_cmap(
name='turbo',
data=mpl_data,
lut=turbo_colormap_data.shape[0]
)
except BaseException:
pass
mpl_data_r = RGBToPyCmap(turbo_colormap_data[::-1, :])
try:
plt.register_cmap(
name='turbo_r',
data=mpl_data_r,
lut=turbo_colormap_data.shape[0]
)
except BaseException:
pass
[docs]
def make_wm_pial_ticks(
data,
start=0,
end=100,
step=25,
force_int=True
):
x_ticks = [0, data.shape[0]//4, data.shape[0]//2,
(data.shape[0]//2+data.shape[0]//4), data.shape[0]]
x_labels = list(np.arange(start, end*1.1, step=step))
if len(x_ticks) != len(x_labels):
raise ValueError(
f"Length of ticks ({len(x_ticks)}) {x_ticks} != length of labels ({len(x_labels)}) {x_labels}")
if force_int:
x_labels = [int(round(i, 0)) for i in x_labels]
return {
"ticks": x_ticks,
"labels": x_labels
}
[docs]
def annotate_cortical_ribbon(
axs,
pial_pos=(0.02, 0.92),
wm_pos=(0.02, 0.02),
lbls=["pial", "wm"],
**kwargs
):
if "xycoords" not in list(kwargs.keys()):
kwargs["xycoords"] = "axes fraction"
for pos, tag in zip([pial_pos, wm_pos], lbls):
axs.annotate(
tag,
pos,
**kwargs
)
[docs]
def add_axvspan(
axs,
loc=(0, 2),
color="#cccccc",
alpha=0.3,
ymin=0,
ymax=1,
**kwargs
):
axs.axvspan(
*loc,
ymin=ymin,
ymax=ymax,
alpha=alpha,
color=color,
**kwargs
)
return axs