Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,7 @@ def async_grpo_train(
grpo_save_state: GRPOSaveState,
master_config: MasterConfig,
max_trajectory_age_steps: int = 1,
rlix_hooks: Optional[Any] = None,
) -> None:
"""Run asynchronous GRPO training with replay buffer.

Expand All @@ -2394,6 +2395,17 @@ def async_grpo_train(
master_config: Master configuration
max_trajectory_age_steps: Maximum age (in training steps) for trajectories to be used in training
"""
# F5/F11: RLix integration flag.
# True when RLIX_CONTROL_PLANE=rlix env var is set; False in standalone mode.
# Controls: skip standalone refit, enable before/after_training hooks, and
# skip prepare_for_generation() / refit_policy_generation() which conflict
# with scheduler-driven sleep/wake.
DO_TIME_SHARING: bool = os.environ.get("RLIX_CONTROL_PLANE") == "rlix"

# F5/F9: Resolve hooks — use injected real implementation or no-op default.
from nemo_rl.algorithms.rlix_hooks import NoOpRLixHooks, RLixHooksProtocol
hooks: RLixHooksProtocol = rlix_hooks if rlix_hooks is not None else NoOpRLixHooks()

# Ensure we are running with a compatible async generation backend
assert _should_use_async_rollouts(master_config), (
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. "
Expand Down Expand Up @@ -2533,36 +2545,44 @@ def async_grpo_train(
# Ensure collector knows initial weight version
trajectory_collector.set_weight_version.remote(weight_version)

# F6: Register collector handle with pipeline actor so _expand_workers can
# call set_weight_version after each selective sync (before routing activation).
hooks.on_trajectory_collector_created(trajectory_collector)

print("📦 Started continuous background trajectory collection")

print(
f"🚀 Starting async GRPO training with buffer_size={optimal_buffer_size}, max_age={max_trajectory_age_steps} steps"
)

print("⏳ Preparing policy generation for training...")
if NEED_REFIT and POLICY_GENERATION_STALE:
print("🔄 Refitting policy generation with actual model weights...")
try:
refit_policy_generation(policy, policy_generation, colocated_inference)
print("✅ Policy generation refit completed successfully")
POLICY_GENERATION_STALE = False
except Exception as e:
print(f"❌ Policy generation refit failed: {e}")
import traceback

traceback.print_exc()
return
else:
print("🔄 Preparing policy generation for inference...")
try:
policy_generation.prepare_for_generation()
print("✅ Policy generation preparation completed successfully")
except Exception as e:
print(f"❌ Policy generation preparation failed: {e}")
import traceback

traceback.print_exc()
return
# F5/F11: In RLix mode, skip initial refit and prepare_for_generation.
# Weights are synced on first scheduler expand; sleep/wake is scheduler-driven.
# Calling prepare_for_generation here would reinitialize already-running inference workers.
if not DO_TIME_SHARING:
if NEED_REFIT and POLICY_GENERATION_STALE:
print("🔄 Refitting policy generation with actual model weights...")
try:
refit_policy_generation(policy, policy_generation, colocated_inference)
print("✅ Policy generation refit completed successfully")
POLICY_GENERATION_STALE = False
except Exception as e:
print(f"❌ Policy generation refit failed: {e}")
import traceback

traceback.print_exc()
return
else:
print("🔄 Preparing policy generation for inference...")
try:
policy_generation.prepare_for_generation()
print("✅ Policy generation preparation completed successfully")
except Exception as e:
print(f"❌ Policy generation preparation failed: {e}")
import traceback

traceback.print_exc()
return

print("✅ Policy generation setup complete, proceeding to validation...")

Expand Down Expand Up @@ -2782,6 +2802,10 @@ def async_grpo_train(

# Training phase (same as sync version)
print("▶ Preparing for logprob inference...")
# F5: Block until scheduler grants actor_train GPUs.
# In RLix mode: scheduler asynchronously shrinks overlap inference
# workers before returning. In standalone mode: no-op.
hooks.before_training(step)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

Expand Down Expand Up @@ -2853,7 +2877,30 @@ def async_grpo_train(

print("🔄 Synchronizing policy weights to trajectory collector…")
generation_logger_metrics = None
if NEED_REFIT:
if DO_TIME_SHARING:
# F5/F11: RLix mode — replace standalone refit with scheduler-
# driven expand. The scheduler's resize_infer(add=overlap_ranks)
# calls pipeline._expand_workers() which does the atomic
# wake + selective sync + version update + routing activation (F6).
#
# TODO F4: policy.build_cpu_bucket_cache(step)
# self._cache_ready_step = step
# TODO F11: policy.offload_training_gpu()
# policy.destroy_nccl_groups()
with timer.time("weight_sync"):
# Notify scheduler: actor_train GPUs are free.
# Scheduler asynchronously triggers expand + weight sync.
published_version = hooks.after_training(step)
# RLix publishes version=cache_ready_step after active
# refresh completes. Fall back to step for older hooks.
weight_version = (
int(published_version)
if published_version is not None
else int(step)
)
POLICY_GENERATION_STALE = False
elif NEED_REFIT:
# Standalone mode — original refit path.
# Measure pending-generation wait as exposed_generation time
print("🔄 Coordinating with trajectory collector before refit...")
with timer.time("exposed_generation"):
Expand Down Expand Up @@ -2894,13 +2941,14 @@ def async_grpo_train(
# Pause trajectory collection during validation to reduce memory pressure
trajectory_collector.pause.remote()

if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
if not DO_TIME_SHARING:
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
val_metrics, validation_timings = validate(
policy_generation,
val_dataloader,
Expand Down
81 changes: 81 additions & 0 deletions nemo_rl/algorithms/rlix_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""RLix hook protocol and no-op default for NeMo RL's async_grpo_train.

This module is the seam between NeMo RL (caller) and RLix (implementor).
Import direction:
nemo_rl/algorithms/grpo.py → rlix_hooks.py (this file, NeMo RL repo)
rlix/pipeline/nemo_rl_pipeline.py → provides NemoRLRLixHooks (real impl)

NeMo RL code only depends on RLixHooksProtocol + NoOpRLixHooks from this file.
It never imports from the rlix package directly, preventing circular dependencies.
"""
from __future__ import annotations

from typing import Any, runtime_checkable

from typing_extensions import Protocol


@runtime_checkable
class RLixHooksProtocol(Protocol):
"""Protocol that async_grpo_train expects from its rlix_hooks argument.

NeMo RL standalone mode uses NoOpRLixHooks (all methods are no-ops).
RLix mode passes NemoRLRLixHooks (rlix/pipeline/nemo_rl_pipeline.py),
which makes blocking Ray RPC calls to the scheduler.
"""

def before_training(self, step: int) -> None:
"""Called before logprob inference + training; may block on scheduler.

F5: in RLix mode, blocks until the scheduler grants actor_train GPUs.
The scheduler shrinks overlap inference workers before returning.
In standalone mode, this is a no-op.
"""
...

def after_training(self, step: int) -> int | None:
"""Called after policy.train() completes; notifies scheduler to expand.

F5: in RLix mode, notifies the scheduler that actor_train GPUs are
released. The scheduler asynchronously calls coordinator.resize_infer
(add=overlap_ranks), which routes to pipeline._expand_workers() (F6).
Weight sync and version update happen inside _expand_workers.
In standalone mode, this is a no-op.

Preconditions (must be satisfied before calling in RLix mode):
- CPU bucket cache built (TODO F4: policy.build_cpu_bucket_cache)
- Training GPU VRAM offloaded (TODO F11: policy.offload_training_gpu)
- Megatron NCCL groups destroyed (TODO F11: destroy_nccl_groups)
"""
...

def on_trajectory_collector_created(self, collector: Any) -> None:
"""Register the trajectory collector Ray actor handle with the pipeline.

F6 dependency: _expand_workers calls collector.set_weight_version after
each selective sync. The handle must be registered here before the
first expand fires, otherwise _expand_workers logs a warning and skips
the version update.

Called once, immediately after AsyncTrajectoryCollector is created and
set_weight_version has been called with the initial value.
"""
...


class NoOpRLixHooks:
"""Default no-op implementation used in NeMo RL standalone mode.

Satisfies RLixHooksProtocol so grpo.py always calls hooks.* without
guarding against None. In standalone mode these are all no-ops;
the real implementations live in rlix/pipeline/nemo_rl_pipeline.py.
"""

def before_training(self, step: int) -> None:
pass

def after_training(self, step: int) -> int | None:
return None

def on_trajectory_collector_created(self, collector: Any) -> None:
pass
149 changes: 149 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,155 @@ def _load_model_weights(weights, model_runner):

return True

def setup_collective_group(
self,
model_update_name: str,
comm_plan: dict[int, Any],
mode: str,
timeout_s: float | None = None,
dp_rank: int | None = None,
) -> bool:
"""Create a temporary RLix model-update NCCL group for this vLLM rank."""
del timeout_s # StatelessProcessGroup does not expose timeout control.
from nemo_rl.distributed.stateless_process_group import StatelessProcessGroup

if len(comm_plan) != 1:
raise ValueError(
"RLix model-update receiver expects a single owner comm plan; "
f"got {len(comm_plan)} entries"
)
owner_plan = next(iter(comm_plan.values()))
local_rank = int(torch.distributed.get_rank())
rank = None
if mode == "sender":
rank = 0
else:
if dp_rank is None:
raise ValueError("dp_rank is required for receiver setup")
local_ranks = owner_plan.get("broadcast_local_ranks_by_dp_rank", {}).get(
int(dp_rank), []
)
if local_rank not in [int(x) for x in local_ranks]:
return True
devices = owner_plan.get("tgt_devices", [])
ordered = [
(int(item["rank"]), int(item["device"]))
for item in devices
]
try:
rank = 1 + ordered.index((int(dp_rank), local_rank))
except ValueError:
return True

groups = getattr(self, "_rlix_model_update_groups", None)
if groups is None:
groups = {}
self._rlix_model_update_groups = groups
group = StatelessProcessGroup(
master_address=str(owner_plan["master_addr"]),
port=int(owner_plan["master_port"]),
rank=int(rank),
world_size=1 + len(owner_plan.get("tgt_devices", [])),
)
group.init_nccl_communicator(device=self.device)
groups[str(model_update_name)] = group
return True

def update_parameter_in_bucket(
self,
payload_list: list[Any],
is_lora: bool = False,
ipc_local_ranks: list[int] | None = None,
model_update_transport: str | None = None,
) -> bool:
"""Apply one IPC-delivered RLix weight bucket to selected local ranks."""
del is_lora, model_update_transport
local_rank = int(torch.distributed.get_rank())
if ipc_local_ranks is not None and local_rank not in [int(x) for x in ipc_local_ranks]:
return True
weights: list[tuple[str, torch.Tensor]] = []
for item in payload_list:
if isinstance(item, tuple) and len(item) == 2:
weights.append(item)
elif isinstance(item, dict):
name = item.get("name") or item.get("key")
tensor = item.get("tensor")
if tensor is None:
tensor = item.get("value")
if name is not None and tensor is not None:
weights.append((str(name), tensor))
policy_weights, draft_weights = self._split_policy_and_draft_weights(weights)
from nemo_rl.models.generation.vllm.quantization import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
fp8.load_weights(policy_weights, self.model_runner)
else:
self.model_runner.model.load_weights(weights=policy_weights)
self._load_draft_weights(draft_weights)
torch.cuda.current_stream().synchronize()
return True

def broadcast_parameter(
self,
group_name: str,
names: list[str],
dtypes: list[Any],
shapes: list[Any],
is_lora: bool = False,
broadcast_local_ranks: list[int] | None = None,
) -> bool:
"""Receive one RLix broadcast bucket and apply it to selected local ranks."""
del is_lora
local_rank = int(torch.distributed.get_rank())
if broadcast_local_ranks is not None and local_rank not in [int(x) for x in broadcast_local_ranks]:
return True
groups = getattr(self, "_rlix_model_update_groups", {})
group = groups.get(str(group_name))
if group is None:
raise RuntimeError(f"RLix model update group {group_name!r} is not initialized")
state_dict_info = {
str(name): (torch.Size(shape), dtype)
for name, dtype, shape in zip(names, dtypes, shapes)
}

def _load(weights: list[tuple[str, torch.Tensor]]) -> None:
self.update_parameter_in_bucket(weights)

packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=group,
src=0,
post_unpack_func=_load,
)
return True

def destroy_collective_group(self, group_name: str) -> bool:
"""Destroy a temporary RLix model-update group if this rank joined it."""
groups = getattr(self, "_rlix_model_update_groups", None)
if not groups:
return True
groups.pop(str(group_name), None)
gc.collect()
torch.cuda.empty_cache()
return True

def verify_model(self, expected_stats: dict[str, Any]) -> bool:
"""Receiver API placeholder for RLix checksum verification."""
del expected_stats
return True

def finalize_weight_update(self) -> bool:
"""Run vLLM post-load hooks once after all RLix buckets are applied."""
from vllm.model_executor.model_loader.utils import process_weights_after_loading

process_weights_after_loading(
self.model_runner.model, self.model_config, self.device
)
self._maybe_process_fp8_kv_cache()
gc.collect()
torch.cuda.empty_cache()
return True

def cleanup(self) -> None:
"""Shutdown and cleanup resources."""
# Close ZMQ socket and context if they exist
Expand Down
Loading