11package love.forte.plugin.suspendtrans.ir
22
3- import love.forte.plugin.suspendtrans.*
3+ import love.forte.plugin.suspendtrans.SuspendTransformConfiguration
4+ import love.forte.plugin.suspendtrans.SuspendTransformUserData
5+ import love.forte.plugin.suspendtrans.SuspendTransformUserDataKey
6+ import love.forte.plugin.suspendtrans.fqn
47import love.forte.plugin.suspendtrans.utils.*
58import org.jetbrains.kotlin.backend.common.IrElementTransformerVoidWithContext
69import org.jetbrains.kotlin.backend.common.extensions.FirIncompatiblePluginAPI
@@ -9,16 +12,16 @@ import org.jetbrains.kotlin.descriptors.CallableDescriptor
912import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor
1013import org.jetbrains.kotlin.ir.IrStatement
1114import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
12- import org.jetbrains.kotlin.ir.builders.irBlockBody
13- import org.jetbrains.kotlin.ir.builders.irCall
14- import org.jetbrains.kotlin.ir.builders.irGet
15- import org.jetbrains.kotlin.ir.builders.irReturn
15+ import org.jetbrains.kotlin.ir.builders.*
1616import org.jetbrains.kotlin.ir.declarations.*
1717import org.jetbrains.kotlin.ir.expressions.IrBody
18+ import org.jetbrains.kotlin.ir.expressions.IrCall
19+ import org.jetbrains.kotlin.ir.expressions.IrTypeOperator
20+ import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl
1821import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
19- import org.jetbrains.kotlin.ir.types.isSubtypeOfClass
20- import org.jetbrains.kotlin.ir.types.typeWith
21- import org.jetbrains.kotlin.ir.util.*
22+ import org.jetbrains.kotlin.ir.types.*
23+ import org.jetbrains.kotlin.ir.util.isAnnotationWithEqualFqName
24+ import org.jetbrains.kotlin.ir.util.primaryConstructor
2225import org.jetbrains.kotlin.name.ClassId
2326import org.jetbrains.kotlin.name.FqName
2427
@@ -92,7 +95,8 @@ class SuspendTransformTransformer(
9295
9396 private fun resolveFunctionBodyByDescriptor (declaration : IrFunction , descriptor : CallableDescriptor ): IrFunction ? {
9497 val userData = descriptor.getUserData(SuspendTransformUserDataKey ) ? : return null
95- val callableFunction = pluginContext.referenceFunctions(userData.transformer.transformFunctionInfo.toCallableId()).firstOrNull()
98+ val callableFunction =
99+ pluginContext.referenceFunctions(userData.transformer.transformFunctionInfo.toCallableId()).firstOrNull()
96100 ? : throw IllegalStateException (" Transform function ${userData.transformer.transformFunctionInfo} not found" )
97101
98102 val generatedOriginFunction = resolveFunctionBody(declaration, userData.originFunction, callableFunction)
@@ -112,7 +116,7 @@ class SuspendTransformTransformer(
112116 currentAnnotations.any { a -> a.isAnnotationWithEqualFqName(name) }
113117 addAll(currentAnnotations)
114118
115- val syntheticFunctionIncludes = userData.transformer.originFunctionIncludeAnnotations
119+ val syntheticFunctionIncludes = userData.transformer.originFunctionIncludeAnnotations
116120
117121 syntheticFunctionIncludes.forEach { include ->
118122 val classId = include.classInfo.toClassId()
@@ -205,24 +209,69 @@ private fun generateTransformBodyForFunction(
205209 // println(transformTargetFunctionCall.owner.valueParameters)
206210 val owner = transformTargetFunctionCall.owner
207211
208- if (owner.valueParameters.size > 1 ) {
209- val secondType = owner.valueParameters[1 ].type
210- val coroutineScopeTypeName = " kotlinx.coroutines.CoroutineScope" .fqn
211- val coroutineScopeTypeClassId = ClassId .topLevel(" kotlinx.coroutines.CoroutineScope" .fqn)
212- val coroutineScopeTypeNameUnsafe = coroutineScopeTypeName.toUnsafe()
213- if (secondType.isClassType(coroutineScopeTypeNameUnsafe)) {
214- function.dispatchReceiverParameter?.also { dispatchReceiverParameter ->
215- context.referenceClass(coroutineScopeTypeClassId)?.also { coroutineScopeRef ->
216- if (dispatchReceiverParameter.type.isSubtypeOfClass(coroutineScopeRef)) {
217- // put 'this' to second arg
218- putValueArgument(1 , irGet(dispatchReceiverParameter))
219- }
220- }
221- }
222- }
212+ // CoroutineScope
213+ val ownerValueParameters = owner.valueParameters
223214
215+ if (ownerValueParameters.size > 1 ) {
216+ for (index in 1 .. ownerValueParameters.lastIndex) {
217+ val valueParameter = ownerValueParameters[index]
218+ val type = valueParameter.type
219+ tryResolveCoroutineScopeValueParameter(type, context, function, owner, this @irBlockBody, index)
220+ }
224221 }
225222
226223 })
227224 }
228225}
226+
227+ private val coroutineScopeTypeName = " kotlinx.coroutines.CoroutineScope" .fqn
228+ private val coroutineScopeTypeClassId = ClassId .topLevel(" kotlinx.coroutines.CoroutineScope" .fqn)
229+ private val coroutineScopeTypeNameUnsafe = coroutineScopeTypeName.toUnsafe()
230+
231+ /* *
232+ * 解析类型为 CoroutineScope 的参数。
233+ * 如果当前参数类型为 CoroutineScope:
234+ * - 如果当前 receiver 即为 CoroutineScope 类型,将其填充
235+ * - 如果当前 receiver 不是 CoroutineScope 类型,但是此参数可以为 null,
236+ * 则使用 safe-cast 将 receiver 转化为 CoroutineScope ( `dispatcher as? CoroutineScope` )
237+ * - 其他情况忽略此参数(适用于此参数有默认值的情况)
238+ */
239+ private fun IrCall.tryResolveCoroutineScopeValueParameter (
240+ type : IrType ,
241+ context : IrPluginContext ,
242+ function : IrFunction ,
243+ owner : IrSimpleFunction ,
244+ builderWithScope : IrBuilderWithScope ,
245+ index : Int
246+ ) {
247+ if (! type.isClassType(coroutineScopeTypeNameUnsafe)) {
248+ return
249+ }
250+
251+ function.dispatchReceiverParameter?.also { dispatchReceiverParameter ->
252+ context.referenceClass(coroutineScopeTypeClassId)?.also { coroutineScopeRef ->
253+ if (dispatchReceiverParameter.type.isSubtypeOfClass(coroutineScopeRef)) {
254+ // put 'this' to the arg
255+ putValueArgument(index, builderWithScope.irGet(dispatchReceiverParameter))
256+ } else {
257+ val scopeType = coroutineScopeRef.defaultType
258+
259+ val scopeParameter = owner.valueParameters.getOrNull(1 )
260+
261+ if (scopeParameter?.type?.isNullable() == true ) {
262+ val irSafeAs = IrTypeOperatorCallImpl (
263+ startOffset,
264+ endOffset,
265+ scopeType,
266+ IrTypeOperator .SAFE_CAST ,
267+ scopeType,
268+ builderWithScope.irGet(dispatchReceiverParameter)
269+ )
270+
271+ putValueArgument(index, irSafeAs)
272+ }
273+ // irAs(irGet(dispatchReceiverParameter), coroutineScopeRef.defaultType)
274+ }
275+ }
276+ }
277+ }
0 commit comments