diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py new file mode 100644 index 000000000..d4ae3e452 --- /dev/null +++ b/gigl/distributed/base_dist_loader.py @@ -0,0 +1,587 @@ +""" +Base distributed loader that consolidates shared initialization logic +from DistNeighborLoader and DistABLPLoader. + +Subclasses GLT's DistLoader and handles: +- Dataset metadata storage +- Colocated mode: DistLoader attribute setting + staggered producer init +- Graph Store mode: barrier loop + async RPC dispatch + channel creation +""" + +import sys +import time +from collections import Counter +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel +from graphlearn_torch.distributed import ( + DistLoader, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, + get_context, +) +from graphlearn_torch.distributed.dist_client import async_request_server +from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer +from graphlearn_torch.distributed.rpc import rpc_is_initialized +from graphlearn_torch.sampler import ( + NodeSamplerInput, + RemoteSamplerInput, + SamplingConfig, + SamplingType, +) +from torch_geometric.typing import EdgeType +from typing_extensions import Self + +import gigl.distributed.utils +from gigl.common.logger import Logger +from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT +from gigl.distributed.dist_context import DistributedContext +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils.neighborloader import ( + DatasetSchema, + patch_fanout_for_sampling, +) +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE + +logger = Logger() + + +# We don't see logs for graph store mode for whatever reason. +# TOOD(#442): Revert this once the GCP issues are resolved. +def _flush() -> None: + sys.stdout.flush() + sys.stderr.flush() + + +@dataclass(frozen=True) +class DistributedRuntimeInfo: + """Plain data container for resolved distributed context information.""" + + node_world_size: int + node_rank: int + rank: int + world_size: int + local_rank: int + local_world_size: int + master_ip_address: str + should_cleanup_distributed_context: bool + + +class BaseDistLoader(DistLoader): + """Base class for GiGL distributed loaders. + + Consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader. + Subclasses GLT's DistLoader but does NOT call its ``__init__`` — instead, it + replicates the relevant attribute-setting logic to allow configurable producer classes. + + Subclasses should: + 1. Call ``resolve_runtime()`` to get runtime context. + 2. Determine mode (colocated vs graph store). + 3. Call ``create_sampling_config()`` to build the SamplingConfig. + 4. For colocated: call ``create_colocated_channel()`` and construct the + ``DistMpSamplingProducer`` (or subclass), then pass the producer as ``sampler``. + 5. For graph store: pass the RPC function (e.g. ``DistServer.create_sampling_producer``) + as ``sampler``. + 6. Call ``super().__init__()`` with the prepared data. + + Args: + dataset: ``DistDataset`` (colocated) or ``RemoteDistDataset`` (graph store). + sampler_input: Prepared by the subclass. Single input for colocated mode, + list (one per server) for graph store mode. + dataset_schema: Contains edge types, feature info, edge dir, etc. + worker_options: ``MpDistSamplingWorkerOptions`` (colocated) or + ``RemoteDistSamplingWorkerOptions`` (graph store). + sampling_config: Configuration for the sampler (created via ``create_sampling_config``). + device: Target device for sampled results. + runtime: Resolved distributed runtime information. + sampler: Either a pre-constructed ``DistMpSamplingProducer`` (colocated mode) + or a callable to dispatch on the ``DistServer`` (graph store mode). + process_start_gap_seconds: Delay between each process for staggered colocated init. + """ + + @staticmethod + def resolve_runtime( + context: Optional[DistributedContext] = None, + local_process_rank: Optional[int] = None, + local_process_world_size: Optional[int] = None, + ) -> DistributedRuntimeInfo: + """Resolves distributed context from either a DistributedContext or torch.distributed. + + Args: + context: (Deprecated) If provided, derives rank info from the DistributedContext. + Requires local_process_rank and local_process_world_size. + local_process_rank: (Deprecated) Required when context is provided. + local_process_world_size: (Deprecated) Required when context is provided. + + Returns: + A DistributedRuntimeInfo containing all resolved rank/topology information. + """ + should_cleanup_distributed_context: bool = False + + if context: + assert ( + local_process_world_size is not None + ), "context: DistributedContext provided, so local_process_world_size must be provided." + assert ( + local_process_rank is not None + ), "context: DistributedContext provided, so local_process_rank must be provided." + + master_ip_address = context.main_worker_ip_address + node_world_size = context.global_world_size + node_rank = context.global_rank + local_world_size = local_process_world_size + local_rank = local_process_rank + + rank = node_rank * local_world_size + local_rank + world_size = node_world_size * local_world_size + + if not torch.distributed.is_initialized(): + logger.info( + "process group is not available, trying to torch.distributed.init_process_group " + "to communicate necessary setup information." + ) + should_cleanup_distributed_context = True + logger.info( + f"Initializing process group with master ip address: {master_ip_address}, " + f"rank: {rank}, world size: {world_size}, " + f"local_rank: {local_rank}, local_world_size: {local_world_size}." + ) + torch.distributed.init_process_group( + backend="gloo", + init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", + rank=rank, + world_size=world_size, + ) + else: + assert torch.distributed.is_initialized(), ( + "context: DistributedContext is None, so process group must be " + "initialized before constructing the loader." + ) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() + master_ip_address = rank_ip_addresses[0] + + count_ranks_per_ip_address = Counter(rank_ip_addresses) + local_world_size = count_ranks_per_ip_address[master_ip_address] + for rank_ip_address, count in count_ranks_per_ip_address.items(): + if count != local_world_size: + raise ValueError( + f"All ranks must have the same number of processes, but found " + f"{count} processes for rank {rank} on ip {rank_ip_address}, " + f"expected {local_world_size}. " + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" + ) + + node_world_size = len(count_ranks_per_ip_address) + local_rank = rank % local_world_size + node_rank = rank // local_world_size + + return DistributedRuntimeInfo( + node_world_size=node_world_size, + node_rank=node_rank, + rank=rank, + world_size=world_size, + local_rank=local_rank, + local_world_size=local_world_size, + master_ip_address=master_ip_address, + should_cleanup_distributed_context=should_cleanup_distributed_context, + ) + + def __init__( + self, + dataset: Union[DistDataset, RemoteDistDataset], + sampler_input: Union[NodeSamplerInput, list[NodeSamplerInput]], + dataset_schema: DatasetSchema, + worker_options: Union[ + MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions + ], + sampling_config: SamplingConfig, + device: torch.device, + runtime: DistributedRuntimeInfo, + sampler: Union[DistMpSamplingProducer, Callable[..., int]], + process_start_gap_seconds: float = 60.0, + ): + # Set right away so __del__ can clean up if we throw during init. + # Will be set to False once connections are initialized. + self._shutdowned = True + + # Store dataset metadata for subclass _collate_fn usage + self._is_homogeneous_with_labeled_edge_type = ( + dataset_schema.is_homogeneous_with_labeled_edge_type + ) + self._node_feature_info = dataset_schema.node_feature_info + self._edge_feature_info = dataset_schema.edge_feature_info + + # --- Attributes shared by both modes (mirrors GLT DistLoader.__init__) --- + self.input_data = sampler_input + self.sampling_type = sampling_config.sampling_type + self.num_neighbors = sampling_config.num_neighbors + self.batch_size = sampling_config.batch_size + self.shuffle = sampling_config.shuffle + self.drop_last = sampling_config.drop_last + self.with_edge = sampling_config.with_edge + self.with_weight = sampling_config.with_weight + self.collect_features = sampling_config.collect_features + self.edge_dir = sampling_config.edge_dir + self.sampling_config = sampling_config + self.to_device = device + self.worker_options = worker_options + + self._is_collocated_worker = False + self._with_channel = True + self._num_recv = 0 + self._epoch = 0 + + # --- Mode-specific attributes and connection initialization --- + if isinstance(sampler, DistMpSamplingProducer): + assert isinstance(dataset, DistDataset) + assert isinstance(worker_options, MpDistSamplingWorkerOptions) + assert isinstance(sampler_input, NodeSamplerInput) + + self.data: Optional[DistDataset] = dataset + self._is_mp_worker = True + self._is_remote_worker = False + + self.num_data_partitions = dataset.num_partitions + self.data_partition_idx = dataset.partition_idx + self._set_ntypes_and_etypes( + dataset.get_node_types(), dataset.get_edge_types() + ) + + self._input_len = len(sampler_input) + self._input_type = sampler_input.input_type + self._num_expected = self._input_len // self.batch_size + if not self.drop_last and self._input_len % self.batch_size != 0: + self._num_expected += 1 + + self._shutdowned = False + self._init_colocated_connections( + dataset=dataset, + producer=sampler, + runtime=runtime, + process_start_gap_seconds=process_start_gap_seconds, + ) + else: + assert isinstance(dataset, RemoteDistDataset) + assert isinstance(worker_options, RemoteDistSamplingWorkerOptions) + assert isinstance(sampler_input, list) + assert callable(sampler) + + self.data = None + self._is_mp_worker = False + self._is_remote_worker = True + self._num_expected = float("inf") + + self._server_rank_list: list[int] = ( + worker_options.server_rank + if isinstance(worker_options.server_rank, list) + else [worker_options.server_rank] + ) + self._input_data_list = sampler_input + self._input_type = self._input_data_list[0].input_type + + self.num_data_partitions = dataset.cluster_info.num_storage_nodes + self.data_partition_idx = dataset.cluster_info.compute_node_rank + edge_types = dataset_schema.edge_types or [] + if edge_types: + node_types = list( + set([et[0] for et in edge_types] + [et[2] for et in edge_types]) + ) + else: + node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] + self._set_ntypes_and_etypes(node_types, edge_types) + + self._shutdowned = False + self._init_graph_store_connections( + dataset=dataset, + create_producer_fn=sampler, + ) + + @staticmethod + def create_sampling_config( + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], + dataset_schema: DatasetSchema, + batch_size: int = 1, + shuffle: bool = False, + drop_last: bool = False, + ) -> SamplingConfig: + """Creates a SamplingConfig with patched fanout. + + Patches ``num_neighbors`` to zero-out label edge types, then creates + the SamplingConfig used by both colocated and graph store modes. + + Args: + num_neighbors: Fanout per hop. + dataset_schema: Contains edge types and edge dir. + batch_size: How many samples per batch. + shuffle: Whether to shuffle input nodes. + drop_last: Whether to drop the last incomplete batch. + + Returns: + A fully configured SamplingConfig. + """ + num_neighbors = patch_fanout_for_sampling( + edge_types=dataset_schema.edge_types, + num_neighbors=num_neighbors, + ) + return SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir=dataset_schema.edge_dir, + seed=None, + ) + + @staticmethod + def create_colocated_channel( + worker_options: MpDistSamplingWorkerOptions, + ) -> ShmChannel: + """Creates a ShmChannel for colocated mode. + + Creates and optionally pin-memories the shared-memory channel. + + Args: + worker_options: The colocated worker options (must already be fully configured). + + Returns: + A ShmChannel ready to be passed to a DistMpSamplingProducer. + """ + channel = ShmChannel( + worker_options.channel_capacity, worker_options.channel_size + ) + if worker_options.pin_memory: + channel.pin_memory() + return channel + + def _init_colocated_connections( + self, + dataset: DistDataset, + producer: DistMpSamplingProducer, + runtime: DistributedRuntimeInfo, + process_start_gap_seconds: float, + ) -> None: + """Initialize colocated mode connections. + + Validates the GLT distributed context, stores the pre-constructed producer, + and performs staggered initialization to avoid memory OOM. + + All DistLoader attributes are already set by ``__init__`` before this is called. + + Args: + dataset: The local DistDataset. + producer: A pre-constructed DistMpSamplingProducer (or subclass). + runtime: Resolved distributed runtime info (used for staggered sleep). + process_start_gap_seconds: Delay multiplier for staggered init. + """ + # Validate context and store the pre-constructed producer and its channel + current_ctx = get_context() + if not current_ctx.is_worker(): + raise RuntimeError( + f"'{self.__class__.__name__}': only supports " + f"launching multiprocessing sampling workers with " + f"a non-server distribution mode, current role of " + f"distributed context is {current_ctx.role}." + ) + if dataset is None: + raise ValueError( + f"'{self.__class__.__name__}': missing input dataset " + f"when launching multiprocessing sampling workers." + ) + self.worker_options._set_worker_ranks(current_ctx) + self._channel = producer.output_channel + self._mp_producer = producer + + # Staggered init — sleep proportional to local_rank to avoid + # concurrent initialization spikes that cause CPU memory OOM. + logger.info( + f"---Machine {runtime.rank} local process number {runtime.local_rank} " + f"preparing to sleep for {process_start_gap_seconds * runtime.local_rank} seconds" + ) + time.sleep(process_start_gap_seconds * runtime.local_rank) + self._mp_producer.init() + + def _init_graph_store_connections( + self, + dataset: RemoteDistDataset, + create_producer_fn: Callable[..., int], + ) -> None: + """Initialize Graph Store mode connections. + + Validates the GLT distributed context, performs a sequential barrier loop + across compute nodes, dispatches async RPCs to create sampling producers on + storage nodes, and creates a RemoteReceivingChannel. + + All DistLoader attributes are already set by ``__init__`` before this is called. + + Uses ``async_request_server`` instead of ``ThreadPoolExecutor`` to avoid + TensorPipe rendezvous deadlock with many servers. + + For Graph Store mode it's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). + Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. + E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. + + See below for a connection setup. + ╔═══════════════════════════════════════════════════════════════════════════════════════╗ + ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ + ╚═══════════════════════════════════════════════════════════════════════════════════════╝ + + COMPUTE NODES STORAGE NODES + ═════════════ ═════════════ + + ┌──────────────────────┐ (1) ┌───────────────┐ + │ COMPUTE NODE 0 │ │ │ + │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ + │ │GPU │GPU │GPU │GPU │ ╱ │ │ + │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ + │ └────┴────┴────┴────┤ (2) ╲ ╱ + └──────────────────────┘ ╲ ╱ + ╳ + (3) ╱ ╲ (4) + ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ + │ COMPUTE NODE 1 │ ╱ ╲ │ │ + │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ + │ │GPU │GPU │GPU │GPU │ │ │ + │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ + │ └────┴────┴────┴────┤ └───────────────┘ + └──────────────────────┘ + + ┌─────────────────────────────────────────────────────────────────────────────┐ + │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ + │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ + │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ + │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ + └─────────────────────────────────────────────────────────────────────────────┘ + """ + # Validate distributed context + ctx = get_context() + if ctx is None: + raise RuntimeError( + f"'{self.__class__.__name__}': the distributed context " + f"has not been initialized." + ) + if not ctx.is_client(): + raise RuntimeError( + f"'{self.__class__.__name__}': must be used on a client " + f"worker process." + ) + + # Move input to CPU before sending to server + for inp in self._input_data_list: + if not isinstance(inp, RemoteSamplerInput): + inp.to(torch.device("cpu")) + + node_rank = dataset.cluster_info.compute_node_rank + + _flush() + start_time = time.time() + rpc_futures: list[tuple[int, torch.futures.Future[int]]] = [] + # Dispatch ALL create_producer RPCs async. + # async_request_server queues the RPC in TensorPipe and returns + # immediately, allowing all storage nodes to start their worker + # rendezvous simultaneously. + logger.info( + f"node_rank={node_rank} dispatching create_sampling_producer to " + f"{len(self._server_rank_list)} servers" + ) + _flush() + t_dispatch = time.time() + for server_rank, inp_data in zip(self._server_rank_list, self._input_data_list): + fut = async_request_server( + server_rank, + create_producer_fn, + inp_data, + self.sampling_config, + self.worker_options, + ) + rpc_futures.append((server_rank, fut)) + logger.info( + f"node_rank={node_rank} all {len(rpc_futures)} RPCs dispatched in " + f"{time.time() - t_dispatch:.3f}s, waiting for responses" + ) + _flush() + + # Wait for all results + self._producer_id_list: list[int] = [] + for server_rank, fut in rpc_futures: + t_wait = time.time() + producer_id: int = fut.wait() + logger.info( + f"node_rank={node_rank} create_sampling_producer" + f"(server_rank={server_rank}) returned " + f"producer_id={producer_id} in {time.time() - t_wait:.2f}s" + ) + _flush() + self._producer_id_list.append(producer_id) + logger.info( + f"node_rank={node_rank} all {len(self._producer_id_list)} producers " + f"created in {time.time() - t_dispatch:.2f}s total" + ) + _flush() + # Create remote receiving channel for cross-machine message passing + self._channel = RemoteReceivingChannel( + self._server_rank_list, + self._producer_id_list, + self.worker_options.prefetch_size, + ) + + logger.info( + f"node_rank {node_rank} initialized the dist loader in " + f"{time.time() - start_time:.2f}s" + ) + _flush() + + # Overwrite DistLoader.shutdown to so we can use our own shutdown and rpc calls + def shutdown(self) -> None: + if self._shutdowned: + return + if self._is_collocated_worker: + self._collocated_producer.shutdown() + elif self._is_mp_worker: + self._mp_producer.shutdown() + elif rpc_is_initialized() is True: + rpc_futures: list[torch.futures.Future[None]] = [] + for server_rank, producer_id in zip( + self._server_rank_list, self._producer_id_list + ): + fut = async_request_server( + server_rank, DistServer.destroy_sampling_producer, producer_id + ) + rpc_futures.append(fut) + torch.futures.wait_all(rpc_futures) + self._shutdowned = True + + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls + def __iter__(self) -> Self: + self._num_recv = 0 + if self._is_collocated_worker: + self._collocated_producer.reset() + elif self._is_mp_worker: + self._mp_producer.produce_all() + else: + rpc_futures: list[torch.futures.Future[None]] = [] + for server_rank, producer_id in zip( + self._server_rank_list, self._producer_id_list + ): + fut = async_request_server( + server_rank, + DistServer.start_new_epoch_sampling, + producer_id, + self._epoch, + ) + rpc_futures.append(fut) + torch.futures.wait_all(rpc_futures) + self._channel.reset() + self._epoch += 1 + return self diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 1db2fc277..42b04e71e 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,27 +1,20 @@ import ast -import concurrent.futures -import time -from collections import Counter, abc, defaultdict +from collections import abc, defaultdict from itertools import count -from typing import Optional, Union +from typing import Callable, Optional, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel +from graphlearn_torch.channel import SampleMessage from graphlearn_torch.distributed import ( - DistLoader, MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, - get_context, - request_server, ) -from graphlearn_torch.sampler import SamplingConfig, SamplingType -from graphlearn_torch.utils import reverse_edge_type from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType import gigl.distributed.utils from gigl.common.logger import Logger -from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistABLPSamplingProducer @@ -38,7 +31,6 @@ DatasetSchema, SamplingClusterSetup, labeled_to_homogeneous, - patch_fanout_for_sampling, set_missing_features, shard_nodes_by_process, strip_label_edges, @@ -61,7 +53,7 @@ logger = Logger() -class DistABLPLoader(DistLoader): +class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. # NOTE: This is per-class, not per-instance. @@ -204,21 +196,8 @@ def __init__( # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. - # We set to `True` as we don't need to cleanup right away, and this will get set - # to `False` in super().__init__()` e.g. - # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True - node_world_size: int - node_rank: int - rank: int - world_size: int - local_rank: int - local_world_size: int - - master_ip_address: str - should_cleanup_distributed_context: bool = False - # Determine sampling cluster setup based on dataset type if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE @@ -250,81 +229,23 @@ def __init__( del supervision_edge_type self._instance_count = next(self._counter) - self.data: Optional[Union[DistDataset, RemoteDistDataset]] = None - if isinstance(dataset, DistDataset): - self.data = dataset - - if context: - assert ( - local_process_world_size is not None - ), "context: DistributedContext provided, so local_process_world_size must be provided." - assert ( - local_process_rank is not None - ), "context: DistributedContext provided, so local_process_rank must be provided." - - master_ip_address = context.main_worker_ip_address - node_world_size = context.global_world_size - node_rank = context.global_rank - local_world_size = local_process_world_size - local_rank = local_process_rank - - rank = node_rank * local_world_size + local_rank - world_size = node_world_size * local_world_size - - if not torch.distributed.is_initialized(): - logger.info( - "process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information." - ) - should_cleanup_distributed_context = True - logger.info( - f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}" - ) - torch.distributed.init_process_group( - backend="gloo", # We just default to gloo for this temporary process group - init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", - rank=rank, - world_size=world_size, - ) - else: - assert ( - torch.distributed.is_initialized() - ), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}." - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() - master_ip_address = rank_ip_addresses[0] - - count_ranks_per_ip_address = Counter(rank_ip_addresses) - local_world_size = count_ranks_per_ip_address[master_ip_address] - for rank_ip_address, count in count_ranks_per_ip_address.items(): - if count != local_world_size: - raise ValueError( - f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}." - + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" - ) - - node_world_size = len(count_ranks_per_ip_address) - local_rank = rank % local_world_size - node_rank = rank // local_world_size - - del ( - context, - local_process_rank, - local_process_world_size, - ) # delete deprecated vars so we don't accidentally use them. + # Resolve distributed context + runtime = BaseDistLoader.resolve_runtime( + context, local_process_rank, local_process_world_size + ) + del context, local_process_rank, local_process_world_size device = ( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( - local_process_rank=local_rank + local_process_rank=runtime.local_rank ) ) self.to_device = device - # Call appropriate setup method based on sampling cluster setup + # Mode-specific setup if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance( dataset, DistDataset @@ -334,26 +255,29 @@ def __init__( raise ValueError( f"When using Colocated mode, input_nodes must be of type " f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " - f"received Graph Store format: dict[int, ABLPInputNodes]" + f"received {type(input_nodes)}" ) - ( - sampler_input, - worker_options, - dataset_metadata, - ) = self._setup_for_colocated( + setup_info = self._setup_for_colocated( input_nodes=input_nodes, dataset=dataset, - local_rank=local_rank, - local_world_size=local_world_size, + local_rank=runtime.local_rank, + local_world_size=runtime.local_world_size, device=device, - master_ip_address=master_ip_address, - node_rank=node_rank, - node_world_size=node_world_size, + master_ip_address=runtime.master_ip_address, + node_rank=runtime.node_rank, + node_world_size=runtime.node_world_size, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, num_cpu_threads=num_cpu_threads, ) + sampler_input: Union[ + ABLPNodeSamplerInput, list[ABLPNodeSamplerInput] + ] = setup_info[0] + worker_options: Union[ + MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions + ] = setup_info[1] + dataset_schema: DatasetSchema = setup_info[2] else: # Graph Store mode assert isinstance( dataset, RemoteDistDataset @@ -371,7 +295,7 @@ def __init__( ( sampler_input, worker_options, - dataset_metadata, + dataset_schema, ) = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, @@ -380,146 +304,56 @@ def __init__( prefetch_size=prefetch_size, ) - self.is_homogeneous_with_labeled_edge_type = ( - dataset_metadata.is_homogeneous_with_labeled_edge_type - ) - self._node_feature_info = dataset_metadata.node_feature_info - self._edge_feature_info = dataset_metadata.edge_feature_info - - num_neighbors = patch_fanout_for_sampling( - dataset_metadata.edge_types, num_neighbors - ) - - if should_cleanup_distributed_context and torch.distributed.is_initialized(): + # Cleanup temporary process group if needed + if ( + runtime.should_cleanup_distributed_context + and torch.distributed.is_initialized() + ): logger.info( f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." ) torch.distributed.destroy_process_group() - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, + # Create SamplingConfig (with patched fanout) + sampling_config = BaseDistLoader.create_sampling_config( num_neighbors=num_neighbors, + dataset_schema=dataset_schema, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset_metadata.edge_dir, - seed=None, # it's actually optional - None means random. ) + # Build the sampler: a pre-constructed producer for colocated mode, + # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: - # Code below this point is taken from the GLT DistNeighborLoader.__init__() function - # (graphlearn_torch/python/distributed/dist_neighbor_loader.py). - # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. - - # Type narrowing for colocated mode - - self.input_data = sampler_input[0] - del sampler_input - assert isinstance(self.data, DistDataset) - assert isinstance(self.input_data, ABLPNodeSamplerInput) - - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.worker_options = worker_options - - # We can set shutdowned to false now - self._shutdowned = False - - self._is_mp_worker = True - self._is_collocated_worker = False - self._is_remote_worker = False - - self.num_data_partitions = self.data.num_partitions - self.data_partition_idx = self.data.partition_idx - self._set_ntypes_and_etypes( - self.data.get_node_types(), self.data.get_edge_types() - ) - - self._num_recv = 0 - self._epoch = 0 - - current_ctx = get_context() - - self._input_len = len(self.input_data) - self._input_type = self.input_data.input_type - self._num_expected = self._input_len // self.batch_size - if not self.drop_last and self._input_len % self.batch_size != 0: - self._num_expected += 1 - - if not current_ctx.is_worker(): - raise RuntimeError( - f"'{self.__class__.__name__}': only supports " - f"launching multiprocessing sampling workers with " - f"a non-server distribution mode, current role of " - f"distributed context is {current_ctx.role}." - ) - if self.data is None: - raise ValueError( - f"'{self.__class__.__name__}': missing input dataset " - f"when launching multiprocessing sampling workers." - ) - - # Launch multiprocessing sampling workers - self._with_channel = True - self.worker_options._set_worker_ranks(current_ctx) - - self._channel = ShmChannel( - self.worker_options.channel_capacity, self.worker_options.channel_size - ) - if self.worker_options.pin_memory: - self._channel.pin_memory() - - self._mp_producer = DistABLPSamplingProducer( - self.data, - self.input_data, - self.sampling_config, - self.worker_options, - self._channel, - ) - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + assert isinstance(dataset, DistDataset) + assert isinstance(worker_options, MpDistSamplingWorkerOptions) + channel = BaseDistLoader.create_colocated_channel(worker_options) + sampler: Union[ + DistABLPSamplingProducer, Callable[..., int] + ] = DistABLPSamplingProducer( + dataset, + sampler_input, + sampling_config, + worker_options, + channel, ) - time.sleep(process_start_gap_seconds * local_rank) - self._mp_producer.init() else: - # Graph Store mode - re-implement remote worker setup - # Use sequential initialization per compute node to avoid race conditions - # when initializing the samplers on the storage nodes. - node_rank = dataset.cluster_info.compute_node_rank - for target_node_rank in range(dataset.cluster_info.num_compute_nodes): - if node_rank == target_node_rank: - self._init_remote_worker( - dataset=dataset, - sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, - dataset_metadata=dataset_metadata, - ) - logger.info( - f"node_rank {node_rank} / {dataset.cluster_info.num_compute_nodes} initialized the dist loader" - ) - torch.distributed.barrier() - torch.distributed.barrier() - logger.info( - f"node_rank {node_rank} / {dataset.cluster_info.num_compute_nodes} finished initializing the dist loader" - ) + sampler = DistServer.create_sampling_ablp_producer + + # Call base class — handles metadata storage and connection initialization + # (including staggered init for colocated mode). + super().__init__( + dataset=dataset, + sampler_input=sampler_input, + dataset_schema=dataset_schema, + worker_options=worker_options, + sampling_config=sampling_config, + device=device, + runtime=runtime, + sampler=sampler, + process_start_gap_seconds=process_start_gap_seconds, + ) def _setup_for_colocated( self, @@ -540,7 +374,7 @@ def _setup_for_colocated( worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], - ) -> tuple[list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema]: + ) -> tuple[ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: """ Setup method for colocated (non-Graph Store) mode. @@ -559,7 +393,7 @@ def _setup_for_colocated( num_cpu_threads: Number of CPU threads for PyTorch. Returns: - Tuple of (list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema). + Tuple of (ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema). """ # Validate input format - should not be Graph Store format if isinstance(input_nodes, abc.Mapping): @@ -741,7 +575,7 @@ def _setup_for_colocated( edge_types = list(dataset.graph.keys()) return ( - [sampler_input], + sampler_input, worker_options, DatasetSchema( is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type, @@ -788,10 +622,6 @@ def _setup_for_graph_store( num_ports=dataset.cluster_info.num_compute_nodes ) sampling_port = sampling_ports[node_rank] - # TODO(kmonte) - We need to be able to differentiate between different instances of the same loader. - # e.g. if we have two different DistABLPLoaders, then they will have conflicting worker keys. - # And they will share each others data. Therefor, the second loader will not load the data it's expecting. - # Probably, we can just keep track of the insantiations on the server-side and include the count in the worker key. worker_key = ( f"compute_ablp_loader_rank_{node_rank}_worker_{self._instance_count}" ) @@ -919,114 +749,6 @@ def _setup_for_graph_store( ), ) - def _init_remote_worker( - self, - dataset: RemoteDistDataset, - sampler_input: list[ABLPNodeSamplerInput], - sampling_config: SamplingConfig, - worker_options: RemoteDistSamplingWorkerOptions, - dataset_metadata: DatasetSchema, - ) -> None: - """ - Initialize the remote worker code path for Graph Store mode. - - This re-implements GLT's DistLoader remote worker setup but uses GiGL's DistServer. - - Args: - dataset: The RemoteDistDataset to sample from. - sampler_input: List of ABLPNodeSamplerInput, one per server. - sampling_config: Configuration for sampling. - worker_options: Options for remote sampling workers. - dataset_metadata: Metadata about the dataset schema. - """ - # Set instance variables (like DistLoader does) - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.worker_options = worker_options - - self._shutdowned = False - - # Set worker type flags - self._is_mp_worker = False - self._is_collocated_worker = False - self._is_remote_worker = True - - # For remote worker, end of epoch is determined by server - self._num_expected = float("inf") - self._with_channel = True - - self._num_recv = 0 - self._epoch = 0 - - # Get server rank list from worker_options - self._server_rank_list = ( - worker_options.server_rank - if isinstance(worker_options.server_rank, list) - else [worker_options.server_rank] - ) - self._input_data_list = sampler_input # Already a list (one per server) - - # Get input type from first input - self._input_type = self._input_data_list[0].input_type - - # Get dataset metadata from cluster_info (not via RPC) - self.num_data_partitions = dataset.cluster_info.num_storage_nodes - self.data_partition_idx = dataset.cluster_info.compute_node_rank - - # Derive node types from edge types - # For labeled homogeneous: edge_types contains DEFAULT_HOMOGENEOUS_EDGE_TYPE - # For heterogeneous: extract unique src/dst types from edge types - edge_types = dataset_metadata.edge_types or [] - if edge_types: - node_types = list( - set([et[0] for et in edge_types] + [et[2] for et in edge_types]) - ) - else: - node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] - self._set_ntypes_and_etypes(node_types, edge_types) - - # Create sampling producers on each server (concurrently) - # Move input data to CPU before sending to server - for input_data in self._input_data_list: - input_data.to(torch.device("cpu")) - - self._producer_id_list = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit( - request_server, - server_rank, - DistServer.create_sampling_ablp_producer, - input_data, - self.sampling_config, - self.worker_options, - ) - for server_rank, input_data in zip( - self._server_rank_list, self._input_data_list - ) - ] - - for future in futures: - producer_id = future.result() - self._producer_id_list.append(producer_id) - logger.info( - f"DistABLPLoader rank {torch.distributed.get_rank()} producers: ({[producer_id for producer_id in self._producer_id_list]})" - ) - # Create remote receiving channel for cross-machine message passing - self._channel = RemoteReceivingChannel( - self._server_rank_list, - self._producer_id_list, - self.worker_options.prefetch_size, - ) - def _get_labels( self, msg: SampleMessage ) -> tuple[ @@ -1190,7 +912,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: ) if isinstance(data, HeteroData): data = strip_label_edges(data) - if not self.is_homogeneous_with_labeled_edge_type: + if not self._is_homogeneous_with_labeled_edge_type: if len(self._supervision_edge_types) != 1: raise ValueError( f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 1c0634042..867e80f29 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -1,39 +1,30 @@ import sys -import time -from collections import Counter, abc +from collections import abc from itertools import count -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage +from graphlearn_torch.channel import SampleMessage from graphlearn_torch.distributed import ( - DistLoader, MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, ) -from graphlearn_torch.distributed.dist_context import get_context -from graphlearn_torch.sampler import ( - NodeSamplerInput, - RemoteSamplerInput, - SamplingConfig, - SamplingType, -) +from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer +from graphlearn_torch.sampler import NodeSamplerInput from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType import gigl.distributed.utils from gigl.common.logger import Logger -from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.graph_store.compute import async_request_server, request_server from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, labeled_to_homogeneous, - patch_fanout_for_sampling, set_missing_features, shard_nodes_by_process, strip_label_edges, @@ -52,12 +43,14 @@ DEFAULT_NUM_CPU_THREADS = 2 +# We don't see logs for graph store mode for whatever reason. +# TOOD(#442): Revert this once the GCP issues are resolved. def flush(): sys.stdout.flush() sys.stderr.flush() -class DistNeighborLoader(DistLoader): +class DistNeighborLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. # NOTE: This is per-class, not per-instance. @@ -90,6 +83,12 @@ def __init__( drop_last: bool = False, ): """ + Distributed Neighbor Loader. + Takes in some input nodes and samples neighbors from the dataset. + This loader should be used if you do not have any specially sampling needs, + e.g. you need to generate *training* examples for Anchor Based Link Prediction (ABLP) tasks. + Though this loader is useful for generating random negative examples for ABLP training. + Note: We try to adhere to pyg dataloader api as much as possible. See the following for reference: https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader @@ -151,81 +150,15 @@ def __init__( # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. - # We set to `True` as we don't need to cleanup right away, and this will get set - # to `False` in super().__init__()` e.g. - # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True - node_world_size: int - node_rank: int - rank: int - world_size: int - local_rank: int - local_world_size: int - - master_ip_address: str - should_cleanup_distributed_context: bool = False - - if context: - assert ( - local_process_world_size is not None - ), "context: DistributedContext provided, so local_process_world_size must be provided." - assert ( - local_process_rank is not None - ), "context: DistributedContext provided, so local_process_rank must be provided." - - master_ip_address = context.main_worker_ip_address - node_world_size = context.global_world_size - node_rank = context.global_rank - local_world_size = local_process_world_size - local_rank = local_process_rank - - rank = node_rank * local_world_size + local_rank - world_size = node_world_size * local_world_size - - if not torch.distributed.is_initialized(): - logger.info( - "process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information." - ) - should_cleanup_distributed_context = True - logger.info( - f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}." - ) - torch.distributed.init_process_group( - backend="gloo", # We just default to gloo for this temporary process group - init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", - rank=rank, - world_size=world_size, - ) - - else: - assert ( - torch.distributed.is_initialized() - ), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}." - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() - master_ip_address = rank_ip_addresses[0] - - count_ranks_per_ip_address = Counter(rank_ip_addresses) - local_world_size = count_ranks_per_ip_address[master_ip_address] - for rank_ip_address, count in count_ranks_per_ip_address.items(): - if count != local_world_size: - raise ValueError( - f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}." - + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" - ) - - node_world_size = len(count_ranks_per_ip_address) - local_rank = rank % local_world_size - node_rank = rank // local_world_size + # Resolve distributed context + runtime = BaseDistLoader.resolve_runtime( + context, local_process_rank, local_process_world_size + ) + del context, local_process_rank, local_process_world_size - del ( - context, - local_process_rank, - local_process_world_size, - ) # delete deprecated vars so we don't accidentally use them. + # Determine mode if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE else: @@ -241,37 +174,37 @@ def __init__( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( - local_process_rank=local_rank + local_process_rank=runtime.local_rank ) ) - # Determines if the node ids passed in are heterogeneous or homogeneous. + # Mode-specific setup if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance( dataset, DistDataset ), "When using colocated mode, dataset must be a DistDataset." - input_data, worker_options, dataset_metadata = self._setup_for_colocated( + input_data, worker_options, dataset_schema = self._setup_for_colocated( input_nodes=input_nodes, dataset=dataset, - local_rank=local_rank, - local_world_size=local_world_size, + local_rank=runtime.local_rank, + local_world_size=runtime.local_world_size, device=device, - master_ip_address=master_ip_address, - node_rank=node_rank, - node_world_size=node_world_size, + master_ip_address=runtime.master_ip_address, + node_rank=runtime.node_rank, + node_world_size=runtime.node_world_size, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, num_cpu_threads=num_cpu_threads, ) - else: # Graph Store mode + else: assert isinstance( dataset, RemoteDistDataset ), "When using Graph Store mode, dataset must be a RemoteDistDataset." if prefetch_size is None: logger.info(f"prefetch_size is not provided, using default of 4") prefetch_size = 4 - input_data, worker_options, dataset_metadata = self._setup_for_graph_store( + input_data, worker_options, dataset_schema = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, num_workers=num_workers, @@ -279,65 +212,56 @@ def __init__( channel_size=channel_size, ) - self._is_homogeneous_with_labeled_edge_type = ( - dataset_metadata.is_homogeneous_with_labeled_edge_type - ) - self._node_feature_info = dataset_metadata.node_feature_info - self._edge_feature_info = dataset_metadata.edge_feature_info + # Cleanup temporary process group if needed + if ( + runtime.should_cleanup_distributed_context + and torch.distributed.is_initialized() + ): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() - logger.info(f"num_neighbors before patch: {num_neighbors}") - num_neighbors = patch_fanout_for_sampling( - edge_types=dataset_metadata.edge_types, - num_neighbors=num_neighbors, - ) - logger.info( - f"num_neighbors: {num_neighbors}, edge_types: {dataset_metadata.edge_types}" - ) - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, + # Create SamplingConfig (with patched fanout) + sampling_config = BaseDistLoader.create_sampling_config( num_neighbors=num_neighbors, + dataset_schema=dataset_schema, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset_metadata.edge_dir, - seed=None, # it's actually optional - None means random. ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): - logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." - ) - torch.distributed.destroy_process_group() - + # Build the sampler: a pre-constructed producer for colocated mode, + # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" - ) - time.sleep(process_start_gap_seconds * local_rank) - super().__init__( - dataset, # Pass in the dataset for colocated mode. + assert isinstance(dataset, DistDataset) + assert isinstance(worker_options, MpDistSamplingWorkerOptions) + channel = BaseDistLoader.create_colocated_channel(worker_options) + sampler: Union[ + DistMpSamplingProducer, Callable[..., int] + ] = DistMpSamplingProducer( + dataset, input_data, sampling_config, - device, worker_options, + channel, ) else: - self._init_graph_store_connections( - dataset=dataset, - input_data=input_data, - sampling_config=sampling_config, - device=device, - worker_options=worker_options, - ) + sampler = GiglDistServer.create_sampling_producer + + # Call base class — handles metadata storage and connection initialization + # (including staggered init for colocated mode). + super().__init__( + dataset=dataset, + sampler_input=input_data, + dataset_schema=dataset_schema, + worker_options=worker_options, + sampling_config=sampling_config, + device=device, + runtime=runtime, + sampler=sampler, + process_start_gap_seconds=process_start_gap_seconds, + ) def _setup_for_graph_store( self, @@ -353,7 +277,7 @@ def _setup_for_graph_store( num_workers: int, prefetch_size: int, channel_size: str, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: + ) -> tuple[list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -430,7 +354,7 @@ def _setup_for_graph_store( servers = nodes.keys() if max(servers) >= dataset.cluster_info.num_storage_nodes or min(servers) < 0: raise ValueError( - f"When using Graph Store mode, the server ranks must be less than the number of storage nodes and greater than 0, received inputs for servers: {list(nodes.keys())}" + f"When using Graph Store mode, the server ranks must be in range [0, num_servers ({dataset.cluster_info.num_storage_nodes})), received inputs for servers: {list(nodes.keys())}" ) input_data: list[NodeSamplerInput] = [] for server_rank in range(dataset.cluster_info.num_storage_nodes): @@ -606,236 +530,6 @@ def _setup_for_colocated( ), ) - def _init_graph_store_connections( - self, - dataset: RemoteDistDataset, - input_data: list[NodeSamplerInput], - sampling_config: SamplingConfig, - device: torch.device, - worker_options: RemoteDistSamplingWorkerOptions, - ): - # Graph Store mode — initialize DistLoader attributes directly instead of - # calling super().__init__() to avoid the ThreadPoolExecutor deadlock at scale. - # - # GLT's DistLoader.__init__() dispatches create_sampling_producer RPCs via - # ThreadPoolExecutor(max_workers=32). With 60+ servers, only 32 threads run, - # causing a TensorPipe rendezvous deadlock. Instead, we inline the DistLoader - # init code and dispatch all RPCs asynchronously in a simple loop. - - node_rank = dataset.cluster_info.compute_node_rank - num_storage_nodes = dataset.cluster_info.num_storage_nodes - - # --- Set all DistLoader attributes (mirrors GLT DistLoader.__init__) --- - # These are required by inherited methods: shutdown(), __iter__(), __next__(), - # __del__(), _collate_fn(), _set_ntypes_and_etypes(). - self.data = None # No local data in Graph Store mode - self.input_data = input_data - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.to_device = device - self.worker_options = worker_options - self._shutdowned = False - - self._is_collocated_worker = False - self._is_mp_worker = False - self._is_remote_worker = True - - self._num_recv = 0 - self._epoch = 0 - - # Context validation - ctx = get_context() - if ctx is None: - raise RuntimeError( - f"'{self.__class__.__name__}': the distributed context " - f"has not been initialized." - ) - if not ctx.is_client(): - raise RuntimeError( - f"'{self.__class__.__name__}': must be used on a client " - f"worker process." - ) - - # Remote worker attributes - self._num_expected = float("inf") - self._with_channel = True - - self._server_rank_list: list[int] = ( - self.worker_options.server_rank - if isinstance(self.worker_options.server_rank, list) - else [self.worker_options.server_rank] - ) - self._input_data_list: list[NodeSamplerInput] = ( - self.input_data if isinstance(self.input_data, list) else [self.input_data] - ) - self._input_type = self._input_data_list[0].input_type - - # --- Barrier loop: one compute node at a time --- - logger.info( - f"node_rank {node_rank} starting barrier loop with " - f"{dataset.cluster_info.num_compute_nodes} compute nodes" - ) - flush() - # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. - # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. - # Note that each compute node may have multiple connections to each storage node, once per compute process. - # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). - # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. - # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. - # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] - # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 - # Then we deadlock and fail. - # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] - # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 - # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 - - # See below for a connection setup. - # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ - # ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ - # ╚═══════════════════════════════════════════════════════════════════════════════════════╝ - - # COMPUTE NODES STORAGE NODES - # ═════════════ ═════════════ - - # ┌──────────────────────┐ (1) ┌───────────────┐ - # │ COMPUTE NODE 0 │ │ │ - # │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ - # │ │GPU │GPU │GPU │GPU │ ╱ │ │ - # │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ - # │ └────┴────┴────┴────┤ (2) ╲ ╱ - # └──────────────────────┘ ╲ ╱ - # ╳ - # (3) ╱ ╲ (4) - # ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ - # │ COMPUTE NODE 1 │ ╱ ╲ │ │ - # │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ - # │ │GPU │GPU │GPU │GPU │ │ │ - # │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ - # │ └────┴────┴────┴────┤ └───────────────┘ - # └──────────────────────┘ - - # ┌─────────────────────────────────────────────────────────────────────────────┐ - # │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ - # │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ - # │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ - # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ - # └─────────────────────────────────────────────────────────────────────────────┘ - for target_node_rank in range(dataset.cluster_info.num_compute_nodes): - start_time = time.time() - if node_rank == target_node_rank: - # Step 1: Get dataset metadata via RPC (single call, fast) - ( - self.num_data_partitions, - self.data_partition_idx, - ntypes, - etypes, - ) = request_server( - self._server_rank_list[0], - GiglDistServer.get_dataset_meta, - ) - self._set_ntypes_and_etypes(ntypes, etypes) - - # Step 2: Move input data to CPU if needed - for i, inp in enumerate(self._input_data_list): - if not isinstance(inp, RemoteSamplerInput): - self._input_data_list[i] = inp.to(torch.device("cpu")) - - # Step 3: Dispatch ALL create_sampling_producer RPCs async. - # - # This is the key fix: async_request_server queues the RPC - # in TensorPipe and returns immediately. By dispatching all - # N RPCs in a loop BEFORE waiting for any response, all - # storage nodes receive the RPC and start their worker - # rendezvous simultaneously. No ThreadPoolExecutor needed. - logger.info( - f"node_rank={node_rank} dispatching " - f"create_sampling_producer to " - f"{num_storage_nodes} servers" - ) - flush() - t_dispatch = time.time() - rpc_futures: list[tuple[int, torch.futures.Future[int]]] = [] - for server_rank, inp_data in zip( - self._server_rank_list, self._input_data_list - ): - fut = async_request_server( - server_rank, - GiglDistServer.create_sampling_producer, - inp_data, - self.sampling_config, - self.worker_options, - ) - rpc_futures.append((server_rank, fut)) - logger.info( - f"node_rank={node_rank} all " - f"{len(rpc_futures)} RPCs dispatched in " - f"{time.time() - t_dispatch:.3f}s, " - f"waiting for responses" - ) - flush() - - # Step 4: Wait for all results - self._producer_id_list: list[int] = [] - for server_rank, fut in rpc_futures: - t_wait = time.time() - producer_id: int = fut.wait() - logger.info( - f"node_rank={node_rank} " - f"create_sampling_producer" - f"(server_rank={server_rank}) returned " - f"producer_id={producer_id} in " - f"{time.time() - t_wait:.2f}s" - ) - flush() - self._producer_id_list.append(producer_id) - logger.info( - f"node_rank={node_rank} all " - f"{len(self._producer_id_list)} producers created " - f"in {time.time() - t_dispatch:.2f}s total" - ) - flush() - - # Step 5: Create remote receiving channel - self._channel = RemoteReceivingChannel( - self._server_rank_list, - self._producer_id_list, - self.worker_options.prefetch_size, - ) - - logger.info( - f"node_rank {node_rank} initialized the dist loader in " - f"{time.time() - start_time:.2f}s" - ) - flush() - else: - logger.info( - f"node_rank {node_rank} waiting for barrier " - f"for rank {target_node_rank}" - ) - flush() - torch.distributed.barrier(device_ids=torch.device("cpu")) - logger.info( - f"node_rank {node_rank} barrier for rank " - f"{target_node_rank} in {time.time() - start_time:.2f}s" - ) - flush() - - torch.distributed.barrier(device_ids=torch.device("cpu")) - logger.info( - f"node_rank {node_rank}: all " - f"{dataset.cluster_info.num_compute_nodes} node ranks " - f"initialized the dist loader" - ) - flush() - def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) data = set_missing_features(