diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index b907cd396..5bcd5ca66 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -207,6 +207,7 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]: completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) + self._set_litellm_retry_params(completion_kwargs) # if hasattr(self, "model") and self.model not in O1_MODELS: # completion_kwargs["temperature"] = 0.003 @@ -295,6 +296,7 @@ def stream_complete( completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) + self._set_litellm_retry_params(completion_kwargs) for chunk in litellm.completion( messages=messages, @@ -363,6 +365,7 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]: completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) + self._set_litellm_retry_params(completion_kwargs) response = await litellm.acompletion( messages=messages, @@ -454,6 +457,26 @@ def get_metrics(self) -> dict[str, object]: def get_usage_reason(self) -> object: return self.platform_kwargs.get("llm_usage_reason") + @staticmethod + def _set_litellm_retry_params(completion_kwargs: dict) -> None: + """Activate litellm's wrapper-level retry for all providers. + + litellm's retry mechanism (completion_with_retries) only activates when + num_retries is set. Our adapters pass max_retries (from user UI config) + which only works for SDK-based providers (OpenAI, Azure). This bridges + the gap by copying max_retries into num_retries so httpx-based providers + (Anthropic, Vertex, Bedrock, Mistral, etc.) also get retries. + + SDK-based providers (OpenAI, Azure) default to max_retries=2 internally, + which would multiply with wrapper retries. Setting max_retries=0 ensures + all retries go through the wrapper uniformly. + """ + max_retries = completion_kwargs.get("max_retries") + if max_retries: + completion_kwargs["num_retries"] = max_retries + completion_kwargs["max_retries"] = 0 + completion_kwargs["retry_strategy"] = "exponential_backoff_retry" + def _record_usage( self, model: str,