Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
808ee01
Multi-Device TensorRT Runtime with Native NCCL Collectives
apbose Apr 1, 2026
aaa6557
removing the try-except block in TRTengine.cpp and correcting the typis
apbose Apr 8, 2026
4c1e68d
Redesign distributed inference API: auto-detect rank, lazy NCCL setup…
apbose Apr 9, 2026
7cfa40b
remove nccl.h dependancy
apbose Apr 9, 2026
ac96255
clean up import and add comment
apbose Apr 9, 2026
fe1c6f4
moving setup_nccl_library call to example script
apbose Apr 9, 2026
b658c7a
work on the save/load export part-add is_md flag, guard export tracin…
apbose Apr 10, 2026
a35dfe6
refactor: Adjusting how we use NCCL
narendasan Apr 10, 2026
3def3f7
fix: enable torch.compile(backend='tensorrt') for LLMs with dynamic s…
narendasan Apr 10, 2026
2aa8f14
test: add torch.compile(backend='tensorrt') integration test for Llam…
narendasan Apr 10, 2026
6f81a66
feat: llama3.2 working with MD-TRT
narendasan Apr 10, 2026
0d2d61c
feat: Support exported and serialization workflows for MD-TRT
narendasan Apr 12, 2026
e08b0c5
ci: fix nccl builds in CI
narendasan Apr 12, 2026
754b62b
chore: Some reorg and cleaning the constructor
narendasan Apr 14, 2026
bf432ad
fix: thread the MD-TRT requirement through the conversion system
narendasan Apr 14, 2026
f4e77ad
fix: DeviceMesh FakeScriptObjects get passed in as arguments into tor…
narendasan Apr 16, 2026
6ba00cf
fix: Address segfaults when a distributed context is manually destroy…
narendasan Apr 16, 2026
edf6518
replacing torchrun with torchtrtrun for right .so
apbose Apr 16, 2026
9e390eb
chore: apply linting
narendasan Apr 16, 2026
1b4e559
use correct group for dummy all_reduce
apbose Apr 16, 2026
df51acf
Broaden NCCL skip guards to include native TRT collectives and fix di…
apbose Apr 17, 2026
4665692
Merge branch 'main' into push-vqqzkszwrvyx
narendasan Apr 18, 2026
fd32b5b
fix: Update the engine cache to be aware of the new setting
narendasan Apr 16, 2026
7dce75e
update TRT version
narendasan Apr 16, 2026
3d3c7d1
fix: the md property was not properly handled across the library
narendasan Apr 20, 2026
fe59779
chore: skip test which is not valid ATen attn
narendasan Apr 20, 2026
fdfd45a
finalizing internal apis
narendasan Apr 21, 2026
f9bd6a4
fix: align apis, make sure to defer binding unless there is one obvio…
narendasan Apr 21, 2026
34ccf1d
test: make bert test backwards compatible
narendasan Apr 21, 2026
6379b06
fix: address some issues in the converters
narendasan Apr 21, 2026
25537bf
fix: we should inherit information about the device mesh from torch d…
narendasan Apr 21, 2026
2f05500
fix: address non MD-TRT build torch bind
narendasan Apr 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,105 changes: 595 additions & 510 deletions .github/workflows/build-test-linux-x86_64.yml

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion .github/workflows/linux-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ on:
default: false
type: boolean
required: false
runner:
description: "Override the runner label (e.g. linux.g4dn.12xlarge.nvidia.gpu for multi-GPU jobs). Defaults to matrix.validation_runner."
default: ""
type: string
required: false

jobs:
test:
Expand All @@ -76,7 +81,7 @@ jobs:
USE_TRT_RTX: ${{ inputs.use-rtx }}
DOWNLOAD_ARTIFACT_NAME: pytorch_tensorrt_${{ matrix.tensorrt.version }}_${{ matrix.python_version }}_${{ matrix.desired_cuda }}_${{ inputs.architecture }}
name: ${{ inputs.job-name }}-${{ matrix.tensorrt.version }}-${{ matrix.python_version }}-${{ matrix.desired_cuda }}
runs-on: ${{ matrix.validation_runner }}
runs-on: ${{ inputs.runner != '' && inputs.runner || matrix.validation_runner }}
container:
image: ${{ matrix.container_image }}
options: ${{ matrix.gpu_arch_type == 'cuda' && '--gpus all --shm-size=1g' || ' ' }}
Expand Down
17 changes: 11 additions & 6 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0")
bazel_dep(name = "platforms", version = "0.0.11")
bazel_dep(name = "rules_cc", version = "0.1.1")
bazel_dep(name = "rules_python", version = "1.3.0")
bazel_dep(name = "bazel_skylib", version = "1.7.1")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(
Expand All @@ -26,6 +27,10 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.

local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch")

torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")

torch_nccl_detect(name = "torch_nccl")

# External dependency for torch_tensorrt if you already have precompiled binaries.
new_local_repository(
name = "torch_tensorrt",
Expand Down Expand Up @@ -131,9 +136,9 @@ http_archive(
http_archive(
name = "tensorrt",
build_file = "@//third_party/tensorrt/archive:BUILD",
strip_prefix = "TensorRT-10.16.0.72",
strip_prefix = "TensorRT-10.16.1.11",
urls = [
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.16.0/tars/TensorRT-10.16.0.72.Linux.x86_64-gnu.cuda-13.2.tar.gz",
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.16.1/tars/TensorRT-10.16.1.11.Linux.x86_64-gnu.cuda-13.2.tar.gz",
],
)

Expand All @@ -149,9 +154,9 @@ http_archive(
http_archive(
name = "tensorrt_sbsa",
build_file = "@//third_party/tensorrt/archive:BUILD",
strip_prefix = "TensorRT-10.16.0.72",
strip_prefix = "TensorRT-10.16.1.11",
urls = [
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.16.0/tars/TensorRT-10.16.0.72.Linux.aarch64-gnu.cuda-13.2.tar.gz",
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.16.1/tars/TensorRT-10.16.1.11.Linux.aarch64-gnu.cuda-13.2.tar.gz",
],
)

Expand All @@ -167,9 +172,9 @@ http_archive(
http_archive(
name = "tensorrt_win",
build_file = "@//third_party/tensorrt/archive:BUILD",
strip_prefix = "TensorRT-10.16.0.72",
strip_prefix = "TensorRT-10.16.1.11",
urls = [
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.16.0/zip/TensorRT-10.16.0.72.Windows.amd64.cuda-13.2.zip",
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.16.1/zip/TensorRT-10.16.1.11.Windows.amd64.cuda-13.2.zip",
],
)

Expand Down
7 changes: 5 additions & 2 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")
load("//toolchains/torch_nccl:defs.bzl", "if_torch_nccl")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -77,13 +79,14 @@ cc_library(
"TRTEngineProfiler.h",
"runtime.h",
],
copts = if_torch_nccl(["-DUSE_C10D_NCCL"]),
linkopts = [
"-lstdc++fs",
],
deps = [
"//core/plugins:torch_tensorrt_plugins",
"//core/util:prelude",
] + select({
] + if_torch_nccl(["@torch_nccl//:nccl_headers"]) + select({
":jetpack": ["@tensorrt_l4t//:nvinfer"],
":rtx_win": ["@tensorrt_rtx_win//:nvinfer"],
":rtx_x86_64": ["@tensorrt_rtx//:nvinfer"],
Expand Down Expand Up @@ -121,6 +124,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
155 changes: 143 additions & 12 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include "core/util/prelude.h"
#include "torch/torch.h"

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
#include "torch/csrc/distributed/c10d/GroupRegistry.hpp"
#include "torch/csrc/distributed/c10d/NCCLUtils.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroup.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp"
#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -88,7 +95,12 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {}
: ResourceAllocationStrategy::kStatic)) {
this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]);
if (this->requires_native_multidevice) {
LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution");
}
}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand Down Expand Up @@ -261,10 +273,21 @@ TRTEngine::TRTEngine(
this->enable_profiling();
#endif
LOG_DEBUG(*this);

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
// Attempt to bind the NCCL communicator immediately after exec_ctx is ready.
// This handles the common case where dist.init_process_group() and an initial
// collective have already been called before the engine is constructed.
// If the communicator isn't available yet (e.g. engine constructed before the
// first collective), bind_nccl_comm returns false and execute_engine() will
// retry on its first invocation.
if (this->requires_native_multidevice) {
bind_nccl_comm();
}
#endif
}

TRTEngine::~TRTEngine() {
torch::cuda::synchronize(device_info.id);
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
Expand Down Expand Up @@ -383,6 +406,13 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
if (profile_execution) {
enable_profiling();
}
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
// exec_ctx was recreated — re-bind the NCCL communicator if this is a
// distributed engine that has already been set up.
if (nccl_initialized) {
bind_nccl_comm();
}
#endif
// Indicates to reevaluate the runtime settings
runtime_states.context_changed = true;

Expand Down Expand Up @@ -428,6 +458,7 @@ std::string TRTEngine::to_str() const {
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl;
// clang-format on
return ss.str();
}
Expand All @@ -437,15 +468,6 @@ std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
return os;
}

TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
rt = other.rt;
cuda_engine = other.cuda_engine;
device_info = other.device_info;
exec_ctx = other.exec_ctx;
num_io = other.num_io;
return (*this);
}

void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& serialized_info) {
TORCHTRT_CHECK(
serialized_info.size() == SERIALIZATION_LEN,
Expand All @@ -472,7 +494,8 @@ FlattenedState TRTEngine::__obj_flatten__() {
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]));
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]),
std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
Expand All @@ -497,6 +520,8 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0";
// rank/world_size are runtime facts (may differ at load time); not serialized.

return serialized_info;
}
Expand All @@ -519,6 +544,112 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
}
}

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
bool TRTEngine::bind_nccl_comm() {
// When group_name is empty (e.g. engine loaded from a serialized
// ExportedProgram where the Python TorchTensorRTModule wrapper was
// inlined and set_group_name() was never called), auto-resolve the
// process group from the c10d registry. PyTorch assigns sequential
// numeric names ("0", "1", ...) to process groups; probe until we
// find one with an NCCL backend.
if (this->group_name.empty() && this->requires_native_multidevice) {
// PyTorch assigns sequential numeric names ("0", "1", ...) to process
// groups. Collect every group that has an NCCL backend; we can only
// auto-resolve when there is exactly one — if there are several (TP+DP,
// Megatron 4-D parallelism, etc.) we cannot know which group this engine
// belongs to and the caller must pin it explicitly.
std::vector<std::string> nccl_groups;
for (int i = 0; i < 20; ++i) {
auto candidate = std::to_string(i);
auto probe = c10d::resolve_process_group(candidate);
if (probe != nullptr && probe->getBackendType() == c10d::ProcessGroup::BackendType::NCCL) {
nccl_groups.push_back(candidate);
}
}

if (nccl_groups.size() == 1) {
this->group_name = nccl_groups[0];
LOG_INFO("Auto-resolved distributed group name to '" << this->group_name << "'");
} else if (nccl_groups.size() > 1) {
std::string names;
for (const auto& n : nccl_groups) {
if (!names.empty())
names += ", ";
names += "'" + n + "'";
}
LOG_WARNING(
"This TRT engine requires NCCL but multiple NCCL process groups are registered ("
<< names
<< "). Cannot auto-select a group — NCCL bind deferred. "
"Use the recommended workflow: "
"with torch_tensorrt.distributed.distributed_context(group, model) as m: m(inp)");
} else {
LOG_WARNING(
"This TRT engine requires NCCL (requires_native_multidevice=true) but no NCCL process group "
"was found in the c10d registry. Ensure dist.init_process_group(backend='nccl') "
"has been called before loading the engine. You can also set the group name "
"manually via: torch_tensorrt.distributed.distributed_context(group, model)");
}
}

// Soft-return when the process group isn't available yet (e.g. at engine
// construction time when the caller hasn't called dist.init_process_group()).
auto pg = c10d::resolve_process_group(this->group_name);
if (pg == nullptr) {
LOG_DEBUG("ProcessGroup '" << this->group_name << "' not yet registered in c10d; NCCL bind deferred.");
return false;
}

this->rank = pg->getRank();
this->world_size = pg->getSize();

auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL);
TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << this->group_name << "' has no NCCL backend");

auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get());
TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL");

at::cuda::set_device(this->device_info.id);

int64_t comm_ptr = nccl_pg->getCommPtr();
// Soft-return when NCCL hasn't run a collective yet. The communicator is
// created lazily by PyTorch on the first collective — callers should ensure
// at least one collective (e.g. dist.barrier()) has been issued before the
// first TRT forward pass.
if (comm_ptr == 0) {
LOG_DEBUG(
"NCCL communicator not yet initialized for device " << this->device_info.id
<< "; NCCL bind deferred until first execute_engine call.");
return false;
}

TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Cannot bind NCCL communicator: execution context is null");
exec_ctx->setCommunicator(reinterpret_cast<void*>(comm_ptr));
this->nccl_initialized = true;
LOG_INFO("NCCL comm bound (rank=" << this->rank << ", device=" << this->device_info.id << ")");
return true;
}

void TRTEngine::release_nccl_comm() {
if (!this->nccl_initialized) {
return;
}
LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'");
torch::cuda::synchronize(device_info.id);
this->exec_ctx.reset();
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
TORCHTRT_CHECK(
(exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm");
this->nccl_initialized = false;
LOG_INFO("NCCL communicator released from engine '" << this->name << "'");
}
#endif // ENABLE_TRT_NCCL_COLLECTIVES

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
Loading
Loading