Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,18 @@ def _create_chat_message_content(
if part.text:
items.append(TextContent(text=part.text, inner_content=response, metadata=response_metadata))
elif part.function_call:
fc_metadata: dict[str, Any] = {}
thought_sig = getattr(part, "thought_signature", None)
if thought_sig:
fc_metadata["thought_signature"] = thought_sig
items.append(
FunctionCallContent(
id=f"{part.function_call.name}_{idx!s}",
name=format_gemini_function_name_to_kernel_function_fully_qualified_name(
part.function_call.name # type: ignore[arg-type]
),
arguments={k: v for k, v in part.function_call.args.items()}, # type: ignore
metadata=fc_metadata if fc_metadata else None,
)
)

Expand Down Expand Up @@ -360,13 +365,18 @@ def _create_streaming_chat_message_content(
)
)
elif part.function_call:
fc_metadata: dict[str, Any] = {}
thought_sig = getattr(part, "thought_signature", None)
if thought_sig:
fc_metadata["thought_signature"] = thought_sig
items.append(
FunctionCallContent(
id=f"{part.function_call.name}_{idx!s}",
name=format_gemini_function_name_to_kernel_function_fully_qualified_name(
part.function_call.name # type: ignore[arg-type]
),
arguments={k: v for k, v in part.function_call.args.items()}, # type: ignore
metadata=fc_metadata if fc_metadata else None,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,24 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]:
if item.text:
parts.append(Part.from_text(text=item.text))
elif isinstance(item, FunctionCallContent):
parts.append(
Part.from_function_call(
name=item.name, # type: ignore[arg-type]
args=json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments, # type: ignore[arg-type]
thought_signature = item.metadata.get("thought_signature") if item.metadata else None
if thought_signature:
parts.append(
Part(
function_call={
"name": item.name, # type: ignore[arg-type]
"args": json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments,
},
thought_signature=thought_signature,
)
)
else:
parts.append(
Part.from_function_call(
name=item.name, # type: ignore[arg-type]
args=json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments, # type: ignore[arg-type]
)
)
)
elif isinstance(item, ImageContent):
parts.append(_create_image_part(item))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from google.cloud.aiplatform_v1beta1.types.content import Candidate
from vertexai.generative_models import FunctionDeclaration, Part, Tool, ToolConfig
Expand Down Expand Up @@ -89,14 +89,16 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]:
if item.text:
parts.append(Part.from_text(item.text))
elif isinstance(item, FunctionCallContent):
parts.append(
Part.from_dict({
"function_call": {
"name": item.name,
"args": json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments,
}
})
)
part_dict: dict[str, Any] = {
"function_call": {
"name": item.name, # type: ignore[arg-type]
"args": json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments,
}
}
thought_signature = item.metadata.get("thought_signature") if item.metadata else None
if thought_signature:
part_dict["thought_signature"] = thought_signature
parts.append(Part.from_dict(part_dict))
elif isinstance(item, ImageContent):
parts.append(_create_image_part(item))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,18 @@ def _create_chat_message_content(self, response: GenerationResponse, candidate:
if "text" in part_dict:
items.append(TextContent(text=part.text, inner_content=response, metadata=response_metadata))
elif "function_call" in part_dict:
fc_metadata: dict[str, Any] = {}
thought_sig = part_dict.get("thought_signature")
if thought_sig:
fc_metadata["thought_signature"] = thought_sig
items.append(
FunctionCallContent(
id=f"{part.function_call.name}_{idx!s}",
name=format_gemini_function_name_to_kernel_function_fully_qualified_name(
part.function_call.name
),
arguments={k: v for k, v in part.function_call.args.items()},
metadata=fc_metadata if fc_metadata else None,
)
)

Expand Down Expand Up @@ -309,13 +314,18 @@ def _create_streaming_chat_message_content(
)
)
elif "function_call" in part_dict:
fc_metadata_s: dict[str, Any] = {}
thought_sig_s = part_dict.get("thought_signature")
if thought_sig_s:
fc_metadata_s["thought_signature"] = thought_sig_s
items.append(
FunctionCallContent(
id=f"{part.function_call.name}_{idx!s}",
name=format_gemini_function_name_to_kernel_function_fully_qualified_name(
part.function_call.name
),
arguments={k: v for k, v in part.function_call.args.items()},
metadata=fc_metadata_s if fc_metadata_s else None,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,197 @@ def test_google_ai_chat_completion_parse_chat_history_correctly(google_ai_unit_t
assert parsed_chat_history[0].parts[0].text == "test_user_message"
assert parsed_chat_history[1].role == "model"
assert parsed_chat_history[1].parts[0].text == "test_assistant_message"


# region deserialization (Part → FunctionCallContent round-trip)


def test_create_chat_message_content_with_thought_signature(google_ai_unit_test_env) -> None:
"""Test that thought_signature from a Part is deserialized into FunctionCallContent.metadata."""
from google.genai.types import (
Candidate,
Content,
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
Part,
)
from google.genai.types import (
FinishReason as GFinishReason,
)

from semantic_kernel.contents.function_call_content import FunctionCallContent

thought_sig_value = b"test-thought-sig-bytes"
part = Part.from_function_call(name="test_function", args={"key": "value"})
part.thought_signature = thought_sig_value

candidate = Candidate()
candidate.index = 0
candidate.content = Content(role="user", parts=[part])
candidate.finish_reason = GFinishReason.STOP

response = GenerateContentResponse()
response.candidates = [candidate]
response.usage_metadata = GenerateContentResponseUsageMetadata(
prompt_token_count=0, cached_content_token_count=0, candidates_token_count=0, total_token_count=0
)

completion = GoogleAIChatCompletion()
result = completion._create_chat_message_content(response, candidate)

fc_items = [item for item in result.items if isinstance(item, FunctionCallContent)]
assert len(fc_items) == 1
assert fc_items[0].metadata is not None
assert fc_items[0].metadata["thought_signature"] == thought_sig_value


def test_create_chat_message_content_without_thought_signature(google_ai_unit_test_env) -> None:
"""Test that FunctionCallContent works when Part has no thought_signature."""
from google.genai.types import (
Candidate,
Content,
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
Part,
)
from google.genai.types import (
FinishReason as GFinishReason,
)

from semantic_kernel.contents.function_call_content import FunctionCallContent

part = Part.from_function_call(name="test_function", args={"key": "value"})

candidate = Candidate()
candidate.index = 0
candidate.content = Content(role="user", parts=[part])
candidate.finish_reason = GFinishReason.STOP

response = GenerateContentResponse()
response.candidates = [candidate]
response.usage_metadata = GenerateContentResponseUsageMetadata(
prompt_token_count=0, cached_content_token_count=0, candidates_token_count=0, total_token_count=0
)

completion = GoogleAIChatCompletion()
result = completion._create_chat_message_content(response, candidate)

fc_items = [item for item in result.items if isinstance(item, FunctionCallContent)]
assert len(fc_items) == 1
assert "thought_signature" not in fc_items[0].metadata


def test_create_streaming_chat_message_content_with_thought_signature(google_ai_unit_test_env) -> None:
"""Test that thought_signature from a Part is deserialized in streaming path."""
from google.genai.types import (
Candidate,
Content,
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
Part,
)
from google.genai.types import (
FinishReason as GFinishReason,
)

from semantic_kernel.contents.function_call_content import FunctionCallContent

thought_sig_value = b"streaming-thought-sig"
part = Part.from_function_call(name="stream_func", args={"a": "b"})
part.thought_signature = thought_sig_value

candidate = Candidate()
candidate.index = 0
candidate.content = Content(role="user", parts=[part])
candidate.finish_reason = GFinishReason.STOP

chunk = GenerateContentResponse()
chunk.candidates = [candidate]
chunk.usage_metadata = GenerateContentResponseUsageMetadata(
prompt_token_count=0, cached_content_token_count=0, candidates_token_count=0, total_token_count=0
)

completion = GoogleAIChatCompletion()
result = completion._create_streaming_chat_message_content(chunk, candidate)

fc_items = [item for item in result.items if isinstance(item, FunctionCallContent)]
assert len(fc_items) == 1
assert fc_items[0].metadata is not None
assert fc_items[0].metadata["thought_signature"] == thought_sig_value


def test_create_streaming_chat_message_content_without_thought_signature(google_ai_unit_test_env) -> None:
"""Test that streaming FunctionCallContent works when Part lacks thought_signature."""
from google.genai.types import (
Candidate,
Content,
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
Part,
)
from google.genai.types import (
FinishReason as GFinishReason,
)

from semantic_kernel.contents.function_call_content import FunctionCallContent

part = Part.from_function_call(name="stream_func", args={"a": "b"})

candidate = Candidate()
candidate.index = 0
candidate.content = Content(role="user", parts=[part])
candidate.finish_reason = GFinishReason.STOP

chunk = GenerateContentResponse()
chunk.candidates = [candidate]
chunk.usage_metadata = GenerateContentResponseUsageMetadata(
prompt_token_count=0, cached_content_token_count=0, candidates_token_count=0, total_token_count=0
)

completion = GoogleAIChatCompletion()
result = completion._create_streaming_chat_message_content(chunk, candidate)

fc_items = [item for item in result.items if isinstance(item, FunctionCallContent)]
assert len(fc_items) == 1
assert "thought_signature" not in fc_items[0].metadata


def test_create_chat_message_content_getattr_guard_on_missing_attribute(google_ai_unit_test_env) -> None:
"""Test that getattr guard handles SDK versions where thought_signature doesn't exist on Part."""
from unittest.mock import MagicMock

from google.genai.types import (
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
)

from semantic_kernel.contents.function_call_content import FunctionCallContent

# Create a mock Part that lacks 'thought_signature' attribute entirely
mock_part = MagicMock()
mock_part.text = None
mock_part.function_call.name = "test_func"
mock_part.function_call.args = {"x": "y"}
del mock_part.thought_signature # simulate older SDK without the field

# Use a fully-mocked candidate to avoid Content pydantic validation
mock_candidate = MagicMock()
mock_candidate.index = 0
mock_candidate.content.parts = [mock_part]
mock_candidate.finish_reason = 1 # STOP

response = GenerateContentResponse()
response.candidates = [mock_candidate]
response.usage_metadata = GenerateContentResponseUsageMetadata(
prompt_token_count=0, cached_content_token_count=0, candidates_token_count=0, total_token_count=0
)

completion = GoogleAIChatCompletion()
result = completion._create_chat_message_content(response, mock_candidate)

fc_items = [item for item in result.items if isinstance(item, FunctionCallContent)]
assert len(fc_items) == 1
assert "thought_signature" not in fc_items[0].metadata


# endregion deserialization
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,47 @@ def test_format_assistant_message_with_unsupported_items() -> None:

with pytest.raises(ServiceInvalidRequestError):
format_assistant_message(assistant_message)


def test_format_assistant_message_with_thought_signature() -> None:
"""Test that thought_signature is preserved in function call parts."""
import base64

thought_sig = base64.b64encode(b"test_thought_signature_data")
assistant_message = ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
name="test_function",
arguments={"arg1": "value1"},
metadata={"thought_signature": thought_sig},
),
],
)

formatted = format_assistant_message(assistant_message)
assert len(formatted) == 1
assert isinstance(formatted[0], Part)
assert formatted[0].function_call.name == "test_function"
assert formatted[0].function_call.args == {"arg1": "value1"}
assert formatted[0].thought_signature == thought_sig


def test_format_assistant_message_without_thought_signature() -> None:
"""Test that function calls without thought_signature still work."""
assistant_message = ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
name="test_function",
arguments={"arg1": "value1"},
),
],
)

formatted = format_assistant_message(assistant_message)
assert len(formatted) == 1
assert isinstance(formatted[0], Part)
assert formatted[0].function_call.name == "test_function"
assert formatted[0].function_call.args == {"arg1": "value1"}
assert not getattr(formatted[0], "thought_signature", None)
Loading
Loading