Conversation
… any TorchRL collector and yields batches of exactly N complete, padded trajectories. It reconstructs episodes split across collector iterations using traj_ids, supports strict on-policy mode (discard partials on weight update), and outputs standard TensorDict batches with a (collector, mask) field.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3584
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 2 PendingAs of commit b673adf with merge base 143b67e ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
vmoens
left a comment
There was a problem hiding this comment.
Thanks for this!
Curious to hear your thoughts about the need for a new class instead of a new kwarg in collector build functions?
TrajectoryBatcher is able to be a wrapper that avoids modifying every collector class, and handles the cross-batch trajectory reassembly + provides strict_on_policy bookkeeping; we can't do that easily by adding a kwarg |
Is it as efficient though? Can we make this a post-process rather than a wrapper? |
TrajectoryBatcher is gone, and the logic now lives as a num_trajectories_per_batch kwarg on all collectors. implemented in BaseCollector iter using private stateful helpers. |
vmoens
left a comment
There was a problem hiding this comment.
I love it!
Some comments on tests and missing items from docstrings.
Thanks for this!
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
vmoens
left a comment
There was a problem hiding this comment.
Amazing!
Now I think a follow up PR would be nice and could cover using this with filling up a replay buffer with trajectories (using collector.sart() for async collection).
This would work best with SliceSampler and such traj-oriented tooling in the RB API.
Description
Adds TrajectoryBatcher to torchrl/collectors/utils.py (and exports it from torchrl.collectors). It wraps any TorchRL collector and yields batches of exactly num_trajectories complete episodes, reconstructed across collector iterations using ("collector", "traj_ids"). Output is a zero-padded [N, max_len, ...] TensorDict with a ("collector", "mask") boolean field marking valid time steps.
from torchrl.collectors import Collector, TrajectoryBatcher
collector = Collector(env_fn, policy, frames_per_batch=200, total_frames=10_000)
batcher = TrajectoryBatcher(collector, num_trajectories=32, strict_on_policy=True)
for batch in batcher:
# batch.shape == [32, max_episode_len]
# batch[("collector", "mask")] marks valid steps
loss = compute_loss(batch)
loss.backward()
optimizer.step()
batcher.update_policy_weights_() # burns partial trajectories when strict_on_policy=True
Describe your changes in detail.
Fixes #3234
Motivation and Context
Many RL algorithms (Monte Carlo / REINFORCE, episodic PPO, imitation learning, RLHF) require full episodes padded into a single batch. Currently users must manually: track traj_ids, rebuild trajectories across collector iterations, detect episode completion via done/mask, pad variable-length episodes, and optionally discard mixed-policy episodes when the policy is updated. This logic is error-prone and reimplemented repeatedly.
There is no existing high-level utility in TorchRL that provides "give me exactly N full episodes as a padded TensorDict", despite the collectors already exposing all the needed information (traj_ids, done, mask).