@@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
7373 * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
7474 * in which case the subtyping relationship "heals" the type.
7575 */
76- def constrainPatternType (pat : Type , scrut : Type , widenParams : Boolean = true ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
76+ def constrainPatternType (pat : Type , scrut : Type , forceInvariantRefinement : Boolean = false ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
7777
7878 def classesMayBeCompatible : Boolean = {
7979 import Flags ._
80- val patClassSym = pat.classSymbol
81- val scrutClassSym = scrut.classSymbol
82- ! patClassSym .exists || ! scrutClassSym .exists || {
83- if (patClassSym .is(Final )) patClassSym .derivesFrom(scrutClassSym )
84- else if (scrutClassSym .is(Final )) scrutClassSym .derivesFrom(patClassSym )
85- else if (! patClassSym .is(Flags .Trait ) && ! scrutClassSym .is(Flags .Trait ))
86- patClassSym .derivesFrom(scrutClassSym ) || scrutClassSym .derivesFrom(patClassSym )
80+ val patCls = pat.classSymbol
81+ val scrCls = scrut.classSymbol
82+ ! patCls .exists || ! scrCls .exists || {
83+ if (patCls .is(Final )) patCls .derivesFrom(scrCls )
84+ else if (scrCls .is(Final )) scrCls .derivesFrom(patCls )
85+ else if (! patCls .is(Flags .Trait ) && ! scrCls .is(Flags .Trait ))
86+ patCls .derivesFrom(scrCls ) || scrCls .derivesFrom(patCls )
8787 else true
8888 }
8989 }
@@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
9393 case tp => tp
9494 }
9595
96+ def tryConstrainSimplePatternType (pat : Type , scrut : Type ) = {
97+ val patCls = pat.classSymbol
98+ val scrCls = scrut.classSymbol
99+ patCls.exists && scrCls.exists
100+ && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101+ && constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
102+ }
103+
96104 def constrainUpcasted (scrut : Type ): Boolean = trace(i " constrainUpcasted( $scrut) " , gadts) {
97105 // Fold a list of types into an AndType
98106 def buildAndType (xs : List [Type ]): Type = {
@@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
113121 val andType = buildAndType(parents)
114122 ! andType.exists || constrainPatternType(pat, andType)
115123 case scrut @ AppliedType (tycon : TypeRef , _) if tycon.symbol.isClass =>
116- val patClassSym = pat.classSymbol
124+ val patCls = pat.classSymbol
117125 // find all shared parents in the inheritance hierarchy between pat and scrut
118126 def allParentsSharedWithPat (tp : Type , tpClassSym : ClassSymbol ): List [Symbol ] = {
119127 var parents = tpClassSym.info.parents
120128 if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
121129 parents = parents.tail
122130 parents flatMap { tp =>
123131 val sym = tp.classSymbol.asClass
124- if patClassSym .derivesFrom(sym) then List (sym)
132+ if patCls .derivesFrom(sym) then List (sym)
125133 else allParentsSharedWithPat(tp, sym)
126134 }
127135 }
@@ -135,27 +143,39 @@ trait PatternTypeConstrainer { self: TypeComparer =>
135143 case _ => NoType
136144 }
137145 if (upcasted.exists)
138- constrainSimplePatternType (pat, upcasted, widenParams ) || constrainUpcasted(upcasted)
146+ tryConstrainSimplePatternType (pat, upcasted) || constrainUpcasted(upcasted)
139147 else true
140148 }
141149 }
142150
143- scrut.dealias match {
151+ def dealiasDropNonmoduleRefs (tp : Type ) = tp.dealias match {
152+ case tp : TermRef =>
153+ // we drop TermRefs that don't have a class symbol, as they can't
154+ // meaningfully participate in GADT reasoning and just get in the way.
155+ // Their info could, for an example, be an AndType. One example where
156+ // this is important is an enum case that extends its parent and an
157+ // additional trait - argument-less enum cases desugar to vals.
158+ if tp.classSymbol.exists then tp else tp.info
159+ case tp => tp
160+ }
161+
162+ dealiasDropNonmoduleRefs(scrut) match {
144163 case OrType (scrut1, scrut2) =>
145164 either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
146165 case AndType (scrut1, scrut2) =>
147166 constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
148167 case scrut : RefinedOrRecType =>
149168 constrainPatternType(pat, stripRefinement(scrut))
150- case scrut => pat.dealias match {
169+ case scrut => dealiasDropNonmoduleRefs( pat) match {
151170 case OrType (pat1, pat2) =>
152171 either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
153172 case AndType (pat1, pat2) =>
154173 constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
155174 case pat : RefinedOrRecType =>
156175 constrainPatternType(stripRefinement(pat), scrut)
157176 case pat =>
158- constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
177+ tryConstrainSimplePatternType(pat, scrut)
178+ || classesMayBeCompatible && constrainUpcasted(scrut)
159179 }
160180 }
161181 }
@@ -194,7 +214,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
194214 * case classes without also appropriately extending the relevant case class
195215 * (see `RefChecks#checkCaseClassInheritanceInvariant`).
196216 */
197- def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , widenParams : Boolean ): Boolean = {
217+ def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , forceInvariantRefinement : Boolean ): Boolean = {
198218 def refinementIsInvariant (tp : Type ): Boolean = tp match {
199219 case tp : SingletonType => true
200220 case tp : ClassInfo => tp.cls.is(Final ) || tp.cls.is(Case )
@@ -212,13 +232,44 @@ trait PatternTypeConstrainer { self: TypeComparer =>
212232 tp
213233 }
214234
215- val widePt =
216- if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
217- else if widenParams then widenVariantParams(scrutineeTp)
218- else scrutineeTp
219- val narrowTp = SkolemType (patternTp)
220- trace(i " constraining simple pattern type $narrowTp <:< $widePt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
221- isSubType(narrowTp, widePt)
235+ val patternCls = patternTp.classSymbol
236+ val scrutineeCls = scrutineeTp.classSymbol
237+
238+ // NOTE: we already know that there is a derives-from relationship in either direction
239+ val upcastPattern =
240+ patternCls.derivesFrom(scrutineeCls)
241+
242+ val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
243+ val tp = if ! upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
244+
245+ val assumeInvariantRefinement =
246+ migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
247+
248+ trace(i " constraining simple pattern type $tp >:< $pt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
249+ (tp, pt) match {
250+ case (AppliedType (tyconS, argsS), AppliedType (tyconP, argsP)) =>
251+ val saved = state.constraint
252+ val savedGadt = ctx.gadt.fresh
253+ val result =
254+ tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
255+ val variance = param.paramVarianceSign
256+ if variance != 0 && ! assumeInvariantRefinement then true
257+ else if argS.isInstanceOf [TypeBounds ] || argP.isInstanceOf [TypeBounds ] then true
258+ else {
259+ var res = true
260+ if variance < 1 then res &&= isSubType(argS, argP)
261+ if variance > - 1 then res &&= isSubType(argP, argS)
262+ res
263+ }
264+ }
265+ if ! result then
266+ constraint = saved
267+ ctx.gadt.restore(savedGadt)
268+ result
269+ case _ =>
270+ // give up if we don't get AppliedType, e.g. if we upcasted to Any.
271+ false
272+ }
222273 }
223274 }
224275}
0 commit comments