-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[https://nvbugs/5911304][fix] Add URL validation for media input loading #13680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,16 @@ | ||
| import asyncio | ||
| import base64 | ||
| import ipaddress | ||
| import math | ||
| import os | ||
| import socket | ||
| import tempfile | ||
| from collections import defaultdict | ||
| from dataclasses import dataclass | ||
| from io import BytesIO | ||
| from pathlib import Path | ||
| from typing import Any, Coroutine, Dict, List, Optional, Tuple, TypedDict, Union | ||
| from urllib.parse import unquote, urlparse | ||
| from urllib.parse import unquote, urljoin, urlparse | ||
|
|
||
| import aiohttp | ||
| import numpy as np | ||
|
|
@@ -111,6 +113,54 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: | |
| return image.convert(to_mode) | ||
|
|
||
|
|
||
| # SSRF/DoS protections for user-supplied URLs in multimodal inputs (see NVBugs 5911304). | ||
| _MAX_RESPONSE_BYTES = 200 * 1024 * 1024 | ||
| _MAX_REDIRECTS = 5 | ||
| _REDIRECT_STATUSES = (301, 302, 303, 307, 308) | ||
|
|
||
|
|
||
| def _validate_public_url(url: str) -> None: | ||
| """Reject non-http(s) schemes and URLs resolving to non-public addresses.""" | ||
| parsed = urlparse(url) | ||
| if parsed.scheme not in ("http", "https"): | ||
| raise RuntimeError( | ||
| f"Only http/https URLs are allowed, got {parsed.scheme!r}") | ||
| if not parsed.hostname: | ||
| raise RuntimeError("URL has no hostname") | ||
| try: | ||
| infos = socket.getaddrinfo(parsed.hostname, | ||
| None, | ||
| proto=socket.IPPROTO_TCP) | ||
| except socket.gaierror as exc: | ||
| raise RuntimeError( | ||
| f"Could not resolve hostname {parsed.hostname!r}") from exc | ||
| for *_, sockaddr in infos: | ||
| ip = ipaddress.ip_address(sockaddr[0]) | ||
| if not ip.is_global or ip.is_multicast: | ||
| raise RuntimeError(f"URL resolves to a non-public address ({ip})") | ||
|
|
||
|
|
||
| async def _safe_aiohttp_get(url: str, session: aiohttp.ClientSession) -> bytes: | ||
| """Fetch *url*, validating each redirect hop and capping response size.""" | ||
| current = url | ||
| for _ in range(_MAX_REDIRECTS + 1): | ||
| await asyncio.to_thread(_validate_public_url, current) | ||
| async with session.get(current, allow_redirects=False) as response: | ||
| if response.status in _REDIRECT_STATUSES: | ||
| current = urljoin(current, response.headers.get("Location", "")) | ||
| continue | ||
| response.raise_for_status() | ||
| buf = BytesIO() | ||
| total = 0 | ||
| async for chunk in response.content.iter_chunked(1 << 20): | ||
| total += len(chunk) | ||
| if total > _MAX_RESPONSE_BYTES: | ||
| raise RuntimeError("Response exceeds maximum allowed size") | ||
| buf.write(chunk) | ||
| return buf.getvalue() | ||
|
Comment on lines
+148
to
+160
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: fd -t f "utils.py" | grep inputsRepository: NVIDIA/TensorRT-LLM Length of output: 92 🏁 Script executed: cat tensorrt_llm/inputs/utils.pyRepository: NVIDIA/TensorRT-LLM Length of output: 45677 Add timeout parameters to prevent DoS from stalled or trickling remote servers. The code caps response bytes but lacks connect/read timeouts on the aiohttp session. A hostile server can indefinitely stall the handshake or trickle chunks to keep the shared session occupied, making the DoS mitigation incomplete. Suggested fix+# Keep slow or stalled remote media fetches from hanging forever.
+_AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(
+ total=30,
+ connect=10,
+ sock_connect=10,
+ sock_read=10,
+)
+
async def _safe_aiohttp_get(url: str, session: aiohttp.ClientSession) -> bytes:
"""Fetch *url*, validating each redirect hop and capping response size."""
current = url
for _ in range(_MAX_REDIRECTS + 1):
await asyncio.to_thread(_validate_public_url, current)
- async with session.get(current, allow_redirects=False) as response:
+ async with session.get(
+ current,
+ allow_redirects=False,
+ timeout=_AIOHTTP_TIMEOUT) as response:
if response.status in _REDIRECT_STATUSES:
current = urljoin(current, response.headers.get("Location", ""))
continue🤖 Prompt for AI Agents |
||
| raise RuntimeError("Too many redirects") | ||
|
|
||
|
|
||
| def _load_and_convert_image(image): | ||
| image = Image.open(image) | ||
| image.load() | ||
|
|
@@ -176,8 +226,7 @@ async def async_load_image( | |
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| session = await _get_aiohttp_session() | ||
| async with session.get(image) as response: | ||
| content = await response.read() | ||
| content = await _safe_aiohttp_get(image, session) | ||
| image = await asyncio.to_thread(_load_and_convert_image, | ||
| BytesIO(content)) | ||
| elif parsed_url.scheme == "data": | ||
|
|
@@ -440,8 +489,7 @@ def _load_from_bytes(data: bytes) -> VideoData: | |
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| session = await _get_aiohttp_session() | ||
| async with session.get(video) as response: | ||
| content = await response.content.read() | ||
| content = await _safe_aiohttp_get(video, session) | ||
| return await asyncio.to_thread(_load_from_bytes, content) | ||
| elif parsed_url.scheme == "data": | ||
| decoded_video = load_base64_video(video) | ||
|
|
@@ -489,8 +537,7 @@ async def async_load_audio( | |
|
|
||
| if parsed_url.scheme in ["http", "https"]: | ||
| session = await _get_aiohttp_session() | ||
| async with session.get(audio) as response: | ||
| content = await response.content.read() | ||
| content = await _safe_aiohttp_get(audio, session) | ||
| # Offload CPU-bound soundfile decoding to thread pool | ||
| return await asyncio.to_thread(soundfile.read, BytesIO(content)) | ||
| elif parsed_url.scheme == "file": | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,145 @@ | ||||||||||||||||||||||||||||
| """Regression tests for SSRF/DoS protections in tensorrt_llm.inputs.utils. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Covers NVBugs 5911304: user-supplied multimodal URLs were fetched without | ||||||||||||||||||||||||||||
| validation, allowing SSRF to private/loopback/IMDS addresses, unbounded | ||||||||||||||||||||||||||||
| response sizes, and unrestricted redirects. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||
| import socket | ||||||||||||||||||||||||||||
| from unittest.mock import patch | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from tensorrt_llm.inputs.utils import _MAX_RESPONSE_BYTES, _safe_aiohttp_get, _validate_public_url | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _dns(ip: str): | ||||||||||||||||||||||||||||
| return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0))] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| PUBLIC_DNS = _dns("93.184.216.34") # example.com | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class TestValidatePublicUrl: | ||||||||||||||||||||||||||||
| def test_rejects_non_http_scheme(self): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="Only http/https"): | ||||||||||||||||||||||||||||
| _validate_public_url("file:///etc/passwd") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_rejects_missing_hostname(self): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="no hostname"): | ||||||||||||||||||||||||||||
| _validate_public_url("http:///path") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||||||||||||||||
| "ip", | ||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||
| "127.0.0.1", | ||||||||||||||||||||||||||||
| "::1", | ||||||||||||||||||||||||||||
| "10.0.0.1", | ||||||||||||||||||||||||||||
| "172.16.0.1", | ||||||||||||||||||||||||||||
| "192.168.1.1", | ||||||||||||||||||||||||||||
| "169.254.169.254", # AWS/Azure/GCP IMDS | ||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| def test_rejects_non_public_addresses(self, ip): | ||||||||||||||||||||||||||||
| with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns(ip)): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="non-public"): | ||||||||||||||||||||||||||||
| _validate_public_url("http://target.example/") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_rejects_unresolvable_hostname(self): | ||||||||||||||||||||||||||||
| with patch( | ||||||||||||||||||||||||||||
| "tensorrt_llm.inputs.utils.socket.getaddrinfo", side_effect=socket.gaierror("nope") | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="Could not resolve"): | ||||||||||||||||||||||||||||
| _validate_public_url("http://this.does.not.exist.invalid/") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_accepts_public_address(self): | ||||||||||||||||||||||||||||
| with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): | ||||||||||||||||||||||||||||
| _validate_public_url("https://example.com/image.jpg") # no raise | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class _FakeContent: | ||||||||||||||||||||||||||||
| def __init__(self, data: bytes, chunk_size: int = 1 << 20): | ||||||||||||||||||||||||||||
| self._data = data | ||||||||||||||||||||||||||||
| self._chunk = chunk_size | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| async def iter_chunked(self, size): | ||||||||||||||||||||||||||||
| for i in range(0, len(self._data), self._chunk): | ||||||||||||||||||||||||||||
| yield self._data[i : i + self._chunk] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class _FakeResponse: | ||||||||||||||||||||||||||||
| def __init__(self, status=200, headers=None, body=b""): | ||||||||||||||||||||||||||||
| self.status = status | ||||||||||||||||||||||||||||
| self.headers = headers or {} | ||||||||||||||||||||||||||||
| self.content = _FakeContent(body) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| async def __aenter__(self): | ||||||||||||||||||||||||||||
| return self | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| async def __aexit__(self, *_): | ||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def raise_for_status(self): | ||||||||||||||||||||||||||||
| if self.status >= 400: | ||||||||||||||||||||||||||||
| raise RuntimeError(f"http {self.status}") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class _FakeSession: | ||||||||||||||||||||||||||||
| def __init__(self, responses): | ||||||||||||||||||||||||||||
| self._responses = list(responses) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def get(self, url, **kwargs): | ||||||||||||||||||||||||||||
| return self._responses.pop(0) | ||||||||||||||||||||||||||||
|
Comment on lines
+88
to
+93
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, find the file
find . -name "test_url_validation.py" -type fRepository: NVIDIA/TensorRT-LLM Length of output: 110 🏁 Script executed: # Read the file to understand the context
cat -n tests/unittest/inputs/test_url_validation.py | head -120Repository: NVIDIA/TensorRT-LLM Length of output: 4831 🏁 Script executed: # Search for _safe_aiohttp_get to understand how it's being tested
rg -n "_safe_aiohttp_get" tests/unittest/inputs/test_url_validation.py -A 5 -B 2Repository: NVIDIA/TensorRT-LLM Length of output: 2184 🏁 Script executed: # Find the utils file containing _safe_aiohttp_get
find . -name "utils.py" -path "*/inputs/*" -type fRepository: NVIDIA/TensorRT-LLM Length of output: 94 🏁 Script executed: # Read the implementation
cat -n tensorrt_llm/inputs/utils.py | grep -A 30 "def _safe_aiohttp_get"Repository: NVIDIA/TensorRT-LLM Length of output: 1524 Add assertion to test double enforcing
Tighten the test double class _FakeSession:
def __init__(self, responses):
self._responses = list(responses)
def get(self, url, **kwargs):
+ assert kwargs.get("allow_redirects") is False
return self._responses.pop(0)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _run(coro): | ||||||||||||||||||||||||||||
| return asyncio.new_event_loop().run_until_complete(coro) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class TestSafeAiohttpGet: | ||||||||||||||||||||||||||||
| def test_validates_before_request(self): | ||||||||||||||||||||||||||||
| session = _FakeSession([]) # no responses needed; should fail validation | ||||||||||||||||||||||||||||
| with patch( | ||||||||||||||||||||||||||||
| "tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254") | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="non-public"): | ||||||||||||||||||||||||||||
| _run(_safe_aiohttp_get("http://target/", session)) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_rejects_redirect_to_private(self): | ||||||||||||||||||||||||||||
| responses = [ | ||||||||||||||||||||||||||||
| _FakeResponse(status=302, headers={"Location": "http://internal.corp/secret"}), | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| session = _FakeSession(responses) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def dns(host, *a, **kw): | ||||||||||||||||||||||||||||
| return _dns("10.0.0.1") if "internal" in host else PUBLIC_DNS | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", side_effect=dns): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="non-public"): | ||||||||||||||||||||||||||||
| _run(_safe_aiohttp_get("http://example.com/", session)) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_rejects_oversized_response(self): | ||||||||||||||||||||||||||||
| big = b"x" * (_MAX_RESPONSE_BYTES + 1) | ||||||||||||||||||||||||||||
| responses = [_FakeResponse(status=200, body=big)] | ||||||||||||||||||||||||||||
| session = _FakeSession(responses) | ||||||||||||||||||||||||||||
| with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="maximum allowed size"): | ||||||||||||||||||||||||||||
| _run(_safe_aiohttp_get("http://example.com/", session)) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_rejects_too_many_redirects(self): | ||||||||||||||||||||||||||||
| # 7 redirects -> exceeds _MAX_REDIRECTS (5) | ||||||||||||||||||||||||||||
| responses = [ | ||||||||||||||||||||||||||||
| _FakeResponse(status=302, headers={"Location": "http://example.com/next"}) | ||||||||||||||||||||||||||||
| for _ in range(7) | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| session = _FakeSession(responses) | ||||||||||||||||||||||||||||
| with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): | ||||||||||||||||||||||||||||
| with pytest.raises(RuntimeError, match="Too many redirects"): | ||||||||||||||||||||||||||||
| _run(_safe_aiohttp_get("http://example.com/", session)) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def test_returns_body_on_success(self): | ||||||||||||||||||||||||||||
| responses = [_FakeResponse(status=200, body=b"hello")] | ||||||||||||||||||||||||||||
| session = _FakeSession(responses) | ||||||||||||||||||||||||||||
| with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): | ||||||||||||||||||||||||||||
| assert _run(_safe_aiohttp_get("http://example.com/", session)) == b"hello" | ||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n tensorrt_llm/inputs/utils.py | head -200Repository: NVIDIA/TensorRT-LLM
Length of output: 9041
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 354
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 702
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 2032
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 537
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 647
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 1165
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 3308
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 425
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 344
🏁 Script executed:
Repository: NVIDIA/TensorRT-LLM
Length of output: 1899
DNS rebinding allows bypassing the SSRF validation.
Line 147 validates the hostname resolves to a public IP, but line 148's
session.get()performs a separate DNS resolution that could return a different address. An attacker controlling DNS can return a public address during validation and a private/loopback address during connect, defeating the SSRF protection. The validated IP must be pinned to the actual socket connection via a custom connector, or the request must use the resolved address directly instead of the hostname.🤖 Prompt for AI Agents