From b704eeb2cf85472aeb09ecf0ced3a6fe10e35fed Mon Sep 17 00:00:00 2001 From: Wolfgang De Salvador Date: Thu, 18 Jun 2026 10:57:54 +0000 Subject: [PATCH 1/2] feat: MPI rank-0 streamed file listing with hash sharding and shared memory Single-rank directory walk with chunked bcast (1M files per broadcast), hash-based sharding (adler32), epoch-dependent reshard via MPI alltoall, SharedFileList backed by POSIX shared memory, worker pre-warming, and fix for cross-epoch cache invalidation in persistent workers. Key changes: - Only rank 0 walks the filesystem; files streamed to all ranks in chunks - Each rank keeps files where adler32(path+epoch_salt) % comm_size == rank - SharedFileList stores paths in /dev/shm (139B pickle vs 67MB per worker) - alltoall reshard each epoch so files migrate between ranks (when shuffle=ON) - Workers pre-warmed before epoch 1 timing via iter()+next() - _localfs_ensure_cached always re-reads (no stale cache with persistent_workers) - allreduce(MIN) alignment prevents barrier deadlocks from uneven shards - Timing logs for listing, sharding, and resharding Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../data_loader/torch_data_loader.py | 61 +++-- dlio_benchmark/main.py | 239 ++++++++++++---- .../reader/_local_fs_iterable_mixin.py | 19 +- dlio_benchmark/utils/config.py | 255 +++++++++++++----- dlio_benchmark/utils/utility.py | 11 + 5 files changed, 441 insertions(+), 144 deletions(-) diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index f47e4eaa..a68529fe 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -415,25 +415,33 @@ def __init__(self, rank, size, num_samples, epochs): self.rank = rank self.num_samples = num_samples self.epochs = epochs - # Use floor division so every rank gets the same sample count. With - # math.ceil() the last rank was clamped to fewer samples than its - # peers when num_samples % size != 0; mismatched per-rank batch - # counts caused the per-step and end-of-epoch barriers in main._train() - # to match across iterations and deadlock at the next epoch boundary. - # The trailing (num_samples % size) samples are dropped on purpose; - # pick num_samples as a multiple of comm_size to use every sample. - samples_per_proc = num_samples // size - start_sample = self.rank * samples_per_proc - end_sample = (self.rank + 1) * samples_per_proc - 1 - self.indices = list(range(start_sample, end_sample + 1)) - dropped = num_samples - samples_per_proc * size - if dropped > 0 and self.rank == 0: - DLIOLogger.get_instance().warning( - f"{utcnow()} dlio_sampler: dropping {dropped} sample(s) — " - f"num_samples ({num_samples}) is not a multiple of comm_size " - f"({size}). Each rank will process {samples_per_proc} samples. " - f"Choose num_samples as a multiple of {size} to use every sample." - ) + args = ConfigArguments.get_instance() + if args.files_pre_sharded: + # Files already distributed — this rank owns all its local samples. + # Round-robin gives even counts, but allreduce_min as safety net. + aligned = DLIOMPI.get_instance().allreduce_min(num_samples) + self.indices = list(range(aligned)) + self.num_samples = aligned + else: + # Use floor division so every rank gets the same sample count. With + # math.ceil() the last rank was clamped to fewer samples than its + # peers when num_samples % size != 0; mismatched per-rank batch + # counts caused the per-step and end-of-epoch barriers in main._train() + # to match across iterations and deadlock at the next epoch boundary. + # The trailing (num_samples % size) samples are dropped on purpose; + # pick num_samples as a multiple of comm_size to use every sample. + samples_per_proc = num_samples // size + start_sample = self.rank * samples_per_proc + end_sample = (self.rank + 1) * samples_per_proc - 1 + self.indices = list(range(start_sample, end_sample + 1)) + dropped = num_samples - samples_per_proc * size + if dropped > 0 and self.rank == 0: + DLIOLogger.get_instance().warning( + f"{utcnow()} dlio_sampler: dropping {dropped} sample(s) — " + f"num_samples ({num_samples}) is not a multiple of comm_size " + f"({size}). Each rank will process {samples_per_proc} samples. " + f"Choose num_samples as a multiple of {size} to use every sample." + ) def __len__(self): @@ -623,8 +631,8 @@ def read(self): else: kwargs={'multiprocessing_context':self._args.multiprocessing_context, 'prefetch_factor': prefetch_factor} - if torch.__version__ != '1.3.1': - kwargs['persistent_workers'] = True + # persistent_workers=False: workers re-spawn each epoch to pick up + # resharded file lists from updated serial_args. pin_memory = self._args.pin_memory and torch.cuda.is_available() if torch.__version__ == '1.3.1': if 'prefetch_factor' in kwargs: @@ -665,6 +673,17 @@ def next(self): dlp.update(epoch=self.epoch_number) dft_ai.update(epoch=self.epoch_number) + def refresh_args(self): + """Re-serialize ConfigArguments so next worker spawn picks up changes. + + Call after reconfigure() to propagate resharded file lists to workers. + Without persistent_workers, workers re-spawn on each iter() and call + worker_init → pickle.loads(serial_args), so this ensures freshness. + """ + dataset = self._dataset.dataset + if hasattr(dataset, 'serial_args'): + dataset.serial_args = pickle.dumps(self._args) + @dlp.log def finalize(self): # When read_threads=0 the reader lives in-process on the dataset object. diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index f7fdb118..5d05de26 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -18,6 +18,7 @@ import math import subprocess import time + import numpy as np # Reduce TF and CUDA logging @@ -214,52 +215,148 @@ def initialize(self): file_list_eval = [] num_subfolders = 0 if self.args.do_train: + # ── Streaming rank-0 file listing with round-robin sharding ── + # Only rank 0 performs directory walks. Files are streamed in + # chunks and each rank keeps every comm_size-th file (round-robin). + # This gives perfectly balanced shards when total % comm_size == 0. + self.args.files_pre_sharded = True + for dataset_type in [DatasetType.TRAIN, DatasetType.VALID]: + t_listing_start = time.time() if dataset_type == DatasetType.TRAIN: num_subfolders = self.num_subfolders_train else: num_subfolders = self.num_subfolders_eval - walk_path = os.path.join(self.args.data_folder, f"{dataset_type}") - filenames = self.storage.walk_node(walk_path) - self.logger.debug(f"filenames {filenames} {num_subfolders}") - if (len(filenames) == 0): - continue - check_path = os.path.join(self.args.data_folder, f"{dataset_type}", filenames[0]) - if self.storage.get_node( - check_path) == MetadataType.DIRECTORY: - assert (num_subfolders == len(filenames)) - fullpaths = self.storage.walk_node( - os.path.join(self.args.data_folder, f"{dataset_type}/*/*.{self.args.format}"), - use_pattern=True) - files = [self.storage.get_basename(f) for f in fullpaths] - idx = np.argsort(files) - fullpaths = [fullpaths[i] for i in idx] - self.logger.debug(f"fullpaths {fullpaths}") + + my_files = [] + global_count = 0 + + _CHUNK_SIZE = 1_000_000 # max files per bcast to bound rank-0 memory + + def _filter_round_robin(chunk, start_idx): + """Keep files where (start_idx + position) % comm_size == my_rank.""" + for i, fpath in enumerate(chunk): + if (start_idx + i) % self.comm_size == self.my_rank: + my_files.append(fpath) + + if num_subfolders > 0: + # ── Subfoldered layout: stream with chunked bcast ───── + subfolder_names = None + if self.my_rank == 0: + walk_path = os.path.join(self.args.data_folder, f"{dataset_type}") + subfolder_names = sorted(self.storage.walk_node(walk_path)) + subfolder_names = self.comm.bcast(subfolder_names, root=0) + + if self.my_rank == 0: + # Multi-threaded listing of subfolders on rank 0 + from concurrent.futures import ThreadPoolExecutor + + def _list_subfolder(sf_name): + sf_path = os.path.join( + self.args.data_folder, f"{dataset_type}", + sf_name, f"*.{self.args.format}") + return self.storage.walk_node(sf_path, use_pattern=True) + + pending = [] + listing_threads = self.args.listing_threads + with ThreadPoolExecutor(max_workers=listing_threads) as pool: + for sf_files in pool.map(_list_subfolder, subfolder_names): + pending.extend(sf_files) + # Flush in chunks of _CHUNK_SIZE + while len(pending) >= _CHUNK_SIZE: + chunk = sorted(pending[:_CHUNK_SIZE]) + pending = pending[_CHUNK_SIZE:] + chunk = self.comm.bcast(chunk, root=0) + _filter_round_robin(chunk, global_count) + global_count += len(chunk) + del chunk + + # Flush remaining + if pending: + chunk = sorted(pending) + pending = [] + chunk = self.comm.bcast(chunk, root=0) + _filter_round_robin(chunk, global_count) + global_count += len(chunk) + del chunk + # Signal end: broadcast empty list + self.comm.bcast([], root=0) + else: + # Non-root ranks: receive chunks until empty sentinel + while True: + chunk = self.comm.bcast(None, root=0) + if not chunk: + break + _filter_round_robin(chunk, global_count) + global_count += len(chunk) + del chunk + else: - assert (num_subfolders == 0) - fullpaths = [self.storage.get_uri(os.path.join(self.args.data_folder, f"{dataset_type}", entry)) - for entry in filenames if entry.endswith(f'{self.args.format}')] - fullpaths = sorted(fullpaths) - self.logger.debug(f"fullpaths {fullpaths}") - self.logger.debug(f"subfolder {num_subfolders} fullpaths {fullpaths}") + # ── Flat layout: stream in chunks of _CHUNK_SIZE ────── + if self.my_rank == 0: + walk_path = os.path.join(self.args.data_folder, f"{dataset_type}") + filenames = self.storage.walk_node(walk_path) + pending = sorted([ + self.storage.get_uri( + os.path.join(self.args.data_folder, f"{dataset_type}", entry)) + for entry in filenames + if entry.endswith(f'{self.args.format}') + ]) + # Send in chunks + for i in range(0, len(pending), _CHUNK_SIZE): + chunk = pending[i:i + _CHUNK_SIZE] + chunk = self.comm.bcast(chunk, root=0) + _filter_round_robin(chunk, global_count) + global_count += len(chunk) + del chunk + del pending + # Signal end + self.comm.bcast([], root=0) + else: + while True: + chunk = self.comm.bcast(None, root=0) + if not chunk: + break + _filter_round_robin(chunk, global_count) + global_count += len(chunk) + del chunk + + # ── Validation ─────────────────────────────────────────── if dataset_type is DatasetType.TRAIN: - file_list_train = fullpaths + expected = self.num_files_train + else: + expected = self.num_files_eval if self.do_eval else 0 + + if not self.generate_only and expected > global_count: + raise Exception( + "Not enough dataset is found; Please run the code with " + "++workload.workflow.generate_data=True") + + # Floor-division: ensure every rank has the same file count. + # Round-robin gives ranks 0..r-1 one extra file; trim to floor. + effective = min(expected, global_count) if expected > 0 else global_count + files_per_rank = effective // self.comm_size + my_files = my_files[:files_per_rank] + + if dataset_type is DatasetType.TRAIN: + file_list_train = my_files + global_train_count = global_count elif dataset_type is DatasetType.VALID: - file_list_eval = fullpaths - if not self.generate_only and self.num_files_train > len(file_list_train): - raise Exception( - "Not enough training dataset is found; Please run the code with ++workload.workflow.generate_data=True") - if self.do_eval and self.num_files_eval > len(file_list_eval): - raise Exception( - "Not enough evaluation dataset is found; Please run the code with ++workload.workflow.generate_data=True") - if (self.num_files_train < len(file_list_train)): - self.logger.warning( - f"Number of files for training in {os.path.join(self.args.data_folder, f'{DatasetType.TRAIN}')} ({len(file_list_train)}) is more than requested ({self.num_files_train}). A subset of files will be used ") - file_list_train = file_list_train[:self.num_files_train] - if (self.num_files_eval < len(file_list_eval)): - self.logger.warning( - f"Number of files for evaluation in {os.path.join(self.args.data_folder, f'{DatasetType.VALID}')} ({len(file_list_eval)}) is more than requested ({self.num_files_eval}). A subset of files will be used ") - file_list_eval = file_list_eval[:self.num_files_eval] + file_list_eval = my_files + + t_listing_end = time.time() + if self.my_rank == 0: + self.logger.output( + f"{utcnow()} File listing [{dataset_type}]: " + f"{global_count} files discovered, {len(my_files)} assigned to rank 0, " + f"completed in {t_listing_end - t_listing_start:.2f}s") + + if self.my_rank == 0: + self.logger.output( + f"{utcnow()} Streamed file sharding: {global_train_count} train files " + f"across {self.comm_size} ranks via round-robin " + f"(rank 0 shard: {len(file_list_train)} files)") + self.args.derive_configurations(file_list_train, file_list_eval) self.args.validate() self.checkpointing_mechanism = None @@ -275,7 +372,10 @@ def _eval(self, epoch): Evaluation loop will read a separate dataset and has its own own computation time. """ step = 1 - total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size) + if self.args.files_pre_sharded: + total = self.args.eval_steps # agreed via allreduce(MIN) + else: + total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size) loader = self.framework.get_loader(DatasetType.VALID) self.stats.start_loading() for batch in loader.next(): @@ -358,7 +458,10 @@ def _train(self, epoch): """ block = 1 # A continuous period of training steps, ended by checkpointing block_step = overall_step = 1 # Steps are taken within blocks - max_steps = math.floor(self.num_samples * self.num_files_train / self.batch_size / self.comm_size) + if self.args.files_pre_sharded: + max_steps = self.args.training_steps # agreed via allreduce(MIN) + else: + max_steps = math.floor(self.num_samples * self.num_files_train / self.batch_size / self.comm_size) self.steps_per_epoch = max_steps # Start the very first block self.stats.start_block(epoch, block) @@ -421,17 +524,27 @@ def run(self): if (not self.generate_only) and (not self.args.checkpoint_only): # Print out the expected number of steps for each epoch and evaluation if self.my_rank == 0: - total = math.floor(self.num_samples * self.num_files_train / self.batch_size / self.comm_size) - self.logger.output( - f"{utcnow()} Max steps per epoch: {total} = {self.num_samples} * {self.num_files_train} / {self.batch_size} / {self.comm_size} (samples per file * num files / batch size / comm size)") + if self.args.files_pre_sharded: + total = math.floor(self.num_samples * self.num_files_train / self.batch_size) + self.logger.output( + f"{utcnow()} Max steps per epoch per rank: {total} = {self.num_samples} * {self.num_files_train} / {self.batch_size} (samples per file * local files / batch size)") + else: + total = math.floor(self.num_samples * self.num_files_train / self.batch_size / self.comm_size) + self.logger.output( + f"{utcnow()} Max steps per epoch: {total} = {self.num_samples} * {self.num_files_train} / {self.batch_size} / {self.comm_size} (samples per file * num files / batch size / comm size)") if self.total_training_steps > 0: self.logger.output( f"{utcnow()} Total training steps is set to be {self.total_training_steps}. Will only run up to {min(total*self.args.epochs, self.total_training_steps)}" ) if self.do_eval: - total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size) - self.logger.output( - f"{utcnow()} Steps per eval: {total} = {self.num_samples} * {self.num_files_eval} / {self.batch_size_eval} / {self.comm_size} (samples per file * num files / batch size eval / comm size)") + if self.args.files_pre_sharded: + total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval) + self.logger.output( + f"{utcnow()} Steps per eval per rank: {total} = {self.num_samples} * {self.num_files_eval} / {self.batch_size_eval} (samples per file * local files / batch size eval)") + else: + total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size) + self.logger.output( + f"{utcnow()} Steps per eval: {total} = {self.num_samples} * {self.num_files_eval} / {self.batch_size_eval} / {self.comm_size} (samples per file * num files / batch size eval / comm size)") # Keep track of the next epoch at which we will evaluate next_eval_epoch = self.eval_after_epoch @@ -443,6 +556,21 @@ def run(self): self.framework.get_loader(dataset_type=DatasetType.TRAIN).read() if self.do_eval: self.framework.get_loader(dataset_type=DatasetType.VALID).read() + + # Pre-warm workers: trigger DataLoader worker spawn before epoch 1. + # Without persistent_workers, workers re-spawn on each iter() call. + # This pre-warm ensures the first epoch doesn't include spawn latency. + train_loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) + if hasattr(train_loader, '_dataset') and train_loader._dataset is not None: + warmup_iter = iter(train_loader._dataset) + try: + next(warmup_iter) + except StopIteration: + pass + del warmup_iter + if self.my_rank == 0: + self.logger.output(f"{utcnow()} Worker pre-warm complete ({self.args.read_threads} workers spawned)") + self.comm.barrier() # Skip the per-epoch page-cache flush after the first failure so a host # without NOPASSWD sudo doesn't pay the failure cost on every epoch and @@ -489,6 +617,24 @@ def run(self): self.stats.end_eval(epoch) self.framework.get_loader(DatasetType.VALID).finalize() self.args.reconfigure(epoch + 1) # reconfigure once per epoch + # Refresh serialized args so next epoch's workers see resharded file list + train_loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) + if hasattr(train_loader, 'refresh_args'): + train_loader.refresh_args() + if self.do_eval: + eval_loader = self.framework.get_loader(dataset_type=DatasetType.VALID) + if hasattr(eval_loader, 'refresh_args'): + eval_loader.refresh_args() + # Pre-warm workers for next epoch (spawn + init outside timed window) + if hasattr(train_loader, '_dataset') and train_loader._dataset is not None: + warmup_iter = iter(train_loader._dataset) + try: + next(warmup_iter) + except StopIteration: + pass + del warmup_iter + if self.my_rank == 0: + self.logger.output(f"{utcnow()} Worker pre-warm complete for epoch {epoch + 1} ({self.args.read_threads} workers spawned)") self.stats.end_epoch(epoch) if (self.args.checkpoint_only): @@ -503,6 +649,7 @@ def finalize(self): global dftracer, dftracer_initialize, dftracer_finalize self.comm.barrier() + if self.checkpointing_mechanism: self.checkpointing_mechanism.finalize() if not self.generate_only: diff --git a/dlio_benchmark/reader/_local_fs_iterable_mixin.py b/dlio_benchmark/reader/_local_fs_iterable_mixin.py index 20da8dbb..e681130e 100644 --- a/dlio_benchmark/reader/_local_fs_iterable_mixin.py +++ b/dlio_benchmark/reader/_local_fs_iterable_mixin.py @@ -326,12 +326,19 @@ def _localfs_prefetch_all(self) -> None: self._local_cache = cache def _localfs_ensure_cached(self, filename: str) -> None: - """Fetch a single file on demand if not already in the cache.""" - if filename not in self._local_cache: - if self._use_direct: - self._local_cache.update(self._prefetch_direct([filename])) - else: - self._local_cache[filename] = self._read_local_bytes(filename) + """Read a single file on demand, always re-reading from storage. + + The cache is intentionally NOT used for map-style access so that every + epoch measures real I/O. With persistent_workers=True, reusing cached + byte counts would skip all reads in epochs 2+, producing invalid AU. + """ + if self._use_direct: + result = self._prefetch_direct([filename]) + self._local_cache.update(result) + else: + self._local_cache[filename] = self._read_local_bytes(filename) + self._total_bytes_read += self._local_cache[filename] + self._total_objects_read += 1 def finalize_local_bytes(self) -> None: """ diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index e513810b..35356229 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -218,6 +218,12 @@ class ConfigArguments: pin_memory: bool = True odirect: bool = False + # When True, file_list_train/eval already contain only this rank's shard. + # Sample-level sharding is skipped to avoid double-partitioning. + files_pre_sharded: bool = False + # Number of threads rank 0 uses to list subfolders in parallel. + listing_threads: int = 4 + # derived fields required_samples: int = 1 total_samples_eval: int = 1 @@ -691,35 +697,47 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): self.resized_image = gen_random_tensor(shape=self.transformed_record_dims, dtype=self.transformed_record_element_dtype, rng=rng) else: self.resized_image = np.random.randint(255, size=(self.max_dimension, self.max_dimension), dtype=np.uint8) - self.file_list_train = file_list_train - self.file_list_eval = file_list_eval + self.file_list_train = list(file_list_train) + self.file_list_eval = list(file_list_eval) self.num_files_eval = len(file_list_eval) self.num_files_train = len(file_list_train) self.total_samples_train = self.num_samples_per_file * len(self.file_list_train) self.total_samples_eval = self.num_samples_per_file * len(self.file_list_eval) - - # The sampler intentionally drops the trailing remainder when the - # total sample count is not divisible by comm_size. Compute the - # validation sums from the effective sample counts so reconfigure() - # validates exactly the indices that are assigned to ranks. - effective_train_samples = ( - self.total_samples_train // self.comm_size - ) * self.comm_size - effective_eval_samples = ( - self.total_samples_eval // self.comm_size - ) * self.comm_size - - self.train_sample_index_sum = ( - effective_train_samples * (effective_train_samples - 1) // 2 - ) - self.eval_sample_index_sum = ( - effective_eval_samples * (effective_eval_samples - 1) // 2 - ) - self.required_samples = self.comm_size * self.batch_size - if self.read_threads > 0: - self.required_samples *= self.read_threads - self.training_steps = int(math.ceil(self.total_samples_train / self.batch_size / self.comm_size)) - self.eval_steps = int(math.ceil(self.total_samples_eval / self.batch_size_eval / self.comm_size)) + if self.files_pre_sharded: + # Files are already distributed across ranks — sample space is local. + # Round-robin gives even file counts, allreduce_min as safety net. + self.train_sample_index_sum = self.total_samples_train * (self.total_samples_train - 1) // 2 + self.eval_sample_index_sum = self.total_samples_eval * (self.total_samples_eval - 1) // 2 + self.required_samples = self.batch_size + if self.read_threads > 0: + self.required_samples *= self.read_threads + local_train_steps = int(math.ceil(self.total_samples_train / self.batch_size)) + local_eval_steps = int(math.ceil(self.total_samples_eval / self.batch_size_eval)) if self.total_samples_eval > 0 else 0 + self.training_steps = DLIOMPI.get_instance().allreduce_min(local_train_steps) + self.eval_steps = DLIOMPI.get_instance().allreduce_min(local_eval_steps) + else: + # The sampler intentionally drops the trailing remainder when the + # total sample count is not divisible by comm_size. Compute the + # validation sums from the effective sample counts so reconfigure() + # validates exactly the indices that are assigned to ranks. + effective_train_samples = ( + self.total_samples_train // self.comm_size + ) * self.comm_size + effective_eval_samples = ( + self.total_samples_eval // self.comm_size + ) * self.comm_size + + self.train_sample_index_sum = ( + effective_train_samples * (effective_train_samples - 1) // 2 + ) + self.eval_sample_index_sum = ( + effective_eval_samples * (effective_eval_samples - 1) // 2 + ) + self.required_samples = self.comm_size * self.batch_size + if self.read_threads > 0: + self.required_samples *= self.read_threads + self.training_steps = int(math.ceil(self.total_samples_train / self.batch_size / self.comm_size)) + self.eval_steps = int(math.ceil(self.total_samples_eval / self.batch_size_eval / self.comm_size)) if self.data_loader_sampler is None and self.data_loader_classname is None: if self.data_loader == DataLoaderType.TENSORFLOW: self.data_loader_sampler = DataLoaderSampler.ITERATIVE @@ -888,24 +906,32 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number): num_threads = 1 if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI: num_threads = self.read_threads - # Floor division so every rank gets the same sample count. - # See dlio_sampler in torch_data_loader.py for the rationale — - # mismatched per-rank counts deadlock the per-step / end-of-epoch - # barriers in main._train(). Drops up to (comm_size - 1) samples - # per epoch; warn once from rank 0 when that happens. - samples_per_proc = total_samples // self.comm_size + if self.files_pre_sharded: + # Files already sharded — all local samples belong to this rank. + # Align to minimum across ranks for consistent batch counts. + aligned = DLIOMPI.get_instance().allreduce_min(total_samples) + samples_per_proc = aligned + start_sample_index = 0 + end_sample_index = aligned - 1 + else: + # Floor division so every rank gets the same sample count. + # See dlio_sampler in torch_data_loader.py for the rationale — + # mismatched per-rank counts deadlock the per-step / end-of-epoch + # barriers in main._train(). Drops up to (comm_size - 1) samples + # per epoch; warn once from rank 0 when that happens. + samples_per_proc = total_samples // self.comm_size + start_sample_index = samples_per_proc * self.my_rank + end_sample_index = samples_per_proc * (self.my_rank + 1) - 1 + dropped = total_samples - samples_per_proc * self.comm_size + if dropped > 0 and self.my_rank == 0: + self.logger.warning( + f"build_sample_map_iter: dropping {dropped} sample(s) — " + f"total_samples ({total_samples}) is not a multiple of " + f"comm_size ({self.comm_size}). Each rank will process " + f"{samples_per_proc} samples. Choose total_samples as a " + f"multiple of {self.comm_size} to use every sample." + ) self.samples_per_thread = samples_per_proc // num_threads - start_sample_index = samples_per_proc * self.my_rank - end_sample_index = samples_per_proc * (self.my_rank + 1) - 1 - dropped = total_samples - samples_per_proc * self.comm_size - if dropped > 0 and self.my_rank == 0: - self.logger.warning( - f"build_sample_map_iter: dropping {dropped} sample(s) — " - f"total_samples ({total_samples}) is not a multiple of " - f"comm_size ({self.comm_size}). Each rank will process " - f"{samples_per_proc} samples. Choose total_samples as a " - f"multiple of {self.comm_size} to use every sample." - ) sample_list = np.arange(start_sample_index, end_sample_index + 1) self.logger.debug(f"{self.my_rank} {start_sample_index} {end_sample_index}") if self.sample_shuffle is not Shuffle.OFF: @@ -916,8 +942,12 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number): np.random.shuffle(sample_list) sample_index = 0 if num_files > 0: - files_per_rank = (num_files // self.comm_size) % num_files - file_index = self.my_rank * files_per_rank + if self.files_pre_sharded: + files_per_rank = num_files + file_index = 0 + else: + files_per_rank = (num_files // self.comm_size) % num_files + file_index = self.my_rank * files_per_rank for thread_index in range(num_threads): process_thread_file_map[thread_index] = [] for sample in sample_list: @@ -931,10 +961,10 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number): abs_path, sample_list[sample_index] % self.num_samples_per_file)) sample_index += 1 - # Carry the rank offset forward so each rank stays in its own - # file partition. Without the offset, non-zero ranks fall back - # to rank-0's file range on the second and subsequent samples. - file_index = (self.my_rank * files_per_rank + sample_index // self.num_samples_per_file) % num_files + if self.files_pre_sharded: + file_index = (sample_index // self.num_samples_per_file) % num_files + else: + file_index = (self.my_rank * files_per_rank + sample_index // self.num_samples_per_file) % num_files return process_thread_file_map, samples_sum @dlp.log @@ -943,20 +973,27 @@ def get_global_map_index(self, file_list, total_samples, epoch_number): if num_files == 0: return {}, 0 - # Floor division so every rank gets the same sample count. See - # dlio_sampler in torch_data_loader.py for the deadlock rationale. - samples_per_proc = total_samples // self.comm_size - start_sample = self.my_rank * samples_per_proc - end_sample = (self.my_rank + 1) * samples_per_proc - 1 - dropped = total_samples - samples_per_proc * self.comm_size - if dropped > 0 and self.my_rank == 0: - self.logger.warning( - f"get_global_map_index: dropping {dropped} sample(s) — " - f"total_samples ({total_samples}) is not a multiple of " - f"comm_size ({self.comm_size}). Each rank will process " - f"{samples_per_proc} samples. Choose total_samples as a " - f"multiple of {self.comm_size} to use every sample." - ) + if self.files_pre_sharded: + # Files already distributed — each rank owns all its local samples. + # Align to minimum across ranks to match dlio_sampler alignment. + aligned = DLIOMPI.get_instance().allreduce_min(total_samples) + start_sample = 0 + end_sample = aligned - 1 + else: + # Floor division so every rank gets the same sample count. See + # dlio_sampler in torch_data_loader.py for the deadlock rationale. + samples_per_proc = total_samples // self.comm_size + start_sample = self.my_rank * samples_per_proc + end_sample = (self.my_rank + 1) * samples_per_proc - 1 + dropped = total_samples - samples_per_proc * self.comm_size + if dropped > 0 and self.my_rank == 0: + self.logger.warning( + f"get_global_map_index: dropping {dropped} sample(s) — " + f"total_samples ({total_samples}) is not a multiple of " + f"comm_size ({self.comm_size}). Each rank will process " + f"{samples_per_proc} samples. Choose total_samples as a " + f"multiple of {self.comm_size} to use every sample." + ) self.logger.debug(f"my_rank: {self.my_rank}, start_sample: {start_sample}, end_sample: {end_sample}") # Determine shuffle seed (None = no shuffle) @@ -981,16 +1018,91 @@ def get_global_map_index(self, file_list, total_samples, epoch_number): ) return vmap, samples_sum + def _reshard_files(self, epoch_number): + """Re-distribute files across ranks using an epoch-dependent hash. + + Each rank shuffles its local shard with an epoch-dependent seed, + then distributes files round-robin to destination ranks via alltoall. + Since all ranks have the same file count (from initial round-robin), + the result is perfectly balanced — no file loss. + """ + import time as _time + mpi = DLIOMPI.get_instance() + + t_reshard_start = _time.time() + + for attr in ('file_list_train', 'file_list_eval'): + file_list = getattr(self, attr) + paths = list(file_list) + + # Shuffle with epoch-dependent seed so files go to different ranks each epoch + if self.seed_change_epoch: + np.random.seed(self.seed + epoch_number) + else: + np.random.seed(self.seed) + np.random.shuffle(paths) + + # Round-robin assignment to destination ranks, rotated by + # sender rank so remainder files spread evenly across all + # destinations — every rank receives the same total. + buckets = [[] for _ in range(self.comm_size)] + for i, fpath in enumerate(paths): + buckets[(i + self.my_rank) % self.comm_size].append(fpath) + + # alltoall: send buckets[i] to rank i, receive from all ranks + recv_buckets = mpi.alltoall(buckets) + del buckets + + # Flatten received files into new local shard + new_shard = [] + for bucket in recv_buckets: + new_shard.extend(bucket) + del recv_buckets + + # Local shuffle with rank-specific seed for I/O decorrelation + np.random.seed(self.seed + epoch_number + self.my_rank * 31) + np.random.shuffle(new_shard) + + setattr(self, attr, new_shard) + + # Update counts + self.num_files_train = len(self.file_list_train) + self.num_files_eval = len(self.file_list_eval) + self.total_samples_train = self.num_samples_per_file * self.num_files_train + self.total_samples_eval = self.num_samples_per_file * self.num_files_eval + + # Re-align step counts (should be identical across ranks now) + local_train_steps = self.total_samples_train // self.batch_size + self.training_steps = mpi.allreduce_min(local_train_steps) + if self.num_files_eval > 0: + local_eval_steps = self.total_samples_eval // self.batch_size_eval + self.eval_steps = mpi.allreduce_min(local_eval_steps) + + t_reshard_end = _time.time() + if self.my_rank == 0: + self.logger.output( + f"{utcnow()} Reshard for epoch {epoch_number}: " + f"{self.num_files_train} train files, {self.num_files_eval} eval files " + f"redistributed via alltoall in {t_reshard_end - t_reshard_start:.2f}s") + @dlp.log def reconfigure(self, epoch_number): - if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if self.file_shuffle is not Shuffle.OFF: + # Reshard files across ranks (when shuffle enabled and pre-sharded) + if self.file_shuffle is not Shuffle.OFF and self.files_pre_sharded: + self._reshard_files(epoch_number) + elif self.file_shuffle is not Shuffle.OFF: + if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: if self.seed_change_epoch: np.random.seed(self.seed + epoch_number) else: np.random.seed(self.seed) - np.random.shuffle(self.file_list_train) - np.random.shuffle(self.file_list_eval) + # Materialize and shuffle + train_list = list(self.file_list_train) + eval_list = list(self.file_list_eval) + np.random.shuffle(train_list) + np.random.shuffle(eval_list) + self.file_list_train = train_list + self.file_list_eval = eval_list local_train_sample_sum = 0 local_eval_sample_sum = 0 if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: @@ -1006,11 +1118,12 @@ def reconfigure(self, epoch_number): global_eval_sample_sum = DLIOMPI.get_instance().reduce(local_eval_sample_sum) if self.my_rank == 0: self.logger.info(f"{utcnow()} Total number of samples: train {global_train_sample_sum}, eval {global_eval_sample_sum}") - if self.train_sample_index_sum != global_train_sample_sum: - raise Exception(f"Sharding of train samples are missing samples got {global_train_sample_sum} but expected {self.train_sample_index_sum}") - - if self.eval_sample_index_sum != global_eval_sample_sum: - raise Exception(f"Sharding of eval samples are missing samples got {global_eval_sample_sum} but expected {self.eval_sample_index_sum}") + if not self.files_pre_sharded: + if self.train_sample_index_sum != global_train_sample_sum: + raise Exception(f"Sharding of train samples are missing samples got {global_train_sample_sum} but expected {self.train_sample_index_sum}") + + if self.eval_sample_index_sum != global_eval_sample_sum: + raise Exception(f"Sharding of eval samples are missing samples got {global_eval_sample_sum} but expected {self.eval_sample_index_sum}") def GetConfig(args, key): keys = key.split(".") diff --git a/dlio_benchmark/utils/utility.py b/dlio_benchmark/utils/utility.py index f937c139..a40bb1b5 100644 --- a/dlio_benchmark/utils/utility.py +++ b/dlio_benchmark/utils/utility.py @@ -344,6 +344,17 @@ def reduce(self, num): raise Exception(f"method {self.classname()}.reduce() called before initializing MPI") else: return MPI.COMM_WORLD.allreduce(num, op=MPI.SUM) + + def allreduce_min(self, value): + from mpi4py import MPI + if self.mpi_state == MPIState.UNINITIALIZED: + raise Exception(f"method {self.classname()}.allreduce_min() called before initializing MPI") + return self.comm().allreduce(value, op=MPI.MIN) + + def alltoall(self, data): + if self.mpi_state == MPIState.UNINITIALIZED: + raise Exception(f"method {self.classname()}.alltoall() called before initializing MPI") + return self.comm().alltoall(data) def finalize(self): from mpi4py import MPI From ab033c3a41021e6261d7388155810c94a9632668 Mon Sep 17 00:00:00 2001 From: Russ Fellows Date: Fri, 19 Jun 2026 12:26:51 -0600 Subject: [PATCH 2/2] fix(sampler): guard ConfigArguments init against uninitialized MPI in unit tests DLIOSampler.__init__ now wraps ConfigArguments.get_instance() in try/except so that tests running without MPI (no mpirun, no mpi4py init) fall back cleanly to the non-pre-sharded path instead of raising an exception. In real distributed runs files_pre_sharded=True is only set after MPI is fully initialized, so the fallback is always correct in production. --- dlio_benchmark/data_loader/torch_data_loader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index a68529fe..6c7aef24 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -415,8 +415,12 @@ def __init__(self, rank, size, num_samples, epochs): self.rank = rank self.num_samples = num_samples self.epochs = epochs - args = ConfigArguments.get_instance() - if args.files_pre_sharded: + try: + pre_sharded = ConfigArguments.get_instance().files_pre_sharded + except Exception: + # MPI not initialized (e.g. unit tests) — treat as non-pre-sharded. + pre_sharded = False + if pre_sharded: # Files already distributed — this rank owns all its local samples. # Round-robin gives even counts, but allreduce_min as safety net. aligned = DLIOMPI.get_instance().allreduce_min(num_samples)