diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 355e4e0f40ae..dce719859682 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -1,14 +1,16 @@ import asyncio import base64 +import ipaddress import math import os +import socket import tempfile from collections import defaultdict from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import Any, Coroutine, Dict, List, Optional, Tuple, TypedDict, Union -from urllib.parse import unquote, urlparse +from urllib.parse import unquote, urljoin, urlparse import aiohttp import numpy as np @@ -111,6 +113,54 @@ def convert_image_mode(image: Image.Image, to_mode: str) -> Image.Image: return image.convert(to_mode) +# SSRF/DoS protections for user-supplied URLs in multimodal inputs (see NVBugs 5911304). +_MAX_RESPONSE_BYTES = 200 * 1024 * 1024 +_MAX_REDIRECTS = 5 +_REDIRECT_STATUSES = (301, 302, 303, 307, 308) + + +def _validate_public_url(url: str) -> None: + """Reject non-http(s) schemes and URLs resolving to non-public addresses.""" + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise RuntimeError( + f"Only http/https URLs are allowed, got {parsed.scheme!r}") + if not parsed.hostname: + raise RuntimeError("URL has no hostname") + try: + infos = socket.getaddrinfo(parsed.hostname, + None, + proto=socket.IPPROTO_TCP) + except socket.gaierror as exc: + raise RuntimeError( + f"Could not resolve hostname {parsed.hostname!r}") from exc + for *_, sockaddr in infos: + ip = ipaddress.ip_address(sockaddr[0]) + if not ip.is_global or ip.is_multicast: + raise RuntimeError(f"URL resolves to a non-public address ({ip})") + + +async def _safe_aiohttp_get(url: str, session: aiohttp.ClientSession) -> bytes: + """Fetch *url*, validating each redirect hop and capping response size.""" + current = url + for _ in range(_MAX_REDIRECTS + 1): + await asyncio.to_thread(_validate_public_url, current) + async with session.get(current, allow_redirects=False) as response: + if response.status in _REDIRECT_STATUSES: + current = urljoin(current, response.headers.get("Location", "")) + continue + response.raise_for_status() + buf = BytesIO() + total = 0 + async for chunk in response.content.iter_chunked(1 << 20): + total += len(chunk) + if total > _MAX_RESPONSE_BYTES: + raise RuntimeError("Response exceeds maximum allowed size") + buf.write(chunk) + return buf.getvalue() + raise RuntimeError("Too many redirects") + + def _load_and_convert_image(image): image = Image.open(image) image.load() @@ -176,8 +226,7 @@ async def async_load_image( if parsed_url.scheme in ["http", "https"]: session = await _get_aiohttp_session() - async with session.get(image) as response: - content = await response.read() + content = await _safe_aiohttp_get(image, session) image = await asyncio.to_thread(_load_and_convert_image, BytesIO(content)) elif parsed_url.scheme == "data": @@ -440,8 +489,7 @@ def _load_from_bytes(data: bytes) -> VideoData: if parsed_url.scheme in ["http", "https"]: session = await _get_aiohttp_session() - async with session.get(video) as response: - content = await response.content.read() + content = await _safe_aiohttp_get(video, session) return await asyncio.to_thread(_load_from_bytes, content) elif parsed_url.scheme == "data": decoded_video = load_base64_video(video) @@ -489,8 +537,7 @@ async def async_load_audio( if parsed_url.scheme in ["http", "https"]: session = await _get_aiohttp_session() - async with session.get(audio) as response: - content = await response.content.read() + content = await _safe_aiohttp_get(audio, session) # Offload CPU-bound soundfile decoding to thread pool return await asyncio.to_thread(soundfile.read, BytesIO(content)) elif parsed_url.scheme == "file": diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 58d04c7d6a13..da9c915f2d2a 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -37,6 +37,7 @@ l0_a10: - unittest/inputs/test_chat_template_dispatch.py - unittest/inputs/test_content_format.py - unittest/inputs/test_multimodal.py + - unittest/inputs/test_url_validation.py - unittest/others/test_convert_utils.py - unittest/others/test_lora_manager.py - unittest/others/test_lora_module_count.py diff --git a/tests/unittest/inputs/test_url_validation.py b/tests/unittest/inputs/test_url_validation.py new file mode 100644 index 000000000000..4e3d004d90e8 --- /dev/null +++ b/tests/unittest/inputs/test_url_validation.py @@ -0,0 +1,145 @@ +"""Regression tests for SSRF/DoS protections in tensorrt_llm.inputs.utils. + +Covers NVBugs 5911304: user-supplied multimodal URLs were fetched without +validation, allowing SSRF to private/loopback/IMDS addresses, unbounded +response sizes, and unrestricted redirects. +""" + +import asyncio +import socket +from unittest.mock import patch + +import pytest + +from tensorrt_llm.inputs.utils import _MAX_RESPONSE_BYTES, _safe_aiohttp_get, _validate_public_url + + +def _dns(ip: str): + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0))] + + +PUBLIC_DNS = _dns("93.184.216.34") # example.com + + +class TestValidatePublicUrl: + def test_rejects_non_http_scheme(self): + with pytest.raises(RuntimeError, match="Only http/https"): + _validate_public_url("file:///etc/passwd") + + def test_rejects_missing_hostname(self): + with pytest.raises(RuntimeError, match="no hostname"): + _validate_public_url("http:///path") + + @pytest.mark.parametrize( + "ip", + [ + "127.0.0.1", + "::1", + "10.0.0.1", + "172.16.0.1", + "192.168.1.1", + "169.254.169.254", # AWS/Azure/GCP IMDS + ], + ) + def test_rejects_non_public_addresses(self, ip): + with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns(ip)): + with pytest.raises(RuntimeError, match="non-public"): + _validate_public_url("http://target.example/") + + def test_rejects_unresolvable_hostname(self): + with patch( + "tensorrt_llm.inputs.utils.socket.getaddrinfo", side_effect=socket.gaierror("nope") + ): + with pytest.raises(RuntimeError, match="Could not resolve"): + _validate_public_url("http://this.does.not.exist.invalid/") + + def test_accepts_public_address(self): + with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): + _validate_public_url("https://example.com/image.jpg") # no raise + + +class _FakeContent: + def __init__(self, data: bytes, chunk_size: int = 1 << 20): + self._data = data + self._chunk = chunk_size + + async def iter_chunked(self, size): + for i in range(0, len(self._data), self._chunk): + yield self._data[i : i + self._chunk] + + +class _FakeResponse: + def __init__(self, status=200, headers=None, body=b""): + self.status = status + self.headers = headers or {} + self.content = _FakeContent(body) + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + return False + + def raise_for_status(self): + if self.status >= 400: + raise RuntimeError(f"http {self.status}") + + +class _FakeSession: + def __init__(self, responses): + self._responses = list(responses) + + def get(self, url, **kwargs): + return self._responses.pop(0) + + +def _run(coro): + return asyncio.new_event_loop().run_until_complete(coro) + + +class TestSafeAiohttpGet: + def test_validates_before_request(self): + session = _FakeSession([]) # no responses needed; should fail validation + with patch( + "tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=_dns("169.254.169.254") + ): + with pytest.raises(RuntimeError, match="non-public"): + _run(_safe_aiohttp_get("http://target/", session)) + + def test_rejects_redirect_to_private(self): + responses = [ + _FakeResponse(status=302, headers={"Location": "http://internal.corp/secret"}), + ] + session = _FakeSession(responses) + + def dns(host, *a, **kw): + return _dns("10.0.0.1") if "internal" in host else PUBLIC_DNS + + with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", side_effect=dns): + with pytest.raises(RuntimeError, match="non-public"): + _run(_safe_aiohttp_get("http://example.com/", session)) + + def test_rejects_oversized_response(self): + big = b"x" * (_MAX_RESPONSE_BYTES + 1) + responses = [_FakeResponse(status=200, body=big)] + session = _FakeSession(responses) + with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): + with pytest.raises(RuntimeError, match="maximum allowed size"): + _run(_safe_aiohttp_get("http://example.com/", session)) + + def test_rejects_too_many_redirects(self): + # 7 redirects -> exceeds _MAX_REDIRECTS (5) + responses = [ + _FakeResponse(status=302, headers={"Location": "http://example.com/next"}) + for _ in range(7) + ] + session = _FakeSession(responses) + with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): + with pytest.raises(RuntimeError, match="Too many redirects"): + _run(_safe_aiohttp_get("http://example.com/", session)) + + def test_returns_body_on_success(self): + responses = [_FakeResponse(status=200, body=b"hello")] + session = _FakeSession(responses) + with patch("tensorrt_llm.inputs.utils.socket.getaddrinfo", return_value=PUBLIC_DNS): + assert _run(_safe_aiohttp_get("http://example.com/", session)) == b"hello"