Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/runpod_flash/cli/commands/_run_server_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ async def lb_execute(resource_config, func, body: dict):
func: The @remote LB route handler function.
body: Parsed request body (from FastAPI's automatic JSON parsing).
"""
# Extract dependencies before unwrapping the Endpoint facade
dependencies = getattr(resource_config, "dependencies", None)
system_dependencies = getattr(resource_config, "system_dependencies", None)
accelerate_downloads = getattr(resource_config, "accelerate_downloads", False)

# Endpoint facade wraps an internal resource config
if hasattr(resource_config, "_build_resource_config"):
resource_config = resource_config._build_resource_config()
Expand All @@ -121,7 +126,9 @@ async def lb_execute(resource_config, func, body: dict):
log.info(f"{resource_config} | {route_label}")

try:
result = await stub(func, None, None, False, **kwargs)
result = await stub(
func, dependencies, system_dependencies, accelerate_downloads, **kwargs
)
log.info(f"{resource_config} | Execution complete")
return result
except TimeoutError as e:
Expand Down
35 changes: 17 additions & 18 deletions src/runpod_flash/cli/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,18 @@
PIP_MODULE = "pip"


# Packages pre-installed in base Docker images (runpod/pytorch:*).
# Always excluded from build artifacts to avoid:
# 1. Exceeding the 500 MB tarball limit (torch alone is ~500 MB)
# 2. Redundant copies — these are already in the base Docker image
# NOTE: numpy is excluded because the base Docker image provides it, and
# keeping it out of the tarball saves ~30 MB toward the 500 MB limit.
BASE_IMAGE_PACKAGES: frozenset[str] = frozenset(
# These are CUDA/GPU-oriented packages whose large CUDA builds are already
# provided by the GPU base images (runpod/pytorch:*) and therefore should
# not be bundled into the tarball.
# Do NOT add packages here just because the GPU image ships them (e.g. numpy).
# The blacklist is defined strictly by size constraints, not by whether a
# package happens to be present in a particular base image.
SIZE_PROHIBITIVE_PACKAGES: frozenset[str] = frozenset(
{
"torch",
"torchvision",
"torchaudio",
"numpy",
"triton",
"torch", # ~500 MB
"torchvision", # ~50 MB, requires torch
"torchaudio", # ~30 MB, requires torch
"triton", # ~150 MB, CUDA compiler
}
)

Expand Down Expand Up @@ -272,11 +271,11 @@ def run_build(
# Create build directory first to ensure clean state before collecting files
build_dir = create_build_directory(project_dir, app_name)

# Parse exclusions: merge user-specified with always-excluded base image packages
# Parse exclusions: merge user-specified with always-excluded size-prohibitive packages
user_excluded = []
if exclude:
user_excluded = [pkg.strip().lower() for pkg in exclude.split(",")]
excluded_packages = list(set(user_excluded) | BASE_IMAGE_PACKAGES)
excluded_packages = list(set(user_excluded) | SIZE_PROHIBITIVE_PACKAGES)

spec = load_ignore_patterns(project_dir)
files = get_file_tree(project_dir, spec)
Expand Down Expand Up @@ -370,7 +369,7 @@ def run_build(
for req in requirements:
if should_exclude_package(req, excluded_packages):
pkg_name = extract_package_name(req)
if pkg_name in BASE_IMAGE_PACKAGES:
if pkg_name in SIZE_PROHIBITIVE_PACKAGES:
auto_matched.add(pkg_name)
if pkg_name in user_excluded:
user_matched.add(pkg_name)
Expand All @@ -381,12 +380,12 @@ def run_build(

if auto_matched:
console.print(
f"[dim]Auto-excluded base image packages: "
f"[dim]Auto-excluded size-prohibitive packages: "
f"{', '.join(sorted(auto_matched))}[/dim]"
)

# Only warn about unmatched user-specified packages (not auto-excludes)
user_unmatched = set(user_excluded) - user_matched - BASE_IMAGE_PACKAGES
user_unmatched = set(user_excluded) - user_matched - SIZE_PROHIBITIVE_PACKAGES
if user_unmatched:
console.print(
f"[yellow]Warning:[/yellow] No packages matched exclusions: "
Expand Down Expand Up @@ -981,7 +980,7 @@ def create_tarball(
excluded_packages: list[str] | None = None,
) -> None:
"""
Create gzipped tarball of build directory, excluding base image packages.
Create gzipped tarball of build directory, excluding size-prohibitive packages.

Filters at tarball creation time rather than constraining pip resolution,
because pip constraints (`<0.0.0a0`) break resolution for any package that
Expand Down
29 changes: 24 additions & 5 deletions src/runpod_flash/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,30 @@ def _module_parent_subdir(module_path: str) -> str | None:
return parts[0].replace(".", "/")


def _make_import_line(module_path: str, name: str) -> str:
def _make_import_line(module_path: str, name: str, alias: str | None = None) -> str:
"""Build an import statement for *name* from *module_path*.

Uses a regular ``from … import …`` when the module path is a valid
Python identifier chain. Falls back to ``_flash_import()`` (a generated
helper in server.py) when any segment starts with a digit. The helper
temporarily scopes ``sys.path`` so sibling imports in the target module
resolve to the correct directory.

Args:
module_path: Dotted module path to import from.
name: Symbol name to import.
alias: If provided, assign the import to this variable name instead
of *name*. Prevents collisions when multiple modules export the
same symbol (e.g. multiple files exporting ``api``).
"""
target = alias or name
if _has_numeric_module_segments(module_path):
subdir = _module_parent_subdir(module_path)
if subdir:
return f'{name} = _flash_import("{module_path}", "{name}", "{subdir}")'
return f'{name} = _flash_import("{module_path}", "{name}")'
return f'{target} = _flash_import("{module_path}", "{name}", "{subdir}")'
return f'{target} = _flash_import("{module_path}", "{name}")'
if alias:
return f"from {module_path} import {name} as {alias}"
return f"from {module_path} import {name}"


Expand Down Expand Up @@ -390,13 +400,22 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat
)
elif worker.worker_type == "LB":
# Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...))
# Use aliased names to prevent collisions when multiple files export
# the same variable name (e.g. multiple files exporting "api").
config_vars = {
r["config_variable"]
for r in worker.lb_routes
if r.get("config_variable")
}
for var in sorted(config_vars):
all_imports.append(_make_import_line(worker.module_path, var))
alias = f"_cfg_{_sanitize_fn_name(worker.resource_name)}"
all_imports.append(
_make_import_line(worker.module_path, var, alias=alias)
)
# Store the alias so route codegen can reference it
for r in worker.lb_routes:
if r.get("config_variable") == var:
r["_config_alias"] = alias
for fn_name in worker.functions:
all_imports.append(_make_import_line(worker.module_path, fn_name))

Expand Down Expand Up @@ -561,7 +580,7 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat
method = route["method"].lower()
sub_path = route["path"].lstrip("/")
fn_name = route["fn_name"]
config_var = route["config_variable"]
config_var = route.get("_config_alias") or route["config_variable"]
full_path = f"{worker.url_prefix}/{sub_path}"
handler_name = _sanitize_fn_name(
f"_route_{worker.resource_name}_{fn_name}"
Expand Down
33 changes: 31 additions & 2 deletions src/runpod_flash/core/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _find_resource_config_vars(self, file_path: Path) -> Set[str]:
Detects:
- @remote(resource_config=var) / @remote(var) patterns
- ep = Endpoint(...) variables used as LB route decorators (@ep.get, @ep.post, etc)
- @Endpoint(...) used directly as a function/class decorator (QB pattern)

Args:
file_path: Path to Python file to parse
Expand Down Expand Up @@ -143,6 +144,10 @@ def _find_resource_config_vars(self, file_path: Path) -> Set[str]:
if var_name:
var_names.add(var_name)

# @Endpoint(name=..., gpu=...) directly on function/class (QB)
elif self._is_endpoint_direct_decorator(decorator):
var_names.add(node.name)

except Exception as e:
log.warning(f"Failed to parse {file_path}: {e}")

Expand Down Expand Up @@ -170,6 +175,21 @@ def _extract_endpoint_var_from_route(self, decorator: ast.Call) -> str:
return func.value.id
return ""

def _is_endpoint_direct_decorator(self, decorator: ast.expr) -> bool:
"""Check if decorator is @Endpoint(...) used directly on a function/class (QB pattern).

Matches @Endpoint(name=..., gpu=...) but NOT @ep.get()/@ep.post() (which are
attribute calls on an Endpoint variable, handled separately).
"""
if not isinstance(decorator, ast.Call):
return False
func = decorator.func
if isinstance(func, ast.Name) and func.id == "Endpoint":
return True
if isinstance(func, ast.Attribute) and func.attr == "Endpoint":
return True
return False

def _is_remote_decorator(self, decorator: ast.expr) -> bool:
"""Check if decorator is @remote.

Expand Down Expand Up @@ -248,8 +268,10 @@ def _import_module(self, file_path: Path):
def _resolve_resource_variable(self, module, var_name: str) -> DeployableResource:
"""Resolve variable name to DeployableResource instance.

Handles both legacy resource config objects (LiveServerless, etc) and
Endpoint facade objects (unwraps via _build_resource_config()).
Handles:
- Legacy resource config objects (LiveServerless, etc)
- Endpoint facade objects (unwraps via _build_resource_config())
- QB-decorated functions/classes (unwraps __remote_config__["resource_config"])

Args:
module: Imported module
Expand All @@ -270,6 +292,13 @@ def _resolve_resource_variable(self, module, var_name: str) -> DeployableResourc
if isinstance(resource, DeployableResource):
return resource

# unwrap @Endpoint(...)-decorated function/class (QB pattern).
# Endpoint.__call__ wraps via @remote which attaches __remote_config__
if obj is not None and hasattr(obj, "__remote_config__"):
resource = obj.__remote_config__.get("resource_config")
if isinstance(resource, DeployableResource):
return resource

if obj is not None:
log.warning(
f"Resource '{var_name}' failed to resolve to DeployableResource "
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_p1_integration_gaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def test_qb_and_lb_workers_in_same_project(self):
assert "_call_with_body(process" in content

# LB worker: config + function import + LB route
assert "from api import api_config" in content
assert "from api import api_config as _cfg_api" in content
assert "from api import list_items" in content
assert "_lb_execute(api_config, list_items," in content
assert "_lb_execute(_cfg_api, list_items," in content

# Both import helpers should be present
assert "_call_with_body" in content
Expand Down Expand Up @@ -113,9 +113,9 @@ def test_qb_class_and_lb_function_in_same_project(self):
assert "_instance_TextModel.predict" in content

# LB function: config import + route
assert "from health import health_config" in content
assert "from health import health_config as _cfg_health" in content
assert "from health import status" in content
assert "_lb_execute(health_config, status," in content
assert "_lb_execute(_cfg_health, status," in content

def test_multiple_lb_routes_alongside_qb(self):
"""Multiple LB routes + QB function all present."""
Expand Down Expand Up @@ -160,8 +160,8 @@ def test_multiple_lb_routes_alongside_qb(self):
content = server_path.read_text()

# Both LB routes registered
assert "_lb_execute(lb_config, create," in content
assert "_lb_execute(lb_config, read," in content
assert "_lb_execute(_cfg_routes, create," in content
assert "_lb_execute(_cfg_routes, read," in content
# QB route also present
assert '"/worker/runsync"' in content

Expand Down
29 changes: 16 additions & 13 deletions tests/unit/cli/commands/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typer

from runpod_flash.cli.commands.build import (
BASE_IMAGE_PACKAGES,
SIZE_PROHIBITIVE_PACKAGES,
_find_runpod_flash,
_resolve_pip_python_version,
collect_requirements,
Expand Down Expand Up @@ -601,15 +601,18 @@ def _stack():
return _stack()

def test_constant_contains_expected_packages(self):
"""Verify torch ecosystem, numpy, and triton are in BASE_IMAGE_PACKAGES."""
assert "torch" in BASE_IMAGE_PACKAGES
assert "torchvision" in BASE_IMAGE_PACKAGES
assert "torchaudio" in BASE_IMAGE_PACKAGES
assert "numpy" in BASE_IMAGE_PACKAGES
assert "triton" in BASE_IMAGE_PACKAGES
"""Verify CUDA/torch ecosystem packages are in SIZE_PROHIBITIVE_PACKAGES."""
assert "torch" in SIZE_PROHIBITIVE_PACKAGES
assert "torchvision" in SIZE_PROHIBITIVE_PACKAGES
assert "torchaudio" in SIZE_PROHIBITIVE_PACKAGES
assert "triton" in SIZE_PROHIBITIVE_PACKAGES

def test_numpy_not_in_size_prohibitive_packages(self):
"""NumPy must NOT be excluded — CPU images (python-slim) don't ship it."""
assert "numpy" not in SIZE_PROHIBITIVE_PACKAGES

def test_auto_excludes_torch_without_flag(self, tmp_path):
"""Torch and numpy are filtered even with no --exclude flag."""
"""Torch is filtered even with no --exclude flag; numpy passes through."""
project_dir = tmp_path / "project"
project_dir.mkdir()
(project_dir / "worker.py").write_text(
Expand Down Expand Up @@ -637,11 +640,11 @@ def fake_install(_build_dir, reqs, _no_deps, target_python_version=None):

pkg_names = [extract_package_name(r) for r in installed]
assert "torch" not in pkg_names
assert "numpy" not in pkg_names
assert "numpy" in pkg_names
assert "requests" in pkg_names

def test_user_excludes_merged_with_auto(self, tmp_path):
"""User --exclude scipy + auto torch/numpy = all excluded."""
"""User --exclude scipy + auto torch = all excluded; numpy passes through."""
project_dir = tmp_path / "project"
project_dir.mkdir()
(project_dir / "worker.py").write_text(
Expand Down Expand Up @@ -669,12 +672,12 @@ def fake_install(_build_dir, reqs, _no_deps, target_python_version=None):

pkg_names = [extract_package_name(r) for r in installed]
assert "torch" not in pkg_names
assert "numpy" not in pkg_names
assert "numpy" in pkg_names
assert "scipy" not in pkg_names
assert "pandas" in pkg_names

def test_auto_exclude_silent_when_not_in_requirements(self, tmp_path, capsys):
"""No auto-exclude message if no base image packages are in requirements."""
"""No auto-exclude message if no size-prohibitive packages are in requirements."""
project_dir = tmp_path / "project"
project_dir.mkdir()
(project_dir / "worker.py").write_text(
Expand All @@ -695,7 +698,7 @@ def test_auto_exclude_silent_when_not_in_requirements(self, tmp_path, capsys):
run_build(project_dir, "test_app", no_deps=True)

captured = capsys.readouterr()
assert "Auto-excluded base image packages" not in captured.out
assert "Auto-excluded size-prohibitive packages" not in captured.out

def test_user_unmatched_warning_excludes_base_image_packages(
self, tmp_path, capsys
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/cli/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,23 +469,22 @@ def test_post_lb_route_generates_body_param(self, tmp_path):
worker = self._make_lb_worker(tmp_path, method)
content = _generate_flash_server(tmp_path, [worker]).read_text()
assert "body: _api_list_routes_Input" in content
assert "_lb_execute(api_config, list_routes, _to_dict(body))" in content
assert "_lb_execute(_cfg_api, list_routes, _to_dict(body))" in content

def test_get_lb_route_uses_query_params(self, tmp_path):
"""GET LB routes pass query params as a dict."""
worker = self._make_lb_worker(tmp_path, "GET")
content = _generate_flash_server(tmp_path, [worker]).read_text()
assert "async def _route_api_list_routes(request: Request):" in content
assert (
"_lb_execute(api_config, list_routes, dict(request.query_params))"
in content
"_lb_execute(_cfg_api, list_routes, dict(request.query_params))" in content
)

def test_lb_config_var_and_function_imported(self, tmp_path):
"""LB config vars and functions are both imported for remote dispatch."""
worker = self._make_lb_worker(tmp_path)
content = _generate_flash_server(tmp_path, [worker]).read_text()
assert "from api import api_config" in content
assert "from api import api_config as _cfg_api" in content
assert "from api import list_routes" in content

def test_lb_execute_import_present_when_lb_routes_exist(self, tmp_path):
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/core/api/test_runpod_graphql_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,11 @@ class TestGraphQLQueries:
lambda r: len(r) == 1,
),
],
ids=lambda x: x
if isinstance(x, str) and not x.startswith("{") and not x.startswith("(")
else "",
ids=lambda x: (
x
if isinstance(x, str) and not x.startswith("{") and not x.startswith("(")
else ""
),
)
async def test_query_success(
self, method_name, call_args, mock_response, assert_fn
Expand Down
Loading
Loading