66
77from __future__ import annotations
88
9- import sys
10- from collections . abc import Callable , Iterator
9+ from collections . abc import Generator
10+ from typing import Any , TypeVar
1111
1212from astroid .context import InferenceContext
1313from astroid .exceptions import InferenceOverwriteError , UseInferenceDefault
1414from astroid .nodes import NodeNG
15- from astroid .typing import InferenceResult , InferFn
16-
17- if sys .version_info >= (3 , 11 ):
18- from typing import ParamSpec
19- else :
20- from typing_extensions import ParamSpec
21-
22- _P = ParamSpec ("_P" )
15+ from astroid .typing import (
16+ InferenceResult ,
17+ InferFn ,
18+ TransformFn ,
19+ )
2320
2421_cache : dict [
25- tuple [InferFn , NodeNG , InferenceContext | None ], list [InferenceResult ]
22+ tuple [InferFn [ Any ] , NodeNG , InferenceContext | None ], list [InferenceResult ]
2623] = {}
2724
28- _CURRENTLY_INFERRING : set [tuple [InferFn , NodeNG ]] = set ()
25+ _CURRENTLY_INFERRING : set [tuple [InferFn [Any ], NodeNG ]] = set ()
26+
27+ _NodesT = TypeVar ("_NodesT" , bound = NodeNG )
2928
3029
3130def clear_inference_tip_cache () -> None :
3231 """Clear the inference tips cache."""
3332 _cache .clear ()
3433
3534
36- def _inference_tip_cached (
37- func : Callable [_P , Iterator [InferenceResult ]],
38- ) -> Callable [_P , Iterator [InferenceResult ]]:
35+ def _inference_tip_cached (func : InferFn [_NodesT ]) -> InferFn [_NodesT ]:
3936 """Cache decorator used for inference tips."""
4037
41- def inner (* args : _P .args , ** kwargs : _P .kwargs ) -> Iterator [InferenceResult ]:
42- node = args [0 ]
43- context = args [1 ]
38+ def inner (
39+ node : _NodesT ,
40+ context : InferenceContext | None = None ,
41+ ** kwargs : Any ,
42+ ) -> Generator [InferenceResult , None , None ]:
4443 partial_cache_key = (func , node )
4544 if partial_cache_key in _CURRENTLY_INFERRING :
4645 # If through recursion we end up trying to infer the same
4746 # func + node we raise here.
4847 raise UseInferenceDefault
4948 try :
50- return _cache [func , node , context ]
49+ yield from _cache [func , node , context ]
50+ return
5151 except KeyError :
5252 # Recursion guard with a partial cache key.
5353 # Using the full key causes a recursion error on PyPy.
5454 # It's a pragmatic compromise to avoid so much recursive inference
5555 # with slightly different contexts while still passing the simple
5656 # test cases included with this commit.
5757 _CURRENTLY_INFERRING .add (partial_cache_key )
58- result = _cache [func , node , context ] = list (func (* args , ** kwargs ))
58+ result = _cache [func , node , context ] = list (func (node , context , ** kwargs ))
5959 # Remove recursion guard.
6060 _CURRENTLY_INFERRING .remove (partial_cache_key )
6161
62- return iter ( result )
62+ yield from result
6363
6464 return inner
6565
6666
67- def inference_tip (infer_function : InferFn , raise_on_overwrite : bool = False ) -> InferFn :
67+ def inference_tip (
68+ infer_function : InferFn [_NodesT ], raise_on_overwrite : bool = False
69+ ) -> TransformFn [_NodesT ]:
6870 """Given an instance specific inference function, return a function to be
6971 given to AstroidManager().register_transform to set this inference function.
7072
@@ -86,7 +88,9 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) ->
8688 excess overwrites.
8789 """
8890
89- def transform (node : NodeNG , infer_function : InferFn = infer_function ) -> NodeNG :
91+ def transform (
92+ node : _NodesT , infer_function : InferFn [_NodesT ] = infer_function
93+ ) -> _NodesT :
9094 if (
9195 raise_on_overwrite
9296 and node ._explicit_inference is not None
@@ -100,7 +104,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
100104 node = node ,
101105 )
102106 )
103- # pylint: disable=no-value-for-parameter
104107 node ._explicit_inference = _inference_tip_cached (infer_function )
105108 return node
106109
0 commit comments