Skip to content

Commit d5e8012

Browse files
committed
retry schema validation error
1 parent 4c7f3fc commit d5e8012

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

libs/core/kiln_ai/adapters/eval/eval_runner.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,21 +263,38 @@ async def run_job(self, job: EvalJob) -> bool:
263263
eval_run.save_to_file()
264264

265265
return True
266-
except (
267-
litellm.RateLimitError,
268-
litellm.APIConnectionError,
269-
litellm.InternalServerError,
270-
litellm.ServiceUnavailableError,
271-
litellm.BadGatewayError,
272-
) as e:
273-
logger.error(
274-
f"Transient error running eval job for dataset item {job.item.id}: {e}",
275-
exc_info=True,
276-
)
277-
raise RetryableError(str(e)) from e
278266
except Exception as e:
267+
if _is_retryable_error(e):
268+
logger.error(
269+
f"Transient error running eval job for dataset item {job.item.id}: {e}",
270+
exc_info=True,
271+
)
272+
raise RetryableError(str(e)) from e
279273
logger.error(
280274
f"Error running eval job for dataset item {job.item.id}: {e}",
281275
exc_info=True,
282276
)
283277
raise
278+
279+
280+
def _is_retryable_error(e: BaseException) -> bool:
281+
if isinstance(
282+
e,
283+
(
284+
litellm.RateLimitError,
285+
litellm.APIConnectionError,
286+
litellm.InternalServerError,
287+
litellm.ServiceUnavailableError,
288+
litellm.BadGatewayError,
289+
litellm.JSONSchemaValidationError,
290+
),
291+
):
292+
return True
293+
294+
# ValueError thrown by Kiln's adapter when structured output doesn't match schema
295+
if isinstance(
296+
e, ValueError
297+
) and "This task requires a specific output schema" in str(e):
298+
return True
299+
300+
return False

libs/core/kiln_ai/adapters/eval/test_eval_runner.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Dict
22
from unittest.mock import AsyncMock, patch
33

4+
import litellm
45
import pytest
56

67
from kiln_ai.adapters.eval.base_eval import BaseEval
7-
from kiln_ai.adapters.eval.eval_runner import EvalJob, EvalRunner
8+
from kiln_ai.adapters.eval.eval_runner import EvalJob, EvalRunner, _is_retryable_error
89
from kiln_ai.adapters.ml_model_list import ModelProviderName
910
from kiln_ai.datamodel import (
1011
DataSource,
@@ -831,3 +832,34 @@ async def run_task_and_eval(self, eval_job_item: TaskRun):
831832
# For full_trace evals, None trace should fail and not save a run
832833
eval_runs = mock_eval_config.runs()
833834
assert len(eval_runs) == 0
835+
836+
837+
@pytest.mark.parametrize(
838+
"error",
839+
[
840+
litellm.RateLimitError("rate limited", "provider", "model", None),
841+
litellm.APIConnectionError("connection failed", "provider", "model", None),
842+
litellm.InternalServerError("server error", "provider", "model", None),
843+
litellm.ServiceUnavailableError("unavailable", "provider", "model", None),
844+
litellm.BadGatewayError("bad gateway", "provider", "model", None),
845+
litellm.JSONSchemaValidationError("schema error", "provider", "model", None),
846+
ValueError(
847+
"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema."
848+
),
849+
],
850+
)
851+
def test_is_retryable_error_returns_true(error):
852+
assert _is_retryable_error(error) is True
853+
854+
855+
@pytest.mark.parametrize(
856+
"error",
857+
[
858+
ValueError("some other value error"),
859+
RuntimeError("runtime error"),
860+
KeyError("missing key"),
861+
TypeError("type error"),
862+
],
863+
)
864+
def test_is_retryable_error_returns_false(error):
865+
assert _is_retryable_error(error) is False

0 commit comments

Comments
 (0)