Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
endpoint_url=endpoint_url,
region_name=resolved_region,
)
self._boto_session = session

logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)

Expand All @@ -208,6 +209,39 @@ def _cache_strategy(self) -> str | None:
model_id = self.config.get("model_id", "").lower()
if "claude" in model_id or "anthropic" in model_id:
return "anthropic"
# ARN-based application inference profiles use opaque IDs that don't contain the model name.
# Resolve the underlying foundation model via GetInferenceProfile, then re-check.
if "application-inference-profile" in model_id:
if not hasattr(self, "_resolved_application_profile_strategy"):
self._resolved_application_profile_strategy = self._resolve_application_inference_profile_strategy()
return self._resolved_application_profile_strategy
return None

def _resolve_application_inference_profile_strategy(self) -> str | None:
"""Resolve the cache strategy for an ARN-based application inference profile.

Calls GetInferenceProfile on the Bedrock management API (requires the
``bedrock:GetInferenceProfile`` IAM permission) to discover the underlying
foundation model, then checks whether that model supports prompt caching.
Returns None and logs a debug message on any error, including missing permissions.
"""
try:
# GetInferenceProfile is a Bedrock management API, not available on bedrock-runtime.
bedrock_client = self._boto_session.client(
service_name="bedrock",
region_name=self.client.meta.region_name,
)
response = bedrock_client.get_inference_profile(inferenceProfileIdentifier=self.config["model_id"])
for model_ref in response.get("models", []):
model_arn = model_ref.get("modelArn", "").lower()
if "claude" in model_arn or "anthropic" in model_arn:
return "anthropic"
except Exception:
logger.debug(
"model_id=<%s> | could not resolve application inference profile for cache strategy detection; "
"use CacheConfig(strategy='anthropic') to force-enable caching",
self.config.get("model_id"),
)
return None

@override
Expand All @@ -218,6 +252,8 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.BedrockConfig)
if "model_id" in model_config:
self.__dict__.pop("_resolved_application_profile_strategy", None)
self.config.update(model_config)

@override
Expand Down
2 changes: 1 addition & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

T = TypeVar("T", bound=BaseModel)


def _heuristic_estimate_text(text: str) -> int:
"""Estimate token count from text using characters / 4 heuristic."""
return math.ceil(len(text) / 4)
Expand Down Expand Up @@ -84,7 +85,6 @@ def _count_content_block_tokens(
return total



def _estimate_tokens_with_heuristic(
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
Expand Down
96 changes: 96 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,6 +2929,102 @@ def test_cache_strategy_none_for_non_claude(bedrock_client):
assert model._cache_strategy is None


def test_cache_strategy_application_inference_profile_claude(mock_client_method, bedrock_client):
"""ARN-based application inference profiles resolve to 'anthropic' when backed by a Claude model."""
bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}
profile_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=profile_arn)

assert model._cache_strategy == "anthropic"
bedrock_client.get_inference_profile.assert_called_once_with(inferenceProfileIdentifier=profile_arn)

# Verify the management client ("bedrock") was used, not the runtime client ("bedrock-runtime")
bedrock_mgmt_calls = [c for c in mock_client_method.call_args_list if c.kwargs.get("service_name") == "bedrock"]
assert len(bedrock_mgmt_calls) == 1


def test_cache_strategy_application_inference_profile_non_claude(bedrock_client):
"""ARN-based application inference profiles backed by non-Claude models return None."""
bedrock_client.get_inference_profile.return_value = {
"models": [{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-pro-v1:0"}]
}
profile_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=profile_arn)

assert model._cache_strategy is None


def test_cache_strategy_application_inference_profile_api_error(bedrock_client):
"""GetInferenceProfile failure is handled gracefully — returns None without raising."""
bedrock_client.get_inference_profile.side_effect = Exception("AccessDeniedException")
profile_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=profile_arn)

assert model._cache_strategy is None


def test_cache_strategy_application_inference_profile_cached(bedrock_client):
"""GetInferenceProfile is called only once per model instance, not on every _cache_strategy access."""
bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}
profile_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=profile_arn)

_ = model._cache_strategy
_ = model._cache_strategy
_ = model._cache_strategy

bedrock_client.get_inference_profile.assert_called_once()


def test_cache_strategy_application_inference_profile_invalidated_on_model_id_change(bedrock_client):
"""Cached resolution is cleared when update_config changes the model_id."""
bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}
profile_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=profile_arn)

assert model._cache_strategy == "anthropic"
assert bedrock_client.get_inference_profile.call_count == 1

# Switch to a different application inference profile
new_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/xyz999"
bedrock_client.get_inference_profile.return_value = {
"models": [{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-pro-v1:0"}]
}
model.update_config(model_id=new_arn)

assert model._cache_strategy is None
assert bedrock_client.get_inference_profile.call_count == 2


def test_cache_strategy_application_inference_profile_auto_injects_cache_point(bedrock_client):
"""End-to-end: auto strategy enables caching for messages when profile resolves to Claude."""
bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}
profile_arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=profile_arn, cache_config=CacheConfig(strategy="auto"))
messages = [{"role": "user", "content": [{"text": "hello"}]}]

result = model._format_request(messages)

last_content = result["messages"][-1]["content"]
assert any("cachePoint" in block for block in last_content)


def test_inject_cache_point_adds_to_last_user(bedrock_client):
"""Test that _inject_cache_point adds cache point to last user message."""
model = BedrockModel(
Expand Down
1 change: 0 additions & 1 deletion tests/strands/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ async def test_count_tokens_all_inputs(model):
assert result == 50



class TestHeuristicEstimation:
"""Tests for _estimate_tokens_with_heuristic."""

Expand Down