Source code for gsplot.plot.scatter

from typing import Any

import numpy as np
from matplotlib import colors
from matplotlib.axes import Axes
from matplotlib.collections import PathCollection
from matplotlib.typing import ColorType
from numpy.typing import ArrayLike, NDArray

from ..base.base import CreateClassParams, ParamsGetter, bind_passed_params
from ..base.base_alias_validator import AliasValidator
from ..figure.axes_range_base import AxesRangeSingleton
from .line_base import AutoColor, NumLines

__all__: list[str] = ["scatter"]


class Scatter:
    """
    A class for creating scatter plots on a specified Matplotlib axis.

    Parameters
    --------------------
    ax : matplotlib.axes.Axes
        The target axis for the scatter
    x : ArrayLike
        The x-coordinates of the scatter points.
    y : ArrayLike
        The y-coordinates of the scatter points.
    color : ColorType or None, optional
        Color of the points. If `None`, a default color from the axis's cycle is used (default is `None`).
    size : int or float, optional
        Size of the scatter points (default is 1).
    alpha : int or float, optional
        Opacity of the scatter points (default is 1).
    **kwargs : Any
        Additional keyword arguments passed to the `scatter` method of Matplotlib's `Axes`.

    Attributes
    --------------------
    x : numpy.ndarray
        The x-coordinates as a NumPy array.
    y : numpy.ndarray
        The y-coordinates as a NumPy array.

    Methods
    --------------------
    get_color() -> ColorType
        Determines the color for the scatter points, either from the user input or the axis's color cycle.
    plot() -> matplotlib.collections.PathCollection
        Creates and plots the scatter points on the axis.

    Examples
    --------------------
    >>> x = [1, 2, 3, 4]
    >>> y = [10, 20, 15, 25]
    >>> scatter = Scatter(ax=ax, x=x, y=y, color="blue", size=10, alpha=0.5)
    >>> scatter.plot()
    """

    def __init__(
        self,
        ax: Axes,
        x: ArrayLike,
        y: ArrayLike,
        color: ColorType | None = None,
        size: int | float = 1,
        alpha: int | float = 1,
        **kwargs: Any,
    ) -> None:
        self.ax: Axes = ax

        self._x: ArrayLike = x
        self._y: ArrayLike = y
        self._color: ColorType | None = color
        self.size: int | float = size
        self.alpha: int | float = alpha
        self.kwargs: Any = kwargs

        self.x: NDArray[Any] = np.array(self._x)
        self.y: NDArray[Any] = np.array(self._y)

        self.color = self.get_color()

    def get_color(self) -> ColorType:
        """
        Determines the color for the scatter points.

        If a color is not explicitly provided, it retrieves a default color from the
        axis's color cycle.

        Returns
        --------------------
        ColorType
            The resolved color for the scatter points.

        Notes
        --------------------
        The method ensures compatibility with Matplotlib's color representation, converting
        NumPy arrays to hexadecimal strings if needed.

        Examples
        --------------------
        >>> scatter = Scatter(ax=ax, x=[1, 2], y=[3, 4])
        >>> scatter.get_color()
        """
        cycle_color: NDArray | str = AutoColor(self.ax).get_color()
        if isinstance(cycle_color, np.ndarray):
            cycle_color = colors.to_hex(
                tuple(cycle_color)
            )  # convert numpy array to tuple

        default_color: ColorType = cycle_color if self._color is None else self._color
        return default_color

    @NumLines.count
    @AxesRangeSingleton.update
    def plot(self) -> PathCollection:
        """
        Plots the scatter points on the specified axis.

        This method creates a scatter plot using the provided x, y, color, size, and alpha values.

        Returns
        --------------------
        matplotlib.collections.PathCollection
            The scatter plot as a PathCollection object.

        Notes
        --------------------
        - This method is decorated with `@NumLines.count` to track the number of scatter calls on the axis.
        - It is also decorated with `@AxesRangeSingleton.update` to update the axis range with the scatter data.

        Examples
        --------------------
        >>> scatter = Scatter(ax=ax, x=[1, 2, 3], y=[4, 5, 6], size=50, alpha=0.8)
        >>> scatter.plot()
        <matplotlib.collections.PathCollection>
        """
        _plot = self.ax.scatter(
            self.x,
            self.y,
            s=self.size,
            color=self.color,
            alpha=self.alpha,
            **self.kwargs,
        )
        return _plot


[docs] @bind_passed_params() def scatter( ax: Axes, x: ArrayLike, y: ArrayLike, color: ColorType | None = None, size: int | float = 1, alpha: int | float = 1, **kwargs: Any, ) -> PathCollection: """ Creates a scatter plot on the specified axis. This function uses the `Scatter` class to generate a scatter plot with customizable size, color, and transparency. Parameters -------------------- ax : matplotlib.axes.Axes The target axis for the scatter. x : ArrayLike The x-coordinates of the scatter points. y : ArrayLike The y-coordinates of the scatter points. color : ColorType or None, optional Color of the points. If `None`, a default color from the axis's cycle is used (default is `None`). size : int or float, optional Size of the scatter points (default is 1). alpha : int or float, optional Opacity of the scatter points (default is 1). **kwargs : Any Additional keyword arguments passed to the `scatter` method of Matplotlib's `Axes`. Notes -------------------- - This function utilizes the `ParamsGetter` to retrieve bound parameters and the `CreateClassParams` class to handle the merging of default, configuration, and passed parameters. - Alias validation is performed using the `AliasValidator` class. - 's' (size) Returns -------------------- matplotlib.collections.PathCollection The scatter plot as a PathCollection object. Examples -------------------- >>> import gsplot as gs >>> x = [1, 2, 3, 4] >>> y = [10, 20, 15, 25] >>> gs.scatter(ax=ax, x=x, y=y, color="red", size=20, alpha=0.8) <matplotlib.collections.PathCollection> """ alias_map = { "s": "size", } passed_params: dict[str, Any] = ParamsGetter("passed_params").get_bound_params() AliasValidator(alias_map, passed_params).validate() class_params: dict[str, Any] = CreateClassParams(passed_params).get_class_params() _scatter = Scatter( class_params["ax"], class_params["x"], class_params["y"], class_params["color"], class_params["size"], class_params["alpha"], **class_params["kwargs"], ) return _scatter.plot()