diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 5af082e2b0..14c37be860 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -24,7 +24,7 @@ from monai.data import MetaTensor from monai.networks.layers import GaussianFilter from monai.transforms.transform import MapTransform, Randomizable, Transform -from monai.utils import min_version, optional_import +from monai.utils import deprecated, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -84,18 +84,44 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda return d -class NormalizeLabelsInDatasetd(MapTransform): +class RemapLabelsToSequentiald(MapTransform): + """ + Remap label values from a dataset-specific schema to sequential indices (0, 1, 2, 3, ...). + + This transform takes labels with arbitrary values defined in a label dictionary and remaps them + to a sequential range starting from 1 (with background always set to 0). This is useful for + standardizing labels across different datasets or ensuring labels are in a contiguous range. + + The output label indices are assigned in alphabetical order by label name to ensure + deterministic behavior regardless of input dictionary ordering. + + Args: + keys: The ``keys`` parameter will be used to get and set the actual data item to transform + label_names: Dictionary mapping label names to their current values in the dataset. + For example: {"spleen": 1, "liver": 6, "background": 0} + Will be remapped to: {"background": 0, "liver": 1, "spleen": 2} + (alphabetically sorted, excluding background) + allow_missing_keys: If True, missing keys in the data dictionary will not raise an error + + Example: + >>> transform = RemapLabelsToSequentiald( + ... keys="label", + ... label_names={"liver": 6, "spleen": 1, "background": 0} + ... ) + >>> # Input label has values [0, 1, 6] + >>> # Output label will have values [0, 1, 2] (background=0, liver=1, spleen=2) + >>> # And updates d["label_names"] to {"background": 0, "liver": 1, "spleen": 2} + + Note: + - Background label (if present) is always mapped to 0 + - Non-background labels are mapped to sequential indices 1, 2, 3, ... in alphabetical order + - Undefined labels (not in label_names) will be set to 0 (background) + - The transform updates the data dictionary with a new "label_names" key containing the remapped values + """ def __init__( self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False ): - """ - Normalize label values according to label names dictionary - - Args: - keys: The ``keys`` parameter will be used to get and set the actual data item to transform - label_names: all label names - """ super().__init__(keys, allow_missing_keys) self.label_names = label_names or {} @@ -106,13 +132,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda # Dictionary containing new label numbers new_label_names = {} label = np.zeros(d[key].shape) - # Making sure the range values and number of labels are the same - for idx, (key_label, val_label) in enumerate(self.label_names.items(), start=1): - if key_label != "background": - new_label_names[key_label] = idx - label[d[key] == val_label] = idx - if key_label == "background": - new_label_names["background"] = 0 + + # Sort label names to ensure deterministic ordering (exclude background) + sorted_labels = sorted([(k, v) for k, v in self.label_names.items() if k != "background"]) + + # Always set background to 0 first + if "background" in self.label_names: + new_label_names["background"] = 0 + + # Assign sequential indices to sorted non-background labels + for idx, (key_label, val_label) in enumerate(sorted_labels, start=1): + new_label_names[key_label] = idx + label[d[key] == val_label] = idx d["label_names"] = new_label_names if isinstance(d[key], MetaTensor): @@ -122,6 +153,20 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda return d +@deprecated(since="1.6", removed="1.8", msg_suffix="Use `RemapLabelsToSequentiald` instead.") +class NormalizeLabelsInDatasetd(RemapLabelsToSequentiald): + """ + .. deprecated:: 1.6.0 + `NormalizeLabelsInDatasetd` is deprecated and will be removed in version 1.8.0. + Use :class:`RemapLabelsToSequentiald` instead. + + This class is maintained for backward compatibility. Please use RemapLabelsToSequentiald + which better describes the transform's functionality. + """ + + pass + + class SingleLabelSelectiond(MapTransform): def __init__( diff --git a/tests/apps/deepedit/test_deepedit_transforms.py b/tests/apps/deepedit/test_deepedit_transforms.py index 18d6567fd7..db4d872d56 100644 --- a/tests/apps/deepedit/test_deepedit_transforms.py +++ b/tests/apps/deepedit/test_deepedit_transforms.py @@ -25,6 +25,7 @@ FindAllValidSlicesMissingLabelsd, FindDiscrepancyRegionsDeepEditd, NormalizeLabelsInDatasetd, + RemapLabelsToSequentiald, ResizeGuidanceMultipleLabelDeepEditd, SingleLabelSelectiond, SplitPredsLabeld, @@ -282,6 +283,100 @@ def test_correct_results(self, arguments, input_data, expected_result): result = add_fn(input_data) self.assertEqual(len(np.unique(result["label"])), expected_result) + def test_ordering_determinism(self): + """Test that different input ordering produces the same output (alphabetical)""" + # Create a label array with different label values + label = np.array([[[0, 1, 6, 3]]]) # background=0, spleen=1, liver=6, kidney=3 + + # Test case 1: liver first, then kidney, then spleen + data1 = {"label": label.copy()} + transform1 = RemapLabelsToSequentiald( + keys="label", label_names={"liver": 6, "kidney": 3, "spleen": 1, "background": 0} + ) + result1 = transform1(data1) + + # Test case 2: spleen first, then kidney, then liver (different order) + data2 = {"label": label.copy()} + transform2 = RemapLabelsToSequentiald( + keys="label", label_names={"spleen": 1, "kidney": 3, "liver": 6, "background": 0} + ) + result2 = transform2(data2) + + # Both should produce the same output (alphabetically sorted) + # Expected mapping: background=0, kidney=1, liver=2, spleen=3 + np.testing.assert_array_equal(result1["label"], result2["label"]) + + # Verify the actual mapping is alphabetical + expected_output = np.array([[[0, 3, 2, 1]]]) # kidney=1, liver=2, spleen=3, background=0 + np.testing.assert_array_equal(result1["label"], expected_output) + + # Verify label_names is correct + self.assertEqual(result1["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3}) + self.assertEqual(result2["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3}) + + def test_multiple_labels(self): + """Test with multiple non-background labels""" + label = np.array([[[0, 1, 2, 5]]]) # background, spleen, kidney, liver + data = {"label": label.copy()} + transform = RemapLabelsToSequentiald( + keys="label", label_names={"spleen": 1, "kidney": 2, "liver": 5, "background": 0} + ) + result = transform(data) + + # Expected: background=0, kidney=1, liver=2, spleen=3 (alphabetical) + expected = np.array([[[0, 3, 1, 2]]]) + np.testing.assert_array_equal(result["label"], expected) + self.assertEqual(result["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3}) + + def test_deprecated_name_warning(self): + """Test that NormalizeLabelsInDatasetd is properly deprecated. + + The deprecation warning only triggers when MONAI version >= 1.6 (since="1.6"). + This test verifies: + 1. The actual NormalizeLabelsInDatasetd class is marked as deprecated in docstring + 2. The class is a subclass of RemapLabelsToSequentiald + 3. The deprecation mechanism works correctly (tested via version_val simulation) + 4. The actual class functions correctly + """ + import warnings + + from monai.utils import deprecated + + # Verify NormalizeLabelsInDatasetd docstring indicates deprecation + self.assertIn("deprecated", NormalizeLabelsInDatasetd.__doc__.lower()) + self.assertIn("RemapLabelsToSequentiald", NormalizeLabelsInDatasetd.__doc__) + + # Verify NormalizeLabelsInDatasetd is a subclass of RemapLabelsToSequentiald + self.assertTrue(issubclass(NormalizeLabelsInDatasetd, RemapLabelsToSequentiald)) + + # Test the deprecation mechanism using version_val to simulate version 1.6 + # This verifies the @deprecated decorator behavior that NormalizeLabelsInDatasetd uses + @deprecated( + since="1.6", + removed="1.8", + msg_suffix="Use `RemapLabelsToSequentiald` instead.", + version_val="1.6", # Simulate version 1.6 to trigger warning + ) + class DeprecatedNormalizeLabels(RemapLabelsToSequentiald): + pass + + data = {"label": np.array([[[0, 1]]])} + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + transform = DeprecatedNormalizeLabels(keys="label", label_names={"spleen": 1, "background": 0}) + _ = transform(data) + + # Check that a deprecation warning was raised + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, FutureWarning)) + self.assertIn("RemapLabelsToSequentiald", str(w[0].message)) + + # Verify the actual NormalizeLabelsInDatasetd class works correctly + transform_actual = NormalizeLabelsInDatasetd(keys="label", label_names={"spleen": 1, "background": 0}) + result = transform_actual({"label": np.array([[[0, 1]]])}) + self.assertIn("label", result) + class TestResizeGuidanceMultipleLabelCustomd(unittest.TestCase):