Skip to content

Commit 10e5d16

Browse files
Merge pull request #1200 from Kiln-AI/dchiang/KIL-478/retry-eval-fail
Add retry to AsyncJobRunner
2 parents d8440c8 + 6a1a7fd commit 10e5d16

File tree

4 files changed

+375
-18
lines changed

4 files changed

+375
-18
lines changed

libs/core/kiln_ai/adapters/eval/eval_runner.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from typing import AsyncGenerator, Dict, List, Literal, Set
55

6+
import litellm
7+
68
from kiln_ai.adapters.adapter_registry import load_skills_for_task
79
from kiln_ai.adapters.eval.base_eval import BaseEval
810
from kiln_ai.adapters.eval.registry import eval_adapter_from_type
@@ -12,7 +14,7 @@
1214
from kiln_ai.datamodel.eval import EvalConfig, EvalDataType, EvalRun, EvalScores
1315
from kiln_ai.datamodel.task import TaskRunConfig
1416
from kiln_ai.datamodel.task_run import TaskRun, Usage
15-
from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
17+
from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress, RetryableError
1618

1719
logger = logging.getLogger(__name__)
1820

@@ -188,6 +190,7 @@ async def run(self, concurrency: int = 25) -> AsyncGenerator[Progress, None]:
188190
concurrency=concurrency,
189191
jobs=jobs,
190192
run_job_fn=self.run_job,
193+
max_retries=2,
191194
)
192195
async for progress in runner.run():
193196
yield progress
@@ -261,8 +264,37 @@ async def run_job(self, job: EvalJob) -> bool:
261264

262265
return True
263266
except Exception as e:
267+
if _is_retryable_error(e):
268+
logger.error(
269+
f"Transient error running eval job for dataset item {job.item.id}: {e}",
270+
exc_info=True,
271+
)
272+
raise RetryableError(str(e)) from e
264273
logger.error(
265274
f"Error running eval job for dataset item {job.item.id}: {e}",
266275
exc_info=True,
267276
)
268-
return False
277+
raise
278+
279+
280+
def _is_retryable_error(e: BaseException) -> bool:
281+
if isinstance(
282+
e,
283+
(
284+
litellm.RateLimitError,
285+
litellm.APIConnectionError,
286+
litellm.InternalServerError,
287+
litellm.ServiceUnavailableError,
288+
litellm.BadGatewayError,
289+
litellm.JSONSchemaValidationError,
290+
),
291+
):
292+
return True
293+
294+
# ValueError thrown by Kiln's adapter when structured output doesn't match schema
295+
if isinstance(
296+
e, ValueError
297+
) and "This task requires a specific output schema" in str(e):
298+
return True
299+
300+
return False

libs/core/kiln_ai/adapters/eval/test_eval_runner.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Dict
22
from unittest.mock import AsyncMock, patch
33

4+
import litellm
45
import pytest
56

67
from kiln_ai.adapters.eval.base_eval import BaseEval
7-
from kiln_ai.adapters.eval.eval_runner import EvalJob, EvalRunner
8+
from kiln_ai.adapters.eval.eval_runner import EvalJob, EvalRunner, _is_retryable_error
89
from kiln_ai.adapters.ml_model_list import ModelProviderName
910
from kiln_ai.datamodel import (
1011
DataSource,
@@ -608,9 +609,9 @@ async def test_run_job_invalid_evaluator(
608609
"kiln_ai.adapters.eval.eval_runner.eval_adapter_from_type",
609610
return_value=lambda *args, **kwargs: object(),
610611
):
611-
success = await mock_eval_runner.run_job(job)
612+
with pytest.raises(ValueError):
613+
await mock_eval_runner.run_job(job)
612614

613-
assert success is False
614615
assert len(mock_eval_config.runs()) == 0
615616

616617

@@ -640,9 +641,9 @@ async def run_task_and_eval(self, eval_job_item: TaskRun):
640641
"kiln_ai.adapters.eval.eval_runner.eval_adapter_from_type",
641642
return_value=lambda *args, **kwargs: ErrorEvaluator(*args, **kwargs),
642643
):
643-
success = await mock_eval_runner.run_job(job)
644+
with pytest.raises(ValueError):
645+
await mock_eval_runner.run_job(job)
644646

645-
assert success is False
646647
assert len(mock_eval_config.runs()) == 0
647648

648649

@@ -825,9 +826,40 @@ async def run_task_and_eval(self, eval_job_item: TaskRun):
825826
"kiln_ai.adapters.eval.eval_runner.eval_adapter_from_type",
826827
return_value=lambda *args, **kwargs: MockEvaluator(*args, **kwargs),
827828
):
828-
success = await mock_eval_runner.run_job(job)
829+
with pytest.raises(ValueError):
830+
await mock_eval_runner.run_job(job)
829831

830832
# For full_trace evals, None trace should fail and not save a run
831-
assert success is False
832833
eval_runs = mock_eval_config.runs()
833834
assert len(eval_runs) == 0
835+
836+
837+
@pytest.mark.parametrize(
838+
"error",
839+
[
840+
litellm.RateLimitError("rate limited", "provider", "model", None),
841+
litellm.APIConnectionError("connection failed", "provider", "model", None),
842+
litellm.InternalServerError("server error", "provider", "model", None),
843+
litellm.ServiceUnavailableError("unavailable", "provider", "model", None),
844+
litellm.BadGatewayError("bad gateway", "provider", "model", None),
845+
litellm.JSONSchemaValidationError("schema error", "provider", "model", None),
846+
ValueError(
847+
"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema."
848+
),
849+
],
850+
)
851+
def test_is_retryable_error_returns_true(error):
852+
assert _is_retryable_error(error) is True
853+
854+
855+
@pytest.mark.parametrize(
856+
"error",
857+
[
858+
ValueError("some other value error"),
859+
RuntimeError("runtime error"),
860+
KeyError("missing key"),
861+
TypeError("type error"),
862+
],
863+
)
864+
def test_is_retryable_error_returns_false(error):
865+
assert _is_retryable_error(error) is False

libs/core/kiln_ai/utils/async_job_runner.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ class Progress:
1515
errors: int
1616

1717

18+
class RetryableError(Exception):
19+
"""Raise from run_job_fn to signal a transient failure that should be retried."""
20+
21+
pass
22+
23+
1824
class AsyncJobRunnerObserver(Generic[T]):
1925
async def on_error(self, job: T, error: Exception):
2026
"""
@@ -42,10 +48,18 @@ def __init__(
4248
run_job_fn: Callable[[T], Awaitable[bool]],
4349
concurrency: int = 1,
4450
observers: List[AsyncJobRunnerObserver[T]] | None = None,
51+
max_retries: int = 0,
52+
retry_delay: float = 1.0, # in seconds
4553
):
4654
if concurrency < 1:
4755
raise ValueError("concurrency must be ≥ 1")
56+
if max_retries < 0:
57+
raise ValueError("max_retries must be >= 0")
58+
if retry_delay < 0:
59+
raise ValueError("retry_delay must be >= 0")
4860
self.concurrency = concurrency
61+
self.max_retries = max_retries
62+
self.retry_delay = retry_delay
4963
self.jobs = jobs
5064
self.run_job_fn = run_job_fn
5165
self.observers = observers or []
@@ -132,15 +146,32 @@ async def _run_worker(
132146
# worker can end when the queue is empty
133147
break
134148

135-
try:
136-
await self.notify_job_start(job)
137-
result = await run_job_fn(job)
138-
if result:
139-
await self.notify_success(job)
140-
except Exception as e:
141-
logger.error("Job failed to complete", exc_info=True)
142-
await self.notify_error(job, e)
143-
result = False
149+
await self.notify_job_start(job)
150+
result = False
151+
last_error: Exception | None = None
152+
for attempt in range(1 + self.max_retries):
153+
is_last_attempt = attempt == self.max_retries
154+
try:
155+
result = await run_job_fn(job)
156+
last_error = None
157+
break
158+
except RetryableError as e:
159+
result = False
160+
last_error = e
161+
if is_last_attempt:
162+
logger.error("Job failed to complete", exc_info=e)
163+
break
164+
await asyncio.sleep(self.retry_delay)
165+
except Exception as e:
166+
result = False
167+
last_error = e
168+
logger.error("Job failed to complete", exc_info=e)
169+
break
170+
171+
if result:
172+
await self.notify_success(job)
173+
elif last_error is not None:
174+
await self.notify_error(job, last_error)
144175

145176
try:
146177
await status_queue.put(result)

0 commit comments

Comments
 (0)