Skip to content

Commit 9b2ca83

Browse files
author
Duc Hoang
committed
fix perplexity metric issues
1 parent 5aa09c5 commit 9b2ca83

File tree

1 file changed

+82
-25
lines changed

1 file changed

+82
-25
lines changed

src/lighteval/utils/cache_management.py

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from lighteval.tasks.requests import Doc, SamplingMethod
4040
from lighteval.utils.utils import as_list
4141

42-
4342
logger = logging.getLogger(__name__)
4443

4544

@@ -58,7 +57,9 @@ def __str__(self):
5857
return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})"
5958

6059
def __hash__(self):
61-
return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big")
60+
return int.from_bytes(
61+
hashlib.sha256(str(self).encode()).digest(), byteorder="big"
62+
)
6263

6364

6465
class SampleCache:
@@ -84,7 +85,9 @@ def __init__(self, model_config: ModelConfig):
8485
self.model_hash = self.get_model_hash(model_config)
8586

8687
self.cache_dir = (
87-
Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash
88+
Path(os.path.expanduser(self.model_config.cache_dir))
89+
/ self.model_config.model_name
90+
/ self.model_hash
8891
)
8992
self.cache_dir.mkdir(parents=True, exist_ok=True)
9093

@@ -115,10 +118,14 @@ def _load_cached_indices(self) -> dict:
115118
# cache_file.parts gives all the subfolders of the url, up to the file name
116119
# last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2
117120
task_name, task_hash = cache_file.parts[-3:-1]
118-
sampling_method = SamplingMethod[cache_file.stem] # removes the file extension
121+
sampling_method = SamplingMethod[
122+
cache_file.stem
123+
] # removes the file extension
119124
task_id = TaskID(task_name, task_hash, sampling_method)
120125

121-
full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
126+
full_dataset = load_dataset(
127+
"parquet", data_files=str(cache_file), split="train"
128+
)
122129
sample_ids = []
123130
for row in full_dataset:
124131
try:
@@ -169,7 +176,9 @@ def _get_task_hash(self, full_task_name: str) -> str:
169176
task_configs: list[LightevalTaskConfig] = sorted(
170177
self.registry.task_to_configs[f"{task_suite}|{task_name}"]
171178
)
172-
config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs])
179+
config_str = "|".join(
180+
[task_config.__str__(lite=True) for task_config in task_configs]
181+
)
173182
task_hash = hashlib.sha256(config_str.encode()).hexdigest()[:16]
174183
self._task_hashes[full_task_name] = task_hash
175184
return self._task_hashes[full_task_name]
@@ -183,7 +192,12 @@ def get_cache_path(self, task_id: TaskID) -> Path:
183192
Returns:
184193
Path: Path to the cache file for the given task and sample type
185194
"""
186-
return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet"
195+
return (
196+
self.cache_dir
197+
/ task_id.task_name
198+
/ task_id.task_hash
199+
/ f"{task_id.sampling_method.name}.parquet"
200+
)
187201

188202
def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID:
189203
"""Returns a unique task indentifier. Depends on the task name,
@@ -202,12 +216,16 @@ def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID
202216

203217
def get_sampling_method(self, sample: dict) -> str:
204218
if len(sample.get("logprobs", [])) > 0:
219+
if len(sample.get("text", [])) == 0:
220+
return SamplingMethod.PERPLEXITY
205221
return SamplingMethod.LOGPROBS
206222
if len(sample.get("text", [])) > 0:
207223
return SamplingMethod.GENERATIVE
208224
return None
209225

210-
def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]:
226+
def _load_sample(
227+
self, sample: pd.core.series.Series | dict
228+
) -> Union[dict, ModelResponse]:
211229
"""Load a sample from cached data based on sample type.
212230
213231
Args:
@@ -261,7 +279,10 @@ def get_samples_to_process_and_cache(
261279
return docs_not_cached, set(tasks_with_cached_samples)
262280

263281
def get_samples_from_cache(
264-
self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod
282+
self,
283+
docs: List[Doc],
284+
task_ids: List[TaskID] | set[TaskID],
285+
sampling_method: SamplingMethod,
265286
) -> List[dict | ModelResponse]:
266287
"""Get cached samples for the given docs.
267288
Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise.
@@ -277,11 +298,15 @@ def get_samples_from_cache(
277298
continue
278299
cache_file = self.get_cache_path(task_id)
279300
try:
280-
dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
301+
dataset = load_dataset(
302+
"parquet", data_files=str(cache_file), split="train"
303+
)
281304
dataset_df = dataset.to_pandas().set_index("sample_id")
282305
task_datasets[task_id] = dataset_df
283306
except Exception as e:
284-
logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}")
307+
logger.warning(
308+
f"Error loading prediction cache for {str(task_id)}: {e}"
309+
)
285310

286311
# Build results list
287312
results = []
@@ -311,7 +336,11 @@ def cache_samples( # noqa C901
311336
sample = self._dump_sample(result)
312337

313338
processed_data[task_id].append({"sample_id": doc.id, "sample": sample})
314-
processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data}
339+
processed_data = {
340+
task_id: task_data
341+
for task_id, task_data in processed_data.items()
342+
if task_data
343+
}
315344

316345
# Concatenate it with existing data and save to file
317346
for task_id, task_data in processed_data.items():
@@ -325,32 +354,49 @@ def cache_samples( # noqa C901
325354
existing_samples = {}
326355
if cache_file.exists():
327356
try:
328-
existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
357+
existing_dataset = load_dataset(
358+
"parquet", data_files=str(cache_file), split="train"
359+
)
329360
existing_data = existing_dataset.to_list()
330361
except KeyError:
331362
logger.info(f"No data was cached for {str(task_id)}")
332363
except Exception as e:
333-
logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}")
364+
logger.error(
365+
f"Error loading existing prediction cache for {str(task_id)}: {e}"
366+
)
334367

335-
existing_samples = {(row["sample_id"], sampling_method) for row in existing_data}
336-
if any((row["sample_id"], sampling_method) in existing_samples for row in task_data):
368+
existing_samples = {
369+
(row["sample_id"], sampling_method) for row in existing_data
370+
}
371+
if any(
372+
(row["sample_id"], sampling_method) in existing_samples
373+
for row in task_data
374+
):
337375
logger.warning(
338376
"Unexpected behavior: You have reprocessed already cached items - we will ignore the new version."
339377
)
340378

341379
# Merge with new data (new data overwrites existing)
342380
# We look at id + sampling method
343-
new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples]
381+
new_data = [
382+
row
383+
for row in task_data
384+
if (row["sample_id"], sampling_method) not in existing_samples
385+
]
344386
all_samples = existing_data + new_data
345387

346388
# Save updated dataset
347389
dataset = Dataset.from_list(all_samples)
348390
dataset.to_parquet(str(cache_file))
349391

350-
logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.")
392+
logger.info(
393+
f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}."
394+
)
351395

352396
# Refresh cached indices after storing new samples
353-
self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples]
397+
self.existing_indices[task_id] = [
398+
sample["sample_id"] for sample in all_samples
399+
]
354400

355401

356402
def cached(sampling_method: SamplingMethod = None): # noqa C901
@@ -381,12 +427,16 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
381427
cache: SampleCache = self._cache
382428

383429
# Extract task names
384-
task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs}
430+
task_ids = {
431+
cache.get_task_id(doc.task_name, sampling_method) for doc in docs
432+
}
385433

386434
# 1) Identify which samples must be processed because they are not cached
387435
docs_not_cached: List[Doc]
388436
tasks_with_cached_samples: Set[TaskID]
389-
docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method)
437+
docs_not_cached, tasks_with_cached_samples = (
438+
cache.get_samples_to_process_and_cache(docs, sampling_method)
439+
)
390440

391441
# Log cache statistics
392442
cached_count = len(docs) - len(docs_not_cached)
@@ -399,7 +449,8 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
399449
new_results = []
400450
if docs_not_cached:
401451
tasks_needing_sample_processing = {
402-
cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached
452+
cache.get_task_id(doc.task_name, sampling_method)
453+
for doc in docs_not_cached
403454
}
404455
logger.info(
405456
f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}"
@@ -415,15 +466,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
415466
)
416467

417468
# 3) Create final results by pulling from newly saved file cache
418-
final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method)
469+
final_cached_results = cache.get_samples_from_cache(
470+
docs, task_ids, sampling_method
471+
)
419472

420473
# 4) We only keep samples with the correct sampling method
421474
final_results = [
422-
s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method
475+
s
476+
for s in final_cached_results
477+
if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method
423478
]
424479

425480
if any(r is None for r in final_results):
426-
raise ValueError("Problem while loading and aggregating items from cache.")
481+
raise ValueError(
482+
"Problem while loading and aggregating items from cache."
483+
)
427484

428485
return final_results
429486

0 commit comments

Comments
 (0)