@@ -31,11 +31,23 @@ object Matcher {
3131 * @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
3232 */
3333 def unapply [Tup <: Tuple ](scrutineeExpr : Expr [_])(implicit patternExpr : Expr [_], reflection : Reflection ): Option [Tup ] = {
34+ // TODO improve performance
3435 import reflection .{Bind => BindPattern , _ }
36+ import Matching ._
3537
3638 type Env = Set [(Symbol , Symbol )]
3739
38- // TODO improve performance
40+ inline def withEnv [T ](env : Env )(body : => given Env => T ): T = body given env
41+
42+ /** Check that all trees match with =#= and concatenate the results with && */
43+ def (scrutinees : List [Tree ]) =##= (patterns : List [Tree ]) given Env : Matching = {
44+ def rec (l1 : List [Tree ], l2 : List [Tree ]): Matching = (l1, l2) match {
45+ case (x :: xs, y :: ys) => x =#= y && rec(xs, ys)
46+ case (Nil , Nil ) => matched
47+ case _ => notMatched
48+ }
49+ rec(scrutinees, patterns)
50+ }
3951
4052 /** Check that the trees match and return the contents from the pattern holes.
4153 * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
@@ -45,7 +57,7 @@ object Matcher {
4557 * @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
4658 * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
4759 */
48- def treeMatches (scrutinee : Tree , pattern : Tree ) given Env : Option [ Tuple ] = {
60+ def (scrutinee : Tree ) =#= ( pattern : Tree ) given Env : Matching = {
4961
5062 /** Check that both are `val` or both are `lazy val` or both are `var` **/
5163 def checkValFlags (): Boolean = {
@@ -56,7 +68,7 @@ object Matcher {
5668 }
5769
5870 def bindingMatch (sym : Symbol ) =
59- Some ( Tuple1 ( new Bind (sym.name, sym) ))
71+ matched( new Bind (sym.name, sym))
6072
6173 def hasBindTypeAnnotation (tpt : TypeTree ): Boolean = tpt match {
6274 case Annotated (tpt2, Apply (Select (New (TypeIdent (" patternBindHole" )), " <init>" ), Nil )) => true
@@ -67,10 +79,6 @@ object Matcher {
6779 def hasBindAnnotation (sym : Symbol ) =
6880 sym.annots.exists { case Apply (Select (New (TypeIdent (" patternBindHole" ))," <init>" ),List ()) => true ; case _ => true }
6981
70- def treesMatch (scrutinees : List [Tree ], patterns : List [Tree ]): Option [Tuple ] =
71- if (scrutinees.size != patterns.size) None
72- else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _* )
73-
7482 /** Normalieze the tree */
7583 def normalize (tree : Tree ): Tree = tree match {
7684 case Block (Nil , expr) => normalize(expr)
@@ -85,126 +93,130 @@ object Matcher {
8593 if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
8694 s.tpe <:< tpt.tpe &&
8795 tpt2.tpe.derivesFrom(definitions.RepeatedParamClass ) =>
88- Some ( Tuple1 ( scrutinee.seal) )
96+ matched( scrutinee.seal)
8997
9098 // Match a scala.internal.Quoted.patternHole and return the scrutinee tree
9199 case (IsTerm (scrutinee), TypeApply (patternHole, tpt :: Nil ))
92100 if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
93101 scrutinee.tpe <:< tpt.tpe =>
94- Some ( Tuple1 ( scrutinee.seal) )
102+ matched( scrutinee.seal)
95103
96104 //
97105 // Match two equivalent trees
98106 //
99107
100108 case (Literal (constant1), Literal (constant2)) if constant1 == constant2 =>
101- Some (())
109+ matched
102110
103111 case (Typed (expr1, tpt1), Typed (expr2, tpt2)) =>
104- foldMatchings(treeMatches( expr1, expr2), treeMatches( tpt1, tpt2))
112+ expr1 =#= expr2 && tpt1 =#= tpt2
105113
106114 case (Ident (_), Ident (_)) if scrutinee.symbol == pattern.symbol || the[Env ].apply((scrutinee.symbol, pattern.symbol)) =>
107- Some (())
115+ matched
108116
109117 case (Select (qual1, _), Select (qual2, _)) if scrutinee.symbol == pattern.symbol =>
110- treeMatches( qual1, qual2)
118+ qual1 =#= qual2
111119
112120 case (IsRef (_), IsRef (_)) if scrutinee.symbol == pattern.symbol =>
113- Some (())
121+ matched
114122
115123 case (Apply (fn1, args1), Apply (fn2, args2)) if fn1.symbol == fn2.symbol =>
116- foldMatchings(treeMatches( fn1, fn2), treesMatch( args1, args2))
124+ fn1 =#= fn2 && args1 =##= args2
117125
118126 case (TypeApply (fn1, args1), TypeApply (fn2, args2)) if fn1.symbol == fn2.symbol =>
119- foldMatchings(treeMatches( fn1, fn2), treesMatch( args1, args2))
127+ fn1 =#= fn2 && args1 =##= args2
120128
121129 case (Block (stats1, expr1), Block (stats2, expr2)) =>
122- foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
130+ withEnv(the[Env ] ++ stats1.map(_.symbol).zip(stats2.map(_.symbol))) {
131+ stats1 =##= stats2 && expr1 =#= expr2
132+ }
123133
124134 case (If (cond1, thenp1, elsep1), If (cond2, thenp2, elsep2)) =>
125- foldMatchings(treeMatches( cond1, cond2), treeMatches( thenp1, thenp2), treeMatches( elsep1, elsep2))
135+ cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2
126136
127137 case (Assign (lhs1, rhs1), Assign (lhs2, rhs2)) =>
128138 val lhsMatch =
129- if (treeMatches (lhs1, lhs2).isDefined) Some (())
130- else None
131- foldMatchings( lhsMatch, treeMatches( rhs1, rhs2))
139+ if ((lhs1 =#= lhs2).isMatch) matched
140+ else notMatched
141+ lhsMatch && rhs1 =#= rhs2
132142
133143 case (While (cond1, body1), While (cond2, body2)) =>
134- foldMatchings(treeMatches( cond1, cond2), treeMatches( body1, body2))
144+ cond1 =#= cond2 && body1 =#= body2
135145
136146 case (NamedArg (name1, expr1), NamedArg (name2, expr2)) if name1 == name2 =>
137- treeMatches( expr1, expr2)
147+ expr1 =#= expr2
138148
139149 case (New (tpt1), New (tpt2)) =>
140- treeMatches( tpt1, tpt2)
150+ tpt1 =#= tpt2
141151
142152 case (This (_), This (_)) if scrutinee.symbol == pattern.symbol =>
143- Some (())
153+ matched
144154
145155 case (Super (qual1, mix1), Super (qual2, mix2)) if mix1 == mix2 =>
146- treeMatches( qual1, qual2)
156+ qual1 =#= qual2
147157
148158 case (Repeated (elems1, _), Repeated (elems2, _)) if elems1.size == elems2.size =>
149- treesMatch( elems1, elems2)
159+ elems1 =##= elems2
150160
151161 case (IsTypeTree (scrutinee @ TypeIdent (_)), IsTypeTree (pattern @ TypeIdent (_))) if scrutinee.symbol == pattern.symbol =>
152- Some (())
162+ matched
153163
154164 case (IsInferred (scrutinee), IsInferred (pattern)) if scrutinee.tpe <:< pattern.tpe =>
155- Some (())
165+ matched
156166
157167 case (Applied (tycon1, args1), Applied (tycon2, args2)) =>
158- foldMatchings(treeMatches( tycon1, tycon2), treesMatch( args1, args2))
168+ tycon1 =#= tycon2 && args1 =##= args2
159169
160170 case (ValDef (_, tpt1, rhs1), ValDef (_, tpt2, rhs2)) if checkValFlags() =>
161171 val bindMatch =
162172 if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
163- else Some (())
164- val returnTptMatch = treeMatches( tpt1, tpt2)
173+ else matched
174+ val returnTptMatch = tpt1 =#= tpt2
165175 val rhsEnv = the[Env ] + (scrutinee.symbol -> pattern.symbol)
166176 val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
167- foldMatchings( bindMatch, returnTptMatch, rhsMatchings)
177+ bindMatch && returnTptMatch && rhsMatchings
168178
169179 case (DefDef (_, typeParams1, paramss1, tpt1, Some (rhs1)), DefDef (_, typeParams2, paramss2, tpt2, Some (rhs2))) =>
170- val typeParmasMatch = treesMatch( typeParams1, typeParams2)
180+ val typeParmasMatch = typeParams1 =##= typeParams2
171181 val paramssMatch =
172- if (paramss1.size != paramss2.size) None
173- else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch( params1, params2) }: _* )
182+ if (paramss1.size != paramss2.size) notMatched
183+ else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => params1 =##= params2 }: _* )
174184 val bindMatch =
175185 if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
176- else Some (())
177- val tptMatch = treeMatches( tpt1, tpt2)
186+ else matched
187+ val tptMatch = tpt1 =#= tpt2
178188 val rhsEnv =
179189 the[Env ] + (scrutinee.symbol -> pattern.symbol) ++
180190 typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
181191 paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
182- val rhsMatch = treeMatches (rhs1, rhs2) given rhsEnv
192+ val rhsMatch = (rhs1 =#= rhs2) given rhsEnv
183193
184- foldMatchings( bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
194+ bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch
185195
186196 case (Lambda (_, tpt1), Lambda (_, tpt2)) =>
187197 // TODO match tpt1 with tpt2?
188- Some (())
198+ matched
189199
190200 case (Match (scru1, cases1), Match (scru2, cases2)) =>
191- val scrutineeMacth = treeMatches( scru1, scru2)
201+ val scrutineeMacth = scru1 =#= scru2
192202 val casesMatch =
193- if (cases1.size != cases2.size) None
203+ if (cases1.size != cases2.size) notMatched
194204 else foldMatchings(cases1.zip(cases2).map(caseMatches): _* )
195- foldMatchings( scrutineeMacth, casesMatch)
205+ scrutineeMacth && casesMatch
196206
197207 case (Try (body1, cases1, finalizer1), Try (body2, cases2, finalizer2)) =>
198- val bodyMacth = treeMatches( body1, body2)
208+ val bodyMacth = body1 =#= body2
199209 val casesMatch =
200- if (cases1.size != cases2.size) None
210+ if (cases1.size != cases2.size) notMatched
201211 else foldMatchings(cases1.zip(cases2).map(caseMatches): _* )
202212 val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
203- foldMatchings( bodyMacth, casesMatch, finalizerMatch)
213+ bodyMacth && casesMatch && finalizerMatch
204214
205215 // Ignore type annotations
206- case (Annotated (tpt, _), _) => treeMatches(tpt, pattern)
207- case (_, Annotated (tpt, _)) => treeMatches(scrutinee, tpt)
216+ case (Annotated (tpt, _), _) =>
217+ tpt =#= pattern
218+ case (_, Annotated (tpt, _)) =>
219+ scrutinee =#= tpt
208220
209221 // No Match
210222 case _ =>
@@ -225,26 +237,24 @@ object Matcher {
225237 |
226238 |
227239 | """ .stripMargin)
228- None
240+ notMatched
229241 }
230242 }
231243
232- def treeOptMatches (scrutinee : Option [Tree ], pattern : Option [Tree ]) given Env : Option [ Tuple ] = {
244+ def treeOptMatches (scrutinee : Option [Tree ], pattern : Option [Tree ]) given Env : Matching = {
233245 (scrutinee, pattern) match {
234- case (Some (x), Some (y)) => treeMatches(x, y)
235- case (None , None ) => Some (())
236- case _ => None
246+ case (Some (x), Some (y)) => x =#= y
247+ case (None , None ) => matched
248+ case _ => notMatched
237249 }
238250 }
239251
240- def caseMatches (scrutinee : CaseDef , pattern : CaseDef ) given Env : Option [Tuple ] = {
241- val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
242-
243- {
244- implied for Env = caseEnv
252+ def caseMatches (scrutinee : CaseDef , pattern : CaseDef ) given Env : Matching = {
253+ val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
254+ withEnv(caseEnv) {
245255 val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
246- val rhsMatch = treeMatches( scrutinee.rhs, pattern.rhs)
247- foldMatchings( patternMatch, guardMatch, rhsMatch)
256+ val rhsMatch = scrutinee.rhs =#= pattern.rhs
257+ patternMatch && guardMatch && rhsMatch
248258 }
249259 }
250260
@@ -258,34 +268,34 @@ object Matcher {
258268 * @return The new environment containing the bindings defined in this pattern tuppled with
259269 * `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
260270 */
261- def patternMatches (scrutinee : Pattern , pattern : Pattern ) given Env : (Env , Option [ Tuple ] ) = (scrutinee, pattern) match {
271+ def (scrutinee : Pattern ) =%= ( pattern : Pattern ) given Env : (Env , Matching ) = (scrutinee, pattern) match {
262272 case (Pattern .Value (v1), Pattern .Unapply (TypeApply (Select (patternHole @ Ident (" patternHole" ), " unapply" ), List (tpt)), Nil , Nil ))
263273 if patternHole.symbol.owner.fullName == " scala.runtime.quoted.Matcher$" =>
264- (the[Env ], Some ( Tuple1 ( v1.seal) ))
274+ (the[Env ], matched( v1.seal))
265275
266276 case (Pattern .Value (v1), Pattern .Value (v2)) =>
267- (the[Env ], treeMatches(v1, v2) )
277+ (the[Env ], v1 =#= v2 )
268278
269279 case (Pattern .Bind (name1, body1), Pattern .Bind (name2, body2)) =>
270280 val bindEnv = the[Env ] + (scrutinee.symbol -> pattern.symbol)
271- patternMatches (body1, body2) given bindEnv
281+ (body1 =%= body2) given bindEnv
272282
273283 case (Pattern .Unapply (fun1, implicits1, patterns1), Pattern .Unapply (fun2, implicits2, patterns2)) =>
274- val funMatch = treeMatches( fun1, fun2)
284+ val funMatch = fun1 =#= fun2
275285 val implicitsMatch =
276- if (implicits1.size != implicits2.size) None
277- else foldMatchings(implicits1.zip(implicits2).map(treeMatches ): _* )
286+ if (implicits1.size != implicits2.size) notMatched
287+ else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2 ): _* )
278288 val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
279- (patEnv, foldMatchings( funMatch, implicitsMatch, patternsMatch) )
289+ (patEnv, funMatch && implicitsMatch && patternsMatch)
280290
281291 case (Pattern .Alternatives (patterns1), Pattern .Alternatives (patterns2)) =>
282292 foldPatterns(patterns1, patterns2)
283293
284294 case (Pattern .TypeTest (tpt1), Pattern .TypeTest (tpt2)) =>
285- (the[Env ], treeMatches( tpt1, tpt2) )
295+ (the[Env ], tpt1 =#= tpt2)
286296
287297 case (Pattern .WildcardPattern (), Pattern .WildcardPattern ()) =>
288- (the[Env ], Some (()) )
298+ (the[Env ], matched )
289299
290300 case _ =>
291301 if (debug)
@@ -305,30 +315,57 @@ object Matcher {
305315 |
306316 |
307317 | """ .stripMargin)
308- (the[Env ], None )
318+ (the[Env ], notMatched )
309319 }
310320
311- def foldPatterns (patterns1 : List [Pattern ], patterns2 : List [Pattern ]) given Env : (Env , Option [ Tuple ] ) = {
312- if (patterns1.size != patterns2.size) (the[Env ], None )
313- else patterns1.zip(patterns2).foldLeft((the[Env ], Option [ Tuple ](()) )) { (acc, x) =>
314- val (env, res) = patternMatches (x._1, x._2) given acc ._1
315- (env, foldMatchings( acc._2, res) )
321+ def foldPatterns (patterns1 : List [Pattern ], patterns2 : List [Pattern ]) given Env : (Env , Matching ) = {
322+ if (patterns1.size != patterns2.size) (the[Env ], notMatched )
323+ else patterns1.zip(patterns2).foldLeft((the[Env ], matched )) { (acc, x) =>
324+ val (env, res) = (x._1 =%= x._2) given acc ._1
325+ (env, acc._2 && res)
316326 }
317327 }
318328
319329 implied for Env = Set .empty
320- treeMatches (scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf [Option [Tup ]]
330+ (scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple .asInstanceOf [Option [Tup ]]
321331 }
322332
323- /** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
324- * Otherwise the result is `Some` of the concatenation of the tupples.
325- */
326- private def foldMatchings (matchings : Option [Tuple ]* ): Option [Tuple ] = {
327- // TODO improve performance
328- matchings.foldLeft[Option [Tuple ]](Some (())) {
329- case (Some (acc), Some (holes)) => Some (acc ++ holes)
330- case (_, _) => None
333+ /** Result of matching a part of an expression */
334+ private opaque type Matching = Option [Tuple ]
335+
336+ private object Matching {
337+
338+ def notMatched : Matching = None
339+ val matched : Matching = Some (())
340+ def matched (x : Any ): Matching = Some (Tuple1 (x))
341+
342+ def (self : Matching ) asOptionOfTuple : Option [Tuple ] = self
343+
344+ /** Concatenates the contents of two sucessful matchings or return a `notMatched` */
345+ // FIXME inline to avoid alocation of by name closure (see #6395)
346+ /* inline*/ def (self : Matching ) && (that : => Matching ): Matching = self match {
347+ case Some (x) =>
348+ that match {
349+ case Some (y) => Some (x ++ y)
350+ case _ => None
351+ }
352+ case _ => None
331353 }
354+
355+ /** Is this matching the result of a successful match */
356+ def (self : Matching ) isMatch : Boolean = self.isDefined
357+
358+ /** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
359+ * Otherwise the result is `Some` of the concatenation of the tupples.
360+ */
361+ def foldMatchings (matchings : Matching * ): Matching = {
362+ // TODO improve performance
363+ matchings.foldLeft[Matching ](Some (())) {
364+ case (Some (acc), Some (holes)) => Some (acc ++ holes)
365+ case (_, _) => None
366+ }
367+ }
368+
332369 }
333370
334371}
0 commit comments