@@ -2,30 +2,40 @@ package dotty.tools.dotc
22package ast
33
44import core ._
5- import Symbols ._ , Types ._ , Contexts ._ , Flags ._ , Constants ._
6- import StdNames .nme
7-
8- /** Generate proxy classes for @main functions.
9- * A function like
10- *
11- * @main def f(x: S, ys: T*) = ...
12- *
13- * would be translated to something like
14- *
15- * import CommandLineParser._
16- * class f {
17- * @static def main(args: Array[String]): Unit =
18- * try
19- * f(
20- * parseArgument[S](args, 0),
21- * parseRemainingArguments[T](args, 1): _*
22- * )
23- * catch case err: ParseError => showError(err)
24- * }
25- */
5+ import Symbols ._ , Types ._ , Contexts ._ , Decorators ._ , util .Spans ._ , Flags ._ , Constants ._
6+ import StdNames .{nme , tpnme }
7+ import ast .Trees ._
8+ import Names .Name
9+ import Comments .Comment
10+ import NameKinds .DefaultGetterName
11+ import Annotations .Annotation
12+
2613object MainProxies {
2714
28- def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
15+ /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16+ def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
17+ mainAnnotationProxies(stats) ++ mainProxies(stats)
18+ }
19+
20+ /** Generate proxy classes for @main functions.
21+ * A function like
22+ *
23+ * @main def f(x: S, ys: T*) = ...
24+ *
25+ * would be translated to something like
26+ *
27+ * import CommandLineParser._
28+ * class f {
29+ * @static def main(args: Array[String]): Unit =
30+ * try
31+ * f(
32+ * parseArgument[S](args, 0),
33+ * parseRemainingArguments[T](args, 1): _*
34+ * )
35+ * catch case err: ParseError => showError(err)
36+ * }
37+ */
38+ private def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
2939 import tpd ._
3040 def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
3141 case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -39,7 +49,7 @@ object MainProxies {
3949 }
4050
4151 import untpd ._
42- def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
52+ private def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
4353 val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot ).get.tree.span
4454 def pos = mainFun.sourcePos
4555 val argsRef = Ident (nme.args)
@@ -116,4 +126,298 @@ object MainProxies {
116126 }
117127 result
118128 }
129+
130+ private type DefaultValueSymbols = Map [Int , Symbol ]
131+ private type ParameterAnnotationss = Seq [Seq [Annotation ]]
132+
133+ /**
134+ * Generate proxy classes for main functions.
135+ * A function like
136+ *
137+ * /* *
138+ * * Lorem ipsum dolor sit amet
139+ * * consectetur adipiscing elit.
140+ * *
141+ * * @param x my param x
142+ * * @param ys all my params y
143+ * */
144+ * @myMain(80) def f(
145+ * @myMain.Alias("myX") x: S,
146+ * ys: T*
147+ * ) = ...
148+ *
149+ * would be translated to something like
150+ *
151+ * final class f {
152+ * static def main(args: Array[String]): Unit = {
153+ * val cmd = new myMain(80).command(
154+ * info = new CommandInfo(
155+ * name = "f",
156+ * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
157+ * parameters = Seq(
158+ * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
159+ * new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
160+ * )
161+ * )
162+ * args = args
163+ * )
164+ *
165+ * val args0: () => S = cmd.argGetter[S](0, None)
166+ * val args1: () => Seq[T] = cmd.varargGetter[T]
167+ *
168+ * cmd.run(() => f(args0(), args1()*))
169+ * }
170+ * }
171+ */
172+ private def mainAnnotationProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
173+ import tpd ._
174+
175+ /**
176+ * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
177+ * point of the compilation, they must be explicitly passed by [[mainProxy ]].
178+ */
179+ def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
180+ scope match {
181+ case TypeDef (_, template : Template ) =>
182+ template.body.flatMap((_ : Tree ) match {
183+ case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
184+ val DefaultGetterName .NumberedInfo (index) = dd.name.info
185+ List (index -> dd.symbol)
186+ case _ => Nil
187+ }).toMap
188+ case _ => Map .empty
189+ }
190+
191+ /** Computes the list of main methods present in the code. */
192+ def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
193+ case stat : DefDef =>
194+ val sym = stat.symbol
195+ sym.annotations.filter(_.matches(defn.MainAnnotationClass )) match {
196+ case Nil =>
197+ Nil
198+ case _ :: Nil =>
199+ val paramAnnotations = stat.paramss.flatMap(_.map(
200+ valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation ))
201+ ))
202+ (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
203+ case mainAnnot :: others =>
204+ report.error(s " method cannot have multiple main annotations " , mainAnnot.tree)
205+ Nil
206+ }
207+ case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
208+ mainMethods(stat, impl.body)
209+ case _ =>
210+ Nil
211+ }
212+
213+ // Assuming that the top-level object was already generated, all main methods will have a scope
214+ mainMethods(EmptyTree , stats).flatMap(mainAnnotationProxy)
215+ }
216+
217+ private def mainAnnotationProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): Option [TypeDef ] = {
218+ val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass ).get
219+ def pos = mainFun.sourcePos
220+
221+ val documentation = new Documentation (docComment)
222+
223+ /** () => value */
224+ def unitToValue (value : Tree ): Tree =
225+ val defDef = DefDef (nme.ANON_FUN , List (Nil ), TypeTree (), value)
226+ Block (defDef, Closure (Nil , Ident (nme.ANON_FUN ), EmptyTree ))
227+
228+ /** Generate a list of trees containing the ParamInfo instantiations.
229+ *
230+ * A ParamInfo has the following shape
231+ * ```
232+ * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
233+ * ```
234+ */
235+ def parameterInfos (mt : MethodType ): List [Tree ] =
236+ extension (tree : Tree ) def withProperty (sym : Symbol , args : List [Tree ]) =
237+ Apply (Select (tree, sym.name), args)
238+
239+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
240+ val param = paramName.toString
241+ val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
242+ val paramType = paramType0.dealias
243+
244+ val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + " ." + paramType.show
245+ val hasDefault = defaultValueSymbols.contains(idx)
246+ val isRepeated = formal.isRepeatedParam
247+ val paramDoc = documentation.argDocs.getOrElse(param, " " )
248+ val paramAnnots =
249+ val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
250+ Apply (ref(defn.SeqModule .termRef), annotationTrees)
251+
252+ val constructorArgs = List (param, paramTypeStr, hasDefault, isRepeated, paramDoc)
253+ .map(value => Literal (Constant (value)))
254+
255+ New (TypeTree (defn.MainAnnotationParameterInfo .typeRef), List (constructorArgs :+ paramAnnots))
256+
257+ end parameterInfos
258+
259+ /**
260+ * Creates a list of references and definitions of arguments.
261+ * The goal is to create the
262+ * `val args0: () => S = cmd.argGetter[S](0, None)`
263+ * part of the code.
264+ */
265+ def argValDefs (mt : MethodType ): List [ValDef ] =
266+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
267+ val argName = nme.args ++ idx.toString
268+ val isRepeated = formal.isRepeatedParam
269+ val formalType = if isRepeated then formal.argTypes.head else formal
270+ val getterName = if isRepeated then nme.varargGetter else nme.argGetter
271+ val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
272+ case None => ref(defn.NoneModule .termRef)
273+ case Some (dvSym) =>
274+ val value = unitToValue(ref(dvSym.termRef))
275+ Apply (ref(defn.SomeClass .companionModule.termRef), value)
276+ val argGetter0 = TypeApply (Select (Ident (nme.cmd), getterName), TypeTree (formalType) :: Nil )
277+ val argGetter =
278+ if isRepeated then argGetter0
279+ else Apply (argGetter0, List (Literal (Constant (idx)), defaultValueGetterOpt))
280+
281+ ValDef (argName, TypeTree (), argGetter)
282+ end argValDefs
283+
284+
285+ /** Create a list of argument references that will be passed as argument to the main method.
286+ * `args0`, ...`argn*`
287+ */
288+ def argRefs (mt : MethodType ): List [Tree ] =
289+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
290+ val argRef = Apply (Ident (nme.args ++ idx.toString), Nil )
291+ if formal.isRepeatedParam then repeated(argRef) else argRef
292+ end argRefs
293+
294+
295+ /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
296+ def instantiateAnnotation (annot : Annotation ): Tree =
297+ val argss = {
298+ def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
299+ case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
300+ case _ => acc
301+ }
302+
303+ def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
304+ args.flatMap {
305+ case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
306+ case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
307+ case arg => List (TypedSplice (arg))
308+ }
309+
310+ recurse(annot.tree, Nil )
311+ }
312+
313+ New (TypeTree (annot.symbol.typeRef), argss)
314+ end instantiateAnnotation
315+
316+ def generateMainClass (mainCall : Tree , args : List [Tree ], parameterInfos : List [Tree ]): TypeDef =
317+ val cmdInfo =
318+ val nameTree = Literal (Constant (mainFun.showName))
319+ val docTree = Literal (Constant (documentation.mainDoc))
320+ val paramInfos = Apply (ref(defn.SeqModule .termRef), parameterInfos)
321+ New (TypeTree (defn.MainAnnotationCommandInfo .typeRef), List (List (nameTree, docTree, paramInfos)))
322+
323+ val cmd = ValDef (
324+ nme.cmd,
325+ TypeTree (),
326+ Apply (
327+ Select (instantiateAnnotation(mainAnnot), nme.command),
328+ List (cmdInfo, Ident (nme.args))
329+ )
330+ )
331+ val run = Apply (Select (Ident (nme.cmd), nme.run), mainCall)
332+ val body = Block (cmdInfo :: cmd :: args, run)
333+ val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
334+ .withFlags(Param )
335+ /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
336+ * The annotations will be retype-checked in another scope that may not have the same imports.
337+ */
338+ def insertTypeSplices = new TreeMap {
339+ override def transform (tree : Tree )(using Context ): Tree = tree match
340+ case tree : tpd.Ident @ unchecked => TypedSplice (tree)
341+ case tree => super .transform(tree)
342+ }
343+ val annots = mainFun.annotations
344+ .filterNot(_.matches(defn.MainAnnotationClass ))
345+ .map(annot => insertTypeSplices.transform(annot.tree))
346+ val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
347+ .withFlags(JavaStatic )
348+ .withAnnotations(annots)
349+ val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
350+ val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
351+ .withFlags(Final | Invisible )
352+ mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
353+ end generateMainClass
354+
355+ if (! mainFun.owner.isStaticOwner)
356+ report.error(s " main method is not statically accessible " , pos)
357+ None
358+ else mainFun.info match {
359+ case _ : ExprType =>
360+ Some (generateMainClass(unitToValue(ref(mainFun.termRef)), Nil , Nil ))
361+ case mt : MethodType =>
362+ if (mt.isImplicitMethod)
363+ report.error(s " main method cannot have implicit parameters " , pos)
364+ None
365+ else mt.resType match
366+ case restpe : MethodType =>
367+ report.error(s " main method cannot be curried " , pos)
368+ None
369+ case _ =>
370+ Some (generateMainClass(unitToValue(Apply (ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
371+ case _ : PolyType =>
372+ report.error(s " main method cannot have type parameters " , pos)
373+ None
374+ case _ =>
375+ report.error(s " main can only annotate a method " , pos)
376+ None
377+ }
378+ }
379+
380+ /** A class responsible for extracting the docstrings of a method. */
381+ private class Documentation (docComment : Option [Comment ]):
382+ import util .CommentParsing ._
383+
384+ /** The main part of the documentation. */
385+ lazy val mainDoc : String = _mainDoc
386+ /** The parameters identified by @param. Maps from parameter name to its documentation. */
387+ lazy val argDocs : Map [String , String ] = _argDocs
388+
389+ private var _mainDoc : String = " "
390+ private var _argDocs : Map [String , String ] = Map ()
391+
392+ docComment match {
393+ case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
394+ case None =>
395+ }
396+
397+ private def cleanComment (raw : String ): String =
398+ var lines : Seq [String ] = raw.trim.nn.split('\n ' ).nn.toSeq
399+ lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).nn.trim.nn)
400+ var s = lines.foldLeft(" " ) {
401+ case (" " , s2) => s2
402+ case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
403+ case (s1, " " ) => s1 + '\n '
404+ case (s1, s2) if s1.last == '\n ' => s1 + s2
405+ case (s1, s2) => s1 + ' ' + s2
406+ }
407+ s.replaceAll(raw " \[\[ " , " " ).nn.replaceAll(raw " \]\] " , " " ).nn.trim.nn
408+
409+ private def parseDocComment (raw : String ): Unit =
410+ // Positions of the sections (@) in the docstring
411+ val tidx : List [(Int , Int )] = tagIndex(raw)
412+
413+ // Parse main comment
414+ var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx)).nn
415+ _mainDoc = cleanComment(mainComment)
416+
417+ // Parse arguments comments
418+ val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
419+ val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
420+ val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
421+ _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
422+ end Documentation
119423}
0 commit comments