Skip to content
Open
108 changes: 107 additions & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@

from __future__ import annotations

import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
from monai.utils.module import optional_import

from .metric import CumulativeIterationMetric

distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt")
generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure")
sn_label, _ = optional_import("scipy.ndimage", name="label")

__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]


Expand Down Expand Up @@ -95,6 +101,9 @@ class DiceMetric(CumulativeIterationMetric):
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

"""

Expand All @@ -106,6 +115,7 @@ def __init__(
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -114,13 +124,15 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.per_component = per_component
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
per_component=self.per_component,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -175,6 +187,7 @@ def compute_dice(
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
Expand All @@ -192,6 +205,9 @@ def compute_dice(
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -204,6 +220,7 @@ def compute_dice(
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
per_component=per_component,
)(y_pred=y_pred, y=y)


Expand Down Expand Up @@ -246,6 +263,9 @@ class DiceHelper:
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
Expand All @@ -262,6 +282,7 @@ def __init__(
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
per_component: bool = False,
) -> None:
# handling deprecated arguments
if sigmoid is not None:
Expand All @@ -277,6 +298,81 @@ def __init__(
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.per_component = per_component

def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.

Args:
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.

Returns:
torch.Tensor: Voronoi region IDs (int32) on CPU.
"""
if not has_ndimage:
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)

def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute per-component Dice for a single batch item.

Args:
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).

Returns:
torch.Tensor: Mean Dice over connected components.
"""
data = []
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if self.ignore_empty:
data.append(torch.tensor(float("nan"), device=y_idx.device))
elif y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -322,15 +418,25 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

first_ch = 0 if self.include_background else 1
if self.per_component:
if len(y_pred.shape) != 5 or len(y.shape) != 5 or y_pred.shape[1] != 2 or y.shape[1] != 2:
raise ValueError(
"per_component requires both y_pred and y to be 5D binary segmentations "
f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
)

first_ch = 0 if self.include_background and not self.per_component else 1
data = []
for b in range(y_pred.shape[0]):
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
if self.per_component:
c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
data.append(torch.stack(c_list))

data = torch.stack(data, dim=0).contiguous() # type: ignore

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
Expand Down
35 changes: 35 additions & 0 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from parameterized import parameterized

from monai.metrics import DiceHelper, DiceMetric, compute_dice
from monai.utils.module import optional_import

_, has_ndimage = optional_import("scipy.ndimage")

_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
Expand Down Expand Up @@ -250,6 +253,24 @@
{"label_1": 0.4000, "label_2": 0.6667},
]

# Testcase for per_component DiceMetric
y = torch.zeros((5, 2, 64, 64, 64))
y_hat = torch.zeros((5, 2, 64, 64, 64))

y[0, 1, 20:25, 20:25, 20:25] = 1
y[0, 1, 40:45, 40:45, 40:45] = 1
y[0, 0] = 1 - y[0, 1]

y_hat[0, 1, 21:26, 21:26, 21:26] = 1
y_hat[0, 1, 41:46, 39:44, 41:46] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]

TEST_CASE_16 = [
{"per_component": True, "ignore_empty": False},
{"y": y, "y_pred": y_hat},
[[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]],
]


class TestComputeMeanDice(unittest.TestCase):

Expand Down Expand Up @@ -301,6 +322,20 @@ def test_nans_class(self, params, input_data, expected_value):
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# CC DiceMetric tests
@parameterized.expand([TEST_CASE_16])
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
def test_cc_dice_value(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
def test_input_dimensions(self):
with self.assertRaises(ValueError):
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()
Loading