Skip to content

Commit 549eed7

Browse files
authored
Merge pull request #998 from vlita/retry
Fixed bug in reset logic caused by dictionary comprehension
2 parents 631ac07 + 3cfe602 commit 549eed7

File tree

2 files changed

+112
-1
lines changed

2 files changed

+112
-1
lines changed

qcfractal/qcfractal/components/tasks/reset_logic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ def should_reset(record_orm: BaseRecordORM, config: AutoResetConfig) -> bool:
5353
logger.debug(f" {k}: {v}")
5454

5555
# Map to more general error categories
56-
error_counts = {error_map.get(k, "unknown_error"): v for k, v in error_counts.items()}
56+
# error_counts = {error_map.get(k, "unknown_error"): v for k, v in error_counts.items()}
57+
58+
mapped_counts = {}
59+
for k, v in error_counts.items():
60+
category = error_map.get(k, "unknown_error")
61+
# add to dict instead of overwriting
62+
mapped_counts[category] = mapped_counts.get(category, 0) + v
63+
error_counts = mapped_counts
5764

5865
# Are we beyond any of the max on any?
5966
for err, count in error_counts.items():
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Tests for the auto-reset logic, specifically for the dictionary comprehension bug
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING
8+
9+
import pytest
10+
11+
from qcelemental.models import ComputeError, FailedOperation
12+
from qcarchivetesting.testing_classes import QCATestingSnowflake
13+
from qcfractal.components.record_db_models import BaseRecordORM
14+
from qcfractal.components.singlepoint.testing_helpers import load_procedure_data
15+
from qcfractalcompute.compress import compress_result
16+
from qcportal.record_models import PriorityEnum, RecordStatusEnum
17+
18+
if TYPE_CHECKING:
19+
from qcfractal.db_socket import SQLAlchemySocket
20+
from sqlalchemy.orm.session import Session
21+
22+
23+
def test_reset_logic_dict_comprehension_bug(postgres_server, pytestconfig):
24+
25+
pg_harness = postgres_server.get_new_harness("reset_logic_dict_bug")
26+
encoding = pytestconfig.getoption("--client-encoding")
27+
28+
# Configure auto_reset with a limit of 2 unknown_errors
29+
extra_config = {
30+
"auto_reset": {
31+
"enabled": True,
32+
"unknown_error": 2
33+
}
34+
}
35+
36+
with QCATestingSnowflake(pg_harness, encoding=encoding, extra_config=extra_config) as snowflake:
37+
storage_socket = snowflake.get_storage_socket()
38+
activated_manager_name, _ = snowflake.activate_manager()
39+
activated_manager_programs = snowflake.activated_manager_programs()
40+
41+
# Load test data and submit a singlepoint calculation
42+
input_spec, molecule, result_data = load_procedure_data("sp_psi4_water_energy")
43+
meta, record_ids = storage_socket.records.singlepoint.add(
44+
[molecule], input_spec, "tag1", PriorityEnum.normal, None, True
45+
)
46+
record_id = record_ids[0]
47+
48+
# Create FailedOperation objects for the two error types
49+
fop_bad_state = FailedOperation(
50+
error=ComputeError(
51+
error_type="BadStateException",
52+
error_message="QOSMaxSubmitJobPerUserLimit reached"
53+
)
54+
)
55+
fop_too_many = FailedOperation(
56+
error=ComputeError(
57+
error_type="TooManyJobFailuresError",
58+
error_message="Wrapped Parsl exception"
59+
)
60+
)
61+
62+
with storage_socket.session_scope() as session:
63+
rec = session.get(BaseRecordORM, record_id)
64+
65+
# Failure 1: BadStateException
66+
tasks = storage_socket.tasks.claim_tasks(
67+
activated_manager_name.fullname,
68+
activated_manager_programs,
69+
["*"]
70+
)
71+
assert len(tasks) == 1
72+
storage_socket.tasks.update_finished(
73+
activated_manager_name.fullname,
74+
{tasks[0]["id"]: compress_result(fop_bad_state.dict())}
75+
)
76+
session.expire(rec)
77+
assert rec.status == RecordStatusEnum.waiting
78+
assert len(rec.compute_history) == 1
79+
80+
# Failures 2-6: TooManyJobFailuresError
81+
for i in range(2):
82+
tasks = storage_socket.tasks.claim_tasks(
83+
activated_manager_name.fullname,
84+
activated_manager_programs,
85+
["*"]
86+
)
87+
print(f"Iteration {i}: rec.status = {rec.status}, history length = {len(rec.compute_history)}")
88+
assert len(tasks) == 1
89+
storage_socket.tasks.update_finished(
90+
activated_manager_name.fullname,
91+
{tasks[0]["id"]: compress_result(fop_too_many.dict())}
92+
)
93+
session.expire(rec)
94+
95+
# after each failure, check the compute_history length
96+
assert len(rec.compute_history) == i + 2
97+
98+
session.expire(rec)
99+
assert len(rec.compute_history) == 3
100+
101+
assert rec.status == RecordStatusEnum.error, (
102+
f"After 3 errors (1 BadStateException + 2 TooManyJobFailuresError), record should be set to errored."
103+
)
104+

0 commit comments

Comments
 (0)