Skip to content

Commit 172484a

Browse files
fix: refactor duplicated guardrail and callback logic (fixes #1961)
Extract identical guardrail validation (~45 lines) and task callback logic (~13 lines) that was duplicated between sync run_task() and async arun_task() into shared helper methods: - _apply_task_guardrail(): Consolidates guardrail validation and retry logic - _run_task_callback(): Consolidates task callback execution (sync/async) This eliminates 58 lines of character-identical duplication and prevents future guardrail/callback bugs from silently diverging between sync and async paths. All existing behavior preserved, no breaking changes. Co-authored-by: MervinPraison <MervinPraison@users.noreply.github.com>
1 parent 1ad58ca commit 172484a

1 file changed

Lines changed: 95 additions & 119 deletions

File tree

  • src/praisonai-agents/praisonaiagents/agents

src/praisonai-agents/praisonaiagents/agents/agents.py

Lines changed: 95 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,87 @@ async def aexecute_task(self, task_id):
10411041
task_result = _process_task_result(self, context, agent_output)
10421042
return task_result.task_output
10431043

1044+
def _apply_task_guardrail(self, task, task_id, task_output):
1045+
"""Apply guardrail validation to task output.
1046+
1047+
Returns:
1048+
tuple: (task_output, should_retry) where should_retry is True if task should be retried
1049+
1050+
Raises:
1051+
Exception: If guardrail validation fails after max retries
1052+
"""
1053+
if not task._guardrail_fn:
1054+
return task_output, False
1055+
1056+
try:
1057+
guardrail_result = task._process_guardrail(task_output)
1058+
if not guardrail_result.success:
1059+
if task.retry_count >= task.max_retries:
1060+
raise Exception(
1061+
f"Task failed guardrail validation after {task.max_retries} retries. "
1062+
f"Last error: {guardrail_result.error}"
1063+
)
1064+
1065+
task.retry_count += 1
1066+
task.status = "in progress" # Keep task in progress for retry
1067+
logger.warning(f"Task {task_id}: Guardrail validation failed (retry {task.retry_count}/{task.max_retries}): {guardrail_result.error}")
1068+
return task_output, True # Signal retry needed
1069+
1070+
# If guardrail passed and returned a modified result
1071+
if guardrail_result.result is not None:
1072+
if isinstance(guardrail_result.result, str):
1073+
# Update the task output with the modified result
1074+
task_output.raw = guardrail_result.result
1075+
# Clear structured fields to avoid stale cache
1076+
if hasattr(task_output, 'json_dict'):
1077+
task_output.json_dict = None
1078+
if hasattr(task_output, 'pydantic'):
1079+
task_output.pydantic = None
1080+
task.result = task_output
1081+
elif hasattr(guardrail_result.result, 'raw'):
1082+
# Replace with the new task output
1083+
task_output = guardrail_result.result
1084+
task.result = task_output
1085+
1086+
logger.info(f"Task {task_id}: Guardrail validation passed")
1087+
return task_output, False # No retry needed
1088+
1089+
except Exception as e:
1090+
logger.error(f"Task {task_id}: Error in guardrail processing: {e}")
1091+
# Handle guardrail failure with retry logic
1092+
if task.retry_count >= task.max_retries:
1093+
raise Exception(
1094+
f"Task failed due to guardrail processing error after {task.max_retries} retries. "
1095+
f"Last error: {e}"
1096+
) from e
1097+
task.retry_count += 1
1098+
task.status = "in progress"
1099+
logger.warning(f"Task {task_id}: Guardrail processing error (retry {task.retry_count}/{task.max_retries}): {e}")
1100+
return task_output, True # Signal retry needed
1101+
1102+
def _run_task_callback(self, task, task_id, task_output):
1103+
"""Execute task callback (handles both sync and async callbacks).
1104+
1105+
Args:
1106+
task: The task object
1107+
task_id: Task identifier for logging
1108+
task_output: Task output to pass to callback
1109+
"""
1110+
if not task.callback:
1111+
return
1112+
1113+
try:
1114+
if asyncio.iscoroutinefunction(task.callback):
1115+
if run_coroutine_safely:
1116+
run_coroutine_safely(task.callback(task_output))
1117+
else:
1118+
logger.warning("run_coroutine_safely not available, skipping async callback")
1119+
else:
1120+
task.callback(task_output)
1121+
except Exception as e:
1122+
logger.error(f"Error executing task callback for task {task_id}: {e}")
1123+
logger.exception(e)
1124+
10441125
async def arun_task(self, task_id):
10451126
"""Async version of run_task method"""
10461127
if task_id not in self.tasks:
@@ -1059,53 +1140,11 @@ async def arun_task(self, task_id):
10591140
if task.status in ["not started", "in progress"]:
10601141
task_output = await self.aexecute_task(task_id)
10611142
if task_output and self.completion_checker(task, task_output.raw):
1062-
# Run guardrail validation BEFORE marking task complete
1063-
if task._guardrail_fn:
1064-
try:
1065-
guardrail_result = task._process_guardrail(task_output)
1066-
if not guardrail_result.success:
1067-
if task.retry_count >= task.max_retries:
1068-
raise Exception(
1069-
f"Task failed guardrail validation after {task.max_retries} retries. "
1070-
f"Last error: {guardrail_result.error}"
1071-
)
1072-
1073-
task.retry_count += 1
1074-
task.status = "in progress" # Keep task in progress for retry
1075-
logger.warning(f"Task {task_id}: Guardrail validation failed (retry {task.retry_count}/{task.max_retries}): {guardrail_result.error}")
1076-
retries += 1
1077-
continue # Actually retry the task
1078-
1079-
# If guardrail passed and returned a modified result
1080-
if guardrail_result.result is not None:
1081-
if isinstance(guardrail_result.result, str):
1082-
# Update the task output with the modified result
1083-
task_output.raw = guardrail_result.result
1084-
# Clear structured fields to avoid stale cache
1085-
if hasattr(task_output, 'json_dict'):
1086-
task_output.json_dict = None
1087-
if hasattr(task_output, 'pydantic'):
1088-
task_output.pydantic = None
1089-
task.result = task_output
1090-
elif hasattr(guardrail_result.result, 'raw'):
1091-
# Replace with the new task output
1092-
task_output = guardrail_result.result
1093-
task.result = task_output
1094-
1095-
logger.info(f"Task {task_id}: Guardrail validation passed")
1096-
except Exception as e:
1097-
logger.error(f"Task {task_id}: Error in guardrail processing: {e}")
1098-
# Handle guardrail failure with retry logic
1099-
if task.retry_count >= task.max_retries:
1100-
raise Exception(
1101-
f"Task failed due to guardrail processing error after {task.max_retries} retries. "
1102-
f"Last error: {e}"
1103-
) from e
1104-
task.retry_count += 1
1105-
task.status = "in progress"
1106-
logger.warning(f"Task {task_id}: Guardrail processing error (retry {task.retry_count}/{task.max_retries}): {e}")
1107-
retries += 1
1108-
continue # Retry the task
1143+
# Apply guardrail validation using shared helper
1144+
task_output, should_retry = self._apply_task_guardrail(task, task_id, task_output)
1145+
if should_retry:
1146+
retries += 1
1147+
continue
11091148

11101149
task.status = "completed"
11111150
# Run execute_callback for memory operations
@@ -1120,19 +1159,8 @@ async def arun_task(self, task_id):
11201159
if hasattr(task, 'fail_on_memory_error') and task.fail_on_memory_error:
11211160
raise
11221161

1123-
# Run task callback if exists
1124-
if task.callback:
1125-
try:
1126-
if asyncio.iscoroutinefunction(task.callback):
1127-
if run_coroutine_safely:
1128-
run_coroutine_safely(task.callback(task_output))
1129-
else:
1130-
logger.warning("run_coroutine_safely not available, skipping async callback")
1131-
else:
1132-
task.callback(task_output)
1133-
except Exception as e:
1134-
logger.error(f"Error executing task callback for task {task_id}: {e}")
1135-
logger.exception(e)
1162+
# Run task callback using shared helper
1163+
self._run_task_callback(task, task_id, task_output)
11361164

11371165
self.save_output_to_file(task, task_output)
11381166
if self.verbose >= 1:
@@ -1347,52 +1375,11 @@ def run_task(self, task_id):
13471375
if task.status in ["not started", "in progress"]:
13481376
task_output = self.execute_task(task_id)
13491377
if task_output and self.completion_checker(task, task_output.raw):
1350-
# Add guardrail validation (matches arun_task logic)
1351-
if task._guardrail_fn:
1352-
try:
1353-
guardrail_result = task._process_guardrail(task_output)
1354-
if not guardrail_result.success:
1355-
if task.retry_count >= task.max_retries:
1356-
raise Exception(
1357-
f"Task failed guardrail validation after {task.max_retries} retries. "
1358-
f"Last error: {guardrail_result.error}"
1359-
)
1360-
task.retry_count += 1
1361-
task.status = "in progress" # Keep task in progress for retry
1362-
logger.warning(f"Task {task_id}: Guardrail validation failed (retry {task.retry_count}/{task.max_retries}): {guardrail_result.error}")
1363-
retries += 1
1364-
continue # Retry the task
1365-
1366-
# If guardrail passed and returned a modified result
1367-
if guardrail_result.result is not None:
1368-
if isinstance(guardrail_result.result, str):
1369-
# Update the task output with the modified result
1370-
task_output.raw = guardrail_result.result
1371-
# Clear structured fields to avoid stale cache
1372-
if hasattr(task_output, 'json_dict'):
1373-
task_output.json_dict = None
1374-
if hasattr(task_output, 'pydantic'):
1375-
task_output.pydantic = None
1376-
task.result = task_output
1377-
elif hasattr(guardrail_result.result, 'raw'):
1378-
# Replace with the new task output
1379-
task_output = guardrail_result.result
1380-
task.result = task_output
1381-
1382-
logger.info(f"Task {task_id}: Guardrail validation passed")
1383-
except Exception as e:
1384-
logger.error(f"Task {task_id}: Error in guardrail processing: {e}")
1385-
# Handle guardrail failure with retry logic
1386-
if task.retry_count >= task.max_retries:
1387-
raise Exception(
1388-
f"Task failed due to guardrail processing error after {task.max_retries} retries. "
1389-
f"Last error: {e}"
1390-
) from e
1391-
task.retry_count += 1
1392-
task.status = "in progress"
1393-
logger.warning(f"Task {task_id}: Guardrail processing error (retry {task.retry_count}/{task.max_retries}): {e}")
1394-
retries += 1
1395-
continue # Retry the task
1378+
# Apply guardrail validation using shared helper
1379+
task_output, should_retry = self._apply_task_guardrail(task, task_id, task_output)
1380+
if should_retry:
1381+
retries += 1
1382+
continue
13961383

13971384
task.status = "completed"
13981385
# Run execute_callback for memory operations
@@ -1403,19 +1390,8 @@ def run_task(self, task_id):
14031390
logger.error(f"Error executing memory callback for task {task_id}: {e}")
14041391
logger.exception(e)
14051392

1406-
# Run task callback if exists
1407-
if task.callback:
1408-
try:
1409-
if asyncio.iscoroutinefunction(task.callback):
1410-
if run_coroutine_safely:
1411-
run_coroutine_safely(task.callback(task_output))
1412-
else:
1413-
logger.warning("run_coroutine_safely not available, skipping async callback")
1414-
else:
1415-
task.callback(task_output)
1416-
except Exception as e:
1417-
logger.error(f"Error executing task callback for task {task_id}: {e}")
1418-
logger.exception(e)
1393+
# Run task callback using shared helper
1394+
self._run_task_callback(task, task_id, task_output)
14191395

14201396
self.save_output_to_file(task, task_output)
14211397

0 commit comments

Comments
 (0)