1616
1717import com .bruce .intellijplugin .generatesetter .CommonConstants ;
1818import com .bruce .intellijplugin .generatesetter .GenerateAllHandlerAdapter ;
19+ import com .google .common .collect .ImmutableMap ;
1920import com .intellij .openapi .editor .Editor ;
21+ import com .intellij .openapi .module .Module ;
2022import com .intellij .openapi .project .Project ;
23+ import com .intellij .openapi .roots .ProjectFileIndex ;
2124import com .intellij .openapi .roots .ProjectRootManager ;
2225import com .intellij .openapi .vfs .VirtualFile ;
2326import com .intellij .psi .PsiClass ;
2730import com .intellij .psi .PsiImportStatement ;
2831import com .intellij .psi .PsiImportStaticStatement ;
2932import com .intellij .psi .PsiJavaFile ;
33+ import com .intellij .psi .search .GlobalSearchScope ;
34+ import com .intellij .psi .search .PsiShortNamesCache ;
3035import org .jetbrains .annotations .NotNull ;
3136
32- import java .util .HashMap ;
3337import java .util .HashSet ;
3438import java .util .Map ;
3539import java .util .Set ;
3640
37- import static com .bruce .intellijplugin .generatesetter .actions .AssertAllGetterAction .TestEngine .*;
41+ import static com .bruce .intellijplugin .generatesetter .actions .AssertAllGetterAction .TestEngine .ASSERT ;
42+ import static com .bruce .intellijplugin .generatesetter .actions .AssertAllGetterAction .TestEngine .ASSERTJ ;
43+ import static com .bruce .intellijplugin .generatesetter .actions .AssertAllGetterAction .TestEngine .JUNIT4 ;
44+ import static com .bruce .intellijplugin .generatesetter .actions .AssertAllGetterAction .TestEngine .JUNIT5 ;
45+ import static com .bruce .intellijplugin .generatesetter .actions .AssertAllGetterAction .TestEngine .TESTNG ;
3846
3947/**
4048 * @author bruce ge
4149 */
4250public class AssertAllGetterAction extends GenerateAllSetterBase {
4351 enum TestEngine {ASSERT , JUNIT4 , JUNIT5 , TESTNG , ASSERTJ }
4452
45- private TestEngine currentFileTestEngine = TestEngine .ASSERT ;
46- private Set <TestEngine > currentFileAssertsImported = new HashSet <>();
47-
48- private static final Map <TestEngine , String > engineImports = new HashMap <>();
49- // className -> engine
50- private static final Map <String , TestEngine > enginePlainImportsReversed = new HashMap <>();
51- // static className without static and method -> engine
52- private static final Map <String , TestEngine > engineStaticImportsReversed = new HashMap <>();
53- // engine -> static method like assertEquals
54- private static final Map <TestEngine , String > engineStaticImportsMethod = new HashMap <>();
55-
56- static {
57- engineImports .put (JUNIT4 , "static org.junit.Assert.assertEquals" );
58- engineImports .put (JUNIT5 , "static org.junit.jupiter.api.Assertions.assertEquals" );
59- engineImports .put (TESTNG , "static org.testng.Assert.assertEquals" );
60- engineImports .put (ASSERTJ , "static org.assertj.core.api.Assertions.assertThat" );
61- engineImports .put (ASSERT , "java.util.Objects" );
62-
63- engineImports .forEach ((a , b ) -> {
64- if (!b .startsWith ("static " )) {
65- enginePlainImportsReversed .put (b , a );
66- } else {
67- engineStaticImportsReversed .put (b
68- .substring (0 , b .lastIndexOf ("." ))
69- .replace ("static " , "" ), a );
70- engineStaticImportsMethod .put (a , b .substring (b .lastIndexOf ("." ) + 1 ));
71- }
72- });
73- }
53+ // imports to add when generating asserts.
54+ private static final Map <TestEngine , String > engineImports = ImmutableMap .<TestEngine , String >builder ()
55+ .put (JUNIT4 , "static org.junit.Assert.assertEquals" )
56+ .put (JUNIT5 , "static org.junit.jupiter.api.Assertions.assertEquals" )
57+ .put (TESTNG , "static org.testng.Assert.assertEquals" )
58+ .put (ASSERTJ , "static org.assertj.core.api.Assertions.assertThat" )
59+ .put (ASSERT , "java.util.Objects" )
60+ .build ();
61+
62+ // className like 'java.util.Objects' -> engine (only java.util.Objects)
63+ private static final Map <String , TestEngine > enginePlainImportsReversed = ImmutableMap .<String , TestEngine >builder ()
64+ .put ("java.util.Objects" , ASSERT )
65+ .build ();
66+
67+ // static className -> engine
68+ private static final Map <String , TestEngine > engineStaticImportsReversed = ImmutableMap .<String , TestEngine >builder ()
69+ .put ("org.junit.Assert" , JUNIT4 )
70+ .put ("org.junit.jupiter.api.Assertions" , JUNIT5 )
71+ .put ("org.testng.Assert" , TESTNG )
72+ .put ("org.assertj.core.api.Assertions" , ASSERTJ )
73+ .build ();
74+
75+ // engine -> assert static method
76+ private static final Map <TestEngine , String > engineStaticImportsMethod = ImmutableMap .<TestEngine , String >builder ()
77+ .put (JUNIT4 , "assertEquals" )
78+ .put (JUNIT5 , "assertEquals" )
79+ .put (TESTNG , "assertEquals" )
80+ .put (ASSERTJ , "assertThat" )
81+ .build ();
82+
83+ private Project project ;
84+ private PsiFile containingFile ;
7485
7586 public AssertAllGetterAction () {
7687 setGenerateAllHandler (new GenerateAllAssertsHandlerAdapter (true ));
@@ -97,74 +108,114 @@ public boolean isAvailable(@NotNull Project project, Editor editor, @NotNull Psi
97108 boolean inTestSourceContent = instance .getFileIndex ().isInTestSourceContent (virtualFile );
98109
99110 if (inTestSourceContent ) {
100- currentFileAssertsImported = new HashSet <>() ;
101- currentFileTestEngine = detectCurrentTestEngine ( containingFile ) ;
111+ this . project = project ;
112+ this . containingFile = containingFile ;
102113 return super .isAvailable (project , editor , element );
103114 }
104115 return false ;
105116 }
106117
107- private TestEngine detectCurrentTestEngine (PsiFile containingFile ) {
108- if (containingFile instanceof PsiJavaFile ) {
109- PsiJavaFile javaFile = (PsiJavaFile ) containingFile ;
110- PsiImportList importList = javaFile .getImportList ();
111118
112- if (importList != null ) {
113- PsiImportStaticStatement [] importStaticStatements = importList .getImportStaticStatements ();
114- for (PsiImportStaticStatement importStaticStatement : importStaticStatements ) {
119+ class GenerateAllAssertsHandlerAdapter extends GenerateAllHandlerAdapter {
120+ private final boolean generateWithDefaultValues ;
121+ private final Set <TestEngine > currentFileImportedEngines = new HashSet <>();
122+ private TestEngine currentFileTestEngine = TestEngine .ASSERT ;
115123
116- PsiClass psiClass = importStaticStatement .resolveTargetClass ();
117- if (psiClass == null ) {
118- continue ;
119- }
124+ public GenerateAllAssertsHandlerAdapter (boolean generateWithDefaultValues ) {
125+ this .generateWithDefaultValues = generateWithDefaultValues ;
126+ }
127+
128+ private TestEngine detectCurrentTestEngine (Project project , PsiFile containingFile ) {
129+
130+ if (containingFile instanceof PsiJavaFile ) {
131+ PsiJavaFile javaFile = (PsiJavaFile ) containingFile ;
132+ PsiImportList importList = javaFile .getImportList ();
120133
121- String qualifiedName = psiClass .getQualifiedName ();
122- TestEngine testEngine = engineStaticImportsReversed .get (qualifiedName );
123- if (testEngine != null ) {
124- String referenceName = importStaticStatement .getReferenceName ();
125- if (referenceName == null || referenceName .equals (engineStaticImportsMethod .get (testEngine ))) {
126- currentFileAssertsImported .add (testEngine );
134+ // prefer AssertJ if it is in classpath
135+ ProjectFileIndex index = ProjectFileIndex .getInstance (project );
136+ Module module = index .getModuleForFile (containingFile .getVirtualFile ());
137+ if (module != null ) {
138+ GlobalSearchScope searchScope = GlobalSearchScope .moduleRuntimeScope (module , true );
139+ PsiClass [] lists = PsiShortNamesCache .getInstance (project )
140+ .getClassesByName ("Assertions" , searchScope );
141+
142+ for (PsiClass psiClass : lists ) {
143+ if ("org.assertj.core.api.Assertions" .equals (psiClass .getName ())) {
144+ detectImportedEngines (importList );
145+ return ASSERTJ ;
127146 }
128147 }
129148 }
130149
131- PsiImportStatement [] importStatements = importList .getImportStatements ();
132150
133- for (PsiImportStatement importStatement : importStatements ) {
134- String qualifiedName = importStatement .getQualifiedName ();
135- if (qualifiedName == null ) {
136- continue ;
137- }
151+ if (importList != null ) {
152+ detectImportedEngines (importList );
138153
139- TestEngine testEngine = enginePlainImportsReversed .get (qualifiedName );
140- if (testEngine != null ) {
141- currentFileAssertsImported .add (testEngine );
142- }
154+ PsiImportStatement [] importStatements = importList .getImportStatements ();
143155
144- if (qualifiedName .startsWith ("org.junit.jupiter.api." )) {
145- return TestEngine .JUNIT5 ;
146- }
147- if (qualifiedName .startsWith ("org.junit." )) {
148- return TestEngine .JUNIT4 ;
149- }
150- if (qualifiedName .startsWith ("org.assertj." )) {
151- return TestEngine .ASSERTJ ;
152- }
153- if (qualifiedName .startsWith ("org.testng." )) {
154- return TestEngine .TESTNG ;
156+ for (PsiImportStatement importStatement : importStatements ) {
157+ String qualifiedName = importStatement .getQualifiedName ();
158+ if (qualifiedName == null ) {
159+ continue ;
160+ }
161+
162+ if (qualifiedName .startsWith ("org.junit.jupiter.api." )) {
163+ return TestEngine .JUNIT5 ;
164+ }
165+ if (qualifiedName .startsWith ("org.junit." )) {
166+ return TestEngine .JUNIT4 ;
167+ }
168+ if (qualifiedName .startsWith ("org.assertj." )) {
169+ return TestEngine .ASSERTJ ;
170+ }
171+ if (qualifiedName .startsWith ("org.testng." )) {
172+ return TestEngine .TESTNG ;
173+ }
155174 }
156175 }
157176 }
177+
178+ return TestEngine .ASSERT ;
158179 }
159180
160- return TestEngine . ASSERT ;
161- }
181+ private void detectImportedEngines ( PsiImportList importList ) {
182+ currentFileImportedEngines . clear ();
162183
163- class GenerateAllAssertsHandlerAdapter extends GenerateAllHandlerAdapter {
164- private final boolean generateWithDefaultValues ;
184+ if (importList == null ) {
185+ return ;
186+ }
165187
166- public GenerateAllAssertsHandlerAdapter (boolean generateWithDefaultValues ) {
167- this .generateWithDefaultValues = generateWithDefaultValues ;
188+ PsiImportStaticStatement [] importStaticStatements = importList .getImportStaticStatements ();
189+
190+ for (PsiImportStaticStatement importStaticStatement : importStaticStatements ) {
191+ PsiClass psiClass = importStaticStatement .resolveTargetClass ();
192+ if (psiClass == null ) {
193+ continue ;
194+ }
195+
196+ String qualifiedName = psiClass .getQualifiedName ();
197+ TestEngine testEngine = engineStaticImportsReversed .get (qualifiedName );
198+ if (testEngine != null ) {
199+ String referenceName = importStaticStatement .getReferenceName (); // like assertEquals
200+ if (referenceName == null || referenceName .equals (engineStaticImportsMethod .get (testEngine ))) {
201+ currentFileImportedEngines .add (testEngine );
202+ }
203+ }
204+ }
205+
206+ PsiImportStatement [] importStatements = importList .getImportStatements ();
207+
208+ for (PsiImportStatement importStatement : importStatements ) {
209+ String qualifiedName = importStatement .getQualifiedName ();
210+ if (qualifiedName == null ) {
211+ continue ;
212+ }
213+
214+ TestEngine testEngine = enginePlainImportsReversed .get (qualifiedName );
215+ if (testEngine != null ) {
216+ currentFileImportedEngines .add (testEngine );
217+ }
218+ }
168219 }
169220
170221 @ Override
@@ -195,6 +246,8 @@ public String formatLine(String line) {
195246 value = "" ;
196247 }
197248
249+ currentFileTestEngine = detectCurrentTestEngine (project , containingFile );
250+
198251 switch (currentFileTestEngine ) {
199252 case JUNIT4 :
200253 case JUNIT5 :
@@ -212,7 +265,7 @@ public String formatLine(String line) {
212265
213266 @ Override
214267 public void appendImportList (Set <String > newImportList ) {
215- if (!currentFileAssertsImported .contains (currentFileTestEngine )) {
268+ if (!currentFileImportedEngines .contains (currentFileTestEngine )) {
216269 newImportList .add (engineImports .get (currentFileTestEngine ));
217270 }
218271 }
0 commit comments