@@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
7373 * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
7474 * in which case the subtyping relationship "heals" the type.
7575 */
76- def constrainPatternType (pat : Type , scrut : Type , widenParams : Boolean = true ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
76+ def constrainPatternType (pat : Type , scrut : Type , forceInvariantRefinement : Boolean = false ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
7777
7878 def classesMayBeCompatible : Boolean = {
7979 import Flags ._
80- val patClassSym = pat.classSymbol
81- val scrutClassSym = scrut.classSymbol
82- ! patClassSym .exists || ! scrutClassSym .exists || {
83- if (patClassSym .is(Final )) patClassSym .derivesFrom(scrutClassSym )
84- else if (scrutClassSym .is(Final )) scrutClassSym .derivesFrom(patClassSym )
85- else if (! patClassSym .is(Flags .Trait ) && ! scrutClassSym .is(Flags .Trait ))
86- patClassSym .derivesFrom(scrutClassSym ) || scrutClassSym .derivesFrom(patClassSym )
80+ val patCls = pat.classSymbol
81+ val scrCls = scrut.classSymbol
82+ ! patCls .exists || ! scrCls .exists || {
83+ if (patCls .is(Final )) patCls .derivesFrom(scrCls )
84+ else if (scrCls .is(Final )) scrCls .derivesFrom(patCls )
85+ else if (! patCls .is(Flags .Trait ) && ! scrCls .is(Flags .Trait ))
86+ patCls .derivesFrom(scrCls ) || scrCls .derivesFrom(patCls )
8787 else true
8888 }
8989 }
@@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
9393 case tp => tp
9494 }
9595
96+ def tryConstrainSimplePatternType (pat : Type , scrut : Type ) = {
97+ val patCls = pat.classSymbol
98+ val scrCls = scrut.classSymbol
99+ patCls.exists && scrCls.exists
100+ && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101+ && constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
102+ }
103+
96104 def constrainUpcasted (scrut : Type ): Boolean = trace(i " constrainUpcasted( $scrut) " , gadts) {
97105 // Fold a list of types into an AndType
98106 def buildAndType (xs : List [Type ]): Type = {
@@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
113121 val andType = buildAndType(parents)
114122 ! andType.exists || constrainPatternType(pat, andType)
115123 case scrut @ AppliedType (tycon : TypeRef , _) if tycon.symbol.isClass =>
116- val patClassSym = pat.classSymbol
124+ val patCls = pat.classSymbol
117125 // find all shared parents in the inheritance hierarchy between pat and scrut
118126 def allParentsSharedWithPat (tp : Type , tpClassSym : ClassSymbol ): List [Symbol ] = {
119127 var parents = tpClassSym.info.parents
120128 if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
121129 parents = parents.tail
122130 parents flatMap { tp =>
123131 val sym = tp.classSymbol.asClass
124- if patClassSym .derivesFrom(sym) then List (sym)
132+ if patCls .derivesFrom(sym) then List (sym)
125133 else allParentsSharedWithPat(tp, sym)
126134 }
127135 }
@@ -135,42 +143,55 @@ trait PatternTypeConstrainer { self: TypeComparer =>
135143 case _ => NoType
136144 }
137145 if (upcasted.exists)
138- constrainSimplePatternType (pat, upcasted, widenParams ) || constrainUpcasted(upcasted)
146+ tryConstrainSimplePatternType (pat, upcasted) || constrainUpcasted(upcasted)
139147 else true
140148 }
141149 }
142150
143- scrut.dealias match {
151+ def dealiasDropNonmoduleRefs (tp : Type ) = tp.dealias match {
152+ case tp : TermRef =>
153+ // we drop TermRefs that don't have a class symbol, as they can't
154+ // meaningfully participate in GADT reasoning and just get in the way.
155+ // Their info could, for an example, be an AndType. One example where
156+ // this is important is an enum case that extends its parent and an
157+ // additional trait - argument-less enum cases desugar to vals.
158+ // See run/enum-Tree.scala.
159+ if tp.classSymbol.exists then tp else tp.info
160+ case tp => tp
161+ }
162+
163+ dealiasDropNonmoduleRefs(scrut) match {
144164 case OrType (scrut1, scrut2) =>
145165 either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
146166 case AndType (scrut1, scrut2) =>
147167 constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
148168 case scrut : RefinedOrRecType =>
149169 constrainPatternType(pat, stripRefinement(scrut))
150- case scrut => pat.dealias match {
170+ case scrut => dealiasDropNonmoduleRefs( pat) match {
151171 case OrType (pat1, pat2) =>
152172 either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
153173 case AndType (pat1, pat2) =>
154174 constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
155175 case pat : RefinedOrRecType =>
156176 constrainPatternType(stripRefinement(pat), scrut)
157177 case pat =>
158- constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
178+ tryConstrainSimplePatternType(pat, scrut)
179+ || classesMayBeCompatible && constrainUpcasted(scrut)
159180 }
160181 }
161182 }
162183
163184 /** Constrain "simple" patterns (see `constrainPatternType`).
164185 *
165- * This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee,
166- * or otherwise it cannot possibly match. In order to do that, we:
167- *
168- * 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents
169- * 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type.
170- * 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type
171- * 4. Check that `WidenedScrutineeType <: NarrowedPatternType`
186+ * This function expects to receive two types (scrutinee and pattern), both
187+ * of which have class symbols, one of which is derived from another. If the
188+ * type "being derived from" is an applied type, it will 1) "upcast" the
189+ * deriving type to an applied type with the same constructor and 2) infer
190+ * constraints for the applied types' arguments that follow from both
191+ * types being inhabited by one value (the scrutinee).
172192 *
173- * Importantly, note that the pattern type may contain type variables.
193+ * Importantly, note that the pattern type may contain type variables, which
194+ * are used to infer type arguments to Unapply trees.
174195 *
175196 * ## Invariant refinement
176197 * Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
@@ -194,7 +215,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
194215 * case classes without also appropriately extending the relevant case class
195216 * (see `RefChecks#checkCaseClassInheritanceInvariant`).
196217 */
197- def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , widenParams : Boolean ): Boolean = {
218+ def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , forceInvariantRefinement : Boolean ): Boolean = {
198219 def refinementIsInvariant (tp : Type ): Boolean = tp match {
199220 case tp : SingletonType => true
200221 case tp : ClassInfo => tp.cls.is(Final ) || tp.cls.is(Case )
@@ -212,13 +233,53 @@ trait PatternTypeConstrainer { self: TypeComparer =>
212233 tp
213234 }
214235
215- val widePt =
216- if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
217- else if widenParams then widenVariantParams(scrutineeTp)
218- else scrutineeTp
219- val narrowTp = SkolemType (patternTp)
220- trace(i " constraining simple pattern type $narrowTp <:< $widePt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
221- isSubType(narrowTp, widePt)
236+ val patternCls = patternTp.classSymbol
237+ val scrutineeCls = scrutineeTp.classSymbol
238+
239+ // NOTE: we already know that there is a derives-from relationship in either direction
240+ val upcastPattern =
241+ patternCls.derivesFrom(scrutineeCls)
242+
243+ val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
244+ val tp = if ! upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
245+
246+ val assumeInvariantRefinement =
247+ migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
248+
249+ trace(i " constraining simple pattern type $tp >:< $pt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
250+ (tp, pt) match {
251+ case (AppliedType (tyconS, argsS), AppliedType (tyconP, argsP)) =>
252+ val saved = state.constraint
253+ val savedGadt = ctx.gadt.fresh
254+ val result =
255+ tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
256+ val variance = param.paramVarianceSign
257+ if variance != 0 && ! assumeInvariantRefinement then true
258+ else if argS.isInstanceOf [TypeBounds ] || argP.isInstanceOf [TypeBounds ] then
259+ // Passing TypeBounds to isSubType on LHS or RHS does the
260+ // incorrect thing and infers unsound constraints, while simply
261+ // returning true is sound. However, I believe that it should
262+ // still be possible to extract useful constraints here.
263+ // TODO extract GADT information out of wildcard type arguments
264+ true
265+ else {
266+ var res = true
267+ if variance < 1 then res &&= isSubType(argS, argP)
268+ if variance > - 1 then res &&= isSubType(argP, argS)
269+ res
270+ }
271+ }
272+ if ! result then
273+ constraint = saved
274+ ctx.gadt.restore(savedGadt)
275+ result
276+ case _ =>
277+ // Give up if we don't get AppliedType, e.g. if we upcasted to Any.
278+ // Note that this doesn't mean that patternTp, scrutineeTp cannot possibly
279+ // be co-inhabited, just that we cannot extract information out of them directly
280+ // and should upcast.
281+ false
282+ }
222283 }
223284 }
224285}
0 commit comments