|
| 1 | +""" |
| 2 | +Tests for the auto-reset logic, specifically for the dictionary comprehension bug |
| 3 | +""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +from typing import TYPE_CHECKING |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from qcelemental.models import ComputeError, FailedOperation |
| 12 | +from qcarchivetesting.testing_classes import QCATestingSnowflake |
| 13 | +from qcfractal.components.record_db_models import BaseRecordORM |
| 14 | +from qcfractal.components.singlepoint.testing_helpers import load_procedure_data |
| 15 | +from qcfractalcompute.compress import compress_result |
| 16 | +from qcportal.record_models import PriorityEnum, RecordStatusEnum |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from qcfractal.db_socket import SQLAlchemySocket |
| 20 | + from sqlalchemy.orm.session import Session |
| 21 | + |
| 22 | + |
| 23 | +def test_reset_logic_dict_comprehension_bug(postgres_server, pytestconfig): |
| 24 | + |
| 25 | + pg_harness = postgres_server.get_new_harness("reset_logic_dict_bug") |
| 26 | + encoding = pytestconfig.getoption("--client-encoding") |
| 27 | + |
| 28 | + # Configure auto_reset with a limit of 2 unknown_errors |
| 29 | + extra_config = { |
| 30 | + "auto_reset": { |
| 31 | + "enabled": True, |
| 32 | + "unknown_error": 2 |
| 33 | + } |
| 34 | + } |
| 35 | + |
| 36 | + with QCATestingSnowflake(pg_harness, encoding=encoding, extra_config=extra_config) as snowflake: |
| 37 | + storage_socket = snowflake.get_storage_socket() |
| 38 | + activated_manager_name, _ = snowflake.activate_manager() |
| 39 | + activated_manager_programs = snowflake.activated_manager_programs() |
| 40 | + |
| 41 | + # Load test data and submit a singlepoint calculation |
| 42 | + input_spec, molecule, result_data = load_procedure_data("sp_psi4_water_energy") |
| 43 | + meta, record_ids = storage_socket.records.singlepoint.add( |
| 44 | + [molecule], input_spec, "tag1", PriorityEnum.normal, None, True |
| 45 | + ) |
| 46 | + record_id = record_ids[0] |
| 47 | + |
| 48 | + # Create FailedOperation objects for the two error types |
| 49 | + fop_bad_state = FailedOperation( |
| 50 | + error=ComputeError( |
| 51 | + error_type="BadStateException", |
| 52 | + error_message="QOSMaxSubmitJobPerUserLimit reached" |
| 53 | + ) |
| 54 | + ) |
| 55 | + fop_too_many = FailedOperation( |
| 56 | + error=ComputeError( |
| 57 | + error_type="TooManyJobFailuresError", |
| 58 | + error_message="Wrapped Parsl exception" |
| 59 | + ) |
| 60 | + ) |
| 61 | + |
| 62 | + with storage_socket.session_scope() as session: |
| 63 | + rec = session.get(BaseRecordORM, record_id) |
| 64 | + |
| 65 | + # Failure 1: BadStateException |
| 66 | + tasks = storage_socket.tasks.claim_tasks( |
| 67 | + activated_manager_name.fullname, |
| 68 | + activated_manager_programs, |
| 69 | + ["*"] |
| 70 | + ) |
| 71 | + assert len(tasks) == 1 |
| 72 | + storage_socket.tasks.update_finished( |
| 73 | + activated_manager_name.fullname, |
| 74 | + {tasks[0]["id"]: compress_result(fop_bad_state.dict())} |
| 75 | + ) |
| 76 | + session.expire(rec) |
| 77 | + assert rec.status == RecordStatusEnum.waiting |
| 78 | + assert len(rec.compute_history) == 1 |
| 79 | + |
| 80 | + # Failures 2-6: TooManyJobFailuresError |
| 81 | + for i in range(2): |
| 82 | + tasks = storage_socket.tasks.claim_tasks( |
| 83 | + activated_manager_name.fullname, |
| 84 | + activated_manager_programs, |
| 85 | + ["*"] |
| 86 | + ) |
| 87 | + print(f"Iteration {i}: rec.status = {rec.status}, history length = {len(rec.compute_history)}") |
| 88 | + assert len(tasks) == 1 |
| 89 | + storage_socket.tasks.update_finished( |
| 90 | + activated_manager_name.fullname, |
| 91 | + {tasks[0]["id"]: compress_result(fop_too_many.dict())} |
| 92 | + ) |
| 93 | + session.expire(rec) |
| 94 | + |
| 95 | + # after each failure, check the compute_history length |
| 96 | + assert len(rec.compute_history) == i + 2 |
| 97 | + |
| 98 | + session.expire(rec) |
| 99 | + assert len(rec.compute_history) == 3 |
| 100 | + |
| 101 | + assert rec.status == RecordStatusEnum.error, ( |
| 102 | + f"After 3 errors (1 BadStateException + 2 TooManyJobFailuresError), record should be set to errored." |
| 103 | + ) |
| 104 | + |
0 commit comments