diff --git a/python/semantic_kernel/contents/__init__.py b/python/semantic_kernel/contents/__init__.py index b5c903112fde..7ae57aa27aa1 100644 --- a/python/semantic_kernel/contents/__init__.py +++ b/python/semantic_kernel/contents/__init__.py @@ -7,6 +7,7 @@ from semantic_kernel.contents.file_reference_content import FileReferenceContent from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.function_result_content import FunctionResultContent +from semantic_kernel.contents.history_reducer.chat_history_double_buffer_reducer import ChatHistoryDoubleBufferReducer from semantic_kernel.contents.history_reducer.chat_history_reducer import ChatHistoryReducer from semantic_kernel.contents.history_reducer.chat_history_summarization_reducer import ChatHistorySummarizationReducer from semantic_kernel.contents.history_reducer.chat_history_truncation_reducer import ChatHistoryTruncationReducer @@ -35,6 +36,7 @@ "AudioContent", "AuthorRole", "ChatHistory", + "ChatHistoryDoubleBufferReducer", "ChatHistoryReducer", "ChatHistorySummarizationReducer", "ChatHistoryTruncationReducer", diff --git a/python/semantic_kernel/contents/history_reducer/chat_history_double_buffer_reducer.py b/python/semantic_kernel/contents/history_reducer/chat_history_double_buffer_reducer.py new file mode 100644 index 000000000000..c9214850c08e --- /dev/null +++ b/python/semantic_kernel/contents/history_reducer/chat_history_double_buffer_reducer.py @@ -0,0 +1,455 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Double-buffered context window management for chat history. + +Implements a proactive context compaction strategy inspired by double buffering +(graphics, 1970s), checkpoint + WAL replay (databases, 1980s), and hopping +windows (stream processing). Instead of stop-the-world summarization when the +context window fills, this reducer begins summarizing at a configurable threshold +while the agent continues working, then swaps to the pre-built back buffer +seamlessly. + +""" + +import asyncio +import logging +import sys + +if sys.version < "3.11": + from typing_extensions import Self # pragma: no cover +else: + from typing import Self # type: ignore # pragma: no cover +if sys.version < "3.12": + from typing_extensions import override # pragma: no cover +else: + from typing import override # type: ignore # pragma: no cover + +from enum import Enum + +from pydantic import Field + +from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings +from semantic_kernel.contents.chat_history import ChatHistory +from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.history_reducer.chat_history_reducer import ChatHistoryReducer +from semantic_kernel.contents.history_reducer.chat_history_reducer_utils import ( + SUMMARY_METADATA_KEY, + contains_function_call_or_result, + extract_range, + locate_safe_reduction_index, + locate_summarization_boundary, +) +from semantic_kernel.exceptions.content_exceptions import ChatHistoryReducerException +from semantic_kernel.utils.feature_stage_decorator import experimental + +logger = logging.getLogger(__name__) + +GENERATION_METADATA_KEY = "__db_generation__" + +DEFAULT_SUMMARIZATION_PROMPT = """ +Provide a concise and complete summarization of the entire dialog that does not exceed 5 sentences. + +This summary must always: +- Consider both user and assistant interactions +- Maintain continuity for the purpose of further dialog +- Include details from any existing summary +- Focus on the most significant aspects of the dialog + +This summary must never: +- Critique, correct, interpret, presume, or assume +- Identify faults, mistakes, misunderstanding, or correctness +- Analyze what has not occurred +- Exclude details from any existing summary +""" + + +class RenewalPolicy(str, Enum): + """Policy for handling accumulated compression debt across generations.""" + + RECURSE = "recurse" + """Summarize the accumulated summaries (meta-compression).""" + + DUMP = "dump" + """Discard all summaries and start fresh.""" + + +@experimental +class ChatHistoryDoubleBufferReducer(ChatHistoryReducer): + """A ChatHistory with double-buffered context window management. + + Instead of stop-the-world compaction when the context fills, this reducer + begins checkpoint summarization at a configurable threshold (default 70% of + target capacity) while the agent continues working. New messages are appended + to both the active buffer and a pre-built back buffer. When the active buffer + hits the swap threshold, the back buffer becomes the new active context. + + Summaries accumulate across generations up to a configurable limit before a + full renewal (recurse or dump) is triggered, amortizing the cost of full + renewal across many handoffs. + + Args: + target_count: The target message count (context capacity). + threshold_count: The threshold count to avoid orphaning messages. + auto_reduce: Whether to automatically reduce the chat history. + service: The ChatCompletion service to use for summarization. + checkpoint_threshold: Fraction of target_count at which to begin + checkpoint summarization (0.0-1.0). Default 0.7. + swap_threshold: Fraction of target_count at which to swap buffers + (0.0-1.0). Default 0.95. Must be greater than checkpoint_threshold. + max_generations: Maximum number of summary-on-summary layers before + triggering renewal. None means no limit (renewal disabled). + renewal_policy: How to handle accumulated compression debt when + max_generations is reached. Default RECURSE. + summarization_instructions: The summarization prompt template. + fail_on_error: Raise error if summarization fails. Default True. + include_function_content_in_summary: Whether to include function + calls/results in the summary. Default False. + execution_settings: Execution settings for the summarization prompt. + """ + + service: ChatCompletionClientBase + checkpoint_threshold: float = Field( + default=0.7, + gt=0.0, + le=1.0, + description="Fraction of target_count at which to begin checkpoint summarization.", + ) + swap_threshold: float = Field( + default=0.95, + gt=0.0, + le=1.0, + description="Fraction of target_count at which to swap to the back buffer.", + ) + max_generations: int | None = Field( + default=None, + description="Maximum summary-on-summary layers before renewal. " + "None means no limit (renewal disabled).", + ) + renewal_policy: RenewalPolicy = Field( + default=RenewalPolicy.RECURSE, + description="How to handle compression debt when max_generations is reached.", + ) + summarization_instructions: str = Field( + default=DEFAULT_SUMMARIZATION_PROMPT, + description="The summarization instructions.", + kw_only=True, + ) + fail_on_error: bool = Field(default=True, description="Raise error if summarization fails.") + include_function_content_in_summary: bool = Field( + default=False, description="Whether to include function calls/results in the summary." + ) + execution_settings: PromptExecutionSettings | None = None + checkpoint_timeout: float = Field( + default=120.0, + gt=0.0, + description="Maximum seconds to wait for a background checkpoint before cancelling.", + ) + + # Internal state — not part of the public API + _back_buffer: list[ChatMessageContent] | None = None + _checkpoint_task: asyncio.Task | None = None # type: ignore[type-arg] + _checkpoint_in_progress: bool = False + _current_generation: int = 0 + + def model_post_init(self, __context: object) -> None: + """Validate that swap_threshold > checkpoint_threshold.""" + super().model_post_init(__context) + if self.swap_threshold <= self.checkpoint_threshold: + msg = ( + f"swap_threshold ({self.swap_threshold}) must be greater than " + f"checkpoint_threshold ({self.checkpoint_threshold})" + ) + raise ValueError(msg) + + @override + async def reduce(self) -> Self | None: + """Reduce chat history using double-buffered context management. + + Three-phase algorithm: + 1. Checkpoint: At checkpoint_threshold, fire off background summarization + to seed the back buffer. The agent continues working immediately. + 2. Concurrent: New messages go to both active and back buffers. + 3. Swap: At swap_threshold, swap to back buffer. If the background + checkpoint isn't done yet, block on it (graceful degradation to + stop-the-world, same as today's status quo). + + Returns: + self if reduction happened, None if no change is needed. + """ + history = self.messages + total = len(history) + checkpoint_limit = int(self.target_count * self.checkpoint_threshold) + swap_limit = int(self.target_count * self.swap_threshold) + + # Reap completed background checkpoint task + if self._checkpoint_task is not None and self._checkpoint_task.done(): + try: + self._checkpoint_task.result() + except Exception as ex: + logger.warning("Background checkpoint task failed: %s", ex) + self._checkpoint_task = None + + # Phase 3: Swap — if we've hit the swap threshold + if total >= swap_limit: + # If checkpoint is still running in background, block with timeout. + if self._checkpoint_task is not None and not self._checkpoint_task.done(): + logger.info( + "Swap threshold hit while checkpoint still running. " + "Blocking on checkpoint (timeout=%.1fs).", + self.checkpoint_timeout, + ) + try: + await asyncio.wait_for(self._checkpoint_task, timeout=self.checkpoint_timeout) + except TimeoutError: + logger.warning( + "Background checkpoint timed out after %.1fs. Cancelling.", + self.checkpoint_timeout, + ) + self._checkpoint_task.cancel() + except Exception as ex: + logger.warning("Background checkpoint failed at swap time: %s", ex) + finally: + self._checkpoint_task = None + if self._back_buffer is not None: + return await self._swap_buffers() + # No back buffer — checkpoint failed or wasn't started. + # Fall back to stop-the-world: create checkpoint synchronously, then swap. + logger.info("Swap threshold reached with no back buffer. Falling back to synchronous checkpoint.") + try: + await self._create_checkpoint() + except Exception as ex: + logger.warning("Synchronous fallback checkpoint also failed: %s", ex) + if self._back_buffer is not None: + return await self._swap_buffers() + # Truly nothing we can do — checkpoint failed twice. Continue with full context. + logger.warning("All checkpoint attempts failed at swap time. Continuing with full context.") + + # Phase 1: Checkpoint — kick off background summarization + if total >= checkpoint_limit and self._back_buffer is None and not self._checkpoint_in_progress: + self._checkpoint_task = asyncio.create_task(self._create_checkpoint()) + # Return immediately — agent keeps working while checkpoint runs + return self + + # Phase 2: Concurrent — back buffer kept in sync via add_message_async + return None + + async def add_message_async( + self, + message: ChatMessageContent | dict, + encoding: str | None = None, + metadata: dict | None = None, + ) -> None: + """Add a message to the chat history and the back buffer if it exists. + + This is the key to the concurrent phase: every new message goes to both + the active buffer and the back buffer, ensuring the back buffer has + full-fidelity recent messages. + """ + await super().add_message_async(message, encoding=encoding, metadata=metadata) + + # Concurrent phase: append to back buffer too + if self._back_buffer is not None: + if isinstance(message, ChatMessageContent): + self._back_buffer.append(message) + else: + self._back_buffer.append(ChatMessageContent(**message)) + + async def _create_checkpoint(self) -> Self | None: + """Phase 1: Summarize current context and seed the back buffer.""" + self._checkpoint_in_progress = True + history = self.messages + + try: + # Check if we need renewal first + if self.max_generations is not None and self._current_generation >= self.max_generations: + await self._perform_renewal() + # Re-capture — renewal may have reassigned self.messages + history = self.messages + + # Find the summarization boundary (skip existing summaries) + insertion_point = locate_summarization_boundary(history) + if insertion_point == len(history): + logger.warning("All messages are summaries, forcing boundary to 0.") + insertion_point = 0 + + # Find safe reduction index. + # For the checkpoint, we want to keep roughly half the messages as + # recent context in the back buffer, summarizing the rest. + keep_count = len(history) - insertion_point - max(1, (len(history) - insertion_point) // 2) + if keep_count < 1: + keep_count = 1 + truncation_index = locate_safe_reduction_index( + history, + keep_count, + self.threshold_count, + offset_count=insertion_point, + ) + + if truncation_index is None: + logger.info("No valid truncation index found for checkpoint.") + self._checkpoint_in_progress = False + return None + + # Extract messages to summarize + messages_to_summarize = extract_range( + history, + start=insertion_point, + end=truncation_index, + filter_func=(contains_function_call_or_result if not self.include_function_content_in_summary else None), + preserve_pairs=self.include_function_content_in_summary, + ) + + if not messages_to_summarize: + logger.info("No messages to summarize for checkpoint.") + self._checkpoint_in_progress = False + return None + + # Generate summary + summary_msg = await self._summarize(messages_to_summarize) + if not summary_msg: + self._checkpoint_in_progress = False + return None + + # Tag summary with metadata + summary_msg.metadata[SUMMARY_METADATA_KEY] = True + summary_msg.metadata[GENERATION_METADATA_KEY] = self._current_generation + 1 + + # Collect existing summaries to carry forward + existing_summaries = history[:insertion_point] if insertion_point > 0 else [] + + # Seed back buffer: summaries + new summary + recent messages + remainder = history[truncation_index:] + self._back_buffer = [*existing_summaries, summary_msg, *remainder] + self._checkpoint_in_progress = False + + logger.info( + "Checkpoint created at generation %d. Back buffer seeded with %d messages.", + self._current_generation + 1, + len(self._back_buffer), + ) + return self + + except Exception as ex: + self._checkpoint_in_progress = False + logger.warning("Checkpoint creation failed: %s", ex) + if self.fail_on_error: + raise ChatHistoryReducerException("Double-buffer checkpoint creation failed.") from ex + return None + + async def _swap_buffers(self) -> Self: + """Phase 3: Swap the back buffer into the active context.""" + logger.info( + "Swapping buffers. Active: %d messages -> Back: %d messages. Generation %d -> %d.", + len(self.messages), + len(self._back_buffer) if self._back_buffer else 0, + self._current_generation, + self._current_generation + 1, + ) + + self.messages = self._back_buffer or [] + self._back_buffer = None + self._current_generation += 1 + self._checkpoint_in_progress = False + + return self + + async def _perform_renewal(self) -> None: + """Handle accumulated compression debt when max_generations is reached.""" + logger.info( + "Max generations (%d) reached. Performing renewal with policy: %s", + self.max_generations, + self.renewal_policy.value, + ) + + if self.renewal_policy == RenewalPolicy.DUMP: + # Drop all summaries, keep only non-summary messages + self.messages = [ + msg for msg in self.messages + if not msg.metadata or SUMMARY_METADATA_KEY not in msg.metadata + ] + self._current_generation = 0 + + elif self.renewal_policy == RenewalPolicy.RECURSE: + # Summarize the summaries (meta-compression) + summary_messages = [ + msg for msg in self.messages + if msg.metadata and SUMMARY_METADATA_KEY in msg.metadata + ] + non_summary_messages = [ + msg for msg in self.messages + if not msg.metadata or SUMMARY_METADATA_KEY not in msg.metadata + ] + + if summary_messages: + meta_summary = await self._summarize(summary_messages) + if meta_summary: + meta_summary.metadata[SUMMARY_METADATA_KEY] = True + meta_summary.metadata[GENERATION_METADATA_KEY] = 0 + self.messages = [meta_summary, *non_summary_messages] + self._current_generation = 0 + return + + # Fallback: if meta-summarization fails, just dump + self.messages = non_summary_messages + self._current_generation = 0 + + async def _summarize(self, messages: list[ChatMessageContent]) -> ChatMessageContent | None: + """Use the ChatCompletion service to generate a single summary message.""" + from semantic_kernel.contents.utils.author_role import AuthorRole + + chat_history = ChatHistory(messages=messages) + execution_settings = self.execution_settings or self.service.get_prompt_execution_settings_from_settings( + PromptExecutionSettings() + ) + chat_history.add_message( + ChatMessageContent( + role=getattr(execution_settings, "instruction_role", AuthorRole.SYSTEM), + content=self.summarization_instructions, + ) + ) + return await self.service.get_chat_message_content(chat_history=chat_history, settings=execution_settings) + + @property + def generation(self) -> int: + """Current generation count (number of buffer swaps completed).""" + return self._current_generation + + @property + def has_back_buffer(self) -> bool: + """Whether a back buffer is currently active (concurrent phase).""" + return self._back_buffer is not None + + @property + def back_buffer_size(self) -> int: + """Number of messages in the back buffer, or 0 if no back buffer.""" + return len(self._back_buffer) if self._back_buffer is not None else 0 + + def __eq__(self, other: object) -> bool: + """Check if two ChatHistoryDoubleBufferReducer objects are equal.""" + if not isinstance(other, ChatHistoryDoubleBufferReducer): + return False + return ( + self.target_count == other.target_count + and self.threshold_count == other.threshold_count + and self.checkpoint_threshold == other.checkpoint_threshold + and self.swap_threshold == other.swap_threshold + and self.max_generations == other.max_generations + and self.renewal_policy == other.renewal_policy + and self.summarization_instructions == other.summarization_instructions + ) + + def __hash__(self) -> int: + """Hash the object based on its properties.""" + return hash(( + self.__class__.__name__, + self.target_count, + self.threshold_count, + self.checkpoint_threshold, + self.swap_threshold, + self.max_generations, + self.renewal_policy, + self.summarization_instructions, + self.fail_on_error, + self.include_function_content_in_summary, + )) diff --git a/python/tests/unit/contents/test_chat_history_double_buffer_reducer.py b/python/tests/unit/contents/test_chat_history_double_buffer_reducer.py new file mode 100644 index 000000000000..d7b69429f7e7 --- /dev/null +++ b/python/tests/unit/contents/test_chat_history_double_buffer_reducer.py @@ -0,0 +1,456 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings +from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.history_reducer.chat_history_double_buffer_reducer import ( + GENERATION_METADATA_KEY, + ChatHistoryDoubleBufferReducer, + RenewalPolicy, +) +from semantic_kernel.contents.history_reducer.chat_history_reducer_utils import SUMMARY_METADATA_KEY +from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.exceptions.content_exceptions import ChatHistoryReducerException + + +@pytest.fixture +def mock_service(): + """Returns a mock ChatCompletionClientBase with required methods.""" + service = MagicMock(spec=ChatCompletionClientBase) + service.get_prompt_execution_settings_class.return_value = MagicMock(return_value=MagicMock(service_id="foo")) + service.get_chat_message_content = AsyncMock() + service.get_prompt_execution_settings_from_settings.return_value = PromptExecutionSettings() + return service + + +@pytest.fixture +def summary_message(): + """Returns a standard mock summary response.""" + return ChatMessageContent(role=AuthorRole.ASSISTANT, content="This is a summary of the conversation.") + + +async def _await_checkpoint(reducer: ChatHistoryDoubleBufferReducer) -> None: + """Wait for a background checkpoint task to complete.""" + if reducer._checkpoint_task is not None and not reducer._checkpoint_task.done(): + await reducer._checkpoint_task + reducer._checkpoint_task = None + + +def _make_messages(count: int, *, with_summary: bool = False) -> list[ChatMessageContent]: + """Generate a list of alternating user/assistant messages.""" + msgs = [] + if with_summary: + summary = ChatMessageContent(role=AuthorRole.SYSTEM, content="Prior summary.") + summary.metadata[SUMMARY_METADATA_KEY] = True + msgs.append(summary) + for i in range(count): + role = AuthorRole.USER if i % 2 == 0 else AuthorRole.ASSISTANT + msgs.append(ChatMessageContent(role=role, content=f"Message {i}")) + return msgs + + +# --- Init and Validation Tests --- + + +def test_double_buffer_reducer_init(mock_service): + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=20, + checkpoint_threshold=0.6, + swap_threshold=0.9, + max_generations=3, + renewal_policy=RenewalPolicy.DUMP, + ) + assert reducer.target_count == 20 + assert reducer.checkpoint_threshold == 0.6 + assert reducer.swap_threshold == 0.9 + assert reducer.max_generations == 3 + assert reducer.renewal_policy == RenewalPolicy.DUMP + + +def test_double_buffer_reducer_defaults(mock_service): + reducer = ChatHistoryDoubleBufferReducer(service=mock_service, target_count=10) + assert reducer.checkpoint_threshold == 0.7 + assert reducer.swap_threshold == 0.95 + assert reducer.max_generations is None + assert reducer.renewal_policy == RenewalPolicy.RECURSE + assert reducer.fail_on_error is True + assert reducer.generation == 0 + assert reducer.has_back_buffer is False + assert reducer.back_buffer_size == 0 + + +def test_swap_threshold_must_exceed_checkpoint_threshold(mock_service): + with pytest.raises(ValueError, match="swap_threshold.*must be greater"): + ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + checkpoint_threshold=0.8, + swap_threshold=0.5, + ) + + +def test_swap_threshold_cannot_equal_checkpoint_threshold(mock_service): + with pytest.raises(ValueError, match="swap_threshold.*must be greater"): + ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + checkpoint_threshold=0.7, + swap_threshold=0.7, + ) + + +def test_double_buffer_reducer_eq_and_hash(mock_service): + r1 = ChatHistoryDoubleBufferReducer(service=mock_service, target_count=10, max_generations=3) + r2 = ChatHistoryDoubleBufferReducer(service=mock_service, target_count=10, max_generations=3) + r3 = ChatHistoryDoubleBufferReducer(service=mock_service, target_count=10, max_generations=7) + assert r1 == r2 + assert r1 != r3 + assert hash(r1) == hash(r2) + assert hash(r1) != hash(r3) + + +# --- Phase 1: Checkpoint Tests --- + + +async def test_no_reduction_below_checkpoint_threshold(mock_service): + """No action when message count is below checkpoint threshold.""" + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=20, + checkpoint_threshold=0.7, # triggers at 14 messages + ) + reducer.messages = _make_messages(10) + result = await reducer.reduce() + assert result is None + assert reducer.has_back_buffer is False + mock_service.get_chat_message_content.assert_not_awaited() + + +async def test_checkpoint_creates_back_buffer(mock_service, summary_message): + """Hitting checkpoint threshold creates a back buffer seeded with summary.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, # triggers at 5 messages + swap_threshold=0.9, + ) + reducer.messages = _make_messages(8) + + result = await reducer.reduce() + assert result is not None # checkpoint was kicked off + await _await_checkpoint(reducer) # wait for background task + assert reducer.has_back_buffer is True + assert reducer.back_buffer_size > 0 + # Original messages should be untouched + assert len(reducer.messages) == 8 + + +async def test_checkpoint_tags_summary_with_generation(mock_service, summary_message): + """Summary message should be tagged with generation metadata.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + ) + reducer.messages = _make_messages(8) + await reducer.reduce() + await _await_checkpoint(reducer) + + # Find the summary in the back buffer + summaries = [ + msg for msg in (reducer._back_buffer or []) + if msg.metadata.get(SUMMARY_METADATA_KEY) + ] + assert len(summaries) >= 1 + assert summaries[-1].metadata.get(GENERATION_METADATA_KEY) == 1 + + +# --- Phase 2: Concurrent Phase Tests --- + + +async def test_concurrent_phase_appends_to_both_buffers(mock_service, summary_message): + """During concurrent phase, new messages go to both active and back buffer.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + ) + reducer.messages = _make_messages(6) + await reducer.reduce() # kicks off checkpoint + await _await_checkpoint(reducer) # wait for it + + assert reducer.has_back_buffer is True + back_buffer_size_before = reducer.back_buffer_size + active_size_before = len(reducer.messages) + + # Add a new message — should go to both + new_msg = ChatMessageContent(role=AuthorRole.USER, content="New message during concurrent phase") + await reducer.add_message_async(new_msg) + + assert len(reducer.messages) == active_size_before + 1 + assert reducer.back_buffer_size == back_buffer_size_before + 1 + + +# --- Phase 3: Swap Tests --- + + +async def test_swap_at_threshold(mock_service, summary_message): + """Buffer swap occurs when active buffer hits swap threshold.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, # checkpoint at 5 + swap_threshold=0.9, # swap at 9 + ) + # Load enough to trigger checkpoint + reducer.messages = _make_messages(6) + await reducer.reduce() + await _await_checkpoint(reducer) + assert reducer.has_back_buffer is True + assert reducer.generation == 0 + + # Now add messages until we hit swap threshold + while len(reducer.messages) < 9: + msg = ChatMessageContent(role=AuthorRole.USER, content="Filler") + await reducer.add_message_async(msg) + + # This reduce should trigger the swap + result = await reducer.reduce() + assert result is not None + assert reducer.generation == 1 + assert reducer.has_back_buffer is False + # Back buffer should now be the active buffer — it should be smaller + # than the pre-swap active buffer since it has compressed history + assert len(reducer.messages) > 0 + + +async def test_swap_increments_generation(mock_service, summary_message): + """Each swap increments the generation counter.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.3, + swap_threshold=0.8, + ) + reducer.messages = _make_messages(6) + + # First cycle: checkpoint + swap + await reducer.reduce() # checkpoint + assert reducer.generation == 0 + reducer.messages = _make_messages(9) # force above swap threshold + reducer._back_buffer = _make_messages(4, with_summary=True) # simulate back buffer + await reducer.reduce() # swap + assert reducer.generation == 1 + + +# --- Renewal Tests --- + + +async def test_renewal_dump_policy(mock_service, summary_message): + """DUMP renewal policy discards all summaries and resets generation.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=20, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + max_generations=2, + renewal_policy=RenewalPolicy.DUMP, + ) + + # Simulate having reached max generations + reducer._current_generation = 2 + msgs = _make_messages(12, with_summary=True) + reducer.messages = msgs + + # Trigger checkpoint — should perform renewal first + await reducer.reduce() + await _await_checkpoint(reducer) + + # Generation should be reset + assert reducer._current_generation == 0 or reducer.generation == 0 + + +async def test_renewal_recurse_policy(mock_service, summary_message): + """RECURSE renewal policy meta-summarizes accumulated summaries.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=20, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + max_generations=2, + renewal_policy=RenewalPolicy.RECURSE, + ) + + # Simulate having reached max generations with multiple summaries + reducer._current_generation = 2 + summary1 = ChatMessageContent(role=AuthorRole.SYSTEM, content="Summary gen 1.") + summary1.metadata[SUMMARY_METADATA_KEY] = True + summary2 = ChatMessageContent(role=AuthorRole.SYSTEM, content="Summary gen 2.") + summary2.metadata[SUMMARY_METADATA_KEY] = True + reducer.messages = [summary1, summary2, *_make_messages(12)] + + await reducer.reduce() + await _await_checkpoint(reducer) + + # Should have called the service for meta-summarization + assert mock_service.get_chat_message_content.await_count >= 1 + + +# --- Error Handling Tests --- + + +async def test_checkpoint_failure_with_fail_on_error(mock_service): + """Checkpoint failure raises when fail_on_error is True.""" + mock_service.get_chat_message_content.side_effect = Exception("LLM error") + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + fail_on_error=True, + ) + reducer.messages = _make_messages(8) + + # Kick off checkpoint (runs in background) + await reducer.reduce() + # Await the background task — it should raise + with pytest.raises(ChatHistoryReducerException, match="failed"): + await _await_checkpoint(reducer) + + +async def test_checkpoint_failure_without_fail_on_error(mock_service): + """Checkpoint failure logs warning and leaves no back buffer.""" + mock_service.get_chat_message_content.side_effect = Exception("LLM error") + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + fail_on_error=False, + ) + reducer.messages = _make_messages(8) + + await reducer.reduce() + await _await_checkpoint(reducer) + assert reducer.has_back_buffer is False + + +async def test_checkpoint_with_no_summarizable_messages(mock_service): + """Returns None when there are no messages to summarize.""" + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=10, # high threshold means no reduction needed + checkpoint_threshold=0.5, + swap_threshold=0.9, + ) + reducer.messages = _make_messages(6) + + await reducer.reduce() + await _await_checkpoint(reducer) + assert reducer.has_back_buffer is False + mock_service.get_chat_message_content.assert_not_awaited() + + +# --- Auto-reduce Tests --- + + +async def test_auto_reduce_triggers_checkpoint(mock_service, summary_message): + """With auto_reduce=True, adding messages can trigger checkpoint.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=6, + threshold_count=0, + checkpoint_threshold=0.5, # checkpoint at 3 + swap_threshold=0.9, + auto_reduce=True, + ) + + # Add messages one by one — checkpoint should trigger automatically + for i in range(5): + role = AuthorRole.USER if i % 2 == 0 else AuthorRole.ASSISTANT + await reducer.add_message_async(ChatMessageContent(role=role, content=f"Msg {i}")) + + # By now checkpoint should have been triggered + # (exact behavior depends on threshold math, but service should have been called) + + +# --- Graceful Degradation Tests --- + + +async def test_graceful_degradation_on_summarizer_returning_none(mock_service): + """If summarizer returns None, checkpoint is aborted gracefully.""" + mock_service.get_chat_message_content.return_value = None + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + ) + reducer.messages = _make_messages(8) + + await reducer.reduce() + await _await_checkpoint(reducer) + assert reducer.has_back_buffer is False + + +async def test_no_double_checkpoint(mock_service, summary_message): + """Calling reduce twice doesn't create a second checkpoint if one exists.""" + mock_service.get_chat_message_content.return_value = summary_message + + reducer = ChatHistoryDoubleBufferReducer( + service=mock_service, + target_count=10, + threshold_count=0, + checkpoint_threshold=0.5, + swap_threshold=0.9, + ) + reducer.messages = _make_messages(8) + + # First reduce kicks off checkpoint in background + await reducer.reduce() + await _await_checkpoint(reducer) + assert reducer.has_back_buffer is True + call_count_after_first = mock_service.get_chat_message_content.await_count + + # Second reduce should not create another checkpoint (back buffer already exists) + result = await reducer.reduce() + assert result is None # no swap needed yet, no new checkpoint needed + assert mock_service.get_chat_message_content.await_count == call_count_after_first