Source code for gsplot.style.legend

import warnings
from typing import Any

import matplotlib.pyplot as plt
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.legend import Legend as Lg
from matplotlib.legend_handler import HandlerBase

from ..base.base import CreateClassParams, ParamsGetter, bind_passed_params

__all__: list[str] = [
    "legend",
    "legend_axes",
    "legend_handlers",
    "legend_reverse",
    "legend_get_handlers",
]


class Legend:
    """
    A class to manage legends for a specific Matplotlib axis.

    This class provides functionality for customizing, reversing, and managing
    legends on a specific axis in a Matplotlib figure.

    Parameters
    --------------------
    ax : Axes
        The target axis for the legend.
    handles : list[Any], optional
        A list of handles for the legend.
    labels : list[str], optional
        A list of labels for the legend.
    handlers : dict, optional
        A dictionary of custom legend handlers.
    *args : Any
        Additional positional arguments for the legend.
    **kwargs : Any
        Additional keyword arguments for the legend.

    Attributes
    --------------------
    handles : list[Any] | None
        The legend handles.
    labels : list[str] | None
        The legend labels.
    handlers : dict | None
        The custom legend handlers.

    Methods
    --------------------
    get_legend_handlers() -> tuple[list[Artist], list[str], dict[Artist, HandlerBase]]
        Retrieves the legend handles, labels, and associated handlers.
    legend() -> matplotlib.legend.Legend
        Adds a legend to the axis.
    legend_handlers() -> matplotlib.legend.Legend
        Adds a legend with custom handles, labels, and handlers.
    reverse_legend() -> matplotlib.legend.Legend
        Adds a legend to the axis with reversed order of handles and labels.
    """

    def __init__(
        self,
        ax: Axes,
        handles: list[Any] | None = None,
        labels: list[str] | None = None,
        handlers: dict | None = None,
        *args: Any,
        **kwargs: Any
    ):
        self.ax: Axes = ax
        self.handles: list[Any] | None = handles
        self.labels: list[str] | None = labels
        self.handlers: dict | None = handlers
        self.args: Any = args
        self.kwargs: Any = kwargs

    def get_legend_handlers(
        self,
    ) -> tuple[list[Artist], list[str], dict[Artist, HandlerBase]]:
        """
        Retrieves the legend handles, labels, and associated handlers for the target axis.

        Returns
        --------------------
        tuple[list[Artist], list[str], dict[Artist, HandlerBase]]
            - handles: The list of legend handles.
            - labels: The list of legend labels.
            - handlers: A dictionary mapping handles to their legend handlers.
        """

        handles, labels = self.ax.get_legend_handles_labels()

        handler_map = Lg(parent=self.ax, handles=[], labels=[]).get_legend_handler_map()

        handlers = {}
        for handle in handles:
            if type(handle) in handler_map:
                print(handle)
                handlers[handle] = handler_map[type(handle)]
            else:
                # if handle is not in handler_map, pass
                pass

        return handles, labels, handlers

    def legend(self) -> Lg:
        """
        Adds a legend to the target axis.

        Returns
        --------------------
        matplotlib.legend.Legend
            The created legend object.
        """
        _lg = self.ax.legend(*self.args, **self.kwargs)

        return _lg

    def legend_handlers(self) -> Lg:
        """
        Adds a legend with custom handles, labels, and handlers to the target axis.

        Returns
        --------------------
        matplotlib.legend.Legend
            The created legend object with the provided custom handlers.
        """
        _lg = self.ax.legend(
            handles=self.handles,
            labels=self.labels,
            handler_map=self.handlers,
            *self.args,
            **self.kwargs,
        )

        return _lg

    def reverse_legend(self) -> Lg:
        """
        Adds a legend to the target axis with reversed order of handles and labels.

        Returns
        --------------------
        matplotlib.legend.Legend
            The created legend object with reversed order.
        """

        handles, labels, handlers = self.get_legend_handlers()
        _lg = self.ax.legend(
            handles=handles[::-1],
            labels=labels[::-1],
            handler_map=handlers,
            *self.args,
            **self.kwargs,
        )
        return _lg


class LegendAxes:
    """
    A class to manage legends for all axes in the current Matplotlib figure.

    Parameters
    --------------------
    *args : Any
        Additional positional arguments for legends.
    **kwargs : Any
        Additional keyword arguments for legends.

    Attributes
    --------------------
    args : Any
        Positional arguments for legends.
    kwargs : Any
        Keyword arguments for legends.

    Methods
    --------------------
    legend_axes() -> list[matplotlib.legend.Legend]
        Adds legends to all axes in the current figure.
    """

    def __init__(self, *args: Any, **kwargs: Any):
        self.args: Any = args
        self.kwargs: Any = kwargs

    def legend_axes(self) -> list[Lg]:
        """
        Adds legends to all axes in the current Matplotlib figure.

        Returns
        --------------------
        list[matplotlib.legend.Legend]
            A list of legend objects created for each axis.
        """
        _lg_list = []
        for ax in plt.gcf().axes:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                _lg = ax.legend(*self.args, **self.kwargs)
                _lg_list.append(_lg)
        return _lg_list


[docs] @bind_passed_params() def legend(ax: Axes, *args: Any, **kwargs: Any) -> Lg: """ Adds a legend to the specified axis. Parameters -------------------- ax : matplotlib.axes.Axes The target axis for the legend. *args : Any Additional positional arguments for the legend. **kwargs : Any Additional keyword arguments for the legend. Notes -------------------- This function utilizes the `ParamsGetter` to retrieve bound parameters and the `CreateClassParams` class to handle the merging of default, configuration, and passed parameters. Returns -------------------- matplotlib.legend.Legend The created legend object. Examples -------------------- >>> import matplotlib.pyplot as plt >>> import numpy as np >>> import gsplot as gs >>> x = np.linspace(0, 10, 100) >>> plt.plot(x, np.sin(x), label="Sine") >>> plt.plot(x, np.cos(x), label="Cosine") >>> gs.legend(plt.gcf()[0]) # Adds legend to the first axis >>> plt.show() """ passed_params: dict[str, Any] = ParamsGetter("passed_params").get_bound_params() class_params = CreateClassParams(passed_params).get_class_params() _legend = Legend( class_params["ax"], *class_params["args"], **class_params["kwargs"], ) return _legend.legend()
[docs] @bind_passed_params() def legend_axes(*args: Any, **kwargs: Any) -> list[Lg]: """ Adds legends to all axes in the current Matplotlib figure. Parameters -------------------- *args : Any Additional positional arguments for legends. **kwargs : Any Additional keyword arguments for legends. Notes -------------------- This function utilizes the `ParamsGetter` to retrieve bound parameters and the `CreateClassParams` class to handle the merging of default, configuration, and passed parameters. Returns -------------------- list[matplotlib.legend.Legend] A list of legend objects created for each axis. Examples -------------------- >>> import matplotlib.pyplot as plt >>> fig, axes = plt.subplots(2, 1) >>> import gsplot as gs >>> axes[0].plot([1, 2, 3], [4, 5, 6], label="Line 1") >>> axes[1].plot([1, 2, 3], [6, 5, 4], label="Line 2") >>> gs.legend_axes() # Adds legends to all axes >>> plt.show() """ passed_params: dict[str, Any] = ParamsGetter("passed_params").get_bound_params() class_params = CreateClassParams(passed_params).get_class_params() _legend_axes = LegendAxes( *class_params["args"], **class_params["kwargs"], ) return _legend_axes.legend_axes()
[docs] @bind_passed_params() def legend_handlers( ax: Axes, handles: list[Any] | None = None, labels: list[str] | None = None, handlers: dict | None = None, *args: Any, **kwargs: Any ) -> Lg: """ Adds a legend with custom handles, labels, and handlers to the specified axis. Parameters -------------------- ax : Axes The target axis for the legend. Can be an axis index or an `Axes` object. handles : list[Any], optional A list of custom handles for the legend. labels : list[str], optional A list of custom labels for the legend. handlers : dict, optional A dictionary of custom legend handlers. *args : Any Additional positional arguments for the legend. **kwargs : Any Additional keyword arguments for the legend. Notes -------------------- This function utilizes the `ParamsGetter` to retrieve bound parameters and the `CreateClassParams` class to handle the merging of default, configuration, and passed parameters. Returns -------------------- matplotlib.legend.Legend The created legend object with the provided custom handlers. Examples -------------------- >>> import matplotlib.pyplot as plt >>> from matplotlib.lines import Line2D >>> import gsplot as gs >>> fig, ax = plt.subplots() >>> ax.plot([0, 1], [0, 1], label="Line A") >>> custom_handle = [Line2D([0], [0], color="r", lw=2)] >>> gs.legend_handlers(ax, handles=custom_handle, labels=["Custom Line"]) >>> plt.show() """ passed_params: dict[str, Any] = ParamsGetter("passed_params").get_bound_params() class_params = CreateClassParams(passed_params).get_class_params() _legend = Legend( class_params["ax"], class_params["handles"], class_params["labels"], class_params["handlers"], *class_params["args"], **class_params["kwargs"], ) return _legend.legend_handlers()
[docs] @bind_passed_params() def legend_reverse( ax: Axes, handles: list[Any] | None = None, labels: list[str] | None = None, handlers: dict | None = None, *args: Any, **kwargs: Any ) -> Lg: """ Adds a legend to the specified axis with reversed order of handles and labels. Parameters -------------------- ax : Axes The target axis for the legend. Can be an axis index or an `Axes` object. handles : list[Any], optional A list of custom handles for the legend. labels : list[str], optional A list of custom labels for the legend. handlers : dict, optional A dictionary of custom legend handlers. *args : Any Additional positional arguments for the legend. **kwargs : Any Additional keyword arguments for the legend. Notes -------------------- This function utilizes the `ParamsGetter` to retrieve bound parameters and the `CreateClassParams` class to handle the merging of default, configuration, and passed parameters. Returns -------------------- matplotlib.legend.Legend The created legend object with reversed order. Examples -------------------- >>> import matplotlib.pyplot as plt >>> x = [1, 2, 3] >>> y1 = [4, 5, 6] >>> y2 = [6, 5, 4] >>> plt.plot(x, y1, label="Line 1") >>> plt.plot(x, y2, label="Line 2") >>> legend_reverse(plt.gca()[0]) # Reverses the legend order >>> plt.show() """ passed_params: dict[str, Any] = ParamsGetter("passed_params").get_bound_params() class_params = CreateClassParams(passed_params).get_class_params() _legend = Legend( class_params["ax"], class_params["handles"], class_params["labels"], class_params["handlers"], *class_params["args"], **class_params["kwargs"], ) return _legend.reverse_legend()
[docs] def legend_get_handlers( ax: Axes, ) -> tuple: """ Retrieves the legend handles, labels, and associated handlers for the specified axis. Parameters -------------------- ax : Axes The target axis for retrieving the legend handlers. Can be an axis index or an `Axes` object. Returns -------------------- tuple - handles: The list of legend handles. - labels: The list of legend labels. - handlers: A dictionary mapping handles to their legend handlers. Examples -------------------- >>> import matplotlib.pyplot as plt >>> import gsplot as gs >>> fig, ax = plt.subplots() >>> ax.plot([0, 1], [0, 1], label="Line A") >>> ax.plot([1, 0], [0, 1], label="Line B") >>> handles, labels, handlers = gs.legend_get_handlers(ax) >>> print("Handles:", handles) >>> print("Labels:", labels) >>> print("Handlers:", handlers) """ return Legend(ax).get_legend_handlers()