Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def infer_name(
return bases._infer_stmts(stmts, context, frame)


# pylint: disable=no-value-for-parameter
# The order of the decorators here is important
# See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5
nodes.Name._infer = decorators.raise_if_nothing_inferred(
Expand Down
32 changes: 19 additions & 13 deletions astroid/inference_tip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@

from __future__ import annotations

import sys
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING
Comment thread
jacobtylerwalls marked this conversation as resolved.
Outdated

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import InferenceResult, InferFn

if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")
from astroid.typing import (
_P,
Comment thread
jacobtylerwalls marked this conversation as resolved.
Outdated
InferenceResult,
InferFn,
InferFnExplicit,
InferFnTransform,
)

_cache: dict[
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
Expand All @@ -35,12 +34,18 @@ def clear_inference_tip_cache() -> None:

def _inference_tip_cached(
func: Callable[_P, Iterator[InferenceResult]],
) -> Callable[_P, Iterator[InferenceResult]]:
) -> InferFnExplicit:
"""Cache decorator used for inference tips."""

def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
def inner(
*args: _P.args, **kwargs: _P.kwargs
) -> Iterator[InferenceResult] | list[InferenceResult]:
node = args[0]
context = args[1]
if TYPE_CHECKING:
assert isinstance(node, NodeNG)
assert context is None or isinstance(context, InferenceContext)
Comment thread
jacobtylerwalls marked this conversation as resolved.
Outdated

partial_cache_key = (func, node)
if partial_cache_key in _CURRENTLY_INFERRING:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but isn't this what the path wrapper does? Can we use that here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but at glance, it looks like that one is sensitive to specific InferenceContexts. Here, we're not ready to unleash recursive inference with every slightly different context.

# If through recursion we end up trying to infer the same
Expand All @@ -64,7 +69,9 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
return inner


def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
def inference_tip(
infer_function: InferFn, raise_on_overwrite: bool = False
) -> InferFnTransform:
"""Given an instance specific inference function, return a function to be
given to AstroidManager().register_transform to set this inference function.

Expand Down Expand Up @@ -100,7 +107,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
node=node,
)
)
# pylint: disable=no-value-for-parameter
node._explicit_inference = _inference_tip_cached(infer_function)
return node

Expand Down
4 changes: 2 additions & 2 deletions astroid/nodes/node_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from astroid.nodes.as_string import AsStringVisitor
from astroid.nodes.const import OP_PRECEDENCE
from astroid.nodes.utils import Position
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFnExplicit

if TYPE_CHECKING:
from astroid import nodes
Expand Down Expand Up @@ -80,7 +80,7 @@ class NodeNG:
_other_other_fields: ClassVar[tuple[str, ...]] = ()
"""Attributes that contain AST-dependent fields."""
# instance specific inference function infer(node, context)
_explicit_inference: InferFn | None = None
_explicit_inference: InferFnExplicit | None = None

def __init__(
self,
Expand Down
23 changes: 19 additions & 4 deletions astroid/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,28 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union
import sys
from typing import (
TYPE_CHECKING,
Callable,
Generator,
Iterator,
TypedDict,
TypeVar,
Union,
)

if TYPE_CHECKING:
from astroid import bases, exceptions, nodes, transforms, util
from astroid.context import InferenceContext
from astroid.interpreter._import import spec

if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")
_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG")


Expand All @@ -24,9 +38,6 @@ class InferenceErrorInfo(TypedDict):
context: InferenceContext | None


InferFn = Callable[..., Any]


class AstroidManagerBrain(TypedDict):
"""Dictionary to store relevant information for a AstroidManager class."""

Expand Down Expand Up @@ -67,3 +78,7 @@ class AstroidManagerBrain(TypedDict):
],
Generator[InferenceResult, None, None],
]

InferFn = Callable[..., Iterator[InferenceResult]]
InferFnExplicit = Callable[_P, Union[Iterator[InferenceResult], list[InferenceResult]]]
InferFnTransform = Callable[[_NodesT, InferFn], _NodesT]