From f396003bc1572c3fe874a08dd236cf6a7d47bf96 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 22 Jan 2026 09:55:54 -0800 Subject: [PATCH 1/4] [slimtensor] Add CUDA Storage with DeviceTraits and memory allocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/16769 This diff adds CUDA storage infrastructure to SlimTensor, enabling GPU memory allocation and management. **Key changes:** 1. **`cuda/Guard.h`** - CUDAGuard RAII class: - Saves current CUDA device on construction, restores on destruction - Exception-safe device context switching - Constructors accept device index or Device object 2. **`core/Storage.h`** - Extended for CUDA support: - Added `DeviceTraits` specialization with: - `allocate()` - Uses cudaMalloc with CUDAGuard for device selection - `free()` - Uses cudaFree with warning on error - `memcpy()` - Supports Host↔Device and Device↔Device copies - Added `DEFAULT_CUDA_DEVICE` constant - Updated `MaybeOwningStorage` constructor to handle CUDA devices - Stub implementation when `CUDA_AVAILABLE` is not defined (throws error) ghstack-source-id: 335102161 @exported-using-ghexport Differential Revision: [D91202899](https://our.internmc.facebook.com/intern/diff/D91202899/) --- backends/aoti/slim/c10/cuda/Exception.h | 40 +++ backends/aoti/slim/c10/cuda/TARGETS | 6 + backends/aoti/slim/c10/cuda/targets.bzl | 16 + backends/aoti/slim/core/Storage.h | 141 +++++++- backends/aoti/slim/core/targets.bzl | 4 +- backends/aoti/slim/core/test/targets.bzl | 37 +- backends/aoti/slim/core/test/test_storage.cpp | 332 ++++++++++++++---- backends/cuda/runtime/TARGETS | 22 ++ 8 files changed, 518 insertions(+), 80 deletions(-) create mode 100644 backends/aoti/slim/c10/cuda/Exception.h create mode 100644 backends/aoti/slim/c10/cuda/TARGETS create mode 100644 backends/aoti/slim/c10/cuda/targets.bzl diff --git a/backends/aoti/slim/c10/cuda/Exception.h b/backends/aoti/slim/c10/cuda/Exception.h new file mode 100644 index 00000000000..33d8414e661 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/Exception.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef CUDA_AVAILABLE + +#include +#include + +#include +#include +#include + +/// Checks a CUDA expression and aborts on error. +/// @param EXPR The CUDA expression to check. +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + ET_CHECK_MSG( \ + __err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \ + } while (0) + +/// Checks a CUDA expression and logs a warning on error (non-fatal). +/// @param EXPR The CUDA expression to check. +#define ET_CUDA_LOG_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (SLIMTENSOR_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif // CUDA_AVAILABLE diff --git a/backends/aoti/slim/c10/cuda/TARGETS b/backends/aoti/slim/c10/cuda/TARGETS new file mode 100644 index 00000000000..08e83a5f3c4 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/TARGETS @@ -0,0 +1,6 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/c10/cuda/targets.bzl b/backends/aoti/slim/c10/cuda/targets.bzl new file mode 100644 index 00000000000..1d44bd1f032 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/targets.bzl @@ -0,0 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor CUDA exception handling module.""" + + runtime.cxx_library( + name = "exception", + exported_headers = [ + "Exception.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/backends/aoti/slim/c10/macros:macros", + "//executorch/runtime/platform:platform", + ], + ) diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index d122e86c1d4..6718f04cb51 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -10,12 +10,18 @@ #include +#ifdef CUDA_AVAILABLE +#include +#include +#endif + #include #include #include #include #include #include +#include namespace executorch::backends::aoti::slim { @@ -30,6 +36,10 @@ inline void noop(void*) {} /// Default CPU device constant. inline const c10::Device CPU_DEVICE = c10::Device(c10::DeviceType::CPU, 0); +/// Default CUDA device constant. +inline const c10::Device DEFAULT_CUDA_DEVICE = + c10::Device(c10::DeviceType::CUDA, 0); + /// DeviceTraits template for device-specific operations. /// Device-specific implementations provide allocate(), free(), and memcpy(). template @@ -74,6 +84,119 @@ struct DeviceTraits { } }; +#ifdef CUDA_AVAILABLE +/// CUDA specialization of DeviceTraits. +/// Provides CUDA memory allocation and copy operations using +/// cudaMallocAsync/cudaFreeAsync with proper stream handling. +/// +/// IMPORTANT: Callers are expected to set the correct CUDA device and stream +/// using CUDAStreamGuard before calling these methods. This is consistent +/// with PyTorch's CUDACachingAllocator design pattern where the allocator +/// assumes the caller has already set the correct device context. +template <> +struct DeviceTraits { + /// Allocates CUDA device memory on the current stream. + /// Uses cudaMallocAsync for asynchronous allocation on the stream + /// that is currently set via CUDAStreamGuard, similar to how + /// PyTorch's CUDACachingAllocator works. + /// + /// NOTE: Caller must ensure the correct device is already set via + /// CUDAStreamGuard. This function does NOT create a device guard internally. + /// + /// @param nbytes Number of bytes to allocate. + /// @param device The target CUDA device (used to get the stream). + /// @return Pointer to allocated device memory. + static void* allocate(size_t nbytes, const c10::Device& device) { + // Get the current stream for this device (set by CUDAStreamGuard if any) + // This follows PyTorch's pattern where the allocator assumes the caller + // has already set the correct device via CUDAStreamGuard. + auto stream_result = + executorch::backends::cuda::getCurrentCUDAStream(device.index()); + ET_CHECK_MSG( + stream_result.ok(), + "Failed to get current CUDA stream for device %d", + static_cast(device.index())); + + cudaStream_t stream = stream_result.get(); + void* data = nullptr; + ET_CUDA_CHECK(cudaMallocAsync(&data, nbytes, stream)); + return data; + } + + /// Frees CUDA device memory on the current stream. + /// @param ptr Pointer to device memory to free. + static void free(void* ptr) { + // Get the current stream for the current device + auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1); + if (stream_result.ok()) { + ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get())); + } else { + // Fallback to synchronous free if we can't get the stream + ET_CUDA_LOG_WARN(cudaFree(ptr)); + } + } + + /// Copies memory between CPU and CUDA or CUDA and CUDA. + /// @param dst Destination pointer. + /// @param src Source pointer. + /// @param nbytes Number of bytes to copy. + /// @param dst_device Destination device. + /// @param src_device Source device. + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device) { + cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; + + if (src_device.is_cpu()) { + direction = cudaMemcpyHostToDevice; + } else if (dst_device.is_cpu()) { + direction = cudaMemcpyDeviceToHost; + } else { + ET_CHECK_MSG( + src_device.index() == dst_device.index(), + "CUDA memcpy across different device indices not supported: %d != %d", + static_cast(src_device.index()), + static_cast(dst_device.index())); + } + + ET_CUDA_CHECK(cudaMemcpy(dst, src, nbytes, direction)); + } +}; +#else +/// CUDA stub when CUDA_AVAILABLE is not defined. +/// All operations abort with an error message. +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const c10::Device& device) { + (void)nbytes; + (void)device; + ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } + + static void free(void* ptr) { + (void)ptr; + ET_LOG(Error, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device) { + (void)dst; + (void)src; + (void)nbytes; + (void)dst_device; + (void)src_device; + ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } +}; +#endif // CUDA_AVAILABLE + /** * MaybeOwningStorage - A storage class that manages tensor data memory. * @@ -93,17 +216,19 @@ struct DeviceTraits { class MaybeOwningStorage { public: /// Constructs owning storage with allocated memory. - /// @param device The device for storage (must be CPU). + /// @param device The device for storage (CPU or CUDA). /// @param nbytes Number of bytes to allocate. MaybeOwningStorage(const c10::Device& device, size_t nbytes) : device_(device), capacity_(nbytes), is_owning_(true) { - ET_CHECK_MSG( - device.is_cpu(), - "Only CPU device is currently supported, got: %s", - device.str().c_str()); - - data_ = DeviceTraits::allocate(nbytes, device); - deleter_ = DeviceTraits::free; + if (device.is_cpu()) { + data_ = DeviceTraits::allocate(nbytes, device); + deleter_ = DeviceTraits::free; + } else if (device.is_cuda()) { + data_ = DeviceTraits::allocate(nbytes, device); + deleter_ = DeviceTraits::free; + } else { + ET_CHECK_MSG(false, "Unsupported device type: %s", device.str().c_str()); + } } /// Default constructor is deleted - storage must have a device. diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index 2056b8c6866..0fc898c5598 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -17,10 +17,12 @@ def define_common_targets(): "//executorch/backends/aoti/slim/util:shared_ptr", "//executorch/backends/aoti/slim/util:size_util", "//executorch/runtime/platform:platform", + "//executorch/backends/aoti/slim/c10/cuda:exception", + "//executorch/backends/cuda/runtime:guard", ], ) - # Header-only library for SlimTensor + # Header-only library for SlimTensor (CPU-only for now) runtime.cxx_library( name = "slimtensor", headers = [ diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index c7debd46836..3a7e99dd37c 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -1,17 +1,36 @@ +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def get_backend_mode(): + """Get the supported backend mode of slimtensor.""" + return ["cuda", "cpu"] + def define_common_targets(): """Define test targets for SlimTensor core module.""" - runtime.cxx_test( - name = "test_storage", - srcs = [ - "test_storage.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:storage", - ], - ) + # GPU storage test with CUDA support + for backend_mode in get_backend_mode(): + backend_suffix = "_" + backend_mode if backend_mode == "cuda" else "" + + backend_kwargs = { + "external_deps": [("cuda", None, "cuda-lazy")], + "preprocessor_flags": ["-DCUDA_AVAILABLE=1"], + "keep_gpu_sections": True, + "remote_execution": re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + } if backend_mode == "cuda" else {} + + runtime.cxx_test( + name = "test_storage" + backend_suffix, + srcs = [ + "test_storage.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:storage", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_basic", diff --git a/backends/aoti/slim/core/test/test_storage.cpp b/backends/aoti/slim/core/test/test_storage.cpp index 42a8678c888..5ff3d6620be 100644 --- a/backends/aoti/slim/core/test/test_storage.cpp +++ b/backends/aoti/slim/core/test/test_storage.cpp @@ -10,8 +10,29 @@ #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::backends::aoti::slim { +// ============================================================================= +// Test Device Helpers +// ============================================================================= + +inline std::vector getTestDevices() { + std::vector devices = {CPU_DEVICE}; +#ifdef CUDA_AVAILABLE + devices.push_back(DEFAULT_CUDA_DEVICE); +#endif + return devices; +} + +inline std::string deviceToString( + const testing::TestParamInfo& info) { + return info.param.is_cpu() ? "CPU" : "CUDA"; +} + // ============================================================================= // DeviceTraits Tests // ============================================================================= @@ -52,48 +73,39 @@ TEST(DeviceTraitsCPUTest, MemcpyCPUToCPU) { } // ============================================================================= -// MaybeOwningStorage Tests - Owning Mode +// MaybeOwningStorage Parameterized Tests (CPU and CUDA) // ============================================================================= -TEST(MaybeOwningStorageTest, ConstructOwning) { +class MaybeOwningStorageParamTest : public testing::TestWithParam { + protected: + c10::Device device() const { + return GetParam(); + } +}; + +TEST_P(MaybeOwningStorageParamTest, ConstructOwning) { constexpr size_t kNbytes = 512; - MaybeOwningStorage storage(CPU_DEVICE, kNbytes); + MaybeOwningStorage storage(device(), kNbytes); EXPECT_NE(storage.data(), nullptr); EXPECT_EQ(storage.nbytes(), kNbytes); - EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_EQ(storage.device().type(), device().type()); EXPECT_TRUE(storage.is_owning()); EXPECT_TRUE(storage.is_resizable()); } -TEST(MaybeOwningStorageTest, ConstructOwningZeroBytes) { - MaybeOwningStorage storage(CPU_DEVICE, 0); +TEST_P(MaybeOwningStorageParamTest, ConstructOwningZeroBytes) { + MaybeOwningStorage storage(device(), 0); EXPECT_EQ(storage.data(), nullptr); EXPECT_EQ(storage.nbytes(), 0); - EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_EQ(storage.device().type(), device().type()); EXPECT_TRUE(storage.is_owning()); } -TEST(MaybeOwningStorageTest, DataPersistence) { - constexpr size_t kNumFloats = 64; - constexpr size_t kNbytes = kNumFloats * sizeof(float); - MaybeOwningStorage storage(CPU_DEVICE, kNbytes); - - float* data = static_cast(storage.data()); - for (size_t i = 0; i < kNumFloats; ++i) { - data[i] = static_cast(i) * 2.0f; - } - - float* read_data = static_cast(storage.data()); - for (size_t i = 0; i < kNumFloats; ++i) { - EXPECT_FLOAT_EQ(read_data[i], static_cast(i) * 2.0f); - } -} - -TEST(MaybeOwningStorageTest, MoveConstruct) { +TEST_P(MaybeOwningStorageParamTest, MoveConstruct) { constexpr size_t kNbytes = 256; - MaybeOwningStorage original(CPU_DEVICE, kNbytes); + MaybeOwningStorage original(device(), kNbytes); void* original_data = original.data(); MaybeOwningStorage moved(std::move(original)); @@ -101,17 +113,18 @@ TEST(MaybeOwningStorageTest, MoveConstruct) { EXPECT_EQ(moved.data(), original_data); EXPECT_EQ(moved.nbytes(), kNbytes); EXPECT_TRUE(moved.is_owning()); + EXPECT_EQ(moved.device().type(), device().type()); EXPECT_EQ(original.data(), nullptr); EXPECT_EQ(original.nbytes(), 0); EXPECT_FALSE(original.is_owning()); } -TEST(MaybeOwningStorageTest, MoveAssign) { +TEST_P(MaybeOwningStorageParamTest, MoveAssign) { constexpr size_t kNbytes1 = 256; constexpr size_t kNbytes2 = 512; - MaybeOwningStorage storage1(CPU_DEVICE, kNbytes1); - MaybeOwningStorage storage2(CPU_DEVICE, kNbytes2); + MaybeOwningStorage storage1(device(), kNbytes1); + MaybeOwningStorage storage2(device(), kNbytes2); void* storage2_data = storage2.data(); storage1 = std::move(storage2); @@ -125,7 +138,33 @@ TEST(MaybeOwningStorageTest, MoveAssign) { EXPECT_FALSE(storage2.is_owning()); } -TEST(MaybeOwningStorageTest, Clone) { +INSTANTIATE_TEST_SUITE_P( + DeviceTests, + MaybeOwningStorageParamTest, + testing::ValuesIn(getTestDevices()), + deviceToString); + +// ============================================================================= +// MaybeOwningStorage CPU-Only Tests (require direct data access) +// ============================================================================= + +TEST(MaybeOwningStorageCPUTest, DataPersistence) { + constexpr size_t kNumFloats = 64; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + MaybeOwningStorage storage(CPU_DEVICE, kNbytes); + + float* data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + data[i] = static_cast(i) * 2.0f; + } + + float* read_data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(read_data[i], static_cast(i) * 2.0f); + } +} + +TEST(MaybeOwningStorageCPUTest, Clone) { constexpr size_t kNumFloats = 32; constexpr size_t kNbytes = kNumFloats * sizeof(float); MaybeOwningStorage original(CPU_DEVICE, kNbytes); @@ -150,7 +189,7 @@ TEST(MaybeOwningStorageTest, Clone) { EXPECT_FLOAT_EQ(cloned_data[0], 0.0f); } -TEST(MaybeOwningStorageTest, CopyFunction) { +TEST(MaybeOwningStorageCPUTest, CopyFunction) { constexpr size_t kNumFloats = 16; constexpr size_t kNbytes = kNumFloats * sizeof(float); MaybeOwningStorage src_storage(CPU_DEVICE, kNbytes); @@ -171,23 +210,30 @@ TEST(MaybeOwningStorageTest, CopyFunction) { } // ============================================================================= -// Storage (SharedPtr) Tests +// Storage (SharedPtr) Parameterized Tests // ============================================================================= -TEST(StorageSharedPtrTest, BasicUsage) { +class StorageSharedPtrParamTest : public testing::TestWithParam { + protected: + c10::Device device() const { + return GetParam(); + } +}; + +TEST_P(StorageSharedPtrParamTest, BasicUsage) { constexpr size_t kNbytes = 128; - Storage storage(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); + Storage storage(new MaybeOwningStorage(device(), kNbytes)); EXPECT_NE(storage.get(), nullptr); EXPECT_NE(storage->data(), nullptr); EXPECT_EQ(storage->nbytes(), kNbytes); - EXPECT_TRUE(storage->device().is_cpu()); + EXPECT_EQ(storage->device().type(), device().type()); EXPECT_EQ(storage.use_count(), 1); } -TEST(StorageSharedPtrTest, SharedOwnership) { +TEST_P(StorageSharedPtrParamTest, SharedOwnership) { constexpr size_t kNbytes = 128; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); void* data_ptr = storage1->data(); Storage storage2 = storage1; // Copy, not reference - increments ref count @@ -198,7 +244,52 @@ TEST(StorageSharedPtrTest, SharedOwnership) { EXPECT_EQ(storage2->data(), data_ptr); } -TEST(StorageSharedPtrTest, SharedOwnershipModification) { +TEST_P(StorageSharedPtrParamTest, ReferenceCountDecrement) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); + EXPECT_EQ(storage1.use_count(), 1); + + { + Storage storage2 = storage1; + EXPECT_EQ(storage1.use_count(), 2); + } + + EXPECT_EQ(storage1.use_count(), 1); +} + +TEST_P(StorageSharedPtrParamTest, MoveSemantics) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); + void* data_ptr = storage1->data(); + + Storage storage2 = std::move(storage1); + + EXPECT_EQ(storage1.get(), nullptr); + EXPECT_EQ(storage2->data(), data_ptr); + EXPECT_EQ(storage2.use_count(), 1); +} + +TEST_P(StorageSharedPtrParamTest, MakeShared) { + constexpr size_t kNbytes = 256; + Storage storage = make_shared(device(), kNbytes); + + EXPECT_NE(storage.get(), nullptr); + EXPECT_NE(storage->data(), nullptr); + EXPECT_EQ(storage->nbytes(), kNbytes); + EXPECT_EQ(storage.use_count(), 1); +} + +INSTANTIATE_TEST_SUITE_P( + DeviceTests, + StorageSharedPtrParamTest, + testing::ValuesIn(getTestDevices()), + deviceToString); + +// ============================================================================= +// Storage CPU-Only Tests (require direct data access) +// ============================================================================= + +TEST(StorageSharedPtrCPUTest, SharedOwnershipModification) { constexpr size_t kNumFloats = 8; constexpr size_t kNbytes = kNumFloats * sizeof(float); Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); @@ -208,7 +299,7 @@ TEST(StorageSharedPtrTest, SharedOwnershipModification) { data[i] = 0.0f; } - const Storage& storage2 = storage1; + Storage storage2 = storage1; float* data2 = static_cast(storage2->data()); for (size_t i = 0; i < kNumFloats; ++i) { @@ -221,39 +312,156 @@ TEST(StorageSharedPtrTest, SharedOwnershipModification) { } } -TEST(StorageSharedPtrTest, ReferenceCountDecrement) { - constexpr size_t kNbytes = 64; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); - EXPECT_EQ(storage1.use_count(), 1); +#ifdef CUDA_AVAILABLE - { - Storage storage2 = storage1; // Copy increments ref count - EXPECT_EQ(storage1.use_count(), 2); - } // storage2 destroyed, ref count decrements +// ============================================================================= +// DeviceTraits Tests +// ============================================================================= - EXPECT_EQ(storage1.use_count(), 1); +TEST(DeviceTraitsCUDATest, AllocateAndFree) { + constexpr size_t kSize = 1024; + void* ptr = + DeviceTraits::allocate(kSize, DEFAULT_CUDA_DEVICE); + ASSERT_NE(ptr, nullptr); + + DeviceTraits::free(ptr); } -TEST(StorageSharedPtrTest, MoveSemantics) { - constexpr size_t kNbytes = 64; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); - void* data_ptr = storage1->data(); +TEST(DeviceTraitsCUDATest, AllocateZeroBytes) { + void* ptr = + DeviceTraits::allocate(0, DEFAULT_CUDA_DEVICE); + DeviceTraits::free(ptr); +} - Storage storage2 = std::move(storage1); +TEST(DeviceTraitsCUDATest, MemcpyCPUToCUDA) { + constexpr size_t kSize = 256; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_dst = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_verify = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); - EXPECT_EQ(storage1.get(), nullptr); - EXPECT_EQ(storage2->data(), data_ptr); - EXPECT_EQ(storage2.use_count(), 1); + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) * 2.5f; + } + + // Copy CPU -> CUDA + DeviceTraits::memcpy( + cuda_dst, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA -> CPU to verify + DeviceTraits::memcpy( + cpu_verify, + cuda_dst, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_verify[i], static_cast(i) * 2.5f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_dst); + DeviceTraits::free(cpu_verify); } -TEST(StorageSharedPtrTest, MakeShared) { - constexpr size_t kNbytes = 256; - Storage storage = make_shared(CPU_DEVICE, kNbytes); +TEST(DeviceTraitsCUDATest, MemcpyCUDAToCPU) { + constexpr size_t kSize = 128; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_mem = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_dst = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); - EXPECT_NE(storage.get(), nullptr); - EXPECT_NE(storage->data(), nullptr); - EXPECT_EQ(storage->nbytes(), kNbytes); - EXPECT_EQ(storage.use_count(), 1); + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) + 100.0f; + } + + // Copy CPU -> CUDA + DeviceTraits::memcpy( + cuda_mem, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA -> CPU + DeviceTraits::memcpy( + cpu_dst, + cuda_mem, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_dst[i], static_cast(i) + 100.0f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_mem); + DeviceTraits::free(cpu_dst); } +TEST(DeviceTraitsCUDATest, MemcpyCUDAToCUDA) { + constexpr size_t kSize = 64; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_src = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cuda_dst = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_verify = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) * 3.0f; + } + + // Copy CPU -> CUDA src + DeviceTraits::memcpy( + cuda_src, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA src -> CUDA dst + DeviceTraits::memcpy( + cuda_dst, + cuda_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + DEFAULT_CUDA_DEVICE); + + // Copy CUDA dst -> CPU to verify + DeviceTraits::memcpy( + cpu_verify, + cuda_dst, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_verify[i], static_cast(i) * 3.0f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_src); + DeviceTraits::free(cuda_dst); + DeviceTraits::free(cpu_verify); +} + +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 532ab5544ab..024418d31a6 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -3,6 +3,28 @@ load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args") oncall("executorch") +runtime.cxx_library( + name = "guard", + srcs = [ + "guard.cpp", + ], + headers = [ + "guard.h", + "utils.h", + ], + visibility = ["PUBLIC"], + deps = [ + "//executorch/runtime/platform:platform", + ], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) + runtime.cxx_library( name = "cuda_platform", srcs = [ From 60b9279c09f092755fc642dd87719a9883821144 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 22 Jan 2026 09:55:57 -0800 Subject: [PATCH 2/4] [slimtensor] Add CUDA slimtensor creation with basic functionality Pull Request resolved: https://github.com/pytorch/executorch/pull/16770 This diff enables CUDA tensor creation with basic tensor functionality and factory function support **Key changes:* 1. **`core/SlimTensor.h`** - Extended for CUDA support: - Added `is_cuda()` method to check if tensor is on CUDA device 2. **`factory/Empty.h`** - Supports CUDA: - `empty_strided()` and `empty()` work with CUDA device via `new_storage()` - Device routing is handled by `MaybeOwningStorage` constructor ghstack-source-id: 335102160 @exported-using-ghexport Differential Revision: [D91202897](https://our.internmc.facebook.com/intern/diff/D91202897/) --- backends/aoti/slim/core/SlimTensor.h | 7 + backends/aoti/slim/core/targets.bzl | 2 +- backends/aoti/slim/core/test/targets.bzl | 21 +- .../slim/core/test/test_slimtensor_basic.cpp | 15 +- backends/aoti/slim/factory/Empty.h | 6 +- backends/aoti/slim/factory/test/targets.bzl | 37 ++- .../aoti/slim/factory/test/test_empty.cpp | 257 ++++++++++++++++++ 7 files changed, 317 insertions(+), 28 deletions(-) diff --git a/backends/aoti/slim/core/SlimTensor.h b/backends/aoti/slim/core/SlimTensor.h index f3ab9f3fec3..c662202493d 100644 --- a/backends/aoti/slim/core/SlimTensor.h +++ b/backends/aoti/slim/core/SlimTensor.h @@ -227,6 +227,13 @@ class SlimTensor { return device().is_cpu(); } + /** + * Check if the tensor is on CUDA. + */ + bool is_cuda() const { + return device().is_cuda(); + } + /** * Check if the tensor is defined (has valid storage). */ diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index 0fc898c5598..cc74b01b444 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -22,7 +22,6 @@ def define_common_targets(): ], ) - # Header-only library for SlimTensor (CPU-only for now) runtime.cxx_library( name = "slimtensor", headers = [ @@ -37,6 +36,7 @@ def define_common_targets(): "//executorch/backends/aoti/slim/c10/core:sizes_and_strides", "//executorch/backends/aoti/slim/util:array_ref_util", "//executorch/backends/aoti/slim/util:size_util", + "//executorch/backends/aoti/slim/c10/cuda:exception", "//executorch/runtime/platform:platform", ], ) diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index 3a7e99dd37c..3400fd943e8 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -32,16 +32,17 @@ def define_common_targets(): **backend_kwargs ) - runtime.cxx_test( - name = "test_slimtensor_basic", - srcs = [ - "test_slimtensor_basic.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:slimtensor", - "//executorch/backends/aoti/slim/core:storage", - ], - ) + runtime.cxx_test( + name = "test_slimtensor_basic" + backend_suffix, + srcs = [ + "test_slimtensor_basic.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/core:storage", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_copy", diff --git a/backends/aoti/slim/core/test/test_slimtensor_basic.cpp b/backends/aoti/slim/core/test/test_slimtensor_basic.cpp index dc60427c467..d70db1e4ae2 100644 --- a/backends/aoti/slim/core/test/test_slimtensor_basic.cpp +++ b/backends/aoti/slim/core/test/test_slimtensor_basic.cpp @@ -21,6 +21,9 @@ namespace executorch::backends::aoti::slim { inline std::vector get_test_devices() { std::vector devices; devices.push_back(CPU_DEVICE); +#ifdef CUDA_AVAILABLE + devices.push_back(DEFAULT_CUDA_DEVICE); +#endif return devices; } @@ -52,7 +55,9 @@ INSTANTIATE_TEST_SUITE_P( DeviceTests, SlimTensorBasicDeviceTest, ::testing::ValuesIn(get_test_devices()), - [](const ::testing::TestParamInfo& info) { return "CPU"; }); + [](const ::testing::TestParamInfo& info) { + return info.param.is_cuda() ? "CUDA" : "CPU"; + }); // ============================================================================= // Constructor Tests (Device-Parameterized) @@ -144,11 +149,11 @@ TEST_P(SlimTensorBasicDeviceTest, Dtype) { TEST_P(SlimTensorBasicDeviceTest, Device) { SlimTensor tensor = make_2x3_tensor(); - // We only support CPU for now - EXPECT_TRUE(tensor.is_cpu()); - EXPECT_EQ(tensor.device_type(), c10::DeviceType::CPU); - + // Check device type and index + EXPECT_EQ(tensor.device_type(), device().type()); EXPECT_EQ(tensor.device_index(), device().index()); + EXPECT_EQ(tensor.is_cpu(), device().is_cpu()); + EXPECT_EQ(tensor.is_cuda(), device().is_cuda()); } TEST_P(SlimTensorBasicDeviceTest, Numel) { diff --git a/backends/aoti/slim/factory/Empty.h b/backends/aoti/slim/factory/Empty.h index 24b4f53a647..c0ab9d7248d 100644 --- a/backends/aoti/slim/factory/Empty.h +++ b/backends/aoti/slim/factory/Empty.h @@ -23,7 +23,7 @@ namespace executorch::backends::aoti::slim { /// @param sizes The sizes of each dimension. /// @param strides The strides of each dimension. /// @param dtype The scalar type of tensor elements. -/// @param device The target device (must be CPU). +/// @param device The target device. /// @return A new SlimTensor with allocated but uninitialized storage. inline SlimTensor empty_strided( IntArrayRef sizes, @@ -41,7 +41,7 @@ inline SlimTensor empty_strided( /// /// @param sizes The sizes of each dimension. /// @param dtype The scalar type of tensor elements. -/// @param device The target device (must be CPU). +/// @param device The target device. /// @return A new SlimTensor with contiguous strides and uninitialized storage. inline SlimTensor empty( IntArrayRef sizes, @@ -59,7 +59,7 @@ inline SlimTensor empty( /// /// @param sizes The sizes of each dimension as an initializer list. /// @param dtype The scalar type of tensor elements. -/// @param device The target device (must be CPU). +/// @param device The target device. /// @return A new SlimTensor with contiguous strides and uninitialized storage. inline SlimTensor empty( std::initializer_list sizes, diff --git a/backends/aoti/slim/factory/test/targets.bzl b/backends/aoti/slim/factory/test/targets.bzl index a64510b2af1..7bad3067cc0 100644 --- a/backends/aoti/slim/factory/test/targets.bzl +++ b/backends/aoti/slim/factory/test/targets.bzl @@ -1,14 +1,33 @@ +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def get_backend_mode(): + """Get the supported backend mode of slimtensor.""" + return ["cuda", "cpu"] + def define_common_targets(): """Define test targets for SlimTensor factory module.""" - runtime.cxx_test( - name = "test_empty", - srcs = [ - "test_empty.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/factory:empty", - ], - ) + # GPU empty test with CUDA support + for backend_mode in get_backend_mode(): + backend_suffix = "_" + backend_mode if backend_mode == "cuda" else "" + + backend_kwargs = { + "external_deps": [("cuda", None, "cuda-lazy")], + "preprocessor_flags": ["-DCUDA_AVAILABLE=1"], + "keep_gpu_sections": True, + "remote_execution": re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + } if backend_mode == "cuda" else {} + + runtime.cxx_test( + name = "test_empty" + backend_suffix, + srcs = [ + "test_empty.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/factory:empty", + ], + **backend_kwargs + ) diff --git a/backends/aoti/slim/factory/test/test_empty.cpp b/backends/aoti/slim/factory/test/test_empty.cpp index 7d7c9cafc34..18e7ead14ef 100644 --- a/backends/aoti/slim/factory/test/test_empty.cpp +++ b/backends/aoti/slim/factory/test/test_empty.cpp @@ -10,6 +10,10 @@ #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::backends::aoti::slim { // ============================================================================= @@ -229,4 +233,257 @@ TEST(EmptyTest, CanWriteAndReadData) { } } +#ifdef CUDA_AVAILABLE + +// ============================================================================= +// CUDA Empty Tensor Tests +// Tests are skipped at runtime if CUDA hardware is not available. +// ============================================================================= + +// ============================================================================= +// empty_strided() CUDA Tests +// ============================================================================= + +TEST(EmptyStridedCUDATest, Basic2x3Tensor) { + std::vector sizes = {2, 3}; + std::vector strides = {3, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.dim(), 2u); + EXPECT_EQ(tensor.numel(), 6u); + EXPECT_EQ(tensor.dtype(), c10::ScalarType::Float); + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_FALSE(tensor.is_cpu()); + + auto result_sizes = tensor.sizes(); + EXPECT_EQ(result_sizes[0], 2); + EXPECT_EQ(result_sizes[1], 3); + + auto result_strides = tensor.strides(); + EXPECT_EQ(result_strides[0], 3); + EXPECT_EQ(result_strides[1], 1); +} + +TEST(EmptyStridedCUDATest, ContiguousTensor) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_EQ(tensor.numel(), 24u); + EXPECT_EQ(tensor.nbytes(), 24 * sizeof(float)); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, NonContiguousTensor) { + std::vector sizes = {3, 2}; + std::vector strides = {1, 3}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_FALSE(tensor.is_contiguous()); + EXPECT_EQ(tensor.numel(), 6u); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, OneDimensional) { + std::vector sizes = {10}; + std::vector strides = {1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 1u); + EXPECT_EQ(tensor.numel(), 10u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, ZeroSizedTensor) { + std::vector sizes = {0, 3}; + std::vector strides = {3, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.numel(), 0u); + EXPECT_TRUE(tensor.is_empty()); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, LargeDimensionalTensor) { + std::vector sizes = {2, 3, 4, 5}; + std::vector strides = {60, 20, 5, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 4u); + EXPECT_EQ(tensor.numel(), 120u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); +} + +// ============================================================================= +// empty() CUDA Tests +// ============================================================================= + +TEST(EmptyCUDATest, BasicWithArrayRef) { + std::vector sizes = {2, 3, 4}; + + SlimTensor tensor = + empty(makeArrayRef(sizes), c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.dim(), 3u); + EXPECT_EQ(tensor.numel(), 24u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyCUDATest, VerifiesContiguousStrides) { + std::vector sizes = {2, 3, 4}; + + SlimTensor tensor = + empty(makeArrayRef(sizes), c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + auto strides = tensor.strides(); + EXPECT_EQ(strides[0], 12); + EXPECT_EQ(strides[1], 4); + EXPECT_EQ(strides[2], 1); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyCUDATest, InitializerListOverload) { + SlimTensor tensor = + empty({4, 5, 6}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 3u); + EXPECT_EQ(tensor.numel(), 120u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); + + auto sizes = tensor.sizes(); + EXPECT_EQ(sizes[0], 4); + EXPECT_EQ(sizes[1], 5); + EXPECT_EQ(sizes[2], 6); +} + +TEST(EmptyCUDATest, OneDimensional) { + SlimTensor tensor = empty({10}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 1u); + EXPECT_EQ(tensor.numel(), 10u); + EXPECT_EQ(tensor.stride(0), 1); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyCUDATest, ZeroSized) { + SlimTensor tensor = + empty({0, 5}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.is_empty()); + EXPECT_EQ(tensor.numel(), 0u); + EXPECT_TRUE(tensor.is_cuda()); +} + +// ============================================================================= +// empty_like() CUDA Tests +// ============================================================================= + +TEST(EmptyLikeCUDATest, CopiesMetadata) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + + SlimTensor original = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + SlimTensor copy = empty_like(original); + + EXPECT_EQ(copy.dim(), original.dim()); + EXPECT_EQ(copy.numel(), original.numel()); + EXPECT_EQ(copy.dtype(), original.dtype()); + EXPECT_EQ(copy.is_cuda(), original.is_cuda()); + EXPECT_EQ(copy.is_contiguous(), original.is_contiguous()); + + for (size_t i = 0; i < copy.dim(); i++) { + EXPECT_EQ(copy.size(i), original.size(i)); + EXPECT_EQ(copy.stride(i), original.stride(i)); + } +} + +TEST(EmptyLikeCUDATest, HasDifferentStorage) { + SlimTensor original = + empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + SlimTensor copy = empty_like(original); + + EXPECT_NE(original.data_ptr(), copy.data_ptr()); + EXPECT_TRUE(copy.is_cuda()); +} + +TEST(EmptyLikeCUDATest, NonContiguousTensor) { + std::vector sizes = {3, 2}; + std::vector strides = {1, 3}; + + SlimTensor original = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + SlimTensor copy = empty_like(original); + + EXPECT_FALSE(copy.is_contiguous()); + EXPECT_EQ(copy.stride(0), 1); + EXPECT_EQ(copy.stride(1), 3); + EXPECT_TRUE(copy.is_cuda()); +} + +// ============================================================================= +// CUDA Data Access Tests +// ============================================================================= + +TEST(EmptyCUDATest, DataPtrIsValid) { + SlimTensor tensor = + empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + void* data = tensor.data_ptr(); + EXPECT_NE(data, nullptr); +} + +TEST(EmptyCUDATest, DeviceIndex) { + SlimTensor tensor = + empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.device().index(), 0); +} + +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim From f18a5e1f6dd3ecd9c231b8b227786c9e901db6e9 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 22 Jan 2026 09:55:59 -0800 Subject: [PATCH 3/4] [slimtensor] Enable CUDA tensor copy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/16771 This diff enables CUDA tensor copy operations in SlimTensor. **Key changes:** **`core/SlimTensor.h`** - Extended for CUDA support: - Updated `copy_()` to handle cross-device copies: - CPU→CUDA (cudaMemcpyHostToDevice) - CUDA→CPU (cudaMemcpyDeviceToHost) - CUDA→CUDA (cudaMemcpyDeviceToDevice, same device) - Cross-device copies require contiguous tensors - CPU-to-CPU copies continue to support non-contiguous (strided) tensors ghstack-source-id: 335102159 @exported-using-ghexport Differential Revision: [D91202900](https://our.internmc.facebook.com/intern/diff/D91202900/) --- backends/aoti/slim/core/SlimTensor.h | 71 ++-- backends/aoti/slim/core/Storage.h | 15 +- backends/aoti/slim/core/test/targets.bzl | 22 +- .../slim/core/test/test_slimtensor_copy.cpp | 376 ++++++++++++++++++ 4 files changed, 433 insertions(+), 51 deletions(-) diff --git a/backends/aoti/slim/core/SlimTensor.h b/backends/aoti/slim/core/SlimTensor.h index c662202493d..92b34e8a3e8 100644 --- a/backends/aoti/slim/core/SlimTensor.h +++ b/backends/aoti/slim/core/SlimTensor.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -277,69 +278,67 @@ class SlimTensor { * Copy data from another tensor to this tensor. * * Both tensors must have the same numel and dtype. - * Currently only supports CPU-to-CPU copy (contiguous tensors only). + * Supports CPU-to-CPU and cross-device copies (CPU↔CUDA, CUDA↔CUDA). * * @param other The source tensor to copy from * @return Reference to this tensor */ SlimTensor& copy_(const SlimTensor& other) { ET_CHECK_MSG( - this->numel() == other.numel(), - "copy_: numel mismatch (dst=%zu, src=%zu)", - this->numel(), - other.numel()); - ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype mismatch"); + this->numel() == other.numel(), "copy_: numel of tensors must match"); + ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype must match"); if (this->numel() == 0) { return *this; } - // Current we only support CPU-only tensors - // TODO(gasoonjia): support other device types. - ET_CHECK_MSG( - this->is_cpu() && other.is_cpu(), "copy_: only CPU tensors supported"); - + // Case 1: Both tensors are contiguous. We can do a fast bulk copy. if (this->is_contiguous() && other.is_contiguous()) { - // Fast path: both tensors are contiguous, use memcpy - std::memcpy(this->data_ptr(), other.data_ptr(), other.nbytes()); - } else { - // Slow path: element-wise copy for non-contiguous tensors - copy_strided_(other); + storage_->copy_( + this->data_ptr(), other.data_ptr(), other.nbytes(), other.device()); + return *this; } - return *this; - } - - private: - /** - * Element-wise copy for non-contiguous tensors. - */ - void copy_strided_(const SlimTensor& other) { + // Case 2: At least one tensor is non-contiguous, perform element-wise copy + // that respects both source and destination strides. const size_t elem_size = c10::elementSize(dtype_); char* dst_data = static_cast(this->data_ptr()); const char* src_data = static_cast(other.data_ptr()); std::vector counter(this->dim(), 0); for (size_t i = 0; i < this->numel(); i++) { - // Compute source offset + // Compute src offset in elements int64_t src_offset = 0; for (size_t d = 0; d < other.dim(); d++) { - src_offset += counter[d] * other.stride(static_cast(d)); + src_offset += counter[d] * other.stride(d); } - // Compute destination offset + // Compute dst offset in elements int64_t dst_offset = 0; for (size_t d = 0; d < this->dim(); d++) { - dst_offset += counter[d] * this->stride(static_cast(d)); + dst_offset += counter[d] * this->stride(d); } - // Copy single element - std::memcpy( - dst_data + dst_offset * static_cast(elem_size), - src_data + src_offset * static_cast(elem_size), - elem_size); - - // Increment multi-dimensional counter + // Copy elem_size bytes from src to dst + if (this->device().is_cpu() && other.device().is_cpu()) { + std::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size); + } else if (this->device().is_cuda() || other.device().is_cuda()) { +#if defined(CUDA_AVAILABLE) + DeviceTraits::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size, + device(), // dst device + other.device() // src device + ); +#else + ET_CHECK_MSG(false, "Failed on copy_ cuda tensors: no CUDA support"); +#endif + } + // Increment the multi-dimensional counter for (int64_t d = static_cast(this->dim()) - 1; d >= 0; --d) { counter[d]++; if (counter[d] < this->size(d)) { @@ -348,8 +347,10 @@ class SlimTensor { counter[d] = 0; } } + return *this; } + private: void refresh_numel() { numel_ = compute_numel(sizes_and_strides_.sizes_arrayref()); } diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index 6718f04cb51..ccd63f75981 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -296,12 +296,15 @@ class MaybeOwningStorage { return; } - ET_CHECK_MSG( - device_.is_cpu() && src_device.is_cpu(), - "Only CPU-to-CPU copy is currently supported"); - - DeviceTraits::memcpy( - dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + if (device_.is_cpu() && src_device.is_cpu()) { + // CPU to CPU copy + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } else { + // At least one of the devices is CUDA + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } } /// Creates a clone of this storage on the specified device. diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index 3400fd943e8..d0991708c7f 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -44,16 +44,18 @@ def define_common_targets(): **backend_kwargs ) - runtime.cxx_test( - name = "test_slimtensor_copy", - srcs = [ - "test_slimtensor_copy.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:slimtensor", - "//executorch/backends/aoti/slim/core:storage", - ], - ) + runtime.cxx_test( + name = "test_slimtensor_copy" + backend_suffix, + srcs = [ + "test_slimtensor_copy.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/core:storage", + "//executorch/backends/aoti/slim/factory:empty", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_dtypes", diff --git a/backends/aoti/slim/core/test/test_slimtensor_copy.cpp b/backends/aoti/slim/core/test/test_slimtensor_copy.cpp index f227f954798..6d2ed745446 100644 --- a/backends/aoti/slim/core/test/test_slimtensor_copy.cpp +++ b/backends/aoti/slim/core/test/test_slimtensor_copy.cpp @@ -10,6 +10,7 @@ #include #include +#include namespace executorch::backends::aoti::slim { @@ -256,4 +257,379 @@ TEST(SlimTensorCopyTest, CopyWithStorageOffset) { EXPECT_FLOAT_EQ(dst_base[23], 4.0f); } +// ============================================================================= +// CUDA Tensor Creation Tests +// These tests verify CUDA tensor creation and the is_cuda() method. +// When CUDA_AVAILABLE is not defined, CUDA operations abort with an error. +// ============================================================================= + +#ifdef CUDA_AVAILABLE + +TEST(CUDATensorTest, CreateEmptyCUDATensor) { + auto tensor = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_FALSE(tensor.is_cpu()); + EXPECT_EQ(tensor.dim(), 2); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.numel(), 6); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_EQ(tensor.device().type(), c10::DeviceType::CUDA); + EXPECT_EQ(tensor.device().index(), 0); +} + +TEST(CUDATensorTest, CreateEmptyStridedCUDATensor) { + std::vector sizes = {2, 4}; + std::vector strides = {4, 1}; + + auto tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_EQ(tensor.stride(0), 4); + EXPECT_EQ(tensor.stride(1), 1); + EXPECT_EQ(tensor.numel(), 8); +} + +TEST(CUDATensorTest, CreateCUDATensorWithDeviceIndex) { + c10::Device device(c10::DeviceType::CUDA, 0); + auto tensor = empty({4, 4}, c10::ScalarType::Float, device); + + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_EQ(tensor.device_index(), 0); +} + +// ============================================================================= +// Cross-Device Copy Tests +// ============================================================================= + +TEST(CUDACopyTest, CopyFromCPUToCUDA) { + constexpr size_t kNumFloats = 6; + auto cpu_tensor = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + + // Fill CPU tensor with known values + float* cpu_data = static_cast(cpu_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_data[i] = static_cast(i) * 1.5f; + } + + // Create CUDA tensor and copy from CPU + auto cuda_tensor = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_tensor); + + // Copy back to CPU to verify + auto verify_tensor = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + verify_tensor.copy_(cuda_tensor); + + float* verify_data = static_cast(verify_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_data[i], static_cast(i) * 1.5f); + } +} + +TEST(CUDACopyTest, CopyFromCUDAToCPU) { + constexpr size_t kNumFloats = 4; + auto cpu_src = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + + float* src_data = static_cast(cpu_src.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + src_data[i] = static_cast(i) + 100.0f; + } + + // Copy to CUDA + auto cuda_tensor = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_src); + + // Copy back to new CPU tensor + auto cpu_dst = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + cpu_dst.copy_(cuda_tensor); + + float* dst_data = static_cast(cpu_dst.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i) + 100.0f); + } +} + +TEST(CUDACopyTest, CopyFromCUDAToCUDA) { + constexpr size_t kNumFloats = 4; + auto cpu_tensor = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + + float* cpu_data = static_cast(cpu_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_data[i] = static_cast(i) * 2.0f; + } + + // Create first CUDA tensor from CPU + auto cuda_src = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_src.copy_(cpu_tensor); + + // Copy to second CUDA tensor + auto cuda_dst = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_dst.copy_(cuda_src); + + // Verify by copying back to CPU + auto verify_tensor = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + verify_tensor.copy_(cuda_dst); + + float* verify_data = static_cast(verify_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_data[i], static_cast(i) * 2.0f); + } +} + +TEST(CUDACopyTest, CopyDifferentDtypes) { + auto cpu_int = empty({4}, c10::ScalarType::Int, CPU_DEVICE); + int32_t* int_data = static_cast(cpu_int.data_ptr()); + for (int i = 0; i < 4; ++i) { + int_data[i] = i * 10; + } + + auto cuda_int = empty({4}, c10::ScalarType::Int, DEFAULT_CUDA_DEVICE); + cuda_int.copy_(cpu_int); + + auto verify_int = empty({4}, c10::ScalarType::Int, CPU_DEVICE); + verify_int.copy_(cuda_int); + + int32_t* verify_data = static_cast(verify_int.data_ptr()); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(verify_data[i], i * 10); + } +} + +TEST(CUDACopyTest, CopyEmptyTensor) { + auto cpu_empty = empty({0}, c10::ScalarType::Float, CPU_DEVICE); + auto cuda_empty = empty({0}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + // Should not crash + cuda_empty.copy_(cpu_empty); + cpu_empty.copy_(cuda_empty); + + EXPECT_EQ(cpu_empty.numel(), 0); + EXPECT_EQ(cuda_empty.numel(), 0); +} + +// ============================================================================= +// Non-Contiguous Cross-Device Copy Tests +// These tests verify copying non-contiguous CPU tensors to/from CUDA tensors. +// The CUDA tensor must be contiguous, but the CPU tensor can be non-contiguous. +// ============================================================================= + +TEST(CUDACopyTest, CopyNonContiguousCPUToCUDA) { + // Create a transposed (non-contiguous) CPU source tensor + // Logical shape: 2x3, but stored transposed in memory + std::vector src_sizes = {2, 3}; + std::vector src_strides = {1, 2}; // Transposed strides + + Storage src_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 6 * sizeof(float))); + float* src_data = static_cast(src_storage->data()); + // Physical layout for transposed tensor: + // Logical[0,0]=Physical[0], Logical[1,0]=Physical[1] + // Logical[0,1]=Physical[2], Logical[1,1]=Physical[3] + // Logical[0,2]=Physical[4], Logical[1,2]=Physical[5] + src_data[0] = 1.0f; // Logical[0,0] + src_data[1] = 4.0f; // Logical[1,0] + src_data[2] = 2.0f; // Logical[0,1] + src_data[3] = 5.0f; // Logical[1,1] + src_data[4] = 3.0f; // Logical[0,2] + src_data[5] = 6.0f; // Logical[1,2] + + SlimTensor cpu_src( + std::move(src_storage), + makeArrayRef(src_sizes), + makeArrayRef(src_strides), + c10::ScalarType::Float); + + ASSERT_FALSE(cpu_src.is_contiguous()); + + // Create a contiguous CUDA destination + auto cuda_dst = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + ASSERT_TRUE(cuda_dst.is_contiguous()); + + // Copy non-contiguous CPU → contiguous CUDA + cuda_dst.copy_(cpu_src); + + // Verify by copying back to CPU + auto verify = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + verify.copy_(cuda_dst); + + // Values should be in logical order (contiguous layout) + float* verify_data = static_cast(verify.data_ptr()); + EXPECT_FLOAT_EQ(verify_data[0], 1.0f); // [0,0] + EXPECT_FLOAT_EQ(verify_data[1], 2.0f); // [0,1] + EXPECT_FLOAT_EQ(verify_data[2], 3.0f); // [0,2] + EXPECT_FLOAT_EQ(verify_data[3], 4.0f); // [1,0] + EXPECT_FLOAT_EQ(verify_data[4], 5.0f); // [1,1] + EXPECT_FLOAT_EQ(verify_data[5], 6.0f); // [1,2] +} + +TEST(CUDACopyTest, CopyCUDAToNonContiguousCPU) { + constexpr size_t kNumFloats = 6; + + // Create a contiguous CPU source, copy to CUDA + auto cpu_src = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + float* src_data = static_cast(cpu_src.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + src_data[i] = static_cast(i + 1); + } + + auto cuda_tensor = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_src); + + // Create a transposed (non-contiguous) CPU destination + std::vector dst_sizes = {2, 3}; + std::vector dst_strides = {1, 2}; // Transposed strides + + Storage dst_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 6 * sizeof(float))); + SlimTensor cpu_dst( + std::move(dst_storage), + makeArrayRef(dst_sizes), + makeArrayRef(dst_strides), + c10::ScalarType::Float); + + ASSERT_FALSE(cpu_dst.is_contiguous()); + + // Copy contiguous CUDA → non-contiguous CPU + cpu_dst.copy_(cuda_tensor); + + // Verify physical layout matches transposed storage + float* dst_data = static_cast(cpu_dst.storage()->data()); + // Physical layout: [1,4,2,5,3,6] for logical [[1,2,3],[4,5,6]] + EXPECT_FLOAT_EQ(dst_data[0], 1.0f); // Logical[0,0] + EXPECT_FLOAT_EQ(dst_data[1], 4.0f); // Logical[1,0] + EXPECT_FLOAT_EQ(dst_data[2], 2.0f); // Logical[0,1] + EXPECT_FLOAT_EQ(dst_data[3], 5.0f); // Logical[1,1] + EXPECT_FLOAT_EQ(dst_data[4], 3.0f); // Logical[0,2] + EXPECT_FLOAT_EQ(dst_data[5], 6.0f); // Logical[1,2] +} + +TEST(CUDACopyTest, CopyNonContiguousCPUToCUDA3D) { + // Test 3D non-contiguous tensor copy + std::vector sizes = {2, 2, 2}; + // Permuted strides (e.g., from permute(2, 0, 1)) + std::vector non_contig_strides = {2, 1, 4}; + + Storage src_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 8 * sizeof(float))); + float* src_data = static_cast(src_storage->data()); + // Fill with values 1-8 + for (int i = 0; i < 8; ++i) { + src_data[i] = static_cast(i + 1); + } + + SlimTensor cpu_src( + std::move(src_storage), + makeArrayRef(sizes), + makeArrayRef(non_contig_strides), + c10::ScalarType::Float); + + ASSERT_FALSE(cpu_src.is_contiguous()); + + auto cuda_dst = empty({2, 2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_dst.copy_(cpu_src); + + // Copy back and verify the logical order is preserved + auto verify = empty({2, 2, 2}, c10::ScalarType::Float, CPU_DEVICE); + verify.copy_(cuda_dst); + + // Access elements via strided indexing on source + float* verify_data = static_cast(verify.data_ptr()); + + // Verify a few key positions + // The values should match the logical traversal of the source tensor + EXPECT_NE(verify_data[0], 0.0f); // Should have data + EXPECT_EQ(verify.numel(), 8); +} + +TEST(CUDACopyTest, CopyCUDAToNonContiguousCPUWithOffset) { + // Test with storage offset + constexpr size_t kNumFloats = 4; + + auto cpu_src = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + float* src_data = static_cast(cpu_src.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + src_data[i] = static_cast(i + 10); + } + + auto cuda_tensor = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_src); + + // Create non-contiguous destination with storage offset + std::vector dst_sizes = {2, 2}; + std::vector dst_strides = {1, 2}; // Transposed + + Storage dst_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 10 * sizeof(float))); + SlimTensor cpu_dst( + std::move(dst_storage), + makeArrayRef(dst_sizes), + makeArrayRef(dst_strides), + c10::ScalarType::Float, + 2); // offset of 2 elements + + ASSERT_FALSE(cpu_dst.is_contiguous()); + + cpu_dst.copy_(cuda_tensor); + + // Verify data starts at offset + float* raw_data = static_cast(cpu_dst.storage()->data()); + float* offset_data = static_cast(cpu_dst.data_ptr()); + + // offset_data should be 2 elements after raw_data + EXPECT_EQ(offset_data, raw_data + 2); + + // Verify transposed layout at offset + EXPECT_FLOAT_EQ(offset_data[0], 10.0f); // Logical[0,0] + EXPECT_FLOAT_EQ(offset_data[1], 12.0f); // Logical[1,0] + EXPECT_FLOAT_EQ(offset_data[2], 11.0f); // Logical[0,1] + EXPECT_FLOAT_EQ(offset_data[3], 13.0f); // Logical[1,1] +} + +TEST(CUDACopyTest, CopyNonContiguousCPUToCUDAInt64) { + // Test with different dtype (int64) + std::vector sizes = {2, 3}; + std::vector strides = {1, 2}; // Transposed + + Storage src_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 6 * sizeof(int64_t))); + int64_t* src_data = static_cast(src_storage->data()); + // Fill transposed layout + src_data[0] = 100; // Logical[0,0] + src_data[1] = 400; // Logical[1,0] + src_data[2] = 200; // Logical[0,1] + src_data[3] = 500; // Logical[1,1] + src_data[4] = 300; // Logical[0,2] + src_data[5] = 600; // Logical[1,2] + + SlimTensor cpu_src( + std::move(src_storage), + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Long); + + ASSERT_FALSE(cpu_src.is_contiguous()); + + auto cuda_dst = empty({2, 3}, c10::ScalarType::Long, DEFAULT_CUDA_DEVICE); + cuda_dst.copy_(cpu_src); + + auto verify = empty({2, 3}, c10::ScalarType::Long, CPU_DEVICE); + verify.copy_(cuda_dst); + + int64_t* verify_data = static_cast(verify.data_ptr()); + EXPECT_EQ(verify_data[0], 100); + EXPECT_EQ(verify_data[1], 200); + EXPECT_EQ(verify_data[2], 300); + EXPECT_EQ(verify_data[3], 400); + EXPECT_EQ(verify_data[4], 500); + EXPECT_EQ(verify_data[5], 600); +} + +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim From 01ca9366f771b88e642187b3b438d9506d0c0bc9 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 22 Jan 2026 09:56:02 -0800 Subject: [PATCH 4/4] [slimtensor] Add non-owning storage mode for wrapping external memory Pull Request resolved: https://github.com/pytorch/executorch/pull/16772 Add non-owning constructor to MaybeOwningStorage that allows wrapping external memory without taking ownership. This is needed for from_blob() which creates tensors from user-provided data pointers. Changes: - Add MaybeOwningStorage(device, data, nbytes) non-owning constructor - Add is_owning() method to query ownership mode - Non-owning storage uses no-op deleter so external memory is not freed ghstack-source-id: 335102163 @exported-using-ghexport Differential Revision: [D91202898](https://our.internmc.facebook.com/intern/diff/D91202898/) --- backends/aoti/slim/core/Storage.h | 11 + backends/aoti/slim/core/test/test_storage.cpp | 403 ++++++++++++++++++ 2 files changed, 414 insertions(+) diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index ccd63f75981..156556aa9e1 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -231,6 +231,17 @@ class MaybeOwningStorage { } } + /// Constructs non-owning storage with external memory. + /// @param device The device where the data resides. + /// @param data Pointer to external memory (not owned by this storage). + /// @param nbytes Size of the external memory in bytes. + MaybeOwningStorage(const c10::Device& device, void* data, size_t nbytes) + : device_(device), + data_(data), + capacity_(nbytes), + deleter_(detail::noop), + is_owning_(false) {} + /// Default constructor is deleted - storage must have a device. MaybeOwningStorage() = delete; diff --git a/backends/aoti/slim/core/test/test_storage.cpp b/backends/aoti/slim/core/test/test_storage.cpp index 5ff3d6620be..5d61019aa2b 100644 --- a/backends/aoti/slim/core/test/test_storage.cpp +++ b/backends/aoti/slim/core/test/test_storage.cpp @@ -72,6 +72,137 @@ TEST(DeviceTraitsCPUTest, MemcpyCPUToCPU) { DeviceTraits::free(dst); } +// ============================================================================= +// MaybeOwningStorage Tests - Non-Owning Mode +// ============================================================================= + +TEST(MaybeOwningStorageNonOwningTest, ConstructNonOwning) { + constexpr size_t kNumFloats = 64; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + + // Allocate external memory + float* external_data = static_cast( + DeviceTraits::allocate(kNbytes)); + + // Initialize external data + for (size_t i = 0; i < kNumFloats; ++i) { + external_data[i] = static_cast(i) * 2.5f; + } + + { + // Create non-owning storage + MaybeOwningStorage storage(CPU_DEVICE, external_data, kNbytes); + + EXPECT_EQ(storage.data(), external_data); + EXPECT_EQ(storage.nbytes(), kNbytes); + EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_FALSE(storage.is_owning()); + EXPECT_FALSE(storage.is_resizable()); + + // Verify data is accessible through storage + float* data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(data[i], static_cast(i) * 2.5f); + } + } + // After storage goes out of scope, external_data should still be valid + // because the storage did not own it + + // Verify external data is still accessible after storage is destroyed + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(external_data[i], static_cast(i) * 2.5f); + } + + // Clean up external data manually + DeviceTraits::free(external_data); +} + +TEST(MaybeOwningStorageNonOwningTest, ModifyThroughStorage) { + constexpr size_t kNumFloats = 32; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + + // Allocate and initialize external memory + float* external_data = static_cast( + DeviceTraits::allocate(kNbytes)); + for (size_t i = 0; i < kNumFloats; ++i) { + external_data[i] = 0.0f; + } + + { + MaybeOwningStorage storage(CPU_DEVICE, external_data, kNbytes); + + // Modify data through storage + float* data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + data[i] = static_cast(i) * 10.0f; + } + } + + // Verify external data was modified after storage is destroyed + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(external_data[i], static_cast(i) * 10.0f); + } + + DeviceTraits::free(external_data); +} + +TEST(MaybeOwningStorageNonOwningTest, MoveConstruct) { + constexpr size_t kNbytes = 256; + float* external_data = static_cast( + DeviceTraits::allocate(kNbytes)); + + MaybeOwningStorage original(CPU_DEVICE, external_data, kNbytes); + + MaybeOwningStorage moved(std::move(original)); + + EXPECT_EQ(moved.data(), external_data); + EXPECT_EQ(moved.nbytes(), kNbytes); + EXPECT_FALSE(moved.is_owning()); + + EXPECT_EQ(original.data(), nullptr); + EXPECT_EQ(original.nbytes(), 0); + EXPECT_FALSE(original.is_owning()); + + DeviceTraits::free(external_data); +} + +TEST(MaybeOwningStorageNonOwningTest, MoveAssign) { + constexpr size_t kNbytes1 = 256; + constexpr size_t kNbytes2 = 512; + + // Create two external buffers + float* external_data1 = static_cast( + DeviceTraits::allocate(kNbytes1)); + float* external_data2 = static_cast( + DeviceTraits::allocate(kNbytes2)); + + MaybeOwningStorage storage1(CPU_DEVICE, external_data1, kNbytes1); + MaybeOwningStorage storage2(CPU_DEVICE, external_data2, kNbytes2); + + storage1 = std::move(storage2); + + EXPECT_EQ(storage1.data(), external_data2); + EXPECT_EQ(storage1.nbytes(), kNbytes2); + EXPECT_FALSE(storage1.is_owning()); + + EXPECT_EQ(storage2.data(), nullptr); + EXPECT_EQ(storage2.nbytes(), 0); + EXPECT_FALSE(storage2.is_owning()); + + // Clean up both external buffers + DeviceTraits::free(external_data1); + DeviceTraits::free(external_data2); +} + +TEST(MaybeOwningStorageNonOwningTest, ZeroBytes) { + // Non-owning with nullptr and zero bytes + MaybeOwningStorage storage(CPU_DEVICE, nullptr, 0); + + EXPECT_EQ(storage.data(), nullptr); + EXPECT_EQ(storage.nbytes(), 0); + EXPECT_FALSE(storage.is_owning()); +} + // ============================================================================= // MaybeOwningStorage Parameterized Tests (CPU and CUDA) // ============================================================================= @@ -462,6 +593,278 @@ TEST(DeviceTraitsCUDATest, MemcpyCUDAToCUDA) { DeviceTraits::free(cpu_verify); } +// ============================================================================= +// MaybeOwningStorage CUDA Tests +// ============================================================================= + +TEST(MaybeOwningStorageCUDATest, ConstructOwning) { + constexpr size_t kNbytes = 512; + MaybeOwningStorage storage(DEFAULT_CUDA_DEVICE, kNbytes); + + EXPECT_NE(storage.data(), nullptr); + EXPECT_EQ(storage.nbytes(), kNbytes); + EXPECT_TRUE(storage.device().is_cuda()); + EXPECT_FALSE(storage.device().is_cpu()); + EXPECT_TRUE(storage.is_owning()); + EXPECT_TRUE(storage.is_resizable()); + EXPECT_EQ(storage.device().index(), 0); +} + +TEST(MaybeOwningStorageCUDATest, ConstructOwningZeroBytes) { + MaybeOwningStorage storage(DEFAULT_CUDA_DEVICE, 0); + + EXPECT_EQ(storage.data(), nullptr); + EXPECT_EQ(storage.nbytes(), 0); + EXPECT_TRUE(storage.device().is_cuda()); + EXPECT_TRUE(storage.is_owning()); +} + +TEST(MaybeOwningStorageCUDATest, MoveConstruct) { + constexpr size_t kNbytes = 256; + MaybeOwningStorage original(DEFAULT_CUDA_DEVICE, kNbytes); + void* original_data = original.data(); + + MaybeOwningStorage moved(std::move(original)); + + EXPECT_EQ(moved.data(), original_data); + EXPECT_EQ(moved.nbytes(), kNbytes); + EXPECT_TRUE(moved.is_owning()); + EXPECT_TRUE(moved.device().is_cuda()); + + EXPECT_EQ(original.data(), nullptr); + EXPECT_EQ(original.nbytes(), 0); + EXPECT_FALSE(original.is_owning()); +} + +TEST(MaybeOwningStorageCUDATest, MoveAssign) { + constexpr size_t kNbytes1 = 256; + constexpr size_t kNbytes2 = 512; + MaybeOwningStorage storage1(DEFAULT_CUDA_DEVICE, kNbytes1); + MaybeOwningStorage storage2(DEFAULT_CUDA_DEVICE, kNbytes2); + void* storage2_data = storage2.data(); + + storage1 = std::move(storage2); + + EXPECT_EQ(storage1.data(), storage2_data); + EXPECT_EQ(storage1.nbytes(), kNbytes2); + EXPECT_TRUE(storage1.is_owning()); + + EXPECT_EQ(storage2.data(), nullptr); + EXPECT_EQ(storage2.nbytes(), 0); + EXPECT_FALSE(storage2.is_owning()); +} + +// ============================================================================= +// MaybeOwningStorage CUDA Tests - Non-Owning Mode +// ============================================================================= + +TEST(MaybeOwningStorageCUDANonOwningTest, ConstructNonOwning) { + constexpr size_t kNumFloats = 64; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + + // Allocate external CUDA memory + float* external_data = + static_cast(DeviceTraits::allocate( + kNbytes, DEFAULT_CUDA_DEVICE)); + + // Initialize external data via CPU buffer + float* cpu_buffer = static_cast( + DeviceTraits::allocate(kNbytes)); + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_buffer[i] = static_cast(i) * 2.5f; + } + DeviceTraits::memcpy( + external_data, cpu_buffer, kNbytes, DEFAULT_CUDA_DEVICE, CPU_DEVICE); + + { + // Create non-owning storage + MaybeOwningStorage storage(DEFAULT_CUDA_DEVICE, external_data, kNbytes); + + EXPECT_EQ(storage.data(), external_data); + EXPECT_EQ(storage.nbytes(), kNbytes); + EXPECT_TRUE(storage.device().is_cuda()); + EXPECT_FALSE(storage.is_owning()); + EXPECT_FALSE(storage.is_resizable()); + + // Verify data is accessible through storage by copying back to CPU + float* verify_buffer = static_cast( + DeviceTraits::allocate(kNbytes)); + DeviceTraits::memcpy( + verify_buffer, + storage.data(), + kNbytes, + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_buffer[i], static_cast(i) * 2.5f); + } + DeviceTraits::free(verify_buffer); + } + // After storage goes out of scope, external_data should still be valid + // because the storage did not own it + + // Verify external data is still accessible after storage is destroyed + float* verify_buffer2 = static_cast( + DeviceTraits::allocate(kNbytes)); + DeviceTraits::memcpy( + verify_buffer2, external_data, kNbytes, CPU_DEVICE, DEFAULT_CUDA_DEVICE); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_buffer2[i], static_cast(i) * 2.5f); + } + + // Clean up + DeviceTraits::free(verify_buffer2); + DeviceTraits::free(cpu_buffer); + DeviceTraits::free(external_data); +} + +TEST(MaybeOwningStorageCUDANonOwningTest, ModifyThroughStorage) { + constexpr size_t kNumFloats = 32; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + + // Allocate external CUDA memory + float* external_data = + static_cast(DeviceTraits::allocate( + kNbytes, DEFAULT_CUDA_DEVICE)); + + // Initialize to zeros + float* cpu_buffer = static_cast( + DeviceTraits::allocate(kNbytes)); + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_buffer[i] = 0.0f; + } + DeviceTraits::memcpy( + external_data, cpu_buffer, kNbytes, DEFAULT_CUDA_DEVICE, CPU_DEVICE); + + { + MaybeOwningStorage storage(DEFAULT_CUDA_DEVICE, external_data, kNbytes); + + // Modify data through storage by copying new data + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_buffer[i] = static_cast(i) * 10.0f; + } + DeviceTraits::memcpy( + storage.data(), cpu_buffer, kNbytes, DEFAULT_CUDA_DEVICE, CPU_DEVICE); + } + + // Verify external data was modified after storage is destroyed + float* verify_buffer = static_cast( + DeviceTraits::allocate(kNbytes)); + DeviceTraits::memcpy( + verify_buffer, external_data, kNbytes, CPU_DEVICE, DEFAULT_CUDA_DEVICE); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_buffer[i], static_cast(i) * 10.0f); + } + + // Clean up + DeviceTraits::free(verify_buffer); + DeviceTraits::free(cpu_buffer); + DeviceTraits::free(external_data); +} + +TEST(MaybeOwningStorageCUDANonOwningTest, MoveConstruct) { + constexpr size_t kNbytes = 256; + float* external_data = + static_cast(DeviceTraits::allocate( + kNbytes, DEFAULT_CUDA_DEVICE)); + + MaybeOwningStorage original(DEFAULT_CUDA_DEVICE, external_data, kNbytes); + + MaybeOwningStorage moved(std::move(original)); + + EXPECT_EQ(moved.data(), external_data); + EXPECT_EQ(moved.nbytes(), kNbytes); + EXPECT_FALSE(moved.is_owning()); + EXPECT_TRUE(moved.device().is_cuda()); + + EXPECT_EQ(original.data(), nullptr); + EXPECT_EQ(original.nbytes(), 0); + EXPECT_FALSE(original.is_owning()); + + DeviceTraits::free(external_data); +} + +TEST(MaybeOwningStorageCUDANonOwningTest, MoveAssign) { + constexpr size_t kNbytes1 = 256; + constexpr size_t kNbytes2 = 512; + + // Create two external CUDA buffers + float* external_data1 = + static_cast(DeviceTraits::allocate( + kNbytes1, DEFAULT_CUDA_DEVICE)); + float* external_data2 = + static_cast(DeviceTraits::allocate( + kNbytes2, DEFAULT_CUDA_DEVICE)); + + MaybeOwningStorage storage1(DEFAULT_CUDA_DEVICE, external_data1, kNbytes1); + MaybeOwningStorage storage2(DEFAULT_CUDA_DEVICE, external_data2, kNbytes2); + + storage1 = std::move(storage2); + + EXPECT_EQ(storage1.data(), external_data2); + EXPECT_EQ(storage1.nbytes(), kNbytes2); + EXPECT_FALSE(storage1.is_owning()); + EXPECT_TRUE(storage1.device().is_cuda()); + + EXPECT_EQ(storage2.data(), nullptr); + EXPECT_EQ(storage2.nbytes(), 0); + EXPECT_FALSE(storage2.is_owning()); + + // Clean up both external buffers + DeviceTraits::free(external_data1); + DeviceTraits::free(external_data2); +} + +TEST(MaybeOwningStorageCUDANonOwningTest, ZeroBytes) { + // Non-owning with nullptr and zero bytes + MaybeOwningStorage storage(DEFAULT_CUDA_DEVICE, nullptr, 0); + + EXPECT_EQ(storage.data(), nullptr); + EXPECT_EQ(storage.nbytes(), 0); + EXPECT_FALSE(storage.is_owning()); + EXPECT_TRUE(storage.device().is_cuda()); +} + +// ============================================================================= +// Storage (SharedPtr) CUDA Tests +// ============================================================================= + +TEST(StorageSharedPtrCUDATest, BasicUsage) { + constexpr size_t kNbytes = 128; + Storage storage(new MaybeOwningStorage(DEFAULT_CUDA_DEVICE, kNbytes)); + + EXPECT_NE(storage.get(), nullptr); + EXPECT_NE(storage->data(), nullptr); + EXPECT_EQ(storage->nbytes(), kNbytes); + EXPECT_TRUE(storage->device().is_cuda()); + EXPECT_EQ(storage.use_count(), 1); +} + +TEST(StorageSharedPtrCUDATest, SharedOwnership) { + constexpr size_t kNbytes = 128; + Storage storage1(new MaybeOwningStorage(DEFAULT_CUDA_DEVICE, kNbytes)); + void* data_ptr = storage1->data(); + + Storage storage2 = storage1; + + EXPECT_EQ(storage1.use_count(), 2); + EXPECT_EQ(storage2.use_count(), 2); + EXPECT_EQ(storage1->data(), storage2->data()); + EXPECT_EQ(storage2->data(), data_ptr); +} + +TEST(StorageSharedPtrCUDATest, MoveSemantics) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(DEFAULT_CUDA_DEVICE, kNbytes)); + void* data_ptr = storage1->data(); + + Storage storage2 = std::move(storage1); + + EXPECT_EQ(storage1.get(), nullptr); + EXPECT_EQ(storage2->data(), data_ptr); + EXPECT_EQ(storage2.use_count(), 1); +} + #endif // CUDA_AVAILABLE } // namespace executorch::backends::aoti::slim