Skip to content

Commit 6c56a7a

Browse files
TomerBincarljm
andauthored
[red-knot] Implement type narrowing for boolean conditionals (#14037)
## Summary This PR enables red-knot to support type narrowing based on `and` and `or` conditionals, including nested combinations and their negation (for `elif` / `else` blocks and for `not` operator). Part of #13694. In order to address this properly (hopefully 😅), I had to run `NarrowingConstraintsBuilder` functions recursively. In the first commit I introduced a minor refactor - instead of mutating `self.constraints`, the new constraints are now returned as function return values. I also modified the constraints map to be optional, preventing unnecessary hashmap allocations. Thanks @carljm for your support on this :) The second commit contains the logic and tests for handling boolean ops, with improvements to intersections handling in `is_subtype_of` . As I'm still new to Rust and the internals of type checkers, I’d be more than happy to hear any insights or suggestions. Thank you! --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent bb25bd9 commit 6c56a7a

File tree

4 files changed

+591
-59
lines changed

4 files changed

+591
-59
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# Narrowing for conditionals with boolean expressions
2+
3+
## Narrowing in `and` conditional
4+
5+
```py
6+
class A: ...
7+
class B: ...
8+
9+
def instance() -> A | B:
10+
return A()
11+
12+
x = instance()
13+
14+
if isinstance(x, A) and isinstance(x, B):
15+
reveal_type(x) # revealed: A & B
16+
else:
17+
reveal_type(x) # revealed: B & ~A | A & ~B
18+
```
19+
20+
## Arms might not add narrowing constraints
21+
22+
```py
23+
class A: ...
24+
class B: ...
25+
26+
def bool_instance() -> bool:
27+
return True
28+
29+
def instance() -> A | B:
30+
return A()
31+
32+
x = instance()
33+
34+
if isinstance(x, A) and bool_instance():
35+
reveal_type(x) # revealed: A
36+
else:
37+
reveal_type(x) # revealed: A | B
38+
39+
if bool_instance() and isinstance(x, A):
40+
reveal_type(x) # revealed: A
41+
else:
42+
reveal_type(x) # revealed: A | B
43+
44+
reveal_type(x) # revealed: A | B
45+
```
46+
47+
## Statically known arms
48+
49+
```py
50+
class A: ...
51+
class B: ...
52+
53+
def instance() -> A | B:
54+
return A()
55+
56+
x = instance()
57+
58+
if isinstance(x, A) and True:
59+
reveal_type(x) # revealed: A
60+
else:
61+
reveal_type(x) # revealed: B & ~A
62+
63+
if True and isinstance(x, A):
64+
reveal_type(x) # revealed: A
65+
else:
66+
reveal_type(x) # revealed: B & ~A
67+
68+
if False and isinstance(x, A):
69+
# TODO: should emit an `unreachable code` diagnostic
70+
reveal_type(x) # revealed: A
71+
else:
72+
reveal_type(x) # revealed: A | B
73+
74+
if False or isinstance(x, A):
75+
reveal_type(x) # revealed: A
76+
else:
77+
reveal_type(x) # revealed: B & ~A
78+
79+
if True or isinstance(x, A):
80+
reveal_type(x) # revealed: A | B
81+
else:
82+
# TODO: should emit an `unreachable code` diagnostic
83+
reveal_type(x) # revealed: B & ~A
84+
85+
reveal_type(x) # revealed: A | B
86+
```
87+
88+
## The type of multiple symbols can be narrowed down
89+
90+
```py
91+
class A: ...
92+
class B: ...
93+
94+
def instance() -> A | B:
95+
return A()
96+
97+
x = instance()
98+
y = instance()
99+
100+
if isinstance(x, A) and isinstance(y, B):
101+
reveal_type(x) # revealed: A
102+
reveal_type(y) # revealed: B
103+
else:
104+
# No narrowing: Only-one or both checks might have failed
105+
reveal_type(x) # revealed: A | B
106+
reveal_type(y) # revealed: A | B
107+
108+
reveal_type(x) # revealed: A | B
109+
reveal_type(y) # revealed: A | B
110+
```
111+
112+
## Narrowing in `or` conditional
113+
114+
```py
115+
class A: ...
116+
class B: ...
117+
class C: ...
118+
119+
def instance() -> A | B | C:
120+
return A()
121+
122+
x = instance()
123+
124+
if isinstance(x, A) or isinstance(x, B):
125+
reveal_type(x) # revealed: A | B
126+
else:
127+
reveal_type(x) # revealed: C & ~A & ~B
128+
```
129+
130+
## In `or`, all arms should add constraint in order to narrow
131+
132+
```py
133+
class A: ...
134+
class B: ...
135+
class C: ...
136+
137+
def instance() -> A | B | C:
138+
return A()
139+
140+
def bool_instance() -> bool:
141+
return True
142+
143+
x = instance()
144+
145+
if isinstance(x, A) or isinstance(x, B) or bool_instance():
146+
reveal_type(x) # revealed: A | B | C
147+
else:
148+
reveal_type(x) # revealed: C & ~A & ~B
149+
```
150+
151+
## in `or`, all arms should narrow the same set of symbols
152+
153+
```py
154+
class A: ...
155+
class B: ...
156+
class C: ...
157+
158+
def instance() -> A | B | C:
159+
return A()
160+
161+
x = instance()
162+
y = instance()
163+
164+
if isinstance(x, A) or isinstance(y, A):
165+
# The predicate might be satisfied by the right side, so the type of `x` can’t be narrowed down here.
166+
reveal_type(x) # revealed: A | B | C
167+
# The same for `y`
168+
reveal_type(y) # revealed: A | B | C
169+
else:
170+
reveal_type(x) # revealed: B & ~A | C & ~A
171+
reveal_type(y) # revealed: B & ~A | C & ~A
172+
173+
if (isinstance(x, A) and isinstance(y, A)) or (isinstance(x, B) and isinstance(y, B)):
174+
# Here, types of `x` and `y` can be narrowd since all `or` arms constraint them.
175+
reveal_type(x) # revealed: A | B
176+
reveal_type(y) # revealed: A | B
177+
else:
178+
reveal_type(x) # revealed: A | B | C
179+
reveal_type(y) # revealed: A | B | C
180+
```
181+
182+
## mixing `and` and `not`
183+
184+
```py
185+
class A: ...
186+
class B: ...
187+
class C: ...
188+
189+
def instance() -> A | B | C:
190+
return A()
191+
192+
x = instance()
193+
194+
if isinstance(x, B) and not isinstance(x, C):
195+
reveal_type(x) # revealed: B & ~C
196+
else:
197+
# ~(B & ~C) -> ~B | C -> (A & ~B) | (C & ~B) | C -> (A & ~B) | C
198+
reveal_type(x) # revealed: A & ~B | C
199+
```
200+
201+
## mixing `or` and `not`
202+
203+
```py
204+
class A: ...
205+
class B: ...
206+
class C: ...
207+
208+
def instance() -> A | B | C:
209+
return A()
210+
211+
x = instance()
212+
213+
if isinstance(x, B) or not isinstance(x, C):
214+
reveal_type(x) # revealed: B | A & ~C
215+
else:
216+
reveal_type(x) # revealed: C & ~B
217+
```
218+
219+
## `or` with nested `and`
220+
221+
```py
222+
class A: ...
223+
class B: ...
224+
class C: ...
225+
226+
def instance() -> A | B | C:
227+
return A()
228+
229+
x = instance()
230+
231+
if isinstance(x, A) or (isinstance(x, B) and not isinstance(x, C)):
232+
reveal_type(x) # revealed: A | B & ~C
233+
else:
234+
# ~(A | (B & ~C)) -> ~A & ~(B & ~C) -> ~A & (~B | C) -> (~A & C) | (~A ~ B)
235+
reveal_type(x) # revealed: C & ~A
236+
```
237+
238+
## `and` with nested `or`
239+
240+
```py
241+
class A: ...
242+
class B: ...
243+
class C: ...
244+
245+
def instance() -> A | B | C:
246+
return A()
247+
248+
x = instance()
249+
250+
if isinstance(x, A) and (isinstance(x, B) or not isinstance(x, C)):
251+
# A & (B | ~C) -> (A & B) | (A & ~C)
252+
reveal_type(x) # revealed: A & B | A & ~C
253+
else:
254+
# ~((A & B) | (A & ~C)) ->
255+
# ~(A & B) & ~(A & ~C) ->
256+
# (~A | ~B) & (~A | C) ->
257+
# [(~A | ~B) & ~A] | [(~A | ~B) & C] ->
258+
# ~A | (~A & C) | (~B & C) ->
259+
# ~A | (C & ~B) ->
260+
# ~A | (C & ~B) The positive side of ~A is A | B | C ->
261+
reveal_type(x) # revealed: B & ~A | C & ~A | C & ~B
262+
```
263+
264+
## Boolean expression internal narrowing
265+
266+
```py
267+
def optional_string() -> str | None:
268+
return None
269+
270+
x = optional_string()
271+
y = optional_string()
272+
273+
if x is None and y is not x:
274+
reveal_type(y) # revealed: str
275+
276+
# Neither of the conditions alone is sufficient for narrowing y's type:
277+
if x is None:
278+
reveal_type(y) # revealed: str | None
279+
280+
if y is not x:
281+
reveal_type(y) # revealed: str | None
282+
```

crates/red_knot_python_semantic/src/types.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,46 @@ impl<'db> Type<'db> {
528528
.elements(db)
529529
.iter()
530530
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
531+
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
532+
// Check that all target positive values are covered in self positive values
533+
target_intersection
534+
.positive(db)
535+
.iter()
536+
.all(|&target_pos_elem| {
537+
self_intersection
538+
.positive(db)
539+
.iter()
540+
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
541+
})
542+
// Check that all target negative values are excluded in self, either by being
543+
// subtypes of a self negative value or being disjoint from a self positive value.
544+
&& target_intersection
545+
.negative(db)
546+
.iter()
547+
.all(|&target_neg_elem| {
548+
// Is target negative value is subtype of a self negative value
549+
self_intersection.negative(db).iter().any(|&self_neg_elem| {
550+
target_neg_elem.is_subtype_of(db, self_neg_elem)
551+
// Is target negative value is disjoint from a self positive value?
552+
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
553+
target_neg_elem.is_disjoint_from(db, self_pos_elem)
554+
})
555+
})
556+
}
557+
(Type::Intersection(intersection), ty) => intersection
558+
.positive(db)
559+
.iter()
560+
.any(|&elem_ty| elem_ty.is_subtype_of(db, ty)),
561+
(ty, Type::Intersection(intersection)) => {
562+
intersection
563+
.positive(db)
564+
.iter()
565+
.all(|&pos_ty| ty.is_subtype_of(db, pos_ty))
566+
&& intersection
567+
.negative(db)
568+
.iter()
569+
.all(|&neg_ty| neg_ty.is_disjoint_from(db, ty))
570+
}
531571
(Type::Instance(self_class), Type::Instance(target_class)) => {
532572
self_class.is_subclass_of(db, target_class)
533573
}
@@ -2190,6 +2230,11 @@ mod tests {
21902230
Ty::BuiltinInstance("FloatingPointError"),
21912231
Ty::BuiltinInstance("Exception")
21922232
)]
2233+
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::BuiltinInstance("int"))]
2234+
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
2235+
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
2236+
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]})]
2237+
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("str")], neg: vec![Ty::StringLiteral("foo")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
21932238
fn is_subtype_of(from: Ty, to: Ty) {
21942239
let db = setup_db();
21952240
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
@@ -2210,6 +2255,11 @@ mod tests {
22102255
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(42)]), Ty::Tuple(vec![Ty::BuiltinInstance("str")]))]
22112256
#[test_case(Ty::Tuple(vec![Ty::Todo]), Ty::Tuple(vec![Ty::IntLiteral(2)]))]
22122257
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::Todo]))]
2258+
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(3)]})]
2259+
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
2260+
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]})]
2261+
#[test_case(Ty::BuiltinInstance("int"), Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
2262+
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(1)]})]
22132263
fn is_not_subtype_of(from: Ty, to: Ty) {
22142264
let db = setup_db();
22152265
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
@@ -2241,6 +2291,34 @@ mod tests {
22412291
assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db)));
22422292
}
22432293

2294+
#[test]
2295+
fn is_subtype_of_intersection_of_class_instances() {
2296+
let mut db = setup_db();
2297+
db.write_dedented(
2298+
"/src/module.py",
2299+
"
2300+
class A: ...
2301+
a = A()
2302+
class B: ...
2303+
b = B()
2304+
",
2305+
)
2306+
.unwrap();
2307+
let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap();
2308+
2309+
let a_ty = super::global_symbol(&db, module, "a").expect_type();
2310+
let b_ty = super::global_symbol(&db, module, "b").expect_type();
2311+
let intersection = IntersectionBuilder::new(&db)
2312+
.add_positive(a_ty)
2313+
.add_positive(b_ty)
2314+
.build();
2315+
2316+
assert_eq!(intersection.display(&db).to_string(), "A & B");
2317+
assert!(!a_ty.is_subtype_of(&db, b_ty));
2318+
assert!(intersection.is_subtype_of(&db, b_ty));
2319+
assert!(intersection.is_subtype_of(&db, a_ty));
2320+
}
2321+
22442322
#[test_case(
22452323
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]),
22462324
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)])

0 commit comments

Comments
 (0)