feat: Add Alpamayo end-to-end Expert diffusion support#67
feat: Add Alpamayo end-to-end Expert diffusion support#67Turoad wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Add end-to-end TensorRT inference for Alpamayo 1.5 (10B VLM + Diffusion Expert) autonomous driving model. After VLM decode produces a KV cache, the Expert runner performs 10-step flow-matching Euler integration to generate 6 candidate trajectories — all on GPU without host round-trips. New files: - AlpamayoExpertRunner: loads Expert TRT engine, manages GPU buffers (KV reshape, noisy action, timestep, attention mask), runs diffusion - CUDA kernels: kvCacheReshapeRepeat, buildPositionIds, fillTimestep, eulerUpdate Modified files: - llmInferenceRuntime: Expert integration after VLM decode, StopAfterEOS for traj_future_start token, single-seq and multi-seq candidate modes, KV cache dump for offline validation - llm_inference.cpp: CLI flags --expertEngine, --numCandidates, --numDiffusionSteps, --multiSeq, --dumpKVCache - QwenViTRunner: preprocessPreparedVisual for pre-processed multi-camera images (bypasses runtime image decoding) - CMakeLists: curand linkage, SM 110 (Jetson Thor) architecture Validated on Jetson AGX Thor with 300 samples: FP8 VLM + BF16 Expert 6x steady-state ~2.83s/sample, mean minADE 0.799m, median 0.637m (aligned with PyTorch reference mean 0.827m, median 0.700m). Signed-off-by: thor <thor@nvidia.com>
|
Thanks a lot for this MR! The core team is also working on supporting Alpamayo in a new release. We will take this MR as a reference and properly cite this great contribution. |
|
Thanks for the response! Glad to hear the core team is also working on Alpamayo support. I'd love to collaborate rather than have parallel efforts — a few thoughts:
What would make this most useful for the team — merge as-is, adapt to your internal branch, or contribute specific pieces? I'm flexible on the path, just want to make sure the validated work doesn't go to waste. |
|
@Turoad I'd love the test harness and dataset scripts you offered above. Trying to reproduce your end-to-end pipeline on a Jetson AGX Thor against |
What does this PR do?
Type of change: New feature
Overview: Add end-to-end TensorRT inference for Alpamayo 1.5 (10B VLM + Diffusion Expert) autonomous driving model. After VLM decode produces a KV cache, the Expert runner performs 10-step flow-matching Euler integration to generate 6 candidate trajectories — all on GPU without host round-trips.
New files
cpp/runtime/alpamayoExpertRunner.h/.cpp): Loads Expert TRT engine, manages GPU buffers (KV reshape, noisy action, timestep, attention mask), runs 10-step Euler diffusioncpp/kernels/alpamayoExpertKernels/):kvCacheReshapeRepeat,buildPositionIds,fillTimestep,eulerUpdateModified files
<traj_future_start>token, single-seq and multi-seq candidate modes, KV cache dump for offline validation--expertEngine,--numCandidates,--numDiffusionSteps,--multiSeq,--dumpKVCachepreprocessPreparedVisual()for pre-processed multi-camera images (bypasses runtime image decoding)Validation
Validated on Jetson AGX Thor with 300 samples:
Usage
🚀 Pull Request Checklist
✅ Pre-commit Checks
🧪 Tests
📄 Documentation
⚙️ Compatibility
--expertEngineflagAdditional Information
Related issue: #32