diff --git a/env_service/env_service.py b/env_service/env_service.py index f51bab46..7413b8f3 100644 --- a/env_service/env_service.py +++ b/env_service/env_service.py @@ -13,6 +13,7 @@ import asyncio from dataclasses import dataclass import importlib +import re import os import sys import time @@ -58,6 +59,28 @@ def ensure_env(name: str, rel_path: str) -> None: if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) +ALLOWED_ENV_TYPES = frozenset( + d.name + for d in Path( + os.path.join(os.path.dirname(__file__), "environments"), + ).iterdir() + if d.is_dir() and not d.name.startswith("_") +) + + +def _validate_env_type(env_type: str) -> None: + """Validate that env_type is a known, safe environment name.""" + if not re.match(r"^[a-zA-Z0-9_]+$", env_type): + raise ValueError( + f"Invalid env_type: {env_type!r}. " + "Must contain only alphanumeric characters and underscores.", + ) + if env_type not in ALLOWED_ENV_TYPES: + raise ValueError( + f"Unknown env_type: {env_type!r}. " + f"Available: {sorted(ALLOWED_ENV_TYPES)}", + ) + def import_and_register_env(env_name, env_file=None): """ @@ -71,6 +94,7 @@ def import_and_register_env(env_name, env_file=None): Returns: The registered environment class or None on failure. """ + _validate_env_type(env_name) try: if env_file is None: env_file = f"{env_name}_env" @@ -181,6 +205,7 @@ def get_remote_env_cls(self, env_type: str): Returns: The remote environment class. """ + _validate_env_type(env_type) if env_type in self.remote_env: return self.remote_env[env_type]