Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,25 +415,37 @@ def __init__(self, rank, size, num_samples, epochs):
self.rank = rank
self.num_samples = num_samples
self.epochs = epochs
# Use floor division so every rank gets the same sample count. With
# math.ceil() the last rank was clamped to fewer samples than its
# peers when num_samples % size != 0; mismatched per-rank batch
# counts caused the per-step and end-of-epoch barriers in main._train()
# to match across iterations and deadlock at the next epoch boundary.
# The trailing (num_samples % size) samples are dropped on purpose;
# pick num_samples as a multiple of comm_size to use every sample.
samples_per_proc = num_samples // size
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc - 1
self.indices = list(range(start_sample, end_sample + 1))
dropped = num_samples - samples_per_proc * size
if dropped > 0 and self.rank == 0:
DLIOLogger.get_instance().warning(
f"{utcnow()} dlio_sampler: dropping {dropped} sample(s) — "
f"num_samples ({num_samples}) is not a multiple of comm_size "
f"({size}). Each rank will process {samples_per_proc} samples. "
f"Choose num_samples as a multiple of {size} to use every sample."
)
try:
pre_sharded = ConfigArguments.get_instance().files_pre_sharded
except Exception:
# MPI not initialized (e.g. unit tests) — treat as non-pre-sharded.
pre_sharded = False
if pre_sharded:
# Files already distributed — this rank owns all its local samples.
# Round-robin gives even counts, but allreduce_min as safety net.
aligned = DLIOMPI.get_instance().allreduce_min(num_samples)
self.indices = list(range(aligned))
self.num_samples = aligned
else:
# Use floor division so every rank gets the same sample count. With
# math.ceil() the last rank was clamped to fewer samples than its
# peers when num_samples % size != 0; mismatched per-rank batch
# counts caused the per-step and end-of-epoch barriers in main._train()
# to match across iterations and deadlock at the next epoch boundary.
# The trailing (num_samples % size) samples are dropped on purpose;
# pick num_samples as a multiple of comm_size to use every sample.
samples_per_proc = num_samples // size
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc - 1
self.indices = list(range(start_sample, end_sample + 1))
dropped = num_samples - samples_per_proc * size
if dropped > 0 and self.rank == 0:
DLIOLogger.get_instance().warning(
f"{utcnow()} dlio_sampler: dropping {dropped} sample(s) — "
f"num_samples ({num_samples}) is not a multiple of comm_size "
f"({size}). Each rank will process {samples_per_proc} samples. "
f"Choose num_samples as a multiple of {size} to use every sample."
)


def __len__(self):
Expand Down Expand Up @@ -623,8 +635,8 @@ def read(self):
else:
kwargs={'multiprocessing_context':self._args.multiprocessing_context,
'prefetch_factor': prefetch_factor}
if torch.__version__ != '1.3.1':
kwargs['persistent_workers'] = True
# persistent_workers=False: workers re-spawn each epoch to pick up
# resharded file lists from updated serial_args.
pin_memory = self._args.pin_memory and torch.cuda.is_available()
if torch.__version__ == '1.3.1':
if 'prefetch_factor' in kwargs:
Expand Down Expand Up @@ -665,6 +677,17 @@ def next(self):
dlp.update(epoch=self.epoch_number)
dft_ai.update(epoch=self.epoch_number)

def refresh_args(self):
"""Re-serialize ConfigArguments so next worker spawn picks up changes.

Call after reconfigure() to propagate resharded file lists to workers.
Without persistent_workers, workers re-spawn on each iter() and call
worker_init → pickle.loads(serial_args), so this ensures freshness.
"""
dataset = self._dataset.dataset
if hasattr(dataset, 'serial_args'):
dataset.serial_args = pickle.dumps(self._args)

@dlp.log
def finalize(self):
# When read_threads=0 the reader lives in-process on the dataset object.
Expand Down
Loading
Loading