Skip to content

Commit 4b3d12f

Browse files
authored
Merge pull request #8 from Aust1n46/nested
Support nested test classes. Supports any depth of nesting
2 parents 3b4b505 + 966c306 commit 4b3d12f

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

src/main/java/com/github/exabrial/junit5/injectmap/InjectExtension.java

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
1515
import org.junit.jupiter.api.extension.ExtensionContext;
16+
import org.junit.jupiter.api.extension.TestInstances;
1617
import org.mockito.InjectMocks;
1718

1819
import javassist.util.proxy.MethodFilter;
@@ -37,31 +38,44 @@ public static boolean status() {
3738

3839
@Override
3940
public void beforeTestExecution(ExtensionContext context) throws Exception {
40-
Object testInstance = context.getTestInstance().get();
41-
if (testInstance != null) {
42-
final Map<String, Field> injectMap = new HashMap<>();
43-
for (Field testClassField : testInstance.getClass().getDeclaredFields()) {
44-
if (testClassField.getAnnotation(InjectMocks.class) != null) {
45-
testClassField.setAccessible(true);
46-
final Object injectionTarget = testClassField.get(testInstance);
47-
final ProxyFactory proxyFactory = new ProxyFactory();
48-
proxyFactory.setSuperclass(injectionTarget.getClass());
49-
proxyFactory.setFilter(createMethodFilter());
50-
final Class<?> proxyClass = proxyFactory.createClass();
51-
final Object proxy = proxyClass.newInstance();
52-
final Map<String, List<Field>> fieldMap = createFieldMap(injectionTarget.getClass());
53-
Method postConstructMethod;
54-
if (testClassField.getAnnotation(InvokePostConstruct.class) != null) {
55-
postConstructMethod = findPostConstructMethod(injectionTarget);
56-
} else {
57-
postConstructMethod = null;
58-
}
59-
final MethodHandler handler = createMethodHandler(injectMap, injectionTarget, fieldMap, testInstance, postConstructMethod);
60-
((Proxy) proxy).setHandler(handler);
61-
testClassField.set(testInstance, proxy);
62-
} else if (testClassField.getAnnotation(InjectionSource.class) != null) {
63-
injectMap.put(testClassField.getName(), testClassField);
41+
final Object testInstance = context.getTestInstance().get();
42+
final Class<?> actualTestClazz = testInstance.getClass();
43+
Class<?> topLevelClass = actualTestClazz;
44+
while (topLevelClass.getEnclosingClass() != null) {
45+
topLevelClass = topLevelClass.getEnclosingClass();
46+
}
47+
final Class<?> parentTestClass = topLevelClass != actualTestClazz ? topLevelClass : null;
48+
49+
final Object parentClassInstance = context
50+
.getTestInstances()
51+
.get()
52+
.getEnclosingInstances()
53+
.stream()
54+
.filter(i -> i.getClass() == parentTestClass)
55+
.findFirst()
56+
.orElse(testInstance);
57+
final Map<String, Field> injectMap = new HashMap<>();
58+
for (Field testClassField : parentClassInstance.getClass().getDeclaredFields()) {
59+
if (testClassField.getAnnotation(InjectMocks.class) != null) {
60+
testClassField.setAccessible(true);
61+
final Object injectionTarget = testClassField.get(parentClassInstance);
62+
final ProxyFactory proxyFactory = new ProxyFactory();
63+
proxyFactory.setSuperclass(injectionTarget.getClass());
64+
proxyFactory.setFilter(createMethodFilter());
65+
final Class<?> proxyClass = proxyFactory.createClass();
66+
final Object proxy = proxyClass.newInstance();
67+
final Map<String, List<Field>> fieldMap = createFieldMap(injectionTarget.getClass());
68+
Method postConstructMethod;
69+
if (testClassField.getAnnotation(InvokePostConstruct.class) != null) {
70+
postConstructMethod = findPostConstructMethod(injectionTarget);
71+
} else {
72+
postConstructMethod = null;
6473
}
74+
final MethodHandler handler = createMethodHandler(injectMap, injectionTarget, fieldMap, parentClassInstance, postConstructMethod);
75+
((Proxy) proxy).setHandler(handler);
76+
testClassField.set(parentClassInstance, proxy);
77+
} else if (testClassField.getAnnotation(InjectionSource.class) != null) {
78+
injectMap.put(testClassField.getName(), testClassField);
6579
}
6680
}
6781
}
@@ -73,21 +87,16 @@ private Method findPostConstructMethod(Object injectionTarget) {
7387
}
7488
}
7589
throw new RuntimeException(
76-
"@InvokePostConstruct is delcared on:" + injectionTarget + " however no method annotated with @PostConstruct found");
90+
"@InvokePostConstruct is declared on:" + injectionTarget + " however no method annotated with @PostConstruct found");
7791
}
7892

79-
private Map<String, List<Field>> createFieldMap(Class<? extends Object> targetClass) {
93+
private Map<String, List<Field>> createFieldMap(Class<?> targetClass) {
8094
if (targetClass == Object.class) {
8195
return new HashMap<>();
8296
} else {
8397
Map<String, List<Field>> fieldMap = createFieldMap(targetClass.getSuperclass());
8498
for (Field field : targetClass.getDeclaredFields()) {
85-
List<Field> fieldList = fieldMap.get(field.getName());
86-
if (fieldList == null) {
87-
fieldList = new LinkedList<>();
88-
fieldMap.put(field.getName(), fieldList);
89-
}
90-
fieldList.add(field);
99+
fieldMap.computeIfAbsent(field.getName(), k -> new LinkedList<>()).add(field);
91100
}
92101
return fieldMap;
93102
}

0 commit comments

Comments
 (0)