diff --git a/MODULE.bazel b/MODULE.bazel index 50381160aa..bcbb1efa1c 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -26,6 +26,13 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local. local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch") +local_executorch = use_repo_rule("//toolchains:local_executorch.bzl", "local_executorch") + +# Detect the locally installed ExecuTorch source tree at build time. +# Set EXECUTORCH_PATH to the directory containing runtime/, extension/, cmake-out/. +# Requires cmake-out/libexecutorch_core.a to be built first. +local_executorch(name = "executorch") + # External dependency for torch_tensorrt if you already have precompiled binaries. new_local_repository( name = "torch_tensorrt", @@ -77,6 +84,15 @@ local_torch(name = "libtorch") # build_file = "third_party/libtorch/BUILD" #) +# ExecuTorch source tree. The repository root is the *parent* of the +# executorch/ directory so that headers resolve as . +# Requires a cmake-out/ build inside the executorch source tree. +#new_local_repository( +# name = "executorch", +# build_file = "@//third_party/executorch:BUILD", +# path = "/home/lanl/git/executorch", +#) + #new_local_repository( # name = "tensorrt", # path = "/usr/", diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 19260149ae..ef6ae94e00 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -1,6 +1,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_pkg//:pkg.bzl", "pkg_tar") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") + package(default_visibility = ["//visibility:public"]) config_setting( @@ -58,8 +59,13 @@ config_setting( ], ) +# runtime_base: TRTEngine + device management + serialization utilities. +# Does NOT include register_jit_hooks.cpp (TorchScript torch::class_ / +# TORCH_LIBRARY registrations), so it can be linked into +# libtrt_executorch_backend.so without causing a duplicate-registration +# crash when libtorchtrt.so is also loaded in the same process. cc_library( - name = "runtime", + name = "runtime_base", srcs = [ "DeviceList.cpp", "Platform.cpp", @@ -67,8 +73,8 @@ cc_library( "TRTEngine.cpp", "TRTEngineProfiler.cpp", "execute_engine.cpp", - "register_jit_hooks.cpp", "runtime.cpp", + "runtime_utils.cpp", ], hdrs = [ "Platform.h", @@ -100,6 +106,26 @@ cc_library( alwayslink = True, ) +# runtime: full runtime including TorchScript torch::class_ / TORCH_LIBRARY +# registrations. Used by the main libtorchtrt.so. +cc_library( + name = "runtime", + srcs = [ + "register_jit_hooks.cpp", + ], + hdrs = [ + "Platform.h", + "RTDevice.h", + "TRTEngine.h", + "TRTEngineProfiler.h", + "runtime.h", + ], + deps = [ + ":runtime_base", + ], + alwayslink = True, +) + filegroup( name = "include_files", srcs = [ @@ -121,6 +147,6 @@ pkg_tar( pkg_files( name = "include_pkg_files", srcs = [":include_files"], - visibility = ["//visibility:public"], prefix = "include/torch_tensorrt/core/runtime/", + visibility = ["//visibility:public"], ) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d29daa112b..bc96a7a8ca 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -447,6 +447,29 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) { } void TRTEngine::verify_serialization_fmt(const std::vector& serialized_info) { + static const char* kIndexNames[] = { + "ABI_TARGET_IDX", + "NAME_IDX", + "DEVICE_IDX", + "ENGINE_IDX", + "INPUT_BINDING_NAMES_IDX", + "OUTPUT_BINDING_NAMES_IDX", + "HW_COMPATIBLE_IDX", + "SERIALIZED_METADATA_IDX", + "TARGET_PLATFORM_IDX", + "REQUIRES_OUTPUT_ALLOCATOR_IDX", + "RESOURCE_ALLOCATION_STRATEGY_IDX", + }; + fprintf(stderr, "[verify_serialization_fmt] %zu entries (expected %d):\n", serialized_info.size(), SERIALIZATION_LEN); + for (size_t i = 0; i < serialized_info.size(); ++i) { + const char* name = (i < sizeof(kIndexNames) / sizeof(kIndexNames[0])) ? kIndexNames[i] : "?"; + if (i == ENGINE_IDX) { + fprintf(stderr, " [%zu] %-35s = \n", i, name, serialized_info[i].size()); + } else { + fprintf(stderr, " [%zu] %-35s = \"%s\"\n", i, name, serialized_info[i].c_str()); + } + } + TORCHTRT_CHECK( serialized_info.size() == SERIALIZATION_LEN, "Program to be deserialized targets an incompatible Torch-TensorRT ABI"); diff --git a/core/runtime/executorch/BUILD b/core/runtime/executorch/BUILD new file mode 100644 index 0000000000..93b5cf9b2e --- /dev/null +++ b/core/runtime/executorch/BUILD @@ -0,0 +1,82 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "use_torch_whl", + flag_values = { + "//toolchains/dep_src:torch": "whl", + }, +) + +config_setting( + name = "rtx_x86_64", + constraint_values = [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], + flag_values = { + "//toolchains/dep_collection:compute_libs": "rtx", + }, +) + +config_setting( + name = "rtx_win", + constraint_values = [ + "@platforms//os:windows", + ], + flag_values = { + "//toolchains/dep_collection:compute_libs": "rtx", + }, +) + +config_setting( + name = "sbsa", + constraint_values = [ + "@platforms//cpu:aarch64", + ], + flag_values = { + "//toolchains/dep_collection:compute_libs": "default", + }, +) + +config_setting( + name = "jetpack", + constraint_values = [ + "@platforms//cpu:aarch64", + ], + flag_values = { + "//toolchains/dep_collection:compute_libs": "jetpack", + }, +) + +config_setting( + name = "windows", + constraint_values = [ + "@platforms//os:windows", + ], +) + +cc_library( + name = "tensorrt_executorch_backend", + srcs = ["TensorRTBackend.cpp"], + hdrs = ["TensorRTBackend.h"], + # Use executorch_headers (no static link) so that register_backend / + # find_backend / FreeableBuffer etc. remain undefined symbols in the + # final libtrt_executorch_backend.so. At runtime they are resolved from + # libqnn_executorch_backend.so, which _portable_lib.so loads with + # RTLD_GLOBAL, ensuring both share the same registry instance. + deps = [ + "//core/runtime:runtime_base", + "//core/util:prelude", + "@executorch//:executorch_headers", + ] + select({ + ":jetpack": ["@tensorrt_l4t//:nvinfer"], + ":rtx_win": ["@tensorrt_rtx_win//:nvinfer"], + ":rtx_x86_64": ["@tensorrt_rtx//:nvinfer"], + ":sbsa": ["@tensorrt_sbsa//:nvinfer"], + ":windows": ["@tensorrt_win//:nvinfer"], + "//conditions:default": ["@tensorrt//:nvinfer"], + }), + alwayslink = True, +) diff --git a/core/runtime/executorch/TensorRTBackend.cpp b/core/runtime/executorch/TensorRTBackend.cpp index 93f97dc8ce..ea8e76787f 100644 --- a/core/runtime/executorch/TensorRTBackend.cpp +++ b/core/runtime/executorch/TensorRTBackend.cpp @@ -19,6 +19,11 @@ #include #include +// RTDevice and Platform must be included before TRTEngine.h because TRTEngine.h +// references them without including their headers directly (Bazel handles this +// via transitive deps, but a standalone compile needs them explicit). +#include "core/runtime/Platform.h" +#include "core/runtime/RTDevice.h" #include "core/runtime/TRTEngine.h" #include "core/util/prelude.h" @@ -116,7 +121,18 @@ bool TensorRTBackend::is_available() const { // provided by the ExecuTorch MemoryAllocator so that ExecuTorch owns the // lifetime; destroy() calls the destructor explicitly. // --------------------------------------------------------------------------- -Result TensorRTBackend::init(BackendInitContext& context, FreeableBuffer* processed) const { +Result TensorRTBackend::init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const { + (void)compile_specs; + ET_LOG(Info, "TensorRTBackend::init: enter"); + + if (!is_available()) { + ET_LOG(Error, "TensorRT backend is not available"); + return Error::NotSupported; + } + if (processed == nullptr || processed->data() == nullptr) { ET_LOG(Error, "TensorRTBackend::init: null processed buffer"); return Error::InvalidArgument; @@ -125,19 +141,31 @@ Result TensorRTBackend::init(BackendInitContext& context, Freea auto serialized_info = deserialize_engine_info(processed->data(), processed->size()); if (serialized_info.empty()) { + fprintf(stderr, "[TensorRTBackend::init] FAIL: deserialize_engine_info returned empty\n"); ET_LOG(Error, "TensorRTBackend::init: failed to deserialize engine blob"); return Error::InvalidArgument; } + ET_LOG(Info, "TensorRTBackend::init: deserialized %zu entries", serialized_info.size()); // Validate the vector length before handing to TRTEngine // (verify_serialization_fmt throws on mismatch) - core::runtime::TRTEngine::verify_serialization_fmt(serialized_info); + ET_LOG(Info, "TensorRTBackend::init: calling verify_serialization_fmt"); + try { + core::runtime::TRTEngine::verify_serialization_fmt(serialized_info); + } catch (const std::exception& e) { + ET_LOG(Error, "TensorRTBackend::init: verify_serialization_fmt threw: %s", e.what()); + return Error::InvalidArgument; + } catch (...) { + ET_LOG(Error, "TensorRTBackend::init: verify_serialization_fmt threw unknown exception"); + return Error::InvalidArgument; + } MemoryAllocator* allocator = context.get_runtime_allocator(); if (allocator == nullptr) { ET_LOG(Error, "TensorRTBackend::init: null runtime allocator"); return Error::InvalidState; } + ET_LOG(Info, "TensorRTBackend::init: got allocator"); // Allocate raw storage for TRTEngine from ExecuTorch's arena core::runtime::TRTEngine* engine = allocator->allocateInstance(); @@ -145,11 +173,23 @@ Result TensorRTBackend::init(BackendInitContext& context, Freea ET_LOG(Error, "TensorRTBackend::init: allocateInstance failed"); return Error::MemoryAllocationFailed; } + ET_LOG(Info, "TensorRTBackend::init: allocated engine storage at %p", (void*)engine); // Construct in-place; TRTEngine(std::vector) deserializes the // engine bytes, builds the IRuntime/ICudaEngine/IExecutionContext, and // populates in_binding_names / out_binding_names / num_io. - new (engine) core::runtime::TRTEngine(std::move(serialized_info)); + ET_LOG(Info, "TensorRTBackend::init: constructing TRTEngine in-place"); + try { + new (engine) core::runtime::TRTEngine(std::move(serialized_info)); + } catch (const std::exception& e) { + fprintf(stderr, "[TensorRTBackend::init] FAIL: TRTEngine constructor threw: %s\n", e.what()); + ET_LOG(Error, "TensorRTBackend::init: TRTEngine constructor threw: %s", e.what()); + return Error::InvalidArgument; + } catch (...) { + fprintf(stderr, "[TensorRTBackend::init] FAIL: TRTEngine constructor threw unknown exception\n"); + ET_LOG(Error, "TensorRTBackend::init: TRTEngine constructor threw unknown exception"); + return Error::InvalidArgument; + } // Release the blob; we no longer need it processed->Free(); @@ -178,28 +218,58 @@ Result TensorRTBackend::init(BackendInitContext& context, Freea // --------------------------------------------------------------------------- Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* handle, Span args) const { (void)context; + fprintf(stderr, "[TensorRTBackend::execute] enter: handle=%p args.size()=%zu\n", (void*)handle, args.size()); + ET_LOG(Info, "TensorRTBackend::execute: enter"); if (handle == nullptr) { ET_LOG(Error, "TensorRTBackend::execute: null delegate handle"); return Error::InvalidArgument; } - + ET_LOG(Info, "TensorRTBackend::execute: got delegate handle"); auto* engine = static_cast(handle); const size_t num_inputs = engine->num_io.first; const size_t num_outputs = engine->num_io.second; + ET_LOG(Info, "TensorRTBackend::execute: got num_inputs %zu and num_outputs %zu", num_inputs, num_outputs); if (args.size() < num_inputs + num_outputs) { ET_LOG( Error, "TensorRTBackend::execute: expected at least %zu args, got %zu", num_inputs + num_outputs, args.size()); return Error::InvalidArgument; } - + ET_LOG(Info, "TensorRTBackend::execute: got engine"); // IExecutionContext::enqueueV3 is not thread-safe; use the engine mutex std::unique_lock lock(engine->mu); nvinfer1::IExecutionContext* ctx = engine->exec_ctx.get(); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(static_cast(engine->device_info.id)); + + // ExecuTorch's portable runtime pre-allocates output tensors as CPU buffers. + // TRT requires CUDA device pointers for all bindings. We use + // cudaPointerGetAttributes to detect CPU pointers and stage them through + // temporary CUDA allocations, copying back after inference. + auto is_cuda_ptr = [](const void* ptr) -> bool { + if (ptr == nullptr) + return false; + cudaPointerAttributes attrs{}; + cudaError_t err = cudaPointerGetAttributes(&attrs, ptr); + return err == cudaSuccess && attrs.type == cudaMemoryTypeDevice; + }; + + std::vector temp_input_bufs(num_inputs, nullptr); + std::vector temp_output_bufs(num_outputs, nullptr); + + // Cleanup helper – called on every return path. + auto free_temp = [&]() { + for (void* p : temp_input_bufs) + if (p) + cudaFree(p); + for (void* p : temp_output_bufs) + if (p) + cudaFree(p); + }; + // ------------------------------------------------------------------ // 1. Bind input shapes and addresses // ------------------------------------------------------------------ @@ -207,6 +277,7 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* EValue* arg = args[i]; if (arg == nullptr || !arg->isTensor()) { ET_LOG(Error, "TensorRTBackend::execute: input %zu is not a tensor", i); + free_temp(); return Error::InvalidArgument; } @@ -216,18 +287,31 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* if (!ctx->setInputShape(name.c_str(), dims)) { ET_LOG(Error, "TensorRTBackend::execute: setInputShape failed for '%s'", name.c_str()); + free_temp(); return Error::InvalidState; } - void* ptr = et_in.mutable_data_ptr(); - // TRT requires a non-null address even for 0-element tensors + void* src_ptr = et_in.mutable_data_ptr(); + void* trt_ptr = src_ptr; + static char placeholder[16] = {}; - if (ptr == nullptr || et_in.numel() == 0) { - ptr = placeholder; + if (src_ptr == nullptr || et_in.numel() == 0) { + trt_ptr = placeholder; + } else if (!is_cuda_ptr(src_ptr)) { + // CPU input: stage to a temporary CUDA buffer + size_t nbytes = et_in.nbytes(); + if (cudaMalloc(&temp_input_bufs[i], nbytes) != cudaSuccess) { + ET_LOG(Error, "TensorRTBackend::execute: cudaMalloc failed for input %zu", i); + free_temp(); + return Error::MemoryAllocationFailed; + } + cudaMemcpyAsync(temp_input_bufs[i], src_ptr, nbytes, cudaMemcpyHostToDevice, stream); + trt_ptr = temp_input_bufs[i]; } - if (!ctx->setTensorAddress(name.c_str(), ptr)) { + if (!ctx->setTensorAddress(name.c_str(), trt_ptr)) { ET_LOG(Error, "TensorRTBackend::execute: setTensorAddress failed for input '%s'", name.c_str()); + free_temp(); return Error::InvalidState; } } @@ -241,26 +325,58 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* const int32_t n_unresolved = ctx->inferShapes(io_size, unresolved.data()); if (n_unresolved != 0) { ET_LOG(Error, "TensorRTBackend::execute: inferShapes could not resolve %d tensor(s)", n_unresolved); + free_temp(); return Error::InvalidState; } } // ------------------------------------------------------------------ - // 3. Bind output addresses (ExecuTorch pre-allocates the buffers) + // 3. Bind output addresses + // ExecuTorch pre-allocates output tensors at the maximum shape for + // dynamic models. After inferShapes() TRT knows the actual output + // dims, so update the ExecuTorch TensorImpl's sizes before computing + // nbytes() and before the Python binding reads back the shape. + // If the buffer is CPU, stage through a temporary CUDA allocation. // ------------------------------------------------------------------ for (size_t o = 0; o < num_outputs; ++o) { EValue* arg = args[num_inputs + o]; if (arg == nullptr || !arg->isTensor()) { ET_LOG(Error, "TensorRTBackend::execute: output %zu is not a tensor", o); + free_temp(); return Error::InvalidArgument; } exec_aten::Tensor et_out = arg->toTensor(); const std::string& name = engine->out_binding_names[o]; - void* ptr = et_out.mutable_data_ptr(); - if (!ctx->setTensorAddress(name.c_str(), ptr)) { + // Update the ExecuTorch tensor shape to the actual TRT output shape. + // getTensorShape() is valid after inferShapes() has been called. + nvinfer1::Dims actual_dims = ctx->getTensorShape(name.c_str()); + if (actual_dims.nbDims > 0) { + exec_aten::SizesType new_sizes[nvinfer1::Dims::MAX_DIMS]; + for (int d = 0; d < actual_dims.nbDims; ++d) { + new_sizes[d] = static_cast(actual_dims.d[d]); + } + et_out.unsafeGetTensorImpl()->set_sizes_contiguous({new_sizes, static_cast(actual_dims.nbDims)}); + } + + void* dst_ptr = et_out.mutable_data_ptr(); + void* trt_ptr = dst_ptr; + + if (!is_cuda_ptr(dst_ptr)) { + // CPU output buffer: allocate temporary CUDA memory for TRT to write into + size_t nbytes = et_out.nbytes(); // uses updated shape + if (cudaMalloc(&temp_output_bufs[o], nbytes) != cudaSuccess) { + ET_LOG(Error, "TensorRTBackend::execute: cudaMalloc failed for output %zu", o); + free_temp(); + return Error::MemoryAllocationFailed; + } + trt_ptr = temp_output_bufs[o]; + } + + if (!ctx->setTensorAddress(name.c_str(), trt_ptr)) { ET_LOG(Error, "TensorRTBackend::execute: setTensorAddress failed for output '%s'", name.c_str()); + free_temp(); return Error::InvalidState; } } @@ -268,16 +384,26 @@ Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* // ------------------------------------------------------------------ // 4. Enqueue inference on the current CUDA stream // ------------------------------------------------------------------ - cudaStream_t stream = c10::cuda::getCurrentCUDAStream(static_cast(engine->device_info.id)); - if (!ctx->enqueueV3(stream)) { ET_LOG(Error, "TensorRTBackend::execute: enqueueV3 failed"); + free_temp(); return Error::InvalidState; } - // Synchronize so that outputs are visible to downstream ExecuTorch ops + // ------------------------------------------------------------------ + // 5. Copy temporary CUDA outputs back to the ExecuTorch CPU buffers + // ------------------------------------------------------------------ + for (size_t o = 0; o < num_outputs; ++o) { + if (temp_output_bufs[o] != nullptr) { + exec_aten::Tensor et_out = args[num_inputs + o]->toTensor(); + cudaMemcpyAsync(et_out.mutable_data_ptr(), temp_output_bufs[o], et_out.nbytes(), cudaMemcpyDeviceToHost, stream); + } + } + + // Synchronize so outputs are visible to downstream ExecuTorch ops cudaStreamSynchronize(stream); + free_temp(); return Error::Ok; } diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e8f6217a21..bffb50f03e 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -4,67 +4,14 @@ #include "core/runtime/runtime.h" #include "core/util/macros.h" +// serialize_bindings / base64_encode / base64_decode are defined in +// runtime_utils.cpp so the ExecuTorch backend can link them without +// pulling in the torch::class_ registration below. + namespace torch_tensorrt { namespace core { namespace runtime { -std::string serialize_bindings(const std::vector& bindings) { - std::stringstream ss; - for (size_t i = 0; i < bindings.size() - 1; i++) { - ss << bindings[i] << TRTEngine::BINDING_DELIM; - } - ss << bindings[bindings.size() - 1]; - - std::string serialized_binding_info = ss.str(); - - LOG_DEBUG("Serialized Binding Info: " << serialized_binding_info); - - return serialized_binding_info; -} - -static const std::string sym_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //= -std::string base64_encode(const std::string& in) { - std::string out; - int64_t val = 0, valb = -6; - for (unsigned char c : in) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - out.push_back(sym_table[(val >> valb) & 0x3F]); - valb -= 6; - } - } - if (valb > -6) { - out.push_back(sym_table[((val << 8) >> (valb + 8)) & 0x3F]); - }; - while (out.size() % 4) { - out.push_back('='); - } - return out; -} - -std::string base64_decode(const std::string& in) { - std::string out; - std::vector T(256, -1); - for (int i = 0; i < 64; i++) { - T[sym_table[i]] = i; - } - - int64_t val = 0, valb = -8; - for (unsigned char c : in) { - if (T[c] == -1) { - break; - } - val = (val << 6) + T[c]; - valb += 6; - if (valb >= 0) { - out.push_back(char((val >> valb) & 0xFF)); - valb -= 8; - } - } - return out; -} - namespace { // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { diff --git a/core/runtime/runtime_utils.cpp b/core/runtime/runtime_utils.cpp new file mode 100644 index 0000000000..77b3df07ba --- /dev/null +++ b/core/runtime/runtime_utils.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Serialization utilities shared by TRTEngine and register_jit_hooks. + * Kept in a separate translation unit so that the ExecuTorch backend + * (libtrt_executorch_backend.so) can link these without pulling in the + * TorchScript torch::class_ / TORCH_LIBRARY registrations in + * register_jit_hooks.cpp, which would cause a duplicate-registration + * crash when libtorchtrt.so is also loaded in the same process. + */ + +#include +#include +#include + +#include "core/runtime/runtime.h" +#include "core/util/macros.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +std::string serialize_bindings(const std::vector& bindings) { + std::stringstream ss; + for (size_t i = 0; i < bindings.size() - 1; i++) { + ss << bindings[i] << TRTEngine::BINDING_DELIM; + } + ss << bindings[bindings.size() - 1]; + + std::string serialized_binding_info = ss.str(); + + LOG_DEBUG("Serialized Binding Info: " << serialized_binding_info); + + return serialized_binding_info; +} + +// Base64 alphabet (RFC 4648 §4) +static const std::string sym_table = // NOLINT(cert-err58-cpp) + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //= + +std::string base64_encode(const std::string& in) { + std::string out; + int64_t val = 0, valb = -6; + for (unsigned char c : in) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(sym_table[(val >> valb) & 0x3F]); + valb -= 6; + } + } + if (valb > -6) { + out.push_back(sym_table[((val << 8) >> (valb + 8)) & 0x3F]); + } + while (out.size() % 4) { + out.push_back('='); + } + return out; +} + +std::string base64_decode(const std::string& in) { + std::string out; + std::vector T(256, -1); + for (int i = 0; i < 64; i++) { + T[sym_table[i]] = i; + } + + int64_t val = 0, valb = -8; + for (unsigned char c : in) { + if (T[c] == -1) { + break; + } + val = (val << 6) + T[c]; + valb += 6; + if (valb >= 0) { + out.push_back(char((val >> valb) & 0xFF)); + valb -= 8; + } + } + return out; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/cpp/lib/BUILD b/cpp/lib/BUILD index 9054cd93d0..cdbb40b825 100644 --- a/cpp/lib/BUILD +++ b/cpp/lib/BUILD @@ -1,5 +1,6 @@ load("@rules_cc//cc:defs.bzl", "cc_binary") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") + package(default_visibility = ["//visibility:public"]) cc_binary( @@ -54,10 +55,25 @@ cc_binary( ], ) +cc_binary( + name = "libtrt_executorch_backend.so", + srcs = [], + # Allow undefined symbols: ExecuTorch runtime functions (register_backend, + # find_backend, FreeableBuffer, …) are intentionally left unresolved at + # link time. They are resolved at dlopen() time from + # libqnn_executorch_backend.so, which _portable_lib.so has already loaded + # into the process with RTLD_GLOBAL. + linkopts = ["-Wl,--allow-shlib-undefined"], + linkshared = True, + linkstatic = True, + deps = [ + "//core/runtime/executorch:tensorrt_executorch_backend", + ], +) pkg_files( name = "lib_pkg_files", srcs = ["torchtrt.dll"], - visibility = ["//visibility:public"], prefix = "lib/", -) \ No newline at end of file + visibility = ["//visibility:public"], +) diff --git a/examples/torchtrt_executorch_example/BUILD b/examples/torchtrt_executorch_example/BUILD new file mode 100644 index 0000000000..4a26127858 --- /dev/null +++ b/examples/torchtrt_executorch_example/BUILD @@ -0,0 +1,33 @@ +load("@rules_cc//cc:defs.bzl", "cc_binary") + +package(default_visibility = ["//visibility:public"]) + +# C++ inference runner for .pte files compiled with Torch-TensorRT. +# Mirrors load_static_shape.py but drives the ExecuTorch C++ runtime API +# directly instead of Python pybindings. +# +# Build: +# bazel build //examples/torchtrt_executorch_example:trt_executor_runner +# +# Run: +# ./bazel-bin/examples/torchtrt_executorch_example/trt_executor_runner \ +# --model_path=model.pte [--num_runs=1] +cc_binary( + name = "trt_executor_runner", + srcs = ["executor_runner.cpp"], + deps = [ + # Headers for ExecuTorch runtime/ and extension/ (). + "@executorch//:executorch_headers", + # Static ExecuTorch C++ runtime: Program, Method, MethodMeta, + # runtime_init, HierarchicalAllocator, MemoryManager, … + "@executorch//:executorch_core", + # FileDataLoader (extension/data_loader/file_data_loader.cpp). + "@executorch//:executorch_file_data_loader", + # libqnn_executorch_backend.so — carries the 6 backend-registry symbols + # that TensorRTBackend.o leaves undefined (register_backend, vlogf, …). + # Also needed at runtime: load it with RTLD_GLOBAL before executing. + "@executorch//:executorch_runtime", + # Registers TensorRTBackend via static initialiser (alwayslink = True). + "//core/runtime/executorch:tensorrt_executorch_backend", + ], +) diff --git a/examples/torchtrt_executorch_example/executor_runner.cpp b/examples/torchtrt_executorch_example/executor_runner.cpp new file mode 100644 index 0000000000..21ae949666 --- /dev/null +++ b/examples/torchtrt_executorch_example/executor_runner.cpp @@ -0,0 +1,249 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * C++ inference runner for .pte files compiled with Torch-TensorRT. + * Mirrors load_static_shape.py but uses the ExecuTorch C++ runtime API + * directly instead of the Python pybindings. + * + * Usage: + * trt_executor_runner --model_path=model.pte [--num_runs=1] + * + * The runner fills all inputs with ones, runs inference, and prints the + * output shape and first/last values. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::extension::FileDataLoader; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::TensorInfo; + +// ExecuTorch does not use malloc; all memory comes from these static pools. +static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB +static uint8_t temp_allocator_pool[1 * 1024U * 1024U]; // 1 MB + +// --------------------------------------------------------------------------- +// Simple arg parser (avoids gflags dependency) +// --------------------------------------------------------------------------- +static const char* get_flag(int argc, char** argv, const char* flag, const char* def) { + const size_t n = strlen(flag); + for (int i = 1; i < argc; ++i) { + if (strncmp(argv[i], flag, n) == 0 && argv[i][n] == '=') { + return argv[i] + n + 1; + } + } + return def; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- +int main(int argc, char** argv) { + executorch::runtime::runtime_init(); + + const char* model_path = get_flag(argc, argv, "--model_path", "model.pte"); + const int num_runs = atoi(get_flag(argc, argv, "--num_runs", "1")); + + // ------------------------------------------------------------------ + // 1. Load the .pte file + // ------------------------------------------------------------------ + Result loader_result = FileDataLoader::from(model_path); + if (!loader_result.ok()) { + ET_LOG( + Error, + "FileDataLoader::from('%s') failed: 0x%" PRIx32, + model_path, + static_cast(loader_result.error())); + return 1; + } + auto loader = std::make_unique(std::move(loader_result.get())); + + Result program = Program::load(loader.get()); + if (!program.ok()) { + ET_LOG(Error, "Failed to parse model '%s'", model_path); + return 1; + } + ET_LOG(Info, "Model '%s' loaded.", model_path); + + // ------------------------------------------------------------------ + // 2. Identify the method to run (use the first one, typically "forward") + // ------------------------------------------------------------------ + const char* method_name = nullptr; + { + auto name_result = program->get_method_name(0); + ET_CHECK_MSG(name_result.ok(), "Program has no methods"); + method_name = *name_result; + } + ET_LOG(Info, "Method: '%s'", method_name); + + // ------------------------------------------------------------------ + // 3. Inspect memory requirements via MethodMeta + // ------------------------------------------------------------------ + Result method_meta = program->method_meta(method_name); + ET_CHECK_MSG( + method_meta.ok(), + "method_meta('%s') failed: 0x%" PRIx32, + method_name, + static_cast(method_meta.error())); + + // ------------------------------------------------------------------ + // 4. Allocate memory + // - method_allocator: C++ metadata objects for the loaded Method + // - planned_memory: mutable tensor data (sizes from the .pte) + // - temp_allocator: scratch space for kernel temporaries + // ------------------------------------------------------------------ + MemoryAllocator method_allocator{MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)}; + MemoryAllocator temp_allocator{MemoryAllocator(sizeof(temp_allocator_pool), temp_allocator_pool)}; + + std::vector> planned_buffers; + std::vector> planned_spans; + const size_t num_planned = method_meta->num_memory_planned_buffers(); + for (size_t i = 0; i < num_planned; ++i) { + const size_t sz = static_cast(method_meta->memory_planned_buffer_size(i).get()); + ET_LOG(Info, " planned buffer[%zu] = %zu bytes", i, sz); + planned_buffers.push_back(std::make_unique(sz)); + planned_spans.push_back({planned_buffers.back().get(), sz}); + } + HierarchicalAllocator planned_memory{{planned_spans.data(), planned_spans.size()}}; + MemoryManager memory_manager{&method_allocator, &planned_memory, &temp_allocator}; + + // ------------------------------------------------------------------ + // 5. Load the method (this triggers TensorRTBackend::init for any + // TRT delegate sub-graphs in the .pte) + // ------------------------------------------------------------------ + Result method = program->load_method(method_name, &memory_manager, /*event_tracer=*/nullptr); + ET_CHECK_MSG(method.ok(), "load_method('%s') failed: 0x%" PRIx32, method_name, static_cast(method.error())); + ET_LOG(Info, "Method loaded. inputs=%zu outputs=%zu", method->inputs_size(), method->outputs_size()); + + // ------------------------------------------------------------------ + // 6. Prepare input tensors (allocate + fill with 1.0f) + // We create one float32 buffer per input, sized from MethodMeta, + // build a TensorImpl, and call method->set_input(). + // ------------------------------------------------------------------ + const size_t num_inputs = method_meta->num_inputs(); + // These buffers must outlive the execution loop. + std::vector> input_data(num_inputs); + std::vector> input_sizes(num_inputs); + std::vector> input_dim_order(num_inputs); + std::vector> input_strides(num_inputs); + std::vector input_impls; + input_impls.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + Result tensor_info = method_meta->input_tensor_meta(i); + ET_CHECK_MSG( + tensor_info.ok(), "input_tensor_meta(%zu) failed: 0x%" PRIx32, i, static_cast(tensor_info.error())); + + // Copy sizes and compute strides (row-major / contiguous) + const auto& sizes_ref = tensor_info->sizes(); + const ssize_t ndim = static_cast(sizes_ref.size()); + + input_sizes[i].assign(sizes_ref.begin(), sizes_ref.end()); + input_dim_order[i].resize(ndim); + input_strides[i].resize(ndim); + for (ssize_t d = 0; d < ndim; ++d) { + input_dim_order[i][d] = static_cast(d); + } + exec_aten::StridesType stride = 1; + for (ssize_t d = ndim - 1; d >= 0; --d) { + input_strides[i][d] = stride; + stride *= static_cast(input_sizes[i][d]); + } + + const size_t numel = static_cast(tensor_info->nbytes() / sizeof(float)); + input_data[i].assign(numel, 1.0f); // fill with ones + + // Print input shape + fprintf(stderr, " input[%zu] shape=[", i); + for (ssize_t d = 0; d < ndim; ++d) { + fprintf(stderr, "%d%s", input_sizes[i][d], d + 1 < ndim ? "," : ""); + } + fprintf(stderr, "] numel=%zu\n", numel); + + input_impls.emplace_back( + tensor_info->scalar_type(), + ndim, + input_sizes[i].data(), + input_data[i].data(), + input_dim_order[i].data(), + input_strides[i].data()); + } + + // ------------------------------------------------------------------ + // 7. Run inference (num_runs times) + // ------------------------------------------------------------------ + for (int run = 0; run < num_runs; ++run) { + // Set inputs (must be done each run in case memory planning reuses them) + for (size_t i = 0; i < num_inputs; ++i) { + exec_aten::Tensor input_tensor(&input_impls[i]); + EValue input_evalue(input_tensor); + Error err = method->set_input(input_evalue, i); + ET_CHECK_MSG(err == Error::Ok, "set_input(%zu) failed: 0x%" PRIx32, i, static_cast(err)); + } + + Error status = method->execute(); + ET_CHECK_MSG(status == Error::Ok, "execute() failed on run %d: 0x%" PRIx32, run, static_cast(status)); + } + ET_LOG(Info, "Inference completed (%d run(s)).", num_runs); + + // ------------------------------------------------------------------ + // 8. Read and print outputs + // ------------------------------------------------------------------ + const size_t num_outputs = method->outputs_size(); + std::vector outputs(num_outputs); + Error status = method->get_outputs(outputs.data(), num_outputs); + ET_CHECK_MSG(status == Error::Ok, "get_outputs() failed"); + + for (size_t i = 0; i < num_outputs; ++i) { + if (!outputs[i].isTensor()) { + ET_LOG(Info, "output[%zu]: not a tensor", i); + continue; + } + exec_aten::Tensor t = outputs[i].toTensor(); + fprintf(stderr, "output[%zu] shape=[", i); + for (ssize_t d = 0; d < t.dim(); ++d) { + fprintf(stderr, "%d%s", (int)t.size(d), d + 1 < t.dim() ? "," : ""); + } + fprintf(stderr, "] numel=%zu dtype=%d\n", (size_t)t.numel(), (int)t.scalar_type()); + + // Print up to the first 8 float values + if (t.scalar_type() == exec_aten::ScalarType::Float) { + const float* data = t.const_data_ptr(); + const size_t print_n = t.numel() < 8 ? (size_t)t.numel() : 8; + fprintf(stderr, " first %zu values:", print_n); + for (size_t j = 0; j < print_n; ++j) { + fprintf(stderr, " %.4f", data[j]); + } + fprintf(stderr, "\n"); + } + } + + return 0; +} diff --git a/examples/torchtrt_executorch_example/export_dynamic_shape.py b/examples/torchtrt_executorch_example/export_dynamic_shape.py new file mode 100644 index 0000000000..62b71cac9e --- /dev/null +++ b/examples/torchtrt_executorch_example/export_dynamic_shape.py @@ -0,0 +1,84 @@ +""" +.. _executorch_export_dynamic: + +Saving a Torch-TensorRT Model with Dynamic Shapes in ExecuTorch Format (.pte) +============================================================================== + +This example demonstrates how to compile a model with Torch-TensorRT using +dynamic (range-based) input shapes and save it as an ExecuTorch ``.pte`` file. + +The TRT engine is built with a shape profile: batch size can vary between 1 +and 8, and the spatial dimensions between 2 and 8, while the channel dimension +is fixed at 3. The ExecuTorch runtime will select the correct binding sizes +at execute() time based on the actual input shapes. + +Prerequisites +------------- +Install ExecuTorch before running this example:: + + pip install executorch + +See https://pytorch.org/executorch/stable/getting-started-setup.html for details. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt + + +class MyModel(torch.nn.Module): + def forward(self, x): + return x + 1 + + +# %% +# Compile with Torch-TensorRT using dynamic shapes +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# torch_tensorrt.Input with min/opt/max shapes builds a TRT optimization +# profile that covers the full range. The exported program is traced at the +# opt shape; min/max define the runtime range accepted by the engine. + +with torch.no_grad(): + model = MyModel().eval().cuda() + # Trace at the opt shape + opt_input = (torch.randn((4, 3, 4, 4)).cuda(),) + + # Mark variable dimensions so the ExecuTorch .pte allocates them as + # resizable tensors. Without dynamic_shapes the portable runtime + # pre-allocates fixed-size tensors at the export shape and rejects + # inputs of any other size at execute() time. + batch = torch.export.Dim("batch", min=1, max=8) + spatial = torch.export.Dim("spatial", min=2, max=8) + dynamic_shapes = {"x": {0: batch, 2: spatial, 3: spatial}} + + exported_program = torch.export.export( + model, opt_input, dynamic_shapes=dynamic_shapes + ) + compile_settings = { + "arg_inputs": [ + torch_tensorrt.Input( + min_shape=(1, 3, 2, 2), + opt_shape=(4, 3, 4, 4), + max_shape=(8, 3, 8, 8), + dtype=torch.float32, + ), + ], + "min_block_size": 1, + } + trt_gm = torch_tensorrt.dynamo.compile(exported_program, **compile_settings) + + # %% + # Save as ExecuTorch .pte format + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + torch_tensorrt.save( + trt_gm, + "model_dynamic.pte", + output_format="executorch", + arg_inputs=opt_input, + retrace=False, + ) + + print("Saved model_dynamic.pte successfully.") diff --git a/examples/torchtrt_executorch_example/load_dynamic_shape.py b/examples/torchtrt_executorch_example/load_dynamic_shape.py new file mode 100644 index 0000000000..a7fae2d9e6 --- /dev/null +++ b/examples/torchtrt_executorch_example/load_dynamic_shape.py @@ -0,0 +1,82 @@ +""" +.. _executorch_load_dynamic: + +Loading a Torch-TensorRT Dynamic-Shape Model from ExecuTorch Format (.pte) +=========================================================================== + +This example demonstrates how to load a ``.pte`` file produced by +``export_dynamic_shape.py`` and run inference at several different input +shapes within the compiled min/max range. + +Prerequisites +------------- +- ExecuTorch installed with a runtime that includes the TensorRT backend. +- Run ``export_dynamic_shape.py`` first to produce ``model_dynamic.pte``. +""" + +# %% +# Imports +# ^^^^^^^ + +import ctypes +import os + +import torch +import torch_tensorrt # noqa: F401 -- loads libtorchtrt.so / libtorchtrt_runtime.so + +# libqnn_executorch_backend.so carries the ExecuTorch runtime (including +# executorch::runtime::internal::vlogf and register_backend). It must be +# loaded with RTLD_GLOBAL so its symbols are visible to subsequently +# dlopen'd libraries (libtrt_executorch_backend.so and portable_lib). +_executorch_path = os.environ.get("EXECUTORCH_PATH", "/home/lanl/git/executorch") +ctypes.CDLL( + os.path.join(_executorch_path, "backends/qualcomm/libqnn_executorch_backend.so"), + mode=ctypes.RTLD_GLOBAL, +) + +# Load the TensorRT ExecuTorch backend shared library, which runs a static +# initializer that calls executorch::runtime::register_backend("TensorRTBackend"). +# Resolves runtime symbols from libqnn_executorch_backend.so loaded above. +_lib_dir = os.path.join(os.path.dirname(torch_tensorrt.__file__), "lib") +ctypes.CDLL(os.path.join(_lib_dir, "libtrt_executorch_backend.so")) + +from executorch.extension.pybindings import portable_lib as runtime + +# %% +# Load the .pte file +# ^^^^^^^^^^^^^^^^^^ + +executorch_module = runtime._load_for_executorch("model_dynamic.pte") + +# %% +# Run inference at multiple shapes within the compiled profile +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Inputs must be CPU tensors; execute() stages them to CUDA internally. +# All shapes must lie within the min/max range used at export time: +# batch: [1, 8] channels: 3 (fixed) spatial: [2, 8] + + +class MyModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + 1 + + +ref_model = MyModel().eval() + +test_shapes = [ + (1, 3, 2, 2), # minimum shape + (4, 3, 4, 4), # opt shape (used at export time) + (8, 3, 8, 8), # maximum shape + (2, 3, 6, 6), # arbitrary shape within range +] + +for shape in test_shapes: + example_input = torch.randn(shape) # CPU tensor + outputs = executorch_module.forward([example_input]) + with torch.no_grad(): + expected = ref_model(example_input) + + torch.testing.assert_close(outputs[0], expected, rtol=1e-3, atol=1e-3) + print(f"shape={shape} output={outputs[0].shape} dtype={outputs[0].dtype} OK") + +print("All dynamic-shape inference runs passed.") diff --git a/examples/torchtrt_executorch_example/load_static_shape.py b/examples/torchtrt_executorch_example/load_static_shape.py new file mode 100644 index 0000000000..ac27fcbbd9 --- /dev/null +++ b/examples/torchtrt_executorch_example/load_static_shape.py @@ -0,0 +1,80 @@ +""" +.. _executorch_load: + +Loading a Torch-TensorRT Model from ExecuTorch Format (.pte) +============================================================= + +This example demonstrates how to load a ``.pte`` file produced by +``export_static_shape.py`` and run inference with the ExecuTorch runtime. + +Prerequisites +------------- +- ExecuTorch installed with a runtime that includes the TensorRT backend. +- Run ``export_static_shape.py`` first to produce ``model.pte``. +""" + +# %% +# Imports +# ^^^^^^^ + +import ctypes +import os + +import torch +import torch_tensorrt # noqa: F401 -- loads libtorchtrt.so / libtorchtrt_runtime.so + +# libqnn_executorch_backend.so carries the ExecuTorch runtime (including +# executorch::runtime::internal::vlogf and register_backend). It must be +# loaded with RTLD_GLOBAL so its symbols are visible to subsequently +# dlopen'd libraries (libtrt_executorch_backend.so and portable_lib). +_executorch_path = os.environ.get("EXECUTORCH_PATH", "/home/lanl/git/executorch") +ctypes.CDLL( + os.path.join(_executorch_path, "backends/qualcomm/libqnn_executorch_backend.so"), + mode=ctypes.RTLD_GLOBAL, +) + +# Load the TensorRT ExecuTorch backend shared library, which runs a static +# initializer that calls executorch::runtime::register_backend("TensorRTBackend"). +# Resolves runtime symbols from libqnn_executorch_backend.so loaded above. +_lib_dir = os.path.join(os.path.dirname(torch_tensorrt.__file__), "lib") +ctypes.CDLL(os.path.join(_lib_dir, "libtrt_executorch_backend.so")) + +from executorch.extension.pybindings import portable_lib as runtime + +# %% +# Load the .pte file +# ^^^^^^^^^^^^^^^^^^ +# _load_for_executorch returns an ExecuTorchModule whose methods mirror the +# original model's exported methods (e.g. "forward"). + +executorch_module = runtime._load_for_executorch("model.pte") +# %% +# Run inference +# ^^^^^^^^^^^^^ +# Inputs must be passed as a list of tensors matching the static shapes used +# at export time: (2, 3, 4, 4) float32 on CUDA. + +example_input = torch.randn( + (2, 3, 4, 4) +) # CPU tensor; execute() stages it to CUDA internally +outputs = executorch_module.forward([example_input]) + +print("Output shape:", outputs[0].shape) +print("Output dtype:", outputs[0].dtype) + +# %% +# Verify against eager mode +# ^^^^^^^^^^^^^^^^^^^^^^^^^ + + +class MyModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + 1 + + +model = MyModel().eval() +with torch.no_grad(): + expected = model(example_input) + +torch.testing.assert_close(outputs[0], expected, rtol=1e-3, atol=1e-3) +print("Output matches eager mode.") diff --git a/third_party/executorch/BUILD b/third_party/executorch/BUILD new file mode 100644 index 0000000000..ae1be929e3 --- /dev/null +++ b/third_party/executorch/BUILD @@ -0,0 +1,56 @@ +load("@rules_cc//cc:defs.bzl", "cc_import", "cc_library") + +package(default_visibility = ["//visibility:public"]) + +# This BUILD file is used both by: +# - toolchains/local_executorch.bzl (symlinks runtime/, extension/, cmake-out/ +# into the synthetic repo root) +# - new_local_repository(path = "${EXECUTORCH_PATH}") (repo root IS the +# executorch source tree) +# +# In both cases the repo root contains runtime/, extension/, and cmake-out/. +# include_prefix = "executorch" remaps runtime/... → . + +# Headers-only: provides include paths for compilation but does NOT link any +# static ExecuTorch library. ExecuTorch runtime symbols (register_backend, +# find_backend, FreeableBuffer, …) are left undefined and resolved at dlopen() +# time from libqnn_executorch_backend.so, which _portable_lib.so loads with +# RTLD_GLOBAL. This ensures all backends share the same registry instance. +cc_library( + name = "executorch_headers", + hdrs = glob([ + "runtime/**/*.h", + "extension/**/*.h", + ]), + include_prefix = "executorch", +) + +# ExecuTorch runtime shared library. Carries 6 backend-registry symbols that +# libtrt_executorch_backend.so leaves undefined at link time. Must be loaded +# with RTLD_GLOBAL before any TRT backend .so is dlopen'd so all backends share +# the same registry instance. +cc_import( + name = "executorch_runtime", + shared_library = "libqnn_executorch_backend.so", +) + +# Full ExecuTorch C++ runtime (Program, Method, MethodMeta, runtime_init, …) +# built from the ExecuTorch cmake-out. Link this into C++ binaries that call +# the ExecuTorch C++ API directly (e.g. trt_executor_runner). +cc_import( + name = "executorch_core", + static_library = "libexecutorch_core.a", +) + +# FileDataLoader — lives in extension/data_loader/, compiled from the source +# tree that local_executorch.bzl symlinks into the synthetic repo root. +# Needs @libtorch for c10/util/safe_numerics.h. +cc_library( + name = "executorch_file_data_loader", + srcs = ["extension/data_loader/file_data_loader.cpp"], + deps = [ + ":executorch_core", + ":executorch_headers", + "@libtorch", + ], +) diff --git a/toolchains/ci_workspaces/MODULE.bazel.tmpl b/toolchains/ci_workspaces/MODULE.bazel.tmpl index 7f386d0a4f..5201262196 100644 --- a/toolchains/ci_workspaces/MODULE.bazel.tmpl +++ b/toolchains/ci_workspaces/MODULE.bazel.tmpl @@ -150,6 +150,14 @@ new_local_repository( build_file = "third_party/libtorch/BUILD" ) +# ExecuTorch source tree. EXECUTORCH_PATH must point to the directory that +# contains runtime/, extension/, and cmake-out/libexecutorch_core.a. +new_local_repository( + name = "executorch", + path = "${EXECUTORCH_PATH}", + build_file = "third_party/executorch/BUILD" +) + #new_local_repository( # name = "tensorrt", # path = "/usr/", diff --git a/toolchains/local_executorch.bzl b/toolchains/local_executorch.bzl new file mode 100644 index 0000000000..2da608140d --- /dev/null +++ b/toolchains/local_executorch.bzl @@ -0,0 +1,115 @@ +"""Repository rule that locates the locally installed ExecuTorch source tree. + +Discovery order: + 1. EXECUTORCH_PATH env var — absolute path to the executorch source root + (the directory containing runtime/, extension/) + 2. import executorch from the active Python interpreter — walks up from the + package directory to find the source root that contains runtime/backend/ + 3. VIRTUAL_ENV / CONDA_PREFIX / .venv / system python3 + +Only the header files under runtime/ and extension/ are used; no cmake build +of libexecutorch_core.a is required. ExecuTorch runtime symbols are resolved +at dlopen() time from libqnn_executorch_backend.so (loaded by _portable_lib.so). +""" + +def _find_python(ctx): + candidates = [] + + virtual_env = ctx.os.environ.get("VIRTUAL_ENV", "") + if virtual_env: + candidates.append(ctx.path(virtual_env + "/bin/python3")) + candidates.append(ctx.path(virtual_env + "/bin/python")) + + conda_prefix = ctx.os.environ.get("CONDA_PREFIX", "") + if conda_prefix: + candidates.append(ctx.path(conda_prefix + "/bin/python3")) + candidates.append(ctx.path(conda_prefix + "/bin/python")) + + ws = ctx.workspace_root + for rel in [".venv/bin/python3", ".venv/bin/python", "venv/bin/python3", "venv/bin/python"]: + candidates.append(ws.get_child(rel)) + + for name in ["python3", "python"]: + p = ctx.which(name) + if p: + candidates.append(p) + + for candidate in candidates: + if candidate.exists: + return candidate + return None + +def _find_executorch_source(ctx): + """Return the path to the executorch source root, or None.""" + + # 1. Env-var override + et_path = ctx.os.environ.get("EXECUTORCH_PATH", "").strip() + if et_path: + p = ctx.path(et_path) + if p.exists: + return p + fail("EXECUTORCH_PATH is set to '{}' but that directory does not exist.".format(et_path)) + + # 2. Python import — walk up from the package to find the source root + python = _find_python(ctx) + if python: + result = ctx.execute([ + python, + "-c", + "\n".join([ + "import executorch, os", + "pkg = os.path.dirname(executorch.__file__)", + "# Walk upward from the package looking for runtime/backend/", + "for d in [pkg, os.path.join(pkg, '..'), os.path.join(pkg, '..', '..')]:", + " d = os.path.realpath(d)", + " if os.path.isdir(os.path.join(d, 'runtime', 'backend')):", + " print(d)", + " break", + ]), + ]) + if result.return_code == 0 and result.stdout.strip(): + p = ctx.path(result.stdout.strip()) + if p.exists: + return p + + return None + +def _local_executorch_impl(ctx): + et_dir = _find_executorch_source(ctx) + if et_dir == None: + fail( + "Cannot locate the ExecuTorch source tree. " + + "Set EXECUTORCH_PATH to the directory that contains runtime/, " + + "extension/, and cmake-out/ (e.g. export EXECUTORCH_PATH=/path/to/executorch). " + + "Ensure that cmake-out/libexecutorch_core.a has been built.", + ) + + # Symlink the subdirectories referenced by the BUILD file into the synthetic + # repo root, mirroring the new_local_repository(path=EXECUTORCH_PATH) layout. + # include_prefix = "executorch" in the BUILD file handles the header remapping. + # Note: cmake-out is no longer needed (executorch_core static lib was removed). + for sub in ["runtime", "extension"]: + child = et_dir.get_child(sub) + if child.exists: + ctx.symlink(child, sub) + + # Expose libqnn_executorch_backend.so as a cc_import target so that + # libtrt_executorch_backend.so can resolve 6 ExecuTorch backend-registry + # symbols at dlopen() time (register_backend, find_backend, vlogf, …). + qnn_so = et_dir.get_child("backends/qualcomm/libqnn_executorch_backend.so") + if qnn_so.exists: + ctx.symlink(qnn_so, "libqnn_executorch_backend.so") + + # Expose libexecutorch_core.a so that C++ binaries (e.g. trt_executor_runner) + # can link the full ExecuTorch runtime (Program::load, Method::execute, + # runtime_init, MethodMeta, …) statically. + core_a = et_dir.get_child("cmake-out/libexecutorch_core.a") + if core_a.exists: + ctx.symlink(core_a, "libexecutorch_core.a") + + ctx.file("BUILD", ctx.read(Label("@//third_party/executorch:BUILD"))) + +local_executorch = repository_rule( + implementation = _local_executorch_impl, + environ = ["EXECUTORCH_PATH", "VIRTUAL_ENV", "CONDA_PREFIX"], +)