# ======================================================================================
# Copyright (©) 2015-2025 LCS - Laboratoire Catalyse et Spectrochimie, Caen, France.
# CeCILL-B FREE SOFTWARE LICENSE AGREEMENT
# See full LICENSE agreement in the root directory.
# ======================================================================================
import textwrap
import matplotlib as mpl
import matplotlib.axes as maxes
import mpl_toolkits.mplot3d.axes3d as maxes3D # noqa: N812
import numpy as np
from matplotlib import pyplot as plt
from pint import __version__
pint_version = int(__version__.split(".")[1])
@maxes.subplot_class_factory
class _Axes(maxes.Axes):
"""Subclass of matplotlib Axes class."""
from spectrochempy.core.units import remove_args_units
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# def draw(self, renderer):
# # # with plt.rc_context({"something": self.xxx}):
# return super().draw(renderer)
@remove_args_units
def plot(self, *args, **kwargs):
return super().plot(*args, **kwargs)
@remove_args_units
def errorbar(self, *args, **kwargs):
return super().errorbar(*args, **kwargs)
@remove_args_units
def scatter(self, *args, **kwargs):
return super().scatter(*args, **kwargs)
@remove_args_units
def plot_date(self, *args, **kwargs):
return super().plot_date(*args, **kwargs)
@remove_args_units
def step(self, *args, **kwargs):
return super().step(*args, **kwargs)
@remove_args_units
def loglog(self, *args, **kwargs):
return super().loglog(*args, **kwargs)
@remove_args_units
def semilogx(self, *args, **kwargs):
return super().semilogx(*args, **kwargs)
@remove_args_units
def semilogy(self, *args, **kwargs):
return super().semilogy(*args, **kwargs)
@remove_args_units
def fill_between(self, *args, **kwargs):
return super().fill_between(*args, **kwargs)
@remove_args_units
def fill_betweenx(self, *args, **kwargs):
return super().fill_betweenx(*args, **kwargs)
@remove_args_units
def bar(self, *args, **kwargs):
return super().bar(*args, **kwargs)
@remove_args_units
def barh(self, *args, **kwargs):
return super().barh(*args, **kwargs)
@remove_args_units
def bar_label(self, *args, **kwargs):
return super().bar_label(*args, **kwargs)
@remove_args_units
def stem(self, *args, **kwargs):
return super().stem(*args, **kwargs)
@remove_args_units
def eventplot(self, *args, **kwargs):
return super().eventplot(*args, **kwargs)
@remove_args_units
def pie(self, *args, **kwargs):
return super().pie(*args, **kwargs)
@remove_args_units
def stackplot(self, *args, **kwargs):
return super().stackplot(*args, **kwargs)
@remove_args_units
def broken_barh(self, *args, **kwargs):
return super().broken_barh(*args, **kwargs)
@remove_args_units
def vlines(self, *args, **kwargs):
return super().vlines(*args, **kwargs)
@remove_args_units
def hlines(self, *args, **kwargs):
return super().hlines(*args, **kwargs)
@remove_args_units
def fill(self, *args, **kwargs):
return super().fill(*args, **kwargs)
@remove_args_units
def axhline(self, *args, **kwargs):
return super().axhline(*args, **kwargs)
@remove_args_units
def axhspan(self, *args, **kwargs):
return super().axhspan(*args, **kwargs)
@remove_args_units
def axvline(self, *args, **kwargs):
return super().axvline(*args, **kwargs)
@remove_args_units
def axvspan(self, *args, **kwargs):
return super().axvspan(*args, **kwargs)
@remove_args_units
def axline(self, *args, **kwargs):
return super().axline(*args, **kwargs)
@remove_args_units
def acorr(self, *args, **kwargs):
return super().acorr(*args, **kwargs)
@remove_args_units
def angle_spectrum(self, *args, **kwargs):
return super().angle_spectrum(*args, **kwargs)
@remove_args_units
def cohere(self, *args, **kwargs):
return super().cohere(*args, **kwargs)
@remove_args_units
def csd(self, *args, **kwargs):
return super().csd(*args, **kwargs)
@remove_args_units
def magnitude_spectrum(self, *args, **kwargs):
return super().magnitude_spectrum(*args, **kwargs)
@remove_args_units
def phase_spectrum(self, *args, **kwargs):
return super().phase_spectrum(*args, **kwargs)
@remove_args_units
def psd(self, *args, **kwargs):
return super().psd(*args, **kwargs)
@remove_args_units
def specgram(self, *args, **kwargs):
return super().specgram(*args, **kwargs)
@remove_args_units
def xcorr(self, *args, **kwargs):
return super().xcorr(*args, **kwargs)
@remove_args_units
def boxplot(self, *args, **kwargs):
return super().boxplot(*args, **kwargs)
@remove_args_units
def violinplot(self, *args, **kwargs):
return super().violinplot(*args, **kwargs)
@remove_args_units
def violin(self, *args, **kwargs):
return super().violin(*args, **kwargs)
@remove_args_units
def bxp(self, *args, **kwargs):
return super().bxp(*args, **kwargs)
@remove_args_units
def hexbin(self, *args, **kwargs):
return super().hexbin(*args, **kwargs)
@remove_args_units
def hist(self, *args, **kwargs):
return super().hist(*args, **kwargs)
@remove_args_units
def hist2d(self, *args, **kwargs):
return super().hist2d(*args, **kwargs)
@remove_args_units
def stairs(self, *args, **kwargs):
return super().stairs(*args, **kwargs)
@remove_args_units
def contour(self, *args, **kwargs):
return super().contour(*args, **kwargs)
@remove_args_units
def contourf(self, *args, **kwargs):
return super().contourf(*args, **kwargs)
@remove_args_units
def imshow(self, *args, **kwargs):
return super().imshow(*args, **kwargs)
@remove_args_units
def matshow(self, *args, **kwargs):
return super().matshow(*args, **kwargs)
@remove_args_units
def pcolor(self, *args, **kwargs):
return super().pcolor(*args, **kwargs)
@remove_args_units
def pcolorfast(self, *args, **kwargs):
return super().pcolorfast(*args, **kwargs)
@remove_args_units
def pcolormesh(self, *args, **kwargs):
return super().pcolormesh(*args, **kwargs)
@remove_args_units
def spy(self, *args, **kwargs):
"""Plot a spy."""
return super().spy(*args, **kwargs)
@remove_args_units
def tripcolor(self, *args, **kwargs):
"""Plot a tripcolor."""
return super().tripcolor(*args, **kwargs)
@remove_args_units
def triplot(self, *args, **kwargs):
"""Plot a triplot."""
return super().triplot(*args, **kwargs)
@remove_args_units
def tricontour(self, *args, **kwargs):
"""Plot a tricontour."""
return super().tricontour(*args, **kwargs)
@remove_args_units
def tricontourf(self, *args, **kwargs):
"""Plot a tricontourf."""
return super().tricontourf(*args, **kwargs)
@remove_args_units
def annotate(self, *args, **kwargs):
"""Add an annotation to the axes."""
return super().annotate(*args, **kwargs)
@remove_args_units
def text(self, *args, **kwargs):
"""Add text to the axes."""
return super().text(*args, **kwargs)
@remove_args_units
def table(self, *args, **kwargs):
"""Add a table to the axes."""
return super().table(*args, **kwargs)
@remove_args_units
def arrow(self, *args, **kwargs):
"""Add an arrow to the axes."""
return super().arrow(*args, **kwargs)
@remove_args_units
def set_xlim(self, *args, **kwargs):
"""Set the x-axis limits."""
return super().set_xlim(*args, **kwargs)
@remove_args_units
def set_ylim(self, *args, **kwargs):
"""Set the y-axis limits."""
return super().set_ylim(*args, **kwargs)
class _Axes3D(maxes3D.Axes3D):
"""Subclass of matplotlib Axes3D class."""
from spectrochempy.core.units import remove_args_units
def __init__(self, *args, **kwargs):
"""Initialize the 3D axes."""
super().__init__(*args, **kwargs)
@remove_args_units
def plot_surface(self, *args, **kwargs):
"""Plot a surface."""
return super().plot_surface(*args, **kwargs)
def plot_method(type, doc):
"""Select a plot method from the function name."""
def decorator_plot_method(func):
method = func.__name__.split("plot_")[-1]
def wrapper(dataset, *args, **kwargs):
if dataset.ndim < 2:
from spectrochempy.core.plotters.plot1d import plot_1D
_ = kwargs.pop("method", None)
return plot_1D(dataset, *args, method=method, **kwargs)
if kwargs.get("use_plotly", False):
return dataset.plotly(method=method, **kwargs)
return getattr(dataset, f"plot_{type}")(*args, method=method, **kwargs)
wrapper.__doc__ = f"""
{textwrap.dedent(func.__doc__).strip()}
Parameters
----------
dataset : `NDDataset`
The dataset to plot.
**kwargs
Optional keyword parameters (see Other Parameters).
Other Parameters
----------------
{doc.strip()}
See Also
--------
plot_1D
plot_pen
plot_bar
plot_scatter_pen
plot_multiple
plot_2D
plot_stack
plot_map
plot_image
plot_3D
plot_surface
plot_waterfall
multiplot
""".replace(f"\nplot_{method}", "")
return wrapper
return decorator_plot_method
# color conversion function
def cmyk2rgb(C, M, Y, K):
"""
CMYK to RGB conversion.
C,M,Y,K are given in percent.
The R,G,B values are returned in the range of 0..1.
"""
C, Y, M, K = C / 100.0, Y / 100.0, M / 100.0, K / 100.0
# The red (R) color is calculated from the cyan (C) and black (K) colors:
R = (1.0 - C) * (1.0 - K)
# The green color (G) is calculated from the magenta (M) and black (K) colors:
G = (1.0 - M) * (1.0 - K)
# The blue color (B) is calculated from the yellow (Y) and black (K) colors:
B = (1.0 - Y) * (1.0 - K)
return R, G, B
# Constants
# --------------------------------------------------------------------------------------
# For color blind people, it is safe to use only 4 colors in graphs:
# see http://jfly.iam.u-tokyo.ac.jp/color/ichihara_etal_2008.pdf
# Black CMYK=0,0,0,0
# Red CMYK= 0, 77, 100, 0 %
# Blue CMYK= 100, 30, 0, 0 %
# Green CMYK= 85, 0, 60, 10 %
NBlack = (0, 0, 0)
NRed = cmyk2rgb(0, 77, 100, 0)
NBlue = cmyk2rgb(100, 30, 0, 0)
NGreen = cmyk2rgb(85, 0, 60, 10)
# TODO : make a color cycle based on these colors
def figure(preferences=None, **kwargs):
"""
Open a new figure.
Parameters
----------
Kwargs : any
Keywords arguments to be passed to the matplotlib figure constructor.
Preferences : Meta dictionary
Per object saved plot configuration.
"""
from spectrochempy.core.dataset.baseobjects.meta import Meta
if preferences is None:
preferences = Meta()
return get_figure(preferences=preferences, **kwargs)
[docs]
def show():
"""Force the `matplotlib` figure display."""
from spectrochempy import NO_DISPLAY
if NO_DISPLAY:
plt.close("all")
elif get_figure(clear=False):
plt.show(block=True)
def get_figure(**kwargs):
"""
Get the figure where to plot.
Parameters
----------
clear : bool
If False the last used figure is returned.
figsize : 2-tuple of floats, default: rcParams["figure.figsize"])
Figure dimension (width, height) in inches.
dpi : float, default: rcParams["figure.dpi"] (default: 100.0)
Dots per inch.
facecolor : default: rcParams["figure.facecolor"] (default: 'white')
The figure patch facecolor.
edgecolor : default: preferences.figure_edgecolor (default: 'white')
The figure patch edge color.
frameon : bool, default: preferences.figure_frameon (default: True)
If False, suppress drawing the figure background patch.
tight_layout : bool or dict, default: preferences.figure.autolayout
If False use subplotpars. If True adjust subplot parameters using tight_layout
with default padding. When providing a dict containing the keys pad, w_pad,
h_pad, and rect, the default tight_layout paddings will be overridden.
constrained_layout : bool, default: preferences.figure_constrained_layout
If True use constrained layout to adjust positioning of plot elements.
Like tight_layout, but designed to be more flexible.
See Constrained Layout Guide for examples.
preferences : Meta object,
Per object plot configuration.
Returns
-------
matplotlib figure instance
"""
n = plt.get_fignums()
clear = kwargs.get("clear", True)
if not n or clear:
# create a figure
prefs = kwargs.pop("preferences", None)
if prefs is None:
return None
figsize = kwargs.get("figsize", prefs.figure_figsize)
dpi = int(kwargs.get("dpi", prefs.figure_dpi))
facecolor = kwargs.get("facecolor", prefs.figure_facecolor)
edgecolor = kwargs.get("edgecolor", prefs.figure_edgecolor)
frameon = kwargs.get("frameon", prefs.figure_frameon)
tight_layout = kwargs.get("autolayout", prefs.figure_autolayout)
# get the current figure (or the last used)
fig = plt.figure(figsize=figsize)
fig.set_dpi(dpi)
fig.set_frameon(frameon)
try:
fig.set_edgecolor(edgecolor)
except ValueError:
fig.set_edgecolor(eval(edgecolor)) # noqa: S307
try:
fig.set_facecolor(facecolor)
except ValueError:
try:
fig.set_facecolor(eval(facecolor)) # noqa: S307
except ValueError:
fig.set_facecolor("#" + eval(facecolor)) # noqa: S307
fig.set_dpi(dpi)
fig.set_tight_layout(tight_layout)
return fig
# a figure already exists - if several we take the last
return plt.figure(n[-1])
# FOR PLOTLY
def get_plotly_figure(clear=True, fig=None, **kwargs):
"""
Get the figure where to plot.
Parameters
----------
clear : bool
If False the figure provided in the `fig` parameters is used.
fig : plotly figure
If provided, and clear is not True, it will be used for plotting
kwargs : any
Keywords arguments to be passed to the plotly figure constructor.
Returns
-------
Plotly figure instance
"""
from spectrochempy.utils.optional import import_optional_dependency
go = import_optional_dependency("plotly.graph_objects", errors="ignore")
if go is None:
raise ImportError("Plotly is not installed. Uee pip or conda to install it")
if clear or fig is None:
# create a figure
return go.Figure()
# a figure already exists - if several we take the last
return fig
class colorscale:
def normalize(self, vmin, vmax, cmap="viridis", rev=False, offset=0):
"""Normalize the color scale based on the given parameters."""
if rev:
cmap = cmap + "_r"
_colormap = plt.get_cmap(cmap)
_norm = mpl.colors.Normalize(vmin=vmin - offset, vmax=vmax - offset)
self.scalarMap = mpl.cm.ScalarMappable(norm=_norm, cmap=_colormap)
def rgba(self, z, offset=0):
"""Return the rgba color for the given value."""
c = np.array(self.scalarMap.to_rgba(z.squeeze() - offset))
c[0:3] *= 255
c[0:3] = np.round(c[0:3].astype("uint16"), 0)
return f"rgba{tuple(c)}"
colorscale = colorscale()
def make_label(ss, lab="<no_axe_label>", use_mpl=True):
"""Make a label from title and units."""
if ss is None:
return lab
label = ss.title if ss.title else lab # .replace(' ', r'\ ')
if "<untitled>" in label:
label = "values"
if use_mpl:
if ss.units is not None and str(ss.units) not in [
"dimensionless",
"absolute_transmittance",
]:
units = rf"/\ {ss.units:~L}"
if pint_version < 24:
units = units.replace("%", r"\%")
else:
units = ""
label = rf"{label} $\mathrm{{{units}}}$"
else:
if ss.units is not None and str(ss.units) != "dimensionless":
units = rf"{ss.units:~H}"
else:
units = ""
label = rf"{label} / {units}"
return label
def make_attr(key):
name = f"M_{key[1]}"
k = rf"$\mathrm{{{name}}}$"
if "P" in name:
m = "o"
c = NBlack
elif "A" in name:
m = "^"
c = NBlue
elif "B" in name:
m = "s"
c = NRed
if "400" in key:
f = "w"
s = ":"
else:
f = c
s = "-"
return m, c, k, f, s