Add parameter to DiceMetric and DiceHelper classes#8774
Add parameter to DiceMetric and DiceHelper classes#8774VijayVignesh1 wants to merge 9 commits intoProject-MONAI:devfrom
Conversation
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds a per_component mode to DiceMetric, DiceHelper, and compute_dice to compute Dice per connected component. When enabled, inputs must be 5D binary segmentation with exactly 2 channels; ground-truth foreground is decomposed into connected components, Voronoi regions are computed, and per-component Dice scores are produced via new DiceHelper methods compute_voronoi_regions_fast and compute_cc_dice. The per_component flag is propagated through initializers and compute paths; DiceHelper.call validates input shape and raises ValueError for mismatches. Tests were added to validate per-component values and input-dimension checks (skipped if scipy.ndimage is unavailable). Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (3)
monai/metrics/meandice.py (1)
418-426: Wasted computation whenper_component=True.Lines 420-423 compute channel Dice, then lines 424-425 discard it and overwrite
c_list. Move the branch earlier.Proposed fix
for b in range(y_pred.shape[0]): - c_list = [] - for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: - x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() - x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) if self.per_component: c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] + else: + c_list = [] + for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: + x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() + x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] + c_list.append(self.compute_channel(x_pred, x)) data.append(torch.stack(c_list))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 418 - 426, The loop is doing wasted work: it always computes per-channel Dice via compute_channel for each c and only when self.per_component is True it discards those results and replaces c_list with a compute_cc_dice call. Change the logic inside the for b in range(...) loop to check self.per_component before computing channels; if self.per_component is True, directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] and skip the per-channel compute_channel loop and related x_pred/x extraction, otherwise run the existing per-channel path that builds c_list with compute_channel as before. Ensure references to y_pred, y, compute_channel, compute_cc_dice, c_list and per_component are used so the branch correctly short-circuits the expensive channel computations.tests/metrics/test_compute_meandice.py (2)
253-276: Test data construction is hard to follow; expected value undocumented.The lambda-walrus pattern obscures setup. Consider a helper function. Also document how
0.5120was derived for maintainability.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 253 - 276, TEST_CASE_16 uses a lambda-walrus pattern (variables y and y_pred inside TEST_CASE_16) that makes the test data setup hard to read and omits explanation of the expected 0.5120 value; extract the tensor construction into a small descriptive helper (e.g., build_test_case_16_tensors or make_meandice_case_16) and replace the inline lambdas with calls to that helper, and add a short comment next to the expected value explaining how 0.5120 was computed (e.g., describe overlapping voxel counts and Dice formula for the two shifted cubes) so the test is readable and the expected number is documented.
337-339: Shape mismatch may obscure test intent.Both tensors are 4D (not 5D) and have 3 channels (not 2). The spatial mismatch (144 vs 145) is irrelevant to the validation. Use matching shapes to clarify:
- DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + DiceMetric(per_component=True)(torch.ones([3, 3, 64, 64]), torch.ones([3, 3, 64, 64]))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 337 - 339, The test currently uses two 4D tensors with mismatched spatial sizes and 3 channels, which obscures the intent to validate dimensionality; update test_input_dimensions so both tensors have identical shapes but still 4D to trigger the ValueError (e.g., use torch.ones([3, 2, 144, 144]) for both), ensuring the failure comes from incorrect dimensionality for DiceMetric rather than a spatial-size mismatch or wrong channel count.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 14-17: The module imports SciPy unconditionally causing CI
failures when SciPy is not installed; change the top-level imports to use
MONAI's optional_import pattern to import distance_transform_edt,
generate_binary_structure and label (sn_label) and expose a has_scipy flag, then
in compute_voronoi_regions_fast check has_scipy and raise a clear RuntimeError
if False; update references to
sn_label/distance_transform_edt/generate_binary_structure in the file to use the
optionally imported symbols so runtime usage is guarded.
- Line 416: The code currently sets first_ch based on a combined condition which
silently ignores include_background when per_component is True; update the logic
in the MeanDice/meandice implementation to detect the conflicting flags
(self.per_component True and self.include_background False) and emit a clear
warning (e.g., warnings.warn or using the module logger) that include_background
will be ignored when per_component is enabled, then keep the existing behavior
for first_ch (set first_ch=1) to preserve compatibility; reference the
attributes self.include_background, self.per_component and the local variable
first_ch so reviewers can locate and adjust the check and add the warning.
- Around line 300-321: The compute_voronoi_regions_fast function's docstring
lacks a Returns section and the function always returns a CPU tensor
(torch.from_numpy) even if the original input was a CUDA tensor; update the
docstring to include a Returns: description and type (torch.Tensor on same
device as input) and change the implementation to preserve input type/device:
accept numpy array or torch.Tensor for labels, record the original device and
dtype (if torch.Tensor), convert input to CPU numpy for EDT processing, then
convert the resulting voronoi numpy array back to a torch.Tensor and move it to
the original device and appropriate dtype before returning; reference
compute_voronoi_regions_fast, labels, edt_input, indices, and voronoi when
locating where to apply these changes.
- Around line 323-364: The compute_cc_dice method's docstring and
empty-ground-truth handling are incorrect: update the docstring for
compute_cc_dice to state the actual expected input shapes (e.g., tensors that
may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.
---
Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 418-426: The loop is doing wasted work: it always computes
per-channel Dice via compute_channel for each c and only when self.per_component
is True it discards those results and replaces c_list with a compute_cc_dice
call. Change the logic inside the for b in range(...) loop to check
self.per_component before computing channels; if self.per_component is True,
directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0),
y=y[b].unsqueeze(0))] and skip the per-channel compute_channel loop and related
x_pred/x extraction, otherwise run the existing per-channel path that builds
c_list with compute_channel as before. Ensure references to y_pred, y,
compute_channel, compute_cc_dice, c_list and per_component are used so the
branch correctly short-circuits the expensive channel computations.
In `@tests/metrics/test_compute_meandice.py`:
- Around line 253-276: TEST_CASE_16 uses a lambda-walrus pattern (variables y
and y_pred inside TEST_CASE_16) that makes the test data setup hard to read and
omits explanation of the expected 0.5120 value; extract the tensor construction
into a small descriptive helper (e.g., build_test_case_16_tensors or
make_meandice_case_16) and replace the inline lambdas with calls to that helper,
and add a short comment next to the expected value explaining how 0.5120 was
computed (e.g., describe overlapping voxel counts and Dice formula for the two
shifted cubes) so the test is readable and the expected number is documented.
- Around line 337-339: The test currently uses two 4D tensors with mismatched
spatial sizes and 3 channels, which obscures the intent to validate
dimensionality; update test_input_dimensions so both tensors have identical
shapes but still 4D to trigger the ValueError (e.g., use torch.ones([3, 2, 144,
144]) for both), ensuring the failure comes from incorrect dimensionality for
DiceMetric rather than a spatial-size mismatch or wrong channel count.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f8c5c98e-0bb3-413d-9471-3bef41a45cfa
📒 Files selected for processing (2)
monai/metrics/meandice.pytests/metrics/test_compute_meandice.py
…itai - docstring issues, ignore_empty bug Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
…itai - docstring issues, ignore_empty bug Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
monai/metrics/meandice.py (2)
427-427:⚠️ Potential issue | 🟡 Minor
include_backgroundis still silently ignored withper_component=True.Line 427 forces foreground-only behavior without signaling it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` at line 427, The current assignment to first_ch in MeanDice (meandice.py) ignores include_background when per_component is True; change the logic so include_background is honored regardless of per_component by setting first_ch based solely on self.include_background (e.g., first_ch = 0 if self.include_background else 1) instead of conditioning on not self.per_component; update any related comments/tests that assumed the previous behavior.
318-328:⚠️ Potential issue | 🟠 MajorPer-component Dice is not CUDA-safe.
Line 318 uses
np.asarray(labels), which breaks for CUDA tensors; Line 328 returns a CPU tensor, and Line 360 then mixes devices.Proposed fix
- def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): + def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): @@ - x = np.asarray(labels) + labels_t = labels if isinstance(labels, torch.Tensor) else torch.as_tensor(labels) + in_device = labels_t.device + x = labels_t.detach().cpu().numpy() @@ - if num == 0: - return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) + if num == 0: + return torch.zeros_like(labels_t, dtype=torch.int32, device=in_device) @@ - return torch.from_numpy(voronoi) + return torch.from_numpy(voronoi).to(device=in_device, dtype=torch.int32)#!/bin/bash set -euo pipefail # Verify current implementation uses numpy conversion without explicit CPU transfer rg -n -C2 'def compute_voronoi_regions_fast|np\.asarray\(labels\)|torch\.from_numpy\(voronoi\)|compute_voronoi_regions_fast\(y_idx\[0\]\)' monai/metrics/meandice.py # Confirm there is no explicit detach+cpu numpy conversion in this function body python - <<'PY' from pathlib import Path text = Path("monai/metrics/meandice.py").read_text() start = text.index("def compute_voronoi_regions_fast") end = text.index("def compute_cc_dice") chunk = text[start:end] print("contains_detach_cpu_numpy:", ".detach().cpu().numpy()" in chunk or ".cpu().numpy()" in chunk) PYAlso applies to: 356-361
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 318 - 328, The function compute_voronoi_regions_fast currently uses np.asarray(labels) and torch.from_numpy(voronoi) which breaks for CUDA tensors; capture the input tensor's device and dtype first (e.g., orig_device = labels.device if isinstance(labels, torch.Tensor) else torch.device('cpu')), convert safely to CPU numpy via labels = labels.detach().cpu().numpy() (or leave numpy arrays unchanged), run the existing numpy logic, then convert the result back using torch.from_numpy(voronoi).to(device=orig_device, dtype=torch.int32) or torch.as_tensor(voronoi, device=orig_device, dtype=torch.int32) so the returned tensor is on the same device as the input; update both compute_voronoi_regions_fast and the similar code at the other location (lines ~356-361 / compute_cc_dice caller) to follow this pattern.
🧹 Nitpick comments (2)
tests/metrics/test_compute_meandice.py (1)
334-337: Add per-component validation tests for invalidyshape/channel.
test_input_dimensionscovers only one invalid input pattern. Add cases forynot being(B, 2, D, H, W)and fory_pred/ychannel mismatch.As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 334 - 337, Extend the test_input_dimensions in tests/metrics/test_compute_meandice.py to add two more invalid-shape cases: (1) verify DiceMetric(per_component=True) raises ValueError when y has wrong channel count (e.g., y shape not (B,2,D,H,W) such as torch.ones([3,1,144,144]) or torch.ones([3,3,144,144]) depending on 2D/3D expectation), and (2) verify DiceMetric(per_component=True) raises ValueError when y_pred and y have mismatched channel counts (call DiceMetric(per_component=True)(y_pred, y) where y_pred has 2 channels and y has 1 or vice versa). Reference the DiceMetric class and the existing test_input_dimensions to add these assertions so coverage includes invalid y shapes and channel mismatches.monai/metrics/meandice.py (1)
429-437: Skip per-channel Dice work whenper_component=True.Line 431-434 computes channel Dice, then Line 436 overwrites
c_list. This is unnecessary work on every batch item.Proposed refactor
data = [] for b in range(y_pred.shape[0]): + if self.per_component: + data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).unsqueeze(0)) + continue c_list = [] for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] c_list.append(self.compute_channel(x_pred, x)) - if self.per_component: - c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] data.append(torch.stack(c_list))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 429 - 437, The loop currently always computes per-channel Dice via compute_channel for each class then overwrites c_list when self.per_component is True, causing unnecessary work; update the logic in the batch loop (the block using variables b, c_list, first_ch, n_pred_ch and calling compute_channel and compute_cc_dice) to short-circuit when self.per_component is True—i.e., if self.per_component is True, skip the inner per-channel loop entirely and directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]; otherwise run the existing per-channel computation using compute_channel. Ensure you preserve behavior for n_pred_ch == 1 and that data.append(torch.stack(c_list)) still executes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 421-425: The per_component branch only validates y_pred; add the
same validation for y so incorrect shapes or channel counts on the ground truth
produce an immediate error. In the function/method that contains the existing
check (the block referencing self.per_component and y_pred), validate that y
also has 5 dimensions and y.shape[1] == 2 (or that y.shape matches y_pred) and
raise a ValueError with a message parallel to the existing one (e.g.,
"per_component requires 5D binary segmentation with 2 channels... Got shape
{y.shape}"). Ensure you reference the same symbol names (self.per_component,
y_pred, y) so the check runs before any computation that assumes the 5D
two-channel layout.
In `@tests/metrics/test_compute_meandice.py`:
- Around line 275-276: Remove the class-level `@unittest.skipUnless`(has_ndimage,
...) on TestComputeMeanDice and instead apply that skip only to the tests that
exercise the per_component code-path; identify and decorate the specific methods
(e.g., any test methods named test_*per_component* or those that call
compute_mean_dice(..., per_component=True) such as test_mean_dice_per_component)
with `@unittest.skipUnless`(has_ndimage, "Requires scipy.ndimage."); keep other
tests in TestComputeMeanDice unskipped so non-ndimage paths still run.
---
Duplicate comments:
In `@monai/metrics/meandice.py`:
- Line 427: The current assignment to first_ch in MeanDice (meandice.py) ignores
include_background when per_component is True; change the logic so
include_background is honored regardless of per_component by setting first_ch
based solely on self.include_background (e.g., first_ch = 0 if
self.include_background else 1) instead of conditioning on not
self.per_component; update any related comments/tests that assumed the previous
behavior.
- Around line 318-328: The function compute_voronoi_regions_fast currently uses
np.asarray(labels) and torch.from_numpy(voronoi) which breaks for CUDA tensors;
capture the input tensor's device and dtype first (e.g., orig_device =
labels.device if isinstance(labels, torch.Tensor) else torch.device('cpu')),
convert safely to CPU numpy via labels = labels.detach().cpu().numpy() (or leave
numpy arrays unchanged), run the existing numpy logic, then convert the result
back using torch.from_numpy(voronoi).to(device=orig_device, dtype=torch.int32)
or torch.as_tensor(voronoi, device=orig_device, dtype=torch.int32) so the
returned tensor is on the same device as the input; update both
compute_voronoi_regions_fast and the similar code at the other location (lines
~356-361 / compute_cc_dice caller) to follow this pattern.
---
Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 429-437: The loop currently always computes per-channel Dice via
compute_channel for each class then overwrites c_list when self.per_component is
True, causing unnecessary work; update the logic in the batch loop (the block
using variables b, c_list, first_ch, n_pred_ch and calling compute_channel and
compute_cc_dice) to short-circuit when self.per_component is True—i.e., if
self.per_component is True, skip the inner per-channel loop entirely and
directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0),
y=y[b].unsqueeze(0))]; otherwise run the existing per-channel computation using
compute_channel. Ensure you preserve behavior for n_pred_ch == 1 and that
data.append(torch.stack(c_list)) still executes.
In `@tests/metrics/test_compute_meandice.py`:
- Around line 334-337: Extend the test_input_dimensions in
tests/metrics/test_compute_meandice.py to add two more invalid-shape cases: (1)
verify DiceMetric(per_component=True) raises ValueError when y has wrong channel
count (e.g., y shape not (B,2,D,H,W) such as torch.ones([3,1,144,144]) or
torch.ones([3,3,144,144]) depending on 2D/3D expectation), and (2) verify
DiceMetric(per_component=True) raises ValueError when y_pred and y have
mismatched channel counts (call DiceMetric(per_component=True)(y_pred, y) where
y_pred has 2 channels and y has 1 or vice versa). Reference the DiceMetric class
and the existing test_input_dimensions to add these assertions so coverage
includes invalid y shapes and channel mismatches.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 2469c914-d7e5-4549-930b-3212056a1266
📒 Files selected for processing (2)
monai/metrics/meandice.pytests/metrics/test_compute_meandice.py
…eck bug Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/metrics/test_compute_meandice.py (1)
256-272: Test data note: batches 1-4 have all-zero tensors.Valid for testing empty GT handling (
ignore_empty=Falsereturns 1.0), but technically not proper one-hot encoding. Consider adding a comment explaining intent.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 256 - 272, TEST_CASE_16 uses y and y_hat where batches 1-4 are all-zero (not proper one-hot), so clarify intent: update the test comment near TEST_CASE_16 to state that y and y_hat intentionally include all-zero batches to validate per_component DiceMetric behavior with ignore_empty=False (expecting 1.0), reference the variables y, y_hat and the test case name TEST_CASE_16; do not change data values, only add a concise comment explaining that these batches are intentionally empty and used to test empty-GT handling.monai/metrics/meandice.py (2)
430-438: Wasteful computation whenper_component=True.Lines 432-435 compute per-channel Dice, but when
per_component=True, line 437 replacesc_listentirely, discarding that work.♻️ Suggested optimization
for b in range(y_pred.shape[0]): c_list = [] - for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: - x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() - x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) if self.per_component: c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] + else: + for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: + x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() + x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] + c_list.append(self.compute_channel(x_pred, x)) data.append(torch.stack(c_list))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 430 - 438, The loop currently computes per-channel Dice via compute_channel for every class into c_list and then, if self.per_component is True, throws that work away by replacing c_list with compute_cc_dice; to fix, short-circuit the per-component path: inside the outer loop over b, check self.per_component first and only call self.compute_cc_dice for that batch (y_pred[b].unsqueeze(0), y[b].unsqueeze(0)) to create c_list, otherwise run the existing per-channel computation using compute_channel; this avoids wasted compute and ensures c_list is only populated by the needed branch.
322-328: Minor: dtype inconsistency and allocation inefficiency.Line 323 creates an intermediate tensor unnecessarily. Also, return dtype depends on platform (sn_label may return int32 or int64), but docstring promises int32.
♻️ Suggested fix
if num == 0: - return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) + return torch.zeros(x.shape, dtype=torch.int32) edt_input = np.ones(cc.shape, dtype=np.uint8) edt_input[cc > 0] = 0 indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) voronoi = cc[tuple(indices)] - return torch.from_numpy(voronoi) + return torch.from_numpy(voronoi.astype(np.int32))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 322 - 328, The early-return creates an unnecessary tensor from x and the final return dtype can vary; change the num==0 branch to directly return a torch tensor of zeros with the same shape as cc and dtype=torch.int32 (avoid torch.from_numpy(x)). After computing voronoi = cc[tuple(indices)], ensure voronoi is cast to a stable 32-bit integer numpy type (e.g., voronoi = voronoi.astype(np.int32)) before converting with torch.from_numpy so the returned tensor is always int32; update the code around variables num, cc, indices, edt_input and voronoi in meandice.py accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/metrics/meandice.py`:
- Line 428: The code currently forces first_ch = 1 when per_component is True,
silently ignoring include_background; update the logic in the MeanDice class
(where first_ch is computed) to detect the conflicting flags
(include_background=True and per_component=True) and emit a clear warning (using
warnings.warn or the module logger) stating that include_background will be
ignored in per_component mode and that first_ch is set to 1; keep the existing
behavior unless you intend to change semantics, but ensure the warning is raised
at construction or first use (e.g., in __init__ or the method computing
first_ch) so users are informed.
---
Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 430-438: The loop currently computes per-channel Dice via
compute_channel for every class into c_list and then, if self.per_component is
True, throws that work away by replacing c_list with compute_cc_dice; to fix,
short-circuit the per-component path: inside the outer loop over b, check
self.per_component first and only call self.compute_cc_dice for that batch
(y_pred[b].unsqueeze(0), y[b].unsqueeze(0)) to create c_list, otherwise run the
existing per-channel computation using compute_channel; this avoids wasted
compute and ensures c_list is only populated by the needed branch.
- Around line 322-328: The early-return creates an unnecessary tensor from x and
the final return dtype can vary; change the num==0 branch to directly return a
torch tensor of zeros with the same shape as cc and dtype=torch.int32 (avoid
torch.from_numpy(x)). After computing voronoi = cc[tuple(indices)], ensure
voronoi is cast to a stable 32-bit integer numpy type (e.g., voronoi =
voronoi.astype(np.int32)) before converting with torch.from_numpy so the
returned tensor is always int32; update the code around variables num, cc,
indices, edt_input and voronoi in meandice.py accordingly.
In `@tests/metrics/test_compute_meandice.py`:
- Around line 256-272: TEST_CASE_16 uses y and y_hat where batches 1-4 are
all-zero (not proper one-hot), so clarify intent: update the test comment near
TEST_CASE_16 to state that y and y_hat intentionally include all-zero batches to
validate per_component DiceMetric behavior with ignore_empty=False (expecting
1.0), reference the variables y, y_hat and the test case name TEST_CASE_16; do
not change data values, only add a concise comment explaining that these batches
are intentionally empty and used to test empty-GT handling.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 623f6db3-2ab1-4bdf-8c74-90860dc9678d
📒 Files selected for processing (2)
monai/metrics/meandice.pytests/metrics/test_compute_meandice.py
Fixes #8733
Description
This PR adds support for connected component-based Dice metric calculation to the existing DiceMetric and DiceHelper classes.
Changes
Reference
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.