Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion rewrite-go/rewrite/cmd/rpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,17 @@ func (s *server) writeMessage(resp *jsonRPCResponse) error {
}

// safeHandleRequest wraps handleRequest with panic recovery.
func (s *server) safeHandleRequest(req *jsonRPCRequest) *jsonRPCResponse {
func (s *server) safeHandleRequest(req *jsonRPCRequest) (resp *jsonRPCResponse) {
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
s.logger.Printf("PANIC in %s: %v\n%s", req.Method, r, buf[:n])
resp = &jsonRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Error: &rpcError{Code: -32603, Message: fmt.Sprintf("Internal error: %v", r)},
}
}
}()
return s.handleRequest(req)
Expand Down
91 changes: 67 additions & 24 deletions rewrite-go/rewrite/pkg/parser/go_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,59 +171,102 @@ func (ctx *parseContext) mapFile(file *ast.File, sourcePath string) *tree.Compil
}
}

// mapImports maps the import declarations in the file.
// mapImports maps all import declarations in the file into a single Container.
// Go allows multiple import blocks; subsequent blocks are tracked via ImportBlock markers.
func (ctx *parseContext) mapImports(file *ast.File) *tree.Container[*tree.Import] {
var importDecl *ast.GenDecl
// Collect all import GenDecls in order.
var importDecls []*ast.GenDecl
for _, decl := range file.Decls {
if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.IMPORT {
importDecl = gd
break
importDecls = append(importDecls, gd)
}
}
if importDecl == nil {
if len(importDecls) == 0 {
return nil
}

before := ctx.prefixAndSkip(importDecl.Pos(), len("import"))

var elements []tree.RightPadded[*tree.Import]

var containerMarkers tree.Markers
if importDecl.Lparen.IsValid() {
openParenPrefix := ctx.prefix(importDecl.Lparen)
ctx.skip(1) // skip "("
prevGrouped := false

// First import block: captured into Container.Before and Container.Markers
first := importDecls[0]
before := ctx.prefixAndSkip(first.Pos(), len("import"))

if first.Lparen.IsValid() {
prevGrouped = true
openParenPrefix := ctx.prefix(first.Lparen)
ctx.skip(1) // skip "("
containerMarkers = tree.Markers{
ID: uuid.New(),
Entries: []tree.Marker{
tree.GroupedImport{Ident: uuid.New(), Before: openParenPrefix},
},
}
}

for _, spec := range importDecl.Specs {
is := spec.(*ast.ImportSpec)
imp := ctx.mapImportSpec(is)
elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp})
}
for _, spec := range first.Specs {
is := spec.(*ast.ImportSpec)
imp := ctx.mapImportSpec(is)
elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp})
}

closeParen := ctx.prefix(importDecl.Rparen)
if first.Lparen.IsValid() {
closeParen := ctx.prefix(first.Rparen)
ctx.skip(1) // skip ")"

if len(elements) > 0 {
elements[len(elements)-1].After = closeParen
}
} else {
for _, spec := range importDecl.Specs {
is := spec.(*ast.ImportSpec)
imp := ctx.mapImportSpec(is)
elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp})
}

// Subsequent import blocks: attach ImportBlock marker to first import of each
for _, importDecl := range importDecls[1:] {
blockBefore := ctx.prefixAndSkip(importDecl.Pos(), len("import"))
grouped := importDecl.Lparen.IsValid()
var groupedBefore tree.Space
if grouped {
groupedBefore = ctx.prefix(importDecl.Lparen)
ctx.skip(1) // skip "("
}

ctx.mapImportBlockSpecs(importDecl, &elements, tree.ImportBlock{
Ident: uuid.New(),
ClosePrevious: prevGrouped,
Before: blockBefore,
Grouped: grouped,
GroupedBefore: groupedBefore,
})

if grouped {
closeParen := ctx.prefix(importDecl.Rparen)
ctx.skip(1) // skip ")"
if len(elements) > 0 {
elements[len(elements)-1].After = closeParen
}
}
prevGrouped = grouped
}

container := tree.Container[*tree.Import]{Before: before, Elements: elements, Markers: containerMarkers}
return &container
}

// mapImportBlockSpecs maps the specs of a subsequent import block, attaching
// the ImportBlock marker to the first spec's Import node.
func (ctx *parseContext) mapImportBlockSpecs(decl *ast.GenDecl, elements *[]tree.RightPadded[*tree.Import], marker tree.ImportBlock) {
for j, spec := range decl.Specs {
is := spec.(*ast.ImportSpec)
imp := ctx.mapImportSpec(is)
if j == 0 {
imp.Markers = tree.Markers{
ID: uuid.New(),
Entries: []tree.Marker{marker},
}
}
*elements = append(*elements, tree.RightPadded[*tree.Import]{Element: imp})
}
}

// mapImportSpec maps a single import spec.
func (ctx *parseContext) mapImportSpec(spec *ast.ImportSpec) *tree.Import {
prefix := ctx.prefix(spec.Pos())
Expand Down Expand Up @@ -1761,8 +1804,8 @@ func (ctx *parseContext) mapArrayType(expr *ast.ArrayType) tree.Expression {
length = ctx.mapExpr(expr.Len)
}

closePrefix := ctx.prefix(expr.Lbrack + token.Pos(ctx.findNextFrom('[', ctx.file.Offset(expr.Lbrack)) - ctx.file.Offset(expr.Lbrack)))
// Find the `]`
var closePrefix tree.Space
rbrackOff := ctx.findNext(']')
if rbrackOff >= 0 {
closePrefix = ctx.prefix(ctx.file.Pos(rbrackOff))
Expand Down
18 changes: 16 additions & 2 deletions rewrite-go/rewrite/pkg/printer/go_printer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,29 @@ func (p *GoPrinter) VisitCompilationUnit(cu *tree.CompilationUnit, param any) tr
out.Append("import")

grouped := tree.FindMarker[tree.GroupedImport](cu.Imports.Markers)
if grouped != nil {
isGrouped := grouped != nil
if isGrouped {
p.visitSpace(grouped.Before, out)
out.Append("(")
}
for _, rp := range cu.Imports.Elements {
block := tree.FindMarker[tree.ImportBlock](rp.Element.Markers)
if block != nil {
if block.ClosePrevious {
out.Append(")")
}
p.visitSpace(block.Before, out)
out.Append("import")
if block.Grouped {
p.visitSpace(block.GroupedBefore, out)
out.Append("(")
}
isGrouped = block.Grouped
}
p.Visit(rp.Element, out)
p.visitSpace(rp.After, out)
}
if grouped != nil {
if isGrouped {
out.Append(")")
}
}
Expand Down
22 changes: 22 additions & 0 deletions rewrite-go/rewrite/pkg/rpc/space_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ func sendMarkerCodecFields(v any, q *SendQueue) {
// GroupedImport.rpcSend sends: id (UUID string), before whitespace (string)
q.GetAndSend(m, func(x any) any { return x.(tree.GroupedImport).Ident.String() }, nil)
q.GetAndSend(m, func(x any) any { return x.(tree.GroupedImport).Before.Whitespace }, nil)
case tree.ImportBlock:
// ImportBlock.rpcSend sends: id, closePrevious, before, grouped, groupedBefore
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Ident.String() }, nil)
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).ClosePrevious }, nil)
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Before.Whitespace }, nil)
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Grouped }, nil)
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).GroupedBefore.Whitespace }, nil)
case tree.ShortVarDecl:
q.GetAndSend(m, func(x any) any { return x.(tree.ShortVarDecl).Ident.String() }, nil)
case tree.VarKeyword:
Expand Down Expand Up @@ -193,6 +200,21 @@ func receiveMarkersCodec(q *ReceiveQueue, before tree.Markers) tree.Markers {
ws := receiveScalar[string](q, m.Before.Whitespace)
m.Before = tree.Space{Whitespace: ws}
return m
case tree.ImportBlock:
// ImportBlock.rpcReceive: id, closePrevious, before, grouped, groupedBefore
idStr := receiveScalar[string](q, m.Ident.String())
if idStr != "" {
if parsed, err := uuid.Parse(idStr); err == nil {
m.Ident = parsed
}
}
m.ClosePrevious = receiveScalar[bool](q, m.ClosePrevious)
ws := receiveScalar[string](q, m.Before.Whitespace)
m.Before = tree.Space{Whitespace: ws}
m.Grouped = receiveScalar[bool](q, m.Grouped)
gbWs := receiveScalar[string](q, m.GroupedBefore.Whitespace)
m.GroupedBefore = tree.Space{Whitespace: gbWs}
return m
case tree.ShortVarDecl:
idStr := receiveScalar[string](q, m.Ident.String())
if idStr != "" {
Expand Down
3 changes: 3 additions & 0 deletions rewrite-go/rewrite/pkg/rpc/value_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func init() {

// Go-specific marker valueType registrations (for send-side type resolution)
RegisterValueType(reflect.TypeOf(tree.GroupedImport{}), "org.openrewrite.golang.marker.GroupedImport")
RegisterValueType(reflect.TypeOf(tree.ImportBlock{}), "org.openrewrite.golang.marker.ImportBlock")
RegisterValueType(reflect.TypeOf(tree.ShortVarDecl{}), "org.openrewrite.golang.marker.ShortVarDecl")
RegisterValueType(reflect.TypeOf(tree.VarKeyword{}), "org.openrewrite.golang.marker.VarKeyword")
RegisterValueType(reflect.TypeOf(tree.ConstDecl{}), "org.openrewrite.golang.marker.ConstDecl")
Expand Down Expand Up @@ -175,6 +176,8 @@ func init() {
RegisterFactory("org.openrewrite.marker.SearchResult", func() any { return tree.SearchResult{} })
// GroupedImport: IS an RpcCodec, sends 2 sub-fields (id, before whitespace)
RegisterFactory("org.openrewrite.golang.marker.GroupedImport", func() any { return tree.GroupedImport{} })
// ImportBlock: IS an RpcCodec, sends 5 sub-fields (id, closePrevious, before, grouped, groupedBefore)
RegisterFactory("org.openrewrite.golang.marker.ImportBlock", func() any { return tree.ImportBlock{} })
// Go-specific markers: all are RpcCodec
RegisterFactory("org.openrewrite.golang.marker.ShortVarDecl", func() any { return tree.ShortVarDecl{} })
RegisterFactory("org.openrewrite.golang.marker.VarKeyword", func() any { return tree.VarKeyword{} })
Expand Down
25 changes: 25 additions & 0 deletions rewrite-go/rewrite/pkg/tree/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ func (n *CompilationUnit) WithStatements(statements []RightPadded[Statement]) *C
return &c
}

func (n *CompilationUnit) WithPackageDecl(pkg *RightPadded[*Identifier]) *CompilationUnit {
c := *n
c.PackageDecl = pkg
return &c
}

func (n *CompilationUnit) WithImports(imports *Container[*Import]) *CompilationUnit {
c := *n
c.Imports = imports
return &c
}

func (n *CompilationUnit) WithEOF(eof Space) *CompilationUnit {
c := *n
c.EOF = eof
Expand Down Expand Up @@ -462,6 +474,19 @@ type GroupedImport struct {

func (g GroupedImport) ID() uuid.UUID { return g.Ident }

// ImportBlock is a marker on the first Import of a subsequent import block
// (2nd, 3rd, etc.) in files with multiple import declarations. It carries
// the information needed to print the block boundary.
type ImportBlock struct {
Ident uuid.UUID
ClosePrevious bool // true if the previous block was grouped (need to print ")")
Before Space // space before the "import" keyword
Grouped bool // true if this block uses import (...)
GroupedBefore Space // space between "import" and "(" (only if Grouped)
}

func (b ImportBlock) ID() uuid.UUID { return b.Ident }

// MultiAssignment represents a multi-value assignment: `x, y = 1, 2` or `x, y := f()`.
type MultiAssignment struct {
ID uuid.UUID
Expand Down
20 changes: 20 additions & 0 deletions rewrite-go/rewrite/pkg/visitor/go_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,19 @@ var _ VisitorI = (*GoVisitor)(nil)
func (v *GoVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J {
cu = cu.WithPrefix(v.self().VisitSpace(cu.Prefix, p))
cu = cu.WithMarkers(v.visitMarkers(cu.Markers, p))
if cu.PackageDecl != nil {
pkg := *cu.PackageDecl
pkg.Element = visitAndCast[*tree.Identifier](v, pkg.Element, p)
pkg.After = v.self().VisitSpace(pkg.After, p)
cu = cu.WithPackageDecl(&pkg)
}
if cu.Imports != nil {
imports := *cu.Imports
imports.Before = v.self().VisitSpace(imports.Before, p)
imports.Markers = v.visitMarkers(imports.Markers, p)
imports.Elements = visitRightPaddedList(v, imports.Elements, p)
cu = cu.WithImports(&imports)
}
cu = cu.WithStatements(visitRightPaddedList(v, cu.Statements, p))
cu = cu.WithEOF(v.self().VisitSpace(cu.EOF, p))
return cu
Expand Down Expand Up @@ -552,11 +565,18 @@ func (v *GoVisitor) visitMarkers(markers tree.Markers, p any) tree.Markers {

func visitAndCast[T tree.Tree](v *GoVisitor, t tree.Tree, p any) T {
result := v.self().Visit(t, p)
if result == nil {
var zero T
return zero
}
return result.(T)
}

func visitExpression(v *GoVisitor, expr tree.Expression, p any) tree.Expression {
result := v.self().Visit(expr, p)
if result == nil {
return nil
}
return result.(tree.Expression)
}

Expand Down
32 changes: 32 additions & 0 deletions rewrite-go/rewrite/test/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,35 @@ func TestParseGroupedImports(t *testing.T) {
}
`))
}

func TestParseMultipleImportBlocks(t *testing.T) {
NewRecipeSpec().RewriteRun(t,
Golang(`
package main

import "fmt"
import "os"

func hello() {
}
`))
}

func TestParseMultipleGroupedImportBlocks(t *testing.T) {
NewRecipeSpec().RewriteRun(t,
Golang(`
package main

import (
"fmt"
)

import (
"os"
"strings"
)

func hello() {
}
`))
}
Loading