1313
1414import org .junit .jupiter .api .extension .BeforeTestExecutionCallback ;
1515import org .junit .jupiter .api .extension .ExtensionContext ;
16+ import org .junit .jupiter .api .extension .TestInstances ;
1617import org .mockito .InjectMocks ;
1718
1819import javassist .util .proxy .MethodFilter ;
@@ -37,31 +38,46 @@ public static boolean status() {
3738
3839 @ Override
3940 public void beforeTestExecution (ExtensionContext context ) throws Exception {
40- Object testInstance = context .getTestInstance ().get ();
41- if (testInstance != null ) {
42- final Map <String , Field > injectMap = new HashMap <>();
43- for (Field testClassField : testInstance .getClass ().getDeclaredFields ()) {
44- if (testClassField .getAnnotation (InjectMocks .class ) != null ) {
45- testClassField .setAccessible (true );
46- final Object injectionTarget = testClassField .get (testInstance );
47- final ProxyFactory proxyFactory = new ProxyFactory ();
48- proxyFactory .setSuperclass (injectionTarget .getClass ());
49- proxyFactory .setFilter (createMethodFilter ());
50- final Class <?> proxyClass = proxyFactory .createClass ();
51- final Object proxy = proxyClass .newInstance ();
52- final Map <String , List <Field >> fieldMap = createFieldMap (injectionTarget .getClass ());
53- Method postConstructMethod ;
54- if (testClassField .getAnnotation (InvokePostConstruct .class ) != null ) {
55- postConstructMethod = findPostConstructMethod (injectionTarget );
56- } else {
57- postConstructMethod = null ;
58- }
59- final MethodHandler handler = createMethodHandler (injectMap , injectionTarget , fieldMap , testInstance , postConstructMethod );
60- ((Proxy ) proxy ).setHandler (handler );
61- testClassField .set (testInstance , proxy );
62- } else if (testClassField .getAnnotation (InjectionSource .class ) != null ) {
63- injectMap .put (testClassField .getName (), testClassField );
41+ final Object testInstance = context .getTestInstance ().get ();
42+ final Class <?> actualTestClazz = testInstance .getClass ();
43+ System .out .println ("actualTestClass: " + actualTestClazz .getName ());
44+ Class <?> topLevelClass = actualTestClazz ;
45+ while (topLevelClass .getEnclosingClass () != null ) {
46+ topLevelClass = topLevelClass .getEnclosingClass ();
47+ }
48+ final Class <?> parentTestClass = topLevelClass != actualTestClazz ? topLevelClass : null ;
49+
50+ final Object parentClassInstance = context
51+ .getTestInstances ()
52+ .get ()
53+ .getEnclosingInstances ()
54+ .stream ()
55+ .filter (i -> i .getClass () == parentTestClass )
56+ .findFirst ()
57+ .orElse (testInstance );
58+ System .out .println ("enclosing instance class name: " + parentClassInstance .getClass ().getName ());
59+ final Map <String , Field > injectMap = new HashMap <>();
60+ for (Field testClassField : parentClassInstance .getClass ().getDeclaredFields ()) {
61+ if (testClassField .getAnnotation (InjectMocks .class ) != null ) {
62+ testClassField .setAccessible (true );
63+ final Object injectionTarget = testClassField .get (parentClassInstance );
64+ final ProxyFactory proxyFactory = new ProxyFactory ();
65+ proxyFactory .setSuperclass (injectionTarget .getClass ());
66+ proxyFactory .setFilter (createMethodFilter ());
67+ final Class <?> proxyClass = proxyFactory .createClass ();
68+ final Object proxy = proxyClass .newInstance ();
69+ final Map <String , List <Field >> fieldMap = createFieldMap (injectionTarget .getClass ());
70+ Method postConstructMethod ;
71+ if (testClassField .getAnnotation (InvokePostConstruct .class ) != null ) {
72+ postConstructMethod = findPostConstructMethod (injectionTarget );
73+ } else {
74+ postConstructMethod = null ;
6475 }
76+ final MethodHandler handler = createMethodHandler (injectMap , injectionTarget , fieldMap , parentClassInstance , postConstructMethod );
77+ ((Proxy ) proxy ).setHandler (handler );
78+ testClassField .set (parentClassInstance , proxy );
79+ } else if (testClassField .getAnnotation (InjectionSource .class ) != null ) {
80+ injectMap .put (testClassField .getName (), testClassField );
6581 }
6682 }
6783 }
@@ -73,21 +89,16 @@ private Method findPostConstructMethod(Object injectionTarget) {
7389 }
7490 }
7591 throw new RuntimeException (
76- "@InvokePostConstruct is delcared on:" + injectionTarget + " however no method annotated with @PostConstruct found" );
92+ "@InvokePostConstruct is declared on:" + injectionTarget + " however no method annotated with @PostConstruct found" );
7793 }
7894
79- private Map <String , List <Field >> createFieldMap (Class <? extends Object > targetClass ) {
95+ private Map <String , List <Field >> createFieldMap (Class <?> targetClass ) {
8096 if (targetClass == Object .class ) {
8197 return new HashMap <>();
8298 } else {
8399 Map <String , List <Field >> fieldMap = createFieldMap (targetClass .getSuperclass ());
84100 for (Field field : targetClass .getDeclaredFields ()) {
85- List <Field > fieldList = fieldMap .get (field .getName ());
86- if (fieldList == null ) {
87- fieldList = new LinkedList <>();
88- fieldMap .put (field .getName (), fieldList );
89- }
90- fieldList .add (field );
101+ fieldMap .computeIfAbsent (field .getName (), k -> new LinkedList <>()).add (field );
91102 }
92103 return fieldMap ;
93104 }
0 commit comments