Skip to content

Commit e3c8fcc

Browse files
Copilotsuper-linter
authored andcommitted
super-linter: fix linting issues [skip ci]
1 parent 701afe4 commit e3c8fcc

File tree

10 files changed

+902
-765
lines changed

10 files changed

+902
-765
lines changed

examples/megatron-lm/GPT3-175B/aks/helm/megatron-training/scripts/train_megatron.sh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ export OMPI_MCA_coll_hcoll_enable=0 \
4242

4343
if [ "$USE_SHARP" -eq 1 ]; then
4444
export SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
45-
SHARP_COLL_ENABLE_SAT=1 \
46-
SHARP_COLL_LOG_LEVEL=3 \
47-
SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
48-
NCCL_COLLNET_ENABLE=1
45+
SHARP_COLL_ENABLE_SAT=1 \
46+
SHARP_COLL_LOG_LEVEL=3 \
47+
SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
48+
NCCL_COLLNET_ENABLE=1
4949
fi
5050

5151
export NCCL_TOPO_FILE=$TOPO_FILE
@@ -61,7 +61,6 @@ VOCAB_FILE=${VOCAB_FILE:-$DATA_PATH/../bpe/vocab.json}
6161
MERGE_FILE=${MERGE_FILE:-$DATA_PATH/../bpe/merges.txt}
6262
DATA_CACHE_DIR=${DATA_CACHE_DIR:-$STORAGE_MOUNT/datacache}
6363

64-
6564
DATA_SET_SIZE=$(find $DATA_PATH -name "*.bin" -type f | wc -l)
6665

6766
readarray -t TRAIN_DATA < <(find $DATA_PATH -name "*.bin" -type f | sort | head -n $(($DATA_SET_SIZE - $CHUNKS - $CHUNKS)) | xargs -n 1 echo 1.0 | sed "s/.bin//g")

examples/megatron-lm/GPT3-175B/aks/helm/prepare-data/scripts/concatenate.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import os
21
import argparse
32
import logging
3+
import os
44
from glob import glob
55

66
logging.basicConfig(
@@ -9,58 +9,78 @@
99
handlers=[logging.StreamHandler()],
1010
)
1111

12-
def concatenate(input_directory="", output_directory="", worker_index=0, total_workers=1):
12+
13+
def concatenate(
14+
input_directory="", output_directory="", worker_index=0, total_workers=1
15+
):
1316
shards_per_file = 1200
1417
files = sorted(glob(os.path.join(input_directory, "example_train_chunk*.jsonl")))
1518
num_files = len(files)
16-
19+
1720
logging.info(f"Input directory: {input_directory}")
1821
logging.info(f"Output directory: {output_directory}")
1922
logging.info(f"Found {num_files} files to process")
20-
23+
2124
# Find the ceiling of the result
22-
shards = ((num_files + shards_per_file - 1) // shards_per_file)
23-
24-
logging.info(f"Creating {shards} combined chunk(s) comprising {shards_per_file} files each")
25-
25+
shards = (num_files + shards_per_file - 1) // shards_per_file
26+
27+
logging.info(
28+
f"Creating {shards} combined chunk(s) comprising {shards_per_file} files each"
29+
)
30+
2631
# Ensure output directory exists
2732
os.makedirs(output_directory, exist_ok=True)
28-
33+
2934
chunks_processed = 0
3035
for i in range(shards):
3136
if ((i - worker_index) % total_workers) != 0:
3237
continue
33-
38+
3439
file_start = i * shards_per_file
35-
40+
3641
if ((i + 1) * shards_per_file) >= len(files):
3742
file_stop = len(files)
3843
else:
3944
file_stop = (i + 1) * shards_per_file
40-
45+
4146
logging.info(f"Building chunk {i} with files {file_start} to {file_stop}")
42-
47+
4348
output_file = os.path.join(output_directory, f"slim_pajama_{i}.jsonl")
4449
with open(output_file, "w") as outf:
4550
for file_idx in range(file_start, min(file_stop, len(files))):
4651
with open(files[file_idx], "r") as inf:
4752
outf.write(inf.read())
48-
53+
4954
chunks_processed += 1
50-
55+
5156
# Create completion marker file in output directory
52-
completion_file = os.path.join(output_directory, f".concatenate-{worker_index}-complete")
57+
completion_file = os.path.join(
58+
output_directory, f".concatenate-{worker_index}-complete"
59+
)
5360
with open(completion_file, "w") as f:
54-
f.write(f"Worker {worker_index} completed concatenating {chunks_processed} chunks")
61+
f.write(
62+
f"Worker {worker_index} completed concatenating {chunks_processed} chunks"
63+
)
5564
logging.info(f"Created completion marker: {completion_file}")
5665

66+
5767
if __name__ == "__main__":
5868
parser = argparse.ArgumentParser(description="Concatenate JSONL files")
59-
parser.add_argument("--input-directory", type=str, required=True, help="Directory containing input files")
60-
parser.add_argument("--output-directory", type=str, required=True, help="Directory to write concatenated files")
69+
parser.add_argument(
70+
"--input-directory",
71+
type=str,
72+
required=True,
73+
help="Directory containing input files",
74+
)
75+
parser.add_argument(
76+
"--output-directory",
77+
type=str,
78+
required=True,
79+
help="Directory to write concatenated files",
80+
)
6181
parser.add_argument("--worker-index", type=int, default=0, help="Worker index")
6282
parser.add_argument("--total-workers", type=int, default=1, help="Total workers")
63-
83+
6484
args = parser.parse_args()
6585

6686
# Handle backward compatibility

examples/megatron-lm/GPT3-175B/aks/helm/prepare-data/scripts/download_slimpajama.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import os
1111
import time
12+
1213
import requests
1314

1415
CHUNKS = 10
@@ -26,6 +27,7 @@
2627
handlers=[logging.StreamHandler()],
2728
)
2829

30+
2931
def download_shard(url, filename, retry=RETRIES):
3032
"""Download a shard from the given URL and save it to the specified filename."""
3133
if os.path.exists(filename):
@@ -34,7 +36,7 @@ def download_shard(url, filename, retry=RETRIES):
3436

3537
try:
3638
response = requests.get(url, timeout=REQUEST_TIMEOUT)
37-
39+
3840
if response.status_code == 429 and retry > 0:
3941
time.sleep(BACKOFF_TIME)
4042
logging.warning("Throttled. Retrying download for %s...", filename)
@@ -44,76 +46,109 @@ def download_shard(url, filename, retry=RETRIES):
4446
if response.status_code != 200:
4547
if retry > 0:
4648
time.sleep(BACKOFF_TIME)
47-
logging.warning("HTTP %s for %s. Retrying (%d attempts left)...",
48-
response.status_code, filename, retry)
49+
logging.warning(
50+
"HTTP %s for %s. Retrying (%d attempts left)...",
51+
response.status_code,
52+
filename,
53+
retry,
54+
)
4955
download_shard(url, filename, retry=retry - 1)
5056
return
5157
else:
52-
logging.error("Failed to download %s: HTTP %s", url, response.status_code)
58+
logging.error(
59+
"Failed to download %s: HTTP %s", url, response.status_code
60+
)
5361
return
5462

5563
with open(filename, "wb") as fn:
5664
fn.write(response.content)
5765
logging.info("Downloaded %s", filename)
58-
66+
5967
except requests.exceptions.Timeout:
6068
if retry > 0:
6169
time.sleep(BACKOFF_TIME)
62-
logging.warning("Timeout downloading %s. Retrying (%d attempts left)...", filename, retry)
70+
logging.warning(
71+
"Timeout downloading %s. Retrying (%d attempts left)...",
72+
filename,
73+
retry,
74+
)
6375
download_shard(url, filename, retry=retry - 1)
6476
else:
6577
logging.error("Timeout downloading %s after %d retries", filename, RETRIES)
6678
except requests.exceptions.RequestException as e:
6779
if retry > 0:
6880
time.sleep(BACKOFF_TIME)
69-
logging.warning("Network error downloading %s: %s. Retrying (%d attempts left)...",
70-
filename, str(e), retry)
81+
logging.warning(
82+
"Network error downloading %s: %s. Retrying (%d attempts left)...",
83+
filename,
84+
str(e),
85+
retry,
86+
)
7187
download_shard(url, filename, retry=retry - 1)
7288
else:
73-
logging.error("Network error downloading %s after %d retries: %s", filename, RETRIES, str(e))
89+
logging.error(
90+
"Network error downloading %s after %d retries: %s",
91+
filename,
92+
RETRIES,
93+
str(e),
94+
)
95+
7496

75-
def download(directory, full_dataset=True, sample_files=100, worker_index=0, total_workers=1):
97+
def download(
98+
directory, full_dataset=True, sample_files=100, worker_index=0, total_workers=1
99+
):
76100
"""Download SlimPajama dataset from Hugging Face with parallel worker support."""
77101
files_downloaded = 0
78102
files_to_process = []
79-
103+
80104
# First, calculate all files that need to be downloaded
81105
for chunk in range(1, CHUNKS + 1):
82-
shard_limit = SHARDS if full_dataset else min(sample_files // CHUNKS + 1, SHARDS)
106+
shard_limit = (
107+
SHARDS if full_dataset else min(sample_files // CHUNKS + 1, SHARDS)
108+
)
83109
for shard in range(0, shard_limit):
84110
if not full_dataset and len(files_to_process) >= sample_files:
85111
break
86-
112+
87113
filename = f"example_train_chunk{chunk}_shard{shard}.jsonl.zst"
88114
url = f"{REPOSITORY_PATH}/chunk{chunk}/example_train_{shard}.jsonl.zst"
89115
files_to_process.append((filename, url))
90-
116+
91117
if not full_dataset and len(files_to_process) >= sample_files:
92118
break
93-
119+
94120
# Limit to sample_files if not downloading full dataset
95121
if not full_dataset:
96122
files_to_process = files_to_process[:sample_files]
97-
123+
98124
# Distribute files across workers using modulo
99-
worker_files = [file_info for i, file_info in enumerate(files_to_process) if i % total_workers == worker_index]
100-
101-
logging.info(f"Worker {worker_index}/{total_workers}: Processing {len(worker_files)} files out of {len(files_to_process)} total files")
102-
125+
worker_files = [
126+
file_info
127+
for i, file_info in enumerate(files_to_process)
128+
if i % total_workers == worker_index
129+
]
130+
131+
logging.info(
132+
f"Worker {worker_index}/{total_workers}: Processing {len(worker_files)} files out of {len(files_to_process)} total files"
133+
)
134+
103135
# Download assigned files
104136
for filename, url in worker_files:
105137
full_filename = os.path.join(directory, filename)
106138
download_shard(url, full_filename)
107139
files_downloaded += 1
108-
109-
logging.info(f"Worker {worker_index} completed: Downloaded {files_downloaded} files")
110-
140+
141+
logging.info(
142+
f"Worker {worker_index} completed: Downloaded {files_downloaded} files"
143+
)
144+
111145
# Create completion marker file
112146
completion_file = os.path.join(directory, f".download-{worker_index}-complete")
113147
with open(completion_file, "w") as f:
114148
f.write(f"Worker {worker_index} completed downloading {files_downloaded} files")
115149
logging.info(f"Created completion marker: {completion_file}")
116150

151+
117152
if __name__ == "__main__":
118153
parser = argparse.ArgumentParser(
119154
description="Download SlimPajama from Hugging Face with parallel worker support."
@@ -151,9 +186,9 @@ def download(directory, full_dataset=True, sample_files=100, worker_index=0, tot
151186

152187
os.makedirs(args.directory, exist_ok=True)
153188
download(
154-
args.directory,
155-
args.full_dataset,
189+
args.directory,
190+
args.full_dataset,
156191
args.sample_files,
157192
args.worker_index,
158-
args.total_workers
193+
args.total_workers,
159194
)

0 commit comments

Comments
 (0)