Conversation
…to change it, and also using argmax instead of sum and inverting at the correct order to avoid multiplying the effect of the transformation out of proportions
There was a problem hiding this comment.
Pull request overview
This PR fixes issues with the in/out segmentation region selection logic and adds the mix_prob parameter to transform classes. The fix addresses the incorrect handling of multi-channel segmentation masks and reorders operations in transform application.
Changes:
- Fixed
_apply_region_modefunction to correctly handle multi-channel segmentation masks by usingargmaxbefore applying mode inversion - Added
mix_probparameter toRandomInverseGPUandRandomHistogramEqualizationGPUclasses - Reordered operations in
RandomConvTransformGPUto applymix_probmixing before region selection - Updated configuration values for various transform parameters
- Code formatting improvements (trailing whitespace removal, function signature formatting)
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| auglab/transforms/gpu/contrast.py | Fixed segmentation mask handling logic, added mix_prob parameter support, reordered transform operations |
| auglab/transforms/gpu/transforms.py | Added mix_prob parameter to transform instantiations, formatting improvements |
| auglab/configs/transform_params_gpu.json | Updated transform configuration parameters including probabilities and in_seg/out_seg values |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 16 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not isinstance(num_transforms, int) or num_transforms < 0: | ||
| raise ValueError(f"num_transforms must be a non-negative int. Got {num_transforms!r}.") |
There was a problem hiding this comment.
The validation raises a ValueError for negative num_transforms but doesn't handle the case where num_transforms exceeds the length of transforms_list. While the code handles this gracefully at runtime with k = min(self.num_transforms, len(self.transforms_list)), it might be better to warn users at initialization if they provide a num_transforms that exceeds the available transforms.
|
|
||
| for j in idx.tolist(): | ||
| t = self.transforms_list[j] | ||
| if torch.rand(1, device=x.device, dtype=x.dtype) > t.p: |
There was a problem hiding this comment.
The check torch.rand(1, device=x.device, dtype=x.dtype) > t.p uses the image dtype for the random number, which could be problematic if x.dtype is an integer type. This should use torch.float32 or similar floating-point dtype to ensure proper probability comparison.
| if torch.rand(1, device=x.device, dtype=x.dtype) > t.p: | |
| if torch.rand(1, device=x.device, dtype=torch.float32) > t.p: |
| m = torch.argmax(m, axis=1) > 0 | ||
| m = m.to(transformed.dtype) |
There was a problem hiding this comment.
The logic change from torch.sum(m, axis=1) to torch.argmax(m, axis=1) > 0 fundamentally alters the behavior. Previously, the sum would allow multiple overlapping segmentation channels to contribute. Now, argmax selects only the channel with the maximum value, and > 0 converts it to a binary mask. This change is significant and could affect how multiple overlapping segmentation regions are handled. Ensure this is the intended behavior and that it has been tested with multi-channel segmentation masks.
There was a problem hiding this comment.
Technically we should not have overlapping segmentations so both should work I think
| o = torch.randint(0, 2, (seg_mask.shape[0],), device=seg_mask.device, dtype=seg_mask.dtype) | ||
| m = m * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m | ||
| m = torch.sum(m, axis=0) | ||
| m = torch.argmax(m, axis=0) > 0 |
There was a problem hiding this comment.
The same logic change from torch.sum(m, axis=0) to torch.argmax(m, axis=0) > 0 appears here for 3D images. This fundamentally changes how overlapping segmentation channels are handled, converting from a sum of all channels to selecting only the dominant channel. Ensure this behavior change is intentional and properly tested.
| m = torch.argmax(m, axis=0) > 0 | |
| m = torch.sum(m, axis=0) > 0 |
| """Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version). | ||
|
|
||
| Args: | ||
| transforms_list: List of initialized ImageOnlyTransform to choose from. | ||
| num_transforms: Number of transforms to randomly select and apply. | ||
| same_on_batch: apply the same transformation across the batch. | ||
| p: probability for applying the X transforms to a batch. This param controls the augmentation | ||
| probabilities batch-wise. | ||
| keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch | ||
| form ``False``. |
There was a problem hiding this comment.
The new class RandomChooseXTransformsGPU is missing comprehensive documentation. The docstring should explain the behavior when same_on_batch=False (applies different random selections per batch item), how the transforms are selected (without replacement using randperm), and what happens when a selected transform's probability check fails (it's skipped). Additionally, document the kwargs parameter or remove it if unused.
| """Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version). | |
| Args: | |
| transforms_list: List of initialized ImageOnlyTransform to choose from. | |
| num_transforms: Number of transforms to randomly select and apply. | |
| same_on_batch: apply the same transformation across the batch. | |
| p: probability for applying the X transforms to a batch. This param controls the augmentation | |
| probabilities batch-wise. | |
| keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch | |
| form ``False``. | |
| """Randomly choose and apply a subset of transforms from a given list (GPU version). | |
| This transform samples up to ``num_transforms`` transforms from ``transforms_list`` **without | |
| replacement** for each application. The sampling is implemented by taking a random | |
| permutation of the available transforms via :func:`torch.randperm` and selecting the | |
| first ``num_transforms`` indices. | |
| For each selected transform ``t``, a separate probability check is performed against | |
| ``t.p``. If the random draw for ``t`` is greater than ``t.p``, that transform is | |
| **skipped** and not applied, and no additional transform is selected to replace it. | |
| This means that in practice fewer than ``num_transforms`` transforms may be applied | |
| to a given sample. | |
| When ``same_on_batch=True``, the same subset of transforms (with the same random | |
| outcomes of their internal sampling logic) is used for the entire batch: one | |
| selection is made and applied to all batch elements. | |
| When ``same_on_batch=False``, each element in the batch is processed independently: | |
| for every item, a new random subset of transforms is sampled (still without | |
| replacement), and each selected transform performs its own probability check and | |
| random parameter sampling. | |
| Args: | |
| transforms_list: List of initialized :class:`ImageOnlyTransform` instances to | |
| choose from. | |
| num_transforms: Maximum number of transforms to randomly select (without | |
| replacement) and attempt to apply to each sample. | |
| same_on_batch: If ``True``, apply the same randomly selected subset of transforms | |
| to every element in the batch. If ``False``, sample a separate subset for | |
| each batch element. | |
| p: Probability of applying this *composite* transform to a batch. This controls | |
| whether the selection-and-application process runs at all for a given call, | |
| independently from the per-transform probabilities ``t.p``. | |
| keepdim: Whether to keep the output shape the same as the input (``True``) or | |
| broadcast it to the batch form (``False``). | |
| **kwargs: Additional keyword arguments accepted for API compatibility. They are | |
| currently ignored and do not affect the behavior of this transform. |
| if seg is not None and isinstance(seg, torch.Tensor) and seg.shape[0] == batch_size: | ||
| seg_i = seg[i : i + 1] | ||
| else: | ||
| seg_i = seg |
There was a problem hiding this comment.
In the per-batch iteration (lines 351-359), when seg is not None and its batch size matches input.shape[0], individual slices are extracted correctly. However, when seg exists but has a different batch size (line 357), the code falls back to using the entire seg for each batch item. This could lead to shape mismatches or incorrect segmentation application. Consider adding validation to ensure seg has the correct shape or raising an error if the shapes don't match.
| if mix_in_out: | ||
| for i in range(seg_mask.shape[0]): | ||
| # Create a tensor with random one and zero | ||
|
|
||
| o = torch.randint(0, 2, (seg_mask.shape[1],), device=seg_mask.device, dtype=seg_mask.dtype) | ||
| m[i] = m[i] * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m |
There was a problem hiding this comment.
When mix_in_out=True, the code randomly zeros out some segmentation channels before applying argmax. However, if all channels get zeroed out for a sample (which has probability 2^(-num_channels)), then argmax(m, axis=1) > 0 will be False everywhere, effectively making the entire mask zero. This edge case should be handled, or the probability should be adjusted to ensure at least one channel remains active.
No description provided.