@@ -29,12 +29,12 @@ object ConstFold:
2929 def Apply [T <: Apply ](tree : T )(using Context ): T =
3030 tree.fun match
3131 case Select (xt, op) if foldedBinops.contains(op) =>
32- xt.tpe.widenTermRefExpr.normalized match
33- case ConstantType (x) =>
32+ xt match
33+ case ConstantTree (x) =>
3434 tree.args match
3535 case yt :: Nil =>
36- yt.tpe.widenTermRefExpr.normalized match
37- case ConstantType (y) => tree.withFoldedType(foldBinop(op, x, y))
36+ yt match
37+ case ConstantTree (y) => tree.withFoldedType(foldBinop(op, x, y))
3838 case _ => tree
3939 case _ => tree
4040 case _ => tree
@@ -46,8 +46,8 @@ object ConstFold:
4646
4747 def Select [T <: Select ](tree : T )(using Context ): T =
4848 if foldedUnops.contains(tree.name) then
49- tree.qualifier.tpe.widenTermRefExpr.normalized match
50- case ConstantType (x) => tree.withFoldedType(foldUnop(tree.name, x))
49+ tree.qualifier match
50+ case ConstantTree (x) => tree.withFoldedType(foldUnop(tree.name, x))
5151 case _ => tree
5252 else tree
5353
@@ -59,6 +59,17 @@ object ConstFold:
5959 tree.withFoldedType(Constant (targ.tpe))
6060 case _ => tree
6161
62+ private object ConstantTree :
63+ def unapply (tree : Tree )(using Context ): Option [Constant ] =
64+ tree match
65+ case Inlined (_, Nil , expr) => unapply(expr)
66+ case Typed (expr, _) => unapply(expr)
67+ case Literal (c) if c.tag == Constants .NullTag => Some (c)
68+ case _ =>
69+ tree.tpe.widenTermRefExpr.normalized.simplified match
70+ case ConstantType (c) => Some (c)
71+ case _ => None
72+
6273 extension [T <: Tree ](tree : T )(using Context )
6374 private def withFoldedType (c : Constant | Null ): T =
6475 if c == null then tree else tree.withType(ConstantType (c)).asInstanceOf [T ]
@@ -164,15 +175,24 @@ object ConstFold:
164175 case _ => null
165176 }
166177 private def foldStringOp (op : Name , x : Constant , y : Constant ): Constant = op match {
167- case nme.ADD => Constant (x.stringValue + y.stringValue)
178+ case nme.ADD => Constant (x.stringValue + y.stringValue)
168179 case nme.EQ => Constant (x.stringValue == y.stringValue)
180+ case nme.NE => Constant (x.stringValue != y.stringValue)
169181 case _ => null
170182 }
171183
184+ private def foldNullOp (op : Name , x : Constant , y : Constant ): Constant =
185+ assert(x.tag == NullTag || y.tag == NullTag )
186+ op match
187+ case nme.EQ => Constant (x.tag == y.tag)
188+ case nme.NE => Constant (x.tag != y.tag)
189+ case _ => null
190+
172191 private def foldBinop (op : Name , x : Constant , y : Constant ): Constant =
173192 val optag =
174193 if (x.tag == y.tag) x.tag
175194 else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag)
195+ else if (x.tag == NullTag || y.tag == NullTag ) NullTag
176196 else NoTag
177197
178198 try optag match
@@ -182,6 +202,7 @@ object ConstFold:
182202 case FloatTag => foldFloatOp(op, x, y)
183203 case DoubleTag => foldDoubleOp(op, x, y)
184204 case StringTag => foldStringOp(op, x, y)
205+ case NullTag => foldNullOp(op, x, y)
185206 case _ => null
186207 catch case ex : ArithmeticException => null // the code will crash at runtime,
187208 // but that is better than the
0 commit comments