Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def infer_name(
return bases._infer_stmts(stmts, context, frame)


# pylint: disable=no-value-for-parameter
# The order of the decorators here is important
# See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5
nodes.Name._infer = decorators.raise_if_nothing_inferred(
Expand Down
51 changes: 27 additions & 24 deletions astroid/inference_tip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,67 @@

from __future__ import annotations

import sys
from collections.abc import Callable, Iterator
from collections.abc import Generator
from typing import Any, TypeVar

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import InferenceResult, InferFn

if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")
from astroid.typing import (
InferenceResult,
InferFn,
TransformFn,
)

_cache: dict[
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult]
] = {}

_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set()
_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set()

_NodesT = TypeVar("_NodesT", bound=NodeNG)


def clear_inference_tip_cache() -> None:
"""Clear the inference tips cache."""
_cache.clear()


def _inference_tip_cached(
func: Callable[_P, Iterator[InferenceResult]],
) -> Callable[_P, Iterator[InferenceResult]]:
def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]:
"""Cache decorator used for inference tips."""

def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
node = args[0]
context = args[1]
def inner(
node: _NodesT,
context: InferenceContext | None = None,
**kwargs: Any,
) -> Generator[InferenceResult, None, None]:
partial_cache_key = (func, node)
if partial_cache_key in _CURRENTLY_INFERRING:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but isn't this what the path wrapper does? Can we use that here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but at glance, it looks like that one is sensitive to specific InferenceContexts. Here, we're not ready to unleash recursive inference with every slightly different context.

# If through recursion we end up trying to infer the same
# func + node we raise here.
raise UseInferenceDefault
try:
return _cache[func, node, context]
yield from _cache[func, node, context]
return
Comment thread
jacobtylerwalls marked this conversation as resolved.
except KeyError:
# Recursion guard with a partial cache key.
# Using the full key causes a recursion error on PyPy.
# It's a pragmatic compromise to avoid so much recursive inference
# with slightly different contexts while still passing the simple
# test cases included with this commit.
_CURRENTLY_INFERRING.add(partial_cache_key)
result = _cache[func, node, context] = list(func(*args, **kwargs))
_cache[func, node, context] = list(func(node, context, **kwargs))
# Remove recursion guard.
_CURRENTLY_INFERRING.remove(partial_cache_key)

return iter(result)
yield from _cache[func, node, context]
Comment thread
jacobtylerwalls marked this conversation as resolved.
Outdated

return inner


def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
def inference_tip(
infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False
) -> TransformFn[_NodesT]:
"""Given an instance specific inference function, return a function to be
given to AstroidManager().register_transform to set this inference function.

Expand All @@ -86,7 +88,9 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) ->
excess overwrites.
"""

def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
def transform(
node: _NodesT, infer_function: InferFn[_NodesT] = infer_function
) -> _NodesT:
if (
raise_on_overwrite
and node._explicit_inference is not None
Expand All @@ -100,7 +104,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
node=node,
)
)
# pylint: disable=no-value-for-parameter
node._explicit_inference = _inference_tip_cached(infer_function)
return node

Expand Down
9 changes: 8 additions & 1 deletion astroid/nodes/node_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import pprint
import sys
import warnings
from collections.abc import Generator, Iterator
from functools import cached_property
Expand Down Expand Up @@ -37,6 +38,12 @@
from astroid.nodes.utils import Position
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


if TYPE_CHECKING:
from astroid import nodes

Expand Down Expand Up @@ -80,7 +87,7 @@ class NodeNG:
_other_other_fields: ClassVar[tuple[str, ...]] = ()
"""Attributes that contain AST-dependent fields."""
# instance specific inference function infer(node, context)
_explicit_inference: InferFn | None = None
_explicit_inference: InferFn[Self] | None = None

def __init__(
self,
Expand Down
11 changes: 4 additions & 7 deletions astroid/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload

from astroid.context import _invalidate_cache
from astroid.typing import SuccessfulInferenceResult
from astroid.typing import SuccessfulInferenceResult, TransformFn

if TYPE_CHECKING:
from astroid import nodes

_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
_Transform = Callable[
[_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult]
]
_Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]]

_Vistables = Union[
Expand Down Expand Up @@ -52,7 +49,7 @@ def __init__(self) -> None:
type[SuccessfulInferenceResult],
list[
tuple[
_Transform[SuccessfulInferenceResult],
TransformFn[SuccessfulInferenceResult],
_Predicate[SuccessfulInferenceResult],
]
],
Expand Down Expand Up @@ -123,7 +120,7 @@ def _visit_generic(self, node: _Vistables) -> _VisitReturns:
def register_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
transform: _Transform[_SuccessfulInferenceResultT],
transform: TransformFn[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Register `transform(node)` function to be applied on the given node.
Expand All @@ -139,7 +136,7 @@ def register_transform(
def unregister_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
transform: _Transform[_SuccessfulInferenceResultT],
transform: TransformFn[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Unregister the given transform."""
Expand Down
42 changes: 35 additions & 7 deletions astroid/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Generic,
Protocol,
TypedDict,
TypeVar,
Union,
)

if TYPE_CHECKING:
from astroid import bases, exceptions, nodes, transforms, util
from astroid.context import InferenceContext
from astroid.interpreter._import import spec


_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG")


class InferenceErrorInfo(TypedDict):
"""Store additional Inference error information
raised with StopIteration exception.
Expand All @@ -24,9 +31,6 @@ class InferenceErrorInfo(TypedDict):
context: InferenceContext | None


InferFn = Callable[..., Any]


class AstroidManagerBrain(TypedDict):
"""Dictionary to store relevant information for a AstroidManager class."""

Expand All @@ -46,6 +50,11 @@ class AstroidManagerBrain(TypedDict):
_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
_SuccessfulInferenceResultT_contra = TypeVar(
"_SuccessfulInferenceResultT_contra",
bound=SuccessfulInferenceResult,
contravariant=True,
)

ConstFactoryResult = Union[
"nodes.List",
Expand All @@ -67,3 +76,22 @@ class AstroidManagerBrain(TypedDict):
],
Generator[InferenceResult, None, None],
]


class InferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]):
def __call__(
self,
node: _SuccessfulInferenceResultT_contra,
context: InferenceContext | None = None,
**kwargs: Any,
) -> Generator[InferenceResult, None, None]:
...
Comment thread
jacobtylerwalls marked this conversation as resolved.
Outdated


class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]):
def __call__(
self,
node: _SuccessfulInferenceResultT,
infer_function: InferFn[_SuccessfulInferenceResultT] = ...,
) -> _SuccessfulInferenceResultT | None:
...