From 23edbc39ad828c004916fa51ec53c7db76b8727b Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:55:37 +0100 Subject: [PATCH] fix: continuous colorscale strings with discrete-only chart types Passing a continuous colorscale like `colors='turbo'` to area(), line(), fast_bar(), box(), or pie() previously failed because `resolve_colors` set `color_continuous_scale` which these px functions don't accept. Now samples from the continuous scale into a discrete sequence for these chart types. bar() and scatter() natively support continuous scales so are left unchanged. Co-Authored-By: Claude Opus 4.6 --- xarray_plotly/common.py | 30 ++++++++++++++++++++++++++++-- xarray_plotly/plotting.py | 10 +++++----- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/xarray_plotly/common.py b/xarray_plotly/common.py index 786e839..7805d4b 100644 --- a/xarray_plotly/common.py +++ b/xarray_plotly/common.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: import pandas as pd - from xarray import DataArray + from xarray import DataArray, Dataset class _AUTO: @@ -251,7 +251,21 @@ def _get_qualitative_scale_names() -> frozenset[str]: ) -def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]: +def _sample_colorscale(name: str, n: int) -> list[str]: + """Sample *n* evenly-spaced colors from a named Plotly colorscale.""" + scale = px.colors.get_colorscale(name) + samplepoints = [i / max(n - 1, 1) for i in range(n)] + result: list[str] = px.colors.sample_colorscale(scale, samplepoints) + return result + + +def resolve_colors( + colors: Colors, + px_kwargs: dict[str, Any], + *, + color_dim: Hashable | None = None, + darray: DataArray | Dataset | None = None, +) -> dict[str, Any]: """Map unified `colors` parameter to appropriate Plotly px_kwargs. Direct color_* kwargs take precedence and trigger a warning if @@ -260,6 +274,14 @@ def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]: Args: colors: Unified color specification (str, list, dict, or None). px_kwargs: Existing kwargs to pass to Plotly Express. + color_dim: Dimension name used for discrete color grouping. + When provided together with *darray*, a continuous colorscale + string is sampled into a discrete sequence whose length + matches the number of coordinates along this dimension. + Use for chart types that only accept discrete color + parameters (line, area, box, pie). + darray: Source DataArray or Dataset; used with *color_dim* to + determine the number of discrete colors to sample. Returns: Updated px_kwargs with color parameters injected. @@ -284,6 +306,10 @@ def resolve_colors(colors: Colors, px_kwargs: dict[str, Any]) -> dict[str, Any]: # Check if it's a qualitative (discrete) palette name if colors in _get_qualitative_scale_names(): px_kwargs["color_discrete_sequence"] = getattr(px.colors.qualitative, colors) + elif color_dim is not None and darray is not None: + # Sample from continuous scale into a discrete sequence + n = darray.sizes[color_dim] + px_kwargs["color_discrete_sequence"] = _sample_colorscale(colors, n) else: # Assume continuous scale px_kwargs["color_continuous_scale"] = colors diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index a0ceeb9..a45cbd5 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -82,7 +82,6 @@ def line( ------- plotly.graph_objects.Figure """ - px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "line", @@ -94,6 +93,7 @@ def line( facet_row=facet_row, animation_frame=animation_frame, ) + px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray) df = to_dataframe(darray) value_col = get_value_col(darray) @@ -329,7 +329,6 @@ def fast_bar( ------- plotly.graph_objects.Figure """ - px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "fast_bar", @@ -339,6 +338,7 @@ def fast_bar( facet_row=facet_row, animation_frame=animation_frame, ) + px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray) df = to_dataframe(darray) value_col = get_value_col(darray) @@ -410,7 +410,6 @@ def area( ------- plotly.graph_objects.Figure """ - px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "area", @@ -421,6 +420,7 @@ def area( facet_row=facet_row, animation_frame=animation_frame, ) + px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray) df = to_dataframe(darray) value_col = get_value_col(darray) @@ -487,7 +487,6 @@ def box( ------- plotly.graph_objects.Figure """ - px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "box", @@ -498,6 +497,7 @@ def box( facet_row=facet_row, animation_frame=animation_frame, ) + px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("color"), darray=darray) df = to_dataframe(darray) value_col = get_value_col(darray) @@ -746,7 +746,6 @@ def pie( ------- plotly.graph_objects.Figure """ - px_kwargs = resolve_colors(colors, px_kwargs) slots = assign_slots( list(darray.dims), "pie", @@ -754,6 +753,7 @@ def pie( facet_col=facet_col, facet_row=facet_row, ) + px_kwargs = resolve_colors(colors, px_kwargs, color_dim=slots.get("names"), darray=darray) df = to_dataframe(darray) value_col = get_value_col(darray)