diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 85953eb0ce..a4acd737ba 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2376,6 +2376,7 @@ def async_grpo_train( grpo_save_state: GRPOSaveState, master_config: MasterConfig, max_trajectory_age_steps: int = 1, + rlix_hooks: Optional[Any] = None, ) -> None: """Run asynchronous GRPO training with replay buffer. @@ -2394,6 +2395,17 @@ def async_grpo_train( master_config: Master configuration max_trajectory_age_steps: Maximum age (in training steps) for trajectories to be used in training """ + # F5/F11: RLix integration flag. + # True when RLIX_CONTROL_PLANE=rlix env var is set; False in standalone mode. + # Controls: skip standalone refit, enable before/after_training hooks, and + # skip prepare_for_generation() / refit_policy_generation() which conflict + # with scheduler-driven sleep/wake. + DO_TIME_SHARING: bool = os.environ.get("RLIX_CONTROL_PLANE") == "rlix" + + # F5/F9: Resolve hooks β€” use injected real implementation or no-op default. + from nemo_rl.algorithms.rlix_hooks import NoOpRLixHooks, RLixHooksProtocol + hooks: RLixHooksProtocol = rlix_hooks if rlix_hooks is not None else NoOpRLixHooks() + # Ensure we are running with a compatible async generation backend assert _should_use_async_rollouts(master_config), ( "Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " @@ -2533,6 +2545,10 @@ def async_grpo_train( # Ensure collector knows initial weight version trajectory_collector.set_weight_version.remote(weight_version) + # F6: Register collector handle with pipeline actor so _expand_workers can + # call set_weight_version after each selective sync (before routing activation). + hooks.on_trajectory_collector_created(trajectory_collector) + print("πŸ“¦ Started continuous background trajectory collection") print( @@ -2540,29 +2556,33 @@ def async_grpo_train( ) print("⏳ Preparing policy generation for training...") - if NEED_REFIT and POLICY_GENERATION_STALE: - print("πŸ”„ Refitting policy generation with actual model weights...") - try: - refit_policy_generation(policy, policy_generation, colocated_inference) - print("βœ… Policy generation refit completed successfully") - POLICY_GENERATION_STALE = False - except Exception as e: - print(f"❌ Policy generation refit failed: {e}") - import traceback - - traceback.print_exc() - return - else: - print("πŸ”„ Preparing policy generation for inference...") - try: - policy_generation.prepare_for_generation() - print("βœ… Policy generation preparation completed successfully") - except Exception as e: - print(f"❌ Policy generation preparation failed: {e}") - import traceback - - traceback.print_exc() - return + # F5/F11: In RLix mode, skip initial refit and prepare_for_generation. + # Weights are synced on first scheduler expand; sleep/wake is scheduler-driven. + # Calling prepare_for_generation here would reinitialize already-running inference workers. + if not DO_TIME_SHARING: + if NEED_REFIT and POLICY_GENERATION_STALE: + print("πŸ”„ Refitting policy generation with actual model weights...") + try: + refit_policy_generation(policy, policy_generation, colocated_inference) + print("βœ… Policy generation refit completed successfully") + POLICY_GENERATION_STALE = False + except Exception as e: + print(f"❌ Policy generation refit failed: {e}") + import traceback + + traceback.print_exc() + return + else: + print("πŸ”„ Preparing policy generation for inference...") + try: + policy_generation.prepare_for_generation() + print("βœ… Policy generation preparation completed successfully") + except Exception as e: + print(f"❌ Policy generation preparation failed: {e}") + import traceback + + traceback.print_exc() + return print("βœ… Policy generation setup complete, proceeding to validation...") @@ -2782,6 +2802,10 @@ def async_grpo_train( # Training phase (same as sync version) print("β–Ά Preparing for logprob inference...") + # F5: Block until scheduler grants actor_train GPUs. + # In RLix mode: scheduler asynchronously shrinks overlap inference + # workers before returning. In standalone mode: no-op. + hooks.before_training(step) with timer.time("logprob_inference_prep"): policy.prepare_for_lp_inference() @@ -2853,7 +2877,30 @@ def async_grpo_train( print("πŸ”„ Synchronizing policy weights to trajectory collector…") generation_logger_metrics = None - if NEED_REFIT: + if DO_TIME_SHARING: + # F5/F11: RLix mode β€” replace standalone refit with scheduler- + # driven expand. The scheduler's resize_infer(add=overlap_ranks) + # calls pipeline._expand_workers() which does the atomic + # wake + selective sync + version update + routing activation (F6). + # + # TODO F4: policy.build_cpu_bucket_cache(step) + # self._cache_ready_step = step + # TODO F11: policy.offload_training_gpu() + # policy.destroy_nccl_groups() + with timer.time("weight_sync"): + # Notify scheduler: actor_train GPUs are free. + # Scheduler asynchronously triggers expand + weight sync. + published_version = hooks.after_training(step) + # RLix publishes version=cache_ready_step after active + # refresh completes. Fall back to step for older hooks. + weight_version = ( + int(published_version) + if published_version is not None + else int(step) + ) + POLICY_GENERATION_STALE = False + elif NEED_REFIT: + # Standalone mode β€” original refit path. # Measure pending-generation wait as exposed_generation time print("πŸ”„ Coordinating with trajectory collector before refit...") with timer.time("exposed_generation"): @@ -2894,13 +2941,14 @@ def async_grpo_train( # Pause trajectory collection during validation to reduce memory pressure trajectory_collector.pause.remote() - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation( - policy, policy_generation, colocated_inference - ) - POLICY_GENERATION_STALE = False - else: - policy_generation.prepare_for_generation() + if not DO_TIME_SHARING: + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() val_metrics, validation_timings = validate( policy_generation, val_dataloader, diff --git a/nemo_rl/algorithms/rlix_hooks.py b/nemo_rl/algorithms/rlix_hooks.py new file mode 100644 index 0000000000..3184fc2628 --- /dev/null +++ b/nemo_rl/algorithms/rlix_hooks.py @@ -0,0 +1,81 @@ +"""RLix hook protocol and no-op default for NeMo RL's async_grpo_train. + +This module is the seam between NeMo RL (caller) and RLix (implementor). +Import direction: + nemo_rl/algorithms/grpo.py β†’ rlix_hooks.py (this file, NeMo RL repo) + rlix/pipeline/nemo_rl_pipeline.py β†’ provides NemoRLRLixHooks (real impl) + +NeMo RL code only depends on RLixHooksProtocol + NoOpRLixHooks from this file. +It never imports from the rlix package directly, preventing circular dependencies. +""" +from __future__ import annotations + +from typing import Any, runtime_checkable + +from typing_extensions import Protocol + + +@runtime_checkable +class RLixHooksProtocol(Protocol): + """Protocol that async_grpo_train expects from its rlix_hooks argument. + + NeMo RL standalone mode uses NoOpRLixHooks (all methods are no-ops). + RLix mode passes NemoRLRLixHooks (rlix/pipeline/nemo_rl_pipeline.py), + which makes blocking Ray RPC calls to the scheduler. + """ + + def before_training(self, step: int) -> None: + """Called before logprob inference + training; may block on scheduler. + + F5: in RLix mode, blocks until the scheduler grants actor_train GPUs. + The scheduler shrinks overlap inference workers before returning. + In standalone mode, this is a no-op. + """ + ... + + def after_training(self, step: int) -> int | None: + """Called after policy.train() completes; notifies scheduler to expand. + + F5: in RLix mode, notifies the scheduler that actor_train GPUs are + released. The scheduler asynchronously calls coordinator.resize_infer + (add=overlap_ranks), which routes to pipeline._expand_workers() (F6). + Weight sync and version update happen inside _expand_workers. + In standalone mode, this is a no-op. + + Preconditions (must be satisfied before calling in RLix mode): + - CPU bucket cache built (TODO F4: policy.build_cpu_bucket_cache) + - Training GPU VRAM offloaded (TODO F11: policy.offload_training_gpu) + - Megatron NCCL groups destroyed (TODO F11: destroy_nccl_groups) + """ + ... + + def on_trajectory_collector_created(self, collector: Any) -> None: + """Register the trajectory collector Ray actor handle with the pipeline. + + F6 dependency: _expand_workers calls collector.set_weight_version after + each selective sync. The handle must be registered here before the + first expand fires, otherwise _expand_workers logs a warning and skips + the version update. + + Called once, immediately after AsyncTrajectoryCollector is created and + set_weight_version has been called with the initial value. + """ + ... + + +class NoOpRLixHooks: + """Default no-op implementation used in NeMo RL standalone mode. + + Satisfies RLixHooksProtocol so grpo.py always calls hooks.* without + guarding against None. In standalone mode these are all no-ops; + the real implementations live in rlix/pipeline/nemo_rl_pipeline.py. + """ + + def before_training(self, step: int) -> None: + pass + + def after_training(self, step: int) -> int | None: + return None + + def on_trajectory_collector_created(self, collector: Any) -> None: + pass diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 9237788be1..357b31edc7 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -309,6 +309,155 @@ def _load_model_weights(weights, model_runner): return True + def setup_collective_group( + self, + model_update_name: str, + comm_plan: dict[int, Any], + mode: str, + timeout_s: float | None = None, + dp_rank: int | None = None, + ) -> bool: + """Create a temporary RLix model-update NCCL group for this vLLM rank.""" + del timeout_s # StatelessProcessGroup does not expose timeout control. + from nemo_rl.distributed.stateless_process_group import StatelessProcessGroup + + if len(comm_plan) != 1: + raise ValueError( + "RLix model-update receiver expects a single owner comm plan; " + f"got {len(comm_plan)} entries" + ) + owner_plan = next(iter(comm_plan.values())) + local_rank = int(torch.distributed.get_rank()) + rank = None + if mode == "sender": + rank = 0 + else: + if dp_rank is None: + raise ValueError("dp_rank is required for receiver setup") + local_ranks = owner_plan.get("broadcast_local_ranks_by_dp_rank", {}).get( + int(dp_rank), [] + ) + if local_rank not in [int(x) for x in local_ranks]: + return True + devices = owner_plan.get("tgt_devices", []) + ordered = [ + (int(item["rank"]), int(item["device"])) + for item in devices + ] + try: + rank = 1 + ordered.index((int(dp_rank), local_rank)) + except ValueError: + return True + + groups = getattr(self, "_rlix_model_update_groups", None) + if groups is None: + groups = {} + self._rlix_model_update_groups = groups + group = StatelessProcessGroup( + master_address=str(owner_plan["master_addr"]), + port=int(owner_plan["master_port"]), + rank=int(rank), + world_size=1 + len(owner_plan.get("tgt_devices", [])), + ) + group.init_nccl_communicator(device=self.device) + groups[str(model_update_name)] = group + return True + + def update_parameter_in_bucket( + self, + payload_list: list[Any], + is_lora: bool = False, + ipc_local_ranks: list[int] | None = None, + model_update_transport: str | None = None, + ) -> bool: + """Apply one IPC-delivered RLix weight bucket to selected local ranks.""" + del is_lora, model_update_transport + local_rank = int(torch.distributed.get_rank()) + if ipc_local_ranks is not None and local_rank not in [int(x) for x in ipc_local_ranks]: + return True + weights: list[tuple[str, torch.Tensor]] = [] + for item in payload_list: + if isinstance(item, tuple) and len(item) == 2: + weights.append(item) + elif isinstance(item, dict): + name = item.get("name") or item.get("key") + tensor = item.get("tensor") + if tensor is None: + tensor = item.get("value") + if name is not None and tensor is not None: + weights.append((str(name), tensor)) + policy_weights, draft_weights = self._split_policy_and_draft_weights(weights) + from nemo_rl.models.generation.vllm.quantization import fp8 + + if fp8.is_fp8_model(self.model_runner.vllm_config): + fp8.load_weights(policy_weights, self.model_runner) + else: + self.model_runner.model.load_weights(weights=policy_weights) + self._load_draft_weights(draft_weights) + torch.cuda.current_stream().synchronize() + return True + + def broadcast_parameter( + self, + group_name: str, + names: list[str], + dtypes: list[Any], + shapes: list[Any], + is_lora: bool = False, + broadcast_local_ranks: list[int] | None = None, + ) -> bool: + """Receive one RLix broadcast bucket and apply it to selected local ranks.""" + del is_lora + local_rank = int(torch.distributed.get_rank()) + if broadcast_local_ranks is not None and local_rank not in [int(x) for x in broadcast_local_ranks]: + return True + groups = getattr(self, "_rlix_model_update_groups", {}) + group = groups.get(str(group_name)) + if group is None: + raise RuntimeError(f"RLix model update group {group_name!r} is not initialized") + state_dict_info = { + str(name): (torch.Size(shape), dtype) + for name, dtype, shape in zip(names, dtypes, shapes) + } + + def _load(weights: list[tuple[str, torch.Tensor]]) -> None: + self.update_parameter_in_bucket(weights) + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=group, + src=0, + post_unpack_func=_load, + ) + return True + + def destroy_collective_group(self, group_name: str) -> bool: + """Destroy a temporary RLix model-update group if this rank joined it.""" + groups = getattr(self, "_rlix_model_update_groups", None) + if not groups: + return True + groups.pop(str(group_name), None) + gc.collect() + torch.cuda.empty_cache() + return True + + def verify_model(self, expected_stats: dict[str, Any]) -> bool: + """Receiver API placeholder for RLix checksum verification.""" + del expected_stats + return True + + def finalize_weight_update(self) -> bool: + """Run vLLM post-load hooks once after all RLix buckets are applied.""" + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + self._maybe_process_fp8_kv_cache() + gc.collect() + torch.cuda.empty_cache() + return True + def cleanup(self) -> None: """Shutdown and cleanup resources.""" # Close ZMQ socket and context if they exist diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 0faaad17a1..bc9d9246b1 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -16,6 +16,7 @@ import os import warnings from collections import defaultdict +from types import SimpleNamespace from typing import ( Any, AsyncGenerator, @@ -851,6 +852,43 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures + def get_model_update_receiver(self) -> Any: + """Expose a cluster-like receiver surface for RLix selective sync.""" + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + + rank2worker = { + dp_rank: self.worker_group.workers[ + self.worker_group.get_dp_leader_worker_idx(dp_rank) + ] + for dp_rank in range(self.worker_group.dp_size) + } + worker_config = SimpleNamespace( + device_mapping=getattr(self, "_rlix_device_mapping", None), + num_gpus_per_worker=self.cfg["vllm_cfg"]["tensor_parallel_size"], + ) + return SimpleNamespace( + workers=list(self.worker_group.workers), + rank2worker=rank2worker, + worker_config=worker_config, + ) + + def finalize_weight_update(self, dp_ranks: list[int]) -> list[ray.ObjectRef]: + """Run post-load hooks once on selected DP workers after RLix bucket sync.""" + if not self.worker_group or not self.worker_group.workers: + raise RuntimeError("Worker group is not initialized") + futures: list[ray.ObjectRef] = [] + for dp_rank in sorted(set(int(rank) for rank in dp_ranks)): + worker_idx = self.worker_group.get_dp_leader_worker_idx(dp_rank) + futures.append( + self.worker_group.run_single_worker_single_data( + "rlix_model_update_rpc", + worker_idx=worker_idx, + method_name="finalize_weight_update", + ) + ) + return futures + def start_gpu_profiling(self) -> None: """Start GPU profiling.""" futures = self.worker_group.run_all_workers_single_data("start_gpu_profiling") diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 15935b548a..121eb5d194 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -968,6 +968,11 @@ def update_weights_from_collective(self) -> bool: traceback.print_exc() return False + def rlix_model_update_rpc(self, method_name: str, *args: Any) -> bool: + """Forward an RLix model-update method to vLLM internal workers.""" + result = self.llm.collective_rpc(method_name, args=args) + return all(bool(x) for x in result) + def reset_prefix_cache(self): """Reset the prefix cache of vLLM engine.""" assert self.llm is not None, ( diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 12f73572b4..a59f2bfa89 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -1117,6 +1117,15 @@ async def update_weights_from_collective_async(self) -> bool: traceback.print_exc() return False + async def rlix_model_update_rpc(self, method_name: str, *args: Any) -> bool: + """Forward an RLix model-update method to vLLM internal workers.""" + result_or_coro = await self.llm.collective_rpc(method_name, args=args) + if asyncio.iscoroutine(result_or_coro): + result = await result_or_coro + else: + result = result_or_coro + return all(bool(x) for x in result) + async def reset_prefix_cache_async(self): """Async version of reset_prefix_cache.""" assert self.llm is not None, (