|
4 | 4 |
|
5 | 5 | """Functional tests for the code examples in the messages documentation.""" |
6 | 6 |
|
7 | | -from collections import Counter |
| 7 | +import sys |
| 8 | + |
| 9 | +if sys.version_info[:2] >= (3, 9): |
| 10 | + from collections import Counter |
| 11 | +else: |
| 12 | + from collections import Counter as _Counter |
| 13 | + |
| 14 | + class Counter(_Counter): |
| 15 | + def total(self): |
| 16 | + return len(tuple(self.elements())) |
| 17 | + |
| 18 | + |
8 | 19 | from pathlib import Path |
9 | 20 | from typing import Counter as CounterType |
10 | 21 | from typing import List, Optional, TextIO, Tuple |
@@ -75,6 +86,12 @@ def __init__(self, test_file: Tuple[str, Path]) -> None: |
75 | 86 | def runTest(self) -> None: |
76 | 87 | self._runTest() |
77 | 88 |
|
| 89 | + def is_good_test_file(self) -> bool: |
| 90 | + return self._test_file[1].name == "good.py" |
| 91 | + |
| 92 | + def is_bad_test_file(self) -> bool: |
| 93 | + return self._test_file[1].name == "bad.py" |
| 94 | + |
78 | 95 | @staticmethod |
79 | 96 | def get_expected_messages(stream: TextIO) -> MessageCounter: |
80 | 97 | """Parse a file and get expected messages.""" |
@@ -114,6 +131,10 @@ def _runTest(self) -> None: |
114 | 131 | self._linter.check([str(self._test_file[1])]) |
115 | 132 | expected_messages = self._get_expected() |
116 | 133 | actual_messages = self._get_actual() |
| 134 | + if self.is_good_test_file(): |
| 135 | + assert actual_messages.total() == 0 # type: ignore[attr-defined] |
| 136 | + if self.is_bad_test_file(): |
| 137 | + assert actual_messages.total() > 0 # type: ignore[attr-defined] |
117 | 138 | assert expected_messages == actual_messages |
118 | 139 |
|
119 | 140 |
|
|
0 commit comments