2121import org .openrewrite .java .tree .*;
2222import org .openrewrite .staticanalysis .java .JavaFileChecker ;
2323
24+ import org .jspecify .annotations .Nullable ;
25+
2426import java .util .ArrayList ;
2527import java .util .List ;
2628
@@ -90,7 +92,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
9092 inferredType = methodDeclaration .getReturnTypeExpression ().getType ();
9193 }
9294 } else if (e instanceof J .Lambda ) {
93- inferredType = (( J .Lambda ) e ).getType ();
95+ inferredType = getLambdaReturnType ((( J .Lambda ) e ).getType () );
9496 }
9597 }
9698
@@ -101,6 +103,46 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
101103 return m ;
102104 }
103105
106+ private JavaType .@ Nullable Method findMethodIfUnambiguous (JavaType .FullyQualified type ) {
107+ JavaType .Method sam = null ;
108+ for (JavaType .Method candidate : type .getMethods ()) {
109+ if (candidate .hasFlags (Flag .Default ) || candidate .hasFlags (Flag .Static )) {
110+ continue ;
111+ }
112+ if (sam != null ) {
113+ return null ;
114+ }
115+ sam = candidate ;
116+ }
117+ return sam ;
118+ }
119+
120+ private @ Nullable JavaType getLambdaReturnType (@ Nullable JavaType lambdaType ) {
121+ JavaType .Parameterized parameterized = TypeUtils .asParameterized (lambdaType );
122+ if (parameterized == null ) {
123+ return null ;
124+ }
125+ JavaType .Method sam = findMethodIfUnambiguous (parameterized );
126+ if (sam == null ) {
127+ return null ;
128+ }
129+ JavaType samReturn = sam .getReturnType ();
130+ if (samReturn instanceof JavaType .GenericTypeVariable ) {
131+ String name = ((JavaType .GenericTypeVariable ) samReturn ).getName ();
132+ List <JavaType > formalParams = parameterized .getType ().getTypeParameters ();
133+ List <JavaType > actualParams = parameterized .getTypeParameters ();
134+ for (int i = 0 ; i < formalParams .size () && i < actualParams .size (); i ++) {
135+ JavaType formal = formalParams .get (i );
136+ if (formal instanceof JavaType .GenericTypeVariable &&
137+ name .equals (((JavaType .GenericTypeVariable ) formal ).getName ())) {
138+ return actualParams .get (i );
139+ }
140+ }
141+ return null ;
142+ }
143+ return samReturn ;
144+ }
145+
104146 private boolean shouldRetainOnStaticMethod (JavaType .Method methodType ) {
105147 if (!methodType .hasFlags (Flag .Static )) {
106148 return false ;
0 commit comments