diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f49121fd..8af4e177 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -213,8 +213,9 @@ def render_shapes( `fill_alpha` will overwrite the value present in the cmap. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of - them. Other values are set to NA. If element is None, broadcasting behaviour is attempted (use the same - values for all elements). + them. By default, non-matching elements are hidden. To show non-matching elements, set ``na_color`` + explicitly. + If element is None, broadcasting behaviour is attempted (use the same values for all elements). palette : list[str] | str | None Palette for discrete annotations. List of valid color names that should be used for the categories. Must match the number of groups. If element is None, broadcasting behaviour is attempted (use the same values for @@ -398,8 +399,9 @@ def render_points( value is used instead. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of - them. Other values are set to NA. If `element` is `None`, broadcasting behaviour is attempted (use the same - values for all elements). + them. By default, non-matching points are filtered out entirely. To show non-matching points, set + ``na_color`` explicitly. + If element is None, broadcasting behaviour is attempted (use the same values for all elements). palette : list[str] | str | None Palette for discrete annotations. List of valid color names that should be used for the categories. Must match the number of groups. If `element` is `None`, broadcasting behaviour is attempted (use the same values @@ -671,7 +673,7 @@ def render_labels( table_name to be used for the element if you would like a specific table to be used. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of - them. Other values are set to NA. The list can contain multiple discrete labels to be visualized. + them. By default, non-matching labels are hidden. To show non-matching labels, set ``na_color`` explicitly. palette : list[str] | str | None Palette for discrete annotations. List of valid color names that should be used for the categories. Must match the number of groups. The list can contain multiple palettes (one per group) to be visualized. If diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index a5bfb015..9cddb499 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -69,33 +69,15 @@ _DS_NAN_CATEGORY = "ds_nan" -def _coerce_categorical_source(cat_source: Any) -> pd.Categorical: - """Return a pandas Categorical from known, concrete sources only. - - Raises - ------ - TypeError - If *cat_source* is not a ``dd.Series``, ``pd.Series``, - ``pd.Categorical``, or ``np.ndarray``. - """ - if isinstance(cat_source, dd.Series): - if isinstance(cat_source.dtype, pd.CategoricalDtype) and getattr(cat_source.cat, "known", True) is False: - cat_source = cat_source.cat.as_known() - cat_source = cat_source.compute() - - if isinstance(cat_source, pd.Series): - if isinstance(cat_source.dtype, pd.CategoricalDtype): - return cat_source.array - return pd.Categorical(cat_source) - if isinstance(cat_source, pd.Categorical): - return cat_source - if isinstance(cat_source, np.ndarray): - return pd.Categorical(cat_source) - - raise TypeError( - f"Cannot coerce {type(cat_source).__name__} to pd.Categorical. " - "Expected dd.Series, pd.Series, pd.Categorical, or np.ndarray." - ) +def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: + """Return a ``pd.Categorical`` from a pandas or dask Series.""" + if isinstance(series, dd.Series): + if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False: + series = series.cat.as_known() + series = series.compute() + if isinstance(series.dtype, pd.CategoricalDtype): + return series.array + return pd.Categorical(series) def _build_datashader_color_key( @@ -104,18 +86,87 @@ def _build_datashader_color_key( na_color_hex: str, ) -> dict[str, str]: """Build a datashader ``color_key`` dict from a categorical series and its color vector.""" + na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex + # Map each category to its first-occurrence color via codes colors_arr = np.asarray(color_vector, dtype=object) - color_key: dict[str, str] = {} - for cat in cat_series.categories: - if cat == _DS_NAN_CATEGORY: - key_color = na_color_hex - else: - idx = np.flatnonzero(cat_series == cat) - key_color = colors_arr[idx[0]] if idx.size else na_color_hex - if isinstance(key_color, str) and key_color.startswith("#"): - key_color = _hex_no_alpha(key_color) - color_key[str(cat)] = key_color - return color_key + first_color: dict[str, str] = {} + for code, color in zip(cat_series.codes, colors_arr, strict=False): + if code < 0: + continue + cat_name = str(cat_series.categories[code]) + if cat_name not in first_color: + first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color + return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories} + + +def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series: + """Add a sentinel category for NaN values in a categorical series. + + Safely handles series that are not yet categorical, dask-backed + categoricals that need ``as_known()``, and series that already + contain the sentinel. + """ + if not isinstance(series.dtype, pd.CategoricalDtype): + series = series.astype("category") + if hasattr(series.cat, "as_known"): + series = series.cat.as_known() + if sentinel not in series.cat.categories: + series = series.cat.add_categories(sentinel) + return series.fillna(sentinel) + + +def _want_decorations(color_vector: Any, na_color: Color) -> bool: + """Return whether legend/colorbar decorations should be shown. + + Decorations are suppressed when all colors equal the NA color + (i.e., nothing informative to display). + """ + if color_vector is None: + return False + cv = np.asarray(color_vector) + if cv.size == 0: + return False + unique_vals = set(cv.tolist()) + if len(unique_vals) != 1: + return True + only_val = next(iter(unique_vals)) + na_hex = na_color.get_hex() + if isinstance(only_val, str) and only_val.startswith("#") and na_hex.startswith("#"): + return _hex_no_alpha(only_val) != _hex_no_alpha(na_hex) + return bool(only_val != na_hex) + + +def _reparse_points( + sdata_filt: sd.SpatialData, + element: str, + df: pd.DataFrame, + transformation: Any, + coordinate_system: str, +) -> None: + """Re-register a points DataFrame in *sdata_filt* with its transformation.""" + dd_frame = dask.dataframe.from_pandas(df, npartitions=1) + sdata_filt.points[element] = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"}) + set_transformation( + element=sdata_filt.points[element], + transformation=transformation, + to_coordinate_system=coordinate_system, + ) + + +def _filter_groups_transparent_na( + groups: str | list[str], + color_source_vector: pd.Categorical, + color_vector: pd.Series | np.ndarray | list[str], +) -> tuple[np.ndarray, pd.Categorical, np.ndarray]: + """Return a boolean mask and filtered color vectors for groups filtering. + + Used when ``na_color=None`` (fully transparent) so that non-matching + elements are removed entirely instead of rendered invisibly. + """ + keep = color_source_vector.isin(groups) + filtered_csv = color_source_vector[keep] + filtered_cv = np.asarray(color_vector)[keep] + return keep, filtered_csv, filtered_cv def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]: @@ -221,6 +272,18 @@ def _render_shapes( values_are_categorical = color_source_vector is not None + # When groups are specified, filter out non-matching elements by default. + # Only show non-matching elements if the user explicitly sets na_color. + _na = render_params.cmap_params.na_color + if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"): + keep, color_source_vector, color_vector = _filter_groups_transparent_na( + groups, color_source_vector, color_vector + ) + shapes = shapes[keep].reset_index(drop=True) + if len(shapes) == 0: + return + sdata_filt[element] = shapes + # color_source_vector is None when the values aren't categorical if values_are_categorical and render_params.transfunc is not None: color_vector = render_params.transfunc(color_vector) @@ -337,11 +400,13 @@ def _render_shapes( # If single color, broadcast to all shapes color_vector = [color_vector[0]] * len(transformed_element) else: - # If lengths don't match, pad or truncate to match + logger.warning( + f"Color vector length ({len(color_vector)}) does not match element count " + f"({len(transformed_element)}). This may indicate a bug." + ) if len(color_vector) > len(transformed_element): color_vector = color_vector[: len(transformed_element)] else: - # Pad with the last color or na_color na_color = render_params.cmap_params.na_color.get_hex_with_alpha() color_vector = list(color_vector) + [na_color] * (len(transformed_element) - len(color_vector)) @@ -356,12 +421,10 @@ def _render_shapes( aggregate_with_reduction = None continuous_nan_shapes = None - if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1): + if col_for_color is not None: if color_by_categorical: # add a sentinel category so that shapes with NaN value are colored in the na_color - transformed_element[col_for_color] = ( - transformed_element[col_for_color].cat.add_categories(_DS_NAN_CATEGORY).fillna(_DS_NAN_CATEGORY) - ) + transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) agg = cvs.polygons( transformed_element, geometry="geometry", @@ -466,9 +529,7 @@ def _render_shapes( if continuous_nan_shapes is not None: # for coloring by continuous variable: render nan shapes separately - nan_color_hex = render_params.cmap_params.na_color.get_hex() - if nan_color_hex.startswith("#") and len(nan_color_hex) == 9: - nan_color_hex = nan_color_hex[:7] + nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) continuous_nan_shapes = ds.tf.shade( continuous_nan_shapes, cmap=nan_color_hex, @@ -637,10 +698,7 @@ def _render_shapes( vmax = 1.0 _cax.set_clim(vmin=vmin, vmax=vmax) - if ( - len(set(color_vector)) != 1 - or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha() - ): + if _want_decorations(color_vector, render_params.cmap_params.na_color): # necessary in case different shapes elements are annotated with one table if color_source_vector is not None and render_params.col_for_color is not None: color_source_vector = color_source_vector.remove_unused_categories() @@ -781,14 +839,7 @@ def _render_points( # Convert back to dask dataframe to modify sdata transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system] - points_dd = dask.dataframe.from_pandas(points_for_model, npartitions=1) - sdata_filt.points[element] = PointsModel.parse(points_dd, coordinates={"x": "x", "y": "y"}) - # restore transformation in coordinate system of interest - set_transformation( - element=sdata_filt.points[element], - transformation=transformation_in_cs, - to_coordinate_system=coordinate_system, - ) + _reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system) if col_for_color is not None: assert isinstance(col_for_color, str) @@ -831,14 +882,22 @@ def _render_points( ) if added_color_from_table and col_for_color is not None: - points_with_color_dd = dask.dataframe.from_pandas(points_pd_with_color, npartitions=1) - sdata_filt.points[element] = PointsModel.parse(points_with_color_dd, coordinates={"x": "x", "y": "y"}) - set_transformation( - element=sdata_filt.points[element], - transformation=transformation_in_cs, - to_coordinate_system=coordinate_system, + _reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system) + + # When groups are specified, filter out non-matching elements by default. + # Only show non-matching elements if the user explicitly sets na_color. + _na = render_params.cmap_params.na_color + if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"): + keep, color_source_vector, color_vector = _filter_groups_transparent_na( + groups, color_source_vector, color_vector ) - points_dd = points_with_color_dd + n_points = int(keep.sum()) + if n_points == 0: + return + # filter the materialized points, adata, and re-register in sdata_filt + points = points[keep].reset_index(drop=True) + adata = adata[keep] + _reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system) # color_source_vector is None when the values aren't categorical if color_source_vector is None and render_params.transfunc is not None: @@ -918,14 +977,7 @@ def _render_points( if col_for_color is not None: if color_by_categorical: # add nan as category so that nan points are shown in the nan color - cat_series = transformed_element[col_for_color] - if not isinstance(cat_series.dtype, pd.CategoricalDtype): - cat_series = cat_series.astype("category") - if hasattr(cat_series.cat, "as_known"): - cat_series = cat_series.cat.as_known() - if _DS_NAN_CATEGORY not in cat_series.cat.categories: - cat_series = cat_series.cat.add_categories(_DS_NAN_CATEGORY) - transformed_element[col_for_color] = cat_series.fillna(_DS_NAN_CATEGORY) + transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) else: reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum" @@ -1015,9 +1067,7 @@ def _render_points( if continuous_nan_points is not None: # for coloring by continuous variable: render nan points separately - nan_color_hex = render_params.cmap_params.na_color.get_hex() - if nan_color_hex.startswith("#") and len(nan_color_hex) == 9: - nan_color_hex = nan_color_hex[:7] + nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) continuous_nan_points = ds.tf.spread(continuous_nan_points, px=px, how="max") continuous_nan_points = ds.tf.shade( continuous_nan_points, @@ -1085,27 +1135,7 @@ def _render_points( ax.set_xbound(extent["x"]) ax.set_ybound(extent["y"]) - # Decide whether there is any informative color variation. - # We skip legend/colorbar only if all colors are equal to the NA color. - want_decorations = True - if color_vector is None: - want_decorations = False - else: - cv = np.asarray(color_vector) - if cv.size == 0: - want_decorations = False - else: - unique_vals = set(cv.tolist()) - if len(unique_vals) == 1: - only_val = next(iter(unique_vals)) - na_hex = render_params.cmap_params.na_color.get_hex() - if isinstance(only_val, str) and only_val.startswith("#") and na_hex.startswith("#"): - only_norm = _hex_no_alpha(only_val) - na_norm = _hex_no_alpha(na_hex) - if only_norm == na_norm: - want_decorations = False - - if want_decorations: + if _want_decorations(color_vector, render_params.cmap_params.na_color): if color_source_vector is None: palette = ListedColormap(dict.fromkeys(color_vector)) else: @@ -1501,6 +1531,26 @@ def _render_labels( else: assert color_source_vector is None + # When groups are specified, zero out non-matching label IDs so they render as background. + # Only show non-matching labels if the user explicitly sets na_color. + _na = render_params.cmap_params.na_color + if ( + groups is not None + and categorical + and color_source_vector is not None + and (_na.default_color_set or _na.alpha == "00") + ): + keep_vec = color_source_vector.isin(groups) + matching_ids = instance_id[keep_vec] + keep_mask = np.isin(label.values, matching_ids) + label = label.copy() + label.values[~keep_mask] = 0 + instance_id = instance_id[keep_vec] + color_source_vector = color_source_vector[keep_vec] + color_vector = color_vector[keep_vec] + if isinstance(color_vector.dtype, pd.CategoricalDtype): + color_vector = color_vector.remove_unused_categories() + def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage: labels = _map_color_seg( seg=label.values, diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a0aced27..402a2191 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -990,31 +990,17 @@ def _build_alignment_dtype_hint( table_name: str | None, ) -> str: """Build a diagnostic hint string for dtype mismatches between element and table indices.""" - hints: list[str] = [] - color_index_dtype = getattr(color_series.index, "dtype", None) - element_index_dtype = getattr(getattr(element, "index", None), "dtype", None) if element is not None else None - - table_instance_dtype = None - instance_key = None - if table_name is not None and sdata is not None and table_name in sdata.tables: - table = sdata.tables[table_name] - try: - _, _, instance_key = get_table_keys(table) - except (KeyError, ValueError, TypeError, AttributeError): - instance_key = None - if instance_key is not None and hasattr(table, "obs") and instance_key in table.obs: - table_instance_dtype = table.obs[instance_key].dtype - - if ( - element_index_dtype is not None - and table_instance_dtype is not None - and element_index_dtype != table_instance_dtype - ): - hints.append(f"element index dtype is {element_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}") - if color_index_dtype is not None and element_index_dtype is not None and color_index_dtype != element_index_dtype: - hints.append(f"color index dtype is {color_index_dtype}, element index dtype is {element_index_dtype}") - - return f" (hint: {'; '.join(hints)})" if hints else "" + el_dtype = getattr(getattr(element, "index", None), "dtype", None) + if el_dtype is None or table_name is None or sdata is None or table_name not in sdata.tables: + return "" + try: + _, _, instance_key = get_table_keys(sdata.tables[table_name]) + except (KeyError, ValueError): + return "" + tbl_dtype = sdata.tables[table_name].obs[instance_key].dtype + if el_dtype != tbl_dtype: + return f" (hint: element index dtype is {el_dtype}, '{instance_key}' dtype is {tbl_dtype})" + return "" def _set_color_source_vec( @@ -1119,7 +1105,9 @@ def _set_color_source_vec( table_to_use = None else: table_keys = list(sdata.tables.keys()) - if table_keys: + if len(table_keys) == 1: + table_to_use = table_keys[0] + elif len(table_keys) > 1: table_to_use = table_keys[0] logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.") else: diff --git a/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png b/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png index 46230f10..499fb50f 100644 Binary files a/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png and b/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png differ diff --git a/tests/_images/Labels_subset_categorical_label_maintains_order.png b/tests/_images/Labels_subset_categorical_label_maintains_order.png index a18d77cd..9e1c2cc6 100644 Binary files a/tests/_images/Labels_subset_categorical_label_maintains_order.png and b/tests/_images/Labels_subset_categorical_label_maintains_order.png differ diff --git a/tests/_images/Labels_subset_categorical_label_maintains_order_when_palette_overwrite.png b/tests/_images/Labels_subset_categorical_label_maintains_order_when_palette_overwrite.png index 34f063bf..6d0806c5 100644 Binary files a/tests/_images/Labels_subset_categorical_label_maintains_order_when_palette_overwrite.png and b/tests/_images/Labels_subset_categorical_label_maintains_order_when_palette_overwrite.png differ diff --git a/tests/_images/Points_can_annotate_points_with_table_and_groups.png b/tests/_images/Points_can_annotate_points_with_table_and_groups.png index 17358b18..cec3844b 100644 Binary files a/tests/_images/Points_can_annotate_points_with_table_and_groups.png and b/tests/_images/Points_can_annotate_points_with_table_and_groups.png differ diff --git a/tests/_images/Points_datashader_can_color_by_category.png b/tests/_images/Points_datashader_can_color_by_category.png index 1badee5f..d1180d32 100644 Binary files a/tests/_images/Points_datashader_can_color_by_category.png and b/tests/_images/Points_datashader_can_color_by_category.png differ diff --git a/tests/_images/Points_groups_na_color_none_filters_points.png b/tests/_images/Points_groups_na_color_none_filters_points.png new file mode 100644 index 00000000..4c440ecf Binary files /dev/null and b/tests/_images/Points_groups_na_color_none_filters_points.png differ diff --git a/tests/_images/Points_groups_na_color_none_filters_points_datashader.png b/tests/_images/Points_groups_na_color_none_filters_points_datashader.png new file mode 100644 index 00000000..b2b920a5 Binary files /dev/null and b/tests/_images/Points_groups_na_color_none_filters_points_datashader.png differ diff --git a/tests/_images/Shapes_can_filter_with_groups.png b/tests/_images/Shapes_can_filter_with_groups.png index 5d66f966..4f9b71be 100644 Binary files a/tests/_images/Shapes_can_filter_with_groups.png and b/tests/_images/Shapes_can_filter_with_groups.png differ diff --git a/tests/_images/Shapes_groups_na_color_none_filters_shapes.png b/tests/_images/Shapes_groups_na_color_none_filters_shapes.png new file mode 100644 index 00000000..8d8380f0 Binary files /dev/null and b/tests/_images/Shapes_groups_na_color_none_filters_shapes.png differ diff --git a/tests/_images/Shapes_groups_na_color_none_filters_shapes_datashader.png b/tests/_images/Shapes_groups_na_color_none_filters_shapes_datashader.png new file mode 100644 index 00000000..32009600 Binary files /dev/null and b/tests/_images/Shapes_groups_na_color_none_filters_shapes_datashader.png differ diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 95270930..69191ea8 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -576,6 +576,36 @@ def test_plot_can_annotate_points_with_nan_in_df_continuous_datashader(self, sda sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50) sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=40, method="datashader").pl.show() + def test_plot_groups_na_color_none_filters_points(self, sdata_blobs: SpatialData): + """With groups, non-matching points are filtered by default; na_color='red' keeps them visible.""" + sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category") + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], na_color="red", size=30).pl.show( + ax=axs[0], title="na_color='red'" + ) + sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], size=30).pl.show( + ax=axs[1], title="default (filtered)" + ) + + def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs: SpatialData): + """With groups + datashader, non-matching points are filtered by default.""" + sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category") + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + sdata_blobs.pl.render_points( + "blobs_points", color="cat_color", groups=["a"], na_color="red", size=30, method="datashader" + ).pl.show(ax=axs[0], title="na_color='red'") + sdata_blobs.pl.render_points( + "blobs_points", color="cat_color", groups=["a"], size=30, method="datashader" + ).pl.show(ax=axs[1], title="default (filtered)") + + +def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData): + """When no elements match the groups, the plot should render without error.""" + sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category") + sdata_blobs.pl.render_points( + "blobs_points", color="cat_color", groups=["nonexistent"], na_color=None, size=30 + ).pl.show() + def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData): # Work on an independent copy since we mutate tables diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index d38e0c51..dece73b3 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -983,6 +983,34 @@ def test_plot_can_annotate_shapes_with_nan_in_df_continuous_datashader(self, sda sdata_blobs["blobs_polygons"]["cont_color"] = [np.nan, 2, 3, 4, 5] sdata_blobs.pl.render_shapes("blobs_polygons", color="cont_color", method="datashader").pl.show() + def test_plot_groups_na_color_none_filters_shapes(self, sdata_blobs: SpatialData): + """With groups, non-matching shapes are filtered by default; na_color='red' keeps them visible.""" + sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], na_color="red").pl.show( + ax=axs[0], title="na_color='red'" + ) + sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"]).pl.show( + ax=axs[1], title="default (filtered)" + ) + + def test_plot_groups_na_color_none_filters_shapes_datashader(self, sdata_blobs: SpatialData): + """With groups + datashader, non-matching shapes are filtered by default.""" + sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="cat_color", groups=["a"], na_color="red", method="datashader" + ).pl.show(ax=axs[0], title="na_color='red'") + sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], method="datashader").pl.show( + ax=axs[1], title="default (filtered)" + ) + + +def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData): + """When no elements match the groups, the plot should render without error.""" + sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show() + def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog): """Test that NaN values in color data are handled gracefully and logged."""