-
Notifications
You must be signed in to change notification settings - Fork 655
Enable arithmetic operations between device tensors/batches and scalars #6143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Rostan Tabet <[email protected]>
Signed-off-by: Rostan Tabet <[email protected]>
Signed-off-by: Rostan Tabet <[email protected]>
Greptile SummaryThis PR enables arithmetic operations between GPU tensors/batches and scalars by consolidating the
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Tensor/Batch
participant _arithm_op
participant as_tensor
participant _arithmetic_generic_op
User->>Tensor/Batch: gpu_tensor + scalar
Tensor/Batch->>_arithm_op: __add__(gpu_tensor, scalar)
Note over _arithm_op: Check all args for GPU tensors
_arithm_op->>_arithm_op: gpu = any(arg.device == "gpu"<br/>for Tensor/Batch args)
alt Scalar argument found
_arithm_op->>_arithm_op: Check if implicitly convertible
alt GPU mode & not convertible
_arithm_op-->>User: ValueError: not implicitly copyable
else Convertible
_arithm_op->>as_tensor: as_tensor(scalar, device="gpu")
as_tensor-->>_arithm_op: GPU tensor
end
end
Note over _arithm_op: Verify no CPU/GPU mixing
_arithm_op->>_arithm_op: Check all new_args devices match
alt Device mismatch
_arithm_op-->>User: ValueError: Cannot mix GPU and CPU
else All same device
_arithm_op->>_arithmetic_generic_op: Execute operation
_arithmetic_generic_op-->>User: Result tensor
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
dali/python/nvidia/dali/experimental/dynamic/_arithmetic.py, line 36 (link)syntax: Typo: "implictly" should be "implicitly"
-
dali/python/nvidia/dali/experimental/dynamic/_arithmetic.py, line 30-46 (link)logic: Logic bug: when a scalar appears before a GPU tensor (e.g.,
3 + gpu_tensor), the scalar gets converted to CPU before detecting the GPU tensor, causing the final check to fail with "Cannot mix GPU and CPU inputs."The algorithm needs two passes:
- First pass: scan all args to detect if any GPU tensor exists
- Second pass: convert scalars to appropriate device
Example failure case:
gpu_tensor = ndd.tensor([1, 2, 3], device="gpu") result = 5 + gpu_tensor # Will raise ValueError
-
dali/test/python/experimental_mode/test_arithm_ops.py, line 128 (link)style: Test only covers
tensor + scalarbut notscalar + tensor. Add reverse operation tests:# Also test scalar + tensor result_reversed = ndd.as_tensor(apply_bin_op(op, scalar, x)) ref_reversed = apply_bin_op(op, scalar, tensor) if not np.allclose(result_reversed.cpu(), ref_reversed): msg = f"{scalar} {op} {tensor} = \n{result_reversed}\n!=\n{ref_reversed}" raise AssertionError(msg)
5 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
dali/python/nvidia/dali/experimental/dynamic/_arithmetic.py, line 20 (link)syntax: using
|operator inisinstance()requires Python 3.10+, butpyproject.tomltargets Python 3.8+. Use tuple syntax instead for compatibility.
5 files reviewed, 1 comment
c0cc95a to
2265d73
Compare
Signed-off-by: Rostan Tabet <[email protected]>
2265d73 to
8d39fde
Compare
Category: Bug fix (non-breaking change which fixes an issue)
Description:
This PR enables arithmetic operations between tensors/batches and scalars. Previously,
x + nworked ifxwas a CPU tensor andna scalar, but not ifxwas a GPU tensor.Arithmetic operations between a GPU tensor/batch and a Python list or tuple are also supported. Tensor types need to be explicitly copied.
Additional information:
Affected modules and functionalities:
Dynamic mode tensors and batches.
Key points relevant for the review:
Do arithmetic operations work the way we intend to?
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-4545