Skip to content

Commit 0ea5975

Browse files
timtebeekclaude
andauthored
Prevent ReplaceLambdaWithMethodReference in nested generic/overloaded contexts (#776)
* Prevent ReplaceLambdaWithMethodReference in nested generic/overloaded contexts (#774) Adds a check to prevent converting lambdas to method references when the lambda is passed through method calls where one of the enclosing methods is overloaded and type inference depends on the lambda's explicit return type. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Slight polish --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 95e3cad commit 0ea5975

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
7272
}
7373

7474
private static class ReplaceLambdaWithMethodReferenceKotlinVisitor extends KotlinVisitor<ExecutionContext> {
75-
// Implement Me
75+
// XXX Implement Me
7676
}
7777

7878
private static class ReplaceLambdaWithMethodReferenceJavaVisitor extends JavaVisitor<ExecutionContext> {
@@ -195,6 +195,10 @@ public J visitLambda(J.Lambda lambda, ExecutionContext ctx) {
195195
.anyMatch(JavaType.GenericTypeVariable.class::isInstance)) {
196196
return l;
197197
}
198+
// Check if transforming would break type inference in nested generic/overloaded context
199+
if (isLambdaInGenericAndOverloadedContext()) {
200+
return l;
201+
}
198202
J.MemberReference updated = newStaticMethodReference(methodType, true, lambda.getType()).withPrefix(lambda.getPrefix());
199203
doAfterVisit(service(ImportService.class).shortenFullyQualifiedTypeReferencesIn(updated));
200204
return updated;
@@ -358,10 +362,68 @@ private boolean isMethodReferenceAmbiguous(JavaType.Method method) {
358362
}
359363
return false;
360364
}
365+
366+
/**
367+
* Check if the lambda is in a context where converting it to a method reference
368+
* would break type inference. This occurs when a lambda's return type depends on
369+
* generic type inference, and the lambda is passed through method calls where
370+
* one of the enclosing methods is overloaded.
371+
* <p>
372+
* Example: foo(fold(() -> Optional.empty()))
373+
* where fold is generic and foo is overloaded.
374+
*/
375+
private boolean isLambdaInGenericAndOverloadedContext() {
376+
// Walk up the cursor tree to find enclosing method invocations
377+
Cursor cursor = getCursor();
378+
379+
// Find the first method invocation that the lambda is an argument to
380+
Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof SourceFile);
381+
if (!(parent.getValue() instanceof J.MethodInvocation)) {
382+
return false;
383+
}
384+
385+
J.MethodInvocation innerMethod = parent.getValue();
386+
387+
// Now check if there's an enclosing overloaded method where innerMethod is an argument
388+
// Start from the parent of the first method invocation
389+
Cursor grandparent = parent.getParent();
390+
if (grandparent == null) {
391+
return false;
392+
}
393+
394+
grandparent = grandparent.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof SourceFile);
395+
if (!(grandparent.getValue() instanceof J.MethodInvocation)) {
396+
return false;
397+
}
398+
399+
J.MethodInvocation outerMethod = grandparent.getValue();
400+
401+
// Check that innerMethod is actually an ARGUMENT to outerMethod, not just chained
402+
boolean isInnerMethodAnArgument = outerMethod.getArguments().stream()
403+
.anyMatch(arg -> arg == innerMethod);
404+
405+
if (!isInnerMethodAnArgument) {
406+
return false;
407+
}
408+
409+
JavaType.Method outerType = outerMethod.getMethodType();
410+
if (outerType == null) {
411+
return false;
412+
}
413+
414+
// Check if the outer method is overloaded
415+
long overloadCount = outerType.getDeclaringType().getMethods().stream()
416+
.filter(m -> m.getName().equals(outerType.getName()) && !m.isConstructor())
417+
.count();
418+
419+
// If we have nested method calls where the outer one is overloaded,
420+
// be conservative and don't transform to avoid breaking type inference
421+
return overloadCount > 1;
422+
}
361423
}
362424

363425
private static boolean isAMethodInvocationArgument(J.Lambda lambda, Cursor cursor) {
364-
Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof J.CompilationUnit);
426+
Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof SourceFile);
365427
if (parent.getValue() instanceof J.MethodInvocation) {
366428
J.MethodInvocation m = parent.getValue();
367429
return m.getArguments().stream().anyMatch(arg -> arg == lambda);

src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,4 +1576,64 @@ private void foo() {
15761576
)
15771577
);
15781578
}
1579+
1580+
@Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/774")
1581+
@Test
1582+
void methodRefWithGenerics() {
1583+
rewriteRun(
1584+
//language=java
1585+
java(
1586+
"""
1587+
import java.util.Optional;
1588+
import java.util.function.Supplier;
1589+
1590+
class Foo {
1591+
<R> R fold(final Supplier<R> supplier) {return null;}
1592+
1593+
void foo(String l) {}
1594+
void foo(Optional<String> l) {}
1595+
1596+
void bar() {
1597+
foo(fold(() -> Optional.empty()));
1598+
}
1599+
}
1600+
"""
1601+
)
1602+
);
1603+
}
1604+
1605+
@Test
1606+
void simpleGenericMethodTest() {
1607+
rewriteRun(
1608+
//language=java
1609+
java(
1610+
"""
1611+
import java.util.function.Supplier;
1612+
1613+
class Foo {
1614+
<R> R fold(final Supplier<R> supplier) {return null;}
1615+
1616+
String getString() { return "test"; }
1617+
1618+
void bar() {
1619+
String result = fold(() -> getString());
1620+
}
1621+
}
1622+
""",
1623+
"""
1624+
import java.util.function.Supplier;
1625+
1626+
class Foo {
1627+
<R> R fold(final Supplier<R> supplier) {return null;}
1628+
1629+
String getString() { return "test"; }
1630+
1631+
void bar() {
1632+
String result = fold(this::getString);
1633+
}
1634+
}
1635+
"""
1636+
)
1637+
);
1638+
}
15791639
}

0 commit comments

Comments
 (0)