Skip to content

Commit 70b058e

Browse files
asukaminato0721meta-codesync[bot]
authored andcommitted
fix narrowing on dict (and typed dict) .get (#1390)
Summary: fix #238 Taught the narrowing logic to treat dict.get("literal") calls as key facets, so truthiness and is not None checks now refine the corresponding dictionary entry Added regression coverage for both explicit is not None and plain truthiness checks on literal-key .get calls to confirm the dictionary entry narrows as expected Pull Request resolved: #1390 Reviewed By: stroxler Differential Revision: D85959099 Pulled By: yangdanny97 fbshipit-source-id: a7d1b4f5983f1f21486a837f80e0001e85128adf
1 parent f476808 commit 70b058e

File tree

5 files changed

+276
-26
lines changed

5 files changed

+276
-26
lines changed

crates/pyrefly_types/src/types.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,10 @@ impl Type {
888888
matches!(self, Type::Any(_))
889889
}
890890

891+
pub fn is_typed_dict(&self) -> bool {
892+
matches!(self, Type::TypedDict(_) | Type::PartialTypedDict(_))
893+
}
894+
891895
pub fn is_error(&self) -> bool {
892896
matches!(self, Type::Any(AnyStyle::Error))
893897
}

pyrefly/lib/alt/expr.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use ruff_python_ast::ExprTuple;
3838
use ruff_python_ast::Identifier;
3939
use ruff_python_ast::Keyword;
4040
use ruff_python_ast::Number;
41+
use ruff_python_ast::StringLiteralValue;
4142
use ruff_python_ast::name::Name;
4243
use ruff_text_size::Ranged;
4344
use ruff_text_size::TextRange;
@@ -542,6 +543,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
542543
let callee_ty = self.expr_infer(&x.func, errors);
543544
if let Some(d) = self.call_to_dict(&callee_ty, &x.arguments) {
544545
self.dict_infer(&d, hint, x.range, errors)
546+
} else if let Some((obj_ty, key)) =
547+
self.is_dict_get_with_literal(&x.func, &x.arguments, errors)
548+
{
549+
obj_ty
550+
.at_facet(&FacetKind::Key(key.to_string()), || {
551+
self.expr_call_infer(x, callee_ty.clone(), hint, errors)
552+
})
553+
.into_ty()
545554
} else {
546555
self.expr_call_infer(x, callee_ty, hint, errors)
547556
}
@@ -971,6 +980,48 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
971980
}))
972981
}
973982

983+
// Is this a call to `dict.get` with a single string literal argument
984+
fn is_dict_get_with_literal(
985+
&self,
986+
func: &Expr,
987+
args: &Arguments,
988+
errors: &ErrorCollector,
989+
) -> Option<(TypeInfo, StringLiteralValue)> {
990+
let Expr::Attribute(attr_expr) = func else {
991+
return None;
992+
};
993+
if attr_expr.attr.id.as_str() != "get" {
994+
return None;
995+
}
996+
if args.args.len() != 1 {
997+
return None;
998+
}
999+
let Expr::StringLiteral(ExprStringLiteral { value: key, .. }) = &args.args[0] else {
1000+
return None;
1001+
};
1002+
let obj_ty = self.expr_infer_type_info_with_hint(&attr_expr.value, None, errors);
1003+
if self.is_dict_like(obj_ty.ty()) {
1004+
Some((obj_ty, key.clone()))
1005+
} else {
1006+
None
1007+
}
1008+
}
1009+
1010+
// Is this type a `TypedDict` or subtype of `dict`, but not `Any`?
1011+
pub fn is_dict_like(&self, ty: &Type) -> bool {
1012+
if ty.is_any() {
1013+
return false;
1014+
}
1015+
if ty.is_typed_dict() {
1016+
return true;
1017+
}
1018+
let dict_type = self
1019+
.stdlib
1020+
.dict(Type::any_implicit(), Type::any_implicit())
1021+
.to_type();
1022+
self.is_subset_eq(ty, &dict_type)
1023+
}
1024+
9741025
/// Determine the boolean behavior of a type:
9751026
/// - `Some(true)` or `Some(false)` when it is known to be statically truthy
9761027
/// or falsey (as determined by some baked in rules for literals

pyrefly/lib/alt/narrow.rs

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use crate::alt::answers_solver::AnswersSolver;
2626
use crate::alt::callable::CallArg;
2727
use crate::alt::callable::CallKeyword;
2828
use crate::binding::narrow::AtomicNarrowOp;
29+
use crate::binding::narrow::FacetOrigin;
30+
use crate::binding::narrow::FacetSubject;
2931
use crate::binding::narrow::NarrowOp;
3032
use crate::error::collector::ErrorCollector;
3133
use crate::error::style::ErrorStyle;
@@ -821,15 +823,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
821823
match op {
822824
NarrowOp::Atomic(subject, AtomicNarrowOp::HasAttr(attr)) => {
823825
let base_ty = match subject {
824-
Some(facet_chain) => self.get_facet_chain_type(type_info, facet_chain, range),
826+
Some(facet_subject) => {
827+
self.get_facet_chain_type(type_info, &facet_subject.chain, range)
828+
}
825829
None => type_info.ty().clone(),
826830
};
827831
// We only narrow the attribute to `Any` if the attribute does not exist
828832
if !self.has_attr(&base_ty, attr) {
829833
let attr_facet = FacetKind::Attribute(attr.clone());
830834
let facets = match subject {
831-
Some(chain) => {
832-
let mut new_facets = chain.facets().clone();
835+
Some(facet_subject) => {
836+
let mut new_facets = facet_subject.chain.facets().clone();
833837
new_facets.push(attr_facet);
834838
new_facets
835839
}
@@ -851,15 +855,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
851855
return type_info.clone();
852856
}
853857
let base_ty = match subject {
854-
Some(facet_chain) => self.get_facet_chain_type(type_info, facet_chain, range),
858+
Some(facet_subject) => {
859+
self.get_facet_chain_type(type_info, &facet_subject.chain, range)
860+
}
855861
None => type_info.ty().clone(),
856862
};
857863
let attr_ty =
858864
self.attr_infer_for_type(&base_ty, attr, range, &suppress_errors, None);
859865
let attr_facet = FacetKind::Attribute(attr.clone());
860866
let facets = match subject {
861-
Some(chain) => {
862-
let mut new_facets = chain.facets().clone();
867+
Some(facet_subject) => {
868+
let mut new_facets = facet_subject.chain.facets().clone();
863869
new_facets.push(attr_facet);
864870
new_facets
865871
}
@@ -884,16 +890,21 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
884890
let ty = self.atomic_narrow(type_info.ty(), op, range, errors);
885891
type_info.clone().with_ty(ty)
886892
}
887-
NarrowOp::Atomic(Some(facet_chain), op) => {
893+
NarrowOp::Atomic(Some(facet_subject), op) => {
894+
if facet_subject.origin == FacetOrigin::GetMethod
895+
&& !self.supports_dict_get_subject(type_info, facet_subject, range)
896+
{
897+
return type_info.clone();
898+
}
888899
let ty = self.atomic_narrow(
889-
&self.get_facet_chain_type(type_info, facet_chain, range),
900+
&self.get_facet_chain_type(type_info, &facet_subject.chain, range),
890901
op,
891902
range,
892903
errors,
893904
);
894-
let mut narrowed = type_info.with_narrow(facet_chain.facets(), ty);
905+
let mut narrowed = type_info.with_narrow(facet_subject.chain.facets(), ty);
895906
// For certain types of narrows, we can also narrow the parent of the current subject
896-
if let Some((last, prefix)) = facet_chain.facets().split_last() {
907+
if let Some((last, prefix)) = facet_subject.chain.facets().split_last() {
897908
match Vec1::try_from(prefix) {
898909
Ok(prefix_facets) => {
899910
let prefix_chain = FacetChain::new(prefix_facets);
@@ -939,4 +950,32 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
939950
),
940951
}
941952
}
953+
954+
/// We only narrow `x.get("key")` if `x` resolves to a `dict`
955+
fn supports_dict_get_subject(
956+
&self,
957+
type_info: &TypeInfo,
958+
subject: &FacetSubject,
959+
range: TextRange,
960+
) -> bool {
961+
let base_ty = if subject.chain.facets().len() == 1 {
962+
type_info.ty().clone()
963+
} else {
964+
let prefix: Vec<_> = subject
965+
.chain
966+
.facets()
967+
.iter()
968+
.take(subject.chain.facets().len() - 1)
969+
.cloned()
970+
.collect();
971+
match Vec1::try_from_vec(prefix) {
972+
Ok(vec1) => {
973+
let prefix_chain = FacetChain::new(vec1);
974+
self.get_facet_chain_type(type_info, &prefix_chain, range)
975+
}
976+
Err(_) => return false,
977+
}
978+
};
979+
self.is_dict_like(&base_ty)
980+
}
942981
}

pyrefly/lib/binding/narrow.rs

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::types::facet::FacetKind;
4444
use crate::types::types::Type;
4545

4646
assert_words!(AtomicNarrowOp, 11);
47-
assert_words!(NarrowOp, 12);
47+
assert_words!(NarrowOp, 13);
4848

4949
#[derive(Clone, Debug)]
5050
pub enum AtomicNarrowOp {
@@ -92,7 +92,7 @@ pub enum AtomicNarrowOp {
9292

9393
#[derive(Clone, Debug)]
9494
pub enum NarrowOp {
95-
Atomic(Option<FacetChain>, AtomicNarrowOp),
95+
Atomic(Option<FacetSubject>, AtomicNarrowOp),
9696
And(Vec<NarrowOp>),
9797
Or(Vec<NarrowOp>),
9898
}
@@ -176,7 +176,7 @@ impl DisplayWith<ModuleInfo> for NarrowOp {
176176
match self {
177177
Self::Atomic(prop, op) => match prop {
178178
None => write!(f, "{}", op.display_with(ctx)),
179-
Some(prop) => write!(f, "[{prop}] {}", op.display_with(ctx)),
179+
Some(prop) => write!(f, "[{}] {}", prop.chain, op.display_with(ctx)),
180180
},
181181
Self::And(ops) => {
182182
write!(
@@ -234,19 +234,45 @@ impl AtomicNarrowOp {
234234
}
235235
}
236236

237+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
238+
pub enum FacetOrigin {
239+
// This facet is a direct access, like `x.y`, `x[0]`, `x["key"]`
240+
Direct,
241+
// This facet came from a call to a `get` method, like `x.get("key")`
242+
GetMethod,
243+
}
244+
245+
#[derive(Clone, Debug)]
246+
pub struct FacetSubject {
247+
pub chain: FacetChain,
248+
pub origin: FacetOrigin,
249+
}
250+
237251
#[derive(Clone, Debug)]
238252
pub enum NarrowingSubject {
239253
Name(Name),
240-
Facets(Name, FacetChain),
254+
Facets(Name, FacetSubject),
241255
}
242256

243257
impl NarrowingSubject {
244258
pub fn with_facet(&self, prop: FacetKind) -> Self {
245259
match self {
246-
Self::Name(name) => Self::Facets(name.clone(), FacetChain::new(Vec1::new(prop))),
247-
Self::Facets(name, props) => {
248-
let props = Vec1::from_vec_push(props.facets().to_vec(), prop);
249-
Self::Facets(name.clone(), FacetChain::new(props))
260+
Self::Name(name) => Self::Facets(
261+
name.clone(),
262+
FacetSubject {
263+
chain: FacetChain::new(Vec1::new(prop)),
264+
origin: FacetOrigin::Direct,
265+
},
266+
),
267+
Self::Facets(name, facets) => {
268+
let props = Vec1::from_vec_push(facets.chain.facets().to_vec(), prop);
269+
Self::Facets(
270+
name.clone(),
271+
FacetSubject {
272+
chain: FacetChain::new(props),
273+
origin: facets.origin,
274+
},
275+
)
250276
}
251277
}
252278
}
@@ -372,7 +398,7 @@ impl NarrowOps {
372398
for subject in expr_to_subjects(left) {
373399
let (name, prop) = match subject {
374400
NarrowingSubject::Name(name) => (name, None),
375-
NarrowingSubject::Facets(name, prop) => (name, Some(prop)),
401+
NarrowingSubject::Facets(name, facets) => (name, Some(facets)),
376402
};
377403
if let Some((existing, _)) = narrow_ops.0.get_mut(&name) {
378404
existing.and(NarrowOp::Atomic(prop, op.clone()));
@@ -393,7 +419,7 @@ impl NarrowOps {
393419
let mut narrow_ops = Self::new();
394420
let (name, prop) = match subject {
395421
NarrowingSubject::Name(name) => (name, None),
396-
NarrowingSubject::Facets(name, prop) => (name, Some(prop)),
422+
NarrowingSubject::Facets(name, facets) => (name, Some(facets)),
397423
};
398424
if let Some((existing, _)) = narrow_ops.0.get_mut(&name) {
399425
existing.and(NarrowOp::Atomic(prop, op.clone()));
@@ -587,6 +613,12 @@ impl NarrowOps {
587613
*range,
588614
)
589615
}
616+
Some(e @ Expr::Call(call)) if dict_get_subject_for_call_expr(call).is_some() => {
617+
// When the guard is something like `x.get("key")`, we narrow it like `x["key"]` if `x` resolves to a dict
618+
// in the answers step.
619+
// This cannot be a TypeGuard/TypeIs function call, since the first argument is a string literal
620+
Self::from_single_narrow_op(e, AtomicNarrowOp::IsTruthy, e.range())
621+
}
590622
Some(Expr::Call(ExprCall {
591623
node_index: _,
592624
range,
@@ -770,16 +802,63 @@ pub fn identifier_and_chain_prefix_for_expr(expr: &Expr) -> Option<(Identifier,
770802
f(expr, Vec::new())
771803
}
772804

773-
fn subject_for_expr(expr: &Expr) -> Option<NarrowingSubject> {
774-
identifier_and_chain_for_expr(expr)
775-
.map(|(identifier, attr)| NarrowingSubject::Facets(identifier.id, attr))
805+
// Handle narrowing on `dict.get("key")`. During solving, if the resolved
806+
// type of the object is not a subtype of `dict`, we will not perform any narrowing.
807+
fn dict_get_subject_for_call_expr(call_expr: &ExprCall) -> Option<NarrowingSubject> {
808+
let func = &call_expr.func;
809+
let arguments = &call_expr.arguments;
810+
if arguments.keywords.is_empty()
811+
&& arguments.args.len() == 1
812+
&& let Some(first_arg) = arguments.args.first()
813+
&& let Expr::Attribute(attr) = &**func
814+
&& attr.attr.id.as_str() == "get"
815+
&& let Expr::StringLiteral(ExprStringLiteral { value, .. }) = first_arg
816+
{
817+
let key = value.to_string();
818+
if let Some((identifier, facets)) = identifier_and_chain_for_expr(&attr.value) {
819+
// x.y.z.get("key")
820+
let props = Vec1::from_vec_push(facets.facets().to_vec(), FacetKind::Key(key.clone()));
821+
return Some(NarrowingSubject::Facets(
822+
identifier.id,
823+
FacetSubject {
824+
chain: FacetChain::new(props),
825+
origin: FacetOrigin::GetMethod,
826+
},
827+
));
828+
} else if let Expr::Name(name) = &*attr.value {
829+
// x.get("key")
830+
return Some(NarrowingSubject::Facets(
831+
name.id.clone(),
832+
FacetSubject {
833+
chain: FacetChain::new(Vec1::new(FacetKind::Key(key))),
834+
origin: FacetOrigin::GetMethod,
835+
},
836+
));
837+
}
838+
}
839+
None
776840
}
777841

778842
pub fn expr_to_subjects(expr: &Expr) -> Vec<NarrowingSubject> {
779843
fn f(expr: &Expr, res: &mut Vec<NarrowingSubject>) {
780844
match expr {
781845
Expr::Name(name) => res.push(NarrowingSubject::Name(name.id.clone())),
782-
Expr::Attribute(_) | Expr::Subscript(_) => res.extend(subject_for_expr(expr)),
846+
Expr::Attribute(_) | Expr::Subscript(_) => {
847+
if let Some((identifier, facets)) = identifier_and_chain_for_expr(expr) {
848+
res.push(NarrowingSubject::Facets(
849+
identifier.id,
850+
FacetSubject {
851+
chain: facets,
852+
origin: FacetOrigin::Direct,
853+
},
854+
));
855+
}
856+
}
857+
Expr::Call(call) => {
858+
if let Some(subject) = dict_get_subject_for_call_expr(call) {
859+
res.push(subject);
860+
}
861+
}
783862
Expr::Named(ExprNamed { target, value, .. }) => {
784863
f(target, res);
785864
f(value, res);

0 commit comments

Comments
 (0)