@@ -16,14 +16,17 @@ import dotty.tools.dotc.core.Symbols
1616import dotty .tools .dotc .core .Symbols .Symbol
1717import dotty .tools .dotc .core .Types .AndType
1818import dotty .tools .dotc .core .Types .AppliedType
19+ import dotty .tools .dotc .core .Types .MethodType
1920import dotty .tools .dotc .core .Types .OrType
21+ import dotty .tools .dotc .core .Types .RefinedType
2022import dotty .tools .dotc .core .Types .TermRef
2123import dotty .tools .dotc .core .Types .Type
2224import dotty .tools .dotc .core .Types .TypeBounds
2325import dotty .tools .dotc .core .Types .WildcardType
2426import dotty .tools .dotc .util .SourcePosition
2527import dotty .tools .pc .IndexedContext
2628import dotty .tools .pc .utils .MtagsEnrichments .*
29+ import scala .annotation .tailrec
2730
2831object NamedArgCompletions :
2932
@@ -195,9 +198,40 @@ object NamedArgCompletions:
195198 // def curry(x: Int)(apple: String, banana: String) = ???
196199 // curry(1)(apple = "test", b@@)
197200 // ```
198- val (baseParams , baseArgs) =
201+ val (baseParams0 , baseArgs) =
199202 vparamss.zip(argss).lastOption.getOrElse((Nil , Nil ))
200203
204+ val baseParams : List [ParamSymbol ] =
205+ def defaultBaseParams = baseParams0.map(JustSymbol (_))
206+ @ tailrec
207+ def getRefinedParams (refinedType : Type , level : Int ): List [ParamSymbol ] =
208+ if level > 0 then
209+ val resultTypeOpt =
210+ refinedType match
211+ case RefinedType (AppliedType (_, args), _, _) => args.lastOption
212+ case AppliedType (_, args) => args.lastOption
213+ case _ => None
214+ resultTypeOpt match
215+ case Some (resultType) => getRefinedParams(resultType, level - 1 )
216+ case _ => defaultBaseParams
217+ else
218+ refinedType match
219+ case RefinedType (AppliedType (_, args), _, MethodType (ri)) =>
220+ baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
221+ RefinedSymbol (sym, name, arg)
222+ }
223+ case _ => defaultBaseParams
224+ // finds param refinements for lambda expressions
225+ // val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
226+ @ tailrec
227+ def refineParams (method : Tree , level : Int ): List [ParamSymbol ] =
228+ method match
229+ case Select (Apply (f, _), _) => refineParams(f, level + 1 )
230+ case Select (h, v) => getRefinedParams(h.symbol.info, level)
231+ case _ => defaultBaseParams
232+ refineParams(method, 0 )
233+ end baseParams
234+
201235 val args = ident
202236 .map(i => baseArgs.filterNot(_ == i))
203237 .getOrElse(baseArgs)
@@ -221,7 +255,7 @@ object NamedArgCompletions:
221255
222256 baseParams.filterNot(param =>
223257 isNamed(param.name) ||
224- param.denot.is(
258+ param.symbol. denot.is(
225259 Flags .Synthetic
226260 ) // filter out synthesized param, like evidence
227261 )
@@ -232,7 +266,7 @@ object NamedArgCompletions:
232266 .map(_.name.toString)
233267 .getOrElse(" " )
234268 .replace(Cursor .value, " " )
235- val params : List [Symbol ] =
269+ val params : List [ParamSymbol ] =
236270 allParams
237271 .filter(param => param.name.startsWith(prefix))
238272 .distinctBy(sym => (sym.name, sym.info))
@@ -249,7 +283,7 @@ object NamedArgCompletions:
249283 .filter(name => name != " Nil" && name != " None" )
250284 .sorted
251285
252- def findDefaultValue (param : Symbol ): String =
286+ def findDefaultValue (param : ParamSymbol ): String =
253287 val matchingType = matchingTypesInScope(param.info)
254288 if matchingType.size == 1 then s " : ${matchingType.head}"
255289 else if matchingType.size > 1 then s " |???, ${matchingType.mkString(" ," )}| "
@@ -260,12 +294,12 @@ object NamedArgCompletions:
260294 def shouldShow =
261295 allParams.exists(param => param.name.startsWith(prefix))
262296 def isExplicitlyCalled = suffix.startsWith(prefix)
263- def hasParamsToFill = allParams.count(! _.is(Flags .HasDefault )) > 1
297+ def hasParamsToFill = allParams.count(! _.symbol. is(Flags .HasDefault )) > 1
264298 if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill
265299 then
266300 val editText = allParams.zipWithIndex
267301 .collect {
268- case (param, index) if ! param.is(Flags .HasDefault ) =>
302+ case (param, index) if ! param.symbol. is(Flags .HasDefault ) =>
269303 s " ${param.nameBackticked.replace(" $" , " $$" )} = $$ { ${index + 1 }${findDefaultValue(param)}} "
270304 }
271305 .mkString(" , " )
@@ -355,3 +389,16 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context):
355389 case _ => t
356390
357391end FuzzyArgMatcher
392+
393+ sealed trait ParamSymbol :
394+ def name : Name
395+ def info : Type
396+ def symbol : Symbol
397+ def nameBackticked (using Context ) = name.decoded.backticked
398+
399+ case class JustSymbol (symbol : Symbol )(using Context ) extends ParamSymbol :
400+ def name : Name = symbol.name
401+ def info : Type = symbol.info
402+
403+ case class RefinedSymbol (symbol : Symbol , name : Name , info : Type )
404+ extends ParamSymbol
0 commit comments