|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | from itertools import product |
6 | | -from unittest import TestCase |
| 6 | +import pytest # Add pytest back |
| 7 | +import os |
| 8 | + |
| 9 | +# import logging # No longer needed if logger is removed |
| 10 | +# import sys # No longer used |
7 | 11 |
|
8 | 12 | from numpy import array, empty |
9 | 13 |
|
|
15 | 19 | check_arrays, |
16 | 20 | make_alternating_boolean_array, |
17 | 21 | make_cascading_boolean_array, |
18 | | - parameter_space, |
| 22 | + # parameter_space, # This was commented out, so it's unused |
19 | 23 | ) |
20 | | -from zipline.testing.fixtures import ( |
| 24 | +from zipline.testing.fixtures import ( # Assuming this is where ZiplineTestCase and others are |
21 | 25 | WithConstantEquityMinuteBarData, |
22 | 26 | WithDataPortal, |
23 | | - ZiplineTestCase, |
| 27 | + ZiplineTestCase, # Add back ZiplineTestCase import |
24 | 28 | ) |
25 | 29 | from zipline.testing.slippage import TestingSlippage |
26 | 30 | from zipline.testing.predicates import wildcard, instance_of |
27 | 31 | from zipline.utils.numpy_utils import bool_dtype |
28 | 32 |
|
| 33 | +# Group all tests in this module to run on the same worker |
| 34 | +pytestmark = pytest.mark.xdist_group(name="module_group_test_testing") |
29 | 35 |
|
30 | | -class TestParameterSpace(TestCase): |
| 36 | +# Configure a logger for this module |
| 37 | +# logger = logging.getLogger(__name__) # Removed as it's no longer used |
31 | 38 |
|
32 | | - x_args = [1, 2] |
33 | | - y_args = [3, 4] |
34 | 39 |
|
35 | | - @classmethod |
36 | | - def setup_class(cls): |
37 | | - cls.xy_invocations = [] |
38 | | - cls.yx_invocations = [] |
| 40 | +@pytest.fixture(scope="class") |
| 41 | +def invocations_state(request): |
| 42 | + request.cls.xy_invocations = [] |
| 43 | + request.cls.yx_invocations = [] |
| 44 | + yield |
39 | 45 |
|
40 | | - @classmethod |
41 | | - def teardown_class(cls): |
42 | | - # This is the only actual test here. |
43 | | - assert cls.xy_invocations == list(product(cls.x_args, cls.y_args)) |
44 | | - assert cls.yx_invocations == list(product(cls.y_args, cls.x_args)) |
| 46 | + actual_xy_invocations = sorted(request.cls.xy_invocations) |
| 47 | + actual_yx_invocations = sorted(request.cls.yx_invocations) |
| 48 | + |
| 49 | + expected_xy = sorted( |
| 50 | + list(product(request.cls.x_args_vals, request.cls.y_args_vals)) |
| 51 | + ) |
| 52 | + expected_yx = sorted( |
| 53 | + list(product(request.cls.y_args_vals, request.cls.x_args_vals)) |
| 54 | + ) |
| 55 | + |
| 56 | + worker = os.environ.get("PYTEST_XDIST_WORKER", "main") |
| 57 | + |
| 58 | + assert ( |
| 59 | + actual_xy_invocations == expected_xy |
| 60 | + ), f"[{worker}] XY invocations do not match. Expected: {expected_xy}, Got: {actual_xy_invocations}" |
| 61 | + assert ( |
| 62 | + actual_yx_invocations == expected_yx |
| 63 | + ), f"[{worker}] YX invocations do not match. Expected: {expected_yx}, Got: {actual_yx_invocations}" |
45 | 64 |
|
46 | | - @parameter_space(x=x_args, y=y_args) |
| 65 | + |
| 66 | +@pytest.mark.usefixtures("invocations_state") |
| 67 | +class TestParameterSpace: |
| 68 | + """Test class for parametrized tests using a shared state via fixture.""" |
| 69 | + |
| 70 | + x_args_vals = [1, 2] |
| 71 | + y_args_vals = [3, 4] |
| 72 | + |
| 73 | + @pytest.mark.parametrize("x", x_args_vals) |
| 74 | + @pytest.mark.parametrize("y", y_args_vals) |
47 | 75 | def test_xy(self, x, y): |
48 | | - self.xy_invocations.append((x, y)) |
| 76 | + """Test xy parameter combinations.""" |
| 77 | + self.__class__.xy_invocations.append((x, y)) |
49 | 78 |
|
50 | | - @parameter_space(x=x_args, y=y_args) |
| 79 | + @pytest.mark.parametrize("y", y_args_vals) |
| 80 | + @pytest.mark.parametrize("x", x_args_vals) |
51 | 81 | def test_yx(self, y, x): |
52 | | - # Ensure that product is called with args in the order that they appear |
53 | | - # in the function's parameter list. |
54 | | - self.yx_invocations.append((y, x)) |
| 82 | + """Test yx parameter combinations.""" |
| 83 | + self.__class__.yx_invocations.append((y, x)) |
55 | 84 |
|
56 | 85 | def test_nothing(self): |
57 | | - # Ensure that there's at least one "real" test in the class, or else |
58 | | - # our {setUp,tearDown}Class won't be called if, for example, |
59 | | - # `parameter_space` returns None. |
| 86 | + """A simple test that does nothing but ensures fixture setup/teardown works.""" |
60 | 87 | pass |
61 | 88 |
|
62 | 89 |
|
@@ -107,7 +134,9 @@ def test_make_cascading_boolean_array(self): |
107 | 134 |
|
108 | 135 |
|
109 | 136 | class TestTestingSlippage( |
110 | | - WithConstantEquityMinuteBarData, WithDataPortal, ZiplineTestCase |
| 137 | + WithConstantEquityMinuteBarData, |
| 138 | + WithDataPortal, |
| 139 | + ZiplineTestCase, # Add ZiplineTestCase back as a base class |
111 | 140 | ): |
112 | 141 | ASSET_FINDER_EQUITY_SYMBOLS = ("A",) |
113 | 142 | ASSET_FINDER_EQUITY_SIDS = (1,) |
|
0 commit comments