Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 54 additions & 7 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import base64
import ipaddress
import math
import os
import socket
import tempfile
from collections import defaultdict
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Coroutine, Dict, List, Optional, Tuple, TypedDict, Union
from urllib.parse import unquote, urlparse
from urllib.parse import unquote, urljoin, urlparse

import aiohttp
import numpy as np
Expand Down Expand Up @@ -111,6 +113,54 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image:
return image.convert(to_mode)


# SSRF/DoS protections for user-supplied URLs in multimodal inputs (see NVBugs 5911304).
_MAX_RESPONSE_BYTES = 200 * 1024 * 1024
_MAX_REDIRECTS = 5
_REDIRECT_STATUSES = (301, 302, 303, 307, 308)


def _validate_public_url(url: str) -> None:
"""Reject non-http(s) schemes and URLs resolving to non-public addresses."""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise RuntimeError(
f"Only http/https URLs are allowed, got {parsed.scheme!r}")
if not parsed.hostname:
raise RuntimeError("URL has no hostname")
try:
infos = socket.getaddrinfo(parsed.hostname,
None,
proto=socket.IPPROTO_TCP)
except socket.gaierror as exc:
raise RuntimeError(
f"Could not resolve hostname {parsed.hostname!r}") from exc
for *_, sockaddr in infos:
ip = ipaddress.ip_address(sockaddr[0])
if not ip.is_global or ip.is_multicast:
raise RuntimeError(f"URL resolves to a non-public address ({ip})")


async def _safe_aiohttp_get(url: str, session: aiohttp.ClientSession) -> bytes:
"""Fetch *url*, validating each redirect hop and capping response size."""
current = url
for _ in range(_MAX_REDIRECTS + 1):
await asyncio.to_thread(_validate_public_url, current)
async with session.get(current, allow_redirects=False) as response:
Comment on lines +143 to +148
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n tensorrt_llm/inputs/utils.py | head -200

Repository: NVIDIA/TensorRT-LLM

Length of output: 9041


🏁 Script executed:

python3 << 'EOF'
# Research aiohttp's DNS handling behavior
import aiohttp
import inspect

# Check what connector aiohttp uses by default
session = aiohttp.ClientSession()
print("Default connector:", session.connector)
print("Connector type:", type(session.connector))

# Check if there's DNS caching
try:
    print("\nConnector attributes:")
    print(dir(session.connector))
except Exception as e:
    print(f"Error: {e}")

EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 354


🏁 Script executed:

# Check if there are any DNS-related configurations or mitigations in the codebase
rg -i "dns|resolver|connector|resolve" tensorrt_llm/inputs/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 702


🏁 Script executed:

# Search for aiohttp usage patterns in the codebase
rg -A 5 -B 5 "aiohttp.ClientSession\|session.get\|session.post" tensorrt_llm/inputs/utils.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for aiohttp documentation or comments about DNS handling
rg -i "resolver|connector" tensorrt_llm/ | grep -E "\.py:" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 2032


🏁 Script executed:

# Check if there are any custom connectors or resolver configurations
fd -e py -x grep -l "TCPConnector\|resolver\|Resolver" {} \; | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 537


🏁 Script executed:

# Look at aiohttp's connector behavior - check what version is used
grep -r "aiohttp" tensorrt_llm/requirements.txt pyproject.toml setup.py 2>/dev/null | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there's a requirements file with aiohttp version
find . -name "requirements*.txt" -o -name "pyproject.toml" -o -name "setup.py" | head -5 | xargs grep -h "aiohttp" 2>/dev/null | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for any DNS rebinding or SSRF-related security discussions in the repo
rg -i "dns.rebind|rebind|ssrf" tensorrt_llm/ --type py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 647


🏁 Script executed:

# Look at the full context around the _safe_aiohttp_get function to see if there are other mitigations
sed -n '140,165p' tensorrt_llm/inputs/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1165


🏁 Script executed:

# Check if there's any DNS resolver configuration for the aiohttp session creation
rg -B 10 -A 10 "_get_aiohttp_session" tensorrt_llm/inputs/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3308


🏁 Script executed:

# Research if aiohttp has built-in DNS caching that would prevent rebinding
python3 << 'EOF'
# Check aiohttp documentation for DNS behavior
# By default, aiohttp.TCPConnector uses asyncio's getaddrinfo with no explicit DNS caching
# This means each connection attempt will perform a new DNS lookup

# The vulnerability window:
# 1. Line 147: socket.getaddrinfo() validates the hostname
# 2. Between lines 147-148: DNS TTL expires or cache is cleared
# 3. Line 148: aiohttp.session.get() performs another getaddrinfo() which gets a different IP

# This is a classic DNS rebinding attack vector:
# Attacker controls dns.example.com
# Step 1 (validation): dns.example.com → 8.8.8.8 (public IP) ✓ passes validation
# Step 2 (connection): dns.example.com → 127.0.0.1 (private IP) ✗ bypasses validation

print("DNS Rebinding Vulnerability Analysis:")
print("=" * 50)
print("The vulnerability is technically valid because:")
print("1. socket.getaddrinfo() and aiohttp's resolver are separate")
print("2. No DNS result pinning between validation and connection")
print("3. Attacker can control DNS responses with low TTL")
print("4. TOCTOU (Time-of-check to time-of-use) vulnerability")
EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 425


🏁 Script executed:

# Check if there are any mitigations or discussions about this in comments or issues
rg -B 5 -A 5 "NVBugs 5911304" tensorrt_llm/inputs/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 344


🏁 Script executed:

# Verify the actual implementation to confirm the vulnerability is real
sed -n '122,162p' tensorrt_llm/inputs/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1899


DNS rebinding allows bypassing the SSRF validation.

Line 147 validates the hostname resolves to a public IP, but line 148's session.get() performs a separate DNS resolution that could return a different address. An attacker controlling DNS can return a public address during validation and a private/loopback address during connect, defeating the SSRF protection. The validated IP must be pinned to the actual socket connection via a custom connector, or the request must use the resolved address directly instead of the hostname.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/inputs/utils.py` around lines 143 - 148, The SSRF validation in
_safe_aiohttp_get is vulnerable to DNS rebinding because
_validate_public_url(current) checks the hostname but session.get(current)
performs its own DNS resolution; fix by resolving the hostname once via
_validate_public_url (or a new resolver helper) to obtain the concrete IP/port
and then ensure the aiohttp request uses that same IP for the TCP connection
(e.g., create a custom aiohttp.TCPConnector/Resolver that pins the resolved IP
or call session.get() with the resolved IP and set the Host header to the
original hostname). Update _safe_aiohttp_get to pass the pinned IP into the
connector/request so DNS cannot be re-resolved, and keep redirect validation
logic (including calls to _validate_public_url and _MAX_REDIRECTS) intact.

if response.status in _REDIRECT_STATUSES:
current = urljoin(current, response.headers.get("Location", ""))
continue
response.raise_for_status()
buf = BytesIO()
total = 0
async for chunk in response.content.iter_chunked(1 << 20):
total += len(chunk)
if total > _MAX_RESPONSE_BYTES:
raise RuntimeError("Response exceeds maximum allowed size")
buf.write(chunk)
return buf.getvalue()
Comment on lines +148 to +160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

fd -t f "utils.py" | grep inputs

Repository: NVIDIA/TensorRT-LLM

Length of output: 92


🏁 Script executed:

cat tensorrt_llm/inputs/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45677


Add timeout parameters to prevent DoS from stalled or trickling remote servers.

The code caps response bytes but lacks connect/read timeouts on the aiohttp session. A hostile server can indefinitely stall the handshake or trickle chunks to keep the shared session occupied, making the DoS mitigation incomplete.

Suggested fix
+# Keep slow or stalled remote media fetches from hanging forever.
+_AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(
+    total=30,
+    connect=10,
+    sock_connect=10,
+    sock_read=10,
+)
+
 async def _safe_aiohttp_get(url: str, session: aiohttp.ClientSession) -> bytes:
     """Fetch *url*, validating each redirect hop and capping response size."""
     current = url
     for _ in range(_MAX_REDIRECTS + 1):
         await asyncio.to_thread(_validate_public_url, current)
-        async with session.get(current, allow_redirects=False) as response:
+        async with session.get(
+                current,
+                allow_redirects=False,
+                timeout=_AIOHTTP_TIMEOUT) as response:
             if response.status in _REDIRECT_STATUSES:
                 current = urljoin(current, response.headers.get("Location", ""))
                 continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/inputs/utils.py` around lines 148 - 160, Add explicit connect
and read timeouts to the HTTP request to prevent stalled or trickling responses
from exhausting the session: when calling session.get(...) (the request around
the async with session.get(current, allow_redirects=False) block) pass an
aiohttp timeout (ClientTimeout) that sets reasonable sock_connect and sock_read
(and optionally total) values, and ensure those timeouts apply to redirects
handling; keep the existing response size guard around
response.content.iter_chunked and raise as before if exceeded.

raise RuntimeError("Too many redirects")


def _load_and_convert_image(image):
image = Image.open(image)
image.load()
Expand Down Expand Up @@ -176,8 +226,7 @@ async def async_load_image(

if parsed_url.scheme in ["http", "https"]:
session = await _get_aiohttp_session()
async with session.get(image) as response:
content = await response.read()
content = await _safe_aiohttp_get(image, session)
image = await asyncio.to_thread(_load_and_convert_image,
BytesIO(content))
elif parsed_url.scheme == "data":
Expand Down Expand Up @@ -440,8 +489,7 @@ def _load_from_bytes(data: bytes) -> VideoData:

if parsed_url.scheme in ["http", "https"]:
session = await _get_aiohttp_session()
async with session.get(video) as response:
content = await response.content.read()
content = await _safe_aiohttp_get(video, session)
return await asyncio.to_thread(_load_from_bytes, content)
elif parsed_url.scheme == "data":
decoded_video = load_base64_video(video)
Expand Down Expand Up @@ -489,8 +537,7 @@ async def async_load_audio(

if parsed_url.scheme in ["http", "https"]:
session = await _get_aiohttp_session()
async with session.get(audio) as response:
content = await response.content.read()
content = await _safe_aiohttp_get(audio, session)
# Offload CPU-bound soundfile decoding to thread pool
return await asyncio.to_thread(soundfile.read, BytesIO(content))
elif parsed_url.scheme == "file":
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ l0_a10:
- unittest/inputs/test_chat_template_dispatch.py
- unittest/inputs/test_content_format.py
- unittest/inputs/test_multimodal.py
- unittest/inputs/test_url_validation.py
- unittest/others/test_convert_utils.py
- unittest/others/test_lora_manager.py
- unittest/others/test_lora_module_count.py
Expand Down
145 changes: 145 additions & 0 deletions tests/unittest/inputs/test_url_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Regression tests for SSRF/DoS protections in tensorrt_llm.inputs.utils.

Covers NVBugs 5911304: user-supplied multimodal URLs were fetched without
validation, allowing SSRF to private/loopback/IMDS addresses, unbounded
response sizes, and unrestricted redirects.
"""

import asyncio
import socket
from unittest.mock import patch

import pytest

from tensorrt_llm.inputs.utils import _MAX_RESPONSE_BYTES, _safe_aiohttp_get, _validate_public_url


def _dns(ip: str):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0))]


PUBLIC_DNS = _dns("93.184.216.34") # example.com


class TestValidatePublicUrl:
def test_rejects_non_http_scheme(self):
with pytest.raises(RuntimeError, match="Only http/https"):
_validate_public_url("file:///etc/passwd")

def test_rejects_missing_hostname(self):
with pytest.raises(RuntimeError, match="no hostname"):
_validate_public_url("http:///path")

@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"::1",
"10.0.0.1",
"172.16.0.1",
"192.168.1.1",
"169.254.169.254", # AWS/Azure/GCP IMDS
],
)
def test_rejects_non_public_addresses(self, ip):
with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns(ip)):
with pytest.raises(RuntimeError, match="non-public"):
_validate_public_url("http://target.example/")

def test_rejects_unresolvable_hostname(self):
with patch(
"tensorrt_llm.inputs.utils.socket.getaddrinfo", side_effect=socket.gaierror("nope")
):
with pytest.raises(RuntimeError, match="Could not resolve"):
_validate_public_url("http://this.does.not.exist.invalid/")

def test_accepts_public_address(self):
with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS):
_validate_public_url("https://example.com/image.jpg") # no raise


class _FakeContent:
def __init__(self, data: bytes, chunk_size: int = 1 << 20):
self._data = data
self._chunk = chunk_size

async def iter_chunked(self, size):
for i in range(0, len(self._data), self._chunk):
yield self._data[i : i + self._chunk]


class _FakeResponse:
def __init__(self, status=200, headers=None, body=b""):
self.status = status
self.headers = headers or {}
self.content = _FakeContent(body)

async def __aenter__(self):
return self

async def __aexit__(self, *_):
return False

def raise_for_status(self):
if self.status >= 400:
raise RuntimeError(f"http {self.status}")


class _FakeSession:
def __init__(self, responses):
self._responses = list(responses)

def get(self, url, **kwargs):
return self._responses.pop(0)
Comment on lines +88 to +93
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, find the file
find . -name "test_url_validation.py" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 110


🏁 Script executed:

# Read the file to understand the context
cat -n tests/unittest/inputs/test_url_validation.py | head -120

Repository: NVIDIA/TensorRT-LLM

Length of output: 4831


🏁 Script executed:

# Search for _safe_aiohttp_get to understand how it's being tested
rg -n "_safe_aiohttp_get" tests/unittest/inputs/test_url_validation.py -A 5 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 2184


🏁 Script executed:

# Find the utils file containing _safe_aiohttp_get
find . -name "utils.py" -path "*/inputs/*" -type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 94


🏁 Script executed:

# Read the implementation
cat -n tensorrt_llm/inputs/utils.py | grep -A 30 "def _safe_aiohttp_get"

Repository: NVIDIA/TensorRT-LLM

Length of output: 1524


Add assertion to test double enforcing allow_redirects=False.

_FakeSession.get() currently ignores kwargs, so tests won't catch a regression if _safe_aiohttp_get() stops passing allow_redirects=False. This parameter is critical to the security contract: it ensures per-hop URL validation instead of allowing aiohttp to auto-follow redirects unsafely, defeating SSRF protections.

Tighten the test double
 class _FakeSession:
     def __init__(self, responses):
         self._responses = list(responses)
 
     def get(self, url, **kwargs):
+        assert kwargs.get("allow_redirects") is False
         return self._responses.pop(0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class _FakeSession:
def __init__(self, responses):
self._responses = list(responses)
def get(self, url, **kwargs):
return self._responses.pop(0)
class _FakeSession:
def __init__(self, responses):
self._responses = list(responses)
def get(self, url, **kwargs):
assert kwargs.get("allow_redirects") is False
return self._responses.pop(0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/inputs/test_url_validation.py` around lines 88 - 93, The test
double _FakeSession.get currently ignores kwargs so regressions where
_safe_aiohttp_get stops passing allow_redirects=False won’t be caught; update
_FakeSession.get to assert that kwargs contains allow_redirects set to False (or
raise/record an error if missing/True) before returning the next response so the
unit test enforces the security contract and fails if _safe_aiohttp_get no
longer passes allow_redirects=False.



def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)


class TestSafeAiohttpGet:
def test_validates_before_request(self):
session = _FakeSession([]) # no responses needed; should fail validation
with patch(
"tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254")
):
with pytest.raises(RuntimeError, match="non-public"):
_run(_safe_aiohttp_get("http://target/", session))

def test_rejects_redirect_to_private(self):
responses = [
_FakeResponse(status=302, headers={"Location": "http://internal.corp/secret"}),
]
session = _FakeSession(responses)

def dns(host, *a, **kw):
return _dns("10.0.0.1") if "internal" in host else PUBLIC_DNS

with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", side_effect=dns):
with pytest.raises(RuntimeError, match="non-public"):
_run(_safe_aiohttp_get("http://example.com/", session))

def test_rejects_oversized_response(self):
big = b"x" * (_MAX_RESPONSE_BYTES + 1)
responses = [_FakeResponse(status=200, body=big)]
session = _FakeSession(responses)
with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS):
with pytest.raises(RuntimeError, match="maximum allowed size"):
_run(_safe_aiohttp_get("http://example.com/", session))

def test_rejects_too_many_redirects(self):
# 7 redirects -> exceeds _MAX_REDIRECTS (5)
responses = [
_FakeResponse(status=302, headers={"Location": "http://example.com/next"})
for _ in range(7)
]
session = _FakeSession(responses)
with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS):
with pytest.raises(RuntimeError, match="Too many redirects"):
_run(_safe_aiohttp_get("http://example.com/", session))

def test_returns_body_on_success(self):
responses = [_FakeResponse(status=200, body=b"hello")]
session = _FakeSession(responses)
with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS):
assert _run(_safe_aiohttp_get("http://example.com/", session)) == b"hello"
Loading