diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index c648a0a62..3ef0d8706 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -535,6 +535,20 @@ def _wrap_async_generator_result( observe = _decorator.observe +def _get_generator_output( + items: List[Any], + transform_fn: Optional[Callable[[Iterable], str]], +) -> Any: + output: Any = items + + if transform_fn is not None: + output = transform_fn(items) + elif all(isinstance(item, str) for item in items): + output = "".join(items) + + return output + + class _ContextPreservedSyncGeneratorWrapper: """Sync generator wrapper that ensures each iteration runs in preserved context.""" @@ -560,9 +574,17 @@ def __init__( self.items: List[Any] = [] self.span = span self.transform_fn = transform_fn + self._finalized = False - def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper": - return self + def __iter__(self) -> Generator[Any, None, None]: + try: + while True: + yield self.__next__() + except StopIteration: + return + finally: + if not self._finalized: + self.close() def __next__(self) -> Any: try: @@ -573,25 +595,65 @@ def __next__(self) -> Any: return item except StopIteration: - # Handle output and span cleanup when generator is exhausted - output: Any = self.items + self._finalize() + raise # Re-raise StopIteration - if self.transform_fn is not None: - output = self.transform_fn(self.items) + except (Exception, asyncio.CancelledError) as e: + self._finalize(error=e) + raise - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) + def close(self) -> None: + if self._finalized: + return - self.span.update(output=output).end() + try: + close_method = getattr(self.generator, "close", None) + if callable(close_method): + self.context.run(close_method) + except (Exception, asyncio.CancelledError) as e: + self._finalize(error=e) + raise - raise # Re-raise StopIteration + self._finalize() + def throw(self, typ: Any, val: Any = None, tb: Any = None) -> Any: + throw_method = getattr(self.generator, "throw", None) + if not callable(throw_method): + raise AttributeError("Wrapped generator does not support throw()") + + try: + if tb is not None: + item = self.context.run(throw_method, typ, val, tb) + elif val is not None: + item = self.context.run(throw_method, typ, val) + else: + item = self.context.run(throw_method, typ) + + self.items.append(item) + + return item + except StopIteration: + self._finalize() + raise except (Exception, asyncio.CancelledError) as e: + self._finalize(error=e) + raise + + def _finalize(self, error: Optional[BaseException] = None) -> None: + if self._finalized: + return + + self._finalized = True + + if error is not None: self.span.update( - level="ERROR", status_message=str(e) or type(e).__name__ + level="ERROR", status_message=str(error) or type(error).__name__ ).end() + return - raise + self.span.update( + output=_get_generator_output(self.items, self.transform_fn) + ).end() class _ContextPreservedAsyncGeneratorWrapper: @@ -619,6 +681,7 @@ def __init__( self.items: List[Any] = [] self.span = span self.transform_fn = transform_fn + self._finalized = False def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": return self @@ -626,36 +689,85 @@ def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": async def __anext__(self) -> Any: try: # Run the generator's __anext__ in the preserved context - try: - # Python 3.10+ approach with context parameter - item = await asyncio.create_task( - self.generator.__anext__(), # type: ignore - context=self.context, - ) # type: ignore - except TypeError: - # Python < 3.10 fallback - context parameter not supported - item = await self.generator.__anext__() + item = await self._run_in_preserved_context(self.generator.__anext__) self.items.append(item) return item except StopAsyncIteration: - # Handle output and span cleanup when generator is exhausted - output: Any = self.items + self._finalize() + raise # Re-raise StopAsyncIteration + except (Exception, asyncio.CancelledError) as e: + self._finalize(error=e) + raise - if self.transform_fn is not None: - output = self.transform_fn(self.items) + async def close(self) -> None: + await self.aclose() - elif all(isinstance(item, str) for item in self.items): - output = "".join(self.items) + async def aclose(self) -> None: + if self._finalized: + return - self.span.update(output=output).end() + try: + close_method = getattr(self.generator, "aclose", None) + if callable(close_method): + await self._run_in_preserved_context(close_method) + except (Exception, asyncio.CancelledError) as e: + self._finalize(error=e) + raise - raise # Re-raise StopAsyncIteration + self._finalize() + + async def athrow(self, typ: Any, val: Any = None, tb: Any = None) -> Any: + throw_method = getattr(self.generator, "athrow", None) + if not callable(throw_method): + raise AttributeError("Wrapped async generator does not support athrow()") + + try: + if tb is not None: + item = await self._run_in_preserved_context( + lambda: throw_method(typ, val, tb) + ) + elif val is not None: + item = await self._run_in_preserved_context( + lambda: throw_method(typ, val) + ) + else: + item = await self._run_in_preserved_context(lambda: throw_method(typ)) + + self.items.append(item) + + return item + except StopAsyncIteration: + self._finalize() + raise except (Exception, asyncio.CancelledError) as e: + self._finalize(error=e) + raise + + async def _run_in_preserved_context(self, factory: Callable[[], Any]) -> Any: + awaitable = self.context.run(factory) + + try: + task = asyncio.create_task(awaitable, context=self.context) # type: ignore[call-arg] + except TypeError: + task = self.context.run(asyncio.create_task, awaitable) + + return await task + + def _finalize(self, error: Optional[BaseException] = None) -> None: + if self._finalized: + return + + self._finalized = True + + if error is not None: self.span.update( - level="ERROR", status_message=str(e) or type(e).__name__ + level="ERROR", status_message=str(error) or type(error).__name__ ).end() + return - raise + self.span.update( + output=_get_generator_output(self.items, self.transform_fn) + ).end() diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 0cd4dd133..373461c51 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -1057,10 +1057,10 @@ def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: and len(message.tool_calls) > 0 ): message_dict["tool_calls"] = message.tool_calls - + if ( - hasattr(message, "invalid_tool_calls") - and message.invalid_tool_calls is not None + hasattr(message, "invalid_tool_calls") + and message.invalid_tool_calls is not None and len(message.invalid_tool_calls) > 0 ): message_dict["invalid_tool_calls"] = message.invalid_tool_calls diff --git a/langfuse/openai.py b/langfuse/openai.py index 16d293e73..eb5a3b8c2 100644 --- a/langfuse/openai.py +++ b/langfuse/openai.py @@ -1010,6 +1010,7 @@ def __init__( self.response = response self.generation = generation self.completion_start_time: Optional[datetime] = None + self._finalized = False def __iter__(self) -> Any: try: @@ -1039,12 +1040,31 @@ def __next__(self) -> Any: raise def __enter__(self) -> Any: - return self.__iter__() + return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - pass + self.close() + + def close(self) -> None: + if self._finalized: + return + + close_method = getattr(self.response, "close", None) + if callable(close_method): + try: + close_method() + finally: + self._finalize() + return + + self._finalize() def _finalize(self) -> None: + if self._finalized: + return + + self._finalized = True + try: model, completion, usage, metadata = ( _extract_streamed_response_api_response(self.items) @@ -1081,6 +1101,7 @@ def __init__( self.response = response self.generation = generation self.completion_start_time: Optional[datetime] = None + self._finalized = False async def __aiter__(self) -> Any: try: @@ -1110,12 +1131,17 @@ async def __anext__(self) -> Any: raise async def __aenter__(self) -> Any: - return self.__aiter__() + return self async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - pass + await self.aclose() async def _finalize(self) -> None: + if self._finalized: + return + + self._finalized = True + try: model, completion, usage, metadata = ( _extract_streamed_response_api_response(self.items) @@ -1142,11 +1168,40 @@ async def close(self) -> None: Automatically called if the response body is read to completion. """ - await self.response.close() + if self._finalized: + return + + close_method = getattr(self.response, "close", None) + if callable(close_method): + try: + await close_method() + finally: + await self._finalize() + return + + await self._finalize() async def aclose(self) -> None: """Close the response and release the connection. Automatically called if the response body is read to completion. """ - await self.response.aclose() + if self._finalized: + return + + close_method = getattr(self.response, "aclose", None) + if callable(close_method): + try: + await close_method() + finally: + await self._finalize() + else: + close_method = getattr(self.response, "close", None) + if callable(close_method): + try: + await close_method() + finally: + await self._finalize() + return + + await self._finalize() diff --git a/tests/test_decorators.py b/tests/test_decorators.py index c6ed42594..5cf0714d2 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,8 +1,10 @@ import asyncio +import contextvars import os import sys from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from contextlib import aclosing from time import sleep from typing import Optional @@ -13,6 +15,10 @@ from langfuse import Langfuse, get_client, observe, propagate_attributes from langfuse._client.environment_variables import LANGFUSE_PUBLIC_KEY +from langfuse._client.observe import ( + _ContextPreservedAsyncGeneratorWrapper, + _ContextPreservedSyncGeneratorWrapper, +) from langfuse._client.resource_manager import LangfuseResourceManager from langfuse.langchain import CallbackHandler from langfuse.media import LangfuseMedia @@ -25,6 +31,25 @@ mock_kwargs = {"a": 1, "b": 2, "c": 3} +class _FakeObservation: + def __init__(self) -> None: + self.output = None + self.level = None + self.status_message = None + self.end_calls = 0 + + def update(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return self + + def end(self): + self.end_calls += 1 + + return self + + def removeMockResourceManagerInstances(): with LangfuseResourceManager._lock: for public_key in list(LangfuseResourceManager._instances.keys()): @@ -1759,6 +1784,35 @@ def root_function(): assert generator_obs.output == "item_0item_1item_2" +def test_sync_generator_wrapper_finalizes_on_early_break(): + closed = [] + + def generator(): + try: + yield "item_0" + yield "item_1" + finally: + closed.append("closed") + + span = _FakeObservation() + wrapper = _ContextPreservedSyncGeneratorWrapper( + generator(), + contextvars.copy_context(), + span, + None, + ) + + items = [] + for item in wrapper: + items.append(item) + break + + assert items == ["item_0"] + assert closed == ["closed"] + assert span.output == "item_0" + assert span.end_calls == 1 + + @pytest.mark.asyncio @pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") async def test_async_generator_context_preservation(): @@ -1938,6 +1992,79 @@ async def root_function(): assert "Generator failure test" in failing_obs.status_message +@pytest.mark.asyncio +async def test_async_generator_wrapper_aclose_finalizes_partial_output(): + closed = [] + + async def generator(): + try: + yield "async_item_0" + yield "async_item_1" + finally: + closed.append("closed") + + span = _FakeObservation() + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + contextvars.copy_context(), + span, + None, + ) + + items = [] + async with aclosing(wrapper) as stream: + async for item in stream: + items.append(item) + break + + assert items == ["async_item_0"] + assert closed == ["closed"] + assert span.output == "async_item_0" + assert span.end_calls == 1 + + +@pytest.mark.asyncio +async def test_async_generator_wrapper_preserves_context_without_task_context_kwarg( + monkeypatch, +): + marker = contextvars.ContextVar("marker", default=None) + preserved_context = contextvars.copy_context() + span = _FakeObservation() + original_create_task = asyncio.create_task + + token = marker.set("preserved") + try: + preserved_context = contextvars.copy_context() + finally: + marker.reset(token) + + async def generator(): + assert marker.get() == "preserved" + yield "ok" + + def patched_create_task(awaitable, *args, **kwargs): + if "context" in kwargs: + raise TypeError( + "create_task() got an unexpected keyword argument 'context'" + ) + + return original_create_task(awaitable, *args, **kwargs) + + monkeypatch.setattr(asyncio, "create_task", patched_create_task) + + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + preserved_context, + span, + None, + ) + + assert await wrapper.__anext__() == "ok" + await wrapper.aclose() + assert span.output == "ok" + assert span.end_calls == 1 + + def test_sync_generator_empty_context_preservation(): """Test that empty sync generators work correctly with context preservation""" langfuse = get_client() diff --git a/tests/test_openai.py b/tests/test_openai.py index 47f17a5c8..9c1e24fc1 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -1,16 +1,90 @@ import importlib import os +from types import SimpleNamespace from time import sleep import pytest from pydantic import BaseModel from langfuse._client.client import Langfuse +from langfuse.openai import ( + LangfuseResponseGeneratorAsync, + LangfuseResponseGeneratorSync, +) from tests.utils import create_uuid, encode_file_to_base64, get_api langfuse = Langfuse() +class _FakeGeneration: + def __init__(self) -> None: + self.output = None + self.model = None + self.metadata = None + self.completion_start_time = None + self.end_calls = 0 + + def update(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return self + + def end(self): + self.end_calls += 1 + + return self + + +class _FakeSyncStreamResponse: + def __init__(self, chunks): + self._iterator = iter(chunks) + self.close_calls = 0 + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + def close(self): + self.close_calls += 1 + + +class _FakeAsyncStreamResponse: + def __init__(self, chunks): + self._chunks = list(chunks) + self.close_calls = [] + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._chunks: + raise StopAsyncIteration + + return self._chunks.pop(0) + + async def close(self): + self.close_calls.append("close") + + async def aclose(self): + self.close_calls.append("aclose") + + +def _stream_chunk(content: str, *, role: str = "assistant", model: str = "gpt-test"): + return SimpleNamespace( + model=model, + usage=None, + choices=[ + SimpleNamespace( + delta=SimpleNamespace(role=role, content=content), + finish_reason=None, + ) + ], + ) + + @pytest.fixture(scope="module") def openai(): import openai @@ -73,6 +147,27 @@ def test_openai_chat_completion(openai): assert generation.data[0].output["role"] == "assistant" +def test_sync_stream_close_finalizes_partial_output(): + generation = _FakeGeneration() + response = _FakeSyncStreamResponse([_stream_chunk("Hel"), _stream_chunk("lo")]) + stream = LangfuseResponseGeneratorSync( + resource=SimpleNamespace(type="chat", object="ChatCompletions"), + response=response, + generation=generation, + ) + + first_chunk = next(stream) + assert first_chunk.choices[0].delta.content == "Hel" + + stream.close() + stream.close() + + assert response.close_calls == 1 + assert generation.output == "Hel" + assert generation.model == "gpt-test" + assert generation.end_calls == 1 + + def test_openai_chat_completion_stream(openai): generation_name = create_uuid() completion = openai.OpenAI().chat.completions.create( @@ -1163,6 +1258,28 @@ async def test_close_async_stream(openai): assert generation.data[0].completion_start_time <= generation.data[0].end_time +@pytest.mark.asyncio +async def test_async_stream_aclose_finalizes_partial_output(): + generation = _FakeGeneration() + response = _FakeAsyncStreamResponse([_stream_chunk("Hel"), _stream_chunk("lo")]) + stream = LangfuseResponseGeneratorAsync( + resource=SimpleNamespace(type="chat", object="ChatCompletions"), + response=response, + generation=generation, + ) + + first_chunk = await stream.__anext__() + assert first_chunk.choices[0].delta.content == "Hel" + + await stream.aclose() + await stream.aclose() + + assert response.close_calls == ["aclose"] + assert generation.output == "Hel" + assert generation.model == "gpt-test" + assert generation.end_calls == 1 + + def test_base_64_image_input(openai): client = openai.OpenAI() generation_name = "test_base_64_image_input" + create_uuid()[:8]