Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dlio_benchmark.common.enumerations import DatasetType, DataLoaderType
from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader
from dlio_benchmark.reader.reader_factory import ReaderFactory
from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai
from dlio_benchmark.utils.utility import utcnow, DLIOMPI, DLIOLogger, Profile, dft_ai
from dlio_benchmark.utils.config import ConfigArguments

dlp = Profile(MODULE_DATA_LOADER)
Expand Down Expand Up @@ -415,16 +415,35 @@ def __init__(self, rank, size, num_samples, epochs):
self.rank = rank
self.num_samples = num_samples
self.epochs = epochs
samples_per_proc = int(math.ceil(num_samples/size))
# Use floor division so every rank gets the same sample count. With
# math.ceil() the last rank was clamped to fewer samples than its
# peers when num_samples % size != 0; mismatched per-rank batch
# counts caused the per-step and end-of-epoch barriers in main._train()
# to match across iterations and deadlock at the next epoch boundary.
# The trailing (num_samples % size) samples are dropped on purpose;
# pick num_samples as a multiple of comm_size to use every sample.
samples_per_proc = num_samples // size
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc - 1
if end_sample > num_samples - 1:
end_sample = num_samples - 1
self.indices = list(range(start_sample, end_sample + 1))
dropped = num_samples - samples_per_proc * size
if dropped > 0 and self.rank == 0:
DLIOLogger.get_instance().warning(
f"{utcnow()} dlio_sampler: dropping {dropped} sample(s) — "
f"num_samples ({num_samples}) is not a multiple of comm_size "
f"({size}). Each rank will process {samples_per_proc} samples. "
f"Choose num_samples as a multiple of {size} to use every sample."
)


def __len__(self):
return self.num_samples
# Per-rank shard length — must match what __iter__ yields. Returning
# self.num_samples (the global count) here is a pre-existing bug that
# the floor-division change above makes provable: len(self.indices) is
# now num_samples // size while self.num_samples is still num_samples,
# so any caller that builds len(DataLoader) from len(sampler) would
# over-report by a factor of comm_size.
return len(self.indices)

def __iter__(self):
for sample in self.indices:
Expand Down
53 changes: 46 additions & 7 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,24 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None):
self.num_files_train = len(file_list_train)
self.total_samples_train = self.num_samples_per_file * len(self.file_list_train)
self.total_samples_eval = self.num_samples_per_file * len(self.file_list_eval)
self.train_sample_index_sum = self.total_samples_train * (self.total_samples_train - 1) // 2
self.eval_sample_index_sum = self.total_samples_eval * (self.total_samples_eval - 1) // 2

# The sampler intentionally drops the trailing remainder when the
# total sample count is not divisible by comm_size. Compute the
# validation sums from the effective sample counts so reconfigure()
# validates exactly the indices that are assigned to ranks.
effective_train_samples = (
self.total_samples_train // self.comm_size
) * self.comm_size
effective_eval_samples = (
self.total_samples_eval // self.comm_size
) * self.comm_size

self.train_sample_index_sum = (
effective_train_samples * (effective_train_samples - 1) // 2
)
self.eval_sample_index_sum = (
effective_eval_samples * (effective_eval_samples - 1) // 2
)
self.required_samples = self.comm_size * self.batch_size
if self.read_threads > 0:
self.required_samples *= self.read_threads
Expand Down Expand Up @@ -872,12 +888,24 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number):
num_threads = 1
if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI:
num_threads = self.read_threads
samples_per_proc = int(math.ceil(total_samples/self.comm_size))
# Floor division so every rank gets the same sample count.
# See dlio_sampler in torch_data_loader.py for the rationale —
# mismatched per-rank counts deadlock the per-step / end-of-epoch
# barriers in main._train(). Drops up to (comm_size - 1) samples
# per epoch; warn once from rank 0 when that happens.
samples_per_proc = total_samples // self.comm_size
self.samples_per_thread = samples_per_proc // num_threads
start_sample_index = samples_per_proc * self.my_rank
end_sample_index = samples_per_proc * (self.my_rank + 1) - 1
if end_sample_index > total_samples - 1:
end_sample_index = total_samples - 1
dropped = total_samples - samples_per_proc * self.comm_size
if dropped > 0 and self.my_rank == 0:
self.logger.warning(
f"build_sample_map_iter: dropping {dropped} sample(s) — "
f"total_samples ({total_samples}) is not a multiple of "
f"comm_size ({self.comm_size}). Each rank will process "
f"{samples_per_proc} samples. Choose total_samples as a "
f"multiple of {self.comm_size} to use every sample."
)
sample_list = np.arange(start_sample_index, end_sample_index + 1)
self.logger.debug(f"{self.my_rank} {start_sample_index} {end_sample_index}")
if self.sample_shuffle is not Shuffle.OFF:
Expand Down Expand Up @@ -915,9 +943,20 @@ def get_global_map_index(self, file_list, total_samples, epoch_number):
if num_files == 0:
return {}, 0

samples_per_proc = int(math.ceil(total_samples / self.comm_size))
# Floor division so every rank gets the same sample count. See
# dlio_sampler in torch_data_loader.py for the deadlock rationale.
samples_per_proc = total_samples // self.comm_size
start_sample = self.my_rank * samples_per_proc
end_sample = min((self.my_rank + 1) * samples_per_proc - 1, total_samples - 1)
end_sample = (self.my_rank + 1) * samples_per_proc - 1
dropped = total_samples - samples_per_proc * self.comm_size
if dropped > 0 and self.my_rank == 0:
self.logger.warning(
f"get_global_map_index: dropping {dropped} sample(s) — "
f"total_samples ({total_samples}) is not a multiple of "
f"comm_size ({self.comm_size}). Each rank will process "
f"{samples_per_proc} samples. Choose total_samples as a "
f"multiple of {self.comm_size} to use every sample."
)
self.logger.debug(f"my_rank: {self.my_rank}, start_sample: {start_sample}, end_sample: {end_sample}")

# Determine shuffle seed (None = no shuffle)
Expand Down
52 changes: 52 additions & 0 deletions tests/test_dlio_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Regression tests for dlio_sampler — per-rank equality and Sampler contract.

Pins the fix for mlcommons/storage#455: inter-epoch deadlock caused by
math.ceil(N/size) producing unequal per-rank batch counts when N is not a
multiple of comm_size.
"""

from dlio_benchmark.data_loader.torch_data_loader import dlio_sampler


def test_dlio_sampler_equalizes_uneven_rank_counts():
"""Every rank gets the same per-rank shard, with trailing samples dropped.

The original ceil+clamp produced [15,15,15,15,15,15,10] for (N=100, size=7);
the deadlock comes from the last rank doing fewer batches under drop_last.
"""
total = 100
size = 7
batch_size = 3

per_rank_samples = [
len(list(dlio_sampler(rank, size, total, epochs=1)))
for rank in range(size)
]
per_rank_batches = [n // batch_size for n in per_rank_samples]

assert per_rank_samples == [14] * 7
assert per_rank_batches == [4] * 7
assert total - sum(per_rank_samples) == 2


def test_dlio_sampler_len_matches_iterator_length():
"""PyTorch Sampler contract: len(sampler) == len(list(iter(sampler)))."""
total = 100
size = 7

for rank in range(size):
sampler = dlio_sampler(rank, size, total, epochs=1)
assert len(sampler) == len(list(iter(sampler)))


def test_dlio_sampler_even_division_unchanged():
"""When N is a multiple of size, behavior is identical to the old impl."""
total = 100
size = 10

per_rank_samples = [
len(list(dlio_sampler(rank, size, total, epochs=1)))
for rank in range(size)
]
assert per_rank_samples == [10] * 10
assert total - sum(per_rank_samples) == 0
Loading