@@ -14,6 +14,7 @@ import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
1414import org.jetbrains.kotlin.ir.declarations.IrField
1515import org.jetbrains.kotlin.ir.declarations.IrFile
1616import org.jetbrains.kotlin.ir.declarations.IrFunction
17+ import org.jetbrains.kotlin.ir.declarations.IrParameterKind
1718import org.jetbrains.kotlin.ir.declarations.path
1819import org.jetbrains.kotlin.ir.expressions.IrBlockBody
1920import org.jetbrains.kotlin.ir.expressions.IrBody
@@ -23,7 +24,7 @@ import org.jetbrains.kotlin.ir.expressions.IrExpression
2324import org.jetbrains.kotlin.ir.expressions.IrExpressionBody
2425import org.jetbrains.kotlin.ir.expressions.IrGetValue
2526import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
26- import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
27+ import org.jetbrains.kotlin.ir.expressions.impl.IrCallImplWithShape
2728import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
2829import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionExpressionImpl
2930import org.jetbrains.kotlin.ir.expressions.impl.IrGetObjectValueImpl
@@ -38,8 +39,8 @@ import org.jetbrains.kotlin.ir.types.typeWith
3839import org.jetbrains.kotlin.ir.util.SetDeclarationsParentVisitor
3940import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable
4041import org.jetbrains.kotlin.ir.util.isLocal
41- import org.jetbrains.kotlin.ir.visitors.IrElementTransformer
4242import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
43+ import org.jetbrains.kotlin.ir.visitors.IrTransformer
4344import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
4445import org.jetbrains.kotlin.name.CallableId
4546import org.jetbrains.kotlin.name.ClassId
@@ -51,8 +52,8 @@ data class ContainingDeclarations(val clazz: IrClass?, val function: IrFunction?
5152
5253@OptIn(UnsafeDuringIrConstructionAPI ::class )
5354class ExplainerIrTransformer (val pluginContext : IrPluginContext ) :
54- FileLoweringPass ,
55- IrElementTransformer < ContainingDeclarations > {
55+ IrTransformer < ContainingDeclarations >() ,
56+ FileLoweringPass {
5657 lateinit var file: IrFile
5758 lateinit var source: String
5859
@@ -151,12 +152,15 @@ class ExplainerIrTransformer(val pluginContext: IrPluginContext) :
151152 if (expression.startOffset < 0 ) return expression
152153 if (expression.type.classFqName in dataFrameLike) {
153154 if (expression.symbol.owner.name == Name .identifier(" component1" )) return expression
154- var receiver = expression.extensionReceiver
155- // expression.extensionReceiver = extension callables,
155+ val extensionReceiverIndex =
156+ expression.symbol.owner.parameters.indexOfFirst { it.kind == IrParameterKind .ExtensionReceiver }
157+ var receiver: IrExpression ?
158+ // expression.arguments[extensionReceiverIndex] = extension callables,
156159 // expression.dispatchReceiver = member callables such as "GroupBy.aggregate"
157- if (receiver != null ) {
158- val transformedExtensionReceiver = expression.extensionReceiver?.transform(this , data)
159- expression.extensionReceiver = transformedExtensionReceiver
160+ if (extensionReceiverIndex >= 0 ) {
161+ receiver = expression.arguments[extensionReceiverIndex]!!
162+ val transformedExtensionReceiver = receiver.transform(this , data)
163+ expression.arguments[extensionReceiverIndex] = transformedExtensionReceiver
160164 } else {
161165 receiver = expression.dispatchReceiver
162166 val transformedExtensionReceiver = expression.dispatchReceiver?.transform(this , data)
@@ -179,9 +183,26 @@ class ExplainerIrTransformer(val pluginContext: IrPluginContext) :
179183 CallableId (FqName (" kotlin" ), Name .identifier(" also" )),
180184 ).single()
181185
182- val result = IrCallImpl (- 1 , - 1 , expression.type, alsoReference, 1 , 1 ).apply {
183- this .extensionReceiver = expression
184- putTypeArgument(0 , expression.type)
186+ val result = IrCallImplWithShape (
187+ startOffset = - 1 ,
188+ endOffset = - 1 ,
189+ type = expression.type,
190+ symbol = alsoReference,
191+ typeArgumentsCount = 1 ,
192+ valueArgumentsCount = 1 ,
193+ contextParameterCount = 0 ,
194+ hasDispatchReceiver = true ,
195+ hasExtensionReceiver = true ,
196+ ).apply {
197+ val extensionReceiverIndex =
198+ this .symbol.owner.parameters.indexOfFirst { it.kind == IrParameterKind .ExtensionReceiver }
199+ if (extensionReceiverIndex >= 0 ) {
200+ this .arguments[extensionReceiverIndex] = expression
201+ } else {
202+ this .insertExtensionReceiver(expression)
203+ }
204+
205+ typeArguments[0 ] = expression.type
185206
186207 val symbol = IrSimpleFunctionSymbolImpl ()
187208 val alsoLambda = pluginContext.irFactory
@@ -202,25 +223,24 @@ class ExplainerIrTransformer(val pluginContext: IrPluginContext) :
202223 isInfix = false ,
203224 isExpect = false ,
204225 ).apply {
205- valueParameters = buildList {
206- add(
207- pluginContext.irFactory.createValueParameter(
208- startOffset = - 1 ,
209- endOffset = - 1 ,
210- origin = IrDeclarationOrigin .DEFINED ,
211- symbol = IrValueParameterSymbolImpl (),
212- name = Name .identifier(" it" ),
213- index = 0 ,
214- type = expression.type,
215- varargElementType = null ,
216- isCrossinline = false ,
217- isNoinline = false ,
218- isHidden = false ,
219- isAssignable = false ,
220- ),
226+ // replace all regular value parameters with a single one `it`
227+ parameters = parameters.filterNot { it.kind == IrParameterKind .Regular } +
228+ pluginContext.irFactory.createValueParameter(
229+ startOffset = - 1 ,
230+ endOffset = - 1 ,
231+ origin = IrDeclarationOrigin .DEFINED ,
232+ kind = IrParameterKind .Regular ,
233+ name = Name .identifier(" it" ),
234+ type = expression.type,
235+ isAssignable = false ,
236+ symbol = IrValueParameterSymbolImpl (),
237+ varargElementType = null ,
238+ isCrossinline = false ,
239+ isNoinline = false ,
240+ isHidden = false ,
221241 )
222- }
223- val itSymbol = valueParameters[ 0 ] .symbol
242+
243+ val itSymbol = parameters.first { it.kind == IrParameterKind . Regular } .symbol
224244 val source = try {
225245 source.substring(expression.startOffset, expression.endOffset)
226246 } catch (e: Exception ) {
@@ -229,14 +249,21 @@ class ExplainerIrTransformer(val pluginContext: IrPluginContext) :
229249 val expressionId = expressionId(expression)
230250 val receiverId = receiver?.let { expressionId(it) }
231251 val valueArguments = buildList<IrExpression ?> {
232- add(source.irConstImpl())
233- add(ownerName.asStringStripSpecialMarkers().irConstImpl())
234- add(IrGetValueImpl (- 1 , - 1 , itSymbol))
235- add(expressionId.irConstImpl())
236- add(receiverId.irConstImpl())
237- add(data.clazz?.fqNameWhenAvailable?.asString().irConstImpl())
238- add(data.function?.name?.asString().irConstImpl())
239- add(IrConstImpl .int(- 1 , - 1 , pluginContext.irBuiltIns.intType, data.statementIndex))
252+ add(source.irConstImpl()) // source: String
253+ add(ownerName.asStringStripSpecialMarkers().irConstImpl()) // name: String
254+ add(IrGetValueImpl (- 1 , - 1 , itSymbol)) // df: Any
255+ add(expressionId.irConstImpl()) // id: String
256+ add(receiverId.irConstImpl()) // receiverId: String?
257+ add(data.clazz?.fqNameWhenAvailable?.asString().irConstImpl()) // containingClassFqName: String?
258+ add(data.function?.name?.asString().irConstImpl()) // containingFunName: String?
259+ add(
260+ IrConstImpl .int(
261+ - 1 ,
262+ - 1 ,
263+ pluginContext.irBuiltIns.intType,
264+ data.statementIndex,
265+ ),
266+ ) // statementIndex: Int
240267 }
241268 body = pluginContext.irFactory.createBlockBody(- 1 , - 1 ).apply {
242269 val callableId = CallableId (
@@ -245,19 +272,24 @@ class ExplainerIrTransformer(val pluginContext: IrPluginContext) :
245272 Name .identifier(" doAction" ),
246273 )
247274 val doAction = pluginContext.referenceFunctions(callableId).single()
248- statements + = IrCallImpl (
275+ statements + = IrCallImplWithShape (
249276 startOffset = - 1 ,
250277 endOffset = - 1 ,
251278 type = doAction.owner.returnType,
252279 symbol = doAction,
253280 typeArgumentsCount = 0 ,
254281 valueArgumentsCount = valueArguments.size,
282+ contextParameterCount = 0 ,
283+ hasDispatchReceiver = true ,
284+ hasExtensionReceiver = false ,
255285 ).apply {
256286 val clazz = ClassId (explainerPackage, Name .identifier(" PluginCallbackProxy" ))
257287 val plugin = pluginContext.referenceClass(clazz)!!
258288 dispatchReceiver = IrGetObjectValueImpl (- 1 , - 1 , plugin.defaultType, plugin)
289+
290+ val firstValueArgumentIndex = 1 // skipping dispatch receiver
259291 valueArguments.forEachIndexed { i, argument ->
260- putValueArgument(i, argument)
292+ this .arguments[firstValueArgumentIndex + i] = argument
261293 }
262294 }
263295 }
@@ -271,12 +303,16 @@ class ExplainerIrTransformer(val pluginContext: IrPluginContext) :
271303 function = alsoLambda,
272304 origin = IrStatementOrigin .LAMBDA ,
273305 )
274- putValueArgument(0 , alsoLambdaExpression)
306+
307+ val firstValueArgumentIndex = this .symbol.owner.parameters
308+ .indexOfFirst { it.kind == IrParameterKind .Regular }
309+ .takeUnless { it < 0 } ? : this .symbol.owner.parameters.size
310+ this .arguments[firstValueArgumentIndex] = alsoLambdaExpression
275311 }
276312 return result
277313 }
278314
279- private fun String?.irConstImpl (): IrConstImpl < out String ?> {
315+ private fun String?.irConstImpl (): IrConstImpl {
280316 val nullableString = pluginContext.irBuiltIns.stringType.makeNullable()
281317 val argument = if (this == null ) {
282318 IrConstImpl .constNull(- 1 , - 1 , nullableString)
0 commit comments