@@ -3,6 +3,7 @@ package backend
33package jvm
44
55import scala .annotation .switch
6+ import scala .collection .mutable .SortedMap
67
78import scala .tools .asm
89import scala .tools .asm .{Handle , Label , Opcodes }
@@ -840,61 +841,170 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
840841 generatedType
841842 }
842843
843- /*
844- * A Match node contains one or more case clauses,
845- * each case clause lists one or more Int values to use as keys, and a code block.
846- * Except the "default" case clause which (if it exists) doesn't list any Int key.
847- *
848- * On a first pass over the case clauses, we flatten the keys and their targets (the latter represented with asm.Labels).
849- * That representation allows JCodeMethodV to emit a lookupswitch or a tableswitch.
850- *
851- * On a second pass, we emit the switch blocks, one for each different target.
844+ /* A Match node contains one or more case clauses, each case clause lists one or more
845+ * Int/String values to use as keys, and a code block. The exception is the "default" case
846+ * clause which doesn't list any key (there is exactly one of these per match).
852847 */
853848 private def genMatch (tree : Match ): BType = tree match {
854849 case Match (selector, cases) =>
855850 lineNumber(tree)
856- genLoad(selector, INT )
857851 val generatedType = tpeTK(tree)
852+ val postMatch = new asm.Label
858853
859- var flatKeys : List [Int ] = Nil
860- var targets : List [asm.Label ] = Nil
861- var default : asm.Label = null
862- var switchBlocks : List [(asm.Label , Tree )] = Nil
863-
864- // collect switch blocks and their keys, but don't emit yet any switch-block.
865- for (caze @ CaseDef (pat, guard, body) <- cases) {
866- assert(guard == tpd.EmptyTree , guard)
867- val switchBlockPoint = new asm.Label
868- switchBlocks ::= (switchBlockPoint, body)
869- pat match {
870- case Literal (value) =>
871- flatKeys ::= value.intValue
872- targets ::= switchBlockPoint
873- case Ident (nme.WILDCARD ) =>
874- assert(default == null , s " multiple default targets in a Match node, at ${tree.span}" )
875- default = switchBlockPoint
876- case Alternative (alts) =>
877- alts foreach {
878- case Literal (value) =>
879- flatKeys ::= value.intValue
880- targets ::= switchBlockPoint
881- case _ =>
882- abort(s " Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}" )
883- }
884- case _ =>
885- abort(s " Invalid pattern in Match node: $tree at: ${tree.span}" )
854+ // Only two possible selector types exist in `Match` trees at this point: Int and String
855+ if (tpeTK(selector) == INT ) {
856+
857+ /* On a first pass over the case clauses, we flatten the keys and their
858+ * targets (the latter represented with asm.Labels). That representation
859+ * allows JCodeMethodV to emit a lookupswitch or a tableswitch.
860+ *
861+ * On a second pass, we emit the switch blocks, one for each different target.
862+ */
863+
864+ var flatKeys : List [Int ] = Nil
865+ var targets : List [asm.Label ] = Nil
866+ var default : asm.Label = null
867+ var switchBlocks : List [(asm.Label , Tree )] = Nil
868+
869+ genLoad(selector, INT )
870+
871+ // collect switch blocks and their keys, but don't emit yet any switch-block.
872+ for (caze @ CaseDef (pat, guard, body) <- cases) {
873+ assert(guard == tpd.EmptyTree , guard)
874+ val switchBlockPoint = new asm.Label
875+ switchBlocks ::= (switchBlockPoint, body)
876+ pat match {
877+ case Literal (value) =>
878+ flatKeys ::= value.intValue
879+ targets ::= switchBlockPoint
880+ case Ident (nme.WILDCARD ) =>
881+ assert(default == null , s " multiple default targets in a Match node, at ${tree.span}" )
882+ default = switchBlockPoint
883+ case Alternative (alts) =>
884+ alts foreach {
885+ case Literal (value) =>
886+ flatKeys ::= value.intValue
887+ targets ::= switchBlockPoint
888+ case _ =>
889+ abort(s " Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}" )
890+ }
891+ case _ =>
892+ abort(s " Invalid pattern in Match node: $tree at: ${tree.span}" )
893+ }
886894 }
887- }
888895
889- bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY )
896+ bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY )
890897
891- // emit switch-blocks.
892- val postMatch = new asm.Label
893- for (sb <- switchBlocks.reverse) {
894- val (caseLabel, caseBody) = sb
895- markProgramPoint(caseLabel)
896- genLoad(caseBody, generatedType)
897- bc goTo postMatch
898+ // emit switch-blocks.
899+ for (sb <- switchBlocks.reverse) {
900+ val (caseLabel, caseBody) = sb
901+ markProgramPoint(caseLabel)
902+ genLoad(caseBody, generatedType)
903+ bc goTo postMatch
904+ }
905+ } else {
906+
907+ /* Since the JVM doesn't have a way to switch on a string, we switch
908+ * on the `hashCode` of the string then do an `equals` check (with a
909+ * possible second set of jumps if blocks can be reach from multiple
910+ * string alternatives).
911+ *
912+ * This mirrors the way that Java compiles `switch` on Strings.
913+ */
914+
915+ var default : asm.Label = null
916+ var indirectBlocks : List [(asm.Label , Tree )] = Nil
917+
918+ import scala .collection .mutable
919+
920+ // Cases grouped by their hashCode
921+ val casesByHash = SortedMap .empty[Int , List [(String , Either [asm.Label , Tree ])]]
922+ var caseFallback : Tree = null
923+
924+ for (caze @ CaseDef (pat, guard, body) <- cases) {
925+ assert(guard == tpd.EmptyTree , guard)
926+ pat match {
927+ case Literal (value) =>
928+ val strValue = value.stringValue
929+ casesByHash.updateWith(strValue.## ) { existingCasesOpt =>
930+ val newCase = (strValue, Right (body))
931+ Some (newCase :: existingCasesOpt.getOrElse(Nil ))
932+ }
933+ case Ident (nme.WILDCARD ) =>
934+ assert(default == null , s " multiple default targets in a Match node, at ${tree.span}" )
935+ default = new asm.Label
936+ indirectBlocks ::= (default, body)
937+ case Alternative (alts) =>
938+ // We need an extra basic block since multiple strings can lead to this code
939+ val indirectCaseGroupLabel = new asm.Label
940+ indirectBlocks ::= (indirectCaseGroupLabel, body)
941+ alts foreach {
942+ case Literal (value) =>
943+ val strValue = value.stringValue
944+ casesByHash.updateWith(strValue.## ) { existingCasesOpt =>
945+ val newCase = (strValue, Left (indirectCaseGroupLabel))
946+ Some (newCase :: existingCasesOpt.getOrElse(Nil ))
947+ }
948+ case _ =>
949+ abort(s " Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}" )
950+ }
951+
952+ case _ =>
953+ abort(s " Invalid pattern in Match node: $tree at: ${tree.span}" )
954+ }
955+ }
956+
957+ // Organize the hashCode options into switch cases
958+ var flatKeys : List [Int ] = Nil
959+ var targets : List [asm.Label ] = Nil
960+ var hashBlocks : List [(asm.Label , List [(String , Either [asm.Label , Tree ])])] = Nil
961+ for ((hashValue, hashCases) <- casesByHash) {
962+ val switchBlockPoint = new asm.Label
963+ hashBlocks ::= (switchBlockPoint, hashCases)
964+ flatKeys ::= hashValue
965+ targets ::= switchBlockPoint
966+ }
967+
968+ // Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it
969+ genLoadIf(
970+ If (
971+ tree.selector.select(defn.Any_== ).appliedTo(nullLiteral),
972+ Literal (Constant (0 )),
973+ tree.selector.select(defn.Any_hashCode ).appliedToNone
974+ ),
975+ INT
976+ )
977+ bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY )
978+
979+ // emit blocks for each hash case
980+ for ((hashLabel, caseAlternatives) <- hashBlocks.reverse) {
981+ markProgramPoint(hashLabel)
982+ for ((caseString, indirectLblOrBody) <- caseAlternatives) {
983+ val comparison = if (caseString == null ) defn.Any_== else defn.Any_equals
984+ val condp = Literal (Constant (caseString)).select(defn.Any_== ).appliedTo(tree.selector)
985+ val keepGoing = new asm.Label
986+ indirectLblOrBody match {
987+ case Left (jump) =>
988+ genCond(condp, jump, keepGoing, targetIfNoJump = keepGoing)
989+
990+ case Right (caseBody) =>
991+ val thisCaseMatches = new asm.Label
992+ genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches)
993+ markProgramPoint(thisCaseMatches)
994+ genLoad(caseBody, generatedType)
995+ bc goTo postMatch
996+ }
997+ markProgramPoint(keepGoing)
998+ }
999+ bc goTo default
1000+ }
1001+
1002+ // emit blocks for common patterns
1003+ for ((caseLabel, caseBody) <- indirectBlocks.reverse) {
1004+ markProgramPoint(caseLabel)
1005+ genLoad(caseBody, generatedType)
1006+ bc goTo postMatch
1007+ }
8981008 }
8991009
9001010 markProgramPoint(postMatch)
0 commit comments