Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ def render_shapes(
`fill_alpha` will overwrite the value present in the cmap.
groups : list[str] | str | None
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
them. Other values are set to NA. If element is None, broadcasting behaviour is attempted (use the same
values for all elements).
them. By default, non-matching elements are hidden. To show non-matching elements, set ``na_color``
explicitly.
If element is None, broadcasting behaviour is attempted (use the same values for all elements).
palette : list[str] | str | None
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
match the number of groups. If element is None, broadcasting behaviour is attempted (use the same values for
Expand Down Expand Up @@ -398,8 +399,9 @@ def render_points(
value is used instead.
groups : list[str] | str | None
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
them. Other values are set to NA. If `element` is `None`, broadcasting behaviour is attempted (use the same
values for all elements).
them. By default, non-matching points are filtered out entirely. To show non-matching points, set
``na_color`` explicitly.
If element is None, broadcasting behaviour is attempted (use the same values for all elements).
palette : list[str] | str | None
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
match the number of groups. If `element` is `None`, broadcasting behaviour is attempted (use the same values
Expand Down Expand Up @@ -671,7 +673,7 @@ def render_labels(
table_name to be used for the element if you would like a specific table to be used.
groups : list[str] | str | None
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
them. Other values are set to NA. The list can contain multiple discrete labels to be visualized.
them. By default, non-matching labels are hidden. To show non-matching labels, set ``na_color`` explicitly.
palette : list[str] | str | None
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
match the number of groups. The list can contain multiple palettes (one per group) to be visualized. If
Expand Down
246 changes: 148 additions & 98 deletions src/spatialdata_plot/pl/render.py

Large diffs are not rendered by default.

40 changes: 14 additions & 26 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,31 +990,17 @@ def _build_alignment_dtype_hint(
table_name: str | None,
) -> str:
"""Build a diagnostic hint string for dtype mismatches between element and table indices."""
hints: list[str] = []
color_index_dtype = getattr(color_series.index, "dtype", None)
element_index_dtype = getattr(getattr(element, "index", None), "dtype", None) if element is not None else None

table_instance_dtype = None
instance_key = None
if table_name is not None and sdata is not None and table_name in sdata.tables:
table = sdata.tables[table_name]
try:
_, _, instance_key = get_table_keys(table)
except (KeyError, ValueError, TypeError, AttributeError):
instance_key = None
if instance_key is not None and hasattr(table, "obs") and instance_key in table.obs:
table_instance_dtype = table.obs[instance_key].dtype

if (
element_index_dtype is not None
and table_instance_dtype is not None
and element_index_dtype != table_instance_dtype
):
hints.append(f"element index dtype is {element_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}")
if color_index_dtype is not None and element_index_dtype is not None and color_index_dtype != element_index_dtype:
hints.append(f"color index dtype is {color_index_dtype}, element index dtype is {element_index_dtype}")

return f" (hint: {'; '.join(hints)})" if hints else ""
el_dtype = getattr(getattr(element, "index", None), "dtype", None)
if el_dtype is None or table_name is None or sdata is None or table_name not in sdata.tables:
return ""
try:
_, _, instance_key = get_table_keys(sdata.tables[table_name])
except (KeyError, ValueError):
return ""
tbl_dtype = sdata.tables[table_name].obs[instance_key].dtype
if el_dtype != tbl_dtype:
return f" (hint: element index dtype is {el_dtype}, '{instance_key}' dtype is {tbl_dtype})"
return ""


def _set_color_source_vec(
Expand Down Expand Up @@ -1119,7 +1105,9 @@ def _set_color_source_vec(
table_to_use = None
else:
table_keys = list(sdata.tables.keys())
if table_keys:
if len(table_keys) == 1:
table_to_use = table_keys[0]
elif len(table_keys) > 1:
table_to_use = table_keys[0]
logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.")
else:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Points_datashader_can_color_by_category.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_can_filter_with_groups.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 30 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,36 @@ def test_plot_can_annotate_points_with_nan_in_df_continuous_datashader(self, sda
sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50)
sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=40, method="datashader").pl.show()

def test_plot_groups_na_color_none_filters_points(self, sdata_blobs: SpatialData):
"""With groups, non-matching points are filtered by default; na_color='red' keeps them visible."""
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], na_color="red", size=30).pl.show(
ax=axs[0], title="na_color='red'"
)
sdata_blobs.pl.render_points("blobs_points", color="cat_color", groups=["a"], size=30).pl.show(
ax=axs[1], title="default (filtered)"
)

def test_plot_groups_na_color_none_filters_points_datashader(self, sdata_blobs: SpatialData):
"""With groups + datashader, non-matching points are filtered by default."""
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
sdata_blobs.pl.render_points(
"blobs_points", color="cat_color", groups=["a"], na_color="red", size=30, method="datashader"
).pl.show(ax=axs[0], title="na_color='red'")
sdata_blobs.pl.render_points(
"blobs_points", color="cat_color", groups=["a"], size=30, method="datashader"
).pl.show(ax=axs[1], title="default (filtered)")


def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
"""When no elements match the groups, the plot should render without error."""
sdata_blobs["blobs_points"]["cat_color"] = pd.Series(["a", "b", "c", "a"] * 50, dtype="category")
sdata_blobs.pl.render_points(
"blobs_points", color="cat_color", groups=["nonexistent"], na_color=None, size=30
).pl.show()


def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
# Work on an independent copy since we mutate tables
Expand Down
28 changes: 28 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,34 @@ def test_plot_can_annotate_shapes_with_nan_in_df_continuous_datashader(self, sda
sdata_blobs["blobs_polygons"]["cont_color"] = [np.nan, 2, 3, 4, 5]
sdata_blobs.pl.render_shapes("blobs_polygons", color="cont_color", method="datashader").pl.show()

def test_plot_groups_na_color_none_filters_shapes(self, sdata_blobs: SpatialData):
"""With groups, non-matching shapes are filtered by default; na_color='red' keeps them visible."""
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], na_color="red").pl.show(
ax=axs[0], title="na_color='red'"
)
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"]).pl.show(
ax=axs[1], title="default (filtered)"
)

def test_plot_groups_na_color_none_filters_shapes_datashader(self, sdata_blobs: SpatialData):
"""With groups + datashader, non-matching shapes are filtered by default."""
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
sdata_blobs.pl.render_shapes(
"blobs_polygons", color="cat_color", groups=["a"], na_color="red", method="datashader"
).pl.show(ax=axs[0], title="na_color='red'")
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["a"], method="datashader").pl.show(
ax=axs[1], title="default (filtered)"
)


def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
"""When no elements match the groups, the plot should render without error."""
sdata_blobs["blobs_polygons"]["cat_color"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category")
sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_color", groups=["nonexistent"], na_color=None).pl.show()


def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog):
"""Test that NaN values in color data are handled gracefully and logged."""
Expand Down
Loading