@@ -6,12 +6,14 @@ import core.*
66import Contexts .*
77import Symbols .*
88import Types .*
9+ import Denotations .Denotation
910import StdNames .*
11+ import Names .TermName
1012import NameKinds .OuterSelectName
1113import NameKinds .SuperAccessorName
1214
1315import ast .tpd .*
14- import util .SourcePosition
16+ import util .{ SourcePosition , NoSourcePosition }
1517import config .Printers .init as printer
1618import reporting .StoreReporter
1719import reporting .trace as log
@@ -1176,6 +1178,16 @@ object Objects:
11761178 * @param klass The enclosing class where the type `tp` is located.
11771179 */
11781180 def patternMatch (scrutinee : Value , cases : List [CaseDef ], thisV : Value , klass : ClassSymbol ): Contextual [Value ] =
1181+ // expected member types for `unapplySeq`
1182+ def lengthType = ExprType (defn.IntType )
1183+ def lengthCompareType = MethodType (List (defn.IntType ), defn.IntType )
1184+ def applyType (elemTp : Type ) = MethodType (List (defn.IntType ), elemTp)
1185+ def dropType (elemTp : Type ) = MethodType (List (defn.IntType ), defn.CollectionSeqType .appliedTo(elemTp))
1186+ def toSeqType (elemTp : Type ) = ExprType (defn.CollectionSeqType .appliedTo(elemTp))
1187+
1188+ def getMemberMethod (receiver : Type , name : TermName , tp : Type ): Denotation =
1189+ receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
1190+
11791191 def evalCase (caseDef : CaseDef ): Value =
11801192 evalPattern(scrutinee, caseDef.pat)
11811193 eval(caseDef.guard, thisV, klass)
@@ -1206,18 +1218,59 @@ object Objects:
12061218 case UnApply (fun, implicits, pats) =>
12071219 val fun1 = funPart(fun)
12081220 val funRef = fun1.tpe.asInstanceOf [TermRef ]
1221+ val unapplyResTp = funRef.widen.finalResultType
1222+
1223+ val receiver = evalType(funRef.prefix, thisV, klass)
1224+ val implicitValues = evalArgs(implicits.map(Arg .apply), thisV, klass)
1225+ // TODO: implicit values may appear before and/or after the scrutinee parameter.
1226+ val unapplyRes = call(receiver, funRef.symbol, TraceValue (scrutinee, summon[Trace ]) :: implicitValues, funRef.prefix, superType = NoType , needResolve = true )
1227+
12091228 if fun.symbol.name == nme.unapplySeq then
1210- // TODO: handle unapplySeq
1211- ()
1229+ var resultTp = unapplyResTp
1230+ var elemTp = unapplySeqTypeElemTp(resultTp)
1231+ var arity = productArity(resultTp, NoSourcePosition )
1232+ var needsGet = false
1233+ if (! elemTp.exists && arity <= 0 ) {
1234+ needsGet = true
1235+ resultTp = resultTp.select(nme.get).finalResultType
1236+ elemTp = unapplySeqTypeElemTp(resultTp.widen)
1237+ arity = productSelectorTypes(resultTp, NoSourcePosition ).size
1238+ }
1239+
1240+ var resToMatch = unapplyRes
1241+
1242+ if needsGet then
1243+ // Get match
1244+ val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1245+ call(unapplyRes, isEmptyDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1246+
1247+ val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1248+ resToMatch = call(unapplyRes, getDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1249+ end if
1250+
1251+ if elemTp.exists then
1252+ // sequence match
1253+ evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
1254+ else
1255+ // product sequence match
1256+ val selectors = productSelectors(resultTp)
1257+ assert(selectors.length <= pats.length)
1258+ selectors.init.zip(pats).map { (sel, pat) =>
1259+ val selectRes = call(resToMatch, sel, Nil , resultTp, superType = NoType , needResolve = true )
1260+ evalPattern(selectRes, pat)
1261+ }
1262+ val seqPats = pats.drop(selectors.length - 1 )
1263+ val toSeqRes = call(resToMatch, selectors.last, Nil , resultTp, superType = NoType , needResolve = true )
1264+ val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
1265+ evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
1266+ end if
1267+
12121268 else
1213- val receiver = evalType(funRef.prefix, thisV, klass)
1214- val implicitValues = evalArgs(implicits.map(Arg .apply), thisV, klass)
1215- val unapplyRes = call(receiver, funRef.symbol, TraceValue (scrutinee, summon[Trace ]) :: implicitValues, funRef.prefix, superType = NoType , needResolve = true )
12161269 // distribute unapply to patterns
1217- val unapplyResTp = funRef.widen.finalResultType
12181270 if isProductMatch(unapplyResTp, pats.length) then
12191271 // product match
1220- val selectors = productSelectors(unapplyResTp).take(pats.length)
1272+ val selectors = productSelectors(unapplyResTp)
1273+ assert(selectors.length == pats.length)
12211274 selectors.zip(pats).map { (sel, pat) =>
12221275 val selectRes = call(unapplyRes, sel, Nil , unapplyResTp, superType = NoType , needResolve = true )
12231276 evalPattern(selectRes, pat)
@@ -1239,7 +1292,7 @@ object Objects:
12391292 val getResTp = getDenot.info.finalResultType
12401293 val selectors = productSelectors(getResTp).take(pats.length)
12411294 selectors.zip(pats).map { (sel, pat) =>
1242- val selectRes = call(unapplyRes, sel, Nil , unapplyResTp , superType = NoType , needResolve = true )
1295+ val selectRes = call(unapplyRes, sel, Nil , getResTp , superType = NoType , needResolve = true )
12431296 evalPattern(selectRes, pat)
12441297 }
12451298 end if
@@ -1259,6 +1312,42 @@ object Objects:
12591312
12601313 end evalPattern
12611314
1315+ /**
1316+ * Evaluate a sequence value against sequence patterns.
1317+ */
1318+ def evalSeqPatterns (scrutinee : Value , scrutineeType : Type , elemType : Type , pats : List [Tree ]): Unit =
1319+ // call .lengthCompare or .length
1320+ val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
1321+ if lengthCompareDenot.exists then
1322+ call(scrutinee, lengthCompareDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1323+ else
1324+ val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
1325+ call(scrutinee, lengthDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1326+ end if
1327+
1328+ // call .apply
1329+ val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1330+ val applyRes = call(scrutinee, applyDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1331+
1332+ if isWildcardStarArg(pats.last) then
1333+ if pats.size == 1 then
1334+ // call .toSeq
1335+ val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
1336+ val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1337+ evalPattern(toSeqRes, pats.head)
1338+ else
1339+ // call .drop
1340+ val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
1341+ val dropRes = call(scrutinee, dropDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1342+ for pat <- pats.init do evalPattern(applyRes, pat)
1343+ evalPattern(dropRes, pats.last)
1344+ end if
1345+ else
1346+ // no patterns like `xs*`
1347+ for pat <- pats do evalPattern(applyRes, pat)
1348+ end evalSeqPatterns
1349+
1350+
12621351 cases.map(evalCase).join
12631352
12641353
0 commit comments