@@ -29,12 +29,12 @@ object ConstFold:
2929 def Apply [T <: Apply ](tree : T )(using Context ): T =
3030 tree.fun match
3131 case Select (xt, op) if foldedBinops.contains(op) =>
32- xt.tpe.widenTermRefExpr.normalized match
33- case ConstantType (x) =>
32+ treeConstant(xt) match
33+ case Some (x) =>
3434 tree.args match
3535 case yt :: Nil =>
36- yt.tpe.widenTermRefExpr.normalized match
37- case ConstantType (y) => tree.withFoldedType(foldBinop(op, x, y))
36+ treeConstant(yt) match
37+ case Some (y) => tree.withFoldedType(foldBinop(op, x, y))
3838 case _ => tree
3939 case _ => tree
4040 case _ => tree
@@ -46,8 +46,8 @@ object ConstFold:
4646
4747 def Select [T <: Select ](tree : T )(using Context ): T =
4848 if foldedUnops.contains(tree.name) then
49- tree.qualifier.tpe.widenTermRefExpr.normalized match
50- case ConstantType (x) => tree.withFoldedType(foldUnop(tree.name, x))
49+ treeConstant( tree.qualifier) match
50+ case Some (x) => tree.withFoldedType(foldUnop(tree.name, x))
5151 case _ => tree
5252 else tree
5353
@@ -59,6 +59,16 @@ object ConstFold:
5959 tree.withFoldedType(Constant (targ.tpe))
6060 case _ => tree
6161
62+ private def treeConstant (tree : Tree )(using Context ): Option [Constant ] =
63+ tree match
64+ case Inlined (_, Nil , expr) => treeConstant(expr)
65+ case Typed (expr, _) => treeConstant(expr)
66+ case Literal (c) if c.tag == Constants .NullTag => Some (c)
67+ case _ =>
68+ tree.tpe.widenTermRefExpr.normalized.simplified match
69+ case ConstantType (c) => Some (c)
70+ case _ => None
71+
6272 extension [T <: Tree ](tree : T )(using Context )
6373 private def withFoldedType (c : Constant | Null ): T =
6474 if c == null then tree else tree.withType(ConstantType (c)).asInstanceOf [T ]
@@ -164,15 +174,24 @@ object ConstFold:
164174 case _ => null
165175 }
166176 private def foldStringOp (op : Name , x : Constant , y : Constant ): Constant = op match {
167- case nme.ADD => Constant (x.stringValue + y.stringValue)
177+ case nme.ADD => Constant (x.stringValue + y.stringValue)
168178 case nme.EQ => Constant (x.stringValue == y.stringValue)
179+ case nme.NE => Constant (x.stringValue != y.stringValue)
169180 case _ => null
170181 }
171182
183+ private def foldNullOp (op : Name , x : Constant , y : Constant ): Constant =
184+ assert(x.tag == NullTag || y.tag == NullTag )
185+ op match
186+ case nme.EQ => Constant (x.tag == y.tag)
187+ case nme.NE => Constant (x.tag != y.tag)
188+ case _ => null
189+
172190 private def foldBinop (op : Name , x : Constant , y : Constant ): Constant =
173191 val optag =
174192 if (x.tag == y.tag) x.tag
175193 else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag)
194+ else if (x.tag == NullTag || y.tag == NullTag ) NullTag
176195 else NoTag
177196
178197 try optag match
@@ -182,6 +201,7 @@ object ConstFold:
182201 case FloatTag => foldFloatOp(op, x, y)
183202 case DoubleTag => foldDoubleOp(op, x, y)
184203 case StringTag => foldStringOp(op, x, y)
204+ case NullTag => foldNullOp(op, x, y)
185205 case _ => null
186206 catch case ex : ArithmeticException => null // the code will crash at runtime,
187207 // but that is better than the
0 commit comments