Skip to content

Commit 934ca1f

Browse files
committed
Go: add receiver overrides, integ tests, and fix RPC parse input order
Add GolangReceiverDelegate overrides for visitArrayType, visitLiteral, visitMethodDeclaration, visitReturn, and visitVariableDeclarations to handle Go-specific type mismatches in the Java receiver. Go pointer types, function literals, and Go primitives sometimes arrive as nodes that don't satisfy Java's TypeTree/Expression/ Primitive constraints. Fix RPC parse input handling: check text-based inputs before file-path inputs so that the Java integration tests (which send inline source code) work correctly. Previously, SourcePath was tried as a file path even when Text was present. Add 8 integration tests covering: pointer receivers, func literals in return position, deref in short var decl, pointer slice types, maps with pointer values, func type declarations, nested func literals, and variadic pointer params.
1 parent eae89a1 commit 934ca1f

3 files changed

Lines changed: 299 additions & 13 deletions

File tree

rewrite-go/rewrite/cmd/rpc/main.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,20 @@ func (s *server) handleParse(params json.RawMessage) (any, *rpcError) {
308308
var sourcePath string
309309
var source string
310310

311-
filePath := input.Path
312-
if filePath == "" {
313-
filePath = input.SourcePath
314-
}
315-
if filePath != "" {
311+
if input.Text != "" {
312+
source = input.Text
313+
sourcePath = input.SourcePath
314+
if sourcePath == "" {
315+
sourcePath = "<unknown>"
316+
}
317+
} else {
318+
filePath := input.Path
319+
if filePath == "" {
320+
filePath = input.SourcePath
321+
}
322+
if filePath == "" {
323+
continue
324+
}
316325
absPath := filePath
317326
data, err := os.ReadFile(absPath)
318327
if err != nil {
@@ -329,14 +338,6 @@ func (s *server) handleParse(params json.RawMessage) (any, *rpcError) {
329338
} else {
330339
sourcePath = absPath
331340
}
332-
} else if input.Text != "" {
333-
source = input.Text
334-
sourcePath = input.SourcePath
335-
if sourcePath == "" {
336-
sourcePath = "<unknown>"
337-
}
338-
} else {
339-
continue
340341
}
341342

342343
cu, parseErr := func() (cu *tree.CompilationUnit, err error) {

rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangParserIntegTest.java

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,160 @@ func f() {
393393
);
394394
}
395395

396+
@Test
397+
void methodWithPointerReceiver() {
398+
rewriteRun(
399+
go(
400+
"""
401+
package main
402+
403+
type Router struct {
404+
\troutes []string
405+
}
406+
407+
func (r *Router) AddRoute(path string) {
408+
\tr.routes = append(r.routes, path)
409+
}
410+
"""
411+
)
412+
);
413+
}
414+
415+
@Test
416+
void funcLiteralInReturn() {
417+
rewriteRun(
418+
go(
419+
"""
420+
package main
421+
422+
func makeAdder(x int) func(int) int {
423+
\treturn func(y int) int {
424+
\t\treturn x + y
425+
\t}
426+
}
427+
"""
428+
)
429+
);
430+
}
431+
432+
@Test
433+
void derefInShortVarDecl() {
434+
rewriteRun(
435+
go(
436+
"""
437+
package main
438+
439+
type Foo struct{ X int }
440+
441+
func copyFoo(r *Foo) Foo {
442+
\tc := *r
443+
\treturn c
444+
}
445+
"""
446+
)
447+
);
448+
}
449+
450+
@Test
451+
void pointerSliceType() {
452+
rewriteRun(
453+
go(
454+
"""
455+
package main
456+
457+
type Route struct{ Path string }
458+
459+
func collect(routes []*Route) []string {
460+
\tvar result []string
461+
\tfor _, r := range routes {
462+
\t\tresult = append(result, r.Path)
463+
\t}
464+
\treturn result
465+
}
466+
"""
467+
)
468+
);
469+
}
470+
471+
@Test
472+
void mapWithPointerValue() {
473+
rewriteRun(
474+
go(
475+
"""
476+
package main
477+
478+
type Entry struct{ Name string }
479+
480+
func lookup() map[string]*Entry {
481+
\treturn map[string]*Entry{}
482+
}
483+
"""
484+
)
485+
);
486+
}
487+
488+
@Test
489+
void funcTypeDeclaration() {
490+
rewriteRun(
491+
go(
492+
"""
493+
package main
494+
495+
type MiddlewareFunc func(int) int
496+
497+
func (mw MiddlewareFunc) Apply(x int) int {
498+
\treturn mw(x)
499+
}
500+
"""
501+
)
502+
);
503+
}
504+
505+
@Test
506+
void nestedFuncLiteral() {
507+
rewriteRun(
508+
go(
509+
"""
510+
package main
511+
512+
func middleware() func(int) int {
513+
\treturn func(next int) int {
514+
\t\treturn next + 1
515+
\t}
516+
}
517+
518+
func wrapTwice() func(int) int {
519+
\treturn func(x int) int {
520+
\t\tinner := func(y int) int {
521+
\t\t\treturn y * 2
522+
\t\t}
523+
\t\treturn inner(x)
524+
\t}
525+
}
526+
"""
527+
)
528+
);
529+
}
530+
531+
@Test
532+
void variadicWithPointer() {
533+
rewriteRun(
534+
go(
535+
"""
536+
package main
537+
538+
type Handler struct{ Name string }
539+
540+
func Register(handlers ...*Handler) {
541+
\tfor _, h := range handlers {
542+
\t\t_ = h.Name
543+
\t}
544+
}
545+
"""
546+
)
547+
);
548+
}
549+
396550
@Test
397551
void compositeAndKeyValue() {
398552
rewriteRun(

rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangReceiver.java

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,137 @@ public J visitForEachControl(J.ForEachLoop.Control control, RpcReceiveQueue q) {
286286
.getPadding().withIterable(JRightPadded.build(iterable));
287287
}
288288

289+
/**
290+
* Sets a final field to a value whose type doesn't match the declared field type.
291+
* Go's AST uses nodes in positions where Java's model requires stricter types.
292+
*/
293+
private static void setFinalField(Object target, String fieldName, Object value) {
294+
try {
295+
java.lang.reflect.Field f = target.getClass().getDeclaredField(fieldName);
296+
f.setAccessible(true);
297+
f.set(target, value);
298+
} catch (Exception ignored) {
299+
}
300+
}
301+
302+
@SuppressWarnings({"unchecked", "rawtypes"})
303+
private J receiveAsJ(@Nullable Object before, RpcReceiveQueue q) {
304+
return (J) ((RpcReceiveQueue) q).receive(
305+
before,
306+
(java.util.function.UnaryOperator) t -> visitNonNull((J) t, q));
307+
}
308+
309+
@Override
310+
public J visitArrayType(J.ArrayType arrayType, RpcReceiveQueue q) {
311+
J elementType = receiveAsJ(arrayType.getElementType(), q);
312+
if (elementType instanceof TypeTree) {
313+
arrayType = arrayType.withElementType((TypeTree) elementType);
314+
} else if (elementType != null) {
315+
setFinalField(arrayType, "elementType", elementType);
316+
}
317+
return arrayType
318+
.withAnnotations(q.receiveList(arrayType.getAnnotations(), a -> (J.Annotation) visitNonNull(a, q)))
319+
.withDimension(q.receive(arrayType.getDimension(), d -> visitLeftPadded(d, q)))
320+
.withType(q.receive(arrayType.getType(), t -> visitType(t, q)));
321+
}
322+
323+
@Override
324+
public J visitLiteral(J.Literal literal, RpcReceiveQueue q) {
325+
literal = literal
326+
.withValue(q.receive(literal.getValue()))
327+
.withValueSource(q.receive(literal.getValueSource()))
328+
.withUnicodeEscapes(q.receiveList(literal.getUnicodeEscapes(), s -> {
329+
int valueSourceIndex = q.receive(s != null ? s.getValueSourceIndex() : 0);
330+
String codePoint = q.receive(s != null ? s.getCodePoint() : null);
331+
return new J.Literal.UnicodeEscape(valueSourceIndex, codePoint);
332+
}));
333+
@SuppressWarnings({"unchecked", "rawtypes"})
334+
JavaType type = (JavaType) ((RpcReceiveQueue) q).receive(
335+
(Object) literal.getType(),
336+
(java.util.function.UnaryOperator) t -> visitType((JavaType) t, q));
337+
if (type instanceof JavaType.Primitive) {
338+
literal = literal.withType((JavaType.Primitive) type);
339+
} else if (type != null) {
340+
setFinalField(literal, "type", type);
341+
}
342+
return literal;
343+
}
344+
345+
@Override
346+
public J visitMethodDeclaration(J.MethodDeclaration method, RpcReceiveQueue q) {
347+
if (method.getAnnotations().getName() == null) {
348+
method = method.getAnnotations().withName(
349+
new J.MethodDeclaration.IdentifierWithAnnotations(null, null));
350+
}
351+
method = method
352+
.withLeadingAnnotations(q.receiveList(method.getLeadingAnnotations(),
353+
a -> (J.Annotation) visitNonNull(a, q)))
354+
.withModifiers(q.receiveList(method.getModifiers(),
355+
m -> (J.Modifier) visitNonNull(m, q)))
356+
.getPadding().withTypeParameters(q.receive(
357+
method.getPadding().getTypeParameters(),
358+
tp -> (J.TypeParameters) visitNonNull(tp, q)));
359+
360+
J returnType = receiveAsJ(method.getReturnTypeExpression(), q);
361+
if (returnType instanceof TypeTree) {
362+
method = method.withReturnTypeExpression((TypeTree) returnType);
363+
} else if (returnType != null) {
364+
setFinalField(method, "returnTypeExpression", returnType);
365+
}
366+
367+
return method
368+
.getAnnotations().withName(method.getAnnotations().getName()
369+
.withAnnotations(q.receiveList(
370+
method.getAnnotations().getName().getAnnotations(),
371+
a -> (J.Annotation) visitNonNull(a, q))))
372+
.withName(q.receive(method.getName(),
373+
n -> (J.Identifier) visitNonNull(n, q)))
374+
.getPadding().withParameters(q.receive(
375+
method.getPadding().getParameters(), p -> visitContainer(p, q)))
376+
.getPadding().withThrows(q.receive(
377+
method.getPadding().getThrows(), t -> visitContainer(t, q)))
378+
.withBody(q.receive(method.getBody(),
379+
b -> (J.Block) visitNonNull(b, q)))
380+
.getPadding().withDefaultValue(q.receive(
381+
method.getPadding().getDefaultValue(),
382+
d -> visitLeftPadded(d, q)))
383+
.withMethodType(q.receive(method.getMethodType(),
384+
t -> (JavaType.Method) visitType(t, q)));
385+
}
386+
387+
@Override
388+
public J visitReturn(J.Return retrn, RpcReceiveQueue q) {
389+
J expr = receiveAsJ(retrn.getExpression(), q);
390+
if (expr instanceof Expression) {
391+
return retrn.withExpression((Expression) expr);
392+
} else if (expr != null) {
393+
setFinalField(retrn, "expression", expr);
394+
}
395+
return retrn;
396+
}
397+
398+
@Override
399+
public J visitVariableDeclarations(J.VariableDeclarations variableDecls, RpcReceiveQueue q) {
400+
variableDecls = variableDecls
401+
.withLeadingAnnotations(q.receiveList(variableDecls.getLeadingAnnotations(),
402+
a -> (J.Annotation) visitNonNull(a, q)))
403+
.withModifiers(q.receiveList(variableDecls.getModifiers(),
404+
m -> (J.Modifier) visitNonNull(m, q)));
405+
406+
J typeExpr = receiveAsJ(variableDecls.getTypeExpression(), q);
407+
if (typeExpr instanceof TypeTree) {
408+
variableDecls = variableDecls.withTypeExpression((TypeTree) typeExpr);
409+
} else if (typeExpr != null) {
410+
setFinalField(variableDecls, "typeExpression", typeExpr);
411+
}
412+
413+
variableDecls = variableDecls.withVarargs(
414+
q.receive(variableDecls.getVarargs(), v -> visitSpace(v, q)));
415+
return variableDecls.getPadding().withVariables(
416+
q.receiveList(variableDecls.getPadding().getVariables(),
417+
v -> visitRightPadded(v, q)));
418+
}
419+
289420
@Override
290421
public J visitImport(J.Import importStmt, RpcReceiveQueue q) {
291422
importStmt = importStmt.getPadding().withStatic(

0 commit comments

Comments
 (0)