Skip to content

[Feature] Trajectory batcher#3584

Merged
vmoens merged 7 commits intopytorch:mainfrom
theap06:feat/trajectory-batcher
Mar 31, 2026
Merged

[Feature] Trajectory batcher#3584
vmoens merged 7 commits intopytorch:mainfrom
theap06:feat/trajectory-batcher

Conversation

@theap06
Copy link
Copy Markdown
Contributor

@theap06 theap06 commented Mar 30, 2026

Description

Adds TrajectoryBatcher to torchrl/collectors/utils.py (and exports it from torchrl.collectors). It wraps any TorchRL collector and yields batches of exactly num_trajectories complete episodes, reconstructed across collector iterations using ("collector", "traj_ids"). Output is a zero-padded [N, max_len, ...] TensorDict with a ("collector", "mask") boolean field marking valid time steps.

from torchrl.collectors import Collector, TrajectoryBatcher

collector = Collector(env_fn, policy, frames_per_batch=200, total_frames=10_000)
batcher = TrajectoryBatcher(collector, num_trajectories=32, strict_on_policy=True)

for batch in batcher:
# batch.shape == [32, max_episode_len]
# batch[("collector", "mask")] marks valid steps
loss = compute_loss(batch)
loss.backward()
optimizer.step()
batcher.update_policy_weights_() # burns partial trajectories when strict_on_policy=True
Describe your changes in detail.
Fixes #3234

Motivation and Context
Many RL algorithms (Monte Carlo / REINFORCE, episodic PPO, imitation learning, RLHF) require full episodes padded into a single batch. Currently users must manually: track traj_ids, rebuild trajectories across collector iterations, detect episode completion via done/mask, pad variable-length episodes, and optionally discard mixed-policy episodes when the policy is updated. This logic is error-prone and reimplemented repeatedly.

There is no existing high-level utility in TorchRL that provides "give me exactly N full episodes as a padded TensorDict", despite the collectors already exposing all the needed information (traj_ids, done, mask).

… any

TorchRL collector and yields batches of exactly N complete, padded trajectories.
It reconstructs episodes split across collector iterations using traj_ids,
supports strict on-policy mode (discard partials on weight update), and outputs
standard TensorDict batches with a (collector, mask) field.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3584

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 2 Pending

As of commit b673adf with merge base 143b67e (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 30, 2026
@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: Feat/trajectory batcher

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@theap06 theap06 changed the title Feat/trajectory batcher [Feature Request] Feat/trajectory batcher Mar 30, 2026
@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

Unknown or invalid prefix [Feature Request].

Current title: [Feature Request] Feat/trajectory batcher

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@theap06 theap06 changed the title [Feature Request] Feat/trajectory batcher [Feature] Feat/trajectory batcher Mar 30, 2026
@github-actions github-actions bot added the Feature New feature label Mar 30, 2026
@vmoens vmoens changed the title [Feature] Feat/trajectory batcher [Feature] Trajectory batcher Mar 30, 2026
Copy link
Copy Markdown
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!
Curious to hear your thoughts about the need for a new class instead of a new kwarg in collector build functions?

@github-actions github-actions bot added the Documentation Improvements or additions to documentation label Mar 30, 2026
@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented Mar 30, 2026

Thanks for this!
Curious to hear your thoughts about the need for a new class instead of a new kwarg in collector build functions?

TrajectoryBatcher is able to be a wrapper that avoids modifying every collector class, and handles the cross-batch trajectory reassembly + provides strict_on_policy bookkeeping; we can't do that easily by adding a kwarg

@theap06 theap06 requested a review from vmoens March 30, 2026 09:10
@vmoens
Copy link
Copy Markdown
Collaborator

vmoens commented Mar 30, 2026

TrajectoryBatcher is able to be a wrapper that avoids modifying every collector class, and handles the cross-batch trajectory reassembly + provides strict_on_policy bookkeeping; we can't do that easily by adding a kwarg

Is it as efficient though? Can we make this a post-process rather than a wrapper?
I'm asking this because wrappers have a lot of unforeseen consequences like breaking isinstance checks, obscuring the nature of the object they wrap etc. For example weight sync, state-dict, etc are all kinds of things that we'll need to wire up.
Adding the fact that we have collectors that have that kwarg already so the "two ways to do the same thing" may be puzzling and drive people away from collectors entirely.

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented Mar 30, 2026

TrajectoryBatcher is able to be a wrapper that avoids modifying every collector class, and handles the cross-batch trajectory reassembly + provides strict_on_policy bookkeeping; we can't do that easily by adding a kwarg

Is it as efficient though? Can we make this a post-process rather than a wrapper? I'm asking this because wrappers have a lot of unforeseen consequences like breaking isinstance checks, obscuring the nature of the object they wrap etc. For example weight sync, state-dict, etc are all kinds of things that we'll need to wire up. Adding the fact that we have collectors that have that kwarg already so the "two ways to do the same thing" may be puzzling and drive people away from collectors entirely.

TrajectoryBatcher is gone, and the logic now lives as a num_trajectories_per_batch kwarg on all collectors. implemented in BaseCollector iter using private stateful helpers.

Copy link
Copy Markdown
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love it!
Some comments on tests and missing items from docstrings.
Thanks for this!

@theap06 theap06 requested a review from vmoens March 31, 2026 06:29
Copy link
Copy Markdown
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!
Now I think a follow up PR would be nice and could cover using this with filling up a replay buffer with trajectories (using collector.sart() for async collection).
This would work best with SliceSampler and such traj-oriented tooling in the RB API.

@theap06 theap06 requested a review from vmoens March 31, 2026 08:18
@vmoens vmoens merged commit e2c8a8d into pytorch:main Mar 31, 2026
118 of 121 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Collectors Documentation Improvements or additions to documentation Feature New feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] **TrajectoryBatcher** high-level util for complete episode batching

2 participants