Skip to content

Commit 8c63bce

Browse files
committed
[ty] Include conditional symbols (like datetime.UTC) in auto-import in more cases
This is a quick fix to make conditional symbols in auto-import work. In essence, this flips the current failure mode of "don't suggest a symbol that is available" to "possibly suggest a symbol that isn't available." I think suggesting a symbol that isn't available is probably the better failure mode. Fixes #1758, Ref #2795
1 parent 46be943 commit 8c63bce

2 files changed

Lines changed: 120 additions & 11 deletions

File tree

crates/ty_ide/src/symbols.rs

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,12 +1101,30 @@ impl<'db> SymbolVisitor<'db> {
11011101
}
11021102

11031103
/// Updates the origin of `__all__` in the current module.
1104-
///
1105-
/// This will clear existing names if the origin is changed to
1106-
/// mimic the behavior of overriding `__all__` in the current
1107-
/// module.
11081104
fn update_all_origin(&mut self, origin: DunderAllOrigin) {
1109-
if self.all_origin.is_some() {
1105+
// N.B. This used to clear `all_names` whenever there
1106+
// was *any* previous origin set. Now we skip clearing
1107+
// it if the previous origin and the new origin are
1108+
// both "current module." This tends to arise in situation
1109+
// like this:
1110+
//
1111+
// if sys.version > ...:
1112+
// __all__ = ['SomeFancyNewSymbol']
1113+
// else:
1114+
// __all__ = []
1115+
//
1116+
// Clearing is arguably correct here, but auto-import
1117+
// (unlike ty's own __all__ handling) doesn't yet know
1118+
// how to evaluate conditionals. So instead of over-writing
1119+
// __all__, we union it. This will produce incorrect
1120+
// results in some cases, but the failure mode will be
1121+
// "suggests symbol that doesn't exist" instead of
1122+
// "doesn't suggest symbol that does exist." The former
1123+
// seems preferable (until we know how to evaluate at
1124+
// least some rudimentary conditionals).
1125+
if !(matches!(self.all_origin, Some(DunderAllOrigin::CurrentModule))
1126+
&& matches!(origin, DunderAllOrigin::CurrentModule))
1127+
{
11101128
self.all_names.clear();
11111129
}
11121130
self.all_origin = Some(origin);
@@ -1448,6 +1466,7 @@ mod tests {
14481466
use ruff_db::Db;
14491467
use ruff_db::files::{FileRootKind, system_path_to_file};
14501468
use ruff_db::system::{DbWithWritableSystem, SystemPath, SystemPathBuf};
1469+
use ruff_python_ast::PythonVersion;
14511470
use ruff_python_trivia::textwrap::dedent;
14521471
use ty_project::{ProjectMetadata, TestDb};
14531472

@@ -2713,6 +2732,79 @@ class X:
27132732
);
27142733
}
27152734

2735+
/// Tests that a work-around which unions `__all__` values lets
2736+
/// us find conditionally exported symbols.
2737+
///
2738+
/// However, this also means that we may suggest exported symbols
2739+
/// even when they aren't available. What we should ideally do is
2740+
/// evaluate the conditional like ty does, but this requires some
2741+
/// work. (And it's unlikely auto-import will ever get the full
2742+
/// evaluation capabilities as ty, so it's likely this sort of
2743+
/// union work-around will always be something we do in at least
2744+
/// some cases.)
2745+
#[test]
2746+
fn union_all_to_work_around_conditional_symbols_py311() {
2747+
let test = PublicTestBuilder::default()
2748+
.python_version(PythonVersion::PY311)
2749+
.source(
2750+
"test.py",
2751+
"
2752+
import sys
2753+
if sys.version_info >= (3, 11):
2754+
ZQZQZQ = 1
2755+
else:
2756+
ZYZYZY = 1
2757+
2758+
if sys.version_info >= (3, 11):
2759+
__all__ = ['ZQZQZQ']
2760+
else:
2761+
__all__ = ['ZYZYZY']
2762+
",
2763+
)
2764+
.build();
2765+
// Ideally this would only have `ZQZQZQ`.
2766+
insta::assert_snapshot!(
2767+
test.exports(),
2768+
@r"
2769+
ZQZQZQ :: Constant
2770+
ZYZYZY :: Constant
2771+
",
2772+
);
2773+
}
2774+
2775+
/// Like `union_all_to_work_around_conditional_symbols_py311`, but
2776+
/// sets the environment Python version to 3.10 so that the conditional
2777+
/// should evaluate to false.
2778+
#[test]
2779+
fn union_all_to_work_around_conditional_symbols_py310() {
2780+
let test = PublicTestBuilder::default()
2781+
.python_version(PythonVersion::PY310)
2782+
.source(
2783+
"test.py",
2784+
"
2785+
import sys
2786+
if sys.version_info >= (3, 11):
2787+
ZQZQZQ = 1
2788+
else:
2789+
ZYZYZY = 1
2790+
2791+
if sys.version_info >= (3, 11):
2792+
__all__ = ['ZQZQZQ']
2793+
else:
2794+
__all__ = ['ZYZYZY']
2795+
",
2796+
)
2797+
.build();
2798+
// Ideally this would only have `ZYZYZY`.
2799+
insta::assert_snapshot!(
2800+
test.exports(),
2801+
@r"
2802+
ZQZQZQ :: Constant
2803+
ZYZYZY :: Constant
2804+
",
2805+
);
2806+
}
2807+
27162808
#[test]
27172809
fn deprecated_function() {
27182810
let test = public_test(
@@ -2848,16 +2940,17 @@ class C: ...
28482940
/// A list of source files, corresponding to the
28492941
/// file's path and its contents.
28502942
sources: Vec<Source>,
2943+
/// The python version to use.
2944+
python_version: Option<PythonVersion>,
28512945
}
28522946

28532947
impl PublicTestBuilder {
28542948
pub(super) fn build(&self) -> PublicTest {
2855-
let mut db = TestDb::new(ProjectMetadata::new(
2856-
"test".into(),
2857-
SystemPathBuf::from("/"),
2858-
));
2949+
let metadata = ProjectMetadata::new("test".into(), SystemPathBuf::from("/"));
2950+
let mut db = TestDb::new(metadata);
28592951

2860-
db.init_program().unwrap();
2952+
db.init_program_with_python_version(self.python_version.unwrap_or_default())
2953+
.unwrap();
28612954

28622955
for Source { path, contents } in &self.sources {
28632956
db.write_file(path, contents)
@@ -2897,6 +2990,11 @@ class C: ...
28972990
self.sources.push(Source { path, contents });
28982991
self
28992992
}
2993+
2994+
pub(super) fn python_version(&mut self, version: PythonVersion) -> &mut PublicTestBuilder {
2995+
self.python_version = Some(version);
2996+
self
2997+
}
29002998
}
29012999

29023000
struct Source {

crates/ty_project/src/db.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ pub(crate) mod tests {
542542
use ruff_db::files::{FileRootKind, Files};
543543
use ruff_db::system::{DbWithTestSystem, System, TestSystem};
544544
use ruff_db::vendored::VendoredFileSystem;
545+
use ruff_python_ast::PythonVersion;
545546
use ty_module_resolver::SearchPathSettings;
546547
use ty_python_semantic::lint::{LintRegistry, RuleSelection};
547548
use ty_python_semantic::{
@@ -588,6 +589,13 @@ pub(crate) mod tests {
588589
}
589590

590591
pub fn init_program(&mut self) -> anyhow::Result<()> {
592+
self.init_program_with_python_version(PythonVersion::latest_ty())
593+
}
594+
595+
pub fn init_program_with_python_version(
596+
&mut self,
597+
python_version: PythonVersion,
598+
) -> anyhow::Result<()> {
591599
let root = self.project().root(self);
592600

593601
let search_paths = SearchPathSettings::new(vec![root.to_path_buf()])
@@ -597,7 +605,10 @@ pub(crate) mod tests {
597605
Program::from_settings(
598606
self,
599607
ProgramSettings {
600-
python_version: PythonVersionWithSource::default(),
608+
python_version: PythonVersionWithSource {
609+
source: ty_python_semantic::PythonVersionSource::Default,
610+
version: python_version,
611+
},
601612
python_platform: PythonPlatform::default(),
602613
search_paths,
603614
},

0 commit comments

Comments
 (0)