Skip to content
Merged
101 changes: 100 additions & 1 deletion dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,106 @@ def _filter_round_robin(chunk, start_idx):
if (start_idx + i) % self.comm_size == self.my_rank:
my_files.append(fpath)

if num_subfolders > 0:
if self.args.skip_listing:
# ── Deterministic file list (skip S3 listing entirely) ─
# Generate file URIs from DLIO's naming convention without
# any storage API calls or MPI communication. Each rank
# independently computes its own round-robin shard.
# Convention: {file_prefix}_{index:0N}_of_{total}.{format}
# For subfoldered layouts: {subfolder}/{file_prefix}_{index:0N}_of_{total}.{format}
# where subfolder = str(index % num_subfolders).zfill(nd_sf)
num_files_expected = (
self.num_files_train if dataset_type is DatasetType.TRAIN
else (self.num_files_eval if self.do_eval else 0)
)
if num_files_expected > 0:
nd_f = len(str(num_files_expected))
nd_sf = len(str(max(num_subfolders - 1, 0))) if num_subfolders > 0 else 0
for idx in range(self.my_rank, num_files_expected, self.comm_size):
fname = f"{self.args.file_prefix}_{str(idx).zfill(nd_f)}_of_{num_files_expected}.{self.args.format}"
if num_subfolders > 0:
sf = str(idx % num_subfolders).zfill(nd_sf)
rel = os.path.join(sf, fname)
else:
rel = fname
uri = self.storage.get_uri(
os.path.join(self.args.data_folder, f"{dataset_type}", rel))
my_files.append(uri)
global_count = num_files_expected
# ── Sampling validation (rank 0 only) ─────────────
# Confirm the naming convention is correct by checking
# that a sample of files actually exists in storage.
# Always checks the first and last file, plus every
# listing_validation_interval-th file in between.
# If any check fails, raises an informative error.
if self.my_rank == 0 and num_files_expected > 0 and \
self.args.listing_validation_interval > 0:
interval = self.args.listing_validation_interval
val_indices = sorted(
{0, num_files_expected - 1} |
set(range(0, num_files_expected, interval))
)
n_checks = len(val_indices)
# ── Header: tell the user what is about to happen ──
self.logger.output(
f"{utcnow()} skip_listing [{dataset_type}]: validating "
f"{n_checks:,} of {num_files_expected:,} files "
f"(first, last, every {interval:,}) via HEAD requests ...")
failed_uris = []
t_val_start = time.time()
# Report progress every ~10 % of checks, but at least
# every 500 checks and no more often than every 100.
progress_stride = max(100, min(500, n_checks // 10))
for check_num, vidx in enumerate(val_indices):
vfname = f"{self.args.file_prefix}_{str(vidx).zfill(nd_f)}_of_{num_files_expected}.{self.args.format}"
if num_subfolders > 0:
vsf = str(vidx % num_subfolders).zfill(nd_sf)
vrel = os.path.join(vsf, vfname)
else:
vrel = vfname
vuri = self.storage.get_uri(
os.path.join(self.args.data_folder, f"{dataset_type}", vrel))
if not self.storage.file_exists(vuri):
failed_uris.append(vuri)
# Periodic progress line (but not on the very first check)
if check_num > 0 and check_num % progress_stride == 0:
elapsed = time.time() - t_val_start
rate = check_num / elapsed if elapsed > 0 else 0
pct = 100.0 * check_num / n_checks
eta = (n_checks - check_num) / rate if rate > 0 else 0
self.logger.output(
f"{utcnow()} skip_listing [{dataset_type}]: "
f"{check_num:,}/{n_checks:,} checked "
f"({pct:.0f}%) — "
f"{rate:.0f} checks/s — "
f"ETA {eta:.0f}s — "
f"{len(failed_uris)} failed so far")
t_val_end = time.time()
elapsed_total = t_val_end - t_val_start
rate_total = n_checks / elapsed_total if elapsed_total > 0 else 0
if failed_uris:
sample_shown = failed_uris[:3]
raise Exception(
f"skip_listing validation failed: {len(failed_uris)} of "
f"{n_checks:,} sampled files missing in [{dataset_type}] "
f"after {elapsed_total:.1f}s. "
f"First failures: {sample_shown}. "
f"Ensure data was generated with DLIO's standard naming "
f"convention or set skip_listing=False to use directory "
f"listing instead.")
self.logger.output(
f"{utcnow()} skip_listing [{dataset_type}]: validation complete — "
f"all {n_checks:,} samples exist "
f"({elapsed_total:.1f}s, {rate_total:.0f} checks/s); "
f"{len(my_files):,} URIs ready for rank 0 "
f"({global_count:,} total across all ranks)")
elif self.my_rank == 0:
self.logger.output(
f"{utcnow()} skip_listing [{dataset_type}]: generated "
f"{len(my_files):,} file URIs deterministically "
f"({global_count:,} total — validation disabled)")

elif num_subfolders > 0:
# ── Subfoldered layout: stream with chunked bcast ─────
subfolder_names = None
if self.my_rank == 0:
Expand Down
27 changes: 15 additions & 12 deletions dlio_benchmark/reader/_s3_iterable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,9 @@ def _s3_init(self, opts: dict) -> None:
Raises ``ImportError`` immediately if the configured library is not
installed, rather than deferring failure to the first I/O call.
"""
# storage_library is REQUIRED — there is no default. Every object
# storage workload must explicitly declare which library to use.
self._storage_library: str = opts.get("storage_library")
if self._storage_library is None:
raise ValueError(
"storage_options['storage_library'] is required for S3 readers. "
"Add 'storage_library: <value>' under the 'storage:' section of "
"your workload YAML. Supported values: minio, s3dlio, s3torchconnector."
)
# Default to s3dlio — consistent with how data is generated. Users can
# override by setting storage_library in storage_options.
self._storage_library: str = opts.get("storage_library") or "s3dlio"
self._opts: dict = opts
self._object_cache: dict = {} # obj_key → int (raw byte count only)
self._minio_client = None # cached across epochs for TCP keep-alive
Expand Down Expand Up @@ -562,9 +556,18 @@ def _s3_prefetch_all(self) -> None:
self._object_cache = self._prefetch(obj_keys)

def _s3_ensure_cached(self, filename: str) -> None:
"""Fetch a single object on demand if it is not already in the cache."""
if filename not in self._object_cache:
self._object_cache.update(self._prefetch([filename]))
"""Fetch a single object on demand, always re-fetching from storage.

The cache is intentionally NOT short-circuited so that every epoch
measures real I/O. With persistent_workers=True (still used on the
iterable dataset paths), reusing a cached byte count from a previous
epoch would skip the GET entirely in epochs 2+, producing invalid AU.

This mirrors the fix applied to _localfs_ensure_cached in PR #26 —
that fix covered the local-filesystem map-style path but the identical
guard (``if filename not in self._object_cache``) was not removed here.
"""
self._object_cache.update(self._prefetch([filename]))

def finalize_s3_bytes(self) -> None:
"""
Expand Down
9 changes: 4 additions & 5 deletions dlio_benchmark/reader/reader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,10 @@ def get_reader(type, dataset_type, thread_index, epoch_number):
elif type == FormatType.TFRECORD:
if _args.odirect == True:
raise Exception("O_DIRECT for %s format is not yet supported." %type)
elif _args.storage_type in (StorageType.S3, StorageType.AISTORE):
storage_library = (getattr(_args, "storage_options", {}) or {}).get("storage_library")
if storage_library in ("s3dlio", "s3torchconnector", "minio"):
from dlio_benchmark.reader.tfrecord_reader_s3_iterable import TFRecordReaderS3Iterable
return TFRecordReaderS3Iterable(dataset_type, thread_index, epoch_number)
elif (getattr(_args, "storage_options", {}) or {}).get("storage_library") == "s3dlio":
# s3dlio handles both s3:// and file:// URIs.
from dlio_benchmark.reader.tfrecord_reader_s3_iterable import TFRecordReaderS3Iterable
return TFRecordReaderS3Iterable(dataset_type, thread_index, epoch_number)
if _args.data_loader == DataLoaderType.NATIVE_DALI:
from dlio_benchmark.reader.dali_tfrecord_reader import DaliTFRecordReader
return DaliTFRecordReader(dataset_type, thread_index, epoch_number)
Expand Down
9 changes: 8 additions & 1 deletion dlio_benchmark/reader/tfrecord_reader_s3_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""
# Copyright (c) 2025, UChicago Argonne, LLC. Apache 2.0 License.
from dlio_benchmark.common.constants import MODULE_DATA_READER
from dlio_benchmark.reader.reader_handler import FormatReader
from dlio_benchmark.reader.npy_reader import NPYReader
from dlio_benchmark.reader._s3_iterable_mixin import _S3IterableMixin
from dlio_benchmark.utils.utility import Profile, utcnow
Expand All @@ -52,6 +53,10 @@ class TFRecordReaderS3Iterable(NPYReader, _S3IterableMixin):

_object_cache[filename] holds an int (byte count), same pattern as all
other S3 iterable readers.

Note: read_index() calls FormatReader.read_index() directly to bypass
NPYReader._localfs_ensure_cached() which would attempt a local filesystem
read on an S3 URI.
"""

@dlp.log_init
Expand Down Expand Up @@ -119,7 +124,9 @@ def read_index(self, image_idx, step):
filename, _ = self.global_index_map[image_idx]
self._s3_ensure_cached(filename)
dlp.update(step=step)
return super().read_index(image_idx, step)
# Call FormatReader.read_index() directly — skips NPYReader.read_index()
# which would invoke _localfs_ensure_cached() on an S3 URI and fail.
return FormatReader.read_index(self, image_idx, step)

@dlp.log
def finalize(self):
Expand Down
4 changes: 4 additions & 0 deletions dlio_benchmark/storage/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def get_data(self, id, data, offset=None, length=None):
def isfile(self, id):
return os.path.isfile(id)

def file_exists(self, id):
"""Return True if the local file exists."""
return os.path.isfile(id)

def get_basename(self, id):
return os.path.basename(id)

Expand Down
26 changes: 24 additions & 2 deletions dlio_benchmark/storage/obj_store_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,15 @@ def __init__(self, client, bucket, obj_name):
self.buffer = BytesIO()

def write(self, data):
if isinstance(data, bytes):
if isinstance(data, (bytes, bytearray, memoryview)):
self.buffer.write(data)
else:
self.buffer.write(data.encode())
# Handle buffer-protocol objects (e.g. s3dlio BytesView) that
# are not bytes but support the buffer protocol. bytes() works
# for BytesView, memoryview, bytearray, and any C-extension type
# that implements __buffer__. Calling .encode() on these fails
# with AttributeError — .encode() is a str-only method.
self.buffer.write(bytes(data))

def close(self):
self.buffer.seek(0)
Expand Down Expand Up @@ -591,5 +596,22 @@ def list_objects(self, container_name, prefix=None):
def isfile(self, id):
return super().isfile(self.get_uri(id))

def file_exists(self, id):
"""Return True if the object exists in the store, False otherwise.

Uses s3dlio.exists() for s3dlio backend (HEAD request), or
s3_client.stat_object() for s3torchconnector/minio backends.
"""
uri = self.get_uri(id)
if self.storage_library == "s3dlio":
return self._s3dlio.exists(uri)
else:
bucket_name, object_key = self._normalize_object_key(uri)
try:
self.s3_client.stat_object(bucket_name, object_key)
return True
except Exception:
return False

def get_basename(self, id):
return os.path.basename(id)
19 changes: 17 additions & 2 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@ class ConfigArguments:
files_pre_sharded: bool = False
# Number of threads rank 0 uses to list subfolders in parallel.
listing_threads: int = 4
# When True, skip S3/filesystem listing entirely and generate file URIs
# deterministically from DLIO's known naming convention:
# {file_prefix}_{index:0N}_of_{num_files}.{format}
# Each rank independently computes its own round-robin shard with zero
# network calls and zero MPI communication. Use this for DLIO-generated
# datasets where filenames are guaranteed to follow this pattern.
# Eliminates multi-hour S3 listing for large datasets (issue #472).
skip_listing: bool = False
# When skip_listing=True, rank 0 verifies that a sample of the generated
# file URIs actually exist in storage before training begins.
# The first file, last file, and every N-th file are checked via HEAD
# (s3dlio.exists() / os.path.isfile()). Set to 0 to disable validation.
listing_validation_interval: int = 1000

# derived fields
required_samples: int = 1
Expand Down Expand Up @@ -366,8 +379,10 @@ def validate(self):
if (self.do_profiling == True) and (self.profiler == Profiler('darshan')):
if ('LD_PRELOAD' not in os.environ or os.environ["LD_PRELOAD"].find("libdarshan") == -1):
raise Exception("Please set darshan runtime library in LD_PRELOAD")
if self.format is FormatType.TFRECORD and (self.data_loader is DataLoaderType.PYTORCH):
raise Exception(f"{self.framework} support for tfrecord is not implemented for {self.data_loader}.")
if self.format is FormatType.TFRECORD and (self.data_loader is DataLoaderType.PYTORCH) and (self.do_train or self.do_eval):
# TFRecordReaderS3Iterable handles pytorch+tfrecord via s3dlio (s3:// and file://).
if (self.storage_options or {}).get("storage_library") != "s3dlio":
raise Exception(f"{self.framework} support for tfrecord is not implemented for {self.data_loader}.")
if (self.framework == FrameworkType.TENSORFLOW and self.data_loader == DataLoaderType.PYTORCH) or (
self.framework == FrameworkType.PYTORCH and self.data_loader == DataLoaderType.TENSORFLOW):
raise Exception("Imcompatible between framework and data_loader setup.")
Expand Down
24 changes: 23 additions & 1 deletion dlio_benchmark/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ def classname(cls):

def initialize(self):
from mpi4py import MPI
if self.mpi_state == MPIState.CHILD_INITIALIZED:
# The main process can end up in CHILD_INITIALIZED when
# TorchIterableDatasetSimple.__iter__ calls worker_init(0) directly
# in the main thread (num_workers=0 path). That deserializes
# ConfigArguments via pickle.loads → __setstate__ → DLIOMPI.reset()
# + set_parent_values(), leaving the singleton in CHILD_INITIALIZED.
# If MPI is actually running (MPI.Is_initialized()), we are the
# real MPI process — reset to UNINITIALIZED so initialization
# proceeds normally below. If MPI is not running, we truly are
# in a child process and must refuse.
if MPI.Is_initialized():
self.mpi_state = MPIState.UNINITIALIZED
else:
raise Exception(f"method {self.classname()}.initialize() called in a child process")
if self.mpi_state == MPIState.UNINITIALIZED:
# MPI may have already been initialized by dlio_benchmark_test.py
if not MPI.Is_initialized():
Expand Down Expand Up @@ -346,14 +360,22 @@ def reduce(self, num):
return MPI.COMM_WORLD.allreduce(num, op=MPI.SUM)

def allreduce_min(self, value):
from mpi4py import MPI
if self.mpi_state == MPIState.UNINITIALIZED:
raise Exception(f"method {self.classname()}.allreduce_min() called before initializing MPI")
# Single-rank or child-process (DataLoader worker): no collective needed.
# Child processes can never issue MPI collectives; returning the local
# value is correct for single-rank runs and safe for workers.
if self.mpi_state == MPIState.CHILD_INITIALIZED or self.mpi_size <= 1:
return value
from mpi4py import MPI
return self.comm().allreduce(value, op=MPI.MIN)

def alltoall(self, data):
if self.mpi_state == MPIState.UNINITIALIZED:
raise Exception(f"method {self.classname()}.alltoall() called before initializing MPI")
# Single-rank or child-process: identity operation.
if self.mpi_state == MPIState.CHILD_INITIALIZED or self.mpi_size <= 1:
return data
return self.comm().alltoall(data)

def finalize(self):
Expand Down
7 changes: 4 additions & 3 deletions docs/DLIO-Object-Storage_Analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ However, this code path is **only triggered by the TFDataLoader**, which calls `

## 6. One Actual Issue Found (Not Timing-Related)

The `configs/dlio/workload/unet3d_h100_s3dlio.yaml` file still contains the hardcoded endpoint and personal paths that were cleaned from `tests/object-store/`. Specifically:
The `configs/dlio/workload/unet3d_h100_s3dlio.yaml` file still contains a hardcoded
endpoint and personal paths that were cleaned from `tests/object-store/`. Specifically:

- `endpoint_url: http://172.16.1.40:9000`
- `source /home/eval/Documents/Code/mlp-storage/.env` in the comments
- `endpoint_url: <hardcoded-internal-ip>:9000`
- A local filesystem path in the comments

This was outside the scope of the previous cleanup pass and is a separate issue from timing correctness.

Expand Down
2 changes: 1 addition & 1 deletion docs/DLRM-Parquet-S3-Throughput-Analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ This is the only path to full 10 GiB/s with the existing per-file/per-row-group

**Pros:**
- Clean separation: s3dlio stays format-agnostic
- Can be used without s3dlio (e.g. with boto3 backend)
- Can be used without s3dlio (e.g. with a different S3 backend)
- Easier to publish independently

**Cons:**
Expand Down
Loading
Loading