@@ -4196,6 +4196,11 @@ object Types {
41964196 case tycon : TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
41974197 extension (tp : Type ) def fixForEvaluation : Type =
41984198 tp.normalized.dealias match {
4199+ // enable operations for constant singleton terms. E.g.:
4200+ // ```
4201+ // final val one = 1
4202+ // type Two = one.type + one.type
4203+ // ```
41994204 case tp : TermRef => tp.underlying
42004205 case tp => tp
42014206 }
@@ -4234,163 +4239,170 @@ object Types {
42344239 case ConstantType (Constant (n : String )) => Some (n)
42354240 case _ => None
42364241 }
4237- def isConst : Option [Type ] = args.head.fixForEvaluation match {
4242+
4243+ def isConst (tp : Type ) : Option [Type ] = tp.fixForEvaluation match {
42384244 case ConstantType (_) => Some (ConstantType (Constant (true )))
42394245 case _ => Some (ConstantType (Constant (false )))
42404246 }
4247+
4248+ def expectArgsNum (expectedNum : Int ) : Unit =
4249+ // We can use assert instead of a compiler type error because this error should not
4250+ // occur since the type signature of the operation enforces the proper number of args.
4251+ assert(args.length == expectedNum, s " Type operation expects $expectedNum arguments but found ${args.length}" )
4252+
42414253 def natValue (tp : Type ): Option [Int ] = intValue(tp).filter(n => n >= 0 && n < Int .MaxValue )
42424254
4255+ // Runs the op and returns the result as a constant type.
4256+ // If the op throws an exception, then this exception is converted into a type error.
4257+ def runConstantOp (op : => Any ): Type =
4258+ val result = try {
4259+ op
4260+ } catch {
4261+ case e : Throwable =>
4262+ throw new TypeError (e.getMessage)
4263+ }
4264+ ConstantType (Constant (result))
4265+
42434266 def constantFold1 [T ](extractor : Type => Option [T ], op : T => Any ): Option [Type ] =
4244- extractor(args.head).map(a => ConstantType (Constant (op(a))))
4267+ expectArgsNum(1 )
4268+ extractor(args.head).map(a => runConstantOp(op(a)))
42454269
42464270 def constantFold2 [T ](extractor : Type => Option [T ], op : (T , T ) => Any ): Option [Type ] =
42474271 constantFold2AB(extractor, extractor, op)
42484272
42494273 def constantFold2AB [TA , TB ](extractorA : Type => Option [TA ], extractorB : Type => Option [TB ], op : (TA , TB ) => Any ): Option [Type ] =
4274+ expectArgsNum(2 )
42504275 for {
4251- a <- extractorA(args.head )
4252- b <- extractorB(args.last )
4253- } yield ConstantType ( Constant ( op(a, b) ))
4276+ a <- extractorA(args( 0 ) )
4277+ b <- extractorB(args( 1 ) )
4278+ } yield runConstantOp( op(a, b))
42544279
42554280 def constantFold3 [TA , TB , TC ](
42564281 extractorA : Type => Option [TA ],
42574282 extractorB : Type => Option [TB ],
42584283 extractorC : Type => Option [TC ],
42594284 op : (TA , TB , TC ) => Any
42604285 ): Option [Type ] =
4286+ expectArgsNum(3 )
42614287 for {
4262- a <- extractorA(args.head )
4288+ a <- extractorA(args( 0 ) )
42634289 b <- extractorB(args(1 ))
4264- c <- extractorC(args.last )
4265- } yield ConstantType ( Constant ( op(a, b, c) ))
4290+ c <- extractorC(args( 2 ) )
4291+ } yield runConstantOp( op(a, b, c))
42664292
42674293 trace(i " compiletime constant fold $this" , typr, show = true ) {
42684294 val name = tycon.symbol.name
42694295 val owner = tycon.symbol.owner
4270- val nArgs = args.length
42714296 val constantType =
42724297 if (defn.isCompiletime_S(tycon.symbol)) {
4273- if (nArgs == 1 ) constantFold1(natValue, _ + 1 )
4274- else None
4298+ constantFold1(natValue, _ + 1 )
42754299 } else if (owner == defn.CompiletimeOpsAnyModuleClass ) name match {
4276- case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
4277- case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4278- case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4279- case tpnme.IsConst if nArgs == 1 => isConst
4300+ case tpnme.Equals => constantFold2(constValue, _ == _)
4301+ case tpnme.NotEquals => constantFold2(constValue, _ != _)
4302+ case tpnme.ToString => constantFold1(constValue, _.toString)
4303+ case tpnme.IsConst => isConst(args.head)
42804304 case _ => None
42814305 } else if (owner == defn.CompiletimeOpsIntModuleClass ) name match {
4282- case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
4283- case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => - x)
4306+ case tpnme.Abs => constantFold1(intValue, _.abs)
4307+ case tpnme.Negate => constantFold1(intValue, x => - x)
42844308 // ToString is deprecated for ops.int, and moved to ops.any
4285- case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
4286- case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
4287- case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
4288- case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
4289- case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
4290- case (_, 0 ) => throw new TypeError (" Division by 0" )
4291- case (a, b) => a / b
4292- })
4293- case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
4294- case (_, 0 ) => throw new TypeError (" Modulo by 0" )
4295- case (a, b) => a % b
4296- })
4297- case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
4298- case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
4299- case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
4300- case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
4301- case tpnme.Xor if nArgs == 2 => constantFold2(intValue, _ ^ _)
4302- case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(intValue, _ & _)
4303- case tpnme.BitwiseOr if nArgs == 2 => constantFold2(intValue, _ | _)
4304- case tpnme.ASR if nArgs == 2 => constantFold2(intValue, _ >> _)
4305- case tpnme.LSL if nArgs == 2 => constantFold2(intValue, _ << _)
4306- case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
4307- case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
4308- case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4309- case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
4310- case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
4311- case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
4312- case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
4309+ case tpnme.ToString => constantFold1(intValue, _.toString)
4310+ case tpnme.Plus => constantFold2(intValue, _ + _)
4311+ case tpnme.Minus => constantFold2(intValue, _ - _)
4312+ case tpnme.Times => constantFold2(intValue, _ * _)
4313+ case tpnme.Div => constantFold2(intValue, _ / _)
4314+ case tpnme.Mod => constantFold2(intValue, _ % _)
4315+ case tpnme.Lt => constantFold2(intValue, _ < _)
4316+ case tpnme.Gt => constantFold2(intValue, _ > _)
4317+ case tpnme.Ge => constantFold2(intValue, _ >= _)
4318+ case tpnme.Le => constantFold2(intValue, _ <= _)
4319+ case tpnme.Xor => constantFold2(intValue, _ ^ _)
4320+ case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
4321+ case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
4322+ case tpnme.ASR => constantFold2(intValue, _ >> _)
4323+ case tpnme.LSL => constantFold2(intValue, _ << _)
4324+ case tpnme.LSR => constantFold2(intValue, _ >>> _)
4325+ case tpnme.Min => constantFold2(intValue, _ min _)
4326+ case tpnme.Max => constantFold2(intValue, _ max _)
4327+ case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
4328+ case tpnme.ToLong => constantFold1(intValue, _.toLong)
4329+ case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
4330+ case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
43134331 case _ => None
43144332 } else if (owner == defn.CompiletimeOpsLongModuleClass ) name match {
4315- case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4316- case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => - x)
4317- case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4318- case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4319- case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4320- case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4321- case (_, 0L ) => throw new TypeError (" Division by 0" )
4322- case (a, b) => a / b
4323- })
4324- case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4325- case (_, 0L ) => throw new TypeError (" Modulo by 0" )
4326- case (a, b) => a % b
4327- })
4328- case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4329- case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4330- case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4331- case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4332- case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4333- case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4334- case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4335- case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4336- case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4337- case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4338- case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4339- case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4340- case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4333+ case tpnme.Abs => constantFold1(longValue, _.abs)
4334+ case tpnme.Negate => constantFold1(longValue, x => - x)
4335+ case tpnme.Plus => constantFold2(longValue, _ + _)
4336+ case tpnme.Minus => constantFold2(longValue, _ - _)
4337+ case tpnme.Times => constantFold2(longValue, _ * _)
4338+ case tpnme.Div => constantFold2(longValue, _ / _)
4339+ case tpnme.Mod => constantFold2(longValue, _ % _)
4340+ case tpnme.Lt => constantFold2(longValue, _ < _)
4341+ case tpnme.Gt => constantFold2(longValue, _ > _)
4342+ case tpnme.Ge => constantFold2(longValue, _ >= _)
4343+ case tpnme.Le => constantFold2(longValue, _ <= _)
4344+ case tpnme.Xor => constantFold2(longValue, _ ^ _)
4345+ case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
4346+ case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
4347+ case tpnme.ASR => constantFold2(longValue, _ >> _)
4348+ case tpnme.LSL => constantFold2(longValue, _ << _)
4349+ case tpnme.LSR => constantFold2(longValue, _ >>> _)
4350+ case tpnme.Min => constantFold2(longValue, _ min _)
4351+ case tpnme.Max => constantFold2(longValue, _ max _)
4352+ case tpnme.NumberOfLeadingZeros =>
43414353 constantFold1(longValue, java.lang.Long .numberOfLeadingZeros(_))
4342- case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
4343- case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
4344- case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
4354+ case tpnme.ToInt => constantFold1(longValue, _.toInt)
4355+ case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
4356+ case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
43454357 case _ => None
43464358 } else if (owner == defn.CompiletimeOpsFloatModuleClass ) name match {
4347- case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
4348- case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => - x)
4349- case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
4350- case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
4351- case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
4352- case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
4353- case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
4354- case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
4355- case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
4356- case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
4357- case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
4358- case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
4359- case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
4360- case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
4361- case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
4362- case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
4359+ case tpnme.Abs => constantFold1(floatValue, _.abs)
4360+ case tpnme.Negate => constantFold1(floatValue, x => - x)
4361+ case tpnme.Plus => constantFold2(floatValue, _ + _)
4362+ case tpnme.Minus => constantFold2(floatValue, _ - _)
4363+ case tpnme.Times => constantFold2(floatValue, _ * _)
4364+ case tpnme.Div => constantFold2(floatValue, _ / _)
4365+ case tpnme.Mod => constantFold2(floatValue, _ % _)
4366+ case tpnme.Lt => constantFold2(floatValue, _ < _)
4367+ case tpnme.Gt => constantFold2(floatValue, _ > _)
4368+ case tpnme.Ge => constantFold2(floatValue, _ >= _)
4369+ case tpnme.Le => constantFold2(floatValue, _ <= _)
4370+ case tpnme.Min => constantFold2(floatValue, _ min _)
4371+ case tpnme.Max => constantFold2(floatValue, _ max _)
4372+ case tpnme.ToInt => constantFold1(floatValue, _.toInt)
4373+ case tpnme.ToLong => constantFold1(floatValue, _.toLong)
4374+ case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
43634375 case _ => None
43644376 } else if (owner == defn.CompiletimeOpsDoubleModuleClass ) name match {
4365- case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
4366- case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => - x)
4367- case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
4368- case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
4369- case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
4370- case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
4371- case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
4372- case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
4373- case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
4374- case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
4375- case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
4376- case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
4377- case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
4378- case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
4379- case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
4380- case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
4377+ case tpnme.Abs => constantFold1(doubleValue, _.abs)
4378+ case tpnme.Negate => constantFold1(doubleValue, x => - x)
4379+ case tpnme.Plus => constantFold2(doubleValue, _ + _)
4380+ case tpnme.Minus => constantFold2(doubleValue, _ - _)
4381+ case tpnme.Times => constantFold2(doubleValue, _ * _)
4382+ case tpnme.Div => constantFold2(doubleValue, _ / _)
4383+ case tpnme.Mod => constantFold2(doubleValue, _ % _)
4384+ case tpnme.Lt => constantFold2(doubleValue, _ < _)
4385+ case tpnme.Gt => constantFold2(doubleValue, _ > _)
4386+ case tpnme.Ge => constantFold2(doubleValue, _ >= _)
4387+ case tpnme.Le => constantFold2(doubleValue, _ <= _)
4388+ case tpnme.Min => constantFold2(doubleValue, _ min _)
4389+ case tpnme.Max => constantFold2(doubleValue, _ max _)
4390+ case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
4391+ case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
4392+ case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
43814393 case _ => None
43824394 } else if (owner == defn.CompiletimeOpsStringModuleClass ) name match {
4383- case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4384- case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4385- case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4386- case tpnme.Substring if nArgs == 3 =>
4395+ case tpnme.Plus => constantFold2(stringValue, _ + _)
4396+ case tpnme.Length => constantFold1(stringValue, _.length)
4397+ case tpnme.Matches => constantFold2(stringValue, _ matches _)
4398+ case tpnme.Substring =>
43874399 constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
43884400 case _ => None
43894401 } else if (owner == defn.CompiletimeOpsBooleanModuleClass ) name match {
4390- case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => ! x)
4391- case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
4392- case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
4393- case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
4402+ case tpnme.Not => constantFold1(boolValue, x => ! x)
4403+ case tpnme.And => constantFold2(boolValue, _ && _)
4404+ case tpnme.Or => constantFold2(boolValue, _ || _)
4405+ case tpnme.Xor => constantFold2(boolValue, _ ^ _)
43944406 case _ => None
43954407 } else None
43964408
0 commit comments