Skip to content

Commit c433e88

Browse files
committed
Fix nontransitivity pointed out by property tests
1 parent cfa122a commit c433e88

File tree

1 file changed

+66
-27
lines changed
  • crates/red_knot_python_semantic/src

1 file changed

+66
-27
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -655,10 +655,12 @@ impl<'db> Type<'db> {
655655
.elements(db)
656656
.iter()
657657
.all(|&elem_ty| elem_ty.is_subtype_of(db, target)),
658+
658659
(_, Type::Union(union)) => union
659660
.elements(db)
660661
.iter()
661662
.any(|&elem_ty| self.is_subtype_of(db, elem_ty)),
663+
662664
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
663665
// Check that all target positive values are covered in self positive values
664666
target_intersection
@@ -685,10 +687,12 @@ impl<'db> Type<'db> {
685687
})
686688
})
687689
}
690+
688691
(Type::Intersection(intersection), _) => intersection
689692
.positive(db)
690693
.iter()
691694
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),
695+
692696
(_, Type::Intersection(intersection)) => {
693697
intersection
694698
.positive(db)
@@ -1006,6 +1010,8 @@ impl<'db> Type<'db> {
10061010
}
10071011
}
10081012

1013+
// any single-valued type is disjoint from another single-valued type
1014+
// iff the two types are nonequal
10091015
(
10101016
left @ (Type::BooleanLiteral(..)
10111017
| Type::IntLiteral(..)
@@ -1014,17 +1020,47 @@ impl<'db> Type<'db> {
10141020
| Type::SliceLiteral(..)
10151021
| Type::FunctionLiteral(..)
10161022
| Type::ModuleLiteral(..)
1017-
| Type::ClassLiteral(..)),
1023+
| Type::ClassLiteral(..)
1024+
| Type::KnownInstance(..)),
10181025
right @ (Type::BooleanLiteral(..)
10191026
| Type::IntLiteral(..)
10201027
| Type::StringLiteral(..)
10211028
| Type::BytesLiteral(..)
10221029
| Type::SliceLiteral(..)
10231030
| Type::FunctionLiteral(..)
10241031
| Type::ModuleLiteral(..)
1025-
| Type::ClassLiteral(..)),
1032+
| Type::ClassLiteral(..)
1033+
| Type::KnownInstance(..)),
10261034
) => left != right,
10271035

1036+
// One tuple type can be a subtype of another tuple type,
1037+
// but we know for sure that any given tuple type is disjoint from all single-valued types
1038+
(
1039+
Type::Tuple(..),
1040+
Type::ClassLiteral(..)
1041+
| Type::ModuleLiteral(..)
1042+
| Type::BooleanLiteral(..)
1043+
| Type::BytesLiteral(..)
1044+
| Type::FunctionLiteral(..)
1045+
| Type::IntLiteral(..)
1046+
| Type::SliceLiteral(..)
1047+
| Type::StringLiteral(..)
1048+
| Type::LiteralString,
1049+
) => true,
1050+
1051+
(
1052+
Type::ClassLiteral(..)
1053+
| Type::ModuleLiteral(..)
1054+
| Type::BooleanLiteral(..)
1055+
| Type::BytesLiteral(..)
1056+
| Type::FunctionLiteral(..)
1057+
| Type::IntLiteral(..)
1058+
| Type::SliceLiteral(..)
1059+
| Type::StringLiteral(..)
1060+
| Type::LiteralString,
1061+
Type::Tuple(..),
1062+
) => true,
1063+
10281064
(
10291065
Type::SubclassOf(SubclassOfType {
10301066
base: ClassBase::Class(class_a),
@@ -1037,10 +1073,13 @@ impl<'db> Type<'db> {
10371073
base: ClassBase::Class(class_a),
10381074
}),
10391075
) => !class_b.is_subclass_of(db, class_a),
1076+
10401077
(Type::SubclassOf(_), Type::SubclassOf(_)) => false,
1078+
10411079
(Type::SubclassOf(_), Type::Instance(_)) | (Type::Instance(_), Type::SubclassOf(_)) => {
10421080
false
10431081
}
1082+
10441083
(
10451084
Type::SubclassOf(_),
10461085
Type::BooleanLiteral(..)
@@ -1061,20 +1100,22 @@ impl<'db> Type<'db> {
10611100
| Type::ModuleLiteral(..),
10621101
Type::SubclassOf(_),
10631102
) => true,
1103+
10641104
(Type::SubclassOf(_), _) | (_, Type::SubclassOf(_)) => {
10651105
// TODO: Once we have support for final classes, we can determine disjointness in some cases
10661106
// here. However, note that it might be better to turn `Type::SubclassOf('FinalClass')` into
10671107
// `Type::ClassLiteral('FinalClass')` during construction, instead of adding special cases for
10681108
// final classes inside `Type::SubclassOf` everywhere.
10691109
false
10701110
}
1071-
(Type::KnownInstance(left), Type::KnownInstance(right)) => left != right,
1111+
10721112
(Type::KnownInstance(left), right) => {
10731113
left.instance_fallback(db).is_disjoint_from(db, right)
10741114
}
10751115
(left, Type::KnownInstance(right)) => {
10761116
left.is_disjoint_from(db, right.instance_fallback(db))
10771117
}
1118+
10781119
(
10791120
Type::Instance(InstanceType { class: class_none }),
10801121
Type::Instance(InstanceType { class: class_other }),
@@ -1086,6 +1127,7 @@ impl<'db> Type<'db> {
10861127
class_other.known(db),
10871128
Some(KnownClass::NoneType | KnownClass::Object)
10881129
),
1130+
10891131
(Type::Instance(InstanceType { class: class_none }), _)
10901132
| (_, Type::Instance(InstanceType { class: class_none }))
10911133
if class_none.is_known(db, KnownClass::NoneType) =>
@@ -1112,7 +1154,6 @@ impl<'db> Type<'db> {
11121154
| (Type::Instance(InstanceType { class }), Type::StringLiteral(..)) => {
11131155
!matches!(class.known(db), Some(KnownClass::Str | KnownClass::Object))
11141156
}
1115-
(Type::StringLiteral(..), _) | (_, Type::StringLiteral(..)) => true,
11161157

11171158
(Type::LiteralString, Type::LiteralString) => false,
11181159
(Type::LiteralString, Type::Instance(InstanceType { class }))
@@ -1126,14 +1167,12 @@ impl<'db> Type<'db> {
11261167
class.known(db),
11271168
Some(KnownClass::Bytes | KnownClass::Object)
11281169
),
1129-
(Type::BytesLiteral(..), _) | (_, Type::BytesLiteral(..)) => true,
11301170

11311171
(Type::SliceLiteral(..), Type::Instance(InstanceType { class }))
11321172
| (Type::Instance(InstanceType { class }), Type::SliceLiteral(..)) => !matches!(
11331173
class.known(db),
11341174
Some(KnownClass::Slice | KnownClass::Object)
11351175
),
1136-
(Type::SliceLiteral(..), _) | (_, Type::SliceLiteral(..)) => true,
11371176

11381177
(Type::ClassLiteral(..), Type::Instance(InstanceType { class }))
11391178
| (Type::Instance(InstanceType { class }), Type::ClassLiteral(..)) => {
@@ -1161,30 +1200,29 @@ impl<'db> Type<'db> {
11611200
false
11621201
}
11631202

1164-
(Type::Tuple(tuple), other) | (other, Type::Tuple(tuple)) => {
1165-
if let Type::Tuple(other_tuple) = other {
1166-
if tuple.len(db) == other_tuple.len(db) {
1167-
tuple
1168-
.elements(db)
1169-
.iter()
1170-
.zip(other_tuple.elements(db))
1171-
.any(|(e1, e2)| e1.is_disjoint_from(db, *e2))
1172-
} else {
1173-
true
1174-
}
1203+
(Type::Tuple(tuple), Type::Tuple(other_tuple)) => {
1204+
if tuple.len(db) == other_tuple.len(db) {
1205+
tuple
1206+
.elements(db)
1207+
.iter()
1208+
.zip(other_tuple.elements(db))
1209+
.any(|(e1, e2)| e1.is_disjoint_from(db, *e2))
11751210
} else {
1176-
// We can not be sure if the tuple is disjoint from 'other' because:
1177-
// - 'other' might be the homogeneous arbitrary-length tuple type
1178-
// tuple[T, ...] (which we don't have support for yet); if all of
1179-
// our element types are not disjoint with T, this is not disjoint
1180-
// - 'other' might be a user subtype of tuple, which, if generic
1181-
// over the same or compatible *Ts, would overlap with tuple.
1182-
//
1183-
// TODO: add checks for the above cases once we support them
1184-
1185-
false
1211+
true
11861212
}
11871213
}
1214+
1215+
(Type::Tuple(..), Type::Instance(..)) | (Type::Instance(..), Type::Tuple(..)) => {
1216+
// We can not be sure if the tuple is disjoint from the instance because:
1217+
// - 'other' might be the homogeneous arbitrary-length tuple type
1218+
// tuple[T, ...] (which we don't have support for yet); if all of
1219+
// our element types are not disjoint with T, this is not disjoint
1220+
// - 'other' might be a user subtype of tuple, which, if generic
1221+
// over the same or compatible *Ts, would overlap with tuple.
1222+
//
1223+
// TODO: add checks for the above cases once we support them
1224+
false
1225+
}
11881226
}
11891227
}
11901228

@@ -3677,6 +3715,7 @@ pub(crate) mod tests {
36773715
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1)]), Ty::Tuple(vec![Ty::IntLiteral(2)]))]
36783716
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1)]))]
36793717
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(3)]))]
3718+
#[test_case(Ty::Tuple(vec![]), Ty::BuiltinClassLiteral("object"))]
36803719
fn is_disjoint_from(a: Ty, b: Ty) {
36813720
let db = setup_db();
36823721
let a = a.into_type(&db);

0 commit comments

Comments
 (0)