Skip to content

Commit 10366d2

Browse files
author
Austin Brolly
committed
Support nested test classes
1 parent 3b4b505 commit 10366d2

File tree

1 file changed

+43
-32
lines changed

1 file changed

+43
-32
lines changed

src/main/java/com/github/exabrial/junit5/injectmap/InjectExtension.java

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
1515
import org.junit.jupiter.api.extension.ExtensionContext;
16+
import org.junit.jupiter.api.extension.TestInstances;
1617
import org.mockito.InjectMocks;
1718

1819
import javassist.util.proxy.MethodFilter;
@@ -37,31 +38,46 @@ public static boolean status() {
3738

3839
@Override
3940
public void beforeTestExecution(ExtensionContext context) throws Exception {
40-
Object testInstance = context.getTestInstance().get();
41-
if (testInstance != null) {
42-
final Map<String, Field> injectMap = new HashMap<>();
43-
for (Field testClassField : testInstance.getClass().getDeclaredFields()) {
44-
if (testClassField.getAnnotation(InjectMocks.class) != null) {
45-
testClassField.setAccessible(true);
46-
final Object injectionTarget = testClassField.get(testInstance);
47-
final ProxyFactory proxyFactory = new ProxyFactory();
48-
proxyFactory.setSuperclass(injectionTarget.getClass());
49-
proxyFactory.setFilter(createMethodFilter());
50-
final Class<?> proxyClass = proxyFactory.createClass();
51-
final Object proxy = proxyClass.newInstance();
52-
final Map<String, List<Field>> fieldMap = createFieldMap(injectionTarget.getClass());
53-
Method postConstructMethod;
54-
if (testClassField.getAnnotation(InvokePostConstruct.class) != null) {
55-
postConstructMethod = findPostConstructMethod(injectionTarget);
56-
} else {
57-
postConstructMethod = null;
58-
}
59-
final MethodHandler handler = createMethodHandler(injectMap, injectionTarget, fieldMap, testInstance, postConstructMethod);
60-
((Proxy) proxy).setHandler(handler);
61-
testClassField.set(testInstance, proxy);
62-
} else if (testClassField.getAnnotation(InjectionSource.class) != null) {
63-
injectMap.put(testClassField.getName(), testClassField);
41+
final Object testInstance = context.getTestInstance().get();
42+
final Class<?> actualTestClazz = testInstance.getClass();
43+
System.out.println("actualTestClass: " + actualTestClazz.getName());
44+
Class<?> topLevelClass = actualTestClazz;
45+
while (topLevelClass.getEnclosingClass() != null) {
46+
topLevelClass = topLevelClass.getEnclosingClass();
47+
}
48+
final Class<?> parentTestClass = topLevelClass != actualTestClazz ? topLevelClass : null;
49+
50+
final Object parentClassInstance = context
51+
.getTestInstances()
52+
.get()
53+
.getEnclosingInstances()
54+
.stream()
55+
.filter(i -> i.getClass() == parentTestClass)
56+
.findFirst()
57+
.orElse(testInstance);
58+
System.out.println("enclosing instance class name: " + parentClassInstance.getClass().getName());
59+
final Map<String, Field> injectMap = new HashMap<>();
60+
for (Field testClassField : parentClassInstance.getClass().getDeclaredFields()) {
61+
if (testClassField.getAnnotation(InjectMocks.class) != null) {
62+
testClassField.setAccessible(true);
63+
final Object injectionTarget = testClassField.get(parentClassInstance);
64+
final ProxyFactory proxyFactory = new ProxyFactory();
65+
proxyFactory.setSuperclass(injectionTarget.getClass());
66+
proxyFactory.setFilter(createMethodFilter());
67+
final Class<?> proxyClass = proxyFactory.createClass();
68+
final Object proxy = proxyClass.newInstance();
69+
final Map<String, List<Field>> fieldMap = createFieldMap(injectionTarget.getClass());
70+
Method postConstructMethod;
71+
if (testClassField.getAnnotation(InvokePostConstruct.class) != null) {
72+
postConstructMethod = findPostConstructMethod(injectionTarget);
73+
} else {
74+
postConstructMethod = null;
6475
}
76+
final MethodHandler handler = createMethodHandler(injectMap, injectionTarget, fieldMap, parentClassInstance, postConstructMethod);
77+
((Proxy) proxy).setHandler(handler);
78+
testClassField.set(parentClassInstance, proxy);
79+
} else if (testClassField.getAnnotation(InjectionSource.class) != null) {
80+
injectMap.put(testClassField.getName(), testClassField);
6581
}
6682
}
6783
}
@@ -73,21 +89,16 @@ private Method findPostConstructMethod(Object injectionTarget) {
7389
}
7490
}
7591
throw new RuntimeException(
76-
"@InvokePostConstruct is delcared on:" + injectionTarget + " however no method annotated with @PostConstruct found");
92+
"@InvokePostConstruct is declared on:" + injectionTarget + " however no method annotated with @PostConstruct found");
7793
}
7894

79-
private Map<String, List<Field>> createFieldMap(Class<? extends Object> targetClass) {
95+
private Map<String, List<Field>> createFieldMap(Class<?> targetClass) {
8096
if (targetClass == Object.class) {
8197
return new HashMap<>();
8298
} else {
8399
Map<String, List<Field>> fieldMap = createFieldMap(targetClass.getSuperclass());
84100
for (Field field : targetClass.getDeclaredFields()) {
85-
List<Field> fieldList = fieldMap.get(field.getName());
86-
if (fieldList == null) {
87-
fieldList = new LinkedList<>();
88-
fieldMap.put(field.getName(), fieldList);
89-
}
90-
fieldList.add(field);
101+
fieldMap.computeIfAbsent(field.getName(), k -> new LinkedList<>()).add(field);
91102
}
92103
return fieldMap;
93104
}

0 commit comments

Comments
 (0)