99import logging
1010import os
1111import time
12+
1213import requests
1314
1415CHUNKS = 10
2627 handlers = [logging .StreamHandler ()],
2728)
2829
30+
2931def download_shard (url , filename , retry = RETRIES ):
3032 """Download a shard from the given URL and save it to the specified filename."""
3133 if os .path .exists (filename ):
@@ -34,7 +36,7 @@ def download_shard(url, filename, retry=RETRIES):
3436
3537 try :
3638 response = requests .get (url , timeout = REQUEST_TIMEOUT )
37-
39+
3840 if response .status_code == 429 and retry > 0 :
3941 time .sleep (BACKOFF_TIME )
4042 logging .warning ("Throttled. Retrying download for %s..." , filename )
@@ -44,76 +46,109 @@ def download_shard(url, filename, retry=RETRIES):
4446 if response .status_code != 200 :
4547 if retry > 0 :
4648 time .sleep (BACKOFF_TIME )
47- logging .warning ("HTTP %s for %s. Retrying (%d attempts left)..." ,
48- response .status_code , filename , retry )
49+ logging .warning (
50+ "HTTP %s for %s. Retrying (%d attempts left)..." ,
51+ response .status_code ,
52+ filename ,
53+ retry ,
54+ )
4955 download_shard (url , filename , retry = retry - 1 )
5056 return
5157 else :
52- logging .error ("Failed to download %s: HTTP %s" , url , response .status_code )
58+ logging .error (
59+ "Failed to download %s: HTTP %s" , url , response .status_code
60+ )
5361 return
5462
5563 with open (filename , "wb" ) as fn :
5664 fn .write (response .content )
5765 logging .info ("Downloaded %s" , filename )
58-
66+
5967 except requests .exceptions .Timeout :
6068 if retry > 0 :
6169 time .sleep (BACKOFF_TIME )
62- logging .warning ("Timeout downloading %s. Retrying (%d attempts left)..." , filename , retry )
70+ logging .warning (
71+ "Timeout downloading %s. Retrying (%d attempts left)..." ,
72+ filename ,
73+ retry ,
74+ )
6375 download_shard (url , filename , retry = retry - 1 )
6476 else :
6577 logging .error ("Timeout downloading %s after %d retries" , filename , RETRIES )
6678 except requests .exceptions .RequestException as e :
6779 if retry > 0 :
6880 time .sleep (BACKOFF_TIME )
69- logging .warning ("Network error downloading %s: %s. Retrying (%d attempts left)..." ,
70- filename , str (e ), retry )
81+ logging .warning (
82+ "Network error downloading %s: %s. Retrying (%d attempts left)..." ,
83+ filename ,
84+ str (e ),
85+ retry ,
86+ )
7187 download_shard (url , filename , retry = retry - 1 )
7288 else :
73- logging .error ("Network error downloading %s after %d retries: %s" , filename , RETRIES , str (e ))
89+ logging .error (
90+ "Network error downloading %s after %d retries: %s" ,
91+ filename ,
92+ RETRIES ,
93+ str (e ),
94+ )
95+
7496
75- def download (directory , full_dataset = True , sample_files = 100 , worker_index = 0 , total_workers = 1 ):
97+ def download (
98+ directory , full_dataset = True , sample_files = 100 , worker_index = 0 , total_workers = 1
99+ ):
76100 """Download SlimPajama dataset from Hugging Face with parallel worker support."""
77101 files_downloaded = 0
78102 files_to_process = []
79-
103+
80104 # First, calculate all files that need to be downloaded
81105 for chunk in range (1 , CHUNKS + 1 ):
82- shard_limit = SHARDS if full_dataset else min (sample_files // CHUNKS + 1 , SHARDS )
106+ shard_limit = (
107+ SHARDS if full_dataset else min (sample_files // CHUNKS + 1 , SHARDS )
108+ )
83109 for shard in range (0 , shard_limit ):
84110 if not full_dataset and len (files_to_process ) >= sample_files :
85111 break
86-
112+
87113 filename = f"example_train_chunk{ chunk } _shard{ shard } .jsonl.zst"
88114 url = f"{ REPOSITORY_PATH } /chunk{ chunk } /example_train_{ shard } .jsonl.zst"
89115 files_to_process .append ((filename , url ))
90-
116+
91117 if not full_dataset and len (files_to_process ) >= sample_files :
92118 break
93-
119+
94120 # Limit to sample_files if not downloading full dataset
95121 if not full_dataset :
96122 files_to_process = files_to_process [:sample_files ]
97-
123+
98124 # Distribute files across workers using modulo
99- worker_files = [file_info for i , file_info in enumerate (files_to_process ) if i % total_workers == worker_index ]
100-
101- logging .info (f"Worker { worker_index } /{ total_workers } : Processing { len (worker_files )} files out of { len (files_to_process )} total files" )
102-
125+ worker_files = [
126+ file_info
127+ for i , file_info in enumerate (files_to_process )
128+ if i % total_workers == worker_index
129+ ]
130+
131+ logging .info (
132+ f"Worker { worker_index } /{ total_workers } : Processing { len (worker_files )} files out of { len (files_to_process )} total files"
133+ )
134+
103135 # Download assigned files
104136 for filename , url in worker_files :
105137 full_filename = os .path .join (directory , filename )
106138 download_shard (url , full_filename )
107139 files_downloaded += 1
108-
109- logging .info (f"Worker { worker_index } completed: Downloaded { files_downloaded } files" )
110-
140+
141+ logging .info (
142+ f"Worker { worker_index } completed: Downloaded { files_downloaded } files"
143+ )
144+
111145 # Create completion marker file
112146 completion_file = os .path .join (directory , f".download-{ worker_index } -complete" )
113147 with open (completion_file , "w" ) as f :
114148 f .write (f"Worker { worker_index } completed downloading { files_downloaded } files" )
115149 logging .info (f"Created completion marker: { completion_file } " )
116150
151+
117152if __name__ == "__main__" :
118153 parser = argparse .ArgumentParser (
119154 description = "Download SlimPajama from Hugging Face with parallel worker support."
@@ -151,9 +186,9 @@ def download(directory, full_dataset=True, sample_files=100, worker_index=0, tot
151186
152187 os .makedirs (args .directory , exist_ok = True )
153188 download (
154- args .directory ,
155- args .full_dataset ,
189+ args .directory ,
190+ args .full_dataset ,
156191 args .sample_files ,
157192 args .worker_index ,
158- args .total_workers
193+ args .total_workers ,
159194 )
0 commit comments