Skip to content

Commit 5cf1021

Browse files
N5N3vtjnash
andauthored
Subtype: Fix some diagonal rule related false alarm (JuliaLang#53034)
close JuliaLang#33137 close JuliaLang#53021 --------- Co-authored-by: Jameson Nash <vtjnash@gmail.com>
1 parent 1e45aba commit 5cf1021

File tree

5 files changed

+77
-45
lines changed

5 files changed

+77
-45
lines changed

src/jltypes.c

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,43 @@ static void isort_union(jl_value_t **a, size_t len) JL_NOTSAFEPOINT
556556
}
557557
}
558558

559+
static int simple_subtype(jl_value_t *a, jl_value_t *b, int hasfree, int isUnion)
560+
{
561+
if (a == jl_bottom_type || b == (jl_value_t*)jl_any_type)
562+
return 1;
563+
if (jl_egal(a, b))
564+
return 1;
565+
if (hasfree == 0) {
566+
int mergeable = isUnion;
567+
if (!mergeable) // issue #24521: don't merge Type{T} where typeof(T) varies
568+
mergeable = !(jl_is_type_type(a) && jl_is_type_type(b) &&
569+
jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b)));
570+
return mergeable && jl_subtype(a, b);
571+
}
572+
if (jl_is_typevar(a)) {
573+
jl_value_t *na = ((jl_tvar_t*)a)->ub;
574+
hasfree &= jl_has_free_typevars(na);
575+
return simple_subtype(na, b, hasfree, isUnion);
576+
}
577+
if (jl_is_typevar(b)) {
578+
jl_value_t *nb = ((jl_tvar_t*)b)->lb;
579+
// This branch is not valid if `b` obeys diagonal rule,
580+
// as it might normalize `Union` into a single `TypeVar`, e.g.
581+
// Tuple{Union{Int,T},T} where {T>:Int} != Tuple{T,T} where {T>:Int}
582+
if (is_leaf_bound(nb))
583+
return 0;
584+
hasfree &= jl_has_free_typevars(nb) << 1;
585+
return simple_subtype(a, nb, hasfree, isUnion);
586+
}
587+
if (b==(jl_value_t*)jl_datatype_type || b==(jl_value_t*)jl_typeofbottom_type) {
588+
// This branch is not valid for `Union`/`UnionAll`, e.g.
589+
// (Type{Union{Int,T2} where {T2<:T1}} where {T1}){Int} == Type{Int64}
590+
// (Type{Union{Int,T1}} where {T1}){Int} == Type{Int64}
591+
return jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b;
592+
}
593+
return 0;
594+
}
595+
559596
JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
560597
{
561598
if (n == 0)
@@ -580,13 +617,9 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
580617
int has_free = temp[i] != NULL && jl_has_free_typevars(temp[i]);
581618
for (j = 0; j < nt; j++) {
582619
if (j != i && temp[i] && temp[j]) {
583-
if (temp[i] == jl_bottom_type ||
584-
temp[j] == (jl_value_t*)jl_any_type ||
585-
jl_egal(temp[i], temp[j]) ||
586-
(!has_free && !jl_has_free_typevars(temp[j]) &&
587-
jl_subtype(temp[i], temp[j]))) {
620+
int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1);
621+
if (simple_subtype(temp[i], temp[j], has_free2, 1))
588622
temp[i] = NULL;
589-
}
590623
}
591624
}
592625
}
@@ -608,17 +641,7 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
608641
return tu;
609642
}
610643

611-
// note: this is turned off as `Union` doesn't do such normalization.
612-
// static int simple_subtype(jl_value_t *a, jl_value_t *b)
613-
// {
614-
// if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b)
615-
// return 1;
616-
// if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->lb))
617-
// return 1;
618-
// return 0;
619-
// }
620-
621-
static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)
644+
static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree, int isUnion)
622645
{
623646
int subab = 0, subba = 0;
624647
if (jl_egal(a, b)) {
@@ -630,9 +653,9 @@ static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)
630653
else if (b == jl_bottom_type || a == (jl_value_t*)jl_any_type) {
631654
subba = 1;
632655
}
633-
else if (hasfree) {
634-
// subab = simple_subtype(a, b);
635-
// subba = simple_subtype(b, a);
656+
else if (hasfree != 0) {
657+
subab = simple_subtype(a, b, hasfree, isUnion);
658+
subba = simple_subtype(b, a, hasfree, isUnion);
636659
}
637660
else if (jl_is_type_type(a) && jl_is_type_type(b) &&
638661
jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b))) {
@@ -664,10 +687,11 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
664687
// first remove cross-redundancy and check if `a >: b` or `a <: b`.
665688
for (i = 0; i < nta; i++) {
666689
if (temp[i] == NULL) continue;
667-
int hasfree = jl_has_free_typevars(temp[i]);
690+
int has_free = jl_has_free_typevars(temp[i]);
668691
for (j = nta; j < nt; j++) {
669692
if (temp[j] == NULL) continue;
670-
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
693+
int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1);
694+
int subs = simple_subtype2(temp[i], temp[j], has_free2, 0);
671695
int subab = subs & 1, subba = subs >> 1;
672696
if (subab) {
673697
temp[i] = NULL;
@@ -697,15 +721,9 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
697721
size_t jmax = i < nta ? nta : nt;
698722
for (j = jmin; j < jmax; j++) {
699723
if (j != i && temp[i] && temp[j]) {
700-
if (temp[i] == jl_bottom_type ||
701-
temp[j] == (jl_value_t*)jl_any_type ||
702-
jl_egal(temp[i], temp[j]) ||
703-
(!has_free && !jl_has_free_typevars(temp[j]) &&
704-
// issue #24521: don't merge Type{T} where typeof(T) varies
705-
!(jl_is_type_type(temp[i]) && jl_is_type_type(temp[j]) && jl_typeof(jl_tparam0(temp[i])) != jl_typeof(jl_tparam0(temp[j]))) &&
706-
jl_subtype(temp[i], temp[j]))) {
724+
int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1);
725+
if (simple_subtype(temp[i], temp[j], has_free2, 0))
707726
temp[i] = NULL;
708-
}
709727
}
710728
}
711729
}
@@ -769,7 +787,7 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
769787
int hasfree = jl_has_free_typevars(temp[i]);
770788
for (j = nta; j < nt; j++) {
771789
if (temp[j] == NULL) continue;
772-
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
790+
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]), 0);
773791
int subab = subs & 1, subba = subs >> 1;
774792
if (subba && !subab) {
775793
stemp[i] = -1;

src/julia.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,6 +1496,8 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT
14961496

14971497
JL_DLLEXPORT int jl_subtype(jl_value_t *a, jl_value_t *b);
14981498

1499+
int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT;
1500+
14991501
STATIC_INLINE int jl_is_kind(jl_value_t *v) JL_NOTSAFEPOINT
15001502
{
15011503
return (v==(jl_value_t*)jl_uniontype_type || v==(jl_value_t*)jl_datatype_type ||

src/subtype.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ static int subtype_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int R, int pa
805805
// check that a type is concrete or quasi-concrete (Type{T}).
806806
// this is used to check concrete typevars:
807807
// issubtype is false if the lower bound of a concrete type var is not concrete.
808-
static int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT
808+
int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT
809809
{
810810
if (v == jl_bottom_type)
811811
return 1;
@@ -1997,7 +1997,7 @@ static int obvious_subtype(jl_value_t *x, jl_value_t *y, jl_value_t *y0, int *su
19971997
if (var_occurs_invariant(body, (jl_tvar_t*)b))
19981998
return 0;
19991999
}
2000-
if (nparams_expanded_x > npy && jl_is_typevar(b) && concrete_min(a1) > 1) {
2000+
if (nparams_expanded_x > npy && jl_is_typevar(b) && is_leaf_typevar((jl_tvar_t *)b) && concrete_min(a1) > 1) {
20012001
// diagonal rule for 2 or more elements: they must all be concrete on the LHS
20022002
*subtype = 0;
20032003
return 1;
@@ -2008,7 +2008,7 @@ static int obvious_subtype(jl_value_t *x, jl_value_t *y, jl_value_t *y0, int *su
20082008
}
20092009
for (; i < nparams_expanded_x; i++) {
20102010
jl_value_t *a = (vx != JL_VARARG_NONE && i >= npx - 1) ? vxt : jl_tparam(x, i);
2011-
if (i > npy && jl_is_typevar(b)) { // i == npy implies a == a1
2011+
if (i > npy && jl_is_typevar(b) && is_leaf_typevar((jl_tvar_t *)b)) { // i == npy implies a == a1
20122012
// diagonal rule: all the later parameters are also constrained to be type-equal to the first
20132013
jl_value_t *a2 = a;
20142014
jl_value_t *au = jl_unwrap_unionall(a);

test/core.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ k11840(::Type{Union{Tuple{Int32}, Tuple{Int64}}}) = '2'
239239
# issue #20511
240240
f20511(x::DataType) = 0
241241
f20511(x) = 1
242-
Type{Integer} # cache this
243-
@test f20511(Union{Integer,T} where T <: Unsigned) == 1
242+
Type{AbstractSet} # cache this
243+
@test f20511(Union{AbstractSet,Set{T}} where T) == 1
244244

245245
# join
246246
@test typejoin(Int8,Int16) === Signed
@@ -8101,3 +8101,14 @@ end
81018101

81028102
# #52433
81038103
@test_throws ErrorException Core.Intrinsics.pointerref(Ptr{Vector{Int64}}(C_NULL), 1, 0)
8104+
8105+
# #53034 (Union normalization for typevar elimination)
8106+
@test Tuple{Int,Any} <: Tuple{Union{Int,T},T} where {T>:Int}
8107+
@test Tuple{Int,Any} <: Tuple{Union{Int,T},T} where {T>:Integer}
8108+
# #53034 (Union normalization for Type elimination)
8109+
@test Int isa Type{Union{Int,T2} where {T2<:T1}} where {T1}
8110+
@test Int isa Type{Union{Int,T1}} where {T1}
8111+
@test Int isa Union{UnionAll, Type{Union{Int,T2} where {T2<:T1}}} where {T1}
8112+
@test Int isa Union{Union, Type{Union{Int,T1}}} where {T1}
8113+
@test_broken Int isa Union{UnionAll, Type{Union{Int,T2} where {T2<:T1}} where {T1}}
8114+
@test_broken Int isa Union{Union, Type{Union{Int,T1}} where {T1}}

test/subtype.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ function test_diagonal()
146146
@test isequal_type(Ref{Tuple{T, T} where Int<:T<:Int},
147147
Ref{Tuple{S, S}} where Int<:S<:Int)
148148

149+
# issue #53021
150+
@test Tuple{X, X} where {X<:Union{}} <: Tuple{X, X, Vararg{Any}} where {Int<:X<:Int}
151+
@test Tuple{Integer, X, Vararg{X}} where {X<:Int} <: Tuple{Any, Vararg{X}} where {X>:Int}
152+
@test Tuple{Any, X, Vararg{X}} where {X<:Int} <: Tuple{Vararg{X}} where X>:Integer
153+
@test Tuple{Integer, Integer, Any, Vararg{Any}} <: Tuple{Vararg{X}} where X>:Integer
154+
# issue #53019
155+
@test Tuple{T,T} where {T<:Int} <: Tuple{T,T} where {T>:Int}
156+
149157
let A = Tuple{Int,Int8,Vector{Integer}},
150158
B = Tuple{T,T,Vector{T}} where T>:Integer,
151159
C = Tuple{T,T,Vector{Union{Integer,T}}} where T
@@ -1260,14 +1268,7 @@ let a = Tuple{Tuple{T2,4},T6} where T2 where T6,
12601268
end
12611269
let a = Tuple{T3,Int64,Tuple{T3}} where T3,
12621270
b = Tuple{S3,S3,S4} where S4 where S3
1263-
I1 = typeintersect(a, b)
1264-
I2 = typeintersect(b, a)
1265-
@test I1 <: I2
1266-
@test I2 <: I1
1267-
@test_broken I1 <: a
1268-
@test I2 <: a
1269-
@test I1 <: b
1270-
@test I2 <: b
1271+
@testintersect(a, b, Tuple{Int64, Int64, Tuple{Int64}})
12711272
end
12721273
let a = Tuple{T1,Val{T2},T2} where T2 where T1,
12731274
b = Tuple{Float64,S1,S2} where S2 where S1
@@ -2445,7 +2446,7 @@ abstract type P47654{A} end
24452446
@test_broken typeintersect(Type{Tuple{Array{T,1} where T}}, UnionAll) != Union{}
24462447

24472448
#issue 33137
2448-
@test_broken (Tuple{Q,Int} where Q<:Int) <: Tuple{T,T} where T
2449+
@test (Tuple{Q,Int} where Q<:Int) <: Tuple{T,T} where T
24492450

24502451
# issue 24333
24512452
@test (Type{Union{Ref,Cvoid}} <: Type{Union{T,Cvoid}} where T)

0 commit comments

Comments
 (0)