Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
import static org.apiguardian.api.API.Status.INTERNAL;
import static org.junit.platform.commons.util.KotlinReflectionUtils.getKotlinSuspendingFunctionGenericReturnType;
import static org.junit.platform.commons.util.KotlinReflectionUtils.getKotlinSuspendingFunctionReturnType;
import static org.junit.platform.commons.util.KotlinReflectionUtils.invokeKotlinFunction;
import static org.junit.platform.commons.util.KotlinReflectionUtils.invokeKotlinSuspendingFunction;
import static org.junit.platform.commons.util.KotlinReflectionUtils.isKotlinSuspendingFunction;
import static org.junit.platform.commons.util.KotlinReflectionUtils.isKotlinType;

import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;

import org.apiguardian.api.API;
import org.jspecify.annotations.Nullable;
import org.junit.platform.commons.support.ReflectionSupport;
import org.junit.platform.commons.util.KotlinReflectionUtils;

@API(status = INTERNAL, since = "6.0")
public class MethodReflectionUtils {
Expand All @@ -42,9 +46,17 @@ public static Type getGenericReturnType(Method method) {
if (isKotlinSuspendingFunction(method)) {
return invokeKotlinSuspendingFunction(method, target, arguments);
}
if (isKotlinType(method.getDeclaringClass()) && hasInlineTypeArgument(arguments)) {
return invokeKotlinFunction(method, target, arguments);
}
return ReflectionSupport.invokeMethod(method, target, arguments);
}

private static boolean hasInlineTypeArgument(@Nullable Object[] arguments) {
return arguments.length > 0 //
&& Arrays.stream(arguments).anyMatch(KotlinReflectionUtils::isInstanceOfInlineType);
}

private MethodReflectionUtils() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.ReflectJvmMapping;

class KotlinSuspendingFunctionUtils {
class KotlinFunctionUtils {

static Class<?> getReturnType(Method method) {
var returnType = getJavaClass(getJvmErasure(getKotlinFunction(method).getReturnType()));
Expand Down Expand Up @@ -67,17 +67,35 @@ static Class<?>[] getParameterTypes(Method method) {
return Arrays.stream(method.getParameterTypes()).limit(parameterCount - 1).toArray(Class<?>[]::new);
}

static @Nullable Object invoke(Method method, @Nullable Object target, @Nullable Object[] args) {
static @Nullable Object invokeKotlinFunction(Method method, @Nullable Object target, @Nullable Object[] args) {
try {
return invoke(getKotlinFunction(method), target, args);
return invokeKotlinFunction(getKotlinFunction(method), target, args);
}
catch (InterruptedException e) {
throw throwAsUncheckedException(e);
}
}

private static <T> @Nullable T invoke(KFunction<T> function, @Nullable Object target, @Nullable Object[] args)
throws InterruptedException {
private static <T extends @Nullable Object> T invokeKotlinFunction(KFunction<T> function, @Nullable Object target,
@Nullable Object[] args) throws InterruptedException {
if (!isAccessible(function)) {
setAccessible(function, true);
}
return function.callBy(toArgumentMap(target, args, function));
}

static @Nullable Object invokeKotlinSuspendingFunction(Method method, @Nullable Object target,
@Nullable Object[] args) {
try {
return invokeKotlinSuspendingFunction(getKotlinFunction(method), target, args);
}
catch (InterruptedException e) {
throw throwAsUncheckedException(e);
}
}

private static <T extends @Nullable Object> T invokeKotlinSuspendingFunction(KFunction<T> function,
@Nullable Object target, @Nullable Object[] args) throws InterruptedException {
if (!isAccessible(function)) {
setAccessible(function, true);
}
Expand Down Expand Up @@ -113,6 +131,6 @@ private static KFunction<?> getKotlinFunction(Method method) {
() -> "Failed to get Kotlin function for method: " + method);
}

private KotlinSuspendingFunctionUtils() {
private KotlinFunctionUtils() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ public class KotlinReflectionUtils {
private static final String DEFAULT_IMPLS_CLASS_NAME = "DefaultImpls";

private static final @Nullable Class<? extends Annotation> kotlinMetadata;
private static final @Nullable Class<? extends Annotation> jvmInline;
private static final @Nullable Class<?> kotlinCoroutineContinuation;
private static final boolean kotlinReflectPresent;
private static final boolean kotlinxCoroutinesPresent;

static {
var metadata = tryToLoadKotlinMetadataClass();
kotlinMetadata = metadata.toOptional().orElse(null);
jvmInline = tryToLoadJvmInlineClass().toOptional().orElse(null);
kotlinCoroutineContinuation = metadata //
.andThen(__ -> tryToLoadClass("kotlin.coroutines.Continuation")) //
.toOptional() //
Expand All @@ -62,6 +64,12 @@ private static Try<Class<? extends Annotation>> tryToLoadKotlinMetadataClass() {
.andThenTry(it -> (Class<? extends Annotation>) it);
}

@SuppressWarnings("unchecked")
private static Try<Class<? extends Annotation>> tryToLoadJvmInlineClass() {
return tryToLoadClass("kotlin.jvm.JvmInline") //
.andThenTry(it -> (Class<? extends Annotation>) it);
}

/**
* @since 6.0
*/
Expand Down Expand Up @@ -117,36 +125,48 @@ private static Class<?>[] copyWithoutFirst(Class<?>[] values) {
return result;
}

private static boolean isKotlinType(Class<?> clazz) {
public static boolean isKotlinType(Class<?> clazz) {
return kotlinMetadata != null //
&& clazz.getDeclaredAnnotation(kotlinMetadata) != null;
}

public static Class<?> getKotlinSuspendingFunctionReturnType(Method method) {
requireKotlinReflect(method);
return KotlinSuspendingFunctionUtils.getReturnType(method);
return KotlinFunctionUtils.getReturnType(method);
}

public static Type getKotlinSuspendingFunctionGenericReturnType(Method method) {
requireKotlinReflect(method);
return KotlinSuspendingFunctionUtils.getGenericReturnType(method);
return KotlinFunctionUtils.getGenericReturnType(method);
}

public static Parameter[] getKotlinSuspendingFunctionParameters(Method method) {
requireKotlinReflect(method);
return KotlinSuspendingFunctionUtils.getParameters(method);
return KotlinFunctionUtils.getParameters(method);
}

public static Class<?>[] getKotlinSuspendingFunctionParameterTypes(Method method) {
requireKotlinReflect(method);
return KotlinSuspendingFunctionUtils.getParameterTypes(method);
return KotlinFunctionUtils.getParameterTypes(method);
}

public static @Nullable Object invokeKotlinSuspendingFunction(Method method, @Nullable Object target,
@Nullable Object[] args) {
requireKotlinReflect(method);
requireKotlinxCoroutines(method);
return KotlinSuspendingFunctionUtils.invoke(method, target, args);
return KotlinFunctionUtils.invokeKotlinSuspendingFunction(method, target, args);
}

public static boolean isInstanceOfInlineType(@Nullable Object value) {
return jvmInline != null //
&& value != null //
&& value.getClass().getDeclaredAnnotation(jvmInline) != null;
}

public static @Nullable Object invokeKotlinFunction(Method method, @Nullable Object target,
@Nullable Object... args) {
requireKotlinReflect(method);
return KotlinFunctionUtils.invokeKotlinFunction(method, target, args);
}

private static void requireKotlinReflect(Method method) {
Expand All @@ -159,7 +179,7 @@ private static void requireKotlinxCoroutines(Method method) {

private static void requireDependency(Method method, boolean condition, String dependencyNotation) {
Preconditions.condition(condition,
() -> ("Kotlin suspending function [%s] requires %s to be on the classpath or module path. "
() -> ("Kotlin function [%s] requires %s to be on the classpath or module path. "
+ "Please add a corresponding dependency.").formatted(method.toGenericString(),
dependencyNotation));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.junit.jupiter.api.kotlin

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.MethodSource

/**
* Tests for custom inline value classes.
*
* Currently, disabled: The POC only supports kotlin.Result.
* Support for arbitrary inline value classes needs to be added.
*
* @see <a href="https://github.com/junit-team/junit-framework/issues/5081">Issue #5081</a>
*/
@Disabled("POC only supports kotlin.Result, not custom inline value classes")
class CustomInlineValueClassTest {

@MethodSource("userIdProvider")
@ParameterizedTest
fun testUserId(userId: UserId) {
assertEquals(123L, userId.value)
}

@MethodSource("emailProvider")
@ParameterizedTest
fun testEmail(email: Email) {
assertEquals("test@example.com", email.value)
}

companion object {
@JvmStatic
fun userIdProvider() = listOf(Arguments.of(UserId(123L)))

@JvmStatic
fun emailProvider() = listOf(Arguments.of(Email("test@example.com")))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.junit.jupiter.api.kotlin

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.MethodSource

class MultipleInlineValueClassTest {

@MethodSource("mixedProvider")
@ParameterizedTest
fun testMultipleValueClasses(
userId: UserId,
email: Email,
result: Result<String>
) {
assertEquals(100L, userId.value)
assertEquals("user@test.com", email.value)
assertEquals("data", result.getOrThrow())
}

@MethodSource("normalAndValueClassProvider")
@ParameterizedTest
fun testMixedParameters(
normalString: String,
userId: UserId
) {
assertEquals("normal", normalString)
assertEquals(200L, userId.value)
}

companion object {
@JvmStatic
fun mixedProvider() = listOf(
Arguments.of(
UserId(100L),
Email("user@test.com"),
Result.success("data")
)
)

@JvmStatic
fun normalAndValueClassProvider() = listOf(
Arguments.of("normal", UserId(200L))
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.junit.jupiter.api.kotlin

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.assertNotNull
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.MethodSource

class NullableInlineValueClassTest {

@MethodSource("nullableResultProvider")
@ParameterizedTest
fun testNullableResult(result: Result<String>?) {
assertNotNull(result)
assertEquals("test", result.getOrNull())
}

@MethodSource("nullableUserIdProvider")
@ParameterizedTest
fun testNullableUserId(userId: UserId?) {
assertNotNull(userId)
assertEquals(999L, userId.value)
}

companion object {
@JvmStatic
fun nullableResultProvider() =
listOf(
Arguments.of(Result.success("test"))
)

@JvmStatic
fun nullableUserIdProvider() =
listOf(
Arguments.of(UserId(999L))
)
}
}
Loading