1313
1414import org .junit .jupiter .api .extension .BeforeTestExecutionCallback ;
1515import org .junit .jupiter .api .extension .ExtensionContext ;
16+ import org .junit .jupiter .api .extension .TestInstances ;
1617import org .mockito .InjectMocks ;
1718
1819import javassist .util .proxy .MethodFilter ;
@@ -37,31 +38,44 @@ public static boolean status() {
3738
3839 @ Override
3940 public void beforeTestExecution (ExtensionContext context ) throws Exception {
40- Object testInstance = context .getTestInstance ().get ();
41- if (testInstance != null ) {
42- final Map <String , Field > injectMap = new HashMap <>();
43- for (Field testClassField : testInstance .getClass ().getDeclaredFields ()) {
44- if (testClassField .getAnnotation (InjectMocks .class ) != null ) {
45- testClassField .setAccessible (true );
46- final Object injectionTarget = testClassField .get (testInstance );
47- final ProxyFactory proxyFactory = new ProxyFactory ();
48- proxyFactory .setSuperclass (injectionTarget .getClass ());
49- proxyFactory .setFilter (createMethodFilter ());
50- final Class <?> proxyClass = proxyFactory .createClass ();
51- final Object proxy = proxyClass .newInstance ();
52- final Map <String , List <Field >> fieldMap = createFieldMap (injectionTarget .getClass ());
53- Method postConstructMethod ;
54- if (testClassField .getAnnotation (InvokePostConstruct .class ) != null ) {
55- postConstructMethod = findPostConstructMethod (injectionTarget );
56- } else {
57- postConstructMethod = null ;
58- }
59- final MethodHandler handler = createMethodHandler (injectMap , injectionTarget , fieldMap , testInstance , postConstructMethod );
60- ((Proxy ) proxy ).setHandler (handler );
61- testClassField .set (testInstance , proxy );
62- } else if (testClassField .getAnnotation (InjectionSource .class ) != null ) {
63- injectMap .put (testClassField .getName (), testClassField );
41+ final Object testInstance = context .getTestInstance ().get ();
42+ final Class <?> actualTestClazz = testInstance .getClass ();
43+ Class <?> topLevelClass = actualTestClazz ;
44+ while (topLevelClass .getEnclosingClass () != null ) {
45+ topLevelClass = topLevelClass .getEnclosingClass ();
46+ }
47+ final Class <?> parentTestClass = topLevelClass != actualTestClazz ? topLevelClass : null ;
48+
49+ final Object parentClassInstance = context
50+ .getTestInstances ()
51+ .get ()
52+ .getEnclosingInstances ()
53+ .stream ()
54+ .filter (i -> i .getClass () == parentTestClass )
55+ .findFirst ()
56+ .orElse (testInstance );
57+ final Map <String , Field > injectMap = new HashMap <>();
58+ for (Field testClassField : parentClassInstance .getClass ().getDeclaredFields ()) {
59+ if (testClassField .getAnnotation (InjectMocks .class ) != null ) {
60+ testClassField .setAccessible (true );
61+ final Object injectionTarget = testClassField .get (parentClassInstance );
62+ final ProxyFactory proxyFactory = new ProxyFactory ();
63+ proxyFactory .setSuperclass (injectionTarget .getClass ());
64+ proxyFactory .setFilter (createMethodFilter ());
65+ final Class <?> proxyClass = proxyFactory .createClass ();
66+ final Object proxy = proxyClass .newInstance ();
67+ final Map <String , List <Field >> fieldMap = createFieldMap (injectionTarget .getClass ());
68+ Method postConstructMethod ;
69+ if (testClassField .getAnnotation (InvokePostConstruct .class ) != null ) {
70+ postConstructMethod = findPostConstructMethod (injectionTarget );
71+ } else {
72+ postConstructMethod = null ;
6473 }
74+ final MethodHandler handler = createMethodHandler (injectMap , injectionTarget , fieldMap , parentClassInstance , postConstructMethod );
75+ ((Proxy ) proxy ).setHandler (handler );
76+ testClassField .set (parentClassInstance , proxy );
77+ } else if (testClassField .getAnnotation (InjectionSource .class ) != null ) {
78+ injectMap .put (testClassField .getName (), testClassField );
6579 }
6680 }
6781 }
@@ -73,21 +87,16 @@ private Method findPostConstructMethod(Object injectionTarget) {
7387 }
7488 }
7589 throw new RuntimeException (
76- "@InvokePostConstruct is delcared on:" + injectionTarget + " however no method annotated with @PostConstruct found" );
90+ "@InvokePostConstruct is declared on:" + injectionTarget + " however no method annotated with @PostConstruct found" );
7791 }
7892
79- private Map <String , List <Field >> createFieldMap (Class <? extends Object > targetClass ) {
93+ private Map <String , List <Field >> createFieldMap (Class <?> targetClass ) {
8094 if (targetClass == Object .class ) {
8195 return new HashMap <>();
8296 } else {
8397 Map <String , List <Field >> fieldMap = createFieldMap (targetClass .getSuperclass ());
8498 for (Field field : targetClass .getDeclaredFields ()) {
85- List <Field > fieldList = fieldMap .get (field .getName ());
86- if (fieldList == null ) {
87- fieldList = new LinkedList <>();
88- fieldMap .put (field .getName (), fieldList );
89- }
90- fieldList .add (field );
99+ fieldMap .computeIfAbsent (field .getName (), k -> new LinkedList <>()).add (field );
91100 }
92101 return fieldMap ;
93102 }
0 commit comments