@@ -4,6 +4,7 @@ package tasty
44import scala .jdk .CollectionConverters ._
55
66import scala .quoted ._
7+ import scala .util .control .NonFatal
78
89import NameNormalizer ._
910import SyntheticsSupport ._
@@ -124,6 +125,12 @@ trait TypesSupport:
124125 ++ keyword(" =>> " ).l
125126 ++ inner(resType)
126127
128+ case Refinement (parent, " apply" , mt : MethodType ) if isPolyOrEreased(parent) =>
129+ val isCtx = isContextualMethod(mt)
130+ val sym = defn.FunctionClass (mt.paramTypes.length, isCtx)
131+ val at = AppliedType (TypeTree .ref(sym).tpe, mt.paramTypes :+ mt.resType)
132+ inner(Refinement (at, " apply" , mt))
133+
127134 case r : Refinement => { // (parent, name, info)
128135 def getRefinementInformation (t : TypeRepr ): List [TypeRepr ] = t match {
129136 case r : Refinement => getRefinementInformation(r.parent) :+ r
@@ -164,16 +171,22 @@ trait TypesSupport:
164171 case t : PolyType =>
165172 val paramBounds = getParamBounds(t)
166173 val method = t.resType.asInstanceOf [MethodType ]
167- val paramList = getParamList(method)
168- val resType = inner(method.resType)
169- plain(" [" ).l ++ paramBounds ++ plain(" ]" ).l ++ keyword(" => " ).l ++ paramList ++ keyword(" => " ).l ++ resType
174+ val rest = parseDependentFunctionType(method)
175+ plain(" [" ).l ++ paramBounds ++ plain(" ]" ).l ++ keyword(" => " ).l ++ rest
170176 case other => noSupported(s " Not supported type in refinement $info" )
171177 }
172178
173179 def parseDependentFunctionType (info : TypeRepr ): SSignature = info match {
174180 case m : MethodType =>
175- val paramList = getParamList(m)
176- paramList ++ keyword(" => " ).l ++ inner(m.resType)
181+ val isCtx = isContextualMethod(m)
182+ if isDependentMethod(m) then
183+ val paramList = getParamList(m)
184+ val arrow = keyword(if isCtx then " ?=> " else " => " ).l
185+ val resType = inner(m.resType)
186+ paramList ++ arrow ++ resType
187+ else
188+ val sym = defn.FunctionClass (m.paramTypes.length, isCtx)
189+ inner(sym.typeRef.appliedTo(m.paramTypes :+ m.resType))
177190 case other => noSupported(" Dependent function type without MethodType refinement" )
178191 }
179192
@@ -213,8 +226,9 @@ trait TypesSupport:
213226 case Seq (rtpe) =>
214227 plain(" ()" ).l ++ keyword(arrow).l ++ inner(rtpe)
215228 case Seq (arg, rtpe) =>
216- val partOfSignature = arg match
229+ val partOfSignature = stripAnnotated( arg) match
217230 case _ : TermRef | _ : TypeRef | _ : ConstantType | _ : ParamRef => inner(arg)
231+ case at : AppliedType if ! isInfix(at) && ! at.isFunctionType && ! at.isTupleN => inner(arg)
218232 case _ => inParens(inner(arg))
219233 partOfSignature ++ keyword(arrow).l ++ inner(rtpe)
220234 case args =>
@@ -385,3 +399,21 @@ trait TypesSupport:
385399 case _ => false
386400
387401 at.args.size == 2 && (! at.typeSymbol.name.forall(isIdentifierPart) || infixAnnot)
402+
403+ private def isPolyOrEreased (using Quotes )(tr : reflect.TypeRepr ) =
404+ Set (" scala.PolyFunction" , " scala.runtime.ErasedFunction" )
405+ .contains(tr.typeSymbol.fullName)
406+
407+ private def isContextualMethod (using Quotes )(mt : reflect.MethodType ) =
408+ mt.asInstanceOf [dotty.tools.dotc.core.Types .MethodType ].isContextualMethod
409+
410+ private def isDependentMethod (using Quotes )(mt : reflect.MethodType ) =
411+ val method = mt.asInstanceOf [dotty.tools.dotc.core.Types .MethodType ]
412+ try method.isParamDependent || method.isResultDependent
413+ catch case NonFatal (_) => true
414+
415+ private def stripAnnotated (using Quotes )(tr : reflect.TypeRepr ): reflect.TypeRepr =
416+ import reflect .*
417+ tr match
418+ case AnnotatedType (tr, _) => stripAnnotated(tr)
419+ case other => other
0 commit comments