Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ def __init__(
wait_seconds=RETRY_WAIT,
exponential_base=RETRY_EXPONENTIAL_BASE,
jitter=RETRY_JITTER,
retry_predicate=is_rate_limit_error,
retry_predicate=is_transient_error,
reraise=True,
before_sleep_logger=(logger, logging.INFO),
after_logger=(logger, logging.INFO),
Expand All @@ -1191,6 +1191,36 @@ def _apply_rate_limit(self):
self.rate_limiter.throttle()


def is_transient_error(exception: BaseException) -> bool:
"""Return True if the exception is a transient API error worth retrying."""
error_str = str(exception).lower()
return any(
signal in error_str
for signal in (
# Rate limits
"429",
"rate limit",
"too many requests",
"rate_limit",
# Transient server errors
"500",
"502",
"503",
"504",
"service unavailable",
"bad gateway",
"gateway timeout",
"internal server error",
# Connection failures
"connection error",
"connection reset",
"connection refused",
"timed out",
"timeout",
)
)


def is_rate_limit_error(exception: BaseException) -> bool:
"""Check if the exception is a rate limit error."""
error_str = str(exception).lower()
Expand Down
36 changes: 36 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_clean_message_list,
get_tool_call_from_text,
get_tool_json_schema,
is_transient_error,
parse_json_if_needed,
remove_content_after_stop_sequences,
supports_stop_parameter,
Expand Down Expand Up @@ -1077,3 +1078,38 @@ def test_tool_calls_json_serialization(model_class, model_id):
assert len(data["tool_calls"]) > 0
assert data["tool_calls"][0]["function"]["name"] == "final_answer"
assert data["tool_calls"][0]["function"]["arguments"] == "test_result"


class TestIsTransientError:
def test_rate_limit_429(self):
assert is_transient_error(Exception("HTTP 429 Too Many Requests"))

def test_rate_limit_string(self):
assert is_transient_error(Exception("rate limit exceeded"))

def test_503_service_unavailable(self):
assert is_transient_error(Exception("503 Service Unavailable"))

def test_502_bad_gateway(self):
assert is_transient_error(Exception("502 Bad Gateway"))

def test_500_internal_server_error(self):
assert is_transient_error(Exception("500 Internal Server Error"))

def test_504_gateway_timeout(self):
assert is_transient_error(Exception("504 Gateway Timeout"))

def test_connection_reset(self):
assert is_transient_error(Exception("Connection reset by peer"))

def test_timeout(self):
assert is_transient_error(Exception("request timed out"))

def test_non_retryable_400(self):
assert not is_transient_error(Exception("400 Bad Request"))

def test_non_retryable_401(self):
assert not is_transient_error(Exception("401 Unauthorized"))

def test_non_retryable_404(self):
assert not is_transient_error(Exception("404 Not Found"))