Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import typing
import warnings
import weakref
from contextlib import asynccontextmanager, contextmanager
from types import TracebackType

Expand Down Expand Up @@ -140,13 +141,14 @@ class BoundSyncStream(SyncByteStream):
"""
A byte stream that is bound to a given response instance, and that
ensures the `response.elapsed` is set once the response is closed.
Uses weakref to avoid reference cycles with the response object.
"""

def __init__(
self, stream: SyncByteStream, response: Response, start: float
) -> None:
self._stream = stream
self._response = response
self._response_ref: weakref.ref[Response] = weakref.ref(response)
self._start = start

def __iter__(self) -> typing.Iterator[bytes]:
Expand All @@ -155,21 +157,24 @@ def __iter__(self) -> typing.Iterator[bytes]:

def close(self) -> None:
elapsed = time.perf_counter() - self._start
self._response.elapsed = datetime.timedelta(seconds=elapsed)
response = self._response_ref()
if response is not None:
response.elapsed = datetime.timedelta(seconds=elapsed)
self._stream.close()


class BoundAsyncStream(AsyncByteStream):
"""
An async byte stream that is bound to a given response instance, and that
ensures the `response.elapsed` is set once the response is closed.
Uses weakref to avoid reference cycles with the response object.
"""

def __init__(
self, stream: AsyncByteStream, response: Response, start: float
) -> None:
self._stream = stream
self._response = response
self._response_ref: weakref.ref[Response] = weakref.ref(response)
self._start = start

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
Expand All @@ -178,7 +183,9 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:

async def aclose(self) -> None:
elapsed = time.perf_counter() - self._start
self._response.elapsed = datetime.timedelta(seconds=elapsed)
response = self._response_ref()
if response is not None:
response.elapsed = datetime.timedelta(seconds=elapsed)
await self._stream.aclose()


Expand Down
100 changes: 100 additions & 0 deletions tests/test_bound_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Tests for BoundSyncStream and BoundAsyncStream weakref behavior.
These tests verify that the streams properly break reference cycles
to allow garbage collection.
"""

import gc
import typing
import weakref

import pytest

import httpx
from httpx._client import BoundAsyncStream, BoundSyncStream
from httpx._types import AsyncByteStream, SyncByteStream


class MockSyncStream(SyncByteStream):
def __init__(self) -> None:
self.closed = False

def __iter__(self) -> typing.Iterator[bytes]: # pragma: no cover
yield b"test"

def close(self) -> None:
self.closed = True


class MockAsyncStream(AsyncByteStream):
def __init__(self) -> None:
self.closed = False

async def __aiter__(self) -> typing.AsyncIterator[bytes]: # pragma: no cover
yield b"test"

async def aclose(self) -> None:
self.closed = True


def test_bound_sync_stream_sets_elapsed():
response = httpx.Response(200, content=b"")
stream = MockSyncStream()
bound_stream = BoundSyncStream(stream, response=response, start=0.0)
bound_stream.close()
assert hasattr(response, "_elapsed")
assert response.elapsed.total_seconds() >= 0


def test_bound_sync_stream_handles_collected_response():
response = httpx.Response(200, content=b"")
stream = MockSyncStream()
bound_stream = BoundSyncStream(stream, response=response, start=0.0)
del response
gc.collect()
bound_stream.close()
assert stream.closed


def test_bound_sync_stream_no_reference_cycle():
response = httpx.Response(200, content=b"")
response_ref = weakref.ref(response)
stream = MockSyncStream()
bound_stream = BoundSyncStream(stream, response=response, start=0.0)
response.stream = bound_stream
del response
gc.collect()
assert response_ref() is None, "Response should have been garbage collected"


@pytest.mark.anyio
async def test_bound_async_stream_sets_elapsed():
response = httpx.Response(200, content=b"")
stream = MockAsyncStream()
bound_stream = BoundAsyncStream(stream, response=response, start=0.0)
await bound_stream.aclose()
assert hasattr(response, "_elapsed")
assert response.elapsed.total_seconds() >= 0


@pytest.mark.anyio
async def test_bound_async_stream_handles_collected_response():
response = httpx.Response(200, content=b"")
stream = MockAsyncStream()
bound_stream = BoundAsyncStream(stream, response=response, start=0.0)
del response
gc.collect()
await bound_stream.aclose()
assert stream.closed


@pytest.mark.anyio
async def test_bound_async_stream_no_reference_cycle():
response = httpx.Response(200, content=b"")
response_ref = weakref.ref(response)
stream = MockAsyncStream()
bound_stream = BoundAsyncStream(stream, response=response, start=0.0)
response.stream = bound_stream
del response
gc.collect()
assert response_ref() is None, "Response should have been garbage collected"