@@ -18,13 +18,14 @@ import config.Printers.typr
1818import config .Feature
1919import util .{SrcPos , Stats }
2020import reporting .*
21- import NameKinds .WildcardParamName
21+ import NameKinds .{ WildcardParamName , TempResultName }
2222import typer .Applications .{spread , HasSpreads }
2323import typer .Implicits .SearchFailureType
2424import Constants .Constant
2525import cc .*
2626import dotty .tools .dotc .transform .MacroAnnotations .hasMacroAnnotation
2727import dotty .tools .dotc .core .NameKinds .DefaultGetterName
28+ import ast .TreeInfo
2829
2930object PostTyper {
3031 val name : String = " posttyper"
@@ -379,6 +380,25 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
379380 case _ =>
380381 tpt
381382
383+ private def evalSpreadsOnce (trees : List [Tree ])(within : List [Tree ] => Tree )(using Context ): Tree =
384+ if trees.exists:
385+ case spread(elem) => ! (exprPurity(elem) >= TreeInfo .Idempotent )
386+ case _ => false
387+ then
388+ val lifted = new mutable.ListBuffer [ValDef ]
389+ def liftIfImpure (tree : Tree ): Tree = tree match
390+ case tree @ Apply (fn, args) if fn.symbol == defn.spreadMethod =>
391+ cpy.Apply (tree)(fn, args.mapConserve(liftIfImpure))
392+ case _ if tpd.exprPurity(tree) >= TreeInfo .Idempotent =>
393+ tree
394+ case _ =>
395+ val vdef = SyntheticValDef (TempResultName .fresh(), tree)
396+ lifted += vdef
397+ Ident (vdef.namedType)
398+ val pureTrees = trees.mapConserve(liftIfImpure)
399+ Block (lifted.toList, within(pureTrees))
400+ else within(trees)
401+
382402 /** Translate sequence literal containing spread operators. Example:
383403 *
384404 * val xs, ys: List[Int]
@@ -400,50 +420,51 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
400420 * at typer, we don't have all type variables instantiated yet.
401421 */
402422 private def flattenSpreads [T ](tree : SeqLiteral )(using Context ): Tree =
403- val SeqLiteral (elems , elemtpt) = tree
423+ val SeqLiteral (rawElems , elemtpt) = tree
404424 val elemType = elemtpt.tpe
405425 val elemCls = elemType.classSymbol
406426
407- val lengthCalls = elems.collect:
408- case spread(elem) => elem.select(nme.length)
409- val singleElemCount : Tree = Literal (Constant (elems.length - lengthCalls.length))
410- val totalLength =
411- lengthCalls.foldLeft(singleElemCount): (acc, len) =>
412- acc.select(defn.Int_+ ).appliedTo(len)
413-
414- def makeBuilder (name : String ) =
415- ref(defn.ArraySeqBuilderModule ).select(name.toTermName)
416- def genericBuilder = makeBuilder(" generic" )
417- .appliedToType(elemType)
418- .appliedTo(totalLength)
419-
420- val builder =
421- if defn.ScalaValueClasses ().contains(elemCls) then
422- makeBuilder(s " of ${elemCls.name}" ).appliedTo(totalLength)
423- else if elemCls.derivesFrom(defn.ObjectClass ) then
424- val classTagType = defn.ClassTagClass .typeRef.appliedTo(elemType)
425- val classTag = atPhase(Phases .typerPhase):
426- ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
427- classTag.tpe match
428- case _ : SearchFailureType =>
429- genericBuilder
430- case _ =>
431- makeBuilder(" ofRef" )
432- .appliedToType(elemType)
433- .appliedTo(totalLength)
434- .appliedTo(classTag)
435- else
436- genericBuilder
437-
438- elems.foldLeft(builder): (bldr, elem) =>
439- elem match
440- case spread(arg) =>
441- val selector =
442- if arg.tpe.derivesFrom(defn.SeqClass ) then " addSeq"
443- else " addArray"
444- bldr.select(selector.toTermName).appliedTo(arg)
445- case _ => bldr.select(" add" .toTermName).appliedTo(elem)
446- .select(" result" .toTermName)
427+ evalSpreadsOnce(rawElems): elems =>
428+ val lengthCalls = elems.collect:
429+ case spread(elem) => elem.select(nme.length)
430+ val singleElemCount : Tree = Literal (Constant (elems.length - lengthCalls.length))
431+ val totalLength =
432+ lengthCalls.foldLeft(singleElemCount): (acc, len) =>
433+ acc.select(defn.Int_+ ).appliedTo(len)
434+
435+ def makeBuilder (name : String ) =
436+ ref(defn.ArraySeqBuilderModule ).select(name.toTermName)
437+ def genericBuilder = makeBuilder(" generic" )
438+ .appliedToType(elemType)
439+ .appliedTo(totalLength)
440+
441+ val builder =
442+ if defn.ScalaValueClasses ().contains(elemCls) then
443+ makeBuilder(s " of ${elemCls.name}" ).appliedTo(totalLength)
444+ else if elemCls.derivesFrom(defn.ObjectClass ) then
445+ val classTagType = defn.ClassTagClass .typeRef.appliedTo(elemType)
446+ val classTag = atPhase(Phases .typerPhase):
447+ ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
448+ classTag.tpe match
449+ case _ : SearchFailureType =>
450+ genericBuilder
451+ case _ =>
452+ makeBuilder(" ofRef" )
453+ .appliedToType(elemType)
454+ .appliedTo(totalLength)
455+ .appliedTo(classTag)
456+ else
457+ genericBuilder
458+
459+ elems.foldLeft(builder): (bldr, elem) =>
460+ elem match
461+ case spread(arg) =>
462+ val selector =
463+ if arg.tpe.derivesFrom(defn.SeqClass ) then " addSeq"
464+ else " addArray"
465+ bldr.select(selector.toTermName).appliedTo(arg)
466+ case _ => bldr.select(" add" .toTermName).appliedTo(elem)
467+ .select(" result" .toTermName)
447468 end flattenSpreads
448469
449470 override def transform (tree : Tree )(using Context ): Tree =
0 commit comments