99import sys
1010from collections .abc import Callable , Iterator
1111
12+ from astroid .context import InferenceContext
1213from astroid .exceptions import InferenceOverwriteError , UseInferenceDefault
1314from astroid .nodes import NodeNG
1415from astroid .typing import InferenceResult , InferFn
2021
2122_P = ParamSpec ("_P" )
2223
23- _cache : dict [tuple [InferFn , NodeNG ], list [InferenceResult ] | None ] = {}
24+ _cache : dict [
25+ tuple [InferFn , NodeNG , InferenceContext | None ], list [InferenceResult ]
26+ ] = {}
27+
28+ _CURRENTLY_INFERRING : set [tuple [InferFn , NodeNG ]] = set ()
2429
2530
2631def clear_inference_tip_cache () -> None :
@@ -35,16 +40,25 @@ def _inference_tip_cached(
3540
3641 def inner (* args : _P .args , ** kwargs : _P .kwargs ) -> Iterator [InferenceResult ]:
3742 node = args [0 ]
38- try :
39- result = _cache [func , node ]
43+ context = args [1 ]
44+ partial_cache_key = (func , node )
45+ if partial_cache_key in _CURRENTLY_INFERRING :
4046 # If through recursion we end up trying to infer the same
4147 # func + node we raise here.
42- if result is None :
43- raise UseInferenceDefault ()
48+ raise UseInferenceDefault
49+ try :
50+ return _cache [func , node , context ]
4451 except KeyError :
45- _cache [func , node ] = None
46- result = _cache [func , node ] = list (func (* args , ** kwargs ))
47- assert result
52+ # Recursion guard with a partial cache key.
53+ # Using the full key causes a recursion error on PyPy.
54+ # It's a pragmatic compromise to avoid so much recursive inference
55+ # with slightly different contexts while still passing the simple
56+ # test cases included with this commit.
57+ _CURRENTLY_INFERRING .add (partial_cache_key )
58+ result = _cache [func , node , context ] = list (func (* args , ** kwargs ))
59+ # Remove recursion guard.
60+ _CURRENTLY_INFERRING .remove (partial_cache_key )
61+
4862 return iter (result )
4963
5064 return inner
0 commit comments