Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]:

completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs})
completion_kwargs.pop("cost_model", None)
self._set_litellm_retry_params(completion_kwargs)

# if hasattr(self, "model") and self.model not in O1_MODELS:
# completion_kwargs["temperature"] = 0.003
Expand Down Expand Up @@ -295,6 +296,7 @@ def stream_complete(

completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs})
completion_kwargs.pop("cost_model", None)
self._set_litellm_retry_params(completion_kwargs)

for chunk in litellm.completion(
messages=messages,
Expand Down Expand Up @@ -363,6 +365,7 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]:

completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs})
completion_kwargs.pop("cost_model", None)
self._set_litellm_retry_params(completion_kwargs)

response = await litellm.acompletion(
messages=messages,
Expand Down Expand Up @@ -454,6 +457,26 @@ def get_metrics(self) -> dict[str, object]:
def get_usage_reason(self) -> object:
return self.platform_kwargs.get("llm_usage_reason")

@staticmethod
def _set_litellm_retry_params(completion_kwargs: dict) -> None:
"""Activate litellm's wrapper-level retry for all providers.

litellm's retry mechanism (completion_with_retries) only activates when
num_retries is set. Our adapters pass max_retries (from user UI config)
which only works for SDK-based providers (OpenAI, Azure). This bridges
the gap by copying max_retries into num_retries so httpx-based providers
(Anthropic, Vertex, Bedrock, Mistral, etc.) also get retries.

SDK-based providers (OpenAI, Azure) default to max_retries=2 internally,
which would multiply with wrapper retries. Setting max_retries=0 ensures
all retries go through the wrapper uniformly.
"""
max_retries = completion_kwargs.get("max_retries")
if max_retries:
completion_kwargs["num_retries"] = max_retries
completion_kwargs["max_retries"] = 0
completion_kwargs["retry_strategy"] = "exponential_backoff_retry"
Comment on lines +474 to +478
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle zero and invalid retry counts explicitly.

At Line 474, using a truthy check skips an explicit max_retries=0 and allows negative values through. Please branch on is not None and validate bounds before copying.

Suggested patch
     max_retries = completion_kwargs.get("max_retries")
-    if max_retries:
-        completion_kwargs["num_retries"] = max_retries
-        completion_kwargs["retry_strategy"] = "exponential_backoff_retry"
+    if max_retries is None:
+        return
+    if not isinstance(max_retries, int) or max_retries < 0:
+        raise SdkError("Invalid max_retries: expected a non-negative integer")
+    completion_kwargs["num_retries"] = max_retries
+    if max_retries > 0:
+        completion_kwargs["retry_strategy"] = "exponential_backoff_retry"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@unstract/sdk1/src/unstract/sdk1/llm.py` around lines 473 - 476, The current
truthy check for max_retries skips explicit zero and permits negatives; change
the branch to test "max_retries is not None", validate that max_retries is an
integer and >= 0 (or raise a ValueError for invalid values), then set
completion_kwargs["num_retries"] = max_retries and
completion_kwargs["retry_strategy"] = "exponential_backoff_retry"; apply this
logic around the max_retries handling (completion_kwargs, max_retries,
num_retries, retry_strategy) so zero is honored and negatives/non-integers are
rejected.


def _record_usage(
self,
model: str,
Expand Down
Loading