Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions agent/src/browser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Browser screenshot functions for AgentCore BrowserCustom.

Best-effort (fail-open): all operations are wrapped in try/except
so a Browser API outage never blocks the agent pipeline.
"""

import json
import os

_lambda_client = None


def _get_lambda_client():
"""Lazy-init and cache the Lambda client."""
global _lambda_client
if _lambda_client is not None:
return _lambda_client
import boto3

region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
if not region:
raise ValueError("AWS_REGION or AWS_DEFAULT_REGION must be set")
_lambda_client = boto3.client("lambda", region_name=region)
return _lambda_client


def capture_screenshot(url: str, task_id: str = "") -> str | None:
"""Invoke browser-tool Lambda to capture a screenshot. Returns pre-signed URL or None."""
function_name = os.environ.get("BROWSER_TOOL_FUNCTION_NAME")
if not function_name:
return None
try:
client = _get_lambda_client()
except ValueError as e:
print(f"[browser] [ERROR] Configuration error: {e}", flush=True)
return None
try:
payload = json.dumps({"action": "screenshot", "url": url, "taskId": task_id})
response = client.invoke(
FunctionName=function_name,
InvocationType="RequestResponse",
Payload=payload,
)
if "FunctionError" in response:
error_payload = json.loads(response["Payload"].read())
print(
f"[browser] [ERROR] Lambda function crashed: "
f"{response['FunctionError']} — {error_payload}",
flush=True,
)
return None
result = json.loads(response["Payload"].read())
if result.get("status") == "success":
print(f"[browser] Screenshot captured: {result.get('screenshotS3Key')}", flush=True)
return result.get("presignedUrl")
print(f"[browser] Screenshot failed: {result.get('error', 'unknown')}", flush=True)
return None
except Exception as e:
print(
f"[browser] [WARN] capture_screenshot failed (transient): {type(e).__name__}: {e}",
flush=True,
)
return None
1 change: 1 addition & 0 deletions agent/src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ class TaskResult(BaseModel):
output_tokens: int | None = None
cache_read_input_tokens: int | None = None
cache_creation_input_tokens: int | None = None
screenshot_urls: list[str] = Field(default_factory=list)
16 changes: 16 additions & 0 deletions agent/src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ def run_task(
pr_url = ensure_pr(
config, setup, build_passed, lint_passed, agent_result=agent_result
)
# Screenshot capture (fail-open)
screenshot_urls: list[str] = []
if pr_url:
from post_hooks import _append_screenshots_to_pr, capture_pr_screenshots

try:
screenshot_urls = capture_pr_screenshots(pr_url, config.task_id)
if screenshot_urls:
_append_screenshots_to_pr(config, setup, screenshot_urls)
except Exception as exc:
log(
"WARN",
f"Screenshot capture failed (non-fatal): {type(exc).__name__}: {exc}",
)

post_span.set_attribute("build.passed", build_passed)
post_span.set_attribute("lint.passed", lint_passed)
post_span.set_attribute("pr.url", pr_url or "")
Expand Down Expand Up @@ -398,6 +413,7 @@ def run_task(
output_tokens=usage.output_tokens if usage else None,
cache_read_input_tokens=usage.cache_read_input_tokens if usage else None,
cache_creation_input_tokens=usage.cache_creation_input_tokens if usage else None,
screenshot_urls=screenshot_urls,
)

result_dict = result.model_dump()
Expand Down
80 changes: 80 additions & 0 deletions agent/src/post_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,86 @@ def ensure_pr(
return None


def capture_pr_screenshots(pr_url: str, task_id: str = "") -> list[str]:
"""Capture screenshot of PR page. Returns list of pre-signed URLs (fail-open)."""
from browser import capture_screenshot

if not pr_url or not pr_url.startswith("https://github.com/"):
return []
try:
url = capture_screenshot(pr_url, task_id)
return [url] if url else []
except Exception as e:
log("WARN", f"PR screenshot capture failed (non-fatal): {type(e).__name__}: {e}")
return []


def _append_screenshots_to_pr(
config: TaskConfig,
setup: RepoSetup,
screenshot_urls: list[str],
) -> None:
"""Append ## Screenshots section to PR body via gh pr edit."""
if not screenshot_urls:
return
try:
result = subprocess.run(
[
"gh",
"pr",
"view",
setup.branch,
"--repo",
config.repo_url,
"--json",
"body",
"-q",
".body",
],
cwd=setup.repo_dir,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode != 0:
log("WARN", "Could not read PR body for screenshot append")
return

current_body = result.stdout.strip()
images_md = "\n".join(
f"![Screenshot {i + 1}]({url})" for i, url in enumerate(screenshot_urls)
)
screenshots_section = f"## Screenshots\n\n{images_md}"

if re.search(r"## Screenshots", current_body):
updated_body = re.sub(
r"## Screenshots\n.*?(?=\n## |\Z)",
screenshots_section,
current_body,
flags=re.DOTALL,
)
else:
updated_body = f"{current_body}\n\n{screenshots_section}"

edit_result = subprocess.run(
["gh", "pr", "edit", setup.branch, "--repo", config.repo_url, "--body", updated_body],
cwd=setup.repo_dir,
capture_output=True,
text=True,
timeout=30,
)
if edit_result.returncode == 0:
log("POST", f"Appended {len(screenshot_urls)} screenshot(s) to PR body")
else:
log(
"WARN",
f"gh pr edit failed (rc={edit_result.returncode}): "
f"{edit_result.stderr.strip()[:200]}",
)
except Exception as e:
log("WARN", f"Failed to append screenshots to PR: {type(e).__name__}: {e}")


def _extract_agent_notes(repo_dir: str, branch: str, config: TaskConfig) -> str | None:
"""Extract the "## Agent notes" section from the PR body.

Expand Down
82 changes: 82 additions & 0 deletions agent/tests/test_browser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Unit tests for browser.py screenshot functions."""

import json
from io import BytesIO
from unittest.mock import MagicMock, patch

import pytest

import browser


@pytest.fixture(autouse=True)
def _reset_client():
"""Reset the cached Lambda client between tests."""
browser._lambda_client = None
yield
browser._lambda_client = None


class TestCaptureScreenshot:
def test_success_returns_presigned_url(self, monkeypatch):
monkeypatch.setenv("BROWSER_TOOL_FUNCTION_NAME", "my-browser-fn")
monkeypatch.setenv("AWS_REGION", "us-east-1")

response_payload = json.dumps(
{
"status": "success",
"screenshotS3Key": "screenshots/abc123.png",
"presignedUrl": "https://s3.amazonaws.com/bucket/screenshots/abc123.png",
}
).encode()

mock_client = MagicMock()
mock_client.invoke.return_value = {
"Payload": BytesIO(response_payload),
}

with patch("boto3.client", return_value=mock_client):
url = browser.capture_screenshot("https://github.com/owner/repo/pull/1", "task-123")

assert url == "https://s3.amazonaws.com/bucket/screenshots/abc123.png"
mock_client.invoke.assert_called_once()

def test_error_response_returns_none(self, monkeypatch):
monkeypatch.setenv("BROWSER_TOOL_FUNCTION_NAME", "my-browser-fn")
monkeypatch.setenv("AWS_REGION", "us-east-1")

response_payload = json.dumps(
{
"status": "error",
"error": "page not found",
}
).encode()

mock_client = MagicMock()
mock_client.invoke.return_value = {
"Payload": BytesIO(response_payload),
}

with patch("boto3.client", return_value=mock_client):
url = browser.capture_screenshot("https://example.com", "task-123")

assert url is None

def test_missing_env_var_returns_none(self, monkeypatch):
monkeypatch.delenv("BROWSER_TOOL_FUNCTION_NAME", raising=False)

url = browser.capture_screenshot("https://example.com", "task-123")

assert url is None

def test_lambda_invocation_exception_returns_none(self, monkeypatch):
monkeypatch.setenv("BROWSER_TOOL_FUNCTION_NAME", "my-browser-fn")
monkeypatch.setenv("AWS_REGION", "us-east-1")

mock_client = MagicMock()
mock_client.invoke.side_effect = Exception("Lambda timeout")

with patch("boto3.client", return_value=mock_client):
url = browser.capture_screenshot("https://example.com", "task-123")

assert url is None
81 changes: 81 additions & 0 deletions agent/tests/test_post_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Unit tests for post_hooks screenshot functions."""

from unittest.mock import MagicMock, patch

from post_hooks import _append_screenshots_to_pr, capture_pr_screenshots


class TestCapturePrScreenshots:
def test_returns_urls_on_success(self):
with patch("browser.capture_screenshot", return_value="https://s3/img.png"):
result = capture_pr_screenshots("https://github.com/owner/repo/pull/1", "task-1")
assert result == ["https://s3/img.png"]

def test_returns_empty_list_when_pr_url_empty(self):
result = capture_pr_screenshots("", "task-1")
assert result == []

def test_returns_empty_list_when_pr_url_not_github(self):
result = capture_pr_screenshots("https://gitlab.com/owner/repo/pull/1", "task-1")
assert result == []

def test_returns_empty_list_on_exception(self):
with patch("browser.capture_screenshot", side_effect=RuntimeError("boom")):
result = capture_pr_screenshots("https://github.com/owner/repo/pull/1", "task-1")
assert result == []


class TestAppendScreenshotsToPr:
def _make_mocks(self):
config = MagicMock()
config.repo_url = "https://github.com/owner/repo"
setup = MagicMock()
setup.branch = "bgagent/task-1"
setup.repo_dir = "/tmp/repo"
return config, setup

def test_appends_screenshots_section(self):
config, setup = self._make_mocks()
view_result = MagicMock(returncode=0, stdout="## Summary\n\nSome PR body")
edit_result = MagicMock(returncode=0, stderr="")
with patch("post_hooks.subprocess.run", side_effect=[view_result, edit_result]) as mock_run:
_append_screenshots_to_pr(config, setup, ["https://s3/img1.png"])
edit_call = mock_run.call_args_list[1]
body_arg = edit_call[0][0][edit_call[0][0].index("--body") + 1]
assert "## Screenshots" in body_arg
assert "![Screenshot 1](https://s3/img1.png)" in body_arg

def test_replaces_existing_screenshots_section(self):
config, setup = self._make_mocks()
existing_body = "## Summary\n\nBody\n\n## Screenshots\n\n![Screenshot 1](https://old.png)"
view_result = MagicMock(returncode=0, stdout=existing_body)
edit_result = MagicMock(returncode=0, stderr="")
with patch("post_hooks.subprocess.run", side_effect=[view_result, edit_result]) as mock_run:
_append_screenshots_to_pr(config, setup, ["https://s3/new.png"])
edit_call = mock_run.call_args_list[1]
body_arg = edit_call[0][0][edit_call[0][0].index("--body") + 1]
assert "![Screenshot 1](https://s3/new.png)" in body_arg
assert "https://old.png" not in body_arg
# Should only have one ## Screenshots section
assert body_arg.count("## Screenshots") == 1

def test_handles_gh_pr_view_failure(self):
config, setup = self._make_mocks()
view_result = MagicMock(returncode=1, stdout="", stderr="not found")
with patch("post_hooks.subprocess.run", return_value=view_result):
# Should not raise
_append_screenshots_to_pr(config, setup, ["https://s3/img.png"])

def test_handles_gh_pr_edit_failure(self):
config, setup = self._make_mocks()
view_result = MagicMock(returncode=0, stdout="## Summary\n\nBody")
edit_result = MagicMock(returncode=1, stderr="permission denied")
with patch("post_hooks.subprocess.run", side_effect=[view_result, edit_result]):
# Should not raise
_append_screenshots_to_pr(config, setup, ["https://s3/img.png"])

def test_does_nothing_when_urls_empty(self):
config, setup = self._make_mocks()
with patch("post_hooks.subprocess.run") as mock_run:
_append_screenshots_to_pr(config, setup, [])
mock_run.assert_not_called()
7 changes: 7 additions & 0 deletions agent/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def boom(**_kwargs):
assert body["status"] == "unhealthy"
assert body["reason"] == "background_pipeline_failed"

# Wait for the background thread to finish so write_terminal has been called
# (the 503 flag is set before write_terminal in the except block)
with server._threads_lock:
threads = list(server._active_threads)
for t in threads:
t.join(timeout=5)

mock_write.assert_called()
call_kw = mock_write.call_args
assert call_kw[0][0] == "task-crash-1"
Expand Down
6 changes: 5 additions & 1 deletion cdk/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"@aws-sdk/client-bedrock-agentcore": "^3.1021.0",
"@aws-sdk/client-bedrock-runtime": "^3.1021.0",
"@aws-sdk/client-ecs": "^3.1021.0",
"@aws-sdk/client-s3": "^3.1021.0",
"@aws-sdk/s3-request-presigner": "^3.1021.0",
"@aws-sdk/client-dynamodb": "^3.1021.0",
"@aws-sdk/client-lambda": "^3.1021.0",
"@aws-sdk/client-secrets-manager": "^3.1021.0",
Expand All @@ -27,14 +29,16 @@
"aws-cdk-lib": "^2.238.0",
"cdk-nag": "^2.37.55",
"constructs": "^10.3.0",
"ulid": "^3.0.2"
"ulid": "^3.0.2",
"ws": "^8.18.0"
},
"devDependencies": {
"@cdklabs/eslint-plugin": "^1.5.10",
"@stylistic/eslint-plugin": "^2",
"@types/aws-lambda": "^8.10.161",
"@types/jest": "^30.0.0",
"@types/node": "^20",
"@types/ws": "^8.18.0",
"@typescript-eslint/eslint-plugin": "^8",
"@typescript-eslint/parser": "^8",
"aws-cdk": "^2",
Expand Down
Loading
Loading