diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index 3ca9ccf2..f47e4eaa 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -25,7 +25,7 @@ from dlio_benchmark.common.enumerations import DatasetType, DataLoaderType from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader from dlio_benchmark.reader.reader_factory import ReaderFactory -from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai +from dlio_benchmark.utils.utility import utcnow, DLIOMPI, DLIOLogger, Profile, dft_ai from dlio_benchmark.utils.config import ConfigArguments dlp = Profile(MODULE_DATA_LOADER) @@ -415,16 +415,35 @@ def __init__(self, rank, size, num_samples, epochs): self.rank = rank self.num_samples = num_samples self.epochs = epochs - samples_per_proc = int(math.ceil(num_samples/size)) + # Use floor division so every rank gets the same sample count. With + # math.ceil() the last rank was clamped to fewer samples than its + # peers when num_samples % size != 0; mismatched per-rank batch + # counts caused the per-step and end-of-epoch barriers in main._train() + # to match across iterations and deadlock at the next epoch boundary. + # The trailing (num_samples % size) samples are dropped on purpose; + # pick num_samples as a multiple of comm_size to use every sample. + samples_per_proc = num_samples // size start_sample = self.rank * samples_per_proc end_sample = (self.rank + 1) * samples_per_proc - 1 - if end_sample > num_samples - 1: - end_sample = num_samples - 1 self.indices = list(range(start_sample, end_sample + 1)) + dropped = num_samples - samples_per_proc * size + if dropped > 0 and self.rank == 0: + DLIOLogger.get_instance().warning( + f"{utcnow()} dlio_sampler: dropping {dropped} sample(s) — " + f"num_samples ({num_samples}) is not a multiple of comm_size " + f"({size}). Each rank will process {samples_per_proc} samples. " + f"Choose num_samples as a multiple of {size} to use every sample." + ) def __len__(self): - return self.num_samples + # Per-rank shard length — must match what __iter__ yields. Returning + # self.num_samples (the global count) here is a pre-existing bug that + # the floor-division change above makes provable: len(self.indices) is + # now num_samples // size while self.num_samples is still num_samples, + # so any caller that builds len(DataLoader) from len(sampler) would + # over-report by a factor of comm_size. + return len(self.indices) def __iter__(self): for sample in self.indices: diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index cc6f9a80..e513810b 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -697,8 +697,24 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): self.num_files_train = len(file_list_train) self.total_samples_train = self.num_samples_per_file * len(self.file_list_train) self.total_samples_eval = self.num_samples_per_file * len(self.file_list_eval) - self.train_sample_index_sum = self.total_samples_train * (self.total_samples_train - 1) // 2 - self.eval_sample_index_sum = self.total_samples_eval * (self.total_samples_eval - 1) // 2 + + # The sampler intentionally drops the trailing remainder when the + # total sample count is not divisible by comm_size. Compute the + # validation sums from the effective sample counts so reconfigure() + # validates exactly the indices that are assigned to ranks. + effective_train_samples = ( + self.total_samples_train // self.comm_size + ) * self.comm_size + effective_eval_samples = ( + self.total_samples_eval // self.comm_size + ) * self.comm_size + + self.train_sample_index_sum = ( + effective_train_samples * (effective_train_samples - 1) // 2 + ) + self.eval_sample_index_sum = ( + effective_eval_samples * (effective_eval_samples - 1) // 2 + ) self.required_samples = self.comm_size * self.batch_size if self.read_threads > 0: self.required_samples *= self.read_threads @@ -872,12 +888,24 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number): num_threads = 1 if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI: num_threads = self.read_threads - samples_per_proc = int(math.ceil(total_samples/self.comm_size)) + # Floor division so every rank gets the same sample count. + # See dlio_sampler in torch_data_loader.py for the rationale — + # mismatched per-rank counts deadlock the per-step / end-of-epoch + # barriers in main._train(). Drops up to (comm_size - 1) samples + # per epoch; warn once from rank 0 when that happens. + samples_per_proc = total_samples // self.comm_size self.samples_per_thread = samples_per_proc // num_threads start_sample_index = samples_per_proc * self.my_rank end_sample_index = samples_per_proc * (self.my_rank + 1) - 1 - if end_sample_index > total_samples - 1: - end_sample_index = total_samples - 1 + dropped = total_samples - samples_per_proc * self.comm_size + if dropped > 0 and self.my_rank == 0: + self.logger.warning( + f"build_sample_map_iter: dropping {dropped} sample(s) — " + f"total_samples ({total_samples}) is not a multiple of " + f"comm_size ({self.comm_size}). Each rank will process " + f"{samples_per_proc} samples. Choose total_samples as a " + f"multiple of {self.comm_size} to use every sample." + ) sample_list = np.arange(start_sample_index, end_sample_index + 1) self.logger.debug(f"{self.my_rank} {start_sample_index} {end_sample_index}") if self.sample_shuffle is not Shuffle.OFF: @@ -915,9 +943,20 @@ def get_global_map_index(self, file_list, total_samples, epoch_number): if num_files == 0: return {}, 0 - samples_per_proc = int(math.ceil(total_samples / self.comm_size)) + # Floor division so every rank gets the same sample count. See + # dlio_sampler in torch_data_loader.py for the deadlock rationale. + samples_per_proc = total_samples // self.comm_size start_sample = self.my_rank * samples_per_proc - end_sample = min((self.my_rank + 1) * samples_per_proc - 1, total_samples - 1) + end_sample = (self.my_rank + 1) * samples_per_proc - 1 + dropped = total_samples - samples_per_proc * self.comm_size + if dropped > 0 and self.my_rank == 0: + self.logger.warning( + f"get_global_map_index: dropping {dropped} sample(s) — " + f"total_samples ({total_samples}) is not a multiple of " + f"comm_size ({self.comm_size}). Each rank will process " + f"{samples_per_proc} samples. Choose total_samples as a " + f"multiple of {self.comm_size} to use every sample." + ) self.logger.debug(f"my_rank: {self.my_rank}, start_sample: {start_sample}, end_sample: {end_sample}") # Determine shuffle seed (None = no shuffle) diff --git a/tests/test_dlio_sampler.py b/tests/test_dlio_sampler.py new file mode 100644 index 00000000..e427bcd1 --- /dev/null +++ b/tests/test_dlio_sampler.py @@ -0,0 +1,52 @@ +"""Regression tests for dlio_sampler — per-rank equality and Sampler contract. + +Pins the fix for mlcommons/storage#455: inter-epoch deadlock caused by +math.ceil(N/size) producing unequal per-rank batch counts when N is not a +multiple of comm_size. +""" + +from dlio_benchmark.data_loader.torch_data_loader import dlio_sampler + + +def test_dlio_sampler_equalizes_uneven_rank_counts(): + """Every rank gets the same per-rank shard, with trailing samples dropped. + + The original ceil+clamp produced [15,15,15,15,15,15,10] for (N=100, size=7); + the deadlock comes from the last rank doing fewer batches under drop_last. + """ + total = 100 + size = 7 + batch_size = 3 + + per_rank_samples = [ + len(list(dlio_sampler(rank, size, total, epochs=1))) + for rank in range(size) + ] + per_rank_batches = [n // batch_size for n in per_rank_samples] + + assert per_rank_samples == [14] * 7 + assert per_rank_batches == [4] * 7 + assert total - sum(per_rank_samples) == 2 + + +def test_dlio_sampler_len_matches_iterator_length(): + """PyTorch Sampler contract: len(sampler) == len(list(iter(sampler))).""" + total = 100 + size = 7 + + for rank in range(size): + sampler = dlio_sampler(rank, size, total, epochs=1) + assert len(sampler) == len(list(iter(sampler))) + + +def test_dlio_sampler_even_division_unchanged(): + """When N is a multiple of size, behavior is identical to the old impl.""" + total = 100 + size = 10 + + per_rank_samples = [ + len(list(dlio_sampler(rank, size, total, epochs=1))) + for rank in range(size) + ] + assert per_rank_samples == [10] * 10 + assert total - sum(per_rank_samples) == 0