@@ -1234,7 +1234,55 @@ class Typer extends Namer
12341234 if (tree.isInline) checkInInlineContext(" inline match" , tree.posd)
12351235 val sel1 = typedExpr(tree.selector)
12361236 val selType = fullyDefinedType(sel1.tpe, " pattern selector" , tree.span).widen
1237- val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1237+
1238+ /** Extractor for match types hidden behind an AppliedType/MatchAlias */
1239+ object MatchTypeInDisguise {
1240+ def unapply (tp : AppliedType ): Option [MatchType ] = tp match {
1241+ case AppliedType (tycon : TypeRef , args) =>
1242+ tycon.info match {
1243+ case MatchAlias (alias) =>
1244+ alias.applyIfParameterized(args) match {
1245+ case mt : MatchType => Some (mt)
1246+ case _ => None
1247+ }
1248+ case _ => None
1249+ }
1250+ case _ => None
1251+ }
1252+ }
1253+
1254+ /** Does `tree` has the same shape as the given match type?
1255+ * We only support typed patterns with empty guards, but
1256+ * that could potentially be extended in the future.
1257+ */
1258+ def isMatchTypeShaped (mt : MatchType ): Boolean =
1259+ mt.cases.size == tree.cases.size
1260+ && sel1.tpe.frozen_<:< (mt.scrutinee)
1261+ && tree.cases.forall(_.guard.isEmpty)
1262+ && tree.cases
1263+ .map(cas => untpd.unbind(untpd.unsplice(cas.pat)))
1264+ .zip(mt.cases)
1265+ .forall {
1266+ case (pat : Typed , pt) =>
1267+ // To check that pattern types correspond we need to type
1268+ // check `pat` here and throw away the result.
1269+ val gadtCtx : Context = ctx.fresh.setFreshGADTBounds
1270+ val pat1 = typedPattern(pat, selType)(using gadtCtx)
1271+ val Typed (_, tpt) = tpd.unbind(tpd.unsplice(pat1))
1272+ instantiateMatchTypeProto(pat1, pt) match {
1273+ case defn.MatchCase (patternTp, _) => tpt.tpe frozen_=:= patternTp
1274+ case _ => false
1275+ }
1276+ case _ => false
1277+ }
1278+
1279+ val result = pt match {
1280+ case MatchTypeInDisguise (mt) if isMatchTypeShaped(mt) =>
1281+ typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt)
1282+ case _ =>
1283+ typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1284+ }
1285+
12381286 result match {
12391287 case Match (sel, CaseDef (pat, _, _) :: _) =>
12401288 tree.selector.removeAttachment(desugar.CheckIrrefutable ) match {
@@ -1250,6 +1298,21 @@ class Typer extends Namer
12501298 result
12511299 }
12521300
1301+ /** Special typing of Match tree when the expected type is a MatchType,
1302+ * and the patterns of the Match tree and the MatchType correspond.
1303+ */
1304+ def typedDependentMatchFinish (tree : untpd.Match , sel : Tree , wideSelType : Type , cases : List [untpd.CaseDef ], pt : MatchType )(using Context ): Tree = {
1305+ var caseCtx = ctx
1306+ val cases1 = tree.cases.zip(pt.cases)
1307+ .map { case (cas, tpe) =>
1308+ val case1 = typedCase(cas, sel, wideSelType, tpe)(using caseCtx)
1309+ caseCtx = Nullables .afterPatternContext(sel, case1.pat)
1310+ case1
1311+ }
1312+ .asInstanceOf [List [CaseDef ]]
1313+ assignType(cpy.Match (tree)(sel, cases1), sel, cases1).cast(pt)
1314+ }
1315+
12531316 // Overridden in InlineTyper for inline matches
12541317 def typedMatchFinish (tree : untpd.Match , sel : Tree , wideSelType : Type , cases : List [untpd.CaseDef ], pt : Type )(using Context ): Tree = {
12551318 val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
@@ -1290,17 +1353,33 @@ class Typer extends Namer
12901353 }
12911354 }
12921355
1356+ /** If the prototype `pt` is the type lambda (when doing a dependent
1357+ * typing of a match), instantiate that type lambda with the pattern
1358+ * variables found in the pattern `pat`.
1359+ */
1360+ def instantiateMatchTypeProto (pat : Tree , pt : Type )(implicit ctx : Context ) = pt match {
1361+ case caseTp : HKTypeLambda =>
1362+ val bindingsSyms = tpd.patVars(pat).reverse
1363+ val bindingsTps = bindingsSyms.collect { case sym if sym.isType => sym.typeRef }
1364+ caseTp.appliedTo(bindingsTps)
1365+ case pt => pt
1366+ }
1367+
12931368 /** Type a case. */
12941369 def typedCase (tree : untpd.CaseDef , sel : Tree , wideSelType : Type , pt : Type )(using Context ): CaseDef = {
12951370 val originalCtx = ctx
12961371 val gadtCtx : Context = ctx.fresh.setFreshGADTBounds
12971372
12981373 def caseRest (pat : Tree )(using Context ) = {
1374+ val pt1 = instantiateMatchTypeProto(pat, pt) match {
1375+ case defn.MatchCase (_, bodyPt) => bodyPt
1376+ case pt => pt
1377+ }
12991378 val pat1 = indexPattern(tree).transform(pat)
13001379 val guard1 = typedExpr(tree.guard, defn.BooleanType )
1301- var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt ), pt , ctx.scope.toList)
1302- if (pt .isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1303- body1 = body1.ensureConforms(pt )(originalCtx)
1380+ var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1 ), pt1 , ctx.scope.toList)
1381+ if (pt1 .isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1382+ body1 = body1.ensureConforms(pt1 )(originalCtx)
13041383 assignType(cpy.CaseDef (tree)(pat1, guard1, body1), pat1, body1)
13051384 }
13061385
0 commit comments