# ======================================================================================
# Copyright (©) 2015-2025 LCS - Laboratoire Catalyse et Spectrochimie, Caen, France.
# CeCILL-B FREE SOFTWARE LICENSE AGREEMENT
# See full LICENSE agreement in the root directory.
# ======================================================================================
"""Module containing multiplot function(s)."""
__all__ = [
"multiplot",
"multiplot_map",
"multiplot_stack",
"multiplot_image",
"multiplot_lines",
"multiplot_scatter",
"multiplot_with_transposed",
"plot_with_transposed",
]
__dataset_methods__ = []
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
# TODO: tight_layout module is deprecated
from matplotlib._tight_layout import get_subplotspec_list # get_renderer,
from matplotlib._tight_layout import get_tight_layout_figure # get_renderer,
from spectrochempy.utils.misc import is_sequence
from spectrochempy.utils.plots import _Axes
# from spectrochempy.core import preferences, project_preferences
[docs]
def multiplot_scatter(datasets, **kwargs):
"""
Plot a multiplot with 1D scatter type plots.
Alias of multiplot (with `method` argument set to `scatter` .
"""
kwargs["method"] = "scatter"
return multiplot(datasets, **kwargs)
[docs]
def multiplot_lines(datasets, **kwargs):
"""
Plot a multiplot with 1D linetype plots.
Alias of multiplot (with `method` argument set to `lines` .
"""
kwargs["method"] = "lines"
return multiplot(datasets, **kwargs)
[docs]
def multiplot_stack(datasets, **kwargs):
"""
Plot a multiplot with 2D stack type plots.
Alias of multiplot (with `method` argument set to `stack` .
"""
kwargs["method"] = "stack"
return multiplot(datasets, **kwargs)
[docs]
def multiplot_map(datasets, **kwargs):
"""
Plot a multiplot with 2D map type plots.
Alias of multiplot (with `method` argument set to `map` .
"""
kwargs["method"] = "map"
return multiplot(datasets, **kwargs)
[docs]
def multiplot_image(datasets, **kwargs):
"""
Plot a multiplot with 2D image type plots.
Alias of multiplot (with `method` argument set to `image` .
"""
kwargs["method"] = "image"
return multiplot(datasets, **kwargs)
# with transpose plot -----------------------------------------------------------------
[docs]
def plot_with_transposed(dataset, **kwargs):
"""
Plot a 2D dataset as a stacked plot with its transposition in a second axe.
Alias of plot_2D (with `method` argument set to `with_transposed` ).
"""
kwargs["method"] = "with_transposed"
return multiplot(dataset, **kwargs)
multiplot_with_transposed = plot_with_transposed
[docs]
def multiplot(
datasets=None,
labels=None,
nrow=1,
ncol=1,
method="stack",
sharex=False,
sharey=False,
sharez=False,
colorbar=False,
suptitle=None,
suptitle_color="k",
mpl_event=True,
**kwargs,
):
"""
Generate a figure with multiple axes arranged in array (n rows, n columns).
Parameters
----------
datasets : nddataset or list of nddataset
Datasets to plot.
labels : list of str
The labels that will be used as title of each axes.
method : str, default to `map` for 2D and `lines` for 1D data
Type of plot to draw in all axes (`lines` , `scatter` , `stack` , `map`
,`image` or `with_transposed` ).
nrows, ncols : int, default=1
Number of rows/cols of the subplot grid. ncol*nrow must be equal
to the number of datasets to plot.
sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default=False
Controls sharing of properties among x (`sharex` ) or y (`sharey` )
axes::
- True or 'all' : x- or y-axis will be shared among all subplots.
- False or 'none' : each subplot x- or y-axis will be independent.
- 'row' : each subplot row will share an x- or y-axis.
- 'col' : each subplot column will share an x- or y-axis.
When subplots have a shared x-axis along a column, only the x tick
labels of the bottom subplot are visible. Similarly, when
subplots have a shared y-axis along a row, only the y tick labels
of the first column subplot are visible.
sharez : bool or {'none', 'all', 'row', 'col'}, default=False
Equivalent to sharey for 1D plot.
for 2D plot, z is the intensity axis (i.e., contour levels for maps or
the vertical axis for stack plot), y is the third axis.
figsize : 2-tuple of floats
`(width, height)` tuple in inches.
dpi : float
Dots per inch
facecolor : color
The figure patch facecolor; defaults to rc `figure.facecolor` .
edgecolor : color
The figure patch edge color; defaults to rc `figure.edgecolor` .
linewidth : float
The figure patch edge linewidth; the default linewidth of the frame.
frameon : bool
If `False` , suppress drawing the figure frame.
left : float in the [0-1] interval
The left side of the subplots of the figure.
right : float in the [0-1] interval
The right side of the subplots of the figure.
bottom : float in the [0-1] interval
The bottom of the subplots of the figure.
top : float in the [0-1] interval
The top of the subplots of the figure.
wspace : float in the [0-1] interval
The amount of width reserved for blank space between subplots,
expressed as a fraction of the average axis width.
hspace : float in the [0-1] interval
The amount of height reserved for white space between subplots,
expressed as a fraction of the average axis height.
suptitle : str
Title of the figure to display on top.
suptitle_color : color
Color of the subtitles
"""
# some basic checking
# ------------------------------------------------------------------------
if labels is None:
labels = []
if datasets is None:
datasets = []
show_transposed = False
if method in "with_transposed":
show_transposed = True
method = "stack"
nrow = 2
ncol = 1
datasets = [datasets, datasets] # we need to datasets
sharez = True
single = False
if not is_sequence(datasets):
single = True
datasets = [datasets] # make a list
if len(datasets) < nrow * ncol and not show_transposed:
# not enough datasets given in this list.
raise ValueError("Not enough datasets given in this list")
# if labels and len(labels) != len(datasets):
# # not enough labels given in this list.
# raise ValueError('Not enough labels given in this list')
if nrow == ncol and nrow == 1 and not show_transposed and single:
# obviously a single plot, return it
return datasets[0].plot(**kwargs)
if nrow * ncol < len(datasets):
nrow = ncol = len(datasets) // 2
if nrow * ncol < len(datasets):
ncol += 1
ndims = {dataset.ndim for dataset in datasets}
if len(ndims) > 1:
raise NotImplementedError("mixed dataset shape.")
ndim = list(ndims)[0]
# create the subplots and plot the ndarrays
# ------------------------------------------------------------------------
mpl.rcParams["figure.autolayout"] = False
figsize = kwargs.pop("figsize", None)
dpi = kwargs.pop("dpi", 150)
fig = kwargs.pop("fig", None)
if fig is None:
fig = plt.figure(figsize=figsize, dpi=dpi)
else:
fig.clf()
fig.set_size_inches(*figsize)
fig.rcParams = plt.rcParams.copy() # save params used for this figure
if suptitle is not None:
fig.suptitle(suptitle, color=suptitle_color)
# axes is dictionary with keys such as 'axe12', where the fist number
# is the row and the second the column
axes = {}
# limits
xlims = []
ylims = []
zlims = []
if sharex not in [None, True, False, "all", "col"]:
raise ValueError(
"invalid option for sharex. Should be"
" among (None, False, True, 'all' or 'col')",
)
if sharex:
sharex = "all"
if ndim == 1:
sharez = False
textsharey = "sharey"
textsharez = "sharez"
if method in ["stack"]:
sharez, sharey = sharey, sharez # we echange them
zlims, ylims = ylims, zlims
# for our internal needs as only sharex and sharey are recognized by
# matplotlib subplots
textsharey = "sharez"
textsharez = "sharey"
if sharey not in [None, False, True, "all", "col"]:
raise ValueError(
f"invalid option for {textsharey}. Should be"
" among (None, False, True, 'all' or 'row')",
)
if sharez not in [None, False, True, "all", "col", "row"]:
raise ValueError(
f"invalid option for {textsharez}. Should be"
" among (None, False, True, "
"'all', 'row' or 'col')",
)
if sharey:
sharey = "all"
if sharez:
sharez = "all"
for irow in range(nrow):
for icol in range(ncol):
idx = irow * ncol + icol
dataset = datasets[idx]
try:
label = labels[idx]
except Exception:
label = ""
_sharex = None
_sharey = None
_sharez = None
# on the type of the plot and
if (
(irow == icol and irow == 0)
or (sharex == "col" and irow == 0)
or (sharey == "row" and icol == 0)
):
ax = _Axes(fig, nrow, ncol, irow * ncol + icol + 1)
ax = fig.add_subplot(ax)
else:
if sharex == "all":
_sharex = axes["axe11"]
elif sharex == "col":
_sharex = axes[f"axe1{icol + 1}"]
if sharey == "all":
_sharey = axes["axe11"]
elif sharey == "row":
_sharey = axes[f"axe{irow + 1}1"]
# in the last dimension
if sharez == "all":
_sharez = axes["axe11"]
elif sharez == "row":
_sharez = axes[f"axe{irow + 1}1"]
elif sharez == "col":
_sharez = axes[f"axe1{icol + 1}"]
ax = _Axes(fig, nrow, ncol, idx + 1, sharex=_sharex, sharey=_sharey)
ax = fig.add_subplot(ax)
ax._sharez = _sharez # we add a new share info to the ax.
# which will be useful for the interactive masks
ax.name = f"axe{irow + 1}{icol + 1}"
axes[ax.name] = ax
if icol > 0 and sharey:
# hide the redondant ticklabels on left side of interior figures
plt.setp(axes[ax.name].get_yticklabels(), visible=False)
axes[ax.name].yaxis.set_tick_params(
which="both",
labelleft=False,
labelright=False,
)
axes[ax.name].yaxis.offsetText.set_visible(False)
if irow < nrow - 1 and sharex:
# hide the bottom ticklabels of interior rows
plt.setp(axes[ax.name].get_xticklabels(), visible=False)
axes[ax.name].xaxis.set_tick_params(
which="both",
labelbottom=False,
labeltop=False,
)
axes[ax.name].xaxis.offsetText.set_visible(False)
transposed = bool(show_transposed and irow == 1)
dataset.plot(
method=method,
ax=ax,
clear=False,
autolayout=False,
colorbar=colorbar,
data_transposed=transposed,
**kwargs,
)
ax.set_title(label, fontsize=8)
if sharex and irow < nrow - 1:
ax.xaxis.label.set_visible(False)
if sharey and icol > 0:
ax.yaxis.label.set_visible(False)
xlims.append(ax.get_xlim())
ylims.append(ax.get_ylim())
xrev = (ax.get_xlim()[1] - ax.get_xlim()[0]) < 0
# yrev = (ax.get_ylim()[1] - ax.get_ylim()[0]) < 0
# TODO: add a common color bar (set vmin and vmax using zlims)
amp = np.ptp(np.array(ylims))
ylim = [np.min(np.array(ylims) - amp * 0.01), np.max(np.array(ylims)) + amp * 0.01]
for ax in axes.values():
ax.set_ylim(ylim)
# if yrev:
# ylim = ylim[::-1]
# amp = np.ptp(np.array(xlims))
if not show_transposed:
xlim = [np.min(np.array(xlims)), np.max(np.array(xlims))]
if xrev:
xlim = xlim[::-1]
for ax in axes.values():
ax.set_xlim(xlim)
def do_tight_layout(fig, axes, suptitle, **kwargs):
# tight_layout
renderer = fig.canvas.get_renderer()
axeslist = list(axes.values())
subplots_list = list(get_subplotspec_list(axeslist))
kw = get_tight_layout_figure(
fig,
axeslist,
subplots_list,
renderer,
# pad=1.1,
h_pad=0,
w_pad=0,
rect=None,
)
left = kwargs.get("left", kw["left"])
bottom = kwargs.get("bottom", kw["bottom"])
right = kwargs.get("right", kw["right"])
top = kw["top"]
if suptitle:
top = top * 0.95
top = kwargs.get("top", top)
ws = kwargs.get("wspace", kw.get("wspace", 0) * 1.1)
hs = kwargs.get("hspace", kw.get("hspace", 0) * 1.1)
plt.subplots_adjust(
left=left,
bottom=bottom,
right=right,
top=top,
wspace=ws,
hspace=hs,
)
do_tight_layout(fig, axes, suptitle, **kwargs)
if mpl_event:
# make an event that will trigger subplot adjust each time the mouse leave
# or enter the axes or figure
def _onenter(event):
do_tight_layout(fig, axes, suptitle, **kwargs)
fig.canvas.draw()
fig.canvas.mpl_connect("axes_enter_event", _onenter)
fig.canvas.mpl_connect("axes_leave_event", _onenter)
fig.canvas.mpl_connect("figure_enter_event", _onenter)
fig.canvas.mpl_connect("figure_leave_event", _onenter)
return axes