3939from lighteval .tasks .requests import Doc , SamplingMethod
4040from lighteval .utils .utils import as_list
4141
42-
4342logger = logging .getLogger (__name__ )
4443
4544
@@ -58,7 +57,9 @@ def __str__(self):
5857 return f"{ self .task_name } ({ self .task_hash } , { self .sampling_method .name } )"
5958
6059 def __hash__ (self ):
61- return int .from_bytes (hashlib .sha256 (str (self ).encode ()).digest (), byteorder = "big" )
60+ return int .from_bytes (
61+ hashlib .sha256 (str (self ).encode ()).digest (), byteorder = "big"
62+ )
6263
6364
6465class SampleCache :
@@ -84,7 +85,9 @@ def __init__(self, model_config: ModelConfig):
8485 self .model_hash = self .get_model_hash (model_config )
8586
8687 self .cache_dir = (
87- Path (os .path .expanduser (self .model_config .cache_dir )) / self .model_config .model_name / self .model_hash
88+ Path (os .path .expanduser (self .model_config .cache_dir ))
89+ / self .model_config .model_name
90+ / self .model_hash
8891 )
8992 self .cache_dir .mkdir (parents = True , exist_ok = True )
9093
@@ -115,10 +118,14 @@ def _load_cached_indices(self) -> dict:
115118 # cache_file.parts gives all the subfolders of the url, up to the file name
116119 # last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2
117120 task_name , task_hash = cache_file .parts [- 3 :- 1 ]
118- sampling_method = SamplingMethod [cache_file .stem ] # removes the file extension
121+ sampling_method = SamplingMethod [
122+ cache_file .stem
123+ ] # removes the file extension
119124 task_id = TaskID (task_name , task_hash , sampling_method )
120125
121- full_dataset = load_dataset ("parquet" , data_files = str (cache_file ), split = "train" )
126+ full_dataset = load_dataset (
127+ "parquet" , data_files = str (cache_file ), split = "train"
128+ )
122129 sample_ids = []
123130 for row in full_dataset :
124131 try :
@@ -169,7 +176,9 @@ def _get_task_hash(self, full_task_name: str) -> str:
169176 task_configs : list [LightevalTaskConfig ] = sorted (
170177 self .registry .task_to_configs [f"{ task_suite } |{ task_name } " ]
171178 )
172- config_str = "|" .join ([task_config .__str__ (lite = True ) for task_config in task_configs ])
179+ config_str = "|" .join (
180+ [task_config .__str__ (lite = True ) for task_config in task_configs ]
181+ )
173182 task_hash = hashlib .sha256 (config_str .encode ()).hexdigest ()[:16 ]
174183 self ._task_hashes [full_task_name ] = task_hash
175184 return self ._task_hashes [full_task_name ]
@@ -183,7 +192,12 @@ def get_cache_path(self, task_id: TaskID) -> Path:
183192 Returns:
184193 Path: Path to the cache file for the given task and sample type
185194 """
186- return self .cache_dir / task_id .task_name / task_id .task_hash / f"{ task_id .sampling_method .name } .parquet"
195+ return (
196+ self .cache_dir
197+ / task_id .task_name
198+ / task_id .task_hash
199+ / f"{ task_id .sampling_method .name } .parquet"
200+ )
187201
188202 def get_task_id (self , task_name : str , sampling_method : SamplingMethod ) -> TaskID :
189203 """Returns a unique task indentifier. Depends on the task name,
@@ -202,12 +216,16 @@ def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID
202216
203217 def get_sampling_method (self , sample : dict ) -> str :
204218 if len (sample .get ("logprobs" , [])) > 0 :
219+ if len (sample .get ("text" , [])) == 0 :
220+ return SamplingMethod .PERPLEXITY
205221 return SamplingMethod .LOGPROBS
206222 if len (sample .get ("text" , [])) > 0 :
207223 return SamplingMethod .GENERATIVE
208224 return None
209225
210- def _load_sample (self , sample : pd .core .series .Series | dict ) -> Union [dict , ModelResponse ]:
226+ def _load_sample (
227+ self , sample : pd .core .series .Series | dict
228+ ) -> Union [dict , ModelResponse ]:
211229 """Load a sample from cached data based on sample type.
212230
213231 Args:
@@ -261,7 +279,10 @@ def get_samples_to_process_and_cache(
261279 return docs_not_cached , set (tasks_with_cached_samples )
262280
263281 def get_samples_from_cache (
264- self , docs : List [Doc ], task_ids : List [TaskID ] | set [TaskID ], sampling_method : SamplingMethod
282+ self ,
283+ docs : List [Doc ],
284+ task_ids : List [TaskID ] | set [TaskID ],
285+ sampling_method : SamplingMethod ,
265286 ) -> List [dict | ModelResponse ]:
266287 """Get cached samples for the given docs.
267288 Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise.
@@ -277,11 +298,15 @@ def get_samples_from_cache(
277298 continue
278299 cache_file = self .get_cache_path (task_id )
279300 try :
280- dataset = load_dataset ("parquet" , data_files = str (cache_file ), split = "train" )
301+ dataset = load_dataset (
302+ "parquet" , data_files = str (cache_file ), split = "train"
303+ )
281304 dataset_df = dataset .to_pandas ().set_index ("sample_id" )
282305 task_datasets [task_id ] = dataset_df
283306 except Exception as e :
284- logger .warning (f"Error loading prediction cache for { str (task_id )} : { e } " )
307+ logger .warning (
308+ f"Error loading prediction cache for { str (task_id )} : { e } "
309+ )
285310
286311 # Build results list
287312 results = []
@@ -311,7 +336,11 @@ def cache_samples( # noqa C901
311336 sample = self ._dump_sample (result )
312337
313338 processed_data [task_id ].append ({"sample_id" : doc .id , "sample" : sample })
314- processed_data = {task_id : task_data for task_id , task_data in processed_data .items () if task_data }
339+ processed_data = {
340+ task_id : task_data
341+ for task_id , task_data in processed_data .items ()
342+ if task_data
343+ }
315344
316345 # Concatenate it with existing data and save to file
317346 for task_id , task_data in processed_data .items ():
@@ -325,32 +354,49 @@ def cache_samples( # noqa C901
325354 existing_samples = {}
326355 if cache_file .exists ():
327356 try :
328- existing_dataset = load_dataset ("parquet" , data_files = str (cache_file ), split = "train" )
357+ existing_dataset = load_dataset (
358+ "parquet" , data_files = str (cache_file ), split = "train"
359+ )
329360 existing_data = existing_dataset .to_list ()
330361 except KeyError :
331362 logger .info (f"No data was cached for { str (task_id )} " )
332363 except Exception as e :
333- logger .error (f"Error loading existing prediction cache for { str (task_id )} : { e } " )
364+ logger .error (
365+ f"Error loading existing prediction cache for { str (task_id )} : { e } "
366+ )
334367
335- existing_samples = {(row ["sample_id" ], sampling_method ) for row in existing_data }
336- if any ((row ["sample_id" ], sampling_method ) in existing_samples for row in task_data ):
368+ existing_samples = {
369+ (row ["sample_id" ], sampling_method ) for row in existing_data
370+ }
371+ if any (
372+ (row ["sample_id" ], sampling_method ) in existing_samples
373+ for row in task_data
374+ ):
337375 logger .warning (
338376 "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version."
339377 )
340378
341379 # Merge with new data (new data overwrites existing)
342380 # We look at id + sampling method
343- new_data = [row for row in task_data if (row ["sample_id" ], sampling_method ) not in existing_samples ]
381+ new_data = [
382+ row
383+ for row in task_data
384+ if (row ["sample_id" ], sampling_method ) not in existing_samples
385+ ]
344386 all_samples = existing_data + new_data
345387
346388 # Save updated dataset
347389 dataset = Dataset .from_list (all_samples )
348390 dataset .to_parquet (str (cache_file ))
349391
350- logger .info (f"Cached { len (all_samples )} samples of { str (task_id )} at { str (cache_file )} ." )
392+ logger .info (
393+ f"Cached { len (all_samples )} samples of { str (task_id )} at { str (cache_file )} ."
394+ )
351395
352396 # Refresh cached indices after storing new samples
353- self .existing_indices [task_id ] = [sample ["sample_id" ] for sample in all_samples ]
397+ self .existing_indices [task_id ] = [
398+ sample ["sample_id" ] for sample in all_samples
399+ ]
354400
355401
356402def cached (sampling_method : SamplingMethod = None ): # noqa C901
@@ -381,12 +427,16 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
381427 cache : SampleCache = self ._cache
382428
383429 # Extract task names
384- task_ids = {cache .get_task_id (doc .task_name , sampling_method ) for doc in docs }
430+ task_ids = {
431+ cache .get_task_id (doc .task_name , sampling_method ) for doc in docs
432+ }
385433
386434 # 1) Identify which samples must be processed because they are not cached
387435 docs_not_cached : List [Doc ]
388436 tasks_with_cached_samples : Set [TaskID ]
389- docs_not_cached , tasks_with_cached_samples = cache .get_samples_to_process_and_cache (docs , sampling_method )
437+ docs_not_cached , tasks_with_cached_samples = (
438+ cache .get_samples_to_process_and_cache (docs , sampling_method )
439+ )
390440
391441 # Log cache statistics
392442 cached_count = len (docs ) - len (docs_not_cached )
@@ -399,7 +449,8 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
399449 new_results = []
400450 if docs_not_cached :
401451 tasks_needing_sample_processing = {
402- cache .get_task_id (doc .task_name , sampling_method ) for doc in docs_not_cached
452+ cache .get_task_id (doc .task_name , sampling_method )
453+ for doc in docs_not_cached
403454 }
404455 logger .info (
405456 f"Cache: Starting to process { len (docs_not_cached )} /{ len (docs )} samples (not found in cache) for tasks { ',' .join (str (t ) for t in tasks_needing_sample_processing )} "
@@ -415,15 +466,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
415466 )
416467
417468 # 3) Create final results by pulling from newly saved file cache
418- final_cached_results = cache .get_samples_from_cache (docs , task_ids , sampling_method )
469+ final_cached_results = cache .get_samples_from_cache (
470+ docs , task_ids , sampling_method
471+ )
419472
420473 # 4) We only keep samples with the correct sampling method
421474 final_results = [
422- s for s in final_cached_results if cache .get_sampling_method (cache ._dump_sample (s )) == sampling_method
475+ s
476+ for s in final_cached_results
477+ if cache .get_sampling_method (cache ._dump_sample (s )) == sampling_method
423478 ]
424479
425480 if any (r is None for r in final_results ):
426- raise ValueError ("Problem while loading and aggregating items from cache." )
481+ raise ValueError (
482+ "Problem while loading and aggregating items from cache."
483+ )
427484
428485 return final_results
429486
0 commit comments