Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,27 @@ def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict:
return metrics_mean_ci


def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> List[str]:
"""Select datasets from the dataset meta info."""
if datasets is None:
# Dataset filtering not specified, return all datasets
return list(dataset_meta_info.keys())
else:
datasets = datasets.split(",")
# Check if datasets are valid
for dataset in datasets:
if dataset not in dataset_meta_info:
raise ValueError(f"Dataset {dataset} not found in dataset meta info")
# Return all requsted datasets
return datasets


def run_inference_and_evaluation(
model_config: ModelLoadConfig,
inference_config: InferenceConfig,
eval_config: EvaluationConfig,
dataset_meta_info: dict,
datasets: Optional[List[str]],
out_dir: str,
num_repeats: int = 1,
confidence_level: float = 0.95,
Expand All @@ -137,6 +153,8 @@ def run_inference_and_evaluation(
inference_config: Configuration for inference.
eval_config: Configuration for evaluation.
dataset_meta_info: Dictionary containing dataset metadata.
datasets: List of dataset names to run inference and evaluation on. If None, all datasets in the
dataset meta info will be processed.
out_dir: Output directory for results.
num_repeats: Number of times to repeat inference (for CI estimation).
confidence_level: Confidence level for CI calculation.
Expand Down Expand Up @@ -170,7 +188,6 @@ def run_inference_and_evaluation(
runner = MagpieInferenceRunner(model, inference_config)

# Tracking metrics across datasets
datasets = list(dataset_meta_info.keys())
ssim_per_dataset = []
cer_per_dataset = []
all_datasets_filewise_metrics = {}
Expand Down Expand Up @@ -369,8 +386,15 @@ def create_argument_parser() -> argparse.ArgumentParser:
data_group.add_argument(
'--datasets_json_path',
type=str,
required=True,
default=None,
help='Path to dataset configuration JSON file (will process all datasets in the file if --datasets is not specified)',
)
data_group.add_argument(
'--datasets',
type=str,
default=None,
help='Path to dataset configuration JSON file (will process all datasets in the file)',
help='Comma-separated list of dataset names to process using names from the datasets_json_path file. If not specified, all datasets in the datasets_json_path will be processed.',
)
data_group.add_argument(
'--out_dir',
Expand Down Expand Up @@ -478,7 +502,7 @@ def main():
args = parser.parse_args()

dataset_meta_info = load_evalset_config(args.datasets_json_path)
datasets = list(dataset_meta_info.keys())
datasets = filter_datasets(dataset_meta_info, args.datasets)

logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}")

Expand Down Expand Up @@ -549,6 +573,7 @@ def main():
inference_config=inference_config,
eval_config=eval_config,
dataset_meta_info=dataset_meta_info,
datasets=datasets,
out_dir=args.out_dir,
num_repeats=args.num_repeats,
confidence_level=args.confidence_level,
Expand Down
Loading