diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 3bca0c9c2..52530c93b 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -386,6 +386,29 @@ def _translate(self, ptr, from_rank, to_rank): return translated_ptr + @gluon.jit + def as_remote(self, ptr, rank): + """ + Translate a local pointer to point at the target rank's copy. + + Convenience wrapper over ``_translate`` that fills in ``cur_rank`` + automatically. Returns a pointer usable directly with ``gl.load`` + / ``gl.store``. + + Args: + ptr: Pointer in the current rank's address space + rank: Target rank ID + + Returns: + Translated pointer in the target rank's address space + + Example:: + + remote_ptr = ctx.as_remote(buf + offsets, target_rank) + data = gl.load(remote_ptr, mask=mask) + """ + return self._translate(ptr, self.cur_rank, rank) + @gluon.jit def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False): """ @@ -1358,6 +1381,54 @@ def is_symmetric(self, tensor: torch.Tensor) -> bool: """ return self.heap.is_symmetric(tensor) + def as_remote(self, tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Return a zero-copy view of a symmetric tensor pointing to the target rank's copy. + + Takes a tensor allocated on the symmetric heap and returns a new tensor with + the same shape, dtype, and strides, but whose ``data_ptr()`` points to the + corresponding location in the target rank's heap. This is useful for hoisting + pointer translation out of loops or passing pre-translated pointers to kernels. + + Args: + tensor (torch.Tensor): A tensor on the symmetric heap + rank (int): Target rank whose copy to point at + + Returns: + torch.Tensor: A view pointing to the target rank's symmetric heap + + Raises: + ValueError: If tensor is not on the symmetric heap or rank is out of range + + Example: + >>> import iris.experimental.iris_gluon as iris_gl + >>> ctx = iris_gl.iris(heap_size=2**30) + >>> buf = ctx.zeros(1024, dtype=torch.float32) + >>> remote_buf = ctx.as_remote(buf, target_rank) + >>> # remote_buf.data_ptr() now points to target_rank's copy + """ + if not self.is_symmetric(tensor): + raise ValueError("as_remote requires a tensor on the symmetric heap") + if rank < 0 or rank >= self.num_ranks: + raise ValueError(f"rank {rank} out of range [0, {self.num_ranks})") + + local_base = int(self.heap.heap_bases[self.cur_rank].item()) + remote_base = int(self.heap.heap_bases[rank].item()) + offset = tensor.data_ptr() - local_base + remote_ptr = remote_base + offset + + elem_size = tensor.element_size() + if tensor.numel() == 0: + storage_bytes = 0 + else: + max_offset = sum((s - 1) * st for s, st in zip(tensor.shape, tensor.stride())) + storage_bytes = (max_offset + 1) * elem_size + + from iris.tensor_utils import tensor_from_ptr + + flat = tensor_from_ptr(remote_ptr, storage_bytes, dtype=tensor.dtype, device=str(tensor.device)) + return torch.as_strided(flat, tensor.shape, tensor.stride()) + def iris(heap_size=1 << 30): """ diff --git a/iris/iris.py b/iris/iris.py index 8c750ba67..f2359cf6b 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -629,6 +629,53 @@ def is_symmetric(self, tensor: torch.Tensor) -> bool: """ return self.heap.is_symmetric(tensor) + def as_remote(self, tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Return a zero-copy view of a symmetric tensor pointing to the target rank's copy. + + Takes a tensor allocated on the symmetric heap and returns a new tensor with + the same shape, dtype, and strides, but whose ``data_ptr()`` points to the + corresponding location in the target rank's heap. This is useful for hoisting + pointer translation out of loops or passing pre-translated pointers to kernels. + + Args: + tensor (torch.Tensor): A tensor on the symmetric heap + rank (int): Target rank whose copy to point at + + Returns: + torch.Tensor: A view pointing to the target rank's symmetric heap + + Raises: + ValueError: If tensor is not on the symmetric heap or rank is out of range + + Example: + >>> ctx = iris.iris(heap_size=2**30) + >>> buf = ctx.zeros(1024, dtype=torch.float32) + >>> remote_buf = ctx.as_remote(buf, target_rank) + >>> # remote_buf.data_ptr() now points to target_rank's copy + """ + if not self.is_symmetric(tensor): + raise ValueError("as_remote requires a tensor on the symmetric heap") + if rank < 0 or rank >= self.num_ranks: + raise ValueError(f"rank {rank} out of range [0, {self.num_ranks})") + + local_base = int(self.heap.heap_bases[self.cur_rank].item()) + remote_base = int(self.heap.heap_bases[rank].item()) + offset = tensor.data_ptr() - local_base + remote_ptr = remote_base + offset + + elem_size = tensor.element_size() + if tensor.numel() == 0: + storage_bytes = 0 + else: + max_offset = sum((s - 1) * st for s, st in zip(tensor.shape, tensor.stride())) + storage_bytes = (max_offset + 1) * elem_size + + from iris.tensor_utils import tensor_from_ptr + + flat = tensor_from_ptr(remote_ptr, storage_bytes, dtype=tensor.dtype, device=str(tensor.device)) + return torch.as_strided(flat, tensor.shape, tensor.stride()) + def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): """ Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. diff --git a/tests/unittests/test_as_remote.py b/tests/unittests/test_as_remote.py new file mode 100644 index 000000000..962c289ac --- /dev/null +++ b/tests/unittests/test_as_remote.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for as_remote() API — both host-side (Iris, IrisGluon) and device-side (IrisDeviceCtx). +""" + +import gc + +import torch +import pytest +import iris +import iris.experimental.iris_gluon as iris_gl +from triton.experimental import gluon +from triton.experimental.gluon import language as gl + + +# --------------------------------------------------------------------------- +# Host-side tests (single-process, no torchrun required) +# --------------------------------------------------------------------------- + + +class TestAsRemoteHostIris: + """Host-side as_remote() tests using the Iris (Triton) backend.""" + + def test_host_basic(self): + """as_remote returns a tensor with matching shape/dtype/strides but different data_ptr.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + num_ranks = ctx.num_ranks + if num_ranks < 2: + pytest.skip("Need >= 2 ranks") + + buf = ctx.zeros(128, dtype=torch.float32) + target = (ctx.cur_rank + 1) % num_ranks + remote = ctx.as_remote(buf, target) + + assert remote.shape == buf.shape + assert remote.dtype == buf.dtype + assert remote.stride() == buf.stride() + assert remote.data_ptr() != buf.data_ptr() + + def test_host_pointer_math(self): + """Offset from respective heap base must be identical.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + num_ranks = ctx.num_ranks + if num_ranks < 2: + pytest.skip("Need >= 2 ranks") + + buf = ctx.zeros(64, dtype=torch.float32) + target = (ctx.cur_rank + 1) % num_ranks + + local_base = int(ctx.heap.heap_bases[ctx.cur_rank].item()) + remote_base = int(ctx.heap.heap_bases[target].item()) + + remote = ctx.as_remote(buf, target) + assert remote.data_ptr() - remote_base == buf.data_ptr() - local_base + + def test_host_self_rank(self): + """as_remote(tensor, cur_rank) returns a tensor with the same data_ptr.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + buf = ctx.zeros(64, dtype=torch.float32) + remote = ctx.as_remote(buf, ctx.cur_rank) + assert remote.data_ptr() == buf.data_ptr() + assert remote.shape == buf.shape + + def test_host_non_symmetric_raises(self): + """as_remote on a non-symmetric tensor raises ValueError.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + external = torch.zeros(64, dtype=torch.float32, device="cuda") + with pytest.raises(ValueError, match="symmetric heap"): + ctx.as_remote(external, 0) + + def test_host_rank_out_of_range(self): + """as_remote with invalid rank raises ValueError.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + buf = ctx.zeros(64, dtype=torch.float32) + with pytest.raises(ValueError, match="out of range"): + ctx.as_remote(buf, ctx.num_ranks) + with pytest.raises(ValueError, match="out of range"): + ctx.as_remote(buf, -1) + + def test_host_non_contiguous(self): + """as_remote preserves strides of a non-contiguous (sliced) tensor.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + buf_2d = ctx.zeros(16, 16, dtype=torch.float32) + sliced = buf_2d[::2, ::2] # non-contiguous view + assert not sliced.is_contiguous() + + remote = ctx.as_remote(sliced, ctx.cur_rank) + assert remote.shape == sliced.shape + assert remote.stride() == sliced.stride() + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_host_multi_dtype(self, dtype): + """as_remote works across multiple dtypes.""" + ctx = iris.iris(1 << 20, allocator_type="torch") + buf = ctx.zeros(64, dtype=dtype) + remote = ctx.as_remote(buf, ctx.cur_rank) + assert remote.dtype == dtype + assert remote.shape == buf.shape + + +# --------------------------------------------------------------------------- +# Host-side tests for IrisGluon backend (multi-GPU, needs torchrun) +# --------------------------------------------------------------------------- + + +class TestAsRemoteHostGluon: + """Host-side as_remote() tests using the IrisGluon (Gluon) backend.""" + + def test_host_basic(self): + """as_remote returns a tensor with matching shape/dtype/strides but different data_ptr.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + if num_ranks < 2: + ctx.barrier() + del ctx + gc.collect() + pytest.skip("Need >= 2 ranks") + + buf = ctx.zeros(128, dtype=torch.float32) + target = (cur_rank + 1) % num_ranks + remote = ctx.as_remote(buf, target) + + assert remote.shape == buf.shape + assert remote.dtype == buf.dtype + assert remote.stride() == buf.stride() + assert remote.data_ptr() != buf.data_ptr() + + ctx.barrier() + del ctx + gc.collect() + + def test_host_pointer_math(self): + """Offset from respective heap base must be identical.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + if num_ranks < 2: + ctx.barrier() + del ctx + gc.collect() + pytest.skip("Need >= 2 ranks") + + buf = ctx.zeros(64, dtype=torch.float32) + target = (cur_rank + 1) % num_ranks + + local_base = int(ctx.heap.heap_bases[cur_rank].item()) + remote_base = int(ctx.heap.heap_bases[target].item()) + + remote = ctx.as_remote(buf, target) + assert remote.data_ptr() - remote_base == buf.data_ptr() - local_base + + ctx.barrier() + del ctx + gc.collect() + + def test_host_self_rank(self): + """as_remote(tensor, cur_rank) returns a tensor with the same data_ptr.""" + ctx = iris_gl.iris(1 << 20) + cur_rank = ctx.get_rank() + + buf = ctx.zeros(64, dtype=torch.float32) + remote = ctx.as_remote(buf, cur_rank) + assert remote.data_ptr() == buf.data_ptr() + assert remote.shape == buf.shape + + ctx.barrier() + del ctx + gc.collect() + + def test_host_non_symmetric_raises(self): + """as_remote on a non-symmetric tensor raises ValueError.""" + ctx = iris_gl.iris(1 << 20) + external = torch.zeros(64, dtype=torch.float32, device="cuda") + with pytest.raises(ValueError, match="symmetric heap"): + ctx.as_remote(external, 0) + + ctx.barrier() + del ctx + gc.collect() + + def test_host_rank_out_of_range(self): + """as_remote with invalid rank raises ValueError.""" + ctx = iris_gl.iris(1 << 20) + buf = ctx.zeros(64, dtype=torch.float32) + with pytest.raises(ValueError, match="out of range"): + ctx.as_remote(buf, ctx.get_num_ranks()) + with pytest.raises(ValueError, match="out of range"): + ctx.as_remote(buf, -1) + + ctx.barrier() + del ctx + gc.collect() + + def test_host_non_contiguous(self): + """as_remote preserves strides of a non-contiguous (sliced) tensor.""" + ctx = iris_gl.iris(1 << 20) + buf_2d = ctx.zeros(16, 16, dtype=torch.float32) + sliced = buf_2d[::2, ::2] + assert not sliced.is_contiguous() + + remote = ctx.as_remote(sliced, ctx.get_rank()) + assert remote.shape == sliced.shape + assert remote.stride() == sliced.stride() + + ctx.barrier() + del ctx + gc.collect() + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_host_multi_dtype(self, dtype): + """as_remote works across multiple dtypes.""" + ctx = iris_gl.iris(1 << 20) + buf = ctx.zeros(64, dtype=dtype) + remote = ctx.as_remote(buf, ctx.get_rank()) + assert remote.dtype == dtype + assert remote.shape == buf.shape + + ctx.barrier() + del ctx + gc.collect() + + +# --------------------------------------------------------------------------- +# Device-side tests (multi-GPU, needs torchrun) +# --------------------------------------------------------------------------- + + +@gluon.jit +def as_remote_read_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + source_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, +): + """Read from a remote rank using ctx.as_remote + gl.load.""" + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + # Translate pointer then load directly (instead of ctx.load) + remote_ptr = ctx.as_remote(data + offsets, partner) + result = gl.load(remote_ptr, mask=mask) + gl.store(results + offsets, result, mask=mask) + + +@gluon.jit +def as_remote_write_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + destination_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, +): + """Write to a remote rank using ctx.as_remote + gl.store.""" + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + value = gl.load(data + offsets, mask=mask) + + # Translate pointer then store directly (instead of ctx.store) + for dst_rank in range(num_ranks): + remote_ptr = ctx.as_remote(results + offsets, dst_rank) + gl.store(remote_ptr, value, mask=mask) + + +@pytest.mark.parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32], +) +@pytest.mark.parametrize("BLOCK_SIZE", [16, 32]) +def test_device_as_remote_read(dtype, BLOCK_SIZE): + """Rank reads from its partner using ctx.as_remote + gl.load.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + source_rank = ctx.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + data = ctx.full((BLOCK_SIZE,), source_rank, dtype=dtype) + results = ctx.zeros_like(data) + + ctx.barrier() + + as_remote_read_kernel[(1,)]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + source_rank, + num_ranks, + BLOCK_SIZE, + num_warps=1, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=dtype, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + ctx.barrier() + del ctx + gc.collect() + + +@pytest.mark.parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32], +) +@pytest.mark.parametrize("BLOCK_SIZE", [16, 32]) +def test_device_as_remote_write(dtype, BLOCK_SIZE): + """Each rank writes 1s to all ranks using ctx.as_remote + gl.store.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + destination_rank = ctx.get_rank() + + src = ctx.ones(BLOCK_SIZE, dtype=dtype) + results = ctx.zeros_like(src) + + ctx.barrier() + + as_remote_write_kernel[(1,)]( + iris_gl.IrisDeviceCtx, + context_tensor, + src, + results, + destination_rank, + num_ranks, + BLOCK_SIZE, + num_warps=1, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=dtype, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + ctx.barrier() + del ctx + gc.collect()