From 06cd64b966e0b88bf7f9f2cb876b0b847b38e808 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Jun 2026 10:34:54 +0000 Subject: [PATCH 1/5] Introduce rollout topology binding --- xtuner/v1/rl/rollout/controller.py | 125 +++++----- xtuner/v1/rl/rollout/health_manager.py | 10 +- xtuner/v1/rl/rollout/lmdeploy.py | 194 +++++++--------- xtuner/v1/rl/rollout/rollout_topology.py | 122 ++++++++++ xtuner/v1/rl/rollout/sglang.py | 94 ++------ xtuner/v1/rl/rollout/vllm.py | 15 +- xtuner/v1/rl/rollout/worker.py | 277 +++-------------------- xtuner/v1/rl/rollout/worker_registry.py | 156 +++++++------ 8 files changed, 420 insertions(+), 573 deletions(-) create mode 100644 xtuner/v1/rl/rollout/rollout_topology.py diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index ae4572334..6b3696f63 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -19,7 +19,7 @@ RolloutConfig, get_rollout_worker_base_cls, ) -from .worker_registry import RolloutWorkerMetadata, RolloutWorkerRegistry +from .worker_registry import RolloutWorkerRegistry # Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller @@ -63,17 +63,13 @@ def __init__( ) self.health_manager.start() - def get_rollout_metadata(self) -> RolloutWorkerMetadata: + def get_rollout_metadata(self) -> dict: """Get information about the current rollout setup. Returns: - dict: A dictionary containing the engine mesh list, server URL - dictionary, and the rollout configuration. + Legacy trainer/update-weight rollout metadata dictionary. """ - rollout_metadata = self.registry.training_metadata_snapshot() - self.logger.info(f"Rollout worker server URLs: {rollout_metadata['server_url_dict']}") - self.logger.info(f"Rollout worker session server URLs: {rollout_metadata['worker_session_url_dict']}") - return rollout_metadata + return self.registry.metadata().to_legacy() def register_active_workers_to_proxy(self) -> None: if self.proxy_manager is None: @@ -203,13 +199,16 @@ def _build_remote_worker_cls(self, worker_base_cls): }, )(worker_base_cls) - def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistry: + def _init_workers( + self, + placement_group: PlacementGroup, + ) -> RolloutWorkerRegistry: """Initializes and configures the pool of RolloutWorker actors. This method follows the same high-level flow as the legacy implementation: - create workers, initialize worker-local ports, build engine groups, - select workers that launch rollout servers, launch servers, and - expose request-entrypoint server URLs to rollout traffic. + create workers, initialize worker-local ports, build the bound rollout + topology, launch rollout servers, and expose request-entrypoint server + URLs to rollout traffic. Returns: A registry containing all server-process workers and the public @@ -222,79 +221,63 @@ def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistr workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( worker_cls, self.config, placement_group ) - rank_to_actor = {rank: worker for (rank, _), worker in zip(rank_bundle_idx_list, workers)} - - # Reserve worker-local ports for all actors first. build_engine_launch_specs - # uses the returned addresses to bind each ServerProcessSpec to its - # logical engine rendezvous address; only server-process owners call init(). - rank_to_dist_init_addr = { - rank: dist_init_addr - for (rank, _), dist_init_addr in zip( - rank_bundle_idx_list, - ray.get([worker.init_dist_port.remote() for worker in workers]), # type: ignore[attr-defined] - ) + dist_init_results = ray.get( + [ + worker.init_dist_port.remote() # type: ignore[attr-defined] + for worker in workers + ] + ) + rank_to_worker = { + rank: worker for worker, (rank, _dist_init_addr) in zip(workers, dist_init_results, strict=True) } + rank_to_dist_init_addr = dict(dist_init_results) - # Build engine groups and server-process specs from the rank/bundle mapping. - engine_launch_specs = worker_base_cls.build_engine_launch_specs( + rollout_topology = worker_base_cls.build_rollout_topology( self.config, rank_bundle_idx_list, rank_to_dist_init_addr, ) - # Keep the public metadata mesh compatible with origin/main. Backends - # may expose a different update-weight mesh than their internal launch - # topology, e.g. LMDeploy EP has one logical engine but one public entry - # per request-serving EP rank. - engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs) - - # Launch every server process described by the backend-specific specs. - server_rank_to_url = dict( - ray.get( - [ - rank_to_actor[server_process.worker_rank].init.remote( # type: ignore[attr-defined] - engine_launch_spec=engine_spec, - ) - for engine_spec in engine_launch_specs - for server_process in engine_spec.server_processes - ] - ) + server_launch_specs = rollout_topology.server_launch_specs() + server_workers = tuple( + (launch_spec, rank_to_worker[launch_spec.worker_rank]) for launch_spec in server_launch_specs ) - session_url_by_rank = dict( + + ray.get( + [ + worker.bind_server_launch_spec.remote(launch_spec) # type: ignore[attr-defined] + for launch_spec, worker in server_workers + ] + ) + init_results = tuple( ray.get( [ - ( - rank_to_actor[server_process.worker_rank].get_session_server_info.remote() # type: ignore[attr-defined] - ) - for engine_spec in engine_launch_specs - for server_process in engine_spec.server_processes + worker.init.remote() # type: ignore[attr-defined] + for _launch_spec, worker in server_workers ] ) ) - - registry = RolloutWorkerRegistry( - engine_rank_mesh_array=engine_rank_mesh_array, - rollout_config=self.config, - ) - for engine_spec in engine_launch_specs: - for server_process in engine_spec.server_processes: - rank = server_process.worker_rank - url = server_rank_to_url[rank] - session_url = session_url_by_rank.get(rank) - if server_process.accepts_rollout_requests and session_url is None: - raise RuntimeError(f"Rollout worker rank={rank} did not return session server URL during init.") - registry.register_started_server( - rank=rank, - actor=rank_to_actor[rank], - server_url=url, - session_url=session_url, - lifecycle_group_ranks=engine_spec.server_worker_ranks, - is_request_entrypoint=server_process.accepts_rollout_requests, + registry = RolloutWorkerRegistry(rollout_topology=rollout_topology, rollout_config=self.config) + for init_result in init_results: + if rollout_topology.is_request_entrypoint_rank(init_result.rank) and init_result.session_url is None: + raise RuntimeError( + f"Rollout worker rank={init_result.rank} did not return session server URL during init." ) + registry.register_started_server( + rank=init_result.rank, + actor=rank_to_worker[init_result.rank], + server_url=init_result.server_url, + session_url=init_result.session_url, + ) - server_process_workers_info = registry.all_workers() - self.logger.info(f"Rollout server-process worker URLs: {[info.url for info in server_process_workers_info]}") - lifecycle_groups = sorted({info.lifecycle_group_ranks for info in server_process_workers_info}) - self.logger.info(f"Rollout worker lifecycle groups: {lifecycle_groups}") + rollout_metadata = registry.metadata() + legacy_metadata = rollout_metadata.to_legacy() + self.logger.info( + "Rollout worker registry snapshot: " + f"server_urls={legacy_metadata['server_url_dict']}, " + f"session_urls={legacy_metadata['worker_session_url_dict']}, " + f"server_process_urls={[worker.url for worker in registry.all_workers()]}, " + f"lifecycle_groups={registry.lifecycle_groups()}" + ) return registry diff --git a/xtuner/v1/rl/rollout/health_manager.py b/xtuner/v1/rl/rollout/health_manager.py index 846cd37bd..30236e298 100644 --- a/xtuner/v1/rl/rollout/health_manager.py +++ b/xtuner/v1/rl/rollout/health_manager.py @@ -486,8 +486,8 @@ def _restart_worker_group( ) init_results = ray.get( [ - # init() reuses the immutable launch spec cached on each actor - # during controller startup, including placement bundles and dist addr. + # init() reuses the server launch spec bound during + # controller startup. worker.actor.init.remote() # type: ignore[attr-defined] for worker in group.workers ], @@ -505,11 +505,11 @@ def _restart_worker_group( return False for worker, init_result in zip(group.workers, init_results): - init_rank, init_url = init_result - if init_rank != worker.rank or init_url != worker.url: + if init_result.rank != worker.rank or init_result.server_url != worker.url: logger.error( f"Rollout worker restart returned unexpected endpoint: rank={worker.rank}, " - f"init_rank={init_rank}, expected_url={worker.url}, init_url={init_url}." + f"init_rank={init_result.rank}, expected_url={worker.url}, " + f"init_url={init_result.server_url}." ) self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) return False diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 7c4506238..044eb1c19 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -1,6 +1,6 @@ import os from argparse import Namespace -from typing import Any, Dict, List +from typing import Any, Dict, List, Mapping import numpy as np import ray @@ -10,7 +10,8 @@ from transformers import AutoTokenizer from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .worker import EngineLaunchSpec, EngineLaunchSpecs, RolloutConfig, RolloutWorker, ServerProcessSpec +from .rollout_topology import RolloutTopology +from .worker import RolloutConfig, RolloutWorker SHARED_STORE = "shared_store" @@ -80,118 +81,77 @@ def __init__( self.lmdeploy_actor = None @classmethod - def build_engine_launch_specs( + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build LMDeploy server launch layout. - - LMDeploy EP starts one request-serving server per EP rank. - - Example with expert_parallel_size=2: - rank_bundle_idx_list is [(0, 0), (1, 1), (2, 2), (3, 3)]. - rank identifies the rollout worker; bundle idx identifies the Ray - placement-group bundle that owns the GPU resource. - - If rank_to_dist_init_addr is: - {0: "addr0", 1: "addr1", 2: "addr2", 3: "addr3"} - - The launch specs are: - EngineLaunchSpec( - engine_ranks=(0, 1), - server_processes=( - ServerProcessSpec( - worker_rank=0, - placement_group_bundle_idxs=(0,), - dist_init_addr="addr0", - ), - ServerProcessSpec( - worker_rank=1, - placement_group_bundle_idxs=(1,), - dist_init_addr="addr0", - ), - ), - ) - EngineLaunchSpec( - engine_ranks=(2, 3), - server_processes=( - ServerProcessSpec( - worker_rank=2, - placement_group_bundle_idxs=(2,), - dist_init_addr="addr2", - ), - ServerProcessSpec( - worker_rank=3, - placement_group_bundle_idxs=(3,), - dist_init_addr="addr2", - ), - ), - ) - - Each EP rank launches a server process, so server_worker_ranks is the - same as engine_ranks, and every server accepts rollout requests. - """ - if config.expert_parallel_size <= 1: - return RolloutWorker.build_engine_launch_specs( - config, - rank_bundle_idx_list, - rank_to_dist_init_addr, - ) - - ep_size = config.expert_parallel_size + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + """Build LMDeploy rollout topology with bound engine dist-init + addresses.""" + engines = [] num_workers = len(rank_bundle_idx_list) - if num_workers % ep_size != 0: - raise ValueError(f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}.") - - engine_launch_specs: list[EngineLaunchSpec] = [] - for engine_start in range(0, num_workers, ep_size): - engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] - engine_ranks = tuple(rank for rank, _ in engine_meta) - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] - # LMDeploy EP launches one server process for each EP rank. Each - # server owns exactly one placement-group bundle, and every server - # can be used as a rollout request entrypoint. - engine_launch_specs.append( - EngineLaunchSpec( - engine_ranks=engine_ranks, - server_processes=tuple( - ServerProcessSpec( - worker_rank=server_rank, - placement_group_bundle_idxs=(bundle_idx,), - dist_init_addr=engine_dist_init_addr, - ) - for server_rank, bundle_idx in engine_meta - ), + if config.expert_parallel_size <= 1: + num_gpus_per_engine = config.num_gpus_per_engine + if num_workers % num_gpus_per_engine != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by " + f"num_gpus_per_engine={num_gpus_per_engine}." ) - ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), - ) - - @classmethod - def build_metadata_engine_rank_mesh_array( - cls, - engine_launch_specs: EngineLaunchSpecs, - ) -> list[list[int]]: - """Keep LMDeploy EP metadata compatible with origin/main. - - Pure EP uses one request-serving server per EP rank. The logical engine topology is still stored in - EngineLaunchSpec.engine_ranks for dp_rank and lifecycle operations, but update_weighter expects the public - metadata mesh to contain one single-rank entry per request server. - """ - metadata_engine_rank_mesh_array: list[list[int]] = [] - for engine_spec in engine_launch_specs: - request_entrypoint_servers = engine_spec.request_entrypoint_servers - if len(request_entrypoint_servers) > 1: - metadata_engine_rank_mesh_array.extend( - [server_process.worker_rank] for server_process in request_entrypoint_servers + for engine_start in range(0, num_workers, num_gpus_per_engine): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] + engine_ranks = tuple(rank for rank, _ in engine_meta) + engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) + dist_init_addr_owner_rank = engine_ranks[0] + engines.append( + RolloutTopology.engine( + engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=( + RolloutTopology.server_process( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=engine_bundle_idxs, + ), + ), + ) + ) + else: + ep_size = config.expert_parallel_size + if num_workers % ep_size != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}." + ) + for engine_start in range(0, num_workers, ep_size): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] + engine_ranks = tuple(rank for rank, _ in engine_meta) + dist_init_addr_owner_rank = engine_ranks[0] + engines.append( + RolloutTopology.engine( + engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=tuple( + RolloutTopology.server_process( + worker_rank=server_rank, + placement_group_bundle_idxs=(bundle_idx,), + ) + for server_rank, bundle_idx in engine_meta + ), + ) ) + + training_engine_mesh: list[tuple[int, ...]] = [] + for engine in engines: + entrypoint_processes = tuple( + process for process in engine.server_processes if process.accepts_rollout_requests + ) + if len(entrypoint_processes) == 1: + training_engine_mesh.append(tuple(engine.engine_ranks)) else: - metadata_engine_rank_mesh_array.append(list(engine_spec.engine_ranks)) - return metadata_engine_rank_mesh_array + training_engine_mesh.extend((process.worker_rank,) for process in entrypoint_processes) + return RolloutTopology( + engines=tuple(engines), + training_engine_mesh=tuple(training_engine_mesh), + ) def offload(self): """Offloads the model weights and KV cache.""" @@ -342,7 +302,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: "NPU": "ascend", } - extra_config = self.config.extra_rollout_config or dict() + extra_config = self.config.extra_rollout_config lmdeploy_config_kwargs = { k.replace("lmdeploy_", ""): v for k, v in extra_config.items() if k.startswith("lmdeploy_") } @@ -383,14 +343,13 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: if backend == "pytorch" and self.config.max_prefill_token_num: extra_engine_config["max_prefill_token_num"] = self.config.max_prefill_token_num + assert self.server_launch_spec is not None dp_rank = 0 if backend == "pytorch": # currently only support ep > 1 and tp == 1 / ep == 1 and tp > 1 assert ep_size == 1 or tp_size == 1 if ep_size > 1: - engine_launch_spec = self.engine_launch_spec - assert engine_launch_spec is not None - dp_rank = engine_launch_spec.engine_ranks.index(self.rank) + dp_rank = self.server_launch_spec.engine_rank backend_config = ( PytorchEngineConfig( @@ -413,7 +372,10 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: else TurbomindEngineConfig( tp=tp_size, max_batch_size=self.config.rollout_max_batch_size_per_instance, - devices=[bundle_idxs % self.config.gpus_per_node for bundle_idxs in self.engine_bundle_idxs], + devices=[ + bundle_idx % self.config.gpus_per_node + for bundle_idx in self.server_launch_spec.placement_group_bundle_idxs + ], empty_init=self.config.skip_load_weights, session_len=self.config.context_length, model_format="fp8" if self.config.enable_float8 else None, @@ -431,7 +393,9 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: env = { "LMDEPLOY_RAY_EXTERNAL_NS": ray_runtime_ctx.namespace, "LMDEPLOY_RAY_EXTERNAL_PG_NAME": current_pg_name, - "LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES": ",".join(map(str, self.engine_bundle_idxs)), + "LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES": ",".join( + map(str, self.server_launch_spec.placement_group_bundle_idxs) + ), } if self.accelerator == "NPU": @@ -444,7 +408,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: ) if tp_size > 1: - dist_addr, dist_port = self.dist_init_addr.split(":")[:2] + dist_addr, dist_port = self.server_launch_spec.dist_init_addr.split(":")[:2] env.update( { "LMDEPLOY_DIST_MASTER_ADDR": dist_addr, @@ -452,7 +416,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: } ) elif ep_size > 1: - dist_addr, dist_port = self.dist_init_addr.split(":")[:2] + dist_addr, dist_port = self.server_launch_spec.dist_init_addr.split(":")[:2] if speculative_num_draft_tokens is not None: deepep_max_tokens_per_rank = max_batch_size * (1 + speculative_num_draft_tokens) else: diff --git a/xtuner/v1/rl/rollout/rollout_topology.py b/xtuner/v1/rl/rollout/rollout_topology.py new file mode 100644 index 000000000..db54f8050 --- /dev/null +++ b/xtuner/v1/rl/rollout/rollout_topology.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +__all__ = ["RolloutTopology", "ServerLaunchSpec"] + + +@dataclass(frozen=True) +class _ServerProcess: + """Rollout topology expression for one worker-owned server process.""" + + worker_rank: int + placement_group_bundle_idxs: tuple[int, ...] + accepts_rollout_requests: bool = True + node_rank: int = 0 + nnodes: int = 1 + + +@dataclass(frozen=True) +class _Engine: + """Rollout layout for one logical inference engine.""" + + engine_ranks: tuple[int, ...] + dist_init_addr: str + server_processes: tuple[_ServerProcess, ...] + + +@dataclass(frozen=True) +class ServerLaunchSpec: + """Worker-facing launch data projected from rollout topology.""" + + worker_rank: int + placement_group_bundle_idxs: tuple[int, ...] + dist_init_addr: str + engine_rank: int + node_rank: int = 0 + nnodes: int = 1 + + +@dataclass(frozen=True) +class RolloutTopology: + """Immutable rollout engine layout after dist-init addresses are resolved. + + Actor handles, server URLs, session URLs, and lifecycle state belong to RolloutWorkerRegistry. + """ + + engines: tuple[_Engine, ...] + training_engine_mesh: tuple[tuple[int, ...], ...] + _server_process_by_rank: dict[int, _ServerProcess] = field(init=False, repr=False, compare=False) + _lifecycle_group_by_rank: dict[int, tuple[int, ...]] = field(init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + server_process_by_rank: dict[int, _ServerProcess] = {} + lifecycle_group_by_rank: dict[int, tuple[int, ...]] = {} + for engine in self.engines: + lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) + for server in engine.server_processes: + if server.worker_rank in server_process_by_rank: + raise ValueError(f"Duplicate rollout server process worker_rank={server.worker_rank}.") + server_process_by_rank[server.worker_rank] = server + lifecycle_group_by_rank[server.worker_rank] = lifecycle_group + + object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) + object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) + + @staticmethod + def engine( + *, + engine_ranks: tuple[int, ...], + dist_init_addr: str, + server_processes: tuple[_ServerProcess, ...], + ) -> _Engine: + return _Engine( + engine_ranks=engine_ranks, + dist_init_addr=dist_init_addr, + server_processes=server_processes, + ) + + @staticmethod + def server_process( + *, + worker_rank: int, + placement_group_bundle_idxs: tuple[int, ...], + accepts_rollout_requests: bool = True, + node_rank: int = 0, + nnodes: int = 1, + ) -> _ServerProcess: + return _ServerProcess( + worker_rank=worker_rank, + placement_group_bundle_idxs=placement_group_bundle_idxs, + accepts_rollout_requests=accepts_rollout_requests, + node_rank=node_rank, + nnodes=nnodes, + ) + + def server_launch_specs(self) -> tuple[ServerLaunchSpec, ...]: + return tuple( + ServerLaunchSpec( + worker_rank=server.worker_rank, + placement_group_bundle_idxs=server.placement_group_bundle_idxs, + dist_init_addr=engine.dist_init_addr, + engine_rank=engine.engine_ranks.index(server.worker_rank), + node_rank=server.node_rank, + nnodes=server.nnodes, + ) + for engine in self.engines + for server in engine.server_processes + ) + + def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: + return tuple(dict.fromkeys(self._lifecycle_group_by_rank.values())) + + def is_request_entrypoint_rank(self, rank: int) -> bool: + server = self._server_process_by_rank.get(rank) + return server is not None and server.accepts_rollout_requests + + def lifecycle_group_for_server_rank(self, rank: int) -> tuple[int, ...]: + try: + return self._lifecycle_group_by_rank[rank] + except KeyError: + raise KeyError(f"rank={rank} does not own a rollout server process.") from None diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index 047fa2d5a..89dd125c7 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -1,6 +1,6 @@ import base64 import os -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Mapping, Union import numpy as np import ray @@ -11,13 +11,8 @@ from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import XTUNER_DETERMINISTIC -from .worker import ( - EngineLaunchSpec, - EngineLaunchSpecs, - RolloutConfig, - RolloutWorker, - ServerProcessSpec, -) +from .rollout_topology import RolloutTopology +from .worker import RolloutConfig, RolloutWorker class SGLangWorker(RolloutWorker): @@ -49,53 +44,14 @@ def __init__( self.enable_return_routed_experts = self.config.enable_return_routed_experts @classmethod - def build_engine_launch_specs( + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build SGLang server launch layout. - - SGLang starts one server per node in a logical engine. Only node 0 is - used as the rollout request entrypoint. - - Example with expert_parallel_size=16 and gpus_per_node=8: - rank_bundle_idx_list is: - [(0, 0), (1, 1), ..., (15, 15)] - - If rank_to_dist_init_addr is: - {0: "addr0", 1: "addr1", ..., 15: "addr15"} - - The launch spec is: - EngineLaunchSpec( - engine_ranks=(0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15), - server_processes=( - ServerProcessSpec( - worker_rank=0, - placement_group_bundle_idxs=(0, 1, 2, 3, 4, 5, 6, 7), - dist_init_addr="addr0", - accepts_rollout_requests=True, - node_rank=0, - nnodes=2, - ), - ServerProcessSpec( - worker_rank=8, - placement_group_bundle_idxs=(8, 9, 10, 11, 12, 13, 14, 15), - dist_init_addr="addr0", - accepts_rollout_requests=False, - node_rank=1, - nnodes=2, - ), - ), - ) - - SGLang starts one server per node, so server_worker_ranks is (0, 8). - Only the node-0 server accepts rollout requests. - """ + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: num_workers = len(rank_bundle_idx_list) - num_gpus_per_engine = cls._get_num_gpus_per_engine(config) + num_gpus_per_engine = config.num_gpus_per_engine if num_workers % num_gpus_per_engine != 0: raise ValueError( f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." @@ -106,7 +62,7 @@ def build_engine_launch_specs( ) nnodes = max(1, num_gpus_per_engine // config.gpus_per_node) - engine_launch_specs: list[EngineLaunchSpec] = [] + engines = [] for engine_start in range(0, num_workers, num_gpus_per_engine): engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] engine_ranks = tuple(rank for rank, _ in engine_meta) @@ -115,30 +71,30 @@ def build_engine_launch_specs( # first rank of each node owns that node's bundles, while only node # 0 is exposed as the rollout request entrypoint. server_ranks = engine_ranks[:: config.gpus_per_node] - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[server_ranks[0]] - server_processes: list[ServerProcessSpec] = [] + dist_init_addr_owner_rank = server_ranks[0] + server_processes = [] for node_rank, server_rank in enumerate(server_ranks): node_bundle_start = node_rank * config.gpus_per_node node_bundle_end = node_bundle_start + config.gpus_per_node server_processes.append( - ServerProcessSpec( + RolloutTopology.server_process( worker_rank=server_rank, placement_group_bundle_idxs=engine_bundle_idxs[node_bundle_start:node_bundle_end], - dist_init_addr=engine_dist_init_addr, accepts_rollout_requests=node_rank == 0, node_rank=node_rank, nnodes=nnodes, ) ) - engine_launch_specs.append( - EngineLaunchSpec( + engines.append( + RolloutTopology.engine( engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], server_processes=tuple(server_processes), ) ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), + return RolloutTopology( + engines=tuple(engines), + training_engine_mesh=tuple(tuple(engine.engine_ranks) for engine in engines), ) def _get_request_payload(self, rollout_state: RolloutState) -> dict: @@ -325,7 +281,7 @@ def _transform_rollout_config_to_server_configs(self): os.environ.pop("CUDA_VISIBLE_DEVICES", None) from sglang.srt.server_args import ServerArgs - extra_config = self.config.extra_rollout_config or dict() + extra_config = self.config.extra_rollout_config sglang_config_kwargs = { k.replace("sglang_", ""): v for k, v in extra_config.items() if k.startswith("sglang_") } @@ -338,13 +294,7 @@ def _transform_rollout_config_to_server_configs(self): ) tp_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.tensor_parallel_size ep_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.expert_parallel_size - server_process_spec = self._get_current_server_process_spec() - nnodes = ( - server_process_spec.nnodes - if server_process_spec is not None - else max(1, num_gpus_per_engine // self.config.gpus_per_node) - ) - node_rank = server_process_spec.node_rank if server_process_spec is not None else 0 + assert self.server_launch_spec is not None assigned_gpu_id = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) # SGLang 0.5.10 默认启用的 Piecewise CUDA Graph 在启动 warmup compile 阶段会报错。sglang的文档提到这个功能还是实验功能,可能还不太稳定(https://sgl-project-sglang-93.mintlify.app/optimization/cuda-graph#bug-report)。暂时先通过disable_piecewise_cuda_graph=True关掉改功能 @@ -354,11 +304,11 @@ def _transform_rollout_config_to_server_configs(self): host=self.host, port=self.server_port, nccl_port=self.nccl_port, - dist_init_addr=self.dist_init_addr, + dist_init_addr=self.server_launch_spec.dist_init_addr, base_gpu_id=assigned_gpu_id, gpu_id_step=1, - nnodes=nnodes, - node_rank=node_rank, + nnodes=self.server_launch_spec.nnodes, + node_rank=self.server_launch_spec.node_rank, skip_server_warmup=True, mem_fraction_static=self.config.gpu_memory_utilization, enable_memory_saver=True, diff --git a/xtuner/v1/rl/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py index 8cbeaef69..a999baf98 100644 --- a/xtuner/v1/rl/rollout/vllm.py +++ b/xtuner/v1/rl/rollout/vllm.py @@ -2,7 +2,7 @@ import os import traceback from argparse import Namespace -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Mapping, Union import numpy as np import ray @@ -16,6 +16,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_status_from_finish_reason from xtuner.v1.utils.device import get_device, get_torch_device_module +from .rollout_topology import RolloutTopology from .worker import RolloutConfig, RolloutWorker @@ -131,6 +132,15 @@ def run_lmdeploy_server_wrapper(server_namespace: Namespace): class vLLMWorker(RolloutWorker): + @classmethod + def build_rollout_topology( + cls, + config: RolloutConfig, + rank_bundle_idx_list: list[tuple[int, int]], + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + raise NotImplementedError("vLLM rollout topology has not been verified after topology refactor.") + def __init__( self, config: RolloutConfig, @@ -323,13 +333,14 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: args["limit_mm_per_prompt"] = {"image": 10, "video": 0} args["enable_log_requests"] = False args["uvicorn_log_level"] = "error" + assert self.server_launch_spec is not None env = { "VLLM_VERSION": "0.11.0", "TASK_QUEUE_ENABLE": "0", "CPU_AFFINITY_CONF": "2", "VLLM_USE_V1": "1", "VLLM_RAY_PER_WORKER_GPUS": "0.1", - "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.engine_bundle_idxs)), + "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.server_launch_spec.placement_group_bundle_idxs)), "VLLM_MONITOR": "1", "VLLM_ACCU_MONITOR": "0", "CUSTOM_SCHEDULE_KV_LIMIT": "0.9", diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index c50ffa6bf..aed48e50e 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -9,7 +9,7 @@ from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, TypeAlias, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, Optional, Union, cast import httpx import ray @@ -40,6 +40,7 @@ from .constants import ROLLOUT_HTTP_MAX_CONNECTIONS, ROLLOUT_RAY_GENERATE_MAX_CONCURRENCY from .health_manager import ROLLOUT_RAY_GET_TIMEOUT +from .rollout_topology import RolloutTopology, ServerLaunchSpec from .session_server import SessionServerActor from .utils import PartialRolloutHandler @@ -53,56 +54,12 @@ @dataclass(frozen=True) -class ServerProcessSpec: - """How to start one rollout server process.""" - - # Worker rank that owns this server process. - worker_rank: int - # Placement-group bundle indexes assigned to this server process. - placement_group_bundle_idxs: tuple[int, ...] - # Distributed init address used by every server process in the same engine. - # Filled after init_dist_port initializes worker-local ports. - dist_init_addr: str | None = None - # Whether this server is exposed as a rollout request entrypoint. Some - # backends launch extra server processes that must participate in - # lifecycle/health operations but must not be added to worker_server_urls_map - # or receive normal rollout traffic. - accepts_rollout_requests: bool = True - # Node index of this server inside a multi-node logical engine. - node_rank: int = 0 - # Number of nodes used by this logical engine. - nnodes: int = 1 +class RolloutWorkerInitResult: + """Result returned by RolloutWorker.init() after its server starts.""" - -@dataclass(frozen=True) -class EngineLaunchSpec: - """How to launch rollout servers for one logical inference engine.""" - - # All worker ranks that form this logical inference engine. - engine_ranks: tuple[int, ...] - # Server processes required by this engine. - server_processes: tuple[ServerProcessSpec, ...] - - @property - def server_worker_ranks(self) -> tuple[int, ...]: - return tuple(server.worker_rank for server in self.server_processes) - - @property - def request_entrypoint_servers(self) -> tuple[ServerProcessSpec, ...]: - return tuple(server for server in self.server_processes if server.accepts_rollout_requests) - - @property - def request_entrypoint_worker_ranks(self) -> tuple[int, ...]: - return tuple(server.worker_rank for server in self.request_entrypoint_servers) - - @property - def placement_group_bundle_idxs(self) -> tuple[int, ...]: - return tuple( - bundle_idx for server in self.server_processes for bundle_idx in server.placement_group_bundle_idxs - ) - - -EngineLaunchSpecs: TypeAlias = tuple[EngineLaunchSpec, ...] + rank: int + server_url: str + session_url: str | None def get_rollout_worker_base_cls(config: "RolloutConfig") -> type["RolloutWorker"]: @@ -579,8 +536,7 @@ def __init__( self.accelerator = accelerator self.server_func: Callable self.endpoints: dict[str, str] = dict() - self.engine_rank_mesh_array: list[list[int]] = [] - self.engine_launch_spec: EngineLaunchSpec | None = None + self.server_launch_spec: ServerLaunchSpec | None = None # Keep this deliberately large so requests do not queue in the # RolloutWorker/httpx client; the inference engine owns rollout request # scheduling and queueing. @@ -588,7 +544,6 @@ def __init__( limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) self.server_task = None - self.engine_bundle_idxs: list[int] = [] self.server_process: Optional[multiprocessing.Process] = None self.session_server_actor: Any | None = None self.session_server_url: str | None = None @@ -602,205 +557,48 @@ def __init__( self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token self.receive_abort_request = threading.Event() - self.dist_init_addr: str = "" self.serverl_url: str = "" self.partial_rollout_handler = PartialRolloutHandler() self.enable_partial_rollout: bool = False - @staticmethod - def _get_num_gpus_per_engine(config: RolloutConfig) -> int: - return config.num_gpus_per_engine - @classmethod - def validate_engine_launch_specs( - cls, - engine_launch_specs: EngineLaunchSpecs, - *, - known_worker_ranks: tuple[int, ...] | None = None, - ) -> EngineLaunchSpecs: - """Validate backend launch layout before the controller launches - servers.""" - if not engine_launch_specs: - raise ValueError("engine_launch_specs must define at least one engine.") - - known_worker_rank_set = set(known_worker_ranks) if known_worker_ranks is not None else None - seen_engine_ranks: set[int] = set() - seen_server_ranks: set[int] = set() - seen_bundle_idxs: set[int] = set() - for engine_index, engine_spec in enumerate(engine_launch_specs): - if not engine_spec.engine_ranks: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one engine rank.") - engine_rank_set = set(engine_spec.engine_ranks) - if len(engine_rank_set) != len(engine_spec.engine_ranks): - raise ValueError( - f"EngineLaunchSpec[{engine_index}] has duplicate engine ranks: {engine_spec.engine_ranks}." - ) - if known_worker_rank_set is not None: - unknown_engine_ranks = sorted( - rank for rank in engine_spec.engine_ranks if rank not in known_worker_rank_set - ) - if unknown_engine_ranks: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] references unknown engine ranks: {unknown_engine_ranks}." - ) - duplicated_engine_ranks = sorted(rank for rank in engine_spec.engine_ranks if rank in seen_engine_ranks) - if duplicated_engine_ranks: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] engine ranks appear in more than one engine: " - f"{duplicated_engine_ranks}." - ) - seen_engine_ranks.update(engine_spec.engine_ranks) - - if not engine_spec.server_processes: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one server process.") - - for server_process in engine_spec.server_processes: - server_rank = server_process.worker_rank - if server_rank not in engine_rank_set: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] server worker_rank={server_rank} " - f"must be part of engine_ranks={engine_spec.engine_ranks}." - ) - if server_rank in seen_server_ranks: - raise ValueError(f"Server worker_rank={server_rank} appears in more than one server process.") - seen_server_ranks.add(server_rank) - - if not server_process.placement_group_bundle_idxs: - raise ValueError(f"Server worker_rank={server_rank} must own at least one placement-group bundle.") - if len(set(server_process.placement_group_bundle_idxs)) != len( - server_process.placement_group_bundle_idxs - ): - raise ValueError( - f"Server worker_rank={server_rank} has duplicate placement-group bundles: " - f"{server_process.placement_group_bundle_idxs}." - ) - duplicated_bundle_idxs = sorted( - bundle_idx - for bundle_idx in server_process.placement_group_bundle_idxs - if bundle_idx in seen_bundle_idxs - ) - if duplicated_bundle_idxs: - raise ValueError( - f"Placement-group bundles are assigned to multiple server processes: {duplicated_bundle_idxs}." - ) - seen_bundle_idxs.update(server_process.placement_group_bundle_idxs) - - if server_process.nnodes < 1: - raise ValueError(f"Server worker_rank={server_rank} must have nnodes >= 1.") - if server_process.node_rank < 0 or server_process.node_rank >= server_process.nnodes: - raise ValueError( - f"Server worker_rank={server_rank} has invalid node_rank={server_process.node_rank} " - f"for nnodes={server_process.nnodes}." - ) - - if not engine_spec.request_entrypoint_servers: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must expose at least one request entrypoint.") - - if known_worker_rank_set is not None: - missing_engine_ranks = sorted(known_worker_rank_set - seen_engine_ranks) - if missing_engine_ranks: - raise ValueError( - f"EngineLaunchSpecs do not cover known worker ranks in engine_ranks: {missing_engine_ranks}." - ) - - return engine_launch_specs - - @classmethod - def build_engine_launch_specs( + @abstractmethod + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build default launch spec: one request-serving server per engine.""" - num_gpus_per_engine = cls._get_num_gpus_per_engine(config) - num_workers = len(rank_bundle_idx_list) - if num_workers % num_gpus_per_engine != 0: - raise ValueError( - f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." - ) - - engine_launch_specs: list[EngineLaunchSpec] = [] - for engine_start in range(0, num_workers, num_gpus_per_engine): - engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] - engine_ranks = tuple(rank for rank, _ in engine_meta) - engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] - engine_launch_specs.append( - EngineLaunchSpec( - engine_ranks=engine_ranks, - server_processes=( - ServerProcessSpec( - worker_rank=engine_ranks[0], - placement_group_bundle_idxs=engine_bundle_idxs, - dist_init_addr=engine_dist_init_addr, - ), - ), - ) - ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), - ) - - @classmethod - def build_metadata_engine_rank_mesh_array( - cls, - engine_launch_specs: EngineLaunchSpecs, - ) -> list[list[int]]: - """Build the public engine mesh returned in rollout metadata. - - By default, the public metadata mesh matches the logical engine topology. Backends with multiple request - servers per logical engine can override this to preserve their legacy update-weight mesh semantics. - """ - return [list(engine_spec.engine_ranks) for engine_spec in engine_launch_specs] - - def _get_current_server_process_spec( - self, - engine_launch_spec: EngineLaunchSpec | None = None, - ) -> ServerProcessSpec | None: - engine_launch_spec = engine_launch_spec or self.engine_launch_spec - if engine_launch_spec is None: - return None - - for server_process_spec in engine_launch_spec.server_processes: - if server_process_spec.worker_rank == self.rank: - return server_process_spec - raise RuntimeError( - f"Engine launch spec does not include rollout worker rank={self.rank} " - f"in server_worker_ranks={engine_launch_spec.server_worker_ranks}." - ) + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + raise NotImplementedError("Concrete rollout worker classes must implement build_rollout_topology().") def set_enable_partial_rollout(self, enable: bool) -> None: self.enable_partial_rollout = enable + def bind_server_launch_spec(self, server_launch_spec: ServerLaunchSpec) -> None: + if server_launch_spec.worker_rank != self.rank: + raise ValueError( + f"Server launch spec rank={server_launch_spec.worker_rank} does not match worker rank={self.rank}." + ) + self.server_launch_spec = server_launch_spec + def init( self, - *, - engine_launch_spec: EngineLaunchSpec | None = None, - ) -> tuple[int, str]: + ) -> RolloutWorkerInitResult: """Initialize the worker and launch the server. Returns: - Tuple[int, str]: A tuple containing the worker's rank and its - server URL. + Startup result containing rank, server URL, and session URL. """ - if engine_launch_spec is not None: - # Initial controller startup passes the immutable launch spec and caches - # it on the actor. Recovery calls init() without arguments after - # shutdown, intentionally reusing this cached placement/dist layout. - self.engine_launch_spec = engine_launch_spec - server_process_spec = cast( - ServerProcessSpec, - self._get_current_server_process_spec(engine_launch_spec), - ) - self.engine_bundle_idxs = list(server_process_spec.placement_group_bundle_idxs) - if server_process_spec.dist_init_addr is not None: - self.dist_init_addr = server_process_spec.dist_init_addr + if self.server_launch_spec is None: + raise RuntimeError("RolloutWorker.bind_server_launch_spec() must be called before init().") self.receive_abort_request.clear() self._launch_server() self._start_session_server() - return (self.rank, self.server_url) + return RolloutWorkerInitResult( + rank=self.rank, + server_url=self.server_url, + session_url=self.session_server_url, + ) def set_skip_load_weights(self, skip_load_weights: bool) -> None: self.config = self.config.model_copy(update={"skip_load_weights": skip_load_weights}) @@ -808,7 +606,7 @@ def set_skip_load_weights(self, skip_load_weights: bool) -> None: def restore_skip_load_weights(self) -> None: self.config = self.config.model_copy(update={"skip_load_weights": self._default_skip_load_weights}) - def init_dist_port(self) -> str: + def init_dist_port(self) -> tuple[int, str]: """Initialize distributed communication ports. This method initializes four fixed ports for the distributed setup: @@ -816,7 +614,7 @@ def init_dist_port(self) -> str: for NCCL, and one for the session server. Returns: - str: The distributed initialization address (host:port). + Worker rank and distributed initialization address (host:port). """ local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) base_port = self.config.dist_port_base + local_rank * 4 @@ -825,9 +623,9 @@ def init_dist_port(self) -> str: self.server_port = base_port + 1 self.nccl_port = base_port + 2 self.session_server_port = base_port + 3 - self.dist_init_addr = f"{self.host}:{self.dist_port}" + dist_init_addr = f"{self.host}:{self.dist_port}" self.server_url = f"http://{self.host}:{self.server_port}" - return self.dist_init_addr + return self.rank, dist_init_addr def shutdown(self, *, stop_session_server: bool = False): """Shut down the worker, its server task, and any child processes.""" @@ -873,11 +671,12 @@ def _start_session_server(self) -> None: if self.session_server_actor is not None: return + assert self.server_launch_spec is not None current_pg = ray.util.get_current_placement_group() scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=current_pg, placement_group_capture_child_tasks=False, - placement_group_bundle_index=self.engine_bundle_idxs[0], + placement_group_bundle_index=self.server_launch_spec.placement_group_bundle_idxs[0], ) self.session_server_actor = ( ray.remote(SessionServerActor) @@ -907,9 +706,6 @@ def _stop_session_server(self) -> None: self.session_server_actor = None self.session_server_url = None - def get_session_server_info(self) -> tuple[int, str | None]: - return self.rank, self.session_server_url - async def pause_generation(self): """Pause the worker's generation process.""" self.receive_abort_request.set() @@ -1218,11 +1014,12 @@ def _launch_server(self): else: # launch the server as ray task # so that the lmdeploy backend could get externl pg + assert self.server_launch_spec is not None current_pg = ray.util.get_current_placement_group() scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=current_pg, placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], + placement_group_bundle_index=self.server_launch_spec.placement_group_bundle_idxs[0], ) assert ray.is_initialized() ray_kwargs = ( diff --git a/xtuner/v1/rl/rollout/worker_registry.py b/xtuner/v1/rl/rollout/worker_registry.py index 4770dd59b..25f6a785d 100644 --- a/xtuner/v1/rl/rollout/worker_registry.py +++ b/xtuner/v1/rl/rollout/worker_registry.py @@ -3,13 +3,15 @@ import threading from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Iterable, TypedDict +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .rollout_topology import RolloutTopology from .worker import RolloutConfig, RolloutWorker __all__ = [ + "RolloutWorkerEndpointMetadata", "RolloutWorkerMetadata", "RolloutWorkerRegistry", "WorkerGroup", @@ -31,20 +33,12 @@ class WorkerLifecycleState(str, Enum): class WorkerSnapshot: """Read-only snapshot for one rollout server process.""" + rank: int actor: RolloutWorker url: str session_url: str | None = None - lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE - lifecycle_group_ranks: tuple[int, ...] = () is_request_entrypoint: bool = True - rank: int = -1 - - def __post_init__(self) -> None: - lifecycle_state = ( - WorkerLifecycleState.ACTIVE if self.lifecycle_state is None else WorkerLifecycleState(self.lifecycle_state) - ) - object.__setattr__(self, "lifecycle_state", lifecycle_state) - object.__setattr__(self, "lifecycle_group_ranks", tuple(self.lifecycle_group_ranks)) + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE def is_active(self) -> bool: return self.lifecycle_state is WorkerLifecycleState.ACTIVE @@ -56,31 +50,49 @@ class WorkerGroup: workers: tuple[WorkerSnapshot, ...] -class RolloutWorkerMetadata(TypedDict): - """Legacy rollout worker metadata consumed by trainer/update-weight - code.""" +@dataclass(frozen=True) +class RolloutWorkerEndpointMetadata: + """URL and lifecycle state for one request-serving rollout endpoint.""" - engine_rank_mesh_array: list[list[int]] - server_url_dict: dict[int, str] - rollout_config: RolloutConfig - worker_server_urls_status: dict[str, bool] - worker_session_url_dict: dict[int, str] - worker_session_urls_status: dict[str, bool] + rank: int + server_url: str + session_url: str | None + lifecycle_state: WorkerLifecycleState + + @property + def is_active(self) -> bool: + return self.lifecycle_state is WorkerLifecycleState.ACTIVE -def _build_worker_groups(workers: Iterable[WorkerSnapshot]) -> dict[tuple[int, ...], WorkerGroup]: - grouped_workers: dict[tuple[int, ...], list[WorkerSnapshot]] = {} - for worker in workers: - group_ranks = worker.lifecycle_group_ranks or (worker.rank,) - grouped_workers.setdefault(group_ranks, []).append(worker) +@dataclass(frozen=True) +class RolloutWorkerMetadata: + """Structured rollout worker metadata consumed by trainer/update-weight + code.""" - return { - group_ranks: WorkerGroup( - ranks=group_ranks, - workers=tuple(sorted(group_workers, key=lambda worker: worker.rank)), - ) - for group_ranks, group_workers in grouped_workers.items() - } + rollout_config: RolloutConfig + training_engine_mesh: tuple[tuple[int, ...], ...] + request_endpoints: tuple[RolloutWorkerEndpointMetadata, ...] + + def to_legacy(self) -> dict[str, Any]: + """Serialize to the current trainer-facing rollout metadata dict.""" + return { + "engine_rank_mesh_array": [list(engine_ranks) for engine_ranks in self.training_engine_mesh], + "server_url_dict": {endpoint.rank: endpoint.server_url for endpoint in self.request_endpoints}, + "rollout_config": self.rollout_config, + "worker_server_urls_status": { + endpoint.server_url: endpoint.is_active for endpoint in self.request_endpoints + }, + "worker_session_url_dict": { + endpoint.rank: endpoint.session_url + for endpoint in self.request_endpoints + if endpoint.session_url is not None + }, + "worker_session_urls_status": { + endpoint.session_url: endpoint.is_active + for endpoint in self.request_endpoints + if endpoint.session_url is not None + }, + } class RolloutWorkerRegistry: @@ -90,12 +102,11 @@ class RolloutWorkerRegistry: def __init__( self, *, - engine_rank_mesh_array: list[list[int]], + rollout_topology: RolloutTopology, rollout_config: RolloutConfig, ): - """Initialize an empty registry with the training-side metadata - projection.""" - self._engine_rank_mesh_array = [list(engine_ranks) for engine_ranks in engine_rank_mesh_array] + """Initialize an empty registry with the rollout topology.""" + self._rollout_topology = rollout_topology self._rollout_config = rollout_config self._workers: dict[int, WorkerSnapshot] = {} self._lock = threading.RLock() @@ -107,8 +118,7 @@ def register_started_server( actor: RolloutWorker, server_url: str, session_url: str | None = None, - lifecycle_group_ranks: tuple[int, ...] = (), - is_request_entrypoint: bool = True, + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE, ) -> None: """Register one worker actor after its rollout server process has started.""" @@ -118,8 +128,8 @@ def register_started_server( actor=actor, url=server_url, session_url=session_url, - lifecycle_group_ranks=lifecycle_group_ranks or (rank,), - is_request_entrypoint=is_request_entrypoint, + is_request_entrypoint=self._rollout_topology.is_request_entrypoint_rank(rank), + lifecycle_state=lifecycle_state, ) def all_workers(self) -> tuple[WorkerSnapshot, ...]: @@ -162,11 +172,28 @@ def active_entrypoint_by_rank(self, rank: int) -> WorkerSnapshot | None: return None return worker + def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: + """Return registered lifecycle groups in rank order.""" + with self._lock: + return tuple(sorted(self._rollout_topology.lifecycle_groups())) + + def _build_worker_groups(self) -> dict[tuple[int, ...], WorkerGroup]: + grouped_ranks = { + self._rollout_topology.lifecycle_group_for_server_rank(worker.rank) for worker in self._workers.values() + } + return { + group_ranks: WorkerGroup( + ranks=group_ranks, + workers=tuple(self._workers[rank] for rank in group_ranks if rank in self._workers), + ) + for group_ranks in grouped_ranks + } + def claim_inactive_groups_for_recovery(self) -> tuple[WorkerGroup, ...]: """Claim non-active worker groups by moving them to recovering state.""" with self._lock: - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() inactive_groups = [ group for group in worker_groups.values() @@ -184,16 +211,14 @@ def mark_unhealthy_ranks(self, ranks: set[int]) -> tuple[WorkerGroup, ...]: """Mark every lifecycle group containing a failed rank as inactive.""" with self._lock: failed_group_ranks = { - worker.lifecycle_group_ranks or (worker.rank,) - for rank, worker in self._workers.items() - if rank in ranks + self._rollout_topology.lifecycle_group_for_server_rank(rank) for rank in ranks if rank in self._workers } for group_ranks in failed_group_ranks: for rank in group_ranks: worker = self._workers.get(rank) if worker is not None: self._workers[rank] = replace(worker, lifecycle_state=WorkerLifecycleState.INACTIVE) - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() return tuple( worker_groups[group_ranks] for group_ranks in sorted(failed_group_ranks) @@ -214,29 +239,24 @@ def set_group_recovery_result( worker = self._workers.get(rank) if worker is not None: self._workers[rank] = replace(worker, lifecycle_state=lifecycle_state) - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() return worker_groups.get(group.ranks) - def training_metadata_snapshot(self) -> RolloutWorkerMetadata: - """Build the legacy trainer/update-weight metadata from one registry - snapshot.""" + def metadata(self) -> RolloutWorkerMetadata: + """Build trainer/update-weight metadata from one registry snapshot.""" with self._lock: - request_entrypoints = {rank: info for rank, info in self._workers.items() if info.is_request_entrypoint} - worker_server_urls_map = {rank: info.url for rank, info in request_entrypoints.items()} - worker_server_urls_status = {info.url: info.is_active() for info in request_entrypoints.values()} - worker_session_url_dict: dict[int, str] = {} - worker_session_urls_status: dict[str, bool] = {} - for rank, info in request_entrypoints.items(): - if info.session_url is None: - continue - worker_session_url_dict[rank] = info.session_url - worker_session_urls_status[info.session_url] = info.is_active() - - return { - "engine_rank_mesh_array": [list(engine_ranks) for engine_ranks in self._engine_rank_mesh_array], - "server_url_dict": worker_server_urls_map, - "rollout_config": self._rollout_config, - "worker_server_urls_status": worker_server_urls_status, - "worker_session_url_dict": worker_session_url_dict, - "worker_session_urls_status": worker_session_urls_status, - } + request_endpoints = tuple( + RolloutWorkerEndpointMetadata( + rank=worker.rank, + server_url=worker.url, + session_url=worker.session_url, + lifecycle_state=worker.lifecycle_state, + ) + for worker in self.all_workers() + if worker.is_request_entrypoint + ) + return RolloutWorkerMetadata( + rollout_config=self._rollout_config, + training_engine_mesh=self._rollout_topology.training_engine_mesh, + request_endpoints=request_endpoints, + ) From 06aa3a53db4dbf0b35c1c01b20518481884fbc7e Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Jun 2026 10:37:09 +0000 Subject: [PATCH 2/5] Project rollout weight update targets from registry --- xtuner/v1/rl/rollout/controller.py | 5 +++++ xtuner/v1/rl/rollout/lmdeploy.py | 2 ++ xtuner/v1/rl/rollout/rollout_topology.py | 12 ++++++++++++ xtuner/v1/rl/rollout/sglang.py | 1 + xtuner/v1/rl/rollout/worker_registry.py | 24 ++++++++++++++++++++++++ xtuner/v1/rl/weight_update/__init__.py | 2 ++ xtuner/v1/rl/weight_update/data.py | 22 ++++++++++++++++++++++ 7 files changed, 68 insertions(+) diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index 6b3696f63..e85bbbb91 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -8,6 +8,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status from xtuner.v1.rl.utils import AutoAcceleratorWorkers +from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger from .constants import ROLLOUT_RAY_GENERATE_MAX_CONCURRENCY @@ -71,6 +72,10 @@ def get_rollout_metadata(self) -> dict: """ return self.registry.metadata().to_legacy() + def get_weight_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: + """Return rollout endpoints that can receive weight update requests.""" + return self.registry.weight_update_targets() + def register_active_workers_to_proxy(self) -> None: if self.proxy_manager is None: return diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 044eb1c19..aea6540b1 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -111,6 +111,7 @@ def build_rollout_topology( RolloutTopology.server_process( worker_rank=engine_ranks[0], placement_group_bundle_idxs=engine_bundle_idxs, + weight_update_ranks=engine_ranks, ), ), ) @@ -133,6 +134,7 @@ def build_rollout_topology( RolloutTopology.server_process( worker_rank=server_rank, placement_group_bundle_idxs=(bundle_idx,), + weight_update_ranks=(server_rank,), ) for server_rank, bundle_idx in engine_meta ), diff --git a/xtuner/v1/rl/rollout/rollout_topology.py b/xtuner/v1/rl/rollout/rollout_topology.py index db54f8050..bdbfd6a8a 100644 --- a/xtuner/v1/rl/rollout/rollout_topology.py +++ b/xtuner/v1/rl/rollout/rollout_topology.py @@ -12,10 +12,15 @@ class _ServerProcess: worker_rank: int placement_group_bundle_idxs: tuple[int, ...] + weight_update_ranks: tuple[int, ...] = () accepts_rollout_requests: bool = True node_rank: int = 0 nnodes: int = 1 + @property + def is_weight_update_endpoint(self) -> bool: + return bool(self.weight_update_ranks) + @dataclass(frozen=True) class _Engine: @@ -82,6 +87,7 @@ def server_process( *, worker_rank: int, placement_group_bundle_idxs: tuple[int, ...], + weight_update_ranks: tuple[int, ...] = (), accepts_rollout_requests: bool = True, node_rank: int = 0, nnodes: int = 1, @@ -89,6 +95,7 @@ def server_process( return _ServerProcess( worker_rank=worker_rank, placement_group_bundle_idxs=placement_group_bundle_idxs, + weight_update_ranks=weight_update_ranks, accepts_rollout_requests=accepts_rollout_requests, node_rank=node_rank, nnodes=nnodes, @@ -111,6 +118,11 @@ def server_launch_specs(self) -> tuple[ServerLaunchSpec, ...]: def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: return tuple(dict.fromkeys(self._lifecycle_group_by_rank.values())) + def weight_update_endpoint_processes(self) -> tuple[_ServerProcess, ...]: + return tuple( + server for engine in self.engines for server in engine.server_processes if server.is_weight_update_endpoint + ) + def is_request_entrypoint_rank(self, rank: int) -> bool: server = self._server_process_by_rank.get(rank) return server is not None and server.accepts_rollout_requests diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index 89dd125c7..fb54de8de 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -80,6 +80,7 @@ def build_rollout_topology( RolloutTopology.server_process( worker_rank=server_rank, placement_group_bundle_idxs=engine_bundle_idxs[node_bundle_start:node_bundle_end], + weight_update_ranks=engine_ranks if node_rank == 0 else (), accepts_rollout_requests=node_rank == 0, node_rank=node_rank, nnodes=nnodes, diff --git a/xtuner/v1/rl/rollout/worker_registry.py b/xtuner/v1/rl/rollout/worker_registry.py index 25f6a785d..2160545b2 100644 --- a/xtuner/v1/rl/rollout/worker_registry.py +++ b/xtuner/v1/rl/rollout/worker_registry.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: + from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget + from .rollout_topology import RolloutTopology from .worker import RolloutConfig, RolloutWorker @@ -260,3 +262,25 @@ def metadata(self) -> RolloutWorkerMetadata: training_engine_mesh=self._rollout_topology.training_engine_mesh, request_endpoints=request_endpoints, ) + + def weight_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: + """Return weight-update targets resolved with current runtime state.""" + from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget + + with self._lock: + targets: list[RolloutWeightUpdateTarget] = [] + for server in self._rollout_topology.weight_update_endpoint_processes(): + worker = self._workers.get(server.worker_rank) + if worker is None: + raise RuntimeError( + f"Rollout weight update endpoint rank={server.worker_rank} has not been registered." + ) + targets.append( + RolloutWeightUpdateTarget( + endpoint_rank=server.worker_rank, + update_ranks=server.weight_update_ranks, + server_url=worker.url, + lifecycle_state=worker.lifecycle_state.value, + ) + ) + return tuple(sorted(targets, key=lambda target: target.endpoint_rank)) diff --git a/xtuner/v1/rl/weight_update/__init__.py b/xtuner/v1/rl/weight_update/__init__.py index 312f779fe..c5e85c3c3 100644 --- a/xtuner/v1/rl/weight_update/__init__.py +++ b/xtuner/v1/rl/weight_update/__init__.py @@ -3,6 +3,7 @@ RolloutBackend, RolloutEngineInfo, RolloutWeightUpdateInfo, + RolloutWeightUpdateTarget, ServiceUrlMap, TrainRolloutMode, WeightTransportType, @@ -33,6 +34,7 @@ "RolloutBackend", "RolloutEngineInfo", "RolloutWeightUpdateInfo", + "RolloutWeightUpdateTarget", "SGLangIPCBackendAdapter", "SGLangNCCLBackendAdapter", "ServiceUrlMap", diff --git a/xtuner/v1/rl/weight_update/data.py b/xtuner/v1/rl/weight_update/data.py index 6041ff643..65b05430c 100644 --- a/xtuner/v1/rl/weight_update/data.py +++ b/xtuner/v1/rl/weight_update/data.py @@ -15,6 +15,28 @@ WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types. +@dataclass(frozen=True) +class RolloutWeightUpdateTarget: + """Runtime weight-update endpoint resolved from rollout registry state.""" + + # Server-process worker rank that receives weight update requests. + endpoint_rank: int + # Rollout ranks updated through this endpoint. + update_ranks: tuple[int, ...] + # Runtime rollout server URL resolved from WorkerSnapshot. + server_url: str + # Registry lifecycle state value for this endpoint. + lifecycle_state: str + + @property + def is_active(self) -> bool: + return self.lifecycle_state == "active" + + @property + def engine_size(self) -> int: + return len(self.update_ranks) + + @dataclass class RolloutWeightUpdateInfo: # Common rollout metadata. From 6520b277dde79eab99e62c5f5ed9cefff42dc228 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Jun 2026 10:39:49 +0000 Subject: [PATCH 3/5] Bind rollout weight updates with structured targets --- xtuner/v1/rl/trainer/controller.py | 22 ++++++++++++++++++ xtuner/v1/rl/trainer/worker.py | 4 ++++ xtuner/v1/rl/utils/ray_utils.py | 19 +++++++++++++-- xtuner/v1/rl/weight_update/update_weighter.py | 23 +++++++++++++++++++ xtuner/v1/train/rl_trainer.py | 14 +++++++---- 5 files changed, 76 insertions(+), 6 deletions(-) diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index d0965f216..dcb8601fe 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -303,6 +303,28 @@ def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host= ] ) + def bind_rollout_weight_update( + self, + *, + targets, + rollout_config, + train_rollout_mode, + weight_update_host=None, + weight_update_port=None, + ): + ray.get( + [ + worker.bind_rollout_weight_update.remote( + targets=targets, + rollout_config=rollout_config, + train_rollout_mode=train_rollout_mode, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, + ) + for worker in self.workers + ] + ) + def update_weights(self): """Update the weights of the training workers.""" handles = [worker.update_weights.remote() for worker in self.workers] diff --git a/xtuner/v1/rl/trainer/worker.py b/xtuner/v1/rl/trainer/worker.py index 00ce2e067..c2b1909cf 100644 --- a/xtuner/v1/rl/trainer/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -280,6 +280,10 @@ def __init__( def update_rollout_info(self, *args, **kwargs): return self.update_weighter.update_rollout_info(*args, **kwargs) + @ray_method + def bind_rollout_weight_update(self, *args, **kwargs): + return self.update_weighter.bind_rollout_weight_update(*args, **kwargs) + @ray_method def update_weights(self): return self.update_weighter.update_weights() diff --git a/xtuner/v1/rl/utils/ray_utils.py b/xtuner/v1/rl/utils/ray_utils.py index ad6b1dcc9..039016da9 100644 --- a/xtuner/v1/rl/utils/ray_utils.py +++ b/xtuner/v1/rl/utils/ray_utils.py @@ -159,6 +159,10 @@ def signal_handler(signum, frame): def bind_train_rollout( train_workers, rollout_controller, + rollout_config, + train_rollout_mode, + weight_update_host=None, + weight_update_port=None, ) -> None: """Bind the training and rollout workers for updating weights. @@ -170,6 +174,17 @@ def bind_train_rollout( train_workers: A list of training worker actors. rollout_controller: The rollout controller actor. """ - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] - ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) # type: ignore[attr-defined] + ray.get( + [ + worker.bind_rollout_weight_update.remote( + targets=targets, + rollout_config=rollout_config, + train_rollout_mode=train_rollout_mode, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, + ) + for worker in train_workers + ] + ) # type: ignore[attr-defined] return diff --git a/xtuner/v1/rl/weight_update/update_weighter.py b/xtuner/v1/rl/weight_update/update_weighter.py index 7adb7ba8d..1732a2b15 100644 --- a/xtuner/v1/rl/weight_update/update_weighter.py +++ b/xtuner/v1/rl/weight_update/update_weighter.py @@ -11,6 +11,7 @@ DeviceMeshRaw, RolloutBackend, RolloutWeightUpdateInfo, + RolloutWeightUpdateTarget, ServiceUrlMap, TrainRolloutMode, ) @@ -121,6 +122,28 @@ def update_rollout_info( if self._transport is None: self._set_transport() + def bind_rollout_weight_update( + self, + *, + targets: tuple[RolloutWeightUpdateTarget, ...], + rollout_config: RolloutConfig, + train_rollout_mode: TrainRolloutMode, + weight_update_host: str | None = None, + weight_update_port: int | None = None, + ): + """Bind rollout weight-update targets through the legacy metadata + path.""" + + self.update_rollout_info( + engine_rank_mesh_array=[list(target.update_ranks) for target in targets], + server_url_dict={target.endpoint_rank: target.server_url for target in targets}, + rollout_config=rollout_config, + worker_server_urls_status={target.server_url: target.is_active for target in targets}, + train_rollout_mode=train_rollout_mode, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, + ) + def _ensure_rollout_device_mesh(self): if self.rollout_info.rollout_device_mesh is None: # 非共卡 SGLang 不使用这个 mesh;只有共卡/旧权重同步路径需要 diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index fdb49f3c7..9a8d312aa 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -109,17 +109,19 @@ def check_fa3(): def bind_train_rollout( train_controller: TrainingController, rollout_controller: RolloutControllerProxy, + rollout_config: RolloutConfig, train_rollout_mode: TrainRolloutMode | str, weight_update_host: str | None = None, weight_update_port: int | None = None, ) -> None: """Bind the training and rollout workers for update weights.""" - info_dict = ray.get( - rollout_controller.get_rollout_metadata.remote(), # type: ignore[attr-defined] + targets = ray.get( + rollout_controller.get_weight_update_targets.remote(), # type: ignore[attr-defined] timeout=RL_TRAINER_RAY_GET_TIMEOUT, ) - train_controller.update_rollout_info( - info_dict, + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=rollout_config, train_rollout_mode=train_rollout_mode, weight_update_host=weight_update_host, weight_update_port=weight_update_port, @@ -1561,6 +1563,7 @@ def __init__(self, cfg: RLColocateTrainerConfig): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, + rollout_config=self._rollout_config, train_rollout_mode="colocate", ) @@ -1721,6 +1724,7 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, + rollout_config=self._rollout_config, train_rollout_mode="colocate", ) ray.get( @@ -1768,6 +1772,7 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, + rollout_config=self._rollout_config, train_rollout_mode="disaggregated", weight_update_host=self._rollout_config.weight_update_host, weight_update_port=self._rollout_config.weight_update_port, @@ -1962,6 +1967,7 @@ async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, + rollout_config=self._rollout_config, train_rollout_mode="disaggregated", ) self.update_weights() From 049ba1ab6b6f8fd1e8594dc010d83241482a314b Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Jun 2026 10:40:41 +0000 Subject: [PATCH 4/5] Derive weight update metadata from rollout targets --- xtuner/v1/rl/rollout/controller.py | 19 +- xtuner/v1/rl/rollout/lmdeploy.py | 26 +-- xtuner/v1/rl/rollout/rollout_topology.py | 146 +++++++++---- xtuner/v1/rl/rollout/sglang.py | 9 +- xtuner/v1/rl/rollout/worker_registry.py | 80 +------ xtuner/v1/rl/trainer/controller.py | 13 -- xtuner/v1/rl/trainer/worker.py | 8 - xtuner/v1/rl/weight_update/__init__.py | 8 +- xtuner/v1/rl/weight_update/data.py | 202 +++++++++++++++--- xtuner/v1/rl/weight_update/transport.py | 46 ++-- xtuner/v1/rl/weight_update/update_weighter.py | 187 +++------------- xtuner/v1/rl/weight_update/weight_iterator.py | 11 +- 12 files changed, 347 insertions(+), 408 deletions(-) diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index e85bbbb91..f58bb679f 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -64,14 +64,6 @@ def __init__( ) self.health_manager.start() - def get_rollout_metadata(self) -> dict: - """Get information about the current rollout setup. - - Returns: - Legacy trainer/update-weight rollout metadata dictionary. - """ - return self.registry.metadata().to_legacy() - def get_weight_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: """Return rollout endpoints that can receive weight update requests.""" return self.registry.weight_update_targets() @@ -216,8 +208,7 @@ def _init_workers( URLs to rollout traffic. Returns: - A registry containing all server-process workers and the public - training metadata mesh. + A registry containing all server-process workers and runtime state. """ worker_base_cls = get_rollout_worker_base_cls(self.config) worker_cls = self._build_remote_worker_cls(worker_base_cls) @@ -261,7 +252,7 @@ def _init_workers( ] ) ) - registry = RolloutWorkerRegistry(rollout_topology=rollout_topology, rollout_config=self.config) + registry = RolloutWorkerRegistry(rollout_topology=rollout_topology) for init_result in init_results: if rollout_topology.is_request_entrypoint_rank(init_result.rank) and init_result.session_url is None: raise RuntimeError( @@ -274,12 +265,10 @@ def _init_workers( session_url=init_result.session_url, ) - rollout_metadata = registry.metadata() - legacy_metadata = rollout_metadata.to_legacy() self.logger.info( "Rollout worker registry snapshot: " - f"server_urls={legacy_metadata['server_url_dict']}, " - f"session_urls={legacy_metadata['worker_session_url_dict']}, " + f"weight_update_targets={registry.weight_update_targets()}, " + f"active_entrypoints={registry.active_entrypoints()}, " f"server_process_urls={[worker.url for worker in registry.all_workers()]}, " f"lifecycle_groups={registry.lifecycle_groups()}" ) diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index aea6540b1..31b847130 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -10,7 +10,7 @@ from transformers import AutoTokenizer from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .rollout_topology import RolloutTopology +from .rollout_topology import RolloutEngine, RolloutServerProcess, RolloutTopology from .worker import RolloutConfig, RolloutWorker @@ -89,7 +89,7 @@ def build_rollout_topology( ) -> RolloutTopology: """Build LMDeploy rollout topology with bound engine dist-init addresses.""" - engines = [] + engines: list[RolloutEngine] = [] num_workers = len(rank_bundle_idx_list) if config.expert_parallel_size <= 1: num_gpus_per_engine = config.num_gpus_per_engine @@ -104,11 +104,11 @@ def build_rollout_topology( engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) dist_init_addr_owner_rank = engine_ranks[0] engines.append( - RolloutTopology.engine( + RolloutEngine( engine_ranks=engine_ranks, dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], server_processes=( - RolloutTopology.server_process( + RolloutServerProcess( worker_rank=engine_ranks[0], placement_group_bundle_idxs=engine_bundle_idxs, weight_update_ranks=engine_ranks, @@ -127,11 +127,11 @@ def build_rollout_topology( engine_ranks = tuple(rank for rank, _ in engine_meta) dist_init_addr_owner_rank = engine_ranks[0] engines.append( - RolloutTopology.engine( + RolloutEngine( engine_ranks=engine_ranks, dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], server_processes=tuple( - RolloutTopology.server_process( + RolloutServerProcess( worker_rank=server_rank, placement_group_bundle_idxs=(bundle_idx,), weight_update_ranks=(server_rank,), @@ -141,19 +141,7 @@ def build_rollout_topology( ) ) - training_engine_mesh: list[tuple[int, ...]] = [] - for engine in engines: - entrypoint_processes = tuple( - process for process in engine.server_processes if process.accepts_rollout_requests - ) - if len(entrypoint_processes) == 1: - training_engine_mesh.append(tuple(engine.engine_ranks)) - else: - training_engine_mesh.extend((process.worker_rank,) for process in entrypoint_processes) - return RolloutTopology( - engines=tuple(engines), - training_engine_mesh=tuple(training_engine_mesh), - ) + return RolloutTopology(engines=tuple(engines)) def offload(self): """Offloads the model weights and KV cache.""" diff --git a/xtuner/v1/rl/rollout/rollout_topology.py b/xtuner/v1/rl/rollout/rollout_topology.py index bdbfd6a8a..5dd70320e 100644 --- a/xtuner/v1/rl/rollout/rollout_topology.py +++ b/xtuner/v1/rl/rollout/rollout_topology.py @@ -3,18 +3,29 @@ from dataclasses import dataclass, field -__all__ = ["RolloutTopology", "ServerLaunchSpec"] +__all__ = [ + "RolloutEngine", + "RolloutServerProcess", + "RolloutTopology", + "ServerLaunchSpec", +] @dataclass(frozen=True) -class _ServerProcess: - """Rollout topology expression for one worker-owned server process.""" +class RolloutServerProcess: + """Static topology for one worker-owned rollout server process.""" + # Worker rank that owns and starts this server process. worker_rank: int + # Placement-group bundles assigned to this server process. placement_group_bundle_idxs: tuple[int, ...] - weight_update_ranks: tuple[int, ...] = () + # Rollout ranks updated through this server process. + weight_update_ranks: tuple[int, ...] + # Whether this server process can receive rollout generation requests. accepts_rollout_requests: bool = True + # Node index used by backends that launch one server process per node. node_rank: int = 0 + # Number of nodes participating in this server launch. nnodes: int = 1 @property @@ -23,42 +34,117 @@ def is_weight_update_endpoint(self) -> bool: @dataclass(frozen=True) -class _Engine: - """Rollout layout for one logical inference engine.""" +class RolloutEngine: + """Static topology for one logical inference engine.""" + # Rollout ranks that jointly form this logical inference engine. engine_ranks: tuple[int, ...] + # Rendezvous address shared by every server process in this engine. dist_init_addr: str - server_processes: tuple[_ServerProcess, ...] + # Server processes that expose this engine to rollout traffic or control paths. + server_processes: tuple[RolloutServerProcess, ...] @dataclass(frozen=True) class ServerLaunchSpec: """Worker-facing launch data projected from rollout topology.""" + # Worker rank that should receive this launch spec. worker_rank: int + # Placement-group bundles assigned to the launched server process. placement_group_bundle_idxs: tuple[int, ...] + # Engine rendezvous address resolved by RolloutTopology. dist_init_addr: str + # Rank of this worker inside the logical inference engine. engine_rank: int + # Node index for multi-node backend launches. node_rank: int = 0 + # Number of nodes for multi-node backend launches. nnodes: int = 1 @dataclass(frozen=True) class RolloutTopology: - """Immutable rollout engine layout after dist-init addresses are resolved. + """Immutable rollout topology after dist-init addresses are resolved. Actor handles, server URLs, session URLs, and lifecycle state belong to RolloutWorkerRegistry. """ - engines: tuple[_Engine, ...] - training_engine_mesh: tuple[tuple[int, ...], ...] - _server_process_by_rank: dict[int, _ServerProcess] = field(init=False, repr=False, compare=False) + # Logical inference engines and their server-process topology. + engines: tuple[RolloutEngine, ...] + # Server-process lookup keyed by worker rank. + _server_process_by_rank: dict[int, RolloutServerProcess] = field(init=False, repr=False, compare=False) + # Lifecycle group lookup keyed by server-process worker rank. _lifecycle_group_by_rank: dict[int, tuple[int, ...]] = field(init=False, repr=False, compare=False) def __post_init__(self) -> None: - server_process_by_rank: dict[int, _ServerProcess] = {} + if not self.engines: + raise ValueError("RolloutTopology must define at least one engine.") + + seen_engine_ranks: set[int] = set() + seen_bundle_idxs: set[int] = set() + server_process_by_rank: dict[int, RolloutServerProcess] = {} lifecycle_group_by_rank: dict[int, tuple[int, ...]] = {} - for engine in self.engines: + for engine_index, engine in enumerate(self.engines): + if not engine.engine_ranks: + raise ValueError(f"RolloutTopology engine[{engine_index}] must define at least one engine rank.") + if len(set(engine.engine_ranks)) != len(engine.engine_ranks): + raise ValueError( + f"RolloutTopology engine[{engine_index}] has duplicate engine ranks: {engine.engine_ranks}." + ) + duplicate_engine_ranks = sorted(set(engine.engine_ranks).intersection(seen_engine_ranks)) + if duplicate_engine_ranks: + raise ValueError( + f"RolloutTopology engine[{engine_index}] engine ranks appear in more than one engine: " + f"{duplicate_engine_ranks}." + ) + seen_engine_ranks.update(engine.engine_ranks) + + if not engine.server_processes: + raise ValueError(f"RolloutTopology engine[{engine_index}] must define at least one server process.") + if not any(server.accepts_rollout_requests for server in engine.server_processes): + raise ValueError( + f"RolloutTopology engine[{engine_index}] must expose at least one request entrypoint." + ) + + engine_rank_set = set(engine.engine_ranks) + covered_update_ranks: set[int] = set() + for server in engine.server_processes: + duplicate_bundle_idxs = sorted(set(server.placement_group_bundle_idxs).intersection(seen_bundle_idxs)) + if duplicate_bundle_idxs: + raise ValueError( + f"RolloutTopology engine[{engine_index}] server worker_rank={server.worker_rank} " + f"reuses placement-group bundle indexes: {duplicate_bundle_idxs}." + ) + seen_bundle_idxs.update(server.placement_group_bundle_idxs) + if server.worker_rank not in engine_rank_set: + raise ValueError( + f"RolloutTopology engine[{engine_index}] server worker_rank={server.worker_rank} " + f"is not in engine_ranks={engine.engine_ranks}." + ) + unknown_update_ranks = sorted( + rank for rank in server.weight_update_ranks if rank not in engine_rank_set + ) + if unknown_update_ranks: + raise ValueError( + f"RolloutTopology engine[{engine_index}] server worker_rank={server.worker_rank} " + f"references unknown weight_update_ranks={unknown_update_ranks}." + ) + duplicate_update_ranks = sorted(set(server.weight_update_ranks).intersection(covered_update_ranks)) + if duplicate_update_ranks: + raise ValueError( + f"RolloutTopology engine[{engine_index}] has duplicate weight_update_ranks=" + f"{duplicate_update_ranks}." + ) + covered_update_ranks.update(server.weight_update_ranks) + + if covered_update_ranks != engine_rank_set: + missing_update_ranks = sorted(engine_rank_set.difference(covered_update_ranks)) + raise ValueError( + f"RolloutTopology engine[{engine_index}] weight_update_ranks do not cover engine ranks: " + f"{missing_update_ranks}." + ) + lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) for server in engine.server_processes: if server.worker_rank in server_process_by_rank: @@ -69,38 +155,6 @@ def __post_init__(self) -> None: object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) - @staticmethod - def engine( - *, - engine_ranks: tuple[int, ...], - dist_init_addr: str, - server_processes: tuple[_ServerProcess, ...], - ) -> _Engine: - return _Engine( - engine_ranks=engine_ranks, - dist_init_addr=dist_init_addr, - server_processes=server_processes, - ) - - @staticmethod - def server_process( - *, - worker_rank: int, - placement_group_bundle_idxs: tuple[int, ...], - weight_update_ranks: tuple[int, ...] = (), - accepts_rollout_requests: bool = True, - node_rank: int = 0, - nnodes: int = 1, - ) -> _ServerProcess: - return _ServerProcess( - worker_rank=worker_rank, - placement_group_bundle_idxs=placement_group_bundle_idxs, - weight_update_ranks=weight_update_ranks, - accepts_rollout_requests=accepts_rollout_requests, - node_rank=node_rank, - nnodes=nnodes, - ) - def server_launch_specs(self) -> tuple[ServerLaunchSpec, ...]: return tuple( ServerLaunchSpec( @@ -118,7 +172,7 @@ def server_launch_specs(self) -> tuple[ServerLaunchSpec, ...]: def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: return tuple(dict.fromkeys(self._lifecycle_group_by_rank.values())) - def weight_update_endpoint_processes(self) -> tuple[_ServerProcess, ...]: + def weight_update_endpoint_processes(self) -> tuple[RolloutServerProcess, ...]: return tuple( server for engine in self.engines for server in engine.server_processes if server.is_weight_update_endpoint ) diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index fb54de8de..87c1ff747 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -11,7 +11,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import XTUNER_DETERMINISTIC -from .rollout_topology import RolloutTopology +from .rollout_topology import RolloutEngine, RolloutServerProcess, RolloutTopology from .worker import RolloutConfig, RolloutWorker @@ -77,17 +77,17 @@ def build_rollout_topology( node_bundle_start = node_rank * config.gpus_per_node node_bundle_end = node_bundle_start + config.gpus_per_node server_processes.append( - RolloutTopology.server_process( + RolloutServerProcess( worker_rank=server_rank, placement_group_bundle_idxs=engine_bundle_idxs[node_bundle_start:node_bundle_end], - weight_update_ranks=engine_ranks if node_rank == 0 else (), accepts_rollout_requests=node_rank == 0, + weight_update_ranks=engine_ranks if node_rank == 0 else (), node_rank=node_rank, nnodes=nnodes, ) ) engines.append( - RolloutTopology.engine( + RolloutEngine( engine_ranks=engine_ranks, dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], server_processes=tuple(server_processes), @@ -95,7 +95,6 @@ def build_rollout_topology( ) return RolloutTopology( engines=tuple(engines), - training_engine_mesh=tuple(tuple(engine.engine_ranks) for engine in engines), ) def _get_request_payload(self, rollout_state: RolloutState) -> dict: diff --git a/xtuner/v1/rl/rollout/worker_registry.py b/xtuner/v1/rl/rollout/worker_registry.py index 2160545b2..2769dd888 100644 --- a/xtuner/v1/rl/rollout/worker_registry.py +++ b/xtuner/v1/rl/rollout/worker_registry.py @@ -3,18 +3,16 @@ import threading from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget from .rollout_topology import RolloutTopology - from .worker import RolloutConfig, RolloutWorker + from .worker import RolloutWorker __all__ = [ - "RolloutWorkerEndpointMetadata", - "RolloutWorkerMetadata", "RolloutWorkerRegistry", "WorkerGroup", "WorkerLifecycleState", @@ -35,11 +33,17 @@ class WorkerLifecycleState(str, Enum): class WorkerSnapshot: """Read-only snapshot for one rollout server process.""" + # Worker rank that owns the runtime snapshot. rank: int + # Ray actor handle for the rollout worker. actor: RolloutWorker + # Base URL of the rollout server process. url: str + # Session server URL used only by proxy/session routing. session_url: str | None = None + # Whether this worker can receive rollout generation requests. is_request_entrypoint: bool = True + # Current lifecycle state observed by registry and health manager. lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE def is_active(self) -> bool: @@ -48,55 +52,12 @@ def is_active(self) -> bool: @dataclass(frozen=True) class WorkerGroup: + # Worker ranks that share one lifecycle action. ranks: tuple[int, ...] + # Runtime snapshots for registered workers in this lifecycle group. workers: tuple[WorkerSnapshot, ...] -@dataclass(frozen=True) -class RolloutWorkerEndpointMetadata: - """URL and lifecycle state for one request-serving rollout endpoint.""" - - rank: int - server_url: str - session_url: str | None - lifecycle_state: WorkerLifecycleState - - @property - def is_active(self) -> bool: - return self.lifecycle_state is WorkerLifecycleState.ACTIVE - - -@dataclass(frozen=True) -class RolloutWorkerMetadata: - """Structured rollout worker metadata consumed by trainer/update-weight - code.""" - - rollout_config: RolloutConfig - training_engine_mesh: tuple[tuple[int, ...], ...] - request_endpoints: tuple[RolloutWorkerEndpointMetadata, ...] - - def to_legacy(self) -> dict[str, Any]: - """Serialize to the current trainer-facing rollout metadata dict.""" - return { - "engine_rank_mesh_array": [list(engine_ranks) for engine_ranks in self.training_engine_mesh], - "server_url_dict": {endpoint.rank: endpoint.server_url for endpoint in self.request_endpoints}, - "rollout_config": self.rollout_config, - "worker_server_urls_status": { - endpoint.server_url: endpoint.is_active for endpoint in self.request_endpoints - }, - "worker_session_url_dict": { - endpoint.rank: endpoint.session_url - for endpoint in self.request_endpoints - if endpoint.session_url is not None - }, - "worker_session_urls_status": { - endpoint.session_url: endpoint.is_active - for endpoint in self.request_endpoints - if endpoint.session_url is not None - }, - } - - class RolloutWorkerRegistry: """Own runtime rollout worker state and expose consistent query snapshots.""" @@ -105,11 +66,9 @@ def __init__( self, *, rollout_topology: RolloutTopology, - rollout_config: RolloutConfig, ): """Initialize an empty registry with the rollout topology.""" self._rollout_topology = rollout_topology - self._rollout_config = rollout_config self._workers: dict[int, WorkerSnapshot] = {} self._lock = threading.RLock() @@ -244,25 +203,6 @@ def set_group_recovery_result( worker_groups = self._build_worker_groups() return worker_groups.get(group.ranks) - def metadata(self) -> RolloutWorkerMetadata: - """Build trainer/update-weight metadata from one registry snapshot.""" - with self._lock: - request_endpoints = tuple( - RolloutWorkerEndpointMetadata( - rank=worker.rank, - server_url=worker.url, - session_url=worker.session_url, - lifecycle_state=worker.lifecycle_state, - ) - for worker in self.all_workers() - if worker.is_request_entrypoint - ) - return RolloutWorkerMetadata( - rollout_config=self._rollout_config, - training_engine_mesh=self._rollout_topology.training_engine_mesh, - request_endpoints=request_endpoints, - ) - def weight_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: """Return weight-update targets resolved with current runtime state.""" from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index dcb8601fe..8872d4af1 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -290,19 +290,6 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore return - def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None): - ray.get( - [ - worker.update_rollout_info.remote( - **info_dict, - train_rollout_mode=train_rollout_mode, - weight_update_host=weight_update_host, - weight_update_port=weight_update_port, - ) - for worker in self.workers - ] - ) - def bind_rollout_weight_update( self, *, diff --git a/xtuner/v1/rl/trainer/worker.py b/xtuner/v1/rl/trainer/worker.py index c2b1909cf..76e9bd372 100644 --- a/xtuner/v1/rl/trainer/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -7,11 +7,9 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Dict, Iterable, List, Sequence, - TypeAlias, TypedDict, cast, ) @@ -63,8 +61,6 @@ from ..rollout_is import merge_rollout_is_metrics -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() @@ -276,10 +272,6 @@ def __init__( engine=self._engine, ) - @ray_method - def update_rollout_info(self, *args, **kwargs): - return self.update_weighter.update_rollout_info(*args, **kwargs) - @ray_method def bind_rollout_weight_update(self, *args, **kwargs): return self.update_weighter.bind_rollout_weight_update(*args, **kwargs) diff --git a/xtuner/v1/rl/weight_update/__init__.py b/xtuner/v1/rl/weight_update/__init__.py index c5e85c3c3..2b36f5e95 100644 --- a/xtuner/v1/rl/weight_update/__init__.py +++ b/xtuner/v1/rl/weight_update/__init__.py @@ -1,10 +1,7 @@ from .data import ( - DeviceMeshRaw, RolloutBackend, - RolloutEngineInfo, RolloutWeightUpdateInfo, RolloutWeightUpdateTarget, - ServiceUrlMap, TrainRolloutMode, WeightTransportType, WeightUpdateBatch, @@ -25,19 +22,16 @@ __all__ = [ - "DeviceMeshRaw", "IPCBackendAdapter", "IPCWeightTransport", "LMDeployIPCBackendAdapter", "NCCLBackendAdapter", "NCCLWeightTransport", "RolloutBackend", - "RolloutEngineInfo", - "RolloutWeightUpdateInfo", "RolloutWeightUpdateTarget", + "RolloutWeightUpdateInfo", "SGLangIPCBackendAdapter", "SGLangNCCLBackendAdapter", - "ServiceUrlMap", "TrainRolloutMode", "UpdateWeighter", "WeightIterator", diff --git a/xtuner/v1/rl/weight_update/data.py b/xtuner/v1/rl/weight_update/data.py index 65b05430c..18b18b712 100644 --- a/xtuner/v1/rl/weight_update/data.py +++ b/xtuner/v1/rl/weight_update/data.py @@ -1,20 +1,59 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Dict, List, Literal, TypeAlias +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast import torch -from torch.distributed.device_mesh import DeviceMesh -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices. -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs. -RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) +if TYPE_CHECKING: + from xtuner.v1.rl.rollout.worker import RolloutConfig + + TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode. RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend. WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types. +def _resolve_rollout_backend(rollout_config: RolloutConfig) -> RolloutBackend: + # Backend selection follows rollout launcher precedence. + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = "vllm" + else: + backend = (rollout_config.extra_rollout_config or dict()).get("lmdeploy_backend", "pytorch") + + backend = backend.lower() + if backend not in ("sglang", "vllm", "pytorch", "turbomind"): + raise ValueError( + f"Unsupported rollout backend: {backend!r}. Expected 'sglang', 'vllm', 'pytorch' or 'turbomind'." + ) + return cast(RolloutBackend, backend) + + +def _resolve_transport_type( + *, + train_rollout_mode: TrainRolloutMode | str, + backend: RolloutBackend, +) -> tuple[TrainRolloutMode, WeightTransportType]: + assert train_rollout_mode is not None, "bind_rollout_weight_update() must set train_rollout_mode." + + mode = train_rollout_mode.lower() + if mode not in ("colocate", "disaggregated"): + raise ValueError( + f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." + ) + mode = cast(TrainRolloutMode, mode) + if mode == "colocate": + return mode, "ipc" + + if backend == "vllm" or backend == "turbomind": + raise NotImplementedError(f"Disaggregated train-rollout mode is not supported for {backend} backend.") + return mode, "nccl" + + @dataclass(frozen=True) class RolloutWeightUpdateTarget: """Runtime weight-update endpoint resolved from rollout registry state.""" @@ -37,34 +76,147 @@ def engine_size(self) -> int: return len(self.update_ranks) -@dataclass +@dataclass(frozen=True) class RolloutWeightUpdateInfo: - # Common rollout metadata. - api_key: list[str] | str | None = None - rollout_url: str | None = None - backend: RolloutBackend | None = None - tp: int = 1 - ep: int = 1 - train_rollout_mode: TrainRolloutMode | None = None - transport_type: WeightTransportType | None = None - rollout_cfg_info: dict = field(default_factory=dict) - endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"}) - - # Colocated rollout metadata. - rollout_device_mesh: DeviceMesh | None = None - rollout_engine_rank_mesh_array: DeviceMeshRaw = field(default_factory=list) - - # Disaggregated rollout metadata. - rollout_server_url_dict: ServiceUrlMap = field(default_factory=dict) - worker_server_urls_status: dict[str, bool] = field(default_factory=dict) + # Rollout config owns api_key, backend choice, TP/EP, and default update host/port. + rollout_config: RolloutConfig + # Registry-resolved rollout update targets visible to every train worker. + weight_update_targets: tuple[RolloutWeightUpdateTarget, ...] + # Current train worker rank; used to derive the local weight update target. + train_rank: int + # Deployment mode that decides which weight transport family is used. + train_rollout_mode: TrainRolloutMode + # Concrete transport selected from train_rollout_mode and rollout config. + transport_type: WeightTransportType + # Resolved rollout backend used by transports and iterators. + backend: RolloutBackend + # Optional host used by NCCL external weight update groups. weight_update_host: str | None = None + # Optional port used by NCCL external weight update groups. weight_update_port: int | None = None + @classmethod + def from_targets( + cls, + *, + rollout_config: RolloutConfig, + weight_update_targets: tuple[RolloutWeightUpdateTarget, ...], + train_rank: int, + train_rollout_mode: TrainRolloutMode | str, + weight_update_host: str | None = None, + weight_update_port: int | None = None, + ) -> RolloutWeightUpdateInfo: + backend = _resolve_rollout_backend(rollout_config) + tp = rollout_config.tensor_parallel_size + ep = rollout_config.expert_parallel_size + assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." + mode, transport_type = _resolve_transport_type( + train_rollout_mode=train_rollout_mode, + backend=backend, + ) + return cls( + rollout_config=rollout_config, + weight_update_targets=weight_update_targets, + train_rank=train_rank, + train_rollout_mode=mode, + transport_type=transport_type, + backend=backend, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port if weight_update_port is not None else 30000, + ) + + @property + def local_update_target(self) -> RolloutWeightUpdateTarget | None: + return next( + (target for target in self.weight_update_targets if self.train_rank == target.endpoint_rank), + None, + ) + + @property + def rollout_url(self) -> str | None: + target = self.local_update_target + if target is None or not target.is_active: + return None + return target.server_url + + @property + def ipc_rank_mesh(self) -> tuple[tuple[int, ...], ...]: + return tuple(target.update_ranks for target in self.weight_update_targets) + + @property + def _ipc_update_target(self) -> RolloutWeightUpdateTarget | None: + return next( + (target for target in self.weight_update_targets if self.train_rank in target.update_ranks), + None, + ) + + @property + def ipc_engine_parallel_rank(self) -> int | None: + target = self._ipc_update_target + if target is None: + return None + return target.update_ranks.index(self.train_rank) + + @property + def ipc_engine_parallel_size(self) -> int | None: + target = self._ipc_update_target + if target is None: + return None + return target.engine_size + + @property + def active_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: + return tuple(target for target in self.weight_update_targets if target.is_active) + + @property + def nccl_engine_infos(self) -> tuple[tuple[int, str, int], ...]: + return tuple( + (target.endpoint_rank, target.server_url, target.engine_size) for target in self.active_update_targets + ) + + @property + def transport_signature(self) -> tuple[Any, ...]: + target_signature = tuple( + ( + target.endpoint_rank, + tuple(int(rank) for rank in target.update_ranks), + target.server_url, + target.lifecycle_state, + ) + for target in self.weight_update_targets + ) + frozen_api_key = tuple(self.api_key) if isinstance(self.api_key, list) else self.api_key + return ( + self.train_rollout_mode, + self.backend, + self.tp, + self.ep, + frozen_api_key, + self.weight_update_host, + self.weight_update_port, + target_signature, + ) + + @property + def api_key(self) -> list[str] | str | None: + return self.rollout_config.api_key + + @property + def tp(self) -> int: + return self.rollout_config.tensor_parallel_size + + @property + def ep(self) -> int: + return self.rollout_config.expert_parallel_size + @dataclass class WeightUpdateBatch: """A single bucket of weights to send to rollout workers.""" + # HF-style named tensors or backend-specific tensors for one update bucket. state_dict: dict[str, torch.Tensor] + # Whether the train model uses EP and may need rollout EP slicing. train_enable_ep: bool = False + # Whether this is the final bucket in the current update stream. finished: bool = False diff --git a/xtuner/v1/rl/weight_update/transport.py b/xtuner/v1/rl/weight_update/transport.py index 2ae3639be..765416892 100644 --- a/xtuner/v1/rl/weight_update/transport.py +++ b/xtuner/v1/rl/weight_update/transport.py @@ -14,6 +14,7 @@ import torch import torch.distributed as dist from packaging.version import parse as parse_version +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.distributed_c10d import ( Backend, PrefixStore, @@ -39,7 +40,9 @@ @dataclass class WeightUpdateRequest: + # HTTP endpoint on the rollout server that should receive this update. endpoint: str + # JSON body sent to the rollout backend adapter endpoint. body: dict[str, Any] @@ -60,7 +63,7 @@ def __init__(self, *, rollout_info: RolloutWeightUpdateInfo, logger: Any, rank: self._adapter: WeightTransportAdapter | None = None self.rollout_url = self.rollout_info.rollout_url - if self.rollout_url is None: + if self.rollout_url is None and self.rollout_info.transport_type == "ipc": self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") @staticmethod @@ -431,9 +434,12 @@ def __init__( self.config = config self._adapter = self._build_adapter() - assert self.rollout_info.rollout_device_mesh is not None - self.rollout_device_mesh = self.rollout_info.rollout_device_mesh - self.cpu_mesh = self.rollout_info.rollout_device_mesh["engine_parallel"] + self.ipc_update_device_mesh = DeviceMesh( + "cpu", + mesh=[list(ranks) for ranks in self.rollout_info.ipc_rank_mesh], + mesh_dim_names=("engine_instance", "engine_parallel"), + ) + self.cpu_mesh = self.ipc_update_device_mesh["engine_parallel"] self.cpu_group = self.cpu_mesh.get_group() self.head_rank = int(self.cpu_mesh.mesh[0].item()) @@ -652,36 +658,8 @@ def ensure_nccl_weight_update_group(self): if self.group is not None: return - # Map rollout rank to its engine size. - rank_to_engine_size = { - int(rank): len(engine_ranks) - for engine_ranks in self.rollout_info.rollout_engine_rank_mesh_array - for rank in engine_ranks - } - - # Deduplicate rollout engine URLs while keeping the first rank associated - # with each URL as the representative rank for that engine. - url_to_rank: dict[str, int] = {} - for rank, url in sorted( - self.rollout_info.rollout_server_url_dict.items(), - key=lambda item: int(item[0]), - ): - if url: - url_to_rank.setdefault(url, int(rank)) - - # Collect the representative rank, URL, and engine size needed to create - # the NCCL weight update process group. - engine_info = [ - ( - rank, - url, - rank_to_engine_size.get( - rank, - max(self.rollout_info.tp, self.rollout_info.ep), - ), - ) - for url, rank in url_to_rank.items() - ] + # RolloutWeightUpdateInfo owns the runtime target projection. + engine_info = self.rollout_info.nccl_engine_infos if not engine_info: self.logger.error("No active rollout engine url, cannot init sglang weight update group") diff --git a/xtuner/v1/rl/weight_update/update_weighter.py b/xtuner/v1/rl/weight_update/update_weighter.py index 1732a2b15..82a9e75ec 100644 --- a/xtuner/v1/rl/weight_update/update_weighter.py +++ b/xtuner/v1/rl/weight_update/update_weighter.py @@ -1,18 +1,12 @@ from __future__ import annotations -import os -from typing import Any, cast - -from torch.distributed.device_mesh import DeviceMesh +from typing import Any from xtuner.v1.rl.rollout.worker import RolloutConfig from .data import ( - DeviceMeshRaw, - RolloutBackend, RolloutWeightUpdateInfo, RolloutWeightUpdateTarget, - ServiceUrlMap, TrainRolloutMode, ) from .transport import IPCWeightTransport, NCCLWeightTransport, WeightTransport @@ -25,86 +19,37 @@ def __init__(self, *, rank: int, logger: Any, config: Any, engine: Any): self.logger = logger self.config = config self._engine = engine - # Used to update weight to rollout engine. - self.rollout_info = RolloutWeightUpdateInfo() + # Bound rollout weight-update metadata, available after bind_rollout_weight_update(). + self.rollout_info: RolloutWeightUpdateInfo | None = None + # Lazily constructed iterator bound to the current rollout_info. + self.weight_iterator: WeightIterator | None = None self._global_hf_keys_mapping_cache: dict[str, list[str]] = {} - # Transport is initialized after update_rollout_info() is called. + # Transport is initialized after bind_rollout_weight_update() is called. self._transport: WeightTransport | None = None # Used to detect changes in rollout metadata that require resetting the transport. self._transport_signature: tuple[Any, ...] | None = None - @staticmethod - def _normalize_rollout_backend(rollout_config: RolloutConfig) -> RolloutBackend: - # Backend selection follows rollout launcher precedence: explicit SGLang/vLLM env vars win, - # otherwise the LMDeploy backend decides between pytorch and turbomind. - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - backend = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - backend = "vllm" - else: - backend = (rollout_config.extra_rollout_config or dict()).get("lmdeploy_backend", "pytorch") - - backend = backend.lower() - if backend not in ("sglang", "vllm", "pytorch", "turbomind"): - raise ValueError( - f"Unsupported rollout backend: {backend!r}. Expected 'sglang', 'vllm', 'pytorch' or 'turbomind'." - ) - return cast(RolloutBackend, backend) - - def update_rollout_info( + def bind_rollout_weight_update( self, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, + *, + targets: tuple[RolloutWeightUpdateTarget, ...], rollout_config: RolloutConfig, - worker_server_urls_status: dict[str, bool], train_rollout_mode: TrainRolloutMode, weight_update_host: str | None = None, weight_update_port: int | None = None, - worker_session_url_dict: ServiceUrlMap | None = None, - worker_session_urls_status: dict[str, bool] | None = None, ): - """Update the rollout information for the training worker.""" - - self.rollout_info.backend = self._normalize_rollout_backend(rollout_config) - self.set_train_rollout_mode(train_rollout_mode=train_rollout_mode) - - # Common rollout metadata. - tp = rollout_config.tensor_parallel_size - ep = rollout_config.expert_parallel_size - assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." - self.rollout_info.tp = tp - self.rollout_info.ep = ep - self.rollout_info.api_key = rollout_config.api_key - rollout_server_url = server_url_dict.get(self.rank, "") - if not worker_server_urls_status.get(rollout_server_url, False): - self.logger.error(f"Rollout server url {rollout_server_url} is not available.") - self.rollout_info.rollout_url = None - else: - self.rollout_info.rollout_url = rollout_server_url - - if self.rollout_info.transport_type == "ipc": - # Colocated rollout metadata. - # rollout_device_mesh is created after train_rollout_mode is set. - self.rollout_info.rollout_engine_rank_mesh_array = [ - [int(rank) for rank in ranks] for ranks in engine_rank_mesh_array - ] - self._ensure_rollout_device_mesh() - elif self.rollout_info.transport_type == "nccl": - # Disaggregated rollout metadata. - self.rollout_info.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} - self.rollout_info.worker_server_urls_status = worker_server_urls_status - self.rollout_info.weight_update_host = weight_update_host - self.rollout_info.weight_update_port = weight_update_port if weight_update_port is not None else 30000 + """Bind this train worker to rollout weight-update targets.""" - new_transport_signature = self._build_transport_signature( - engine_rank_mesh_array=engine_rank_mesh_array, - server_url_dict=server_url_dict, - worker_server_urls_status=worker_server_urls_status, + self.rollout_info = RolloutWeightUpdateInfo.from_targets( + rollout_config=rollout_config, + weight_update_targets=targets, + train_rank=self.rank, train_rollout_mode=train_rollout_mode, - backend=self.rollout_info.backend, - tp=tp, - ep=ep, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, ) + + new_transport_signature = self.rollout_info.transport_signature # Weight transports may cache resources derived from rollout metadata. # Since rollout workers can fail and recover with new URL/status/mesh metadata, # reset the cached transport whenever that metadata changes. @@ -122,112 +67,32 @@ def update_rollout_info( if self._transport is None: self._set_transport() - def bind_rollout_weight_update( - self, - *, - targets: tuple[RolloutWeightUpdateTarget, ...], - rollout_config: RolloutConfig, - train_rollout_mode: TrainRolloutMode, - weight_update_host: str | None = None, - weight_update_port: int | None = None, - ): - """Bind rollout weight-update targets through the legacy metadata - path.""" - - self.update_rollout_info( - engine_rank_mesh_array=[list(target.update_ranks) for target in targets], - server_url_dict={target.endpoint_rank: target.server_url for target in targets}, - rollout_config=rollout_config, - worker_server_urls_status={target.server_url: target.is_active for target in targets}, - train_rollout_mode=train_rollout_mode, - weight_update_host=weight_update_host, - weight_update_port=weight_update_port, - ) - - def _ensure_rollout_device_mesh(self): - if self.rollout_info.rollout_device_mesh is None: - # 非共卡 SGLang 不使用这个 mesh;只有共卡/旧权重同步路径需要 - # 用 rollout rank 构造 torch DeviceMesh。 - self.rollout_info.rollout_device_mesh = DeviceMesh( - "cpu", - mesh=self.rollout_info.rollout_engine_rank_mesh_array, - mesh_dim_names=("engine_instance", "engine_parallel"), - ) - - def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str): - assert train_rollout_mode is not None, "update_rollout_info() must set train_rollout_mode." - - if self.rollout_info.backend is None: - raise RuntimeError("rollout backend is not set. Please set rollout backend in update_rollout_info().") - - mode = train_rollout_mode.lower() - if mode not in ("colocate", "disaggregated"): - raise ValueError( - f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." - ) - mode = cast(TrainRolloutMode, mode) - self.rollout_info.train_rollout_mode = mode - if mode == "colocate": - self.rollout_info.transport_type = "ipc" - elif mode == "disaggregated": - self.rollout_info.transport_type = "nccl" - - backend = self.rollout_info.backend - if backend == "vllm" or backend == "turbomind": - raise NotImplementedError(f"Disaggregated train-rollout mode is not supported for {backend} backend.") - def update_weights(self): """Update the model weights.""" + assert self.rollout_info is not None, "bind_rollout_weight_update() must be called before update_weights()." assert self._transport is not None, ( f"Weight transport is not initialized. transport_type={self.rollout_info.transport_type!r}, " f"backend={self.rollout_info.backend!r}." ) + assert self.weight_iterator is not None, "Weight iterator is not initialized." self._transport.update(self.weight_iterator) def _set_transport(self) -> None: - if self.rollout_info.transport_type == "ipc": + rollout_info = self.rollout_info + assert rollout_info is not None, "bind_rollout_weight_update() must be called before setting transport." + if rollout_info.transport_type == "ipc": self._transport = IPCWeightTransport( rank=self.rank, logger=self.logger, config=self.config, - rollout_info=self.rollout_info, + rollout_info=rollout_info, ) - elif self.rollout_info.transport_type == "nccl": - self._transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=self.rollout_info) + elif rollout_info.transport_type == "nccl": + self._transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=rollout_info) else: raise NotImplementedError - def _build_transport_signature( - self, - *, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, - worker_server_urls_status: dict[str, bool], - train_rollout_mode: TrainRolloutMode, - backend: RolloutBackend, - tp: int, - ep: int, - ) -> tuple[Any, ...]: - mesh = tuple(tuple(int(rank) for rank in ranks) for ranks in engine_rank_mesh_array) - - active_urls = tuple( - sorted( - (int(rank), url) - for rank, url in server_url_dict.items() - if url and worker_server_urls_status.get(url, False) - ) - ) - - return ( - train_rollout_mode, - backend, - tp, - ep, - mesh, - active_urls, - ) - def _reset_transport(self) -> None: if self._transport is not None: self._transport.teardown() diff --git a/xtuner/v1/rl/weight_update/weight_iterator.py b/xtuner/v1/rl/weight_update/weight_iterator.py index 9e8c5b783..099b510f1 100644 --- a/xtuner/v1/rl/weight_update/weight_iterator.py +++ b/xtuner/v1/rl/weight_update/weight_iterator.py @@ -194,14 +194,16 @@ def iter_hf_batches(self, submodule=None, final_update=False): if train_enable_ep: if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.ep > 1: - rollout_device_mesh = self.rollout_info.rollout_device_mesh - assert rollout_device_mesh is not None + target_ep_rank = self.rollout_info.ipc_engine_parallel_rank + target_ep_size = self.rollout_info.ipc_engine_parallel_size + assert target_ep_rank is not None, "IPC rollout target for current train rank is not resolved." + assert target_ep_size is not None, "IPC rollout target size for current train rank is not resolved." # Colocated IPC can send only the expert slice needed by the local rollout # EP rank fused_gen = self._rl_get_fused_ep_hf_param( model, - target_ep_rank=rollout_device_mesh["engine_parallel"].get_coordinate()[0], - target_ep_size=rollout_device_mesh["engine_parallel"].size(), + target_ep_rank=target_ep_rank, + target_ep_size=target_ep_size, bucket_size=bucket_size, should_gather_train_ep_shards=should_gather_train_ep_shards, ) @@ -251,7 +253,6 @@ def iter_hf_batches(self, submodule=None, final_update=False): @torch.no_grad() def iter_layer_batches(self): """Update the model weights.""" - assert self.rollout_info.rollout_device_mesh is not None model = self._engine.model DEVICE_MODULE.empty_cache() From 3cfb76294e219258a392ab75055ac67cfc4e5524 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Jun 2026 10:41:35 +0000 Subject: [PATCH 5/5] Update rollout topology and weight update tests --- .../rl/test_multi_task_agent_loop_manager.py | 4 +- tests/rl/test_producer.py | 2 +- tests/rl/test_rl_colocate_trainer.py | 2 +- tests/rl/test_rl_trainer_checkpoint.py | 23 +- tests/rl/test_rollout_logic.py | 384 ++++++++++++++++-- tests/rl/test_update_weight_disaggregated.py | 32 +- 6 files changed, 385 insertions(+), 62 deletions(-) diff --git a/tests/rl/test_multi_task_agent_loop_manager.py b/tests/rl/test_multi_task_agent_loop_manager.py index 124e923e6..8db40aa22 100644 --- a/tests/rl/test_multi_task_agent_loop_manager.py +++ b/tests/rl/test_multi_task_agent_loop_manager.py @@ -167,7 +167,7 @@ def _fake_agent_loop(): rollout_ctl = MagicMock() rollout_ctl.continue_generation.remote = AsyncMock() rollout_ctl.pause_generation.remote = AsyncMock() - rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=()) agent_loop = MagicMock() agent_loop.rollout_ctl = rollout_ctl return agent_loop @@ -177,7 +177,7 @@ def _fake_rollout_controller(): rollout_controller = MagicMock() rollout_controller.continue_generation.remote = AsyncMock() rollout_controller.pause_generation.remote = AsyncMock() - rollout_controller.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + rollout_controller.get_weight_update_targets.remote = AsyncMock(return_value=()) return rollout_controller diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py index 879a3ca79..3514542f8 100644 --- a/tests/rl/test_producer.py +++ b/tests/rl/test_producer.py @@ -100,7 +100,7 @@ def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None): mock_agent_loop = MagicMock() mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) - mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + mock_agent_loop.rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=()) async def mock_pause(): await mock_agent_loop.rollout_ctl.pause_generation.remote() diff --git a/tests/rl/test_rl_colocate_trainer.py b/tests/rl/test_rl_colocate_trainer.py index 389879db2..69f902d6d 100644 --- a/tests/rl/test_rl_colocate_trainer.py +++ b/tests/rl/test_rl_colocate_trainer.py @@ -75,7 +75,7 @@ def _build_fake_rollout_controller(): rollout_ctl = MagicMock() rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) - rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=()) return rollout_ctl diff --git a/tests/rl/test_rl_trainer_checkpoint.py b/tests/rl/test_rl_trainer_checkpoint.py index b3f2db0bc..9d14972b2 100644 --- a/tests/rl/test_rl_trainer_checkpoint.py +++ b/tests/rl/test_rl_trainer_checkpoint.py @@ -104,7 +104,7 @@ def __init__(self): self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted") self.onload_weights = _RemoteMethod(return_value="weights_loaded") self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded") - self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}}) + self.get_weight_update_targets = _RemoteMethod(return_value=()) self.set_enable_partial_rollout = _RemoteMethod(return_value=None) self.validate_registered_workers_to_proxy = _RemoteMethod(return_value=_AwaitableValue(None)) @@ -127,14 +127,19 @@ def __init__(self): self.update_weights_count = 0 self.rollout_info = None - def update_rollout_info( - self, - info, - train_rollout_mode, - weight_update_host, - weight_update_port - ): - self.rollout_info = info + def bind_rollout_weight_update( + self, + *, + targets, + rollout_config, + train_rollout_mode, + weight_update_host=None, + weight_update_port=None, + ): + self.rollout_info = { + "targets": targets, + "rollout_config": rollout_config, + } self.train_rollout_mode = train_rollout_mode self.weight_update_host = weight_update_host self.weight_update_port = weight_update_port diff --git a/tests/rl/test_rollout_logic.py b/tests/rl/test_rollout_logic.py index cb3f73a47..b9740784b 100644 --- a/tests/rl/test_rollout_logic.py +++ b/tests/rl/test_rollout_logic.py @@ -23,12 +23,19 @@ from xtuner.v1.rl.agent_loop import AgentLoopConfig from xtuner.v1.rl.rollout.controller import RolloutController from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager +from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker +from xtuner.v1.rl.rollout.rollout_topology import RolloutEngine, RolloutTopology, RolloutServerProcess from xtuner.v1.rl.rollout.proxy_manager import RolloutProxyManager -from xtuner.v1.rl.rollout.worker_registry import RolloutWorkerRegistry, WorkerLifecycleState, WorkerSnapshot +from xtuner.v1.rl.rollout.worker_registry import ( + RolloutWorkerRegistry, + WorkerLifecycleState, + WorkerSnapshot, +) from xtuner.v1.rl.rollout.sglang import SGLangWorker from xtuner.v1.rl.rollout.utils import PartialRolloutHandler, SessionRouter from xtuner.v1.rl.rollout.worker import RolloutWorker from xtuner.v1.rl.utils.misc import delete_from_routedapiproxy +from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateInfo from xtuner.v1.train.rl_trainer import BaseRLTrainer, _agent_loop_manager_requires_rollout_proxy from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult @@ -122,6 +129,183 @@ def test_trainer_auto_enables_rollout_proxy_when_agent_loop_requires_it(self): self.assertTrue(trainer._rollout_config.enable_proxy) +class TestRolloutTopologyAPI(unittest.TestCase): + def _rollout_config( + self, + *, + tp: int, + ep: int, + num_gpus_per_engine: int, + gpus_per_node: int = 8, + ): + return SimpleNamespace( + api_key="test-key", + tensor_parallel_size=tp, + expert_parallel_size=ep, + num_gpus_per_engine=num_gpus_per_engine, + gpus_per_node=gpus_per_node, + extra_rollout_config={"lmdeploy_backend": "pytorch"}, + ) + + def _rank_bundle_idx_list(self, num_workers: int): + return [(rank, rank) for rank in range(num_workers)] + + def _rank_to_dist_init_addr(self, num_workers: int): + return {rank: f"host{rank}:25{rank:03d}" for rank in range(num_workers)} + + def _weight_update_targets(self, topology: RolloutTopology): + registry = RolloutWorkerRegistry(rollout_topology=topology) + for spec in topology.server_launch_specs(): + registry.register_started_server( + rank=spec.worker_rank, + actor=object(), + server_url=f"http://worker-{spec.worker_rank}", + session_url=f"http://session-{spec.worker_rank}", + ) + return registry.weight_update_targets() + + def _rollout_info(self, *, config, targets, train_rank: int): + return RolloutWeightUpdateInfo.from_targets( + rollout_config=config, + weight_update_targets=targets, + train_rank=train_rank, + train_rollout_mode="colocate", + ) + + def test_rollout_topology_resolves_engine_dist_init_addr_when_created(self): + rank_to_dist_init_addr = {0: "host0:25000", 1: "host1:25004"} + dist_init_addr_owner_rank = 0 + engine = RolloutEngine( + engine_ranks=(0, 1), + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1), + ), + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(), + accepts_rollout_requests=False, + ), + ), + ) + + topology = RolloutTopology( + engines=(engine,), + ) + + launch_specs = topology.server_launch_specs() + self.assertEqual(tuple(spec.worker_rank for spec in launch_specs), (0, 1)) + rank_0_launch_spec, rank_1_launch_spec = launch_specs + self.assertEqual(rank_0_launch_spec.dist_init_addr, "host0:25000") + self.assertEqual(rank_1_launch_spec.dist_init_addr, "host0:25000") + self.assertEqual(rank_0_launch_spec.engine_rank, 0) + self.assertEqual(rank_1_launch_spec.engine_rank, 1) + self.assertEqual(rank_1_launch_spec.placement_group_bundle_idxs, (1,)) + self.assertTrue(topology.is_request_entrypoint_rank(0)) + self.assertFalse(topology.is_request_entrypoint_rank(1)) + self.assertEqual(topology.lifecycle_group_for_server_rank(1), (0, 1)) + self.assertEqual( + tuple( + (server.worker_rank, server.weight_update_ranks) + for server in topology.weight_update_endpoint_processes() + ), + ((0, (0, 1)),), + ) + + def test_rollout_topology_rejects_duplicate_server_process_ranks(self): + with self.assertRaisesRegex(ValueError, "Duplicate rollout server process worker_rank=0"): + RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0,), + dist_init_addr="addr0", + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + weight_update_ranks=(0,), + ), + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(), + accepts_rollout_requests=False, + ), + ), + ), + ), + ) + + def test_lmdeploy_tp16_weight_update_targets_match_legacy_mesh_and_url_semantics(self): + config = self._rollout_config(tp=16, ep=1, num_gpus_per_engine=16) + topology = LMDeployWorker.build_rollout_topology( + config, + self._rank_bundle_idx_list(16), + self._rank_to_dist_init_addr(16), + ) + targets = self._weight_update_targets(topology) + + self.assertEqual( + tuple((target.endpoint_rank, target.update_ranks) for target in targets), + ((0, tuple(range(16))),), + ) + self.assertEqual(self._rollout_info(config=config, targets=targets, train_rank=0).rollout_url, "http://worker-0") + self.assertIsNone(self._rollout_info(config=config, targets=targets, train_rank=1).rollout_url) + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=1).ipc_rank_mesh, + (tuple(range(16)),), + ) + + def test_lmdeploy_ep16_weight_update_targets_match_legacy_mesh_and_url_semantics(self): + config = self._rollout_config(tp=1, ep=16, num_gpus_per_engine=16) + topology = LMDeployWorker.build_rollout_topology( + config, + self._rank_bundle_idx_list(16), + self._rank_to_dist_init_addr(16), + ) + targets = self._weight_update_targets(topology) + + self.assertEqual( + tuple((target.endpoint_rank, target.update_ranks) for target in targets), + tuple((rank, (rank,)) for rank in range(16)), + ) + self.assertEqual(self._rollout_info(config=config, targets=targets, train_rank=0).rollout_url, "http://worker-0") + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=15).rollout_url, + "http://worker-15", + ) + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=0).ipc_rank_mesh, + tuple((rank,) for rank in range(16)), + ) + + def test_sglang_tp16_cross_node_weight_update_targets_match_legacy_mesh_and_url_semantics(self): + config = self._rollout_config(tp=16, ep=1, num_gpus_per_engine=16, gpus_per_node=8) + topology = SGLangWorker.build_rollout_topology( + config, + self._rank_bundle_idx_list(16), + self._rank_to_dist_init_addr(16), + ) + targets = self._weight_update_targets(topology) + + self.assertEqual(tuple(spec.worker_rank for spec in topology.server_launch_specs()), (0, 8)) + self.assertEqual( + tuple((target.endpoint_rank, target.update_ranks) for target in targets), + ((0, tuple(range(16))),), + ) + self.assertEqual(self._rollout_info(config=config, targets=targets, train_rank=0).rollout_url, "http://worker-0") + self.assertIsNone(self._rollout_info(config=config, targets=targets, train_rank=8).rollout_url) + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=8).ipc_rank_mesh, + (tuple(range(16)),), + ) + + class TestRolloutController(unittest.IsolatedAsyncioTestCase): def _state(self, uid: int, session_id: int) -> RolloutState: return RolloutState( @@ -142,6 +326,27 @@ def _build_controller(self, router): controller.logger = MagicMock() return controller + def _build_registry(self, ranks): + rollout_topology = RolloutTopology( + engines=tuple( + RolloutEngine( + engine_ranks=(rank,), + dist_init_addr=f"addr{rank}", + server_processes=( + RolloutServerProcess( + worker_rank=rank, + placement_group_bundle_idxs=(rank,), + weight_update_ranks=(rank,), + ), + ), + ) + for rank in ranks + ), + ) + return RolloutWorkerRegistry( + rollout_topology=rollout_topology, + ) + async def test_generate_fails_fast_when_no_active_worker(self): # router 找不到 active worker 时,controller 应直接把原样本标成 FAILED,避免请求悬挂。 state = self._state(uid=1, session_id=123) @@ -175,20 +380,18 @@ async def test_generate_routes_to_active_worker(self): def test_register_active_workers_to_proxy_delegates_active_session_urls(self): controller = RolloutController.__new__(RolloutController) - controller.registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0], [1]], rollout_config=SimpleNamespace()) + controller.registry = self._build_registry((0, 1)) controller.registry.register_started_server( rank=0, actor=object(), server_url="http://worker-0", session_url="http://session-0", - is_request_entrypoint=True, ) controller.registry.register_started_server( rank=1, actor=object(), server_url="http://worker-1", session_url="http://session-1", - is_request_entrypoint=True, ) controller.registry.mark_unhealthy_ranks({1}) controller.proxy_manager = MagicMock() @@ -199,7 +402,7 @@ def test_register_active_workers_to_proxy_delegates_active_session_urls(self): def test_register_active_workers_to_proxy_noops_without_proxy_manager(self): controller = RolloutController.__new__(RolloutController) - controller.registry = RolloutWorkerRegistry(engine_rank_mesh_array=[], rollout_config=SimpleNamespace()) + controller.registry = MagicMock() controller.proxy_manager = None controller.register_active_workers_to_proxy() @@ -360,34 +563,64 @@ class TestRolloutWorkerRegistry(unittest.TestCase): def _worker_by_rank(self, registry, rank): return next(worker for worker in registry.all_workers() if worker.rank == rank) - def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): - config = SimpleNamespace() - registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0, 1]], rollout_config=config) + def _runtime_layout( + self, + *, + engine_ranks=(0,), + server_processes=None, + ): + if server_processes is None: + server_processes = ( + RolloutServerProcess( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=tuple(range(len(engine_ranks))), + accepts_rollout_requests=True, + weight_update_ranks=tuple(engine_ranks), + ), + ) + dist_init_addr_owner_rank = server_processes[0].worker_rank + return RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=tuple(engine_ranks), + dist_init_addr=f"addr{dist_init_addr_owner_rank}", + server_processes=tuple(server_processes), + ), + ), + ) + + def test_registry_filters_entrypoints_and_tracks_lifecycle(self): + runtime_layout = self._runtime_layout( + engine_ranks=(0, 1), + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1), + ), + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(), + accepts_rollout_requests=False, + ), + ), + ) + registry = RolloutWorkerRegistry(rollout_topology=runtime_layout) registry.register_started_server( rank=0, actor=object(), server_url="http://worker-0", session_url="http://session-0", - lifecycle_group_ranks=(0, 1), - is_request_entrypoint=True, ) registry.register_started_server( rank=1, actor=object(), server_url="http://worker-1", session_url=None, - lifecycle_group_ranks=(0, 1), - is_request_entrypoint=False, ) - metadata = registry.training_metadata_snapshot() - - self.assertEqual(metadata["engine_rank_mesh_array"], [[0, 1]]) - self.assertIs(metadata["rollout_config"], config) - self.assertEqual(metadata["server_url_dict"], {0: "http://worker-0"}) - self.assertEqual(metadata["worker_server_urls_status"], {"http://worker-0": True}) - self.assertEqual(metadata["worker_session_url_dict"], {0: "http://session-0"}) - self.assertEqual(metadata["worker_session_urls_status"], {"http://session-0": True}) active_entrypoint = registry.active_entrypoints()[0] self.assertIsInstance(active_entrypoint, WorkerSnapshot) self.assertEqual(active_entrypoint.rank, 0) @@ -395,11 +628,8 @@ def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): active_entrypoint.lifecycle_state = WorkerLifecycleState.INACTIVE unhealthy_groups = registry.mark_unhealthy_ranks({0}) - metadata = registry.training_metadata_snapshot() self.assertEqual(unhealthy_groups[0].ranks, (0, 1)) - self.assertEqual(metadata["worker_server_urls_status"], {"http://worker-0": False}) - self.assertEqual(metadata["worker_session_urls_status"], {"http://session-0": False}) self.assertEqual(tuple(worker.rank for worker in registry.inactive_workers()), (0, 1)) self.assertEqual(registry.active_entrypoints(), ()) claimed_groups = registry.claim_inactive_groups_for_recovery() @@ -408,13 +638,72 @@ def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): registry.set_group_recovery_result(claimed_groups[0], recovered=False) self.assertEqual(self._worker_by_rank(registry, 0).lifecycle_state, WorkerLifecycleState.INACTIVE) + def test_registry_projects_weight_update_targets_from_topology_and_runtime_state(self): + runtime_layout = self._runtime_layout(engine_ranks=(0, 1)) + registry = RolloutWorkerRegistry(rollout_topology=runtime_layout) + registry.register_started_server( + rank=0, + actor=object(), + server_url="http://worker-0", + session_url="http://session-0", + ) + + targets = registry.weight_update_targets() + + self.assertEqual(len(targets), 1) + target = targets[0] + self.assertEqual(target.endpoint_rank, 0) + self.assertEqual(target.update_ranks, (0, 1)) + self.assertEqual(target.engine_size, 2) + self.assertEqual(target.server_url, "http://worker-0") + self.assertEqual(target.lifecycle_state, WorkerLifecycleState.ACTIVE.value) + self.assertTrue(target.is_active) + class TestSessionRouter(unittest.IsolatedAsyncioTestCase): async def test_sticky_session_reselects_when_previous_entrypoint_is_inactive(self): actor_0 = object() actor_1 = object() - registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0], [1]], rollout_config=SimpleNamespace()) - registry.register_started_server(rank=0, actor=actor_0, server_url="http://worker-0") - registry.register_started_server(rank=1, actor=actor_1, server_url="http://worker-1") + rollout_topology = RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0,), + dist_init_addr="addr0", + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + weight_update_ranks=(0,), + ), + ), + ), + RolloutEngine( + engine_ranks=(1,), + dist_init_addr="addr1", + server_processes=( + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(1,), + ), + ), + ), + ), + ) + registry = RolloutWorkerRegistry( + rollout_topology=rollout_topology, + ) + registry.register_started_server( + rank=0, + actor=actor_0, + server_url="http://worker-0", + session_url="http://session-0", + ) + registry.register_started_server( + rank=1, + actor=actor_1, + server_url="http://worker-1", + session_url="http://session-1", + ) router = SessionRouter(registry, max_idle_seconds=None) self.assertIs(await router.get_worker(7), actor_0) @@ -667,19 +956,35 @@ def _worker_by_rank(self, registry, rank): return next(worker for worker in registry.all_workers() if worker.rank == rank) def _build_registry(self, workers_info): + engines = [] + for rank in sorted(workers_info): + engines.append( + RolloutEngine( + engine_ranks=(rank,), + dist_init_addr=f"addr{rank}", + server_processes=( + RolloutServerProcess( + worker_rank=rank, + placement_group_bundle_idxs=(rank,), + accepts_rollout_requests=True, + weight_update_ranks=(rank,), + ), + ), + ) + ) + rollout_topology = RolloutTopology( + engines=tuple(engines), + ) registry = RolloutWorkerRegistry( - engine_rank_mesh_array=[sorted(workers_info)], - rollout_config=SimpleNamespace(), + rollout_topology=rollout_topology, ) for rank, worker_info in workers_info.items(): - lifecycle_group_ranks = worker_info.lifecycle_group_ranks or (rank,) registry.register_started_server( rank=rank, actor=worker_info.actor, server_url=worker_info.url, - session_url=worker_info.session_url, - lifecycle_group_ranks=lifecycle_group_ranks, - is_request_entrypoint=worker_info.is_request_entrypoint, + session_url=worker_info.session_url or f"http://session-{rank}", + lifecycle_state=worker_info.lifecycle_state, ) if worker_info.lifecycle_state is WorkerLifecycleState.INACTIVE: registry.mark_unhealthy_ranks({rank}) @@ -710,7 +1015,7 @@ def _build_manager( def test_marks_worker_inactive_after_consecutive_health_failures(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") workers_info = {0: worker_info} inactive_groups = [] listener = SimpleNamespace( @@ -738,7 +1043,7 @@ def test_marks_worker_inactive_after_consecutive_health_failures(self): def test_inactive_listener_runs_under_operation_lock(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") lock_acquired_by_listener = [] manager, _ = self._build_manager({0: worker_info}, failure_threshold=1) @@ -764,6 +1069,7 @@ def test_inactive_worker_is_not_cleaned_up_again(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) workers_info = { 0: WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -779,7 +1085,7 @@ def test_inactive_worker_is_not_cleaned_up_again(self): def test_health_check_threshold_zero_disables_periodic_health_check(self): # threshold <= 0 表示关闭周期健康监测,不应把 active worker 直接判 inactive。 actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, registry = self._build_manager({0: worker_info}, failure_threshold=0) checked_count = manager._check_and_deactivate_failed_worker_groups() @@ -790,7 +1096,7 @@ def test_health_check_threshold_zero_disables_periodic_health_check(self): def test_fail_fast_health_check_still_runs_when_periodic_health_check_is_disabled(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, registry = self._build_manager({0: worker_info}, failure_threshold=0) checked_count = manager._check_and_deactivate_failed_worker_groups(fail_fast=True) @@ -801,7 +1107,7 @@ def test_fail_fast_health_check_still_runs_when_periodic_health_check_is_disable def test_health_check_uses_configured_timeout(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, _ = self._build_manager({0: worker_info}, check_timeout=2.5) observed_timeouts = [] @@ -817,6 +1123,7 @@ async def fake_wait_for(awaitable, timeout): def test_shutdown_barrier_keeps_failed_shutdown_group_inactive(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -838,6 +1145,7 @@ def test_shutdown_barrier_keeps_failed_shutdown_group_inactive(self): def test_restart_barrier_keeps_failed_recovery_group_inactive(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -859,6 +1167,7 @@ def test_restart_barrier_keeps_failed_recovery_group_inactive(self): def test_restart_barrier_notifies_recovered_group_after_success(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", session_url="http://session-0", @@ -884,6 +1193,7 @@ def test_restart_barrier_notifies_recovered_group_after_success(self): def test_recovered_listener_runs_under_operation_lock(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, diff --git a/tests/rl/test_update_weight_disaggregated.py b/tests/rl/test_update_weight_disaggregated.py index b53c9121d..28a5223a2 100644 --- a/tests/rl/test_update_weight_disaggregated.py +++ b/tests/rl/test_update_weight_disaggregated.py @@ -118,12 +118,8 @@ def init_config(self): ) def _check_sglang_weights(self, rollout_controller, action): - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - active_urls = [ - url - for url, is_active in info_dict["worker_server_urls_status"].items() - if is_active - ] + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + active_urls = [target.server_url for target in targets if target.is_active] self.assertGreater(len(active_urls), 0) results = [] for url in active_urls: @@ -159,8 +155,12 @@ def test_sglang_disaggregated_update_weight_and_generate(self): input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=self.rollout_cfg, + train_rollout_mode="disaggregated", + ) train_controller.update_weights() res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) @@ -194,8 +194,12 @@ def test_sglang_disaggregated_update_weight_equal_after_reset(self): self._check_sglang_weights(rollout_controller, action="snapshot_parameters") self._check_sglang_weights(rollout_controller, action="reset_parameters") - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=self.rollout_cfg, + train_rollout_mode="disaggregated", + ) train_controller.update_weights() self._check_sglang_weights(rollout_controller, action="compare_parameters") @@ -229,8 +233,12 @@ def test_lmdeploy_disaggregated_update_weight_and_generate(self): input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=self.rollout_cfg, + train_rollout_mode="disaggregated", + ) train_controller.update_weights() res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state))