Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions xarray_plotly/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if TYPE_CHECKING:
import pandas as pd
from xarray import DataArray
from xarray import DataArray, Dataset


class _AUTO:
Expand Down Expand Up @@ -251,7 +251,21 @@ def _get_qualitative_scale_names() -> frozenset[str]:
)


def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]:
def _sample_colorscale(name: str, n: int) -> list[str]:
"""Sample *n* evenly-spaced colors from a named Plotly colorscale."""
scale = px.colors.get_colorscale(name)
samplepoints = [i / max(n - 1, 1) for i in range(n)]
result: list[str] = px.colors.sample_colorscale(scale, samplepoints)
return result


def resolve_colors(
colors: Colors,
px_kwargs: dict[str, Any],
*,
color_dim: Hashable | None = None,
darray: DataArray | Dataset | None = None,
) -> dict[str, Any]:
"""Map unified `colors` parameter to appropriate Plotly px_kwargs.

Direct color_* kwargs take precedence and trigger a warning if
Expand All @@ -260,6 +274,14 @@ def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]:
Args:
colors: Unified color specification (str, list, dict, or None).
px_kwargs: Existing kwargs to pass to Plotly Express.
color_dim: Dimension name used for discrete color grouping.
When provided together with *darray*, a continuous colorscale
string is sampled into a discrete sequence whose length
matches the number of coordinates along this dimension.
Use for chart types that only accept discrete color
parameters (line, area, box, pie).
darray: Source DataArray or Dataset; used with *color_dim* to
determine the number of discrete colors to sample.

Returns:
Updated px_kwargs with color parameters injected.
Expand All @@ -284,6 +306,10 @@ def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]:
# Check if it's a qualitative (discrete) palette name
if colors in _get_qualitative_scale_names():
px_kwargs["color_discrete_sequence"] = getattr(px.colors.qualitative, colors)
elif color_dim is not None and darray is not None:
# Sample from continuous scale into a discrete sequence
n = darray.sizes[color_dim]
px_kwargs["color_discrete_sequence"] = _sample_colorscale(colors, n)
else:
# Assume continuous scale
px_kwargs["color_continuous_scale"] = colors
Expand Down
10 changes: 5 additions & 5 deletions xarray_plotly/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def line(
-------
plotly.graph_objects.Figure
"""
px_kwargs = resolve_colors(colors, px_kwargs)
slots = assign_slots(
list(darray.dims),
"line",
Expand All @@ -94,6 +93,7 @@ def line(
facet_row=facet_row,
animation_frame=animation_frame,
)
px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray)

df = to_dataframe(darray)
value_col = get_value_col(darray)
Expand Down Expand Up @@ -329,7 +329,6 @@ def fast_bar(
-------
plotly.graph_objects.Figure
"""
px_kwargs = resolve_colors(colors, px_kwargs)
slots = assign_slots(
list(darray.dims),
"fast_bar",
Expand All @@ -339,6 +338,7 @@ def fast_bar(
facet_row=facet_row,
animation_frame=animation_frame,
)
px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray)

df = to_dataframe(darray)
value_col = get_value_col(darray)
Expand Down Expand Up @@ -410,7 +410,6 @@ def area(
-------
plotly.graph_objects.Figure
"""
px_kwargs = resolve_colors(colors, px_kwargs)
slots = assign_slots(
list(darray.dims),
"area",
Expand All @@ -421,6 +420,7 @@ def area(
facet_row=facet_row,
animation_frame=animation_frame,
)
px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray)

df = to_dataframe(darray)
value_col = get_value_col(darray)
Expand Down Expand Up @@ -487,7 +487,6 @@ def box(
-------
plotly.graph_objects.Figure
"""
px_kwargs = resolve_colors(colors, px_kwargs)
slots = assign_slots(
list(darray.dims),
"box",
Expand All @@ -498,6 +497,7 @@ def box(
facet_row=facet_row,
animation_frame=animation_frame,
)
px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray)

df = to_dataframe(darray)
value_col = get_value_col(darray)
Expand Down Expand Up @@ -746,14 +746,14 @@ def pie(
-------
plotly.graph_objects.Figure
"""
px_kwargs = resolve_colors(colors, px_kwargs)
slots = assign_slots(
list(darray.dims),
"pie",
names=names,
facet_col=facet_col,
facet_row=facet_row,
)
px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("names"), darray=darray)

df = to_dataframe(darray)
value_col = get_value_col(darray)
Expand Down