|
1 | 1 | package dotty.tools.dotc |
2 | 2 | package transform |
3 | 3 |
|
4 | | -import TreeTransforms.{ MiniPhaseTransform, TransformerInfo } |
5 | 4 | import ast.Trees._, ast.tpd, core._ |
6 | 5 | import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._ |
7 | 6 | import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._ |
| 7 | +import MegaPhase.MiniPhase |
8 | 8 |
|
9 | 9 | import scala.collection.mutable |
10 | 10 |
|
11 | 11 | /** Specializes classes that inherit from `FunctionN` where there exists a |
12 | 12 | * specialized form. |
13 | 13 | */ |
14 | | -class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer { |
| 14 | +class SpecializeFunctions extends MiniPhase with InfoTransformer { |
15 | 15 | import ast.tpd._ |
16 | 16 | val phaseName = "specializeFunctions" |
| 17 | + override def runsAfter = Set(classOf[ElimByName]) |
17 | 18 |
|
18 | | - private[this] var _blacklistedSymbols: List[Symbol] = _ |
| 19 | + private val jFunction = "scala.compat.java8.JFunction".toTermName |
19 | 20 |
|
20 | | - private def blacklistedSymbols(implicit ctx: Context): List[Symbol] = { |
21 | | - if (_blacklistedSymbols eq null) _blacklistedSymbols = List( |
22 | | - ctx.getClassIfDefined("scala.math.Ordering").asClass.membersNamed("Ops".toTypeName).first.symbol |
23 | | - ) |
24 | | - |
25 | | - _blacklistedSymbols |
26 | | - } |
27 | | - |
28 | | - /** Transforms the type to include decls for specialized applys and replace |
29 | | - * the class parents with specialized versions. |
30 | | - */ |
31 | | - def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match { |
32 | | - case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) => { |
| 21 | + /** Transforms the type to include decls for specialized applys */ |
| 22 | + override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match { |
| 23 | + case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) && derivesFromFn012(sym) => |
33 | 24 | var newApplys = Map.empty[Name, Symbol] |
34 | 25 |
|
35 | | - val newParents = tp.parents.mapConserve { parent => |
36 | | - List(0, 1, 2, 3).flatMap { arity => |
37 | | - val func = defn.FunctionClass(arity) |
38 | | - if (!parent.derivesFrom(func)) Nil |
39 | | - else { |
40 | | - val typeParams = tp.typeRef.baseArgInfos(func) |
41 | | - val interface = specInterface(typeParams) |
42 | | - |
43 | | - if (interface.exists) { |
44 | | - if (tp.decls.lookup(nme.apply).exists) { |
45 | | - val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init) |
46 | | - newApplys = newApplys + (specializedMethodName -> interface) |
47 | | - } |
| 26 | + var arity = 0 |
| 27 | + while (arity < 3) { |
| 28 | + val func = defn.FunctionClass(arity) |
| 29 | + if (tp.derivesFrom(func)) { |
| 30 | + val typeParams = tp.cls.typeRef.baseType(func).argInfos |
| 31 | + val isSpecializable = |
| 32 | + defn.isSpecializableFunction( |
| 33 | + sym.asClass, |
| 34 | + typeParams.init, |
| 35 | + typeParams.last |
| 36 | + ) |
48 | 37 |
|
49 | | - if (parent.isRef(func)) List(interface.typeRef) |
50 | | - else Nil |
51 | | - } |
52 | | - else Nil |
| 38 | + if (isSpecializable && tp.decls.lookup(nme.apply).exists) { |
| 39 | + val interface = specInterface(typeParams) |
| 40 | + val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init) |
| 41 | + newApplys += (specializedMethodName -> interface) |
53 | 42 | } |
54 | 43 | } |
55 | | - .headOption |
56 | | - .getOrElse(parent) |
| 44 | + arity += 1 |
57 | 45 | } |
58 | 46 |
|
59 | 47 | def newDecls = |
60 | | - if (newApplys.isEmpty) tp.decls |
61 | | - else |
62 | | - newApplys.toList.map { case (name, interface) => |
63 | | - ctx.newSymbol( |
64 | | - sym, |
65 | | - name, |
66 | | - Flags.Override | Flags.Method, |
67 | | - interface.info.decls.lookup(name).info |
68 | | - ) |
69 | | - } |
70 | | - .foldLeft(tp.decls.cloneScope) { |
71 | | - (scope, sym) => scope.enter(sym); scope |
72 | | - } |
| 48 | + newApplys.toList.map { case (name, interface) => |
| 49 | + ctx.newSymbol( |
| 50 | + sym, |
| 51 | + name, |
| 52 | + Flags.Override | Flags.Method | Flags.Synthetic, |
| 53 | + interface.info.decls.lookup(name).info |
| 54 | + ) |
| 55 | + } |
| 56 | + .foldLeft(tp.decls.cloneScope) { |
| 57 | + (scope, sym) => scope.enter(sym); scope |
| 58 | + } |
73 | 59 |
|
74 | | - tp.derivedClassInfo( |
75 | | - classParents = newParents, |
76 | | - decls = newDecls |
77 | | - ) |
78 | | - } |
| 60 | + if (newApplys.isEmpty) tp |
| 61 | + else tp.derivedClassInfo(decls = newDecls) |
79 | 62 |
|
80 | 63 | case _ => tp |
81 | 64 | } |
82 | 65 |
|
83 | 66 | /** Transforms the `Template` of the classes to contain forwarders from the |
84 | | - * generic applys to the specialized ones. Also replaces parents of the |
85 | | - * class on the tree level and inserts the specialized applys in the |
86 | | - * template body. |
| 67 | + * generic applys to the specialized ones. Also inserts the specialized applys |
| 68 | + * in the template body. |
87 | 69 | */ |
88 | | - override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = { |
89 | | - val applyBuf = new mutable.ListBuffer[Tree] |
90 | | - val newBody = tree.body.mapConserve { |
91 | | - case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => { |
92 | | - val specName = nme.apply.specializedFunction( |
93 | | - dt.tpe.widen.finalResultType, |
94 | | - dt.vparamss.head.map(_.symbol.info) |
95 | | - ) |
96 | | - |
97 | | - val specializedApply = tree.symbol.enclosingClass.info.decls.lookup(specName)//member(specName).symbol |
98 | | - //val specializedApply = tree.symbol.enclosingClass.info.member(specName).symbol |
99 | | - |
100 | | - if (false) { |
101 | | - println(tree.symbol.enclosingClass.show) |
102 | | - println("'" + specName.show + "'") |
103 | | - println(specializedApply) |
104 | | - println(specializedApply.exists) |
105 | | - } |
106 | | - |
107 | | - |
108 | | - if (specializedApply.exists) { |
109 | | - val apply = specializedApply.asTerm |
110 | | - val specializedDecl = |
111 | | - polyDefDef(apply, trefs => vrefss => { |
112 | | - dt.rhs |
113 | | - .changeOwner(dt.symbol, apply) |
114 | | - .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol)) |
| 70 | + override def transformTemplate(tree: Template)(implicit ctx: Context) = { |
| 71 | + val cls = tree.symbol.enclosingClass.asClass |
| 72 | + if (derivesFromFn012(cls)) { |
| 73 | + val applyBuf = new mutable.ListBuffer[Tree] |
| 74 | + val newBody = tree.body.mapConserve { |
| 75 | + case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => |
| 76 | + val typeParams = dt.vparamss.head.map(_.symbol.info) |
| 77 | + val retType = dt.tpe.widen.finalResultType |
| 78 | + |
| 79 | + val specName = specializedName(nme.apply, typeParams :+ retType) |
| 80 | + val specializedApply = cls.info.decls.lookup(specName) |
| 81 | + if (specializedApply.exists) { |
| 82 | + val apply = specializedApply.asTerm |
| 83 | + val specializedDecl = |
| 84 | + polyDefDef(apply, trefs => vrefss => { |
| 85 | + dt.rhs |
| 86 | + .changeOwner(dt.symbol, apply) |
| 87 | + .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol)) |
| 88 | + }) |
| 89 | + applyBuf += specializedDecl |
| 90 | + |
| 91 | + // create a forwarding to the specialized apply |
| 92 | + cpy.DefDef(dt)(rhs = { |
| 93 | + tpd |
| 94 | + .ref(apply) |
| 95 | + .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol))) |
115 | 96 | }) |
116 | | - applyBuf += specializedDecl |
117 | | - |
118 | | - // create a forwarding to the specialized apply |
119 | | - cpy.DefDef(dt)(rhs = { |
120 | | - tpd |
121 | | - .ref(apply) |
122 | | - .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol))) |
123 | | - }) |
124 | | - } else dt |
125 | | - } |
126 | | - case x => x |
127 | | - } |
128 | | - |
129 | | - val missing: List[TypeTree] = List(0, 1, 2, 3).flatMap { arity => |
130 | | - val func = defn.FunctionClass(arity) |
131 | | - val tr = tree.symbol.enclosingClass.typeRef |
| 97 | + } else dt |
132 | 98 |
|
133 | | - if (!tr.parents.exists(_.isRef(func))) Nil |
134 | | - else { |
135 | | - val typeParams = tr.baseArgInfos(func) |
136 | | - val interface = specInterface(typeParams) |
137 | | - |
138 | | - if (interface.exists) List(interface.info) |
139 | | - else Nil |
| 99 | + case x => x |
140 | 100 | } |
141 | | - }.map(TypeTree) |
142 | 101 |
|
143 | | - cpy.Template(tree)( |
144 | | - parents = tree.parents ++ missing, |
145 | | - body = applyBuf.toList ++ newBody |
146 | | - ) |
| 102 | + cpy.Template(tree)( |
| 103 | + body = applyBuf.toList ::: newBody |
| 104 | + ) |
| 105 | + } else tree |
147 | 106 | } |
148 | 107 |
|
149 | 108 | /** Dispatch to specialized `apply`s in user code when available */ |
150 | | - override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) = |
| 109 | + override def transformApply(tree: Apply)(implicit ctx: Context) = |
151 | 110 | tree match { |
152 | | - case app @ Apply(fun, args) |
| 111 | + case Apply(fun, args) |
153 | 112 | if fun.symbol.name == nme.apply && |
154 | 113 | fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length)) |
155 | | - => { |
| 114 | + => |
156 | 115 | val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias) |
157 | | - val specializedApply = specializedName(nme.apply, params) |
158 | | - |
159 | | - if (!params.exists(_.isInstanceOf[ExprType]) && fun.symbol.owner.info.decls.lookup(specializedApply).exists) { |
| 116 | + val isSpecializable = |
| 117 | + defn.isSpecializableFunction( |
| 118 | + fun.symbol.owner.asClass, |
| 119 | + params.init, |
| 120 | + params.last) |
| 121 | + |
| 122 | + if (isSpecializable && !params.exists(_.isInstanceOf[ExprType])) { |
| 123 | + val specializedApply = specializedName(nme.apply, params) |
160 | 124 | val newSel = fun match { |
161 | 125 | case Select(qual, _) => |
162 | 126 | qual.select(specializedApply) |
163 | | - case _ => { |
| 127 | + case _ => |
164 | 128 | (fun.tpe: @unchecked) match { |
165 | 129 | case TermRef(prefix: ThisType, name) => |
166 | 130 | tpd.This(prefix.cls).select(specializedApply) |
167 | 131 | case TermRef(prefix: NamedType, name) => |
168 | 132 | tpd.ref(prefix).select(specializedApply) |
169 | 133 | } |
170 | | - } |
171 | 134 | } |
172 | 135 |
|
173 | 136 | newSel.appliedToArgs(args) |
174 | 137 | } |
175 | 138 | else tree |
176 | | - } |
| 139 | + |
177 | 140 | case _ => tree |
178 | 141 | } |
179 | 142 |
|
180 | | - @inline private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) = |
181 | | - name.specializedFor(args, args.map(_.typeSymbol.name), Nil, Nil) |
| 143 | + private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) = |
| 144 | + name.specializedFunction(args.last, args.init) |
182 | 145 |
|
183 | | - @inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = { |
184 | | - val specName = |
185 | | - ("JFunction" + (typeParams.length - 1)).toTermName |
186 | | - .specializedFunction(typeParams.last, typeParams.init) |
| 146 | + private def functionName(typeParams: List[Type])(implicit ctx: Context) = |
| 147 | + jFunction ++ (typeParams.length - 1).toString |
187 | 148 |
|
188 | | - ctx.getClassIfDefined("scala.compat.java8.".toTermName ++ specName) |
189 | | - } |
| 149 | + private def specInterface(typeParams: List[Type])(implicit ctx: Context) = |
| 150 | + ctx.getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init)) |
| 151 | + |
| 152 | + private def derivesFromFn012(sym: Symbol)(implicit ctx: Context): Boolean = |
| 153 | + sym.derivesFrom(defn.FunctionClass(0)) || |
| 154 | + sym.derivesFrom(defn.FunctionClass(1)) || |
| 155 | + sym.derivesFrom(defn.FunctionClass(2)) |
190 | 156 | } |
0 commit comments