diff --git a/rewrite-go/.gitignore b/rewrite-go/.gitignore index cc77b22f021..f50e74755d3 100644 --- a/rewrite-go/.gitignore +++ b/rewrite-go/.gitignore @@ -1 +1 @@ -rewrite/test-classpath.txt +test-classpath.txt diff --git a/rewrite-go/Makefile b/rewrite-go/Makefile new file mode 100644 index 00000000000..00a25569b87 --- /dev/null +++ b/rewrite-go/Makefile @@ -0,0 +1,16 @@ +.PHONY: test parity + +# Default test run — what CI executes. Build-tag-gated tests stay out. +test: + go test ./... + +# Printer fidelity audit. Loads every fixture under +# test/printer-corpus/, parses + prints, and asserts byte-equality. +# Gated behind the `parityaudit` build tag so it never runs in CI; +# devs invoke it manually when investigating printer regressions. +# +# Trade-off (per the eng review's P2 decision): corpus regressions ship +# undetected by automation in exchange for fast iteration and simple +# test infra. Re-evaluate if a corpus-detectable bug ever lands. +parity: + go test -tags parityaudit ./pkg/printer/... diff --git a/rewrite-go/PLAN.md b/rewrite-go/PLAN.md deleted file mode 100644 index 0261f53de39..00000000000 --- a/rewrite-go/PLAN.md +++ /dev/null @@ -1,63 +0,0 @@ -# Go Language Support — Implementation Plan - -## Current State - -The Go module supports **parse + print + GetObject + InstallRecipes + Reset** over RPC. All 15 integration tests pass (helloWorld, structs, slices, interfaces, for/range loops, switch, channels, maps, etc.). The recipe bundle resolver infrastructure (`GolangRecipeBundleResolver`/`GolangRecipeBundleReader`) is wired up on the Java side with stub handling on the Go side. - -## What's Left - -### 1. Go Server: GetMarketplace Handler -The Go RPC server needs to handle `GetMarketplace` requests so Java can discover what recipes the Go process has available. Python/JS servers return a `GetMarketplaceResponse` containing recipe descriptors organized by category. - -### 2. Go Server: PrepareRecipe Handler -Handle `PrepareRecipe` requests to instantiate a recipe with options. This is how the Java host asks the Go process to create a configured recipe instance ready for execution. - -### 3. Go Server: Visit Handler -Handle `Visit` requests — the core recipe execution path. Java sends a tree ID and a prepared recipe ID; the Go server applies the visitor and returns the modified tree. This requires bidirectional RPC (Go calls back to Java's `GetObject` to fetch the tree, applies the visitor, then Java calls `GetObject` to get the result). - -### 4. Go Server: Generate Handler -Handle `Generate` requests for recipes that create new source files rather than modifying existing ones. - -### 5. Go Server: ParseProject Handler -Handle `ParseProject` for bulk project parsing. Java sends a project directory path; Go discovers and parses all `.go` files, resolving the module structure (`go.mod`). Python/JS have 3 overloads supporting exclusion patterns and relative path configuration. - -### 6. Go Server: TraceGetObject Handler -Handle `TraceGetObject` to toggle verbose RPC message tracing for debugging. - -### 7. Go Server: InstallRecipes — Actual Implementation -The current `InstallRecipes` handler is a stub. It needs to: -- For local paths: discover and load Go recipe plugins from a local module -- For package specs: fetch a Go module from a Git repository, build it, and load its recipes -- Return accurate `recipesInstalled` count and resolved `version` - -### 8. Go Recipe Framework -Build the Go-side recipe/visitor infrastructure so Go recipes can be authored: -- Recipe interface/struct (name, description, options, visitor factory) -- Visitor base types for Go AST traversal -- Recipe registration/discovery mechanism -- Marketplace integration (expose registered recipes via GetMarketplace) - -### 9. Java Client: parseProject() -Add `parseProject()` methods to `GoRewriteRpc` (following Python's 3 overloads pattern) that send `ParseProject` RPC requests and return parsed source files with appropriate markers. - -### 10. Java Client: Builder Enhancements -Add builder options to `GoRewriteRpc.Builder`: -- `environment(Map)` — environment variables for the Go subprocess -- `workingDirectory(Path)` — working directory for the Go subprocess -- `metricsCsv(Path)` — metrics output -- `recipeInstallDir(Path)` — where installed recipe modules live - -### 11. Java Client: resetCurrent() -Add static `resetCurrent()` convenience method (mirrors Python/JS pattern). - -## Priority Order - -The most impactful order for enabling end-to-end recipe execution: - -1. **Go Recipe Framework** (#8) — foundation for everything else -2. **GetMarketplace** (#1) + **PrepareRecipe** (#2) — recipe discovery -3. **Visit** (#3) — recipe execution -4. **InstallRecipes actual impl** (#7) — load recipes from Git -5. **parseProject** (#5, #9) — project-level tooling -6. **Generate** (#4) — new file generation -7. **Builder/client polish** (#10, #11, #6) — ergonomics diff --git a/rewrite-go/build.gradle.kts b/rewrite-go/build.gradle.kts index d3d3830e6ce..76af81b1e39 100644 --- a/rewrite-go/build.gradle.kts +++ b/rewrite-go/build.gradle.kts @@ -30,6 +30,12 @@ tasks.withType().configureEach { exclude("**/G.java") } +// Test fixtures for the GoMod conformance corpus are .gomod / .gosum / .json +// files; none of these formats accept a leading license header. +configure { + excludePatterns.addAll(listOf("**/*.gomod", "**/*.gosum", "**/*.json")) +} + val goBuild = tasks.register("goBuild") { workingDir = projectDir // Use relative path to avoid absolute paths in cache key (Exec args are cache inputs) diff --git a/rewrite-go/cmd/rpc/main.go b/rewrite-go/cmd/rpc/main.go index a3e2ad6aa2c..038cd3550ff 100644 --- a/rewrite-go/cmd/rpc/main.go +++ b/rewrite-go/cmd/rpc/main.go @@ -20,7 +20,9 @@ package main import ( "bufio" + "encoding/csv" "encoding/json" + "flag" "fmt" "io" "log" @@ -30,6 +32,8 @@ import ( "runtime" "strconv" "strings" + "sync" + "time" "github.com/google/uuid" @@ -40,6 +44,7 @@ import ( "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/installer" "github.com/openrewrite/rewrite/rewrite-go/pkg/rpc" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) // jsonRPCRequest represents an incoming JSON-RPC 2.0 message (request or response). @@ -82,10 +87,31 @@ type server struct { // Prepared recipe instances keyed by unique ID preparedRecipes map[string]recipe.Recipe + // Per-prepared-recipe accumulator for ScanningRecipe. Lazily created on + // the first scan Visit call. Lifetime = prepared recipe instance; freed + // only on Reset (per the engineering review's D2 decision). + preparedAccumulators map[string]any + + // ExecutionContext fetched from Java on first visit, cached for the + // lifetime of the prepared recipe. Keyed by the ExecutionContext's + // remote object id (the `p` value in Visit/Generate/BatchVisit). + preparedContexts map[string]*recipe.ExecutionContext + // Tracing toggles for GetObject traceReceive bool traceSend bool + // Server configuration from CLI flags (see serverConfig) + metricsCsv string + dataTablesCsvDir string + + // Per-RPC metrics writer. Lazily opened in newServer when metricsCsv + // is set. Writes are guarded by metricsMu so concurrent dispatch + // (e.g. parallel BatchVisit handlers) can't interleave rows. + metricsFile *os.File + metricsWriter *csv.Writer + metricsMu sync.Mutex + reader *bufio.Reader writer io.Writer logger *log.Logger @@ -93,21 +119,52 @@ type server struct { installer *installer.Installer } -func newServer() *server { - logFile, err := os.CreateTemp("", "go-rpc-*.log") - if err != nil { - logFile = os.Stderr +// serverConfig holds CLI-driven configuration applied to the server at startup. +type serverConfig struct { + logFile string + traceRpcMessages bool + metricsCsv string + recipeInstallDir string + dataTablesCsvDir string +} + +func newServer(cfg serverConfig) *server { + var logOut io.Writer + if cfg.logFile != "" { + f, err := os.OpenFile(cfg.logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + logOut = os.Stderr + } else { + logOut = f + } + } else { + f, err := os.CreateTemp("", "go-rpc-*.log") + if err != nil { + logOut = os.Stderr + } else { + logOut = f + } } + // Register the empty-body codec for ExecutionContext under the FQN + // the Java side uses. Matches JS execution.ts:25-35: rpcSend writes + // nothing, rpcReceive returns a fresh ctx. + rpc.RegisterFactory("org.openrewrite.InMemoryExecutionContext", func() any { + return recipe.NewExecutionContext() + }) + reg := recipe.NewRegistry() reg.Activate(golang.Activate) - logger := log.New(logFile, "", log.LstdFlags) + logger := log.New(logOut, "", log.LstdFlags) inst := installer.NewInstaller() inst.Logger = logger.Printf + if cfg.recipeInstallDir != "" { + inst.WorkspaceDir = cfg.recipeInstallDir + } - return &server{ + s := &server{ localObjects: make(map[string]any), remoteObjects: make(map[string]any), localRefs: make(map[uintptr]int), @@ -115,17 +172,94 @@ func newServer() *server { reverseRemoteObjects: make(map[string]any), reverseRemoteRefs: make(map[int]any), preparedRecipes: make(map[string]recipe.Recipe), + preparedAccumulators: make(map[string]any), + preparedContexts: make(map[string]*recipe.ExecutionContext), batchSize: 1000, + traceReceive: cfg.traceRpcMessages, + traceSend: cfg.traceRpcMessages, + metricsCsv: cfg.metricsCsv, + dataTablesCsvDir: cfg.dataTablesCsvDir, reader: bufio.NewReader(os.Stdin), writer: os.Stdout, logger: logger, registry: reg, installer: inst, } + + if cfg.metricsCsv != "" { + f, err := os.OpenFile(cfg.metricsCsv, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + logger.Printf("metrics-csv: cannot open %q: %v — metrics disabled", cfg.metricsCsv, err) + } else { + s.metricsFile = f + s.metricsWriter = csv.NewWriter(f) + if err := s.metricsWriter.Write([]string{"timestamp", "method", "duration_ms", "error"}); err != nil { + logger.Printf("metrics-csv: cannot write header: %v", err) + } + s.metricsWriter.Flush() + } + } + + return s +} + +// closeMetrics flushes and closes the metrics CSV writer if open. Idempotent. +func (s *server) closeMetrics() { + s.metricsMu.Lock() + defer s.metricsMu.Unlock() + if s.metricsWriter != nil { + s.metricsWriter.Flush() + s.metricsWriter = nil + } + if s.metricsFile != nil { + if err := s.metricsFile.Close(); err != nil { + s.logger.Printf("metrics-csv: close failed: %v", err) + } + s.metricsFile = nil + } +} + +// recordMetric appends one row to the metrics CSV. Safe to call from +// concurrent goroutines; rows are written under metricsMu so they don't +// interleave. Errors are logged and dropped — metrics emission must never +// take down a request. +func (s *server) recordMetric(method string, duration time.Duration, rpcErr *rpcError) { + s.metricsMu.Lock() + defer s.metricsMu.Unlock() + if s.metricsWriter == nil { + return + } + errMsg := "" + if rpcErr != nil { + errMsg = rpcErr.Message + } + row := []string{ + time.Now().UTC().Format(time.RFC3339Nano), + method, + strconv.FormatInt(duration.Milliseconds(), 10), + errMsg, + } + if err := s.metricsWriter.Write(row); err != nil { + s.logger.Printf("metrics-csv: write row failed: %v", err) + return + } + s.metricsWriter.Flush() +} + +func parseFlags() serverConfig { + var cfg serverConfig + flag.StringVar(&cfg.logFile, "log-file", "", "path to write server log; empty = OS temp file") + flag.BoolVar(&cfg.traceRpcMessages, "trace-rpc-messages", false, "log every GetObject batch send/receive") + flag.StringVar(&cfg.metricsCsv, "metrics-csv", "", "path to write per-RPC metrics as CSV") + flag.StringVar(&cfg.recipeInstallDir, "recipe-install-dir", "", "directory used as the installer workspace; defaults to ~/.rewrite/go-recipes") + flag.StringVar(&cfg.dataTablesCsvDir, "data-tables-csv-dir", "", "directory where DataTable rows are written as CSV; empty = in-memory only") + flag.Parse() + return cfg } func main() { - s := newServer() + cfg := parseFlags() + s := newServer(cfg) s.logger.Println("Go RPC server starting...") for { @@ -146,6 +280,7 @@ func main() { } s.logger.Println("Go RPC server shutting down...") + s.closeMetrics() } // readMessage reads a Content-Length framed JSON-RPC message from stdin. @@ -194,8 +329,11 @@ func (s *server) writeMessage(resp *jsonRPCResponse) error { return err } -// safeHandleRequest wraps handleRequest with panic recovery. +// safeHandleRequest wraps handleRequest with panic recovery and per-RPC +// metrics capture. The metric row is written exactly once per request, +// after the response is determined (panic-recovered or not). func (s *server) safeHandleRequest(req *jsonRPCRequest) (resp *jsonRPCResponse) { + start := time.Now() defer func() { if r := recover(); r != nil { buf := make([]byte, 4096) @@ -207,6 +345,7 @@ func (s *server) safeHandleRequest(req *jsonRPCRequest) (resp *jsonRPCResponse) Error: &rpcError{Code: -32603, Message: fmt.Sprintf("Internal error: %v", r)}, } } + s.recordMetric(req.Method, time.Since(start), resp.Error) }() return s.handleRequest(req) } @@ -237,6 +376,8 @@ func (s *server) handleRequest(req *jsonRPCRequest) *jsonRPCResponse { result, rpcErr = s.handlePrepareRecipe(req.Params) case "Visit": result, rpcErr = s.handleVisit(req.Params) + case "BatchVisit": + result, rpcErr = s.handleBatchVisit(req.Params) case "Generate": result, rpcErr = s.handleGenerate(req.Params) case "TraceGetObject": @@ -264,9 +405,19 @@ func (s *server) handleGetLanguages() []string { } // parseRequest is the parameter type for Parse. +// +// `Module` and `GoModContent` are optional and let callers establish a +// project context for the batch: when present, the server parses the +// go.mod content into a GoResolutionResult, builds a ProjectImporter +// with the module's `require` entries plus all sibling .go inputs as +// known sources, and uses that importer for type attribution. Without +// them, the server falls back to per-file parsing with the stdlib +// importer (today's behavior). type parseRequest struct { - Inputs []parseInput `json:"inputs"` - RelativeTo *string `json:"relativeTo"` + Inputs []parseInput `json:"inputs"` + RelativeTo *string `json:"relativeTo"` + Module string `json:"module,omitempty"` + GoModContent string `json:"goModContent,omitempty"` } // parseInput can be a path-based or text-based input. @@ -297,27 +448,37 @@ func (p *parseInput) UnmarshalJSON(data []byte) error { } // handleParse parses Go source files and returns their IDs. +// +// When req.Module + req.GoModContent are set, the handler builds a +// ProjectImporter from the parsed go.mod (requires) plus every .go input +// in the batch (siblings) and uses it for type attribution. Inputs in the +// same package directory are parsed together so cross-file references +// resolve. Without module context the handler parses each input in +// isolation with the stdlib-only importer. func (s *server) handleParse(params json.RawMessage) (any, *rpcError) { var req parseRequest if err := json.Unmarshal(params, &req); err != nil { return nil, &rpcError{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} } - p := goparser.NewGoParser() - ids := make([]string, 0, len(req.Inputs)) - for _, input := range req.Inputs { - var sourcePath string - var source string - + // Resolve every input to (sourcePath, source) before deciding how to + // parse. This lets us build the ProjectImporter with knowledge of all + // siblings and lets us group by package directory. + type resolved struct { + idx int + sourcePath string + source string + } + resolvedInputs := make([]resolved, 0, len(req.Inputs)) + for i, input := range req.Inputs { + var sourcePath, source string if input.Text != "" { - // Text-based input (inline source from tests or recipe framework) source = input.Text sourcePath = input.SourcePath if sourcePath == "" { sourcePath = "" } } else { - // File-path-based input (mod build sends file paths) filePath := input.Path if filePath == "" { filePath = input.SourcePath @@ -332,8 +493,7 @@ func (s *server) handleParse(params json.RawMessage) (any, *rpcError) { } source = string(data) if req.RelativeTo != nil && *req.RelativeTo != "" { - rel, err := filepath.Rel(*req.RelativeTo, absPath) - if err == nil { + if rel, err := filepath.Rel(*req.RelativeTo, absPath); err == nil { sourcePath = rel } else { sourcePath = absPath @@ -342,25 +502,113 @@ func (s *server) handleParse(params json.RawMessage) (any, *rpcError) { sourcePath = absPath } } + resolvedInputs = append(resolvedInputs, resolved{idx: i, sourcePath: sourcePath, source: source}) + } + + p := goparser.NewGoParser() + + // Build a ProjectImporter when module context is provided. Recognize + // the requires from go.mod content; register every .go input as a + // sibling source so intra-project imports type-check against real + // sources, and third-party imports declared in `require` resolve to + // stub *types.Package objects. When req.RelativeTo is set, the vendor + // walker scans `/vendor//` for real + // resolution — replace directives in the go.mod redirect that walk. + if req.Module != "" { + pi := goparser.NewProjectImporter(req.Module, nil) + if req.RelativeTo != nil && *req.RelativeTo != "" { + pi.SetProjectRoot(*req.RelativeTo) + } + if req.GoModContent != "" { + if mrr, err := goparser.ParseGoMod("go.mod", req.GoModContent); err == nil && mrr != nil { + for _, r := range mrr.Requires { + pi.AddRequire(r.ModulePath) + } + for _, r := range mrr.Replaces { + pi.AddReplace(r.OldPath, r.NewPath, r.NewVersion) + } + } + } + for _, r := range resolvedInputs { + if strings.HasSuffix(r.sourcePath, ".go") { + pi.AddSource(r.sourcePath, r.source) + } + } + p.Importer = pi + } + + // Group .go inputs by package directory. Each group parses together + // via parser.ParsePackage so file-A-references-file-B resolves. + type fileEntry struct { + idx int + input goparser.FileInput + } + groups := map[string][]fileEntry{} + for _, r := range resolvedInputs { + if !strings.HasSuffix(r.sourcePath, ".go") { + continue + } + dir := filepath.Dir(r.sourcePath) + groups[dir] = append(groups[dir], fileEntry{idx: r.idx, input: goparser.FileInput{Path: r.sourcePath, Content: r.source}}) + } - cu, parseErr := func() (cu *tree.CompilationUnit, err error) { + // Parse each group; collect CUs by their original input index so the + // returned IDs land in input-order. Pre-filter against the parser's + // BuildContext so the post-parse `cus` slice aligns 1:1 with the + // `included` subset of the group. + cuByIdx := make(map[int]*tree.CompilationUnit, len(resolvedInputs)) + parseErrByIdx := make(map[int]error) + for _, group := range groups { + included := make([]fileEntry, 0, len(group)) + files := make([]goparser.FileInput, 0, len(group)) + for _, g := range group { + if !goparser.MatchBuildContext(p.BuildContext, filepath.Base(g.input.Path), g.input.Content) { + continue + } + included = append(included, g) + files = append(files, g.input) + } + if len(files) == 0 { + continue + } + cus, err := func() (out []*tree.CompilationUnit, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("panic: %v", r) } }() - return p.Parse(sourcePath, source) + return p.ParsePackage(files) }() - if parseErr != nil { - s.logger.Printf("Parse error for %s: %v", sourcePath, parseErr) - pe := tree.NewParseError(sourcePath, source, parseErr) - id := pe.Ident.String() - s.localObjects[id] = pe + if err != nil { + // Whole-package parse failure — record per-file ParseErrors + // for every file the build context didn't exclude. + for _, g := range included { + parseErrByIdx[g.idx] = err + } + continue + } + for i, cu := range cus { + cuByIdx[included[i].idx] = cu + } + } + + // Emit results in input order. + ids := make([]string, 0, len(req.Inputs)) + for _, r := range resolvedInputs { + if cu, ok := cuByIdx[r.idx]; ok && cu != nil { + id := cu.ID.String() + s.localObjects[id] = cu ids = append(ids, id) continue } - id := cu.ID.String() - s.localObjects[id] = cu + err := parseErrByIdx[r.idx] + if err == nil { + err = fmt.Errorf("no compilation unit produced") + } + s.logger.Printf("Parse error for %s: %v", r.sourcePath, err) + pe := tree.NewParseError(r.sourcePath, r.source, err) + id := pe.Ident.String() + s.localObjects[id] = pe ids = append(ids, id) } @@ -401,7 +649,9 @@ func (s *server) handleGetObject(params json.RawMessage) (any, *rpcError) { sender := rpc.NewGoSender() q.Send(obj, before, func(v any) { - sender.Visit(v, q) + if t, ok := v.(tree.Tree); ok { + sender.Visit(t, q) + } }) q.Put(rpc.RpcObjectData{State: rpc.EndOfObject}) q.Flush() @@ -536,7 +786,15 @@ func (s *server) getObjectFromJava(id string, sourceFileType string) any { receiver := rpc.NewGoReceiver() obj := q.Receive(before, func(v any) any { - return receiver.Visit(v, q) + // ExecutionContext uses an empty-body codec (matches JS execution.ts): + // the type tag arrives via the queue envelope; no field messages follow. + if ctx, ok := v.(*recipe.ExecutionContext); ok { + return ctx + } + if t, ok := v.(tree.Tree); ok { + return receiver.Visit(t, q) + } + return v }) // Consume the END_OF_OBJECT sentinel if present @@ -625,9 +883,95 @@ func (s *server) handleReset() bool { s.reverseRemoteObjects = make(map[string]any) s.reverseRemoteRefs = make(map[int]any) s.preparedRecipes = make(map[string]recipe.Recipe) + s.preparedAccumulators = make(map[string]any) + s.preparedContexts = make(map[string]*recipe.ExecutionContext) return true } +// resolveExecutionContext returns a usable ExecutionContext for a Visit / +// Generate / BatchVisit call. If pid is nil or empty, a fresh local ctx is +// returned. Otherwise the ctx is fetched from Java once (via the empty-body +// codec) and cached under the pid for subsequent calls in the same recipe run. +// +// When --data-tables-csv-dir is set, a CsvDataTableStore is installed into +// the ctx so any recipe that emits data-table rows writes them to that +// directory. Otherwise an InMemoryDataTableStore is created lazily on first +// InsertRow. +func (s *server) resolveExecutionContext(pid *string) *recipe.ExecutionContext { + var ctx *recipe.ExecutionContext + if pid == nil || *pid == "" { + ctx = recipe.NewExecutionContext() + } else if cached, ok := s.preparedContexts[*pid]; ok { + return cached + } else { + obj := s.getObjectFromJava(*pid, "org.openrewrite.InMemoryExecutionContext") + var ok bool + ctx, ok = obj.(*recipe.ExecutionContext) + if !ok || ctx == nil { + ctx = recipe.NewExecutionContext() + } + s.preparedContexts[*pid] = ctx + } + s.installDataTableStore(ctx) + return ctx +} + +// getOrCreateAccumulator returns the accumulator for a ScanningRecipe, +// creating it lazily on the first call. The accumulator's lifetime is +// tied to the prepared recipe instance — freed only on Reset. +func (s *server) getOrCreateAccumulator(recipeID string, sr recipe.ScanningRecipe, ctx *recipe.ExecutionContext) any { + if acc, ok := s.preparedAccumulators[recipeID]; ok { + return acc + } + acc := sr.InitialValue(ctx) + s.preparedAccumulators[recipeID] = acc + return acc +} + +// seedCursor reconstructs the cursor chain from RPC cursor IDs (root +// first) and seeds it onto the visitor via SetCursor. Visitors that +// don't expose SetCursor (e.g., aren't GoVisitor-derived) silently +// skip. Each cursor ID points to a tree node Java has; fetched via +// the existing reverse-RPC GetObject path. Mirrors how Java's RpcRecipe +// seeds the JavaVisitor cursor before traversal. +func (s *server) seedCursor(v recipe.TreeVisitor, ids []string) { + type cursorAware interface { + SetCursor(c *visitor.Cursor) + } + ca, ok := v.(cursorAware) + if !ok || len(ids) == 0 { + return + } + values := make([]tree.Tree, 0, len(ids)) + for _, id := range ids { + obj := s.getObjectFromJava(id, "") + if t, ok := obj.(tree.Tree); ok { + values = append(values, t) + } + } + if len(values) > 0 { + ca.SetCursor(visitor.BuildChain(values)) + } +} + +// installDataTableStore puts a DataTableStore into the ctx if one isn't +// already present. Choice driven by --data-tables-csv-dir. +func (s *server) installDataTableStore(ctx *recipe.ExecutionContext) { + if _, ok := ctx.GetMessage(recipe.DataTableStoreKey); ok { + return + } + if s.dataTablesCsvDir != "" { + store, err := recipe.NewCsvDataTableStore(s.dataTablesCsvDir) + if err != nil { + s.logger.Printf("CsvDataTableStore unavailable, falling back to in-memory: %v", err) + } else { + ctx.PutMessage(recipe.DataTableStoreKey, store) + return + } + } + ctx.PutMessage(recipe.DataTableStoreKey, recipe.NewInMemoryDataTableStore()) +} + // marketplaceRow matches Java's GetMarketplaceResponse.Row. type marketplaceRow struct { Descriptor marketplaceDescriptor `json:"descriptor"` @@ -680,25 +1024,6 @@ func (s *server) handleGetMarketplace(params json.RawMessage) (any, *rpcError) { for _, reg := range s.registry.AllRegistrations() { desc := reg.Descriptor - var options []marketplaceOption - for _, opt := range desc.Options { - var example *string - if opt.Example != "" { - example = &opt.Example - } - options = append(options, marketplaceOption{ - Name: opt.Name, - DisplayName: opt.DisplayName, - Description: opt.Description, - Example: example, - Required: opt.Required, - Type: "String", - }) - } - if options == nil { - options = []marketplaceOption{} - } - var categoryPath []marketplaceCategory for _, cat := range reg.Categories { categoryPath = append(categoryPath, marketplaceCategory{ @@ -712,20 +1037,7 @@ func (s *server) handleGetMarketplace(params json.RawMessage) (any, *rpcError) { } rows = append(rows, marketplaceRow{ - Descriptor: marketplaceDescriptor{ - Name: desc.Name, - DisplayName: desc.DisplayName, - InstanceName: desc.DisplayName, - Description: desc.Description, - Tags: nonNil(desc.Tags), - Options: options, - Preconditions: []marketplaceDescriptor{}, - RecipeList: []marketplaceDescriptor{}, - DataTables: []any{}, - Maintainers: []any{}, - Contributors: []any{}, - Examples: []any{}, - }, + Descriptor: marketplaceDescriptorFromRecipe(desc), CategoryPaths: [][]marketplaceCategory{categoryPath}, }) } @@ -735,6 +1047,149 @@ func (s *server) handleGetMarketplace(params json.RawMessage) (any, *rpcError) { return rows, nil } +// marketplaceDescriptorFromRecipe converts a recipe.RecipeDescriptor to the +// wire-format marketplaceDescriptor expected by Java. Recursive fields +// (recipeList, preconditions) are populated. Cycle protection is handled +// upstream by recipe.Describe. +func marketplaceDescriptorFromRecipe(desc recipe.RecipeDescriptor) marketplaceDescriptor { + options := make([]marketplaceOption, 0, len(desc.Options)) + for _, opt := range desc.Options { + var example *string + if opt.Example != "" { + example = &opt.Example + } + valid := make([]any, 0, len(opt.Valid)) + for _, v := range opt.Valid { + valid = append(valid, v) + } + options = append(options, marketplaceOption{ + Name: opt.Name, + DisplayName: opt.DisplayName, + Description: opt.Description, + Example: example, + Required: opt.Required, + Type: opt.TypeName(), + Value: opt.Value, + Valid: valid, + }) + } + + recipeList := make([]marketplaceDescriptor, 0, len(desc.RecipeList)) + for _, sub := range desc.RecipeList { + recipeList = append(recipeList, marketplaceDescriptorFromRecipe(sub)) + } + + preconditions := make([]marketplaceDescriptor, 0, len(desc.Preconditions)) + for _, pre := range desc.Preconditions { + preconditions = append(preconditions, marketplaceDescriptorFromRecipe(pre)) + } + + dataTables := make([]any, 0, len(desc.DataTables)) + for _, dt := range desc.DataTables { + dataTables = append(dataTables, marketplaceDataTable{ + Name: dt.Name, + DisplayName: dt.DisplayName, + Description: dt.Description, + Columns: marketplaceColumns(dt.Columns), + }) + } + + maintainers := make([]any, 0, len(desc.Maintainers)) + for _, m := range desc.Maintainers { + maintainers = append(maintainers, marketplaceMaintainer{ + Name: m.Name, Email: m.Email, Logo: m.Logo, + }) + } + + contributors := make([]any, 0, len(desc.Contributors)) + for _, c := range desc.Contributors { + contributors = append(contributors, marketplaceContributor{ + Name: c.Name, Email: c.Email, LineCount: c.LineCount, + }) + } + + examples := make([]any, 0, len(desc.Examples)) + for _, ex := range desc.Examples { + sources := make([]any, 0, len(ex.Sources)) + for _, src := range ex.Sources { + sources = append(sources, marketplaceExampleSource{ + Before: src.Before, After: src.After, + Path: src.Path, Language: src.Language, + }) + } + examples = append(examples, marketplaceExample{ + Description: ex.Description, + Sources: sources, + Parameters: nonNil(ex.Parameters), + }) + } + + return marketplaceDescriptor{ + Name: desc.Name, + DisplayName: desc.DisplayName, + InstanceName: desc.DisplayName, + Description: desc.Description, + Tags: nonNil(desc.Tags), + Options: options, + Preconditions: preconditions, + RecipeList: recipeList, + DataTables: dataTables, + Maintainers: maintainers, + Contributors: contributors, + Examples: examples, + } +} + +func marketplaceColumns(cols []recipe.ColumnDescriptor) []marketplaceColumn { + out := make([]marketplaceColumn, 0, len(cols)) + for _, c := range cols { + out = append(out, marketplaceColumn{ + Name: c.Name, DisplayName: c.DisplayName, + Description: c.Description, Type: c.Type, + }) + } + return out +} + +type marketplaceDataTable struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + Columns []marketplaceColumn `json:"columns"` +} + +type marketplaceColumn struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + Type string `json:"type"` +} + +type marketplaceMaintainer struct { + Name string `json:"name"` + Email string `json:"email,omitempty"` + Logo string `json:"logo,omitempty"` +} + +type marketplaceContributor struct { + Name string `json:"name"` + Email string `json:"email,omitempty"` + LineCount int `json:"lineCount"` +} + +type marketplaceExample struct { + Description string `json:"description"` + Sources []any `json:"sources"` + Parameters []string `json:"parameters"` +} + +type marketplaceExampleSource struct { + Before string `json:"before"` + After string `json:"after"` + Path string `json:"path,omitempty"` + Language string `json:"language,omitempty"` +} + func nonNil(s []string) []string { if s == nil { return []string{} @@ -750,38 +1205,13 @@ type prepareRecipeRequest struct { // prepareRecipeResponse contains the prepared recipe info. type prepareRecipeResponse struct { - ID string `json:"id"` - Descriptor recipeDescResponse `json:"descriptor"` - EditVisitor string `json:"editVisitor"` - EditPreconditions []any `json:"editPreconditions"` - ScanVisitor *string `json:"scanVisitor,omitempty"` - ScanPreconditions []any `json:"scanPreconditions"` - DelegatesTo *delegatesToResponse `json:"delegatesTo,omitempty"` -} - -type recipeDescResponse struct { - Name string `json:"name"` - DisplayName string `json:"displayName"` - InstanceName string `json:"instanceName"` - Description string `json:"description"` - Tags []string `json:"tags"` - EstimatedEffortPerOccurrence *string `json:"estimatedEffortPerOccurrence"` - Options []optionDescResponse `json:"options"` - Preconditions []any `json:"preconditions"` - RecipeList []any `json:"recipeList"` - DataTables []any `json:"dataTables"` - Maintainers []any `json:"maintainers"` - Contributors []any `json:"contributors"` - Examples []any `json:"examples"` - Source *string `json:"source"` -} - -type optionDescResponse struct { - Name string `json:"name"` - DisplayName string `json:"displayName"` - Description string `json:"description"` - Example string `json:"example,omitempty"` - Required bool `json:"required"` + ID string `json:"id"` + Descriptor marketplaceDescriptor `json:"descriptor"` + EditVisitor string `json:"editVisitor"` + EditPreconditions []any `json:"editPreconditions"` + ScanVisitor *string `json:"scanVisitor,omitempty"` + ScanPreconditions []any `json:"scanPreconditions"` + DelegatesTo *delegatesToResponse `json:"delegatesTo,omitempty"` } type delegatesToResponse struct { @@ -812,26 +1242,14 @@ func (s *server) handlePrepareRecipe(params json.RawMessage) (any, *rpcError) { desc = reg.Descriptor } - // Generate unique ID and store the prepared recipe + // response.ID must be the per-instance UUID, not the recipe name — + // callers echo it back to identify the prepared instance. recipeID := uuid.New().String() s.preparedRecipes[recipeID] = instance resp := prepareRecipeResponse{ - ID: req.ID, - Descriptor: recipeDescResponse{ - Name: desc.Name, - DisplayName: desc.DisplayName, - InstanceName: desc.DisplayName, - Description: desc.Description, - Tags: []string{}, - Options: []optionDescResponse{}, - Preconditions: []any{}, - RecipeList: []any{}, - DataTables: []any{}, - Maintainers: []any{}, - Contributors: []any{}, - Examples: []any{}, - }, + ID: recipeID, + Descriptor: marketplaceDescriptorFromRecipe(desc), EditVisitor: "edit:" + recipeID, EditPreconditions: []any{}, ScanPreconditions: []any{}, @@ -845,17 +1263,6 @@ func (s *server) handlePrepareRecipe(params json.RawMessage) (any, *rpcError) { } } - // Map options - for _, opt := range desc.Options { - resp.Descriptor.Options = append(resp.Descriptor.Options, optionDescResponse{ - Name: opt.Name, - DisplayName: opt.DisplayName, - Description: opt.Description, - Example: opt.Example, - Required: opt.Required, - }) - } - // Check for delegation if instance != nil { if del, ok := instance.(recipe.DelegatesTo); ok { @@ -914,14 +1321,29 @@ func (s *server) handleVisit(params json.RawMessage) (any, *rpcError) { return &visitResponse{Modified: false}, nil } - // Get the visitor based on phase + // Resolve ctx first because ScanningRecipe.InitialValue may need it. + ctx := s.resolveExecutionContext(req.PID) + + // Get the visitor based on phase. For ScanningRecipe, both scan and edit + // phases need access to the accumulator; the accumulator is created + // lazily on the first scan visit. var v recipe.TreeVisitor switch phase { case "edit": - v = r.Editor() + if sr, ok := r.(recipe.ScanningRecipe); ok { + acc := s.getOrCreateAccumulator(recipeID, sr, ctx) + v = sr.EditorWithData(acc) + } else { + v = r.Editor() + } case "scan": - // ScanningRecipe not yet supported - return &visitResponse{Modified: false}, nil + sr, ok := r.(recipe.ScanningRecipe) + if !ok { + // scan visitor for a non-scanning recipe is a no-op + return &visitResponse{Modified: false}, nil + } + acc := s.getOrCreateAccumulator(recipeID, sr, ctx) + v = sr.Scanner(acc) default: return nil, &rpcError{Code: -32602, Message: "Unknown phase: " + phase} } @@ -931,13 +1353,20 @@ func (s *server) handleVisit(params json.RawMessage) (any, *rpcError) { } // Apply the visitor - ctx := recipe.NewExecutionContext() treeNode, ok := treeObj.(tree.Tree) if !ok { return &visitResponse{Modified: false}, nil } + s.seedCursor(v, req.Cursor) before := treeNode after := v.Visit(treeNode, ctx) + if after == nil { + after = before + } + // Drain any after-visits queued via GoVisitor.DoAfterVisit during + // the main visit. Mirrors JavaVisitor's afterVisit drain — recipes + // use this to compose follow-ups (e.g. AddImport as a side-effect). + after = visitor.DrainAfterVisits(v, after, ctx) // Check if modified by pointer identity (not value equality, // since tree nodes contain slices which are not comparable). @@ -967,18 +1396,221 @@ type generateResponse struct { SourceFileTypes []string `json:"sourceFileTypes"` } +// stringSet builds a set from a list of strings; used as a snapshot +// helper for hasNewMessages tracking in BatchVisit. +func stringSet(xs []string) map[string]struct{} { + out := make(map[string]struct{}, len(xs)) + for _, x := range xs { + out[x] = struct{}{} + } + return out +} + +// instantiateVisitor parses a visitor name like "edit:UUID" or "scan:UUID" +// and returns the configured visitor for that prepared recipe. Returns nil +// for installer-loaded recipes (no Go-side implementation) or unknown phases. +func (s *server) instantiateVisitor(visitorName string, ctx *recipe.ExecutionContext) recipe.TreeVisitor { + parts := strings.SplitN(visitorName, ":", 2) + if len(parts) != 2 { + return nil + } + phase := parts[0] + recipeID := parts[1] + + r, ok := s.preparedRecipes[recipeID] + if !ok || r == nil { + return nil + } + + switch phase { + case "edit": + if sr, ok := r.(recipe.ScanningRecipe); ok { + acc := s.getOrCreateAccumulator(recipeID, sr, ctx) + return sr.EditorWithData(acc) + } + return r.Editor() + case "scan": + if sr, ok := r.(recipe.ScanningRecipe); ok { + acc := s.getOrCreateAccumulator(recipeID, sr, ctx) + return sr.Scanner(acc) + } + return nil + } + return nil +} + +// batchVisitRequest is the parameter type for BatchVisit. +// Wire shape mirrors JS rewrite-javascript/rewrite/src/rpc/request/batch-visit.ts. +type batchVisitRequest struct { + SourceFileType string `json:"sourceFileType"` + TreeID string `json:"treeId"` + PID *string `json:"p"` + Cursor []string `json:"cursor"` + Visitors []batchVisitItem `json:"visitors"` +} + +type batchVisitItem struct { + Visitor string `json:"visitor"` + VisitorOptions map[string]any `json:"visitorOptions"` +} + +// batchVisitResponse must use the four-field per-result shape Java expects: +// {modified, deleted, hasNewMessages, searchResultIds}. +type batchVisitResponse struct { + Results []batchVisitResult `json:"results"` +} + +type batchVisitResult struct { + Modified bool `json:"modified"` + Deleted bool `json:"deleted"` + HasNewMessages bool `json:"hasNewMessages"` + SearchResultIDs []string `json:"searchResultIds"` +} + +// handleBatchVisit runs N visitors sequentially against a single tree, +// piping the output of visitor N into visitor N+1. On deletion, the +// pipeline stops and remaining visitors are not run. +func (s *server) handleBatchVisit(params json.RawMessage) (any, *rpcError) { + var req batchVisitRequest + if err := json.Unmarshal(params, &req); err != nil { + return nil, &rpcError{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + + ctx := s.resolveExecutionContext(req.PID) + + treeObj := s.getObjectFromJava(req.TreeID, req.SourceFileType) + current, _ := treeObj.(tree.Tree) + if current == nil { + return &batchVisitResponse{Results: []batchVisitResult{}}, nil + } + + // Track search-result marker IDs already visible on the tree before any + // visitor runs. Per-visitor, we diff against this set so each result + // only carries the IDs *newly added* by that visitor (matches + // JS batch-visit.ts:46+). + knownIDs := map[string]struct{}{} + for _, id := range tree.CollectSearchResultIDs(current) { + knownIDs[id.String()] = struct{}{} + } + + results := make([]batchVisitResult, 0, len(req.Visitors)) + for _, item := range req.Visitors { + v := s.instantiateVisitor(item.Visitor, ctx) + if v == nil { + results = append(results, batchVisitResult{SearchResultIDs: []string{}}) + continue + } + s.seedCursor(v, req.Cursor) + before := current + // Snapshot the ctx message keys so we can detect whether the + // visitor added any new ones (`hasNewMessages`). + preKeys := stringSet(ctx.MessageKeys()) + after := v.Visit(current, ctx) + + deleted := after == nil + modified := !deleted && !treeIdentical(before, after) + + hasNewMessages := false + for _, k := range ctx.MessageKeys() { + if _, ok := preKeys[k]; !ok { + hasNewMessages = true + break + } + } + + var newSearchResultIDs []string + if !deleted { + afterTree, _ := after.(tree.Tree) + if afterTree != nil { + for _, id := range tree.CollectSearchResultIDs(afterTree) { + sid := id.String() + if _, seen := knownIDs[sid]; seen { + continue + } + knownIDs[sid] = struct{}{} + newSearchResultIDs = append(newSearchResultIDs, sid) + } + } + } + if newSearchResultIDs == nil { + newSearchResultIDs = []string{} + } + + results = append(results, batchVisitResult{ + Modified: modified, + Deleted: deleted, + HasNewMessages: hasNewMessages, + SearchResultIDs: newSearchResultIDs, + }) + + if deleted { + delete(s.localObjects, req.TreeID) + current = nil + break + } + if modified { + if t, ok := after.(tree.Tree); ok { + current = t + } + } + } + + // Store the final tree under both req.treeId and its own id (if different), + // matching the JS pattern. + if current != nil { + s.localObjects[req.TreeID] = current + s.reverseRemoteObjects[req.TreeID] = current + } + + return &batchVisitResponse{Results: results}, nil +} + // handleGenerate returns any new source files generated by a scanning recipe. +// req.ID is the per-instance UUID returned by PrepareRecipe. func (s *server) handleGenerate(params json.RawMessage) (any, *rpcError) { var req generateRequest if err := json.Unmarshal(params, &req); err != nil { return nil, &rpcError{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} } - // For now, Go doesn't support ScanningRecipe.generate() - return &generateResponse{ - IDs: []string{}, - SourceFileTypes: []string{}, - }, nil + r, ok := s.preparedRecipes[req.ID] + if !ok || r == nil { + return &generateResponse{IDs: []string{}, SourceFileTypes: []string{}}, nil + } + sr, ok := r.(recipe.ScanningRecipe) + if !ok { + return &generateResponse{IDs: []string{}, SourceFileTypes: []string{}}, nil + } + + ctx := s.resolveExecutionContext(req.PID) + acc := s.getOrCreateAccumulator(req.ID, sr, ctx) + + generated := sr.Generate(acc, ctx) + resp := &generateResponse{ + IDs: make([]string, 0, len(generated)), + SourceFileTypes: make([]string, 0, len(generated)), + } + for _, t := range generated { + if t == nil { + continue + } + // New trees get a fresh UUID; Java fetches them via GetObject. + newID := uuid.New().String() + s.localObjects[newID] = t + resp.IDs = append(resp.IDs, newID) + resp.SourceFileTypes = append(resp.SourceFileTypes, sourceFileTypeFor(t)) + } + return resp, nil +} + +// sourceFileTypeFor returns the FQN Java expects for a Go-side tree. +// Currently every Go source file is a CompilationUnit; expand if other +// SourceFile types are added. +func sourceFileTypeFor(t tree.Tree) string { + if _, ok := t.(*tree.CompilationUnit); ok { + return "org.openrewrite.golang.tree.Go$CompilationUnit" + } + return "" } // traceGetObjectRequest is the parameter type for TraceGetObject. @@ -1012,9 +1644,19 @@ type parseProjectRequest struct { type parseProjectResponseItem struct { ID string `json:"id"` SourceFileType string `json:"sourceFileType"` + SourcePath string `json:"sourcePath"` } -// handleParseProject discovers and parses all Go files in a project directory. +// handleParseProject discovers and parses all Go files in a project +// directory. When a sibling go.mod exists, files are grouped by their +// closest-ancestor module and parsed together so cross-file references +// inside a package resolve. Each parsed compilation unit gets the owning +// module's GoResolutionResult attached as a marker so Java-side recipes +// can read module dependency info without re-parsing go.mod themselves. +// +// Multi-module repos (root go.mod plus nested submodules) are honored: +// each .go file resolves against its closest-ancestor go.mod, not the +// project root's. func (s *server) handleParseProject(params json.RawMessage) (any, *rpcError) { var req parseProjectRequest if err := json.Unmarshal(params, &req); err != nil { @@ -1023,19 +1665,24 @@ func (s *server) handleParseProject(params json.RawMessage) (any, *rpcError) { s.logger.Printf("ParseProject: path=%s", req.ProjectPath) - // Discover all .go files in the project directory - var goFiles []string + // Discover all .go files AND every go.mod in the project tree. + type discovered struct { + goFiles []string + goMods []string + } + var disc discovered err := filepath.Walk(req.ProjectPath, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.IsDir() { base := filepath.Base(path) - // Skip common non-source directories + // Skip common non-source directories. vendor/ is handled by + // the (3-tier) ProjectImporter for symbol resolution; we + // don't want to parse vendored code as project sources. if base == "vendor" || base == "node_modules" || base == ".git" || base == "testdata" { return filepath.SkipDir } - // Check exclusions for _, excl := range req.Exclusions { if matched, _ := filepath.Match(excl, base); matched { return filepath.SkipDir @@ -1043,8 +1690,11 @@ func (s *server) handleParseProject(params json.RawMessage) (any, *rpcError) { } return nil } - if strings.HasSuffix(path, ".go") && !strings.HasSuffix(path, "_test.go") { - goFiles = append(goFiles, path) + switch { + case filepath.Base(path) == "go.mod": + disc.goMods = append(disc.goMods, path) + case strings.HasSuffix(path, ".go") && !strings.HasSuffix(path, "_test.go"): + disc.goFiles = append(disc.goFiles, path) } return nil }) @@ -1052,38 +1702,196 @@ func (s *server) handleParseProject(params json.RawMessage) (any, *rpcError) { return nil, &rpcError{Code: -32603, Message: fmt.Sprintf("Walk error: %v", err)} } - // Parse each file - p := goparser.NewGoParser() - var items []parseProjectResponseItem - for _, goFile := range goFiles { + // Parse every go.mod once; index by directory so we can find the + // closest ancestor for each .go file. Failing to parse a go.mod is + // non-fatal — the affected files just lose module context and fall + // back to stdlib-only attribution. + type modCtx struct { + dir string // absolute directory containing go.mod + mrr *tree.GoResolutionResult + } + mods := make(map[string]*modCtx, len(disc.goMods)) + for _, modPath := range disc.goMods { + data, err := os.ReadFile(modPath) + if err != nil { + s.logger.Printf("ParseProject: skip go.mod %s: %v", modPath, err) + continue + } + mrr, err := goparser.ParseGoMod(modPath, string(data)) + if err != nil || mrr == nil { + s.logger.Printf("ParseProject: skip malformed go.mod %s: %v", modPath, err) + continue + } + // If a sibling go.sum exists, populate ResolvedDependencies too. + sumPath := filepath.Join(filepath.Dir(modPath), "go.sum") + if sumData, err := os.ReadFile(sumPath); err == nil { + mrr.ResolvedDependencies = goparser.ParseGoSum(string(sumData)) + } + mods[filepath.Dir(modPath)] = &modCtx{dir: filepath.Dir(modPath), mrr: mrr} + } + + // closestModule walks up `dir` looking for the deepest known go.mod + // directory. Returns nil when no ancestor module exists (tree-relative + // stdlib-only parse). + closestModule := func(dir string) *modCtx { + for cur := dir; ; { + if m, ok := mods[cur]; ok { + return m + } + parent := filepath.Dir(cur) + if parent == cur { + return nil + } + cur = parent + } + } + + // Pre-read every .go file so each is touched once even if used both + // as a ProjectImporter source and as a parse input. + contents := make(map[string]string, len(disc.goFiles)) + for _, goFile := range disc.goFiles { data, err := os.ReadFile(goFile) if err != nil { - s.logger.Printf("Skip %s: %v", goFile, err) + s.logger.Printf("ParseProject: skip %s: %v", goFile, err) continue } + contents[goFile] = string(data) + } + // Build a ProjectImporter per module, populated with every .go file + // that belongs to that module so cross-package resolution works. + // Project root is the module's go.mod dir so the vendor walker + // scans the right tree; replace directives are forwarded too. + piByModule := make(map[string]*goparser.ProjectImporter, len(mods)) + for _, m := range mods { + pi := goparser.NewProjectImporter(m.mrr.ModulePath, nil) + pi.SetProjectRoot(m.dir) + for _, r := range m.mrr.Requires { + pi.AddRequire(r.ModulePath) + } + for _, r := range m.mrr.Replaces { + pi.AddReplace(r.OldPath, r.NewPath, r.NewVersion) + } + piByModule[m.dir] = pi + } + for _, goFile := range disc.goFiles { + src, ok := contents[goFile] + if !ok { + continue + } + m := closestModule(filepath.Dir(goFile)) + if m == nil { + continue + } + piByModule[m.dir].AddSource(goFile, src) + } + + // Group files by (owning module, package directory). Each group + // parses together via ParsePackage so file-A-references-file-B + // resolves within a package. + type groupKey struct{ moduleDir, pkgDir string } + type fileEntry struct { + idx int + path string + sourcePath string + content string + } + groups := make(map[groupKey][]fileEntry) + type ordered struct { + idx int + sourcePath string + modCtx *modCtx + } + order := make([]ordered, 0, len(disc.goFiles)) + for i, goFile := range disc.goFiles { + src, ok := contents[goFile] + if !ok { + continue + } sourcePath := goFile if req.RelativeTo != nil && *req.RelativeTo != "" { if rel, err := filepath.Rel(*req.RelativeTo, goFile); err == nil { sourcePath = rel } } + m := closestModule(filepath.Dir(goFile)) + moduleDir := "" + if m != nil { + moduleDir = m.dir + } + key := groupKey{moduleDir: moduleDir, pkgDir: filepath.Dir(goFile)} + groups[key] = append(groups[key], fileEntry{ + idx: i, + path: goFile, + sourcePath: sourcePath, + content: src, + }) + order = append(order, ordered{idx: i, sourcePath: sourcePath, modCtx: m}) + } - cu, err := p.Parse(sourcePath, string(data)) + // Parse each group; collect CUs by original input index so the + // returned IDs land in input-order. Files filtered out by the + // parser's BuildContext (`//go:build` / `_GOOS_GOARCH.go` suffixes) + // don't appear in the response — handled here so the post-parse + // `cus` slice aligns with the `included` subset of entries. + cuByIdx := make(map[int]*tree.CompilationUnit, len(disc.goFiles)) + for key, entries := range groups { + p := goparser.NewGoParser() + if pi, ok := piByModule[key.moduleDir]; ok { + p.Importer = pi + } + included := make([]fileEntry, 0, len(entries)) + inputs := make([]goparser.FileInput, 0, len(entries)) + for _, e := range entries { + if !goparser.MatchBuildContext(p.BuildContext, filepath.Base(e.sourcePath), e.content) { + continue + } + included = append(included, e) + inputs = append(inputs, goparser.FileInput{Path: e.sourcePath, Content: e.content}) + } + if len(inputs) == 0 { + continue + } + cus, err := func() (out []*tree.CompilationUnit, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic: %v", r) + } + }() + return p.ParsePackage(inputs) + }() if err != nil { - s.logger.Printf("Parse error %s: %v", goFile, err) + for _, e := range included { + s.logger.Printf("ParseProject: parse error in %s: %v", e.path, err) + } continue } + for i, cu := range cus { + cuByIdx[included[i].idx] = cu + } + } + // Emit results in input order, attaching the owning module's marker + // to each cu so Java-side recipes can read module dependency info. + items := make([]parseProjectResponseItem, 0, len(disc.goFiles)) + for _, o := range order { + cu, ok := cuByIdx[o.idx] + if !ok || cu == nil { + continue + } + if o.modCtx != nil { + cu.Markers = tree.AddMarker(cu.Markers, *o.modCtx.mrr) + } id := cu.ID.String() s.localObjects[id] = cu items = append(items, parseProjectResponseItem{ ID: id, SourceFileType: "org.openrewrite.golang.tree.Go$CompilationUnit", + SourcePath: o.sourcePath, }) } - s.logger.Printf("ParseProject: parsed %d files", len(items)) + s.logger.Printf("ParseProject: parsed %d files across %d module(s)", len(items), len(mods)) return items, nil } diff --git a/rewrite-go/cmd/rpc/metrics_test.go b/rewrite-go/cmd/rpc/metrics_test.go new file mode 100644 index 00000000000..98dd58c88db --- /dev/null +++ b/rewrite-go/cmd/rpc/metrics_test.go @@ -0,0 +1,206 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "encoding/csv" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "sync" + "testing" + "time" +) + +// newTestServer wires a server pointed at a temp metrics CSV. Cleanup is +// registered with t.Cleanup. The server's stdin/stdout aren't used by the +// tests — they only invoke safeHandleRequest directly. +func newTestServer(t *testing.T) (*server, string) { + t.Helper() + dir := t.TempDir() + csvPath := filepath.Join(dir, "metrics.csv") + logPath := filepath.Join(dir, "server.log") + s := newServer(serverConfig{ + logFile: logPath, + metricsCsv: csvPath, + }) + t.Cleanup(s.closeMetrics) + return s, csvPath +} + +// readMetricsCSV parses the metrics file and returns its header and rows. +// Caller must close the writer (via closeMetrics) before reading so the +// flush completes. +func readMetricsCSV(t *testing.T, path string) (header []string, rows [][]string) { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatalf("open metrics csv: %v", err) + } + defer f.Close() + r := csv.NewReader(f) + all, err := r.ReadAll() + if err != nil { + t.Fatalf("parse metrics csv: %v", err) + } + if len(all) == 0 { + t.Fatal("metrics csv is empty (no header)") + } + return all[0], all[1:] +} + +func TestMetricsCSVHeaderWritten(t *testing.T) { + s, csvPath := newTestServer(t) + s.closeMetrics() + + header, rows := readMetricsCSV(t, csvPath) + want := []string{"timestamp", "method", "duration_ms", "error"} + if len(header) != len(want) { + t.Fatalf("header columns: want %v, got %v", want, header) + } + for i, col := range want { + if header[i] != col { + t.Errorf("header[%d]: want %q, got %q", i, col, header[i]) + } + } + if len(rows) != 0 { + t.Errorf("want no rows before any RPC, got %d", len(rows)) + } +} + +func TestMetricsCSVRowPerRequest(t *testing.T) { + s, csvPath := newTestServer(t) + + // Fire a few requests serially. GetLanguages has no params and is + // always available; perfect for an isolation test. + calls := []string{"GetLanguages", "GetLanguages", "GetLanguages"} + for _, m := range calls { + s.safeHandleRequest(&jsonRPCRequest{JSONRPC: "2.0", ID: json.RawMessage("1"), Method: m}) + } + s.closeMetrics() + + _, rows := readMetricsCSV(t, csvPath) + if len(rows) != len(calls) { + t.Fatalf("rows: want %d, got %d (%v)", len(calls), len(rows), rows) + } + for i, row := range rows { + if len(row) != 4 { + t.Errorf("row[%d] columns: want 4, got %d (%v)", i, len(row), row) + continue + } + if row[1] != calls[i] { + t.Errorf("row[%d].method: want %q, got %q", i, calls[i], row[1]) + } + if row[3] != "" { + t.Errorf("row[%d].error: want empty, got %q", i, row[3]) + } + // duration_ms is non-negative integer; not necessarily > 0 for + // trivial methods on fast machines. + if _, err := strconv.Atoi(row[2]); err != nil { + t.Errorf("row[%d].duration_ms: not an int: %q", i, row[2]) + } + if _, err := time.Parse(time.RFC3339Nano, row[0]); err != nil { + t.Errorf("row[%d].timestamp: not RFC3339Nano: %q (%v)", i, row[0], err) + } + } +} + +func TestMetricsCSVCapturesErrors(t *testing.T) { + s, csvPath := newTestServer(t) + s.safeHandleRequest(&jsonRPCRequest{JSONRPC: "2.0", ID: json.RawMessage("1"), Method: "BogusMethodThatDoesNotExist"}) + s.closeMetrics() + + _, rows := readMetricsCSV(t, csvPath) + if len(rows) != 1 { + t.Fatalf("rows: want 1, got %d", len(rows)) + } + if rows[0][1] != "BogusMethodThatDoesNotExist" { + t.Errorf("method: %q", rows[0][1]) + } + if rows[0][3] == "" { + t.Errorf("expected error column populated for unknown method, got empty") + } +} + +func TestMetricsCSVConcurrentLoad(t *testing.T) { + // Concurrent RPC dispatch must serialize CSV writes; otherwise rows + // interleave and the parser fails or row counts drift. This test + // fires N parallel safeHandleRequest calls and asserts the file has + // exactly N well-formed rows. + s, csvPath := newTestServer(t) + + const goroutines = 16 + const perGoroutine = 50 + const total = goroutines * perGoroutine + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + gid := g + go func() { + defer wg.Done() + for i := 0; i < perGoroutine; i++ { + s.safeHandleRequest(&jsonRPCRequest{ + JSONRPC: "2.0", + ID: json.RawMessage(strconv.Itoa(gid*perGoroutine + i)), + Method: "GetLanguages", + }) + } + }() + } + wg.Wait() + s.closeMetrics() + + _, rows := readMetricsCSV(t, csvPath) + if len(rows) != total { + t.Fatalf("rows: want %d, got %d", total, len(rows)) + } + // Every row must have 4 columns and a parseable timestamp+duration. + // If writes interleaved, csv.NewReader.ReadAll above would fail or + // produce malformed rows; we re-validate here in case ReadAll silently + // padded. + for i, row := range rows { + if len(row) != 4 { + t.Fatalf("row[%d] columns: want 4, got %d (%v)", i, len(row), row) + } + if row[1] != "GetLanguages" { + t.Errorf("row[%d].method: want GetLanguages, got %q", i, row[1]) + } + if _, err := time.Parse(time.RFC3339Nano, row[0]); err != nil { + t.Errorf("row[%d].timestamp malformed: %q (%v)", i, row[0], err) + } + if _, err := strconv.Atoi(row[2]); err != nil { + t.Errorf("row[%d].duration_ms not an int: %q", i, row[2]) + } + } + fmt.Printf("concurrent metrics test: wrote %d rows across %d goroutines\n", total, goroutines) +} + +func TestMetricsCSVDisabledWhenFlagEmpty(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "server.log") + s := newServer(serverConfig{logFile: logPath}) + t.Cleanup(s.closeMetrics) + + // Should be a no-op; if writer-or-file leaks it would panic on close. + s.safeHandleRequest(&jsonRPCRequest{JSONRPC: "2.0", ID: json.RawMessage("1"), Method: "GetLanguages"}) + if s.metricsWriter != nil || s.metricsFile != nil { + t.Errorf("metrics writer should be nil when flag empty (writer=%v file=%v)", s.metricsWriter, s.metricsFile) + } +} diff --git a/rewrite-go/doc/recipe-authoring.md b/rewrite-go/doc/recipe-authoring.md new file mode 100644 index 00000000000..675b08819a6 --- /dev/null +++ b/rewrite-go/doc/recipe-authoring.md @@ -0,0 +1,302 @@ +# Authoring Go recipes + +A short reference for the patterns recipe authors need that aren't +obvious from the rest of the codebase. Snippets here are pulled from +real tests so they stay accurate as the API evolves. + +## A recipe in 30 lines + +A Go-native recipe embeds `recipe.Base`, names itself, and returns a +`TreeVisitor` from `Editor()`: + +```go +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +type RenameXToFlag struct{ recipe.Base } + +func (r *RenameXToFlag) Name() string { return "org.openrewrite.golang.test.RenameXToFlag" } +func (r *RenameXToFlag) DisplayName() string { return "Rename x to flag" } +func (r *RenameXToFlag) Description() string { return "Test recipe." } + +func (r *RenameXToFlag) Editor() recipe.TreeVisitor { + return visitor.Init(&renameXVisitor{}) +} + +type renameXVisitor struct{ visitor.GoVisitor } + +func (v *renameXVisitor) VisitIdentifier(ident *tree.Identifier, _ any) tree.J { + if ident.Name == "x" { + c := *ident + c.Name = "flag" + return &c + } + return ident +} +``` + +Two patterns that come up everywhere: + +- `visitor.Init(...)` — sets the `Self` field on the embedded + `GoVisitor` so virtual dispatch works. Always use it. +- Return a fresh value, never mutate in place. Recipes get the same + cu reference as their parent visitor; in-place mutation breaks + no-change detection (and makes recipe diffs confusing to debug). + +## Test wrappers — `goProject(...)` + +Multi-file Go recipes test against a project layout that mirrors what +shows up in real codebases: a `go.mod` plus one or more `.go` files, +all SIBLINGS inside a project directory (not nested). The Go and Java +test harnesses both expose this: + +**Go side** (`pkg/test/spec.go`): + +```go +spec := test.NewRecipeSpec().WithRecipe(&MyRecipe{}) +spec.RewriteRun(t, + test.GoProject("foo", + test.GoMod(` + module example.com/foo + + go 1.22 + + require github.com/google/uuid v1.6.0 + `), + test.GoSum(` + github.com/google/uuid v1.6.0 h1:... + github.com/google/uuid v1.6.0/go.mod h1:... + `), + test.Golang(`package main + +func main() {} +`).WithPath("main.go"), + ), +) +``` + +**Java side** (`Assertions.goProject`): + +```java +rewriteRun( + spec -> spec.recipe(new MyRecipe()), + goProject("foo", + goMod("module example.com/foo\n\ngo 1.22\n"), + go("package main\n\nfunc main() {}\n") + ) +); +``` + +What `goProject(...)` does: + +1. Tags every child source with a `tree.GoProject` marker — recipes + that need to know "is this file part of project X?" read it. +2. Parses any `goMod(...)` sibling and attaches a + `tree.GoResolutionResult` marker holding requires / replaces / + excludes / retracts. A sibling `GoSum(...)` populates + `ResolvedDependencies` on the same marker. +3. On the Java side, files are written to a temp dir before parsing + so the on-disk vendor walker can resolve relative paths. On the + Go side, the test harness threads the same data through + `parser.ProjectImporter` in memory. + +## Reading module info from a recipe + +`tree.GoResolutionResult` lives on the sibling `go.mod` source in a +multi-file project. Recipes that need module-level information walk +the input set looking for the marker: + +```go +import "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + +func modulePath(cu *tree.CompilationUnit) string { + for _, m := range cu.Markers.Entries { + if mrr, ok := m.(tree.GoResolutionResult); ok { + return mrr.ModulePath + } + } + return "" +} +``` + +For project-aware parses (`handleParseProject` and Java's +`parseWithProject`), the marker is attached to **every** +`Go.CompilationUnit` so the lookup is local — no scanning siblings +needed. For test-harness parses, only the `goMod(...)` source carries +the marker (mirrors how Maven recipes read `MavenResolutionResult` off +the sibling `pom.xml`). + +The shared imports primitives in `pkg/recipe/golang/internal/imports.go` +expose this as `internal.FindModulePath(cu)` for recipes inside the +package (Add/Remove/OrderImports use it for grouping). + +## Cursor message map + +Visitors traverse depth-first; sometimes a child needs to leave a +breadcrumb for an ancestor (or vice versa). The cursor exposes a +small message map for this: + +```go +v.Cursor().PutMessage("foundError", true) + +// Later, in some ancestor visit: +if v.Cursor().GetNearestMessage("foundError") == true { + // ... +} +``` + +Available methods on `*visitor.Cursor`: + +| Method | What it does | +|---|---| +| `PutMessage(key, value)` | Store on this cursor's frame. | +| `GetMessage(key)` | Read from this frame only. | +| `GetNearestMessage(key)` | Walk parents looking for `key`; return first hit. | +| `GetNearestMessageOrDefault(key, default)` | Same, with a fallback. | +| `PollNearestMessage(key)` | Walk parents like `GetNearestMessage`, then **delete** the message from the frame it was found on. | +| `PutMessageOnFirstEnclosing(key, value, predicate)` | Walk up to the first ancestor matching predicate, store there. | + +All match the Java `Cursor.putMessage` / `getNearestMessage` semantics. + +## Module context for type attribution + +Type attribution depth depends on how the parser is constructed: + +| Construction | Stdlib | Intra-project | Third-party requires | Vendored sources | +|---|---|---|---|---| +| `parser.NewGoParser()` (default) | ✓ | (single file only) | ✗ | ✗ | +| `goparser.NewProjectImporter(modulePath, fallback)` + `AddSource(...)` | ✓ | ✓ (cross-file) | stub (typed) | ✗ | +| `+ SetProjectRoot(rootDir)` | ✓ | ✓ | stub | ✓ if `/vendor//` exists | +| `+ AddReplace(old, new, version)` | ✓ | ✓ | redirected to `new` (vendor or local) | ✓ | + +For Java-side parsing through RPC, build the parser with +`GolangParser.builder().module("...").goMod(content).build()`. The +server reconstructs a `ProjectImporter` from those plus sibling .go +files, and uses `RelativeTo` (when set) as the project root for the +vendor walker. + +The `parseProject` RPC handles this automatically — every `.go` file +under the project root resolves against its closest-ancestor `go.mod`. +Multi-module repos (root + nested submodules) are honored. + +## Composing import edits — `ImportService` and `DoAfterVisit` + +Most refactor recipes don't ONLY rewrite imports — they edit a method +body, then need to add an import as a follow-up. Composing this with +the `AddImport` recipe directly is awkward (you'd have to nest visitors +or mutate the tree twice). The canonical pattern uses two pieces: + +1. `recipe.Service[T](sourceFile)` — fetch a registered service by type. +2. `GoVisitor.DoAfterVisit(visitor)` — queue a follow-up visitor; the + recipe runner drains the queue once the main visit returns. + +Together they let a recipe queue an import as a side-effect: + +```go +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +type ReplaceTimeSinceCall struct{ recipe.Base } + +func (r *ReplaceTimeSinceCall) Editor() recipe.TreeVisitor { + return visitor.Init(&replaceTimeSinceVisitor{}) +} + +type replaceTimeSinceVisitor struct{ visitor.GoVisitor } + +func (v *replaceTimeSinceVisitor) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { + mi = v.GoVisitor.VisitMethodInvocation(mi, p).(*tree.MethodInvocation) + if !looksLikeTimeSince(mi) { + return mi + } + // ... rewrite mi to xerrors-style call ... + + // Side-effect: queue an `import "xerrors"` to land after the main + // visit completes. The harness drains the queue automatically. + svc := recipe.Service[*golang.ImportService](nil) + v.DoAfterVisit(svc.AddImportVisitor("xerrors", nil, false /* unconditional */)) + return mi +} +``` + +`ImportService` exposes four visitors: + +| Method | Returns a visitor that... | +|---|---| +| `AddImportVisitor(path, alias, onlyIfReferenced)` | Adds `import [alias] "path"`. No-op when already present. | +| `RemoveImportVisitor(path)` | Deletes any import with the matching path. | +| `RemoveUnusedImportsVisitor()` | Drops imports the file doesn't reference. | +| `OrderImportsVisitor()` | Sorts into stdlib / third-party / local groups. | + +Each visitor is queueable via `DoAfterVisit` OR applicable directly via +`v.Visit(cu, ctx)`. After-visits can themselves queue more after-visits +(the runner drains transitively). + +**Service registration** happens at package `init()` time. As long as +your recipe imports `pkg/recipe/golang` (which most do), services are +registered before any test or RPC dispatch runs. Looking up a missing +service panics with a clear message. + +## Asserting types in tests + +Two helpers exist on both sides for common type assertions: + +**Go** (`pkg/test/expect_type.go`): + +```go +import . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + +ExpectType(t, cu, "p", "main.Point") // class/struct types +ExpectPrimitiveType(t, cu, "x", "int") // primitives +ExpectMethodType(t, cu, "Println", "fmt") // method's declaring FQN +``` + +**Java** (`org.openrewrite.golang.Assertions`): + +```java +expectType(cu, "p", "main.Point"); +expectPrimitiveType(cu, "x", "int"); +expectMethodType(cu, "Println", "fmt"); +``` + +Both walk the tree, find the first identifier (or method) matching the +name, and assert on its attributed type. Throw `AssertionError` with a +descriptive message on mismatch — drop them straight into an +`afterRecipe(cu -> ...)` lambda. + +## Surface boundaries + +When in doubt about what pattern to use: + +- **Refactoring an expression / statement / declaration** — write a + `recipe.Recipe` with a `GoVisitor` Editor. +- **Pattern → template rewrite** — use `template.Rewrite(before, after)` + from `pkg/template`. See `pkg/template/PARITY-AUDIT.md` for the full + surface against `JavaTemplate`. +- **Adding / removing / reordering imports** — use the existing + recipes in `pkg/recipe/golang/`: `AddImport`, `RemoveImport`, + `RemoveUnusedImports`, `OrderImports`. They share primitives + in `pkg/recipe/golang/internal/imports.go`. +- **Asserting on the parsed LST in a test** — `goProject(...)` for the + setup, `expectType` / `expectMethodType` for the assertion, and + `afterRecipe(cu -> ...)` for the callback. + +## Where to look next + +- `PLAN.md` — what's shipped vs. what's open work. +- `pkg/template/PARITY-AUDIT.md` — GoTemplate vs. JavaTemplate surface. +- `test/testdata/printer-corpus/README.md` — printer fidelity corpus + (run with `make parity`). +- `test/import_recipes_test.go` — worked examples for the Add/Remove/ + OrderImports recipes. +- `test/cross_package_generics_test.go` — what to expect from + cross-package type attribution. diff --git a/rewrite-go/go.mod b/rewrite-go/go.mod index db7369a0295..213ea00d7a2 100644 --- a/rewrite-go/go.mod +++ b/rewrite-go/go.mod @@ -1,5 +1,7 @@ module github.com/openrewrite/rewrite/rewrite-go -go 1.23 +go 1.25.0 require github.com/google/uuid v1.6.0 + +require golang.org/x/mod v0.35.0 // indirect diff --git a/rewrite-go/go.sum b/rewrite-go/go.sum index 7790d7c3e03..4ae3bf38875 100644 --- a/rewrite-go/go.sum +++ b/rewrite-go/go.sum @@ -1,2 +1,4 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= diff --git a/rewrite-go/pkg/format/auto_format.go b/rewrite-go/pkg/format/auto_format.go new file mode 100644 index 00000000000..618773271a5 --- /dev/null +++ b/rewrite-go/pkg/format/auto_format.go @@ -0,0 +1,67 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// AutoFormatVisitor composes the per-responsibility format visitors +// into a single end-to-end pipeline. The pipeline applies passes in a +// fixed order — each later pass relies on the structure normalized by +// the earlier ones: +// +// 1. RemoveTrailingWhitespaceVisitor — strip trailing spaces/tabs +// 2. BlankLinesVisitor — collapse blank-line runs +// 3. TabsAndIndentsVisitor — re-indent post-newline whitespace +// 4. SpacesVisitor — normalize intra-line spacing +// +// Pass-1 runs first because pass-3 (re-indent) only touches the post- +// newline portion of a Whitespace; if pass-1 ran later, the trailing +// space/tabs of an earlier line could survive. Pass-2 runs before pass-3 +// because collapsing blank lines doesn't touch indents — it just +// removes whole `\n` characters — so pass-3 still has the right +// post-newline section to rewrite. Pass-4 runs last because changes +// within a binary/assignment don't affect whether a line carries a +// newline; spacing fixes can't disturb prior passes. +// +// stopAfter is forwarded to every member visitor; pass nil to format +// the entire visited subtree. +type AutoFormatVisitor struct { + visitor.GoVisitor + stopAfter tree.Tree +} + +// NewAutoFormatVisitor returns a composer visitor that, on its first +// Visit, queues the four per-responsibility passes via DoAfterVisit. +// The recipe runner's after-visit drain runs them in order. Each pass +// sees the partially-normalized tree from its predecessors. +func NewAutoFormatVisitor(stopAfter tree.Tree) *AutoFormatVisitor { + return visitor.Init(&AutoFormatVisitor{stopAfter: stopAfter}) +} + +func (v *AutoFormatVisitor) Visit(t tree.Tree, p any) tree.Tree { + if t == nil { + return nil + } + v.DoAfterVisit(NewRemoveTrailingWhitespaceVisitor(v.stopAfter)) + v.DoAfterVisit(NewBlankLinesVisitor(v.stopAfter)) + v.DoAfterVisit(NewTabsAndIndentsVisitor(v.stopAfter)) + v.DoAfterVisit(NewSpacesVisitor(v.stopAfter)) + return t +} diff --git a/rewrite-go/pkg/format/blank_lines.go b/rewrite-go/pkg/format/blank_lines.go new file mode 100644 index 00000000000..999b2286163 --- /dev/null +++ b/rewrite-go/pkg/format/blank_lines.go @@ -0,0 +1,141 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "strings" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// BlankLinesVisitor enforces gofmt's blank-line rules: +// +// - Any run of >1 blank line collapses to one (anywhere in the +// compilation unit — gofmt's `1 blank line max` rule applies +// uniformly inside blocks AND between top-level decls). +// - Block.End loses any leading blank line so the closing brace +// sits flush against the last statement. +// - The first statement of a block has no blank line above it. +// +// The first-statement rule reaches through the leftmost spine of the +// statement node — the parser places inter-statement whitespace on the +// *leftmost descendant* (Variable.Prefix, not Assignment.Prefix), so +// transformLeftmostPrefix walks down to find it. +type BlankLinesVisitor struct { + visitor.GoVisitor + stopAfterTracker +} + +// NewBlankLinesVisitor returns a visitor configured with the given +// stopAfter bound. Pass nil to format the entire visited tree. +func NewBlankLinesVisitor(stopAfter tree.Tree) *BlankLinesVisitor { + return visitor.Init(&BlankLinesVisitor{ + stopAfterTracker: stopAfterTracker{stopAfter: stopAfter}, + }) +} + +func (v *BlankLinesVisitor) Visit(t tree.Tree, p any) tree.Tree { + if v.shouldHalt() { + return t + } + out := v.GoVisitor.Visit(t, p) + v.noteVisited(t) + return out +} + +func (v *BlankLinesVisitor) VisitSpace(s tree.Space, p any) tree.Space { + if !strings.Contains(s.Whitespace, "\n\n\n") { + return s + } + s.Whitespace = capInternalBlankLines(s.Whitespace, 1) + return s +} + +func (v *BlankLinesVisitor) VisitBlock(block *tree.Block, p any) tree.J { + out := v.GoVisitor.VisitBlock(block, p).(*tree.Block) + out = out.WithEnd(adjustSpace(out.End, stripLeadingBlankLines)) + + // Strip any leading blank line above the first statement. The + // inter-statement whitespace lives on the statement's leftmost + // descendant, so we walk the spine to find it. + if len(out.Statements) > 0 && out.Statements[0].Element != nil { + first := out.Statements[0] + if updated, ok := transformLeftmostPrefix(first.Element, stripLeadingBlankLinesSpace).(tree.Statement); ok { + first.Element = updated + out.Statements[0] = first + } + } + return out +} + +func stripLeadingBlankLinesSpace(s tree.Space) tree.Space { + s.Whitespace = stripLeadingBlankLines(s.Whitespace) + return s +} + +func adjustSpace(s tree.Space, f func(string) string) tree.Space { + updated := f(s.Whitespace) + if updated == s.Whitespace { + return s + } + s.Whitespace = updated + return s +} + +// stripLeadingBlankLines collapses any "\n\n+" run at the start to a +// single "\n", preserving any trailing indent. +func stripLeadingBlankLines(ws string) string { + if !strings.HasPrefix(ws, "\n\n") { + return ws + } + for strings.HasPrefix(ws, "\n\n") { + ws = ws[1:] + } + return ws +} + +// capInternalBlankLines walks ws and collapses any internal run of +// 3+ newlines to exactly 2 (i.e., one blank line). Leaves runs of +// 0–2 newlines untouched. +func capInternalBlankLines(ws string, max int) string { + allowed := max + 1 + var b strings.Builder + b.Grow(len(ws)) + i := 0 + for i < len(ws) { + if ws[i] != '\n' { + b.WriteByte(ws[i]) + i++ + continue + } + // Count run length. + j := i + for j < len(ws) && ws[j] == '\n' { + j++ + } + run := j - i + if run > allowed { + run = allowed + } + for k := 0; k < run; k++ { + b.WriteByte('\n') + } + i = j + } + return b.String() +} diff --git a/rewrite-go/pkg/format/leftmost.go b/rewrite-go/pkg/format/leftmost.go new file mode 100644 index 00000000000..fc161a814ca --- /dev/null +++ b/rewrite-go/pkg/format/leftmost.go @@ -0,0 +1,162 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// transformLeftmostPrefix walks the leftmost-spine of t and applies f +// to the deepest leaf's Prefix. The "leftmost spine" follows the child +// whose source position is leftmost — Binary's Left, Assignment's +// Variable, FieldAccess's Target, ArrayAccess's Indexed, +// MethodInvocation's Select (or Name when Select is nil). Types whose +// source layout puts the operator/keyword first (Unary, Parentheses, +// TypeCast) carry their leading whitespace on their own Prefix and are +// treated as leaves. +// +// The parser places inter-statement / inter-expression whitespace on +// the leftmost descendant rather than on the enclosing node. Format +// passes that need to manipulate that whitespace (BlankLines' +// block-start strip, Spaces' "single space after `:=`" with a Binary +// operand) reach for it through this helper instead of guessing where +// the parser put it. +func transformLeftmostPrefix(t tree.Tree, f func(tree.Space) tree.Space) tree.Tree { + if t == nil { + return nil + } + switch n := t.(type) { + case *tree.Binary: + if n.Left == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Left, f).(tree.Expression); ok { + c := *n + c.Left = updated + return &c + } + case *tree.Assignment: + if n.Variable == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Variable, f).(tree.Expression); ok { + c := *n + c.Variable = updated + return &c + } + case *tree.AssignmentOperation: + if n.Variable == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Variable, f).(tree.Expression); ok { + c := *n + c.Variable = updated + return &c + } + case *tree.FieldAccess: + if n.Target == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Target, f).(tree.Expression); ok { + c := *n + c.Target = updated + return &c + } + case *tree.MethodInvocation: + c := *n + if n.Select != nil && n.Select.Element != nil { + if updated, ok := transformLeftmostPrefix(n.Select.Element, f).(tree.Expression); ok { + sp := *n.Select + sp.Element = updated + c.Select = &sp + return &c + } + } else if n.Name != nil { + if updated, ok := transformLeftmostPrefix(n.Name, f).(*tree.Identifier); ok { + c.Name = updated + return &c + } + } + case *tree.ArrayAccess: + if n.Indexed == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Indexed, f).(tree.Expression); ok { + c := *n + c.Indexed = updated + return &c + } + case *tree.StatementExpression: + if n.Statement == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Statement, f).(tree.Statement); ok { + c := *n + c.Statement = updated + return &c + } + case *tree.Composite: + if n.TypeExpr == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.TypeExpr, f).(tree.Expression); ok { + c := *n + c.TypeExpr = updated + return &c + } + case *tree.ParameterizedType: + if n.Clazz == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Clazz, f).(tree.Expression); ok { + c := *n + c.Clazz = updated + return &c + } + case *tree.IndexList: + if n.Target == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Target, f).(tree.Expression); ok { + c := *n + c.Target = updated + return &c + } + case *tree.Slice: + if n.Indexed == nil { + break + } + if updated, ok := transformLeftmostPrefix(n.Indexed, f).(tree.Expression); ok { + c := *n + c.Indexed = updated + return &c + } + } + // Leaf — apply f to this node's own Prefix. + cur := getPrefix(t) + next := f(cur) + if next.Whitespace == cur.Whitespace && len(next.Comments) == len(cur.Comments) { + return t + } + return withPrefix(t, next) +} + +// setLeftmostPrefix is the "set" form of transformLeftmostPrefix: it +// replaces the leftmost leaf's Prefix with prefix unconditionally. +func setLeftmostPrefix(t tree.Tree, prefix tree.Space) tree.Tree { + return transformLeftmostPrefix(t, func(_ tree.Space) tree.Space { return prefix }) +} diff --git a/rewrite-go/pkg/format/prefix.go b/rewrite-go/pkg/format/prefix.go new file mode 100644 index 00000000000..ec0fedbda39 --- /dev/null +++ b/rewrite-go/pkg/format/prefix.go @@ -0,0 +1,73 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "reflect" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// getPrefix returns the `Prefix` field of any tree node that carries +// one. Reads via reflection so format visitors can manipulate prefixes +// uniformly across the ~50 concrete LST types without an exhaustive +// type-switch. +// +// Returns the zero Space when the node has no Prefix field (which is +// vanishingly rare — every J-conformant type carries one). +func getPrefix(t tree.Tree) tree.Space { + if t == nil { + return tree.Space{} + } + rv := reflect.ValueOf(t) + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + return tree.Space{} + } + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return tree.Space{} + } + f := rv.FieldByName("Prefix") + if !f.IsValid() { + return tree.Space{} + } + if s, ok := f.Interface().(tree.Space); ok { + return s + } + return tree.Space{} +} + +// withPrefix calls the node's `WithPrefix(Space) ` method to produce +// a copy with the given prefix. Returns the original node if it has no +// such method (defensive — every J-conformant type ships one). +func withPrefix[T tree.Tree](t T, prefix tree.Space) T { + rv := reflect.ValueOf(t) + m := rv.MethodByName("WithPrefix") + if !m.IsValid() { + return t + } + results := m.Call([]reflect.Value{reflect.ValueOf(prefix)}) + if len(results) == 0 { + return t + } + if r, ok := results[0].Interface().(T); ok { + return r + } + return t +} diff --git a/rewrite-go/pkg/format/spaces.go b/rewrite-go/pkg/format/spaces.go new file mode 100644 index 00000000000..5ff34e19e1b --- /dev/null +++ b/rewrite-go/pkg/format/spaces.go @@ -0,0 +1,132 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "strings" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// SpacesVisitor enforces gofmt's intra-line spacing rules: +// +// - One space around binary operators (`a + b`, not `a+b`). +// - No space around unary operators (`!x`, `-y`). +// - One space after commas in argument/parameter lists +// (RightPadded.After fields don't get touched here — they precede +// the comma, not follow it). +// - One space around `=` and `:=` (Assignment / VariableDeclarator). +// +// Non-trivial whitespace (containing newlines) is preserved on the +// theory that it was authored deliberately. SpacesVisitor only fixes +// "0 spaces" and "2+ spaces" in single-line contexts; multi-line +// expressions are left to the author and the indent visitor. +type SpacesVisitor struct { + visitor.GoVisitor + stopAfterTracker +} + +// NewSpacesVisitor returns a visitor configured with the given +// stopAfter bound. Pass nil to format the entire visited tree. +func NewSpacesVisitor(stopAfter tree.Tree) *SpacesVisitor { + return visitor.Init(&SpacesVisitor{ + stopAfterTracker: stopAfterTracker{stopAfter: stopAfter}, + }) +} + +func (v *SpacesVisitor) Visit(t tree.Tree, p any) tree.Tree { + if v.shouldHalt() { + return t + } + out := v.GoVisitor.Visit(t, p) + v.noteVisited(t) + return out +} + +func (v *SpacesVisitor) VisitBinary(bin *tree.Binary, p any) tree.J { + bin = v.GoVisitor.VisitBinary(bin, p).(*tree.Binary) + bin.Operator.Before = ensureSingleSpace(bin.Operator.Before) + bin = bin.WithRight(ensureLeadingSingleSpace(bin.Right)) + return bin +} + +func (v *SpacesVisitor) VisitAssignment(a *tree.Assignment, p any) tree.J { + a = v.GoVisitor.VisitAssignment(a, p).(*tree.Assignment) + a.Value.Before = ensureSingleSpace(a.Value.Before) + a.Value.Element = ensureLeadingSingleSpace(a.Value.Element) + return a +} + +func (v *SpacesVisitor) VisitAssignmentOperation(ao *tree.AssignmentOperation, p any) tree.J { + ao = v.GoVisitor.VisitAssignmentOperation(ao, p).(*tree.AssignmentOperation) + ao.Operator.Before = ensureSingleSpace(ao.Operator.Before) + ao.Assignment = ensureLeadingSingleSpace(ao.Assignment) + return ao +} + +func (v *SpacesVisitor) VisitUnary(u *tree.Unary, p any) tree.J { + u = v.GoVisitor.VisitUnary(u, p).(*tree.Unary) + u.Operand = clearExpressionLeadingSpace(u.Operand) + return u +} + +// ensureSingleSpace returns the space unchanged if it contains a +// newline (deliberate multi-line layout), otherwise normalizes any +// 0-or-many-spaces to exactly one space. +func ensureSingleSpace(s tree.Space) tree.Space { + if strings.Contains(s.Whitespace, "\n") { + return s + } + if s.Whitespace == " " { + return s + } + s.Whitespace = " " + return s +} + +// ensureLeadingSingleSpace ensures the leftmost leaf of an Expression +// has a single-space Prefix. Walks the leftmost spine via +// transformLeftmostPrefix so the rule works whether the leading +// whitespace lives directly on the expression (Identifier, Literal, +// Unary, Parentheses) or on a deeper leftmost descendant (Binary's +// Left, FieldAccess's Target, etc.). +func ensureLeadingSingleSpace(e tree.Expression) tree.Expression { + if e == nil { + return e + } + out := transformLeftmostPrefix(e, ensureSingleSpace) + if r, ok := out.(tree.Expression); ok { + return r + } + return e +} + +func clearExpressionLeadingSpace(e tree.Expression) tree.Expression { + if e == nil { + return e + } + prefix := getPrefix(e) + if strings.Contains(prefix.Whitespace, "\n") { + return e + } + if prefix.Whitespace == "" { + return e + } + prefix.Whitespace = "" + return withPrefix(e, prefix) +} diff --git a/rewrite-go/pkg/format/stop_after.go b/rewrite-go/pkg/format/stop_after.go new file mode 100644 index 00000000000..24912438263 --- /dev/null +++ b/rewrite-go/pkg/format/stop_after.go @@ -0,0 +1,75 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package format provides single-responsibility visitors that together +// implement gofmt-style normalization. Each visitor handles one aspect: +// +// - TabsAndIndentsVisitor — re-indent post-newline whitespace +// - BlankLinesVisitor — collapse runs of >1 blank line and trim +// blank lines at block boundaries +// - SpacesVisitor — token-level spacing fixes (binary operators, +// commas, etc.) +// - RemoveTrailingWhitespaceVisitor — strip trailing space/tab from +// line endings inside Whitespace +// +// AutoFormatVisitor composes the four into a fixed pipeline via +// DoAfterVisit. Each is independently usable for recipes that only need +// one pass. +// +// All visitors honor a `stopAfter tree.Tree` parameter: when non-nil, +// traversal halts after the given node has been visited, leaving +// downstream subtrees untouched. When nil, the entire visited subtree +// is processed. +// +// The package is deliberately minimal. gofmt's full rule set (column +// alignment in const blocks, struct field tag alignment, etc.) is not +// implemented yet — recipes that need byte-exact gofmt output should +// pipe through the gofmt binary; AutoFormat targets the "splice in a +// synthesized subtree and put it at the right indent" use case. +package format + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// stopAfterTracker is embedded by every format visitor to share the +// "halt after visiting this node" semantics. Once stopAfter has been +// fully visited, halted flips to true and subsequent Visit calls +// short-circuit. +type stopAfterTracker struct { + stopAfter tree.Tree + halted bool +} + +// shouldHalt reports whether traversal should be skipped because the +// stopAfter node has already been visited. Call from any Visit override +// before doing work. +func (s *stopAfterTracker) shouldHalt() bool { + return s.halted +} + +// noteVisited records that the given node has been visited; flips +// halted once the configured stopAfter node is encountered. Call this +// at the END of each Visit method (after recursing into children) so +// the stopAfter subtree is processed before halting kicks in. +func (s *stopAfterTracker) noteVisited(t tree.Tree) { + if s.halted || s.stopAfter == nil || t == nil { + return + } + if t == s.stopAfter { + s.halted = true + } +} diff --git a/rewrite-go/pkg/format/tabs_and_indents.go b/rewrite-go/pkg/format/tabs_and_indents.go new file mode 100644 index 00000000000..a6e29b4b703 --- /dev/null +++ b/rewrite-go/pkg/format/tabs_and_indents.go @@ -0,0 +1,133 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "strings" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// TabsAndIndentsVisitor re-indents block-level whitespace to gofmt's +// `\t × depth` convention. +// +// Strategy: rather than touching every Space in the tree (which would +// rewrite continuation alignments inside multi-line argument lists), +// the visitor drives indentation from VisitBlock explicitly: +// +// - For each `Block.Statements[i].Element`, the leftmost leaf's +// Prefix carries the inter-statement whitespace; we rewrite its +// indent (the post-newline portion) to `\t × (depth+1)`. +// - `Block.End` (the whitespace before `}`) is re-indented to +// `\t × depth`. +// - VisitCase decrements depth around the case-clause Prefix so +// `case` keywords align with the enclosing `switch` (gofmt +// convention) while body statements remain at body depth. +// +// Whitespace that doesn't carry a newline, or whitespace inside an +// expression context (continuations, alignments) is left untouched. +type TabsAndIndentsVisitor struct { + visitor.GoVisitor + stopAfterTracker + depth int +} + +// NewTabsAndIndentsVisitor returns a visitor configured with the given +// stopAfter bound. Pass nil to format the entire visited tree. +func NewTabsAndIndentsVisitor(stopAfter tree.Tree) *TabsAndIndentsVisitor { + return visitor.Init(&TabsAndIndentsVisitor{ + stopAfterTracker: stopAfterTracker{stopAfter: stopAfter}, + }) +} + +func (v *TabsAndIndentsVisitor) Visit(t tree.Tree, p any) tree.Tree { + if v.shouldHalt() { + return t + } + out := v.GoVisitor.Visit(t, p) + v.noteVisited(t) + return out +} + +// VisitBlock dispatches the body at depth+1 and re-indents each +// statement's leftmost-leaf Prefix and the closing-brace `End`. +func (v *TabsAndIndentsVisitor) VisitBlock(block *tree.Block, p any) tree.J { + v.depth++ + stmts := make([]tree.RightPadded[tree.Statement], len(block.Statements)) + for i, rp := range block.Statements { + if rp.Element != nil { + fixed, _ := transformLeftmostPrefix(rp.Element, v.reindentSpace).(tree.Statement) + if fixed != nil { + rp.Element = fixed + } + if next, ok := v.Visit(rp.Element, p).(tree.Statement); ok { + rp.Element = next + } + } + stmts[i] = rp + } + block.Statements = stmts + v.depth-- + + block = block.WithEnd(v.reindentSpace(block.End)) + return block +} + +// VisitCase aligns the case keyword with the enclosing switch (one +// tab less than the switch body's depth) while keeping the case body +// statements at body depth. Also explicitly visits Body so nested +// blocks inside a case get their own indent fixes (the default +// GoVisitor.VisitCase doesn't recurse into Body). +func (v *TabsAndIndentsVisitor) VisitCase(c *tree.Case, p any) tree.J { + v.depth-- + c = c.WithPrefix(v.reindentSpace(c.Prefix)) + v.depth++ + + body := make([]tree.RightPadded[tree.Statement], len(c.Body)) + for i, rp := range c.Body { + if rp.Element != nil { + fixed, _ := transformLeftmostPrefix(rp.Element, v.reindentSpace).(tree.Statement) + if fixed != nil { + rp.Element = fixed + } + if next, ok := v.Visit(rp.Element, p).(tree.Statement); ok { + rp.Element = next + } + } + body[i] = rp + } + c.Body = body + return c +} + +// reindentSpace rewrites the indent (post-last-newline portion) of s +// to `\t × v.depth`. Whitespace without a newline is returned +// unchanged. The pre-newline portion (which can hold blank lines) is +// preserved — BlankLinesVisitor handles that. +func (v *TabsAndIndentsVisitor) reindentSpace(s tree.Space) tree.Space { + if !strings.Contains(s.Whitespace, "\n") { + return s + } + last := strings.LastIndex(s.Whitespace, "\n") + want := strings.Repeat("\t", v.depth) + if s.Whitespace[last+1:] == want { + return s + } + s.Whitespace = s.Whitespace[:last+1] + want + return s +} diff --git a/rewrite-go/pkg/format/trailing_whitespace.go b/rewrite-go/pkg/format/trailing_whitespace.go new file mode 100644 index 00000000000..ef34938260a --- /dev/null +++ b/rewrite-go/pkg/format/trailing_whitespace.go @@ -0,0 +1,74 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package format + +import ( + "strings" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// RemoveTrailingWhitespaceVisitor strips trailing space/tab characters +// from the end of each line inside `Space.Whitespace` fields. Indent +// whitespace at the start of the next line is preserved. +// +// Input: "foo \t\n bar" +// Output: "foo\n bar" +type RemoveTrailingWhitespaceVisitor struct { + visitor.GoVisitor + stopAfterTracker +} + +// NewRemoveTrailingWhitespaceVisitor returns a visitor configured with +// the given stopAfter bound. Pass nil to format the entire visited tree. +func NewRemoveTrailingWhitespaceVisitor(stopAfter tree.Tree) *RemoveTrailingWhitespaceVisitor { + return visitor.Init(&RemoveTrailingWhitespaceVisitor{ + stopAfterTracker: stopAfterTracker{stopAfter: stopAfter}, + }) +} + +func (v *RemoveTrailingWhitespaceVisitor) Visit(t tree.Tree, p any) tree.Tree { + if v.shouldHalt() { + return t + } + out := v.GoVisitor.Visit(t, p) + v.noteVisited(t) + return out +} + +func (v *RemoveTrailingWhitespaceVisitor) VisitSpace(s tree.Space, p any) tree.Space { + if s.Whitespace == "" || !strings.ContainsAny(s.Whitespace, " \t") { + return s + } + s.Whitespace = stripTrailingPerLine(s.Whitespace) + return s +} + +// stripTrailingPerLine removes trailing space/tab characters from each +// line. The trailing portion of the LAST line is preserved because it's +// the leading indent of whatever syntax follows. +func stripTrailingPerLine(ws string) string { + if !strings.Contains(ws, "\n") { + return ws + } + lines := strings.Split(ws, "\n") + for i := 0; i < len(lines)-1; i++ { + lines[i] = strings.TrimRight(lines[i], " \t") + } + return strings.Join(lines, "\n") +} diff --git a/rewrite-go/pkg/parser/build_tags.go b/rewrite-go/pkg/parser/build_tags.go new file mode 100644 index 00000000000..f0d63841069 --- /dev/null +++ b/rewrite-go/pkg/parser/build_tags.go @@ -0,0 +1,192 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "go/build" + "go/build/constraint" + "strings" +) + +// MatchBuildContext returns true when (name, content) should be parsed +// under buildCtx — i.e. the filename suffix doesn't conflict with +// GOOS/GOARCH and any `//go:build` / `// +build` constraint lines +// evaluate to true. Mirrors build.Context.MatchFile for callers that +// have file content in-memory rather than on disk. +// +// The constraint evaluator recognizes: +// - GOOS/GOARCH name tags (matches when ctx.GOOS == tag etc.) +// - "cgo" if buildCtx.CgoEnabled is true +// - language version tags (e.g. "go1.21") if at or below the configured +// release; falls back to true to avoid spurious exclusions +// - any tag listed in buildCtx.BuildTags or buildCtx.ToolTags +// - the special "ignore" tag (always false — Go convention for +// manually excluding a file) +// +// Unknown tags evaluate to false. This matches build.Context's behavior +// for tags it doesn't recognize. +func MatchBuildContext(buildCtx build.Context, name, content string) bool { + if !matchOSArchFilename(buildCtx, name) { + return false + } + for _, line := range buildConstraintLines(content) { + expr, err := constraint.Parse(line) + if err != nil { + // Malformed constraint — skip the line rather than the whole + // file. Mirrors `build` package's tolerance for legacy tags. + continue + } + if !expr.Eval(func(tag string) bool { return matchTag(buildCtx, tag) }) { + return false + } + } + return true +} + +// matchOSArchFilename implements Go's filename build constraints: +// +// *_GOOS.go +// *_GOARCH.go +// *_GOOS_GOARCH.go +// +// where GOOS and GOARCH are known operating systems or architectures. +// `_test.go` is intentionally NOT excluded here (callers handle that +// upstream). +func matchOSArchFilename(buildCtx build.Context, name string) bool { + if !strings.HasSuffix(name, ".go") { + return true + } + stem := strings.TrimSuffix(name, ".go") + stem = strings.TrimSuffix(stem, "_test") + parts := strings.Split(stem, "_") + n := len(parts) + if n < 2 { + return true + } + + last := parts[n-1] + prev := "" + if n >= 3 { + prev = parts[n-2] + } + + if knownOS(prev) && knownArch(last) { + return prev == buildCtx.GOOS && last == buildCtx.GOARCH + } + if knownOS(last) { + return last == buildCtx.GOOS + } + if knownArch(last) { + return last == buildCtx.GOARCH + } + return true +} + +// buildConstraintLines extracts each `//go:build` or `// +build` line +// from the file's leading comment block (everything up to the first +// blank line after a comment, per Go's build constraint placement rules). +// +// We're conservative: scan up to the first non-comment, non-blank line +// after we've seen any meaningful content, then stop. This catches both +// new-style and legacy constraints in practice. +func buildConstraintLines(content string) []string { + var out []string + for _, raw := range strings.Split(content, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + // Stop at the package declaration — constraints must appear + // before package per Go convention. + if strings.HasPrefix(line, "package ") || line == "package" { + break + } + if !strings.HasPrefix(line, "//") { + continue + } + if constraint.IsGoBuild(line) || constraint.IsPlusBuild(line) { + out = append(out, line) + } + } + return out +} + +// matchTag evaluates a single build tag against the build context. +func matchTag(buildCtx build.Context, tag string) bool { + switch tag { + case "ignore": + return false + case buildCtx.GOOS: + return true + case buildCtx.GOARCH: + return true + case "cgo": + return buildCtx.CgoEnabled + case "unix": + return knownUnixOS(buildCtx.GOOS) + } + for _, t := range buildCtx.BuildTags { + if t == tag { + return true + } + } + for _, t := range buildCtx.ToolTags { + if t == tag { + return true + } + } + for _, t := range buildCtx.ReleaseTags { + if t == tag { + return true + } + } + return false +} + +// knownOS / knownArch / knownUnixOS mirror the lists Go's build package +// uses to identify GOOS/GOARCH filename suffixes. Limited to the +// recognized values so unrelated underscore-separated stems +// (`server_handler.go`) don't trigger filtering. +func knownOS(s string) bool { _, ok := knownOSes[s]; return ok } +func knownArch(s string) bool { _, ok := knownArches[s]; return ok } +func knownUnixOS(s string) bool { _, ok := unixOSes[s]; return ok } + +// knownOSes lists every Go-recognized GOOS value as of Go 1.22. New +// values can be added without breaking existing files since unknown +// suffixes already fall through to "match". +var knownOSes = map[string]bool{ + "aix": true, "android": true, "darwin": true, "dragonfly": true, + "freebsd": true, "hurd": true, "illumos": true, "ios": true, + "js": true, "linux": true, "nacl": true, "netbsd": true, + "openbsd": true, "plan9": true, "solaris": true, "wasip1": true, + "windows": true, "zos": true, +} + +var unixOSes = map[string]bool{ + "aix": true, "android": true, "darwin": true, "dragonfly": true, + "freebsd": true, "hurd": true, "illumos": true, "ios": true, + "linux": true, "netbsd": true, "openbsd": true, "solaris": true, +} + +var knownArches = map[string]bool{ + "386": true, "amd64": true, "amd64p32": true, "arm": true, + "arm64": true, "arm64be": true, "armbe": true, "loong64": true, + "mips": true, "mips64": true, "mips64le": true, "mips64p32": true, + "mips64p32le": true, "mipsle": true, "ppc": true, "ppc64": true, + "ppc64le": true, "riscv": true, "riscv64": true, "s390": true, + "s390x": true, "sparc": true, "sparc64": true, "wasm": true, +} diff --git a/rewrite-go/pkg/parser/go_parser.go b/rewrite-go/pkg/parser/go_parser.go index 523e842ce43..5121fb6544f 100644 --- a/rewrite-go/pkg/parser/go_parser.go +++ b/rewrite-go/pkg/parser/go_parser.go @@ -19,10 +19,13 @@ package parser import ( "fmt" "go/ast" + "go/build" "go/importer" "go/parser" "go/token" "go/types" + "path/filepath" + "strconv" "strings" "github.com/google/uuid" @@ -34,57 +37,127 @@ type GoParser struct { // Importer resolves imported packages for type checking. // Defaults to importer.Default() which resolves stdlib packages. Importer types.Importer + + // BuildContext drives `//go:build` and filename-suffix constraint + // evaluation in ParsePackage. Defaults to build.Default (the host's + // GOOS/GOARCH). Recipe authors that need cross-platform analysis can + // set this explicitly via NewGoParserWithBuildContext. + BuildContext build.Context } func NewGoParser() *GoParser { return &GoParser{ - Importer: importer.Default(), + Importer: importer.Default(), + BuildContext: build.Default, + } +} + +// NewGoParserWithBuildContext returns a parser that filters input files +// against the given build context. Useful for recipes that need to +// analyze code as it would compile under a specific GOOS/GOARCH/cgo +// configuration. To switch contexts, build a new parser — A3 keeps +// BuildContext immutable per parser to avoid cache-key complexity. +func NewGoParserWithBuildContext(buildCtx build.Context) *GoParser { + return &GoParser{ + Importer: importer.Default(), + BuildContext: buildCtx, } } -// Parse parses the given Go source code and returns a CompilationUnit. +// FileInput is one file given to ParsePackage. +type FileInput struct { + Path string + Content string +} + +// Parse parses a single Go source file and returns its CompilationUnit. +// Convenience wrapper around ParsePackage for the common one-file case; +// type attribution that depends on sibling files in the same package +// won't resolve here. Use ParsePackage when sibling files matter. func (gp *GoParser) Parse(sourcePath string, source string) (*tree.CompilationUnit, error) { - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, sourcePath, source, parser.ParseComments) + cus, err := gp.ParsePackage([]FileInput{{Path: sourcePath, Content: source}}) if err != nil { - return nil, fmt.Errorf("parse error: %w", err) + return nil, err + } + if len(cus) == 0 { + return nil, fmt.Errorf("no compilation unit produced") + } + return cus[0], nil +} + +// ParsePackage parses every file in a single Go package together so +// type-checking sees them as one unit. File A's reference to file B's +// symbol resolves; the resulting CompilationUnits share a single +// types.Info populated by one types.Config.Check call. +// +// All files MUST belong to the same package (same `package` clause). +// Order in the returned slice matches the input order. +func (gp *GoParser) ParsePackage(files []FileInput) ([]*tree.CompilationUnit, error) { + if len(files) == 0 { + return nil, nil + } + + // Filter out files excluded by the build context — `//go:build` / + // `// +build` constraints and OS/arch filename suffixes. Skipped + // files don't appear in the output at all (they're as if they + // weren't passed in). + filtered := make([]FileInput, 0, len(files)) + for _, f := range files { + if MatchBuildContext(gp.BuildContext, filepath.Base(f.Path), f.Content) { + filtered = append(filtered, f) + } + } + files = filtered + if len(files) == 0 { + return nil, nil + } + + fset := token.NewFileSet() + asts := make([]*ast.File, 0, len(files)) + for _, f := range files { + a, err := parser.ParseFile(fset, f.Path, f.Content, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("parse %s: %w", f.Path, err) + } + asts = append(asts, a) } - // Run type checking to populate type information typeInfo := &types.Info{ Types: make(map[ast.Expr]types.TypeAndValue), Defs: make(map[*ast.Ident]types.Object), Uses: make(map[*ast.Ident]types.Object), Selections: make(map[*ast.SelectorExpr]*types.Selection), } - conf := types.Config{ Importer: gp.Importer, // Don't fail on type errors — we want partial type info even when - // imports can't be resolved (single-file mode). - Error: func(err error) {}, + // some imports can't be resolved. + Error: func(error) {}, } - // Determine package name from the parsed file + // Use the first file's package name as the type-checker hint; + // types.Config.Check validates that all files agree. pkgName := "main" - if file.Name != nil { - pkgName = file.Name.Name + if asts[0].Name != nil { + pkgName = asts[0].Name.Name } + _, _ = conf.Check(pkgName, fset, asts, typeInfo) - // Type-check; errors are non-fatal (unresolvable imports are expected) - _, _ = conf.Check(pkgName, fset, []*ast.File{file}, typeInfo) - - ctx := &parseContext{ - src: []byte(source), - fset: fset, - file: fset.File(file.Pos()), - astFile: file, - cursor: 0, - typeInfo: typeInfo, - mapper: newTypeMapper(), + mapper := newTypeMapper() + cus := make([]*tree.CompilationUnit, 0, len(files)) + for i, f := range files { + ctx := &parseContext{ + src: []byte(f.Content), + fset: fset, + file: fset.File(asts[i].Pos()), + astFile: asts[i], + cursor: 0, + typeInfo: typeInfo, + mapper: mapper, + } + cus = append(cus, ctx.mapFile(asts[i], f.Path)) } - - return ctx.mapFile(file, sourcePath), nil + return cus, nil } // parseContext holds the state needed during AST-to-LST mapping. @@ -332,13 +405,16 @@ func (ctx *parseContext) mapGenDecl(decl *ast.GenDecl) tree.Statement { // mapVarConstDecl maps `var x int`, `var x = 5`, `const x = 5`, etc. func (ctx *parseContext) mapVarConstDecl(decl *ast.GenDecl) tree.Statement { prefix := ctx.prefix(decl.Pos()) + leadingAnns, prefix := extractDirectives(prefix) keyword := decl.Tok.String() ctx.skip(len(keyword)) if len(decl.Specs) == 1 && !decl.Lparen.IsValid() { // Single declaration: var x int = 5 spec := decl.Specs[0].(*ast.ValueSpec) - return ctx.mapValueSpec(spec, prefix, keyword) + vd := ctx.mapValueSpec(spec, prefix, keyword) + vd.LeadingAnnotations = leadingAnns + return vd } // Grouped declaration: var ( ... ) or const ( ... ) @@ -375,10 +451,11 @@ func (ctx *parseContext) mapVarConstDecl(decl *ast.GenDecl) tree.Statement { specs := &tree.Container[tree.Statement]{Before: lparenPrefix, Elements: elements} return &tree.VariableDeclarations{ - ID: uuid.New(), - Prefix: prefix, - Markers: tree.Markers{ID: uuid.New(), Entries: markerEntries}, - Specs: specs, + ID: uuid.New(), + Prefix: prefix, + Markers: tree.Markers{ID: uuid.New(), Entries: markerEntries}, + LeadingAnnotations: leadingAnns, + Specs: specs, } } @@ -468,10 +545,13 @@ func (ctx *parseContext) mapValueSpec(spec *ast.ValueSpec, prefix tree.Space, ke // mapTypeDecl maps a `type Name ...` declaration. func (ctx *parseContext) mapTypeDecl(decl *ast.GenDecl) tree.Statement { prefix := ctx.prefixAndSkip(decl.Pos(), len("type")) + leadingAnns, prefix := extractDirectives(prefix) if len(decl.Specs) == 1 && !decl.Lparen.IsValid() { spec := decl.Specs[0].(*ast.TypeSpec) - return ctx.mapTypeSpec(spec, prefix) + td := ctx.mapTypeSpec(spec, prefix) + td.LeadingAnnotations = leadingAnns + return td } // Grouped type declaration: type ( ... ) @@ -501,9 +581,10 @@ func (ctx *parseContext) mapTypeDecl(decl *ast.GenDecl) tree.Statement { specs := &tree.Container[tree.Statement]{Before: lparenPrefix, Elements: elements} return &tree.TypeDecl{ - ID: uuid.New(), - Prefix: prefix, - Specs: specs, + ID: uuid.New(), + Prefix: prefix, + LeadingAnnotations: leadingAnns, + Specs: specs, } } @@ -533,6 +614,7 @@ func (ctx *parseContext) mapTypeSpec(spec *ast.TypeSpec, prefix tree.Space) *tre // mapFuncDecl maps a function declaration. func (ctx *parseContext) mapFuncDecl(decl *ast.FuncDecl) *tree.MethodDeclaration { prefix := ctx.prefixAndSkip(decl.Pos(), len("func")) + leadingAnns, prefix := extractDirectives(prefix) var receiver *tree.Container[tree.Statement] if decl.Recv != nil && len(decl.Recv.List) > 0 { @@ -550,13 +632,14 @@ func (ctx *parseContext) mapFuncDecl(decl *ast.FuncDecl) *tree.MethodDeclaration } md := &tree.MethodDeclaration{ - ID: uuid.New(), - Prefix: prefix, - Receiver: receiver, - Name: name, - Parameters: params, - ReturnType: returnType, - Body: body, + ID: uuid.New(), + Prefix: prefix, + LeadingAnnotations: leadingAnns, + Receiver: receiver, + Name: name, + Parameters: params, + ReturnType: returnType, + Body: body, } // Type attribution for method declaration @@ -772,16 +855,56 @@ func (ctx *parseContext) mapFieldListAsParams(fl *ast.FieldList) tree.Container[ } // mapBlockStmt maps a block statement. +// +// Multi-statements-per-line (`_ = 1; _ = 2`) carry a literal `;` that +// Go's tokenizer recognizes but doesn't surface as part of either +// statement's AST. To round-trip the source, we look for an inline `;` +// between this statement and the next, capture the leading whitespace +// as RightPadded.After, mark the entry with a Semicolon marker, and +// advance the cursor past the `;`. The Block printer emits `;` when +// the marker is present. func (ctx *parseContext) mapBlockStmt(block *ast.BlockStmt) *tree.Block { prefix := ctx.prefix(block.Lbrace) ctx.skip(1) // "{" var stmts []tree.RightPadded[tree.Statement] - for _, stmt := range block.List { + for i, stmt := range block.List { mapped := ctx.mapStmt(stmt) - if mapped != nil { - stmts = append(stmts, tree.RightPadded[tree.Statement]{Element: mapped}) + if mapped == nil { + continue + } + rp := tree.RightPadded[tree.Statement]{Element: mapped} + + // Detect inline `;` separator. Go inserts implicit semicolons + // at end-of-line, so a literal `;` between two statements only + // appears when they share a source line (or when a `;` appears + // before the closing `}` on the last statement's line). We + // avoid scanning comments/strings for stray `;` bytes by + // gating on line numbers from the tokenizer. + stmtEndLine := ctx.file.Position(stmt.End()).Line + var nextStartLine int + if i+1 < len(block.List) { + nextStartLine = ctx.file.Position(block.List[i+1].Pos()).Line + } else { + nextStartLine = ctx.file.Position(block.Rbrace).Line } + if stmtEndLine == nextStartLine { + // Same line — look for the explicit `;`. + boundary := 0 + if i+1 < len(block.List) { + boundary = ctx.file.Offset(block.List[i+1].Pos()) + } else { + boundary = ctx.file.Offset(block.Rbrace) + } + semiOffset := ctx.findNextBefore(';', boundary) + if semiOffset >= 0 { + rp.After = ctx.prefix(ctx.file.Pos(semiOffset)) + ctx.skip(1) // consume ";" + rp.Markers = tree.AddMarker(rp.Markers, tree.NewSemicolon()) + } + } + + stmts = append(stmts, rp) } end := ctx.prefix(block.Rbrace) @@ -2376,14 +2499,247 @@ func (ctx *parseContext) mapFieldListAsStructBody(fl *ast.FieldList) *tree.Block return &tree.Block{ID: uuid.New(), Prefix: blockPrefix, Statements: stmts, End: end} } -// mapStructTag maps a struct field tag (e.g., `json:"name"`). -// Tags are stored as markers on the VariableDeclarations. +// mapStructTag parses a struct field tag literal into a sequence of +// Annotations — one per `key:"value"` pair — and attaches them to +// vd.LeadingAnnotations. +// +// Mirrors `reflect.StructTag.Lookup` parsing semantics: leading spaces +// between pairs are skipped (any number), keys run up to a colon, +// values are double-quoted strings respecting Go escape sequences. +// +// Whitespace policy (Option 1, lossy on non-canonical input): +// - The space between the field type and the opening backtick goes +// onto the first annotation's Prefix. +// - Whitespace between two `key:"value"` pairs goes onto the next +// annotation's Prefix. +// - Whitespace IMMEDIATELY inside the backticks (e.g., +// ` json:"x" `) is dropped — gofmt produces zero inner padding, +// so this only affects hand-typed weird input. Roundtrip on +// gofmt'd input is exact. func (ctx *parseContext) mapStructTag(vd *tree.VariableDeclarations, tag *ast.BasicLit) { - tagLit := ctx.mapBasicLit(tag) - vd.Markers.Entries = append(vd.Markers.Entries, tree.StructTag{ - Ident: uuid.New(), - Tag: tagLit, - }) + outerPrefix := ctx.prefix(tag.Pos()) + ctx.skip(len(tag.Value)) + + // tag.Value includes the wrapping backticks (or quotes). + raw := tag.Value + if len(raw) >= 2 { + first, last := raw[0], raw[len(raw)-1] + if (first == '`' && last == '`') || (first == '"' && last == '"') { + raw = raw[1 : len(raw)-1] + } + } + + pairs := parseStructTagPairs(raw) + if len(pairs) == 0 { + return + } + + annotations := make([]*tree.Annotation, len(pairs)) + for i, p := range pairs { + var annPrefix tree.Space + if i == 0 { + annPrefix = outerPrefix + } else { + annPrefix = tree.Space{Whitespace: p.PrefixWS} + } + annotations[i] = &tree.Annotation{ + ID: uuid.New(), + Prefix: annPrefix, + AnnotationType: &tree.Identifier{ + ID: uuid.New(), + Name: p.Key, + }, + Arguments: &tree.Container[tree.Expression]{ + Elements: []tree.RightPadded[tree.Expression]{ + {Element: &tree.Literal{ + ID: uuid.New(), + Source: p.QuotedValue, + Value: p.UnquotedValue, + Kind: tree.StringLiteral, + }}, + }, + }, + } + } + vd.LeadingAnnotations = annotations +} + +// extractDirectives splits a Space's leading line-comments into Annotation +// nodes when they match Go directive syntax (`//go:NAME [args]`, +// `//lint:NAME [args]`). The returned residual Space holds whatever +// wasn't extracted: the whitespace after the last extracted directive, +// plus any comments past the first non-directive (or block-comment). +// +// Used by the parser at top-level decl entry points (func, type, +// var/const) to populate `LeadingAnnotations` and shrink the decl's +// own Prefix to the whitespace between last directive and keyword. +func extractDirectives(s tree.Space) (anns []*tree.Annotation, residual tree.Space) { + if len(s.Comments) == 0 { + return nil, s + } + pendingPrefixWS := s.Whitespace + i := 0 + for i < len(s.Comments) { + c := s.Comments[i] + if c.Kind != tree.LineComment { + break + } + name, args, ok := parseDirective(c.Text) + if !ok { + break + } + anns = append(anns, buildDirectiveAnnotation(name, args, tree.Space{Whitespace: pendingPrefixWS})) + pendingPrefixWS = c.Suffix + i++ + } + if len(anns) == 0 { + return nil, s + } + residual = tree.Space{ + Whitespace: pendingPrefixWS, + Comments: s.Comments[i:], + } + return anns, residual +} + +// parseDirective tries to parse a `//PREFIX:NAME [ARGS]` line into +// (name, args, ok). The full directive name returned is `PREFIX:NAME` +// — preserved exactly as authors write it (`go:noinline`, +// `lint:ignore`). `args` is the trimmed text after the first space (or +// "" when absent). +// +// Recognized prefixes: `go`, `lint`. (Other vendor-specific prefixes +// like `nolint` aren't of the form `PREFIX:NAME` and are left as +// regular comments.) +func parseDirective(text string) (name, args string, ok bool) { + if !strings.HasPrefix(text, "//") { + return "", "", false + } + inner := text[2:] + colonIdx := strings.Index(inner, ":") + if colonIdx <= 0 { + return "", "", false + } + prefix := inner[:colonIdx] + if !isDirectivePrefix(prefix) { + return "", "", false + } + rest := inner[colonIdx+1:] + spaceIdx := strings.IndexAny(rest, " \t") + if spaceIdx < 0 { + if rest == "" { + return "", "", false + } + return prefix + ":" + rest, "", true + } + dirName := rest[:spaceIdx] + if dirName == "" { + return "", "", false + } + dirArgs := strings.TrimLeft(rest[spaceIdx:], " \t") + return prefix + ":" + dirName, dirArgs, true +} + +func isDirectivePrefix(p string) bool { + switch p { + case "go", "lint": + return true + } + return false +} + +func buildDirectiveAnnotation(name, args string, prefix tree.Space) *tree.Annotation { + ann := &tree.Annotation{ + ID: uuid.New(), + Prefix: prefix, + AnnotationType: &tree.Identifier{ + ID: uuid.New(), + Name: name, + }, + } + if args != "" { + ann.Arguments = &tree.Container[tree.Expression]{ + Before: tree.Space{Whitespace: " "}, + Elements: []tree.RightPadded[tree.Expression]{ + {Element: &tree.Literal{ + ID: uuid.New(), + Source: args, + Value: args, + Kind: tree.StringLiteral, + }}, + }, + } + } + return ann +} + +// structTagPair is one parsed `key:"value"` pair from a struct tag. +type structTagPair struct { + PrefixWS string // whitespace consumed before this pair (used only for non-first pairs) + Key string + QuotedValue string // the value source including its surrounding quotes (e.g. `"name"`) + UnquotedValue string // the value contents after Go-string unquoting +} + +// parseStructTagPairs scans a struct tag's contents (without the +// surrounding backticks) into a sequence of `key:"value"` pairs. +// Mirrors `reflect.StructTag.Lookup`'s scanning loop: +// - Skip ASCII whitespace. +// - Read key (printable, non-quote, non-colon, non-control). +// - Expect `:"`, read quoted string respecting `\` escapes. +// +// Returns whatever pairs it parsed up to the first malformed section; +// gofmt'd input is always well-formed but defensive scanning matches +// stdlib behavior. +func parseStructTagPairs(tag string) []structTagPair { + var pairs []structTagPair + i := 0 + for i < len(tag) { + // Skip leading whitespace. + prefStart := i + for i < len(tag) && (tag[i] == ' ' || tag[i] == '\t' || tag[i] == '\n' || tag[i] == '\r') { + i++ + } + prefixWS := tag[prefStart:i] + if i == len(tag) { + break + } + // Read key. + keyStart := i + for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f { + i++ + } + if i == keyStart || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' { + break + } + key := tag[keyStart:i] + i++ // skip `:` + // Read quoted value. + valueStart := i + i++ // skip opening `"` + for i < len(tag) && tag[i] != '"' { + if tag[i] == '\\' { + i++ + } + i++ + } + if i >= len(tag) { + break + } + i++ // skip closing `"` + quotedValue := tag[valueStart:i] + unquoted, err := strconv.Unquote(quotedValue) + if err != nil { + break + } + pairs = append(pairs, structTagPair{ + PrefixWS: prefixWS, + Key: key, + QuotedValue: quotedValue, + UnquotedValue: unquoted, + }) + } + return pairs } // mapFieldListAsInterfaceBody maps an interface's method list to a Block. diff --git a/rewrite-go/pkg/parser/gomod_parser.go b/rewrite-go/pkg/parser/gomod_parser.go new file mode 100644 index 00000000000..87991a97e2d --- /dev/null +++ b/rewrite-go/pkg/parser/gomod_parser.go @@ -0,0 +1,150 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "log" + "regexp" + "strings" + + "golang.org/x/mod/modfile" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// ParseGoMod parses go.mod content into a tree.GoResolutionResult. +// Mirrors org.openrewrite.golang.GoModParser on the Java side. +func ParseGoMod(path, content string) (*tree.GoResolutionResult, error) { + f, err := modfile.Parse(path, []byte(content), nil) + if err != nil { + return nil, err + } + + mrr := tree.NewGoResolutionResult("", "", "", path) + + if f.Module != nil { + mrr.ModulePath = f.Module.Mod.Path + } + if f.Go != nil { + mrr.GoVersion = f.Go.Version + } + if f.Toolchain != nil { + mrr.Toolchain = f.Toolchain.Name + } + + for _, r := range f.Require { + mrr.Requires = append(mrr.Requires, tree.GoRequire{ + ModulePath: r.Mod.Path, + Version: r.Mod.Version, + Indirect: r.Indirect, + }) + } + + for _, r := range f.Replace { + mrr.Replaces = append(mrr.Replaces, tree.GoReplace{ + OldPath: r.Old.Path, + OldVersion: r.Old.Version, + NewPath: r.New.Path, + NewVersion: r.New.Version, + }) + } + + for _, e := range f.Exclude { + mrr.Excludes = append(mrr.Excludes, tree.GoExclude{ + ModulePath: e.Mod.Path, + Version: e.Mod.Version, + }) + } + + for _, r := range f.Retract { + // modfile.Retract gives Low and High; if equal, it's a single + // version, otherwise it's a range. Format the range expression to + // match what the Java side stores. + var rng string + if r.Low == r.High { + rng = r.Low + } else { + rng = "[" + r.Low + ", " + r.High + "]" + } + mrr.Retracts = append(mrr.Retracts, tree.GoRetract{ + VersionRange: rng, + Rationale: strings.TrimSpace(r.Rationale), + }) + } + + return &mrr, nil +} + +// goSumLine matches one line of a go.sum file: +// +// [/go.mod] h1: +// +// Each module version appears on two lines — one for the module zip, one +// for its go.mod. Mirrors the Java GO_SUM_LINE pattern. +var goSumLine = regexp.MustCompile(`^\s*(\S+)\s+(\S+?)(/go\.mod)?\s+h1:(\S+)\s*$`) + +// ParseGoSum parses go.sum content into a slice of GoResolvedDependency, +// one per (module, version) pair. Bad lines are logged and skipped — go.sum +// is best-effort metadata, not an authoritative spec; a single malformed +// line should never tank a parse. +// +// Mirrors org.openrewrite.golang.GoModParser#parseSumSibling. The Go side +// is content-based (not filesystem-based) because the parser is invoked via +// RPC where sources are passed as strings. +func ParseGoSum(content string) []tree.GoResolvedDependency { + if content == "" { + return nil + } + type slot struct{ module, gomod string } + order := []string{} + byKey := map[string]*slot{} + for i, line := range strings.Split(content, "\n") { + if strings.TrimSpace(line) == "" { + continue + } + m := goSumLine.FindStringSubmatch(line) + if m == nil { + log.Printf("go.sum line %d: skipping malformed entry: %q", i+1, line) + continue + } + module, version, isGoMod, hash := m[1], m[2], m[3] != "", "h1:"+m[4] + key := module + "@" + version + s, ok := byKey[key] + if !ok { + s = &slot{} + byKey[key] = s + order = append(order, key) + } + if isGoMod { + s.gomod = hash + } else { + s.module = hash + } + } + out := make([]tree.GoResolvedDependency, 0, len(order)) + for _, key := range order { + s := byKey[key] + parts := strings.SplitN(key, "@", 2) + out = append(out, tree.GoResolvedDependency{ + ModulePath: parts[0], + Version: parts[1], + ModuleHash: s.module, + GoModHash: s.gomod, + }) + } + return out +} diff --git a/rewrite-go/pkg/parser/project_importer.go b/rewrite-go/pkg/parser/project_importer.go new file mode 100644 index 00000000000..d848f51a919 --- /dev/null +++ b/rewrite-go/pkg/parser/project_importer.go @@ -0,0 +1,306 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "log" + "os" + "path" + "path/filepath" + "strings" +) + +// ProjectImporter resolves Go imports for a parsed file against four +// layers, in order: +// +// 1. Sibling sources within the same project (registered via AddSource). +// Yields real *types.Package objects from parsing + type-checking. +// 2. Vendored sources at `/vendor//*.go`. When +// a `replace` directive applies, the lookup target is resolved +// accordingly: local replace (`./` / `../` prefix) walks the local +// path; module-path replace walks `vendor//`. +// Yields real *types.Package objects with full method/field types. +// 3. Modules declared in go.mod's `require` directives (registered via +// AddRequire). Yields a STUB *types.Package — right path and name, +// empty scope — so references like `import "github.com/x/y"` make +// the identifier `y` non-nil even when the module's source isn't +// present locally. +// 4. The fallback (importer.Default by default), which resolves stdlib +// packages from GOROOT. +// +// Mirrors the role of MavenProject/JavaSourceSet classpath resolution on +// the Java side: when a recipe parses a Go file inside a project, imports +// of `/` resolve against that sub-package's parsed +// sources, vendored deps resolve against on-disk files, and requires +// without vendor sources fall back to typed-but-empty stubs. +// +// Vendor walking is lazy on each Import() call (matching the existing +// 3-tier resolver's laziness). No eager startup walk. +type ProjectImporter struct { + modulePath string + projectRoot string // absolute or relative path to the dir containing go.mod + // sources keyed by full import path (e.g. "example.com/foo/sub"). + sources map[string][]projectFile + // requires lists module paths declared in go.mod `require` directives. + // An import path matches when it equals one of these OR is under one + // of them as a sub-path (require x/y also covers x/y/z imports). + requires []string + // replaces maps each old import path to its replacement target. + // Built by AddReplace from the parsed go.mod replace directives. + replaces map[string]replaceTarget + cache map[string]*types.Package + fset *token.FileSet + fallback types.Importer +} + +// replaceTarget mirrors a go.mod `replace ... => newPath [newVersion]` +// entry. NewVersion is empty when NewPath is a local filesystem path. +type replaceTarget struct { + NewPath string + NewVersion string +} + +type projectFile struct { + path string // module-relative, e.g. "main.go" or "sub/sub.go" + content string +} + +// NewProjectImporter creates an importer rooted at the given module path. +// Pass importer.Default() (or nil for the same default) as the stdlib +// fallback. Project root is unset and must be configured via +// SetProjectRoot for vendor walking to find anything. +func NewProjectImporter(modulePath string, fallback types.Importer) *ProjectImporter { + if fallback == nil { + fallback = importer.Default() + } + return &ProjectImporter{ + modulePath: modulePath, + sources: make(map[string][]projectFile), + replaces: make(map[string]replaceTarget), + cache: make(map[string]*types.Package), + fset: token.NewFileSet(), + fallback: fallback, + } +} + +// SetProjectRoot configures the directory the vendor walker scans +// relative to. Without this set, vendor lookups always miss and the +// resolver falls through to the require-stub tier. Pass the directory +// containing the project's go.mod. +func (p *ProjectImporter) SetProjectRoot(root string) { + p.projectRoot = root +} + +// AddReplace registers a go.mod `replace oldPath [oldVersion] => newPath [newVersion]` +// entry. At Import() time, requests for oldPath (or sub-paths under it) +// are redirected to newPath. Local-path replacements (`./` / `../`) +// resolve against the project root; module-path replacements resolve +// against `vendor//`. +func (p *ProjectImporter) AddReplace(oldPath, newPath, newVersion string) { + if oldPath == "" || newPath == "" { + return + } + p.replaces[oldPath] = replaceTarget{NewPath: newPath, NewVersion: newVersion} +} + +// AddRequire registers a module path declared in go.mod's `require` list. +// Imports of this path (or any sub-path under it) that aren't already +// satisfied by AddSource'd sibling sources resolve to a stub +// *types.Package — non-nil, with the right path and name, but with an +// empty scope. Real method/field types still need the module's actual +// sources (vendor dir or go-mod cache walk; not done yet). +func (p *ProjectImporter) AddRequire(modulePath string) { + if modulePath != "" { + p.requires = append(p.requires, modulePath) + } +} + +// AddSource registers a .go file with the importer. relPath is the file's +// path relative to the module root, e.g. "main.go" or "sub/sub.go". Only +// .go files are indexed; anything else is ignored. +func (p *ProjectImporter) AddSource(relPath, content string) { + if !strings.HasSuffix(relPath, ".go") { + return + } + dir := path.Dir(relPath) + importPath := p.modulePath + if dir != "" && dir != "." { + importPath = p.modulePath + "/" + dir + } + p.sources[importPath] = append(p.sources[importPath], projectFile{path: relPath, content: content}) +} + +// Import implements types.Importer. +func (p *ProjectImporter) Import(importPath string) (*types.Package, error) { + if cached, ok := p.cache[importPath]; ok { + return cached, nil + } + if files, ok := p.sources[importPath]; ok { + pkg, err := p.parsePackage(importPath, files) + if err != nil { + return nil, err + } + p.cache[importPath] = pkg + return pkg, nil + } + // Vendor walker: real on-disk source resolution. When a replace + // directive applies, the walker follows it; otherwise it looks for + // `/vendor//`. Parse failures are logged + // and the resolver falls through to the require-stub tier so a + // broken vendored dep doesn't tank the parent parse (per the C4 + // directive in the eng review). + if p.projectRoot != "" { + if pkg := p.walkVendor(importPath); pkg != nil { + p.cache[importPath] = pkg + return pkg, nil + } + } + // Stub-resolve any path declared in go.mod requires (or under one). + // Real symbols stay missing, but the package object itself is non-nil + // so identifiers referencing it have a Package type. The package must + // be marked complete; otherwise the type-checker treats the import + // as not yet loaded and reports `undefined: ` at use sites. + for _, req := range p.requires { + if importPath == req || strings.HasPrefix(importPath, req+"/") { + stub := types.NewPackage(importPath, path.Base(importPath)) + stub.MarkComplete() + p.cache[importPath] = stub + return stub, nil + } + } + return p.fallback.Import(importPath) +} + +// walkVendor attempts to resolve importPath against on-disk sources at +// `/vendor//` (or the equivalent path under a +// matching `replace` directive). Returns nil if no sources are found or +// if parsing failed (in which case the resolver should fall through to +// the stub tier). +func (p *ProjectImporter) walkVendor(importPath string) *types.Package { + dir := p.resolveVendorDir(importPath) + if dir == "" { + return nil + } + files, err := readGoFilesIn(dir) + if err != nil || len(files) == 0 { + return nil + } + pkg, err := p.parsePackage(importPath, files) + if err != nil { + log.Printf("vendor walker: skip %s (parse error: %v) — falling back to stub", importPath, err) + return nil + } + return pkg +} + +// resolveVendorDir maps an import path to the on-disk directory the +// vendor walker should scan. Returns "" when no resolution applies. +// +// Replace-directive resolution honored: +// - `replace foo => ./local/foo` → walk `/local/foo` +// - `replace foo => bar` (module) → walk `/vendor/bar` +// +// Without a matching replace, the default vendor dir is +// `/vendor/`. Sub-package imports +// (`require foo/bar` + `import "foo/bar/sub"`) walk the same way: +// `/vendor/foo/bar/sub`. +func (p *ProjectImporter) resolveVendorDir(importPath string) string { + target := importPath + for old, repl := range p.replaces { + if importPath == old || strings.HasPrefix(importPath, old+"/") { + suffix := strings.TrimPrefix(importPath, old) + if isLocalReplace(repl.NewPath) { + return filepath.Join(p.projectRoot, filepath.FromSlash(repl.NewPath+suffix)) + } + target = repl.NewPath + suffix + break + } + } + return filepath.Join(p.projectRoot, "vendor", filepath.FromSlash(target)) +} + +func isLocalReplace(newPath string) bool { + return strings.HasPrefix(newPath, "./") || + strings.HasPrefix(newPath, "../") || + newPath == "." || + filepath.IsAbs(newPath) +} + +// readGoFilesIn reads every non-test .go file in dir as a projectFile. +// Returns nil + error if dir doesn't exist or can't be listed; returns +// nil + nil (interpreted as "not vendored") when the dir is empty of .go +// files. +func readGoFilesIn(dir string) ([]projectFile, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + var out []projectFile + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { + continue + } + full := filepath.Join(dir, name) + data, err := os.ReadFile(full) + if err != nil { + return nil, err + } + out = append(out, projectFile{path: full, content: string(data)}) + } + return out, nil +} + +// parsePackage parses + type-checks a sibling package's files. +func (p *ProjectImporter) parsePackage(importPath string, files []projectFile) (*types.Package, error) { + asts := make([]*ast.File, 0, len(files)) + for _, f := range files { + a, err := parser.ParseFile(p.fset, f.path, f.content, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("parse %s: %w", f.path, err) + } + asts = append(asts, a) + } + conf := types.Config{ + // Recurse: a sibling package may itself import another sibling. + Importer: p, + // Type errors in sibling code shouldn't break the caller's parse — + // we still want partial type info. + Error: func(error) {}, + } + info := &types.Info{ + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + // types.Config.Check's first argument becomes the resulting package's + // Path() — e.g. "github.com/x/y". This MUST be the full import path, + // not just the package name, because mapType sets a method's + // DeclaringType.FullyQualifiedName from pkg.Path() at type-mapping + // time. With "y" alone, vendored fmt-like methods would resolve to + // FQN "y" instead of "github.com/x/y". + pkg, _ := conf.Check(importPath, p.fset, asts, info) + return pkg, nil +} diff --git a/rewrite-go/pkg/parser/type_mapper.go b/rewrite-go/pkg/parser/type_mapper.go index 9cdd66caa2b..26578b1d792 100644 --- a/rewrite-go/pkg/parser/type_mapper.go +++ b/rewrite-go/pkg/parser/type_mapper.go @@ -333,6 +333,26 @@ func (m *typeMapper) mapObject(obj types.Object) tree.JavaType { if fn, ok := obj.(*types.Func); ok { return m.mapMethodObject(fn) } + // Package-alias identifiers (the `y` in `y.Hello()` after + // `import "github.com/x/y"`) come back as *types.PkgName, whose .Type() + // is the Invalid sentinel. Map them to a JavaTypeClass tagged with the + // imported package's path so recipes can recognize the reference even + // when the package's symbols aren't loaded. + if pn, ok := obj.(*types.PkgName); ok { + imported := pn.Imported() + if imported == nil { + return nil + } + // JavaType.FullyQualified.Kind is a fixed enum (Class, Enum, + // Interface, Annotation, Record, Value) without a Package value. + // Map package aliases to "Class" with the import path as the FQN; + // recipes can recognize package references by the FQN containing + // path separators (e.g. "github.com/x/y"). + return &tree.JavaTypeClass{ + Kind: "Class", + FullyQualifiedName: imported.Path(), + } + } return m.mapType(obj.Type()) } diff --git a/rewrite-go/pkg/printer/go_printer.go b/rewrite-go/pkg/printer/go_printer.go index 4f4748fbc84..23578e6da66 100644 --- a/rewrite-go/pkg/printer/go_printer.go +++ b/rewrite-go/pkg/printer/go_printer.go @@ -145,7 +145,13 @@ func (p *GoPrinter) VisitBlock(block *tree.Block, param any) tree.J { out.Append("{") for _, rp := range block.Statements { p.Visit(rp.Element, out) + // If the source had `;` separating this statement from the + // next, the parser captured the leading space as After and + // stamped a Semicolon marker on the RightPadded. p.visitSpace(rp.After, out) + if tree.FindMarker[tree.Semicolon](rp.Markers) != nil { + out.Append(";") + } } p.visitSpace(block.End, out) out.Append("}") @@ -205,6 +211,15 @@ func (p *GoPrinter) VisitAssignment(assign *tree.Assignment, param any) tree.J { func (p *GoPrinter) VisitMethodDeclaration(md *tree.MethodDeclaration, param any) tree.J { out := param.(*PrintOutputCapture) + // Each leading annotation emits as a `//[ ]` line. The + // annotation's Prefix supplies the whitespace before `//` (newline + // + indent for non-first directives). The newline that follows the + // last directive lives on md.Prefix below. + for _, ann := range md.LeadingAnnotations { + p.visitSpace(ann.Prefix, out) + out.Append("//") + p.printDirectiveBody(ann, out) + } p.beforeSyntax(md.Prefix, md.Markers, out) isInterfaceMethod := tree.FindMarker[tree.InterfaceMethod](md.Markers) != nil if !isInterfaceMethod { @@ -293,6 +308,17 @@ func (p *GoPrinter) VisitMethodInvocation(mi *tree.MethodInvocation, param any) func (p *GoPrinter) VisitVariableDeclarations(vd *tree.VariableDeclarations, param any) tree.J { out := param.(*PrintOutputCapture) + // Non-struct context: any leading annotations are `//go:`-style + // directives, emitted before the var/const keyword. Struct-field + // context handles annotations later (after the type) as a + // backtick-wrapped tag. + if !p.insideStructType() { + for _, ann := range vd.LeadingAnnotations { + p.visitSpace(ann.Prefix, out) + out.Append("//") + p.printDirectiveBody(ann, out) + } + } p.beforeSyntax(vd.Prefix, vd.Markers, out) isGroupedSpec := tree.FindMarker[tree.GroupedSpec](vd.Markers) != nil if !isGroupedSpec { @@ -336,9 +362,23 @@ func (p *GoPrinter) VisitVariableDeclarations(vd *tree.VariableDeclarations, par if vd.TypeExpr != nil { p.Visit(vd.TypeExpr, out) } - // Then struct tag if present - if tag := tree.FindMarker[tree.StructTag](vd.Markers); tag != nil { - p.Visit(tag.Tag, out) + // Then struct tag, reconstructed from LeadingAnnotations (one + // Annotation per `key:"value"` pair). Only emitted when this + // VariableDeclarations is a struct field — non-struct positions + // don't allow tags syntactically. Inner-leading / inner-trailing + // whitespace is normalized to gofmt's canonical zero-padding (we + // chose Option 1 in the design discussion: lossy on non-canonical + // input, exact on gofmt'd input). + if len(vd.LeadingAnnotations) > 0 && p.insideStructType() { + first := vd.LeadingAnnotations[0] + p.visitSpace(first.Prefix, out) + out.Append("`") + p.printAnnotationBody(first, out) + for _, ann := range vd.LeadingAnnotations[1:] { + p.visitSpace(ann.Prefix, out) + p.printAnnotationBody(ann, out) + } + out.Append("`") } // Then initializers firstInit := true @@ -508,6 +548,78 @@ func (p *GoPrinter) VisitForEachControl(control *tree.ForEachControl, param any) return control } +// VisitAnnotation prints an Annotation in struct-tag form +// (`key:"value"`) — including its leading whitespace via Prefix. +// Backtick wrapping is the VariableDeclarations printer's job for +// struct-field context; this method only emits the annotation's own +// substring. +func (p *GoPrinter) VisitAnnotation(ann *tree.Annotation, param any) tree.J { + out := param.(*PrintOutputCapture) + p.beforeSyntax(ann.Prefix, ann.Markers, out) + p.printAnnotationBody(ann, out) + p.afterSyntax(ann.Markers, out) + return ann +} + +// printAnnotationBody emits an annotation's body content in struct-tag +// form (type, colon, arguments) without the leading Prefix. Used by +// VariableDeclarations to lay out a backtick-wrapped struct tag where +// the first annotation's Prefix lives outside the backticks. +func (p *GoPrinter) printAnnotationBody(ann *tree.Annotation, out *PrintOutputCapture) { + if ann.AnnotationType != nil { + p.Visit(ann.AnnotationType, out) + } + if ann.Arguments != nil { + p.visitSpace(ann.Arguments.Before, out) + out.Append(":") + for _, rp := range ann.Arguments.Elements { + p.Visit(rp.Element, out) + p.visitSpace(rp.After, out) + } + } +} + +// printDirectiveBody emits an annotation's body in source-directive +// form: `[ ]`. Used to render `//go:noinline`, +// `//go:linkname x runtime.x`, `//lint:ignore`, etc., on +// MethodDeclaration / TypeDecl / top-level VariableDeclarations. The +// preceding `//` and the annotation's leading Prefix are emitted by +// the caller; this helper only produces the substring after `//`. +// +// Arguments are emitted as their raw source (single Literal whose +// Source field carries the rest-of-line text). The space between the +// directive name and its arguments lives on the Arguments.Before slot +// — typically a single space. +func (p *GoPrinter) printDirectiveBody(ann *tree.Annotation, out *PrintOutputCapture) { + if ann.AnnotationType != nil { + p.Visit(ann.AnnotationType, out) + } + if ann.Arguments != nil { + p.visitSpace(ann.Arguments.Before, out) + for _, rp := range ann.Arguments.Elements { + p.Visit(rp.Element, out) + p.visitSpace(rp.After, out) + } + } +} + +// insideStructType reports whether the cursor's value sits inside a +// StructType ancestor — i.e., it's a struct field rather than a +// top-level / local / parameter declaration. Drives the struct-tag +// rendering decision in VisitVariableDeclarations. +func (p *GoPrinter) insideStructType() bool { + c := p.Cursor() + if c == nil { + return false + } + for cur := c.Parent(); cur != nil; cur = cur.Parent() { + if _, ok := cur.Value().(*tree.StructType); ok { + return true + } + } + return false +} + func (p *GoPrinter) VisitUnary(unary *tree.Unary, param any) tree.J { out := param.(*PrintOutputCapture) p.beforeSyntax(unary.Prefix, unary.Markers, out) @@ -920,6 +1032,11 @@ func (p *GoPrinter) VisitInterfaceType(it *tree.InterfaceType, param any) tree.J func (p *GoPrinter) VisitTypeDecl(td *tree.TypeDecl, param any) tree.J { out := param.(*PrintOutputCapture) + for _, ann := range td.LeadingAnnotations { + p.visitSpace(ann.Prefix, out) + out.Append("//") + p.printDirectiveBody(ann, out) + } p.beforeSyntax(td.Prefix, td.Markers, out) if tree.FindMarker[tree.GroupedSpec](td.Markers) == nil { out.Append("type") diff --git a/rewrite-go/pkg/printer/parity_test.go b/rewrite-go/pkg/printer/parity_test.go new file mode 100644 index 00000000000..26023447dfc --- /dev/null +++ b/rewrite-go/pkg/printer/parity_test.go @@ -0,0 +1,108 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//go:build parityaudit + +package printer_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" +) + +// TestPrinterCorpus walks every .go file under +// test/printer-corpus/ and asserts parse → print byte-equality. Each +// failure is a printer bug to fix in pkg/printer/go_printer.go. +// +// Gated behind the `parityaudit` build tag (see Makefile target +// `make parity`). Default test runs skip this entirely so corpus +// triage doesn't block CI. +func TestPrinterCorpus(t *testing.T) { + root := findCorpusRoot(t) + + var fixtures []string + if err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if strings.HasSuffix(path, ".go") { + fixtures = append(fixtures, path) + } + return nil + }); err != nil { + t.Fatalf("walk corpus: %v", err) + } + if len(fixtures) == 0 { + t.Fatalf("no .go fixtures found under %s", root) + } + + for _, f := range fixtures { + f := f + name, _ := filepath.Rel(root, f) + t.Run(name, func(t *testing.T) { + content, err := os.ReadFile(f) + if err != nil { + t.Fatalf("read: %v", err) + } + cu, err := parser.NewGoParser().Parse(filepath.Base(f), string(content)) + if err != nil { + t.Fatalf("parse: %v", err) + } + printed := printer.Print(cu) + if printed != string(content) { + t.Errorf("byte-equality failed\n--- expected ---\n%s\n--- actual ---\n%s", + string(content), printed) + } + }) + } +} + +// findCorpusRoot resolves test/testdata/printer-corpus relative to this +// test file. The corpus lives under testdata/ so the standard Go tooling +// (`go test ./...`) skips it — Go treats `testdata/` as a magic +// directory that's ignored when discovering packages, even though our +// fixtures are valid `.go` files. +func findCorpusRoot(t *testing.T) string { + t.Helper() + // pkg/printer/parity_test.go → ../../test/testdata/printer-corpus + candidates := []string{ + filepath.Join("..", "..", "test", "testdata", "printer-corpus"), + } + for _, c := range candidates { + if info, err := os.Stat(c); err == nil && info.IsDir() { + return c + } + } + t.Fatalf("printer-corpus directory not found relative to %s", mustGetwd(t)) + return "" +} + +func mustGetwd(t *testing.T) string { + t.Helper() + wd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + return wd +} diff --git a/rewrite-go/pkg/recipe/data_table.go b/rewrite-go/pkg/recipe/data_table.go new file mode 100644 index 00000000000..4d50c37c6cc --- /dev/null +++ b/rewrite-go/pkg/recipe/data_table.go @@ -0,0 +1,308 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package recipe + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "reflect" + "regexp" + "strings" + "sync" +) + +// DataTableStoreKey is the ExecutionContext message key under which a +// DataTableStore is installed. Mirrors JS DATA_TABLE_STORE. +const DataTableStoreKey = "org.openrewrite.dataTables.store" + +// DataTableLike is the type-erased view of a DataTable that DataTableStore +// implementations work against. Generic DataTable[Row] satisfies this. +type DataTableLike interface { + Descriptor() DataTableDescriptor + InstanceName() string + Group() string +} + +// DataTableStore is where rows emitted by a recipe end up. The default +// in-memory store is created lazily on first InsertRow when no store has +// been installed in the ExecutionContext. +type DataTableStore interface { + InsertRow(dt DataTableLike, ctx *ExecutionContext, row any) + GetRows(dataTableName, group string) []any + GetDataTables() []DataTableLike +} + +// DataTable is a typed, strongly-bound handle a recipe holds onto and writes +// rows into. Row is the row type; recipes typically declare: +// +// var Findings = recipe.NewDataTable[FindingsRow]( +// "org.example.MyRecipe.Findings", +// "My findings", "What MyRecipe found", +// []recipe.ColumnDescriptor{...}) +type DataTable[Row any] struct { + descriptor DataTableDescriptor + instanceName string + group string +} + +// NewDataTable creates a DataTable handle. The instance name defaults to +// the display name; recipes may override it via SetInstanceName / +// SetGroup if multiple buckets within one recipe run are needed. +func NewDataTable[Row any](name, displayName, description string, columns []ColumnDescriptor) *DataTable[Row] { + return &DataTable[Row]{ + descriptor: DataTableDescriptor{ + Name: name, + DisplayName: displayName, + Description: description, + Columns: columns, + }, + instanceName: displayName, + } +} + +func (dt *DataTable[Row]) Descriptor() DataTableDescriptor { return dt.descriptor } +func (dt *DataTable[Row]) InstanceName() string { return dt.instanceName } +func (dt *DataTable[Row]) Group() string { return dt.group } +func (dt *DataTable[Row]) SetInstanceName(name string) { dt.instanceName = name } +func (dt *DataTable[Row]) SetGroup(group string) { dt.group = group } + +// InsertRow appends a row to the data table. If no DataTableStore has been +// installed in ctx, an InMemoryDataTableStore is created lazily. +func (dt *DataTable[Row]) InsertRow(ctx *ExecutionContext, row Row) { + store, ok := ctx.GetMessage(DataTableStoreKey) + if !ok { + store = NewInMemoryDataTableStore() + ctx.PutMessage(DataTableStoreKey, store) + } + if s, ok := store.(DataTableStore); ok { + s.InsertRow(dt, ctx, row) + } +} + +// --- InMemoryDataTableStore --- + +type bucket struct { + dt DataTableLike + rows []any +} + +// InMemoryDataTableStore holds rows in memory keyed by (dataTableName, group). +// Rows can be read back via GetRows. Default for tests and recipes that +// don't need disk-backed output. +type InMemoryDataTableStore struct { + mu sync.Mutex + buckets map[string]*bucket +} + +func NewInMemoryDataTableStore() *InMemoryDataTableStore { + return &InMemoryDataTableStore{buckets: map[string]*bucket{}} +} + +func bucketKey(dt DataTableLike) string { + suffix := dt.Group() + if suffix == "" { + suffix = dt.InstanceName() + } + return dt.Descriptor().Name + "\x00" + suffix +} + +func (s *InMemoryDataTableStore) InsertRow(dt DataTableLike, _ *ExecutionContext, row any) { + s.mu.Lock() + defer s.mu.Unlock() + key := bucketKey(dt) + b, ok := s.buckets[key] + if !ok { + b = &bucket{dt: dt} + s.buckets[key] = b + } + b.rows = append(b.rows, row) +} + +func (s *InMemoryDataTableStore) GetRows(dataTableName, group string) []any { + s.mu.Lock() + defer s.mu.Unlock() + if group != "" { + if b, ok := s.buckets[dataTableName+"\x00"+group]; ok { + return append([]any{}, b.rows...) + } + return nil + } + for _, b := range s.buckets { + if b.dt.Descriptor().Name == dataTableName && b.dt.Group() == "" { + return append([]any{}, b.rows...) + } + } + return nil +} + +func (s *InMemoryDataTableStore) GetDataTables() []DataTableLike { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]DataTableLike, 0, len(s.buckets)) + for _, b := range s.buckets { + out = append(out, b.dt) + } + return out +} + +// --- CsvDataTableStore --- + +var nonAlnum = regexp.MustCompile(`[^a-z0-9]`) +var dashRun = regexp.MustCompile(`-+`) +var leadingTrailingDash = regexp.MustCompile(`^-|-$`) + +// SanitizeScope produces a filesystem-safe identifier from a scope string, +// matching JS data-table.ts:85-103. Lowercase, non-alnum→dash, collapse +// dashes, truncate to 30 at a word boundary, suffix with a 4-char sha256. +func SanitizeScope(scope string) string { + s := strings.ToLower(scope) + s = nonAlnum.ReplaceAllString(s, "-") + s = dashRun.ReplaceAllString(s, "-") + s = leadingTrailingDash.ReplaceAllString(s, "") + if len(s) > 30 { + s = s[:30] + if i := strings.LastIndex(s, "-"); i > 0 { + s = s[:i] + } + } + sum := sha256.Sum256([]byte(scope)) + return s + "-" + hex.EncodeToString(sum[:2]) +} + +// CsvDataTableStore writes rows directly to CSV files as they are inserted. +// One file per (recipe, dataTable, group) tuple, opened append-mode and +// kept open for the store's lifetime. +type CsvDataTableStore struct { + outputDir string + mu sync.Mutex + files map[string]*os.File + dataTables map[string]DataTableLike +} + +func NewCsvDataTableStore(outputDir string) (*CsvDataTableStore, error) { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return nil, fmt.Errorf("create dataTables output dir: %w", err) + } + return &CsvDataTableStore{ + outputDir: outputDir, + files: map[string]*os.File{}, + dataTables: map[string]DataTableLike{}, + }, nil +} + +// Close flushes and closes all CSV files. Call on server Reset / shutdown. +func (s *CsvDataTableStore) Close() { + s.mu.Lock() + defer s.mu.Unlock() + for _, f := range s.files { + _ = f.Close() + } + s.files = map[string]*os.File{} +} + +func fileKey(dt DataTableLike) string { + scope := dt.InstanceName() + if dt.Group() != "" { + scope = scope + "-" + dt.Group() + } + return SanitizeScope(dt.Descriptor().Name + "-" + scope) +} + +func (s *CsvDataTableStore) InsertRow(dt DataTableLike, _ *ExecutionContext, row any) { + s.mu.Lock() + defer s.mu.Unlock() + + key := fileKey(dt) + f, ok := s.files[key] + if !ok { + csvPath := filepath.Join(s.outputDir, key+".csv") + newFile, err := os.OpenFile(csvPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return // best-effort; fail silently rather than crash a recipe + } + // Header preamble + column header row, written once when the file is created. + desc := dt.Descriptor() + preamble := fmt.Sprintf("# @name %s\n# @instanceName %s\n# @group %s\n", + desc.Name, dt.InstanceName(), dt.Group()) + header := make([]string, 0, len(desc.Columns)) + for _, c := range desc.Columns { + header = append(header, csvEscape(c.DisplayName)) + } + _, _ = newFile.WriteString(preamble + strings.Join(header, ",") + "\n") + f = newFile + s.files[key] = f + s.dataTables[key] = dt + } + + // Reflect the row by column name. Rows are typically structs with field + // names matching column names (case-insensitive). + desc := dt.Descriptor() + values := make([]string, 0, len(desc.Columns)) + rv := reflect.ValueOf(row) + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + for _, col := range desc.Columns { + var v any + if rv.Kind() == reflect.Struct { + fv := rv.FieldByNameFunc(func(name string) bool { + return strings.EqualFold(name, col.Name) + }) + if fv.IsValid() { + v = fv.Interface() + } + } else if rv.Kind() == reflect.Map { + mv := rv.MapIndex(reflect.ValueOf(col.Name)) + if mv.IsValid() { + v = mv.Interface() + } + } + values = append(values, csvEscape(v)) + } + _, _ = f.WriteString(strings.Join(values, ",") + "\n") +} + +func (s *CsvDataTableStore) GetRows(_, _ string) []any { + // CSV store is write-only; reads happen by parsing the CSV files + // from outside the process (the saas mounts the dir). + return nil +} + +func (s *CsvDataTableStore) GetDataTables() []DataTableLike { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]DataTableLike, 0, len(s.dataTables)) + for _, dt := range s.dataTables { + out = append(out, dt) + } + return out +} + +// csvEscape formats a value following RFC 4180. +func csvEscape(v any) string { + if v == nil { + return `""` + } + s := fmt.Sprintf("%v", v) + if strings.ContainsAny(s, ",\"\n\r") { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` + } + return s +} diff --git a/rewrite-go/pkg/recipe/execution_context.go b/rewrite-go/pkg/recipe/execution_context.go index e285c97b737..d8f9bf4f8e4 100644 --- a/rewrite-go/pkg/recipe/execution_context.go +++ b/rewrite-go/pkg/recipe/execution_context.go @@ -57,3 +57,16 @@ func (ctx *ExecutionContext) GetMessageOrDefault(key string, defaultValue any) a } return defaultValue } + +// MessageKeys returns a snapshot of the current message keys. Used by the +// BatchVisit handler to detect whether a visitor pass added new keys +// (`hasNewMessages` in the per-visitor result). +func (ctx *ExecutionContext) MessageKeys() []string { + ctx.mu.RLock() + defer ctx.mu.RUnlock() + keys := make([]string, 0, len(ctx.messages)) + for k := range ctx.messages { + keys = append(keys, k) + } + return keys +} diff --git a/rewrite-go/pkg/recipe/golang/activate.go b/rewrite-go/pkg/recipe/golang/activate.go index 8d5a3b82db2..b12ecb220fd 100644 --- a/rewrite-go/pkg/recipe/golang/activate.go +++ b/rewrite-go/pkg/recipe/golang/activate.go @@ -28,4 +28,9 @@ func Activate(r *recipe.Registry) { r.Register(&FindTypes{}, golangCategory, searchCategory) r.Register(&FindMethods{}, golangCategory, searchCategory) r.Register(&RenameXToFlag{}, golangCategory) + r.Register(&AddImport{}, golangCategory) + r.Register(&RemoveImport{}, golangCategory) + r.Register(&RemoveUnusedImports{}, golangCategory) + r.Register(&OrderImports{}, golangCategory) + r.Register(&RenamePackage{}, golangCategory) } diff --git a/rewrite-go/pkg/recipe/golang/add_import.go b/rewrite-go/pkg/recipe/golang/add_import.go new file mode 100644 index 00000000000..d4b9bea876b --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/add_import.go @@ -0,0 +1,93 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang/internal" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// AddImport adds an `import` statement to a Go compilation unit. +// +// Mirrors the Java AddImport recipe with full surface parity (per the +// C1 directive in the eng review): +// +// - PackagePath is the import path to add (e.g. "fmt", "github.com/x/y"). +// - Alias is the import alias: nil = regular import; "_" = blank import; +// "." = dot import; any other identifier = aliased import. +// - OnlyIfReferenced gates the add: true means add only when something +// in the file references the package (typed identifier with the +// matching FQN); false means add unconditionally. +// +// Cross-form idempotence: if the file already has an import of PackagePath +// in any form (regular, aliased, dot, or blank) AND the existing form is +// compatible with the requested form, AddImport is a no-op. Compatible +// means: requested alias == existing alias; or requested alias == nil and +// existing import is non-blank/non-dot. +type AddImport struct { + recipe.Base + PackagePath string + Alias *string + OnlyIfReferenced bool +} + +func (r *AddImport) Name() string { return "org.openrewrite.golang.AddImport" } +func (r *AddImport) DisplayName() string { return "Add import" } +func (r *AddImport) Description() string { + return "Add an `import` statement to a Go compilation unit. No-op if the import is already present in a compatible form." +} + +func (r *AddImport) Options() []recipe.OptionDescriptor { + opts := []recipe.OptionDescriptor{ + recipe.Option("packagePath", "Package path", "The import path to add."). + WithExample("fmt").WithValue(r.PackagePath), + } + if r.Alias != nil { + opts = append(opts, recipe.Option("alias", "Alias", `The import alias. Use "_" for blank imports, "." for dot imports.`). + AsOptional().WithExample("fmtutil").WithValue(*r.Alias)) + } + opts = append(opts, recipe.Option("onlyIfReferenced", "Only if referenced", + "When true, add the import only if some identifier in the file already references the package via type attribution."). + AsOptional().WithValue(r.OnlyIfReferenced)) + return opts +} + +func (r *AddImport) Editor() recipe.TreeVisitor { + return visitor.Init(&addImportVisitor{cfg: r}) +} + +type addImportVisitor struct { + visitor.GoVisitor + cfg *AddImport +} + +func (v *addImportVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + cu = v.GoVisitor.VisitCompilationUnit(cu, p).(*tree.CompilationUnit) + if v.cfg.PackagePath == "" { + return cu + } + if internal.HasImport(cu, v.cfg.PackagePath, v.cfg.Alias) { + return cu + } + if v.cfg.OnlyIfReferenced && !internal.ReferencedPackages(cu)[v.cfg.PackagePath] { + return cu + } + imp := internal.NewImport(v.cfg.PackagePath, v.cfg.Alias) + return internal.AddToBlock(cu, imp, internal.FindModulePath(cu)) +} diff --git a/rewrite-go/pkg/recipe/golang/annotation_service.go b/rewrite-go/pkg/recipe/golang/annotation_service.go new file mode 100644 index 00000000000..762241dc9e3 --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/annotation_service.go @@ -0,0 +1,342 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "strings" + + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// AnnotationService exposes Go's annotation-equivalents (struct field +// tags + `//go:` / `//lint:` directives) as a uniform J.Annotation +// surface. Mirrors org.openrewrite.java.service.AnnotationService. +// +// In Go, annotation-equivalents live on: +// - Struct fields: each `key:"value"` tag pair is one Annotation +// on the field's VariableDeclarations.LeadingAnnotations. +// - Top-level decls: each `//go:noinline`, `//go:linkname`, +// `//lint:ignore`, etc. is one Annotation on the enclosing +// MethodDeclaration / TypeDecl / VariableDeclarations +// LeadingAnnotations. +// +// Recipes get one via recipe.Service: +// +// svc := recipe.Service[*golang.AnnotationService](cu) +// if svc.IsAnnotatedWith(node, "json") { ... } +// +// Or for matcher-based queries: +// +// if svc.Matches(v.Cursor(), golang.NewAnnotationMatcher("go:linkname")) { ... } +type AnnotationService struct{} + +// AllAnnotations returns every annotation attached to the cursor's +// nearest-enclosing decl (the cursor's value if it's a decl, else the +// closest decl ancestor). Mirrors Java's getAllAnnotations(Cursor). +// +// Supported decl types: VariableDeclarations, MethodDeclaration, +// TypeDecl. For other node types, returns nil. +func (s *AnnotationService) AllAnnotations(c *visitor.Cursor) []*tree.Annotation { + if c == nil { + return nil + } + for cur := c; cur != nil; cur = cur.Parent() { + switch n := cur.Value().(type) { + case *tree.VariableDeclarations: + return n.LeadingAnnotations + case *tree.MethodDeclaration: + return n.LeadingAnnotations + case *tree.TypeDecl: + return n.LeadingAnnotations + } + } + return nil +} + +// AnnotationMatcher matches annotations by name. The pattern can be: +// - An exact name match: "json", "go:noinline", "lint:ignore". +// - A wildcard pattern with "*" at the end: "go:*" matches any +// `//go:` directive; "*" matches any annotation. +type AnnotationMatcher struct { + pattern string +} + +// NewAnnotationMatcher constructs a matcher for the given pattern. +// See AnnotationMatcher for pattern semantics. +func NewAnnotationMatcher(pattern string) AnnotationMatcher { + return AnnotationMatcher{pattern: pattern} +} + +// Matches reports whether the given annotation's type-name matches +// the pattern. Returns false if the annotation has no resolvable name +// (defensive — every annotation we emit has a Identifier or +// FieldAccess as its AnnotationType). +func (m AnnotationMatcher) Matches(ann *tree.Annotation) bool { + if ann == nil { + return false + } + name := annotationName(ann) + if name == "" { + return false + } + if m.pattern == "*" { + return true + } + if strings.HasSuffix(m.pattern, "*") { + return strings.HasPrefix(name, m.pattern[:len(m.pattern)-1]) + } + return name == m.pattern +} + +// annotationName extracts the annotation's type-name. For Identifier +// returns its Name; for FieldAccess (qualified) returns "Target.Name". +func annotationName(ann *tree.Annotation) string { + switch t := ann.AnnotationType.(type) { + case *tree.Identifier: + return t.Name + case *tree.FieldAccess: + // Walk the FieldAccess chain — uncommon for Go annotations but + // included for cross-language symmetry with Java FQNs. + var b strings.Builder + appendFieldAccessName(&b, t) + return b.String() + } + return "" +} + +func appendFieldAccessName(b *strings.Builder, fa *tree.FieldAccess) { + switch t := fa.Target.(type) { + case *tree.Identifier: + b.WriteString(t.Name) + case *tree.FieldAccess: + appendFieldAccessName(b, t) + } + if fa.Name.Element != nil { + if b.Len() > 0 { + b.WriteByte('.') + } + b.WriteString(fa.Name.Element.Name) + } +} + +// Matches reports whether the cursor's enclosing decl carries an +// annotation matching the matcher. +func (s *AnnotationService) Matches(c *visitor.Cursor, matcher AnnotationMatcher) bool { + for _, a := range s.AllAnnotations(c) { + if matcher.Matches(a) { + return true + } + } + return false +} + +// IsAnnotatedWith reports whether the given decl carries an annotation +// whose name equals (or wildcard-matches) the given pattern. Convenience +// over Matches with an inline matcher. Common patterns: +// +// svc.IsAnnotatedWith(field, "json") // exact key match on a struct field +// svc.IsAnnotatedWith(funcDecl, "go:noinline") // exact directive match +// svc.IsAnnotatedWith(funcDecl, "go:*") // any //go: directive +func (s *AnnotationService) IsAnnotatedWith(t tree.Tree, pattern string) bool { + matcher := NewAnnotationMatcher(pattern) + for _, a := range annotationsOn(t) { + if matcher.Matches(a) { + return true + } + } + return false +} + +// FindAnnotations returns all annotations on the given decl that match +// the pattern. +func (s *AnnotationService) FindAnnotations(t tree.Tree, pattern string) []*tree.Annotation { + matcher := NewAnnotationMatcher(pattern) + var out []*tree.Annotation + for _, a := range annotationsOn(t) { + if matcher.Matches(a) { + out = append(out, a) + } + } + return out +} + +// annotationsOn returns LeadingAnnotations directly from a decl node, +// without cursor traversal. Used by the IsAnnotatedWith / FindAnnotations +// surface, which take the decl node directly rather than a cursor. +func annotationsOn(t tree.Tree) []*tree.Annotation { + switch n := t.(type) { + case *tree.VariableDeclarations: + return n.LeadingAnnotations + case *tree.MethodDeclaration: + return n.LeadingAnnotations + case *tree.TypeDecl: + return n.LeadingAnnotations + } + return nil +} + +// AddAnnotationVisitor returns a visitor that appends an Annotation to +// the LeadingAnnotations of every matching declaration in the visited +// tree. The annotation's `Prefix` is set to a single newline so it +// renders on its own line above the decl (directive convention). +// +// matcher selects which decls receive the annotation. Pass +// `func(t tree.Tree) bool { return true }` to apply universally +// (rare; most recipes scope by decl name or context). +// +// For struct field tags, recipes typically construct the annotation +// manually and append directly rather than using this visitor — tags +// have no leading-newline convention. +func (s *AnnotationService) AddAnnotationVisitor(matcher func(tree.Tree) bool, ann *tree.Annotation) recipe.TreeVisitor { + return visitor.Init(&addAnnotationVisitor{match: matcher, ann: ann}) +} + +// RemoveAnnotationVisitor returns a visitor that removes any +// annotation matching the given pattern from every decl in the +// visited tree. +func (s *AnnotationService) RemoveAnnotationVisitor(pattern string) recipe.TreeVisitor { + return visitor.Init(&removeAnnotationVisitor{matcher: NewAnnotationMatcher(pattern)}) +} + +type addAnnotationVisitor struct { + visitor.GoVisitor + match func(tree.Tree) bool + ann *tree.Annotation +} + +func (v *addAnnotationVisitor) VisitMethodDeclaration(md *tree.MethodDeclaration, p any) tree.J { + md = v.GoVisitor.VisitMethodDeclaration(md, p).(*tree.MethodDeclaration) + if v.match(md) { + clone := positionDirectiveAnnotation(v.ann, &md.Prefix, len(md.LeadingAnnotations) == 0) + md = md.WithLeadingAnnotations(append(append([]*tree.Annotation{}, md.LeadingAnnotations...), clone)) + } + return md +} + +func (v *addAnnotationVisitor) VisitTypeDecl(td *tree.TypeDecl, p any) tree.J { + td = v.GoVisitor.VisitTypeDecl(td, p).(*tree.TypeDecl) + if v.match(td) { + clone := positionDirectiveAnnotation(v.ann, &td.Prefix, len(td.LeadingAnnotations) == 0) + td = td.WithLeadingAnnotations(append(append([]*tree.Annotation{}, td.LeadingAnnotations...), clone)) + } + return td +} + +func (v *addAnnotationVisitor) VisitVariableDeclarations(vd *tree.VariableDeclarations, p any) tree.J { + vd = v.GoVisitor.VisitVariableDeclarations(vd, p).(*tree.VariableDeclarations) + if v.match(vd) { + clone := positionDirectiveAnnotation(v.ann, &vd.Prefix, len(vd.LeadingAnnotations) == 0) + vd = vd.WithLeadingAnnotations(append(append([]*tree.Annotation{}, vd.LeadingAnnotations...), clone)) + } + return vd +} + +// positionDirectiveAnnotation produces a clone of `template` with its +// Prefix set so it sits on its own line above the decl. Mutates +// `*declPrefix` in place when this annotation is the first to be +// added: the outer leading whitespace migrates from the decl's Prefix +// onto the new annotation, and the decl's Prefix becomes a single +// newline (the separator between the directive line and the +// keyword). For non-first annotations, the decl's Prefix is left +// alone and the new annotation gets a `\n` prefix so it stacks below +// existing directives. +func positionDirectiveAnnotation(template *tree.Annotation, declPrefix *tree.Space, isFirst bool) *tree.Annotation { + clone := cloneAnnotation(template) + if isFirst { + clone.Prefix = *declPrefix + *declPrefix = tree.Space{Whitespace: "\n"} + } else { + clone.Prefix = tree.Space{Whitespace: "\n"} + } + return clone +} + +type removeAnnotationVisitor struct { + visitor.GoVisitor + matcher AnnotationMatcher +} + +func (v *removeAnnotationVisitor) VisitMethodDeclaration(md *tree.MethodDeclaration, p any) tree.J { + md = v.GoVisitor.VisitMethodDeclaration(md, p).(*tree.MethodDeclaration) + return md.WithLeadingAnnotations(filterAnnotations(md.LeadingAnnotations, v.matcher)) +} + +func (v *removeAnnotationVisitor) VisitTypeDecl(td *tree.TypeDecl, p any) tree.J { + td = v.GoVisitor.VisitTypeDecl(td, p).(*tree.TypeDecl) + return td.WithLeadingAnnotations(filterAnnotations(td.LeadingAnnotations, v.matcher)) +} + +func (v *removeAnnotationVisitor) VisitVariableDeclarations(vd *tree.VariableDeclarations, p any) tree.J { + vd = v.GoVisitor.VisitVariableDeclarations(vd, p).(*tree.VariableDeclarations) + return vd.WithLeadingAnnotations(filterAnnotations(vd.LeadingAnnotations, v.matcher)) +} + +func filterAnnotations(in []*tree.Annotation, m AnnotationMatcher) []*tree.Annotation { + if len(in) == 0 { + return in + } + out := make([]*tree.Annotation, 0, len(in)) + for _, a := range in { + if !m.Matches(a) { + out = append(out, a) + } + } + if len(out) == len(in) { + return in + } + return out +} + +// cloneAnnotation produces a fresh Annotation with new UUIDs for the +// outer node and its inner Identifier/Literal so the same template +// can be applied to multiple decls without ID collisions. +func cloneAnnotation(ann *tree.Annotation) *tree.Annotation { + if ann == nil { + return nil + } + c := *ann + c.ID = uuid.New() + if id, ok := ann.AnnotationType.(*tree.Identifier); ok { + idClone := *id + idClone.ID = uuid.New() + c.AnnotationType = &idClone + } + if ann.Arguments != nil { + args := *ann.Arguments + newElems := make([]tree.RightPadded[tree.Expression], len(args.Elements)) + for i, rp := range args.Elements { + rp2 := rp + if lit, ok := rp.Element.(*tree.Literal); ok { + litClone := *lit + litClone.ID = uuid.New() + rp2.Element = &litClone + } + newElems[i] = rp2 + } + args.Elements = newElems + c.Arguments = &args + } + return &c +} + +func init() { + recipe.RegisterService[*AnnotationService](func() any { return &AnnotationService{} }) +} diff --git a/rewrite-go/pkg/recipe/golang/auto_format_service.go b/rewrite-go/pkg/recipe/golang/auto_format_service.go new file mode 100644 index 00000000000..d9927692bad --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/auto_format_service.go @@ -0,0 +1,81 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/format" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// AutoFormatService exposes gofmt-style normalization as composable +// visitors. Mirrors org.openrewrite.java.service.AutoFormatService. +// +// Recipes get one via recipe.Service: +// +// svc := recipe.Service[*golang.AutoFormatService](cu) +// v.DoAfterVisit(svc.AutoFormatVisitor(nil)) +// +// AutoFormatVisitor runs the full pipeline (trailing whitespace → +// blank lines → tabs/indents → spaces). The single-pass visitors +// (TabsAndIndentsVisitor, BlankLinesVisitor, etc.) are exposed +// individually for recipes that only need one normalization pass. +// +// Each method accepts a `stopAfter` bound. When non-nil, traversal +// halts once that node has been fully visited — useful for formatting +// only a synthesized subtree without disturbing surrounding code. +// Pass nil to format the entire visited subtree. +type AutoFormatService struct{} + +// AutoFormatVisitor returns a visitor that applies the full gofmt-style +// pipeline. Composes via DoAfterVisit so individual passes can be +// inspected or replaced independently if needed. +func (s *AutoFormatService) AutoFormatVisitor(stopAfter tree.Tree) recipe.TreeVisitor { + return format.NewAutoFormatVisitor(stopAfter) +} + +// TabsAndIndentsVisitor returns just the indent-fix pass. Use when a +// recipe has spliced a subtree that needs re-indenting but already has +// correct internal spacing. +func (s *AutoFormatService) TabsAndIndentsVisitor(stopAfter tree.Tree) recipe.TreeVisitor { + return format.NewTabsAndIndentsVisitor(stopAfter) +} + +// BlankLinesVisitor returns just the blank-line collapse pass. Use to +// clean up after a delete-and-replace edit that left stray blank lines +// inside a block. +func (s *AutoFormatService) BlankLinesVisitor(stopAfter tree.Tree) recipe.TreeVisitor { + return format.NewBlankLinesVisitor(stopAfter) +} + +// SpacesVisitor returns just the intra-line spacing pass. Use after a +// recipe has built up a binary/assignment node from raw parts and the +// operator surrounds need normalizing to a single space. +func (s *AutoFormatService) SpacesVisitor(stopAfter tree.Tree) recipe.TreeVisitor { + return format.NewSpacesVisitor(stopAfter) +} + +// RemoveTrailingWhitespaceVisitor returns just the trailing-whitespace +// strip pass. Useful as a standalone cleanup over a tree the rest of +// the pipeline shouldn't touch. +func (s *AutoFormatService) RemoveTrailingWhitespaceVisitor(stopAfter tree.Tree) recipe.TreeVisitor { + return format.NewRemoveTrailingWhitespaceVisitor(stopAfter) +} + +func init() { + recipe.RegisterService[*AutoFormatService](func() any { return &AutoFormatService{} }) +} diff --git a/rewrite-go/pkg/recipe/golang/import_service.go b/rewrite-go/pkg/recipe/golang/import_service.go new file mode 100644 index 00000000000..b5912ea7226 --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/import_service.go @@ -0,0 +1,102 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" +) + +// ImportService exposes the Go-side import-manipulation primitives as +// composable visitors. Mirrors org.openrewrite.java.service.ImportService. +// +// Recipes get one via recipe.Service: +// +// svc := recipe.Service[*golang.ImportService](cu) +// v.DoAfterVisit(svc.AddImportVisitor("fmt", nil, false)) +// +// Each method returns a recipe.TreeVisitor configured for the +// requested operation. Visitors can be applied directly via +// `v.Visit(cu, ctx)` OR queued for the after-visit drain via +// `GoVisitor.DoAfterVisit(v)`. The latter is the canonical way to +// compose import side-effects with a main edit: +// +// func (r *MyRecipe) Editor() recipe.TreeVisitor { +// return visitor.Init(&myVisitor{}) +// } +// +// type myVisitor struct{ visitor.GoVisitor } +// +// func (v *myVisitor) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { +// if shouldRewrite(mi) { +// // ...rewrite mi... +// cu := v.Cursor().FirstEnclosing(reflect.TypeOf((*tree.CompilationUnit)(nil))) +// svc := recipe.Service[*golang.ImportService](cu) +// v.DoAfterVisit(svc.AddImportVisitor("fmt", nil, false)) +// } +// return mi +// } +type ImportService struct{} + +// AddImportVisitor returns a visitor that adds `import [alias] "packagePath"` +// to the visited compilation unit. No-op if the import is already +// present in a compatible form. +// +// - alias == nil → regular import (`import "fmt"`). +// - alias == "_" → blank import for init() side-effects. +// - alias == "." → dot import. +// - any other identifier → aliased import. +// +// When onlyIfReferenced is true, the import is added only when some +// identifier in the file already references the package via type +// attribution. That's the safe default for IDE-style refactors that +// shouldn't introduce dead imports. +func (s *ImportService) AddImportVisitor(packagePath string, alias *string, onlyIfReferenced bool) recipe.TreeVisitor { + return (&AddImport{ + PackagePath: packagePath, + Alias: alias, + OnlyIfReferenced: onlyIfReferenced, + }).Editor() +} + +// RemoveImportVisitor returns a visitor that deletes any `import` whose +// path matches packagePath. Aliased / blank / dot forms are all +// removed. Empty import containers are nil-ed out so the printer +// doesn't emit an empty `import ()` block. +func (s *ImportService) RemoveImportVisitor(packagePath string) recipe.TreeVisitor { + return (&RemoveImport{PackagePath: packagePath}).Editor() +} + +// RemoveUnusedImportsVisitor returns a visitor that drops imports +// whose alias is never referenced in the file. Mirrors `goimports -w`. +// Blank (`_`) and dot (`.`) imports are preserved by semantic rule. +func (s *ImportService) RemoveUnusedImportsVisitor() recipe.TreeVisitor { + return (&RemoveUnusedImports{}).Editor() +} + +// OrderImportsVisitor returns a visitor that sorts imports into +// stdlib / third-party / local groups. Local detection uses the +// sibling go.mod's module path (via the GoResolutionResult marker). +func (s *ImportService) OrderImportsVisitor() recipe.TreeVisitor { + return (&OrderImports{}).Editor() +} + +func init() { + // Register the service factory so callers can do + // `recipe.Service[*golang.ImportService](cu)`. Stateless — one + // instance per call is fine. + recipe.RegisterService[*ImportService](func() any { return &ImportService{} }) +} diff --git a/rewrite-go/pkg/recipe/golang/internal/imports.go b/rewrite-go/pkg/recipe/golang/internal/imports.go new file mode 100644 index 00000000000..6f6532996cf --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/internal/imports.go @@ -0,0 +1,547 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package internal contains primitives shared across the import-management +// recipes (AddImport, RemoveImport, RemoveUnusedImports, OrderImports). It +// is not part of the public API — callers depend on the recipe types +// themselves, not on these helpers. +package internal + +import ( + "sort" + "strings" + + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// ImportGroup identifies which gofmt-style group an import belongs to. +type ImportGroup int + +const ( + // Stdlib imports — paths with no `.` in their first segment + // (e.g. "fmt", "net/http"). + Stdlib ImportGroup = iota + // ThirdParty imports — paths with a `.` in their first segment + // (e.g. "github.com/x/y") that are not local to the current module. + ThirdParty + // Local imports — paths under the current module's path. + Local +) + +// ImportPath returns the unquoted import path of an Import (e.g. "fmt"). +// Returns "" when Qualid isn't a Literal (defensive — shouldn't happen +// for well-formed Go source). +// +// The Go parser stores the raw quoted source in Literal.Value (and +// .Source) — `"fmt"` not `fmt` — so this helper always strips the +// surrounding quote pair before returning. +func ImportPath(imp *tree.Import) string { + if imp == nil { + return "" + } + lit, ok := imp.Qualid.(*tree.Literal) + if !ok || lit == nil { + return "" + } + raw := "" + if s, ok := lit.Value.(string); ok { + raw = s + } else { + raw = lit.Source + } + return strings.Trim(raw, `"`+"`") +} + +// AliasName returns the alias used by an Import: a custom identifier for +// `import alias "path"`, "_" for blank imports, "." for dot imports, or +// "" when the import uses the default (last segment of the path). +func AliasName(imp *tree.Import) string { + if imp == nil || imp.Alias == nil { + return "" + } + if imp.Alias.Element == nil { + return "" + } + return imp.Alias.Element.Name +} + +// FindImport returns the first import in cu whose path equals importPath. +// Returns nil if no match. +func FindImport(cu *tree.CompilationUnit, importPath string) *tree.Import { + if cu == nil || cu.Imports == nil { + return nil + } + for _, rp := range cu.Imports.Elements { + if ImportPath(rp.Element) == importPath { + return rp.Element + } + } + return nil +} + +// HasImport returns true when an import with the given path (and matching +// alias semantics) is already present. +// - alias == nil: any non-blank, non-dot form counts as a hit +// - alias != nil: requires an exact alias match (use "_" for blank, +// "." for dot) +func HasImport(cu *tree.CompilationUnit, importPath string, alias *string) bool { + if cu == nil || cu.Imports == nil { + return false + } + for _, rp := range cu.Imports.Elements { + if ImportPath(rp.Element) != importPath { + continue + } + existingAlias := AliasName(rp.Element) + if alias == nil { + if existingAlias == "_" || existingAlias == "." { + continue + } + return true + } + if *alias == existingAlias { + return true + } + } + return false +} + +// GroupOf returns the ImportGroup an import path belongs to relative to +// the current module's path. modulePath may be empty when the file is +// outside any known module — local detection then falls back to "any +// non-stdlib path with a `.`" → ThirdParty. +func GroupOf(importPath, modulePath string) ImportGroup { + if IsLocal(importPath, modulePath) { + return Local + } + if IsThirdParty(importPath) { + return ThirdParty + } + return Stdlib +} + +// IsStdlib reports whether importPath looks like a stdlib package (no `.` +// in the first path segment). Mirrors goimports/gofmt's heuristic. +func IsStdlib(importPath string) bool { + first := importPath + if i := strings.Index(importPath, "/"); i >= 0 { + first = importPath[:i] + } + return !strings.Contains(first, ".") +} + +// IsThirdParty reports whether importPath is a non-stdlib, non-local +// dependency. Equivalent to `!IsStdlib && !IsLocal`. +func IsThirdParty(importPath string) bool { + return !IsStdlib(importPath) +} + +// IsLocal reports whether importPath belongs to the current module +// (modulePath itself or any sub-package under it). +func IsLocal(importPath, modulePath string) bool { + if modulePath == "" { + return false + } + return importPath == modulePath || strings.HasPrefix(importPath, modulePath+"/") +} + +// ReferencedPackages walks cu and returns the set of import paths that +// are referenced by some identifier in the file body. Used by +// RemoveUnusedImports to drop imports whose alias is never read. +// +// Detection is driven by the Type attribution that the parser threads +// onto each Identifier: +// - For an `Identifier` whose `Type` is a `JavaTypeClass` and whose +// `FullyQualifiedName` carries an import path (path-shaped FQN), +// that path is added to the set. +// - For a `MethodInvocation`, the `MethodType.DeclaringType.FullyQualifiedName` +// is used. +// +// Aliases and dot imports are handled uniformly — the package's import +// path is what we track, regardless of how the user named it. +func ReferencedPackages(cu *tree.CompilationUnit) map[string]bool { + refs := map[string]bool{} + if cu == nil { + return refs + } + v := visitor.Init(&referencedPackagesVisitor{refs: refs}) + v.Visit(cu, nil) + return refs +} + +type referencedPackagesVisitor struct { + visitor.GoVisitor + refs map[string]bool +} + +func (v *referencedPackagesVisitor) VisitIdentifier(ident *tree.Identifier, p any) tree.J { + if ident.Type != nil { + if fq, ok := ident.Type.(tree.FullyQualified); ok { + if path := pkgPathOf(fq.GetFullyQualifiedName()); path != "" { + v.refs[path] = true + } + } + } + return ident +} + +func (v *referencedPackagesVisitor) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { + if mi.MethodType != nil && mi.MethodType.DeclaringType != nil { + if path := pkgPathOf(mi.MethodType.DeclaringType.FullyQualifiedName); path != "" { + v.refs[path] = true + } + } + return v.GoVisitor.VisitMethodInvocation(mi, p) +} + +// pkgPathOf returns the package import path implied by an FQN. The +// type-mapper produces FQNs in two shapes: +// - "" — for package aliases (the `y` in +// `y.Hello()` after `import "github.com/x/y"`). +// - "." — for named types in that package. +// +// Both shapes share an import-path prefix (the leading segment up to the +// last `.`); we return that prefix so RemoveUnusedImports can match +// against the literal import path. +func pkgPathOf(fqn string) string { + if fqn == "" { + return "" + } + // FQNs that are already an import path (no trailing `.TypeName`) + // contain a `/` and no `.` after the last `/`. Detect that shape and + // return the FQN as-is. + if strings.Contains(fqn, "/") { + lastSlash := strings.LastIndex(fqn, "/") + tail := fqn[lastSlash+1:] + if !strings.Contains(tail, ".") { + return fqn + } + // `.../pkg.TypeName` shape — strip the trailing `.TypeName`. + return fqn[:lastSlash+1+strings.Index(tail, ".")] + } + // Stdlib paths (e.g. "fmt") and "fmt.Println"-style FQNs. + if dot := strings.Index(fqn, "."); dot >= 0 { + return fqn[:dot] + } + return fqn +} + +// AddToBlock returns a copy of cu with imp inserted into the existing +// import container. If cu has no imports container yet, one is created +// with default formatting (single ungrouped `import "path"` line). +// +// The insertion preserves group ordering: imp is placed at the end of +// its own group block (stdlib / third-party / local) so OrderImports' +// invariant holds even when AddImport is invoked alone. +// +// Whitespace for the empty-container case: the printer concatenates +// `import`. To produce +// `\n\nimport "fmt"` between `package main` and the first statement, +// the new Container's Before is `"\n\n"`, the Import's Prefix is empty, +// and the Qualid Literal carries a leading space (printed as the space +// between `import` and the path string). The first Statement's existing +// Prefix supplies the trailing blank line before `func`. +func AddToBlock(cu *tree.CompilationUnit, imp *tree.Import, modulePath string) *tree.CompilationUnit { + if cu == nil { + return cu + } + c := *cu + if c.Imports == nil { + c.Imports = &tree.Container[*tree.Import]{ + Before: tree.Space{Whitespace: "\n\n"}, + } + // Wire the leading-space convention onto the new import: the + // space between `import` and the path lives on the Qualid + // literal's Prefix. + if lit, ok := imp.Qualid.(*tree.Literal); ok { + cloned := *lit + cloned.Prefix = tree.Space{Whitespace: " "} + imp.Qualid = &cloned + } + } + imps := *c.Imports + + // Adding a second import to an ungrouped single-import file + // promotes it to the grouped `import (...)` form — that's the + // only legal Go syntax for multiple imports in one block. + if len(imps.Elements) == 1 && tree.FindMarker[tree.GroupedImport](imps.Markers) == nil { + promoteToGrouped(&imps) + } + + imps.Elements = insertGrouped(imps.Elements, imp, modulePath) + c.Imports = &imps + return &c +} + +// promoteToGrouped converts an ungrouped single-import block to the +// grouped `import (...)` form. Adds the GroupedImport marker and +// rewrites the existing element's Prefix / After so the printer emits +// it indented inside parens. +func promoteToGrouped(imps *tree.Container[*tree.Import]) { + imps.Markers = tree.AddMarker(imps.Markers, tree.GroupedImport{ + Ident: uuid.New(), + Before: tree.Space{Whitespace: " "}, // space between `import` and `(` + }) + if len(imps.Elements) == 0 { + return + } + // The previously-ungrouped element had Qualid.Prefix=" " (space + // between `import` and the path). Inside parens we want the import + // indented onto its own line: imp.Prefix="\n\t", Qualid.Prefix="". + rp := &imps.Elements[0] + if rp.Element != nil { + imp := *rp.Element + imp.Prefix = tree.Space{Whitespace: "\n\t"} + if lit, ok := imp.Qualid.(*tree.Literal); ok { + cloned := *lit + cloned.Prefix = tree.EmptySpace + imp.Qualid = &cloned + } + rp.Element = &imp + } + rp.After = tree.Space{Whitespace: "\n"} // newline before `)` +} + +// RemoveFromBlock returns a copy of cu with imp deleted from the imports +// container. If the container becomes empty as a result, it's nil-ed out +// so the printer doesn't emit an empty `import ()` block. +// +// Whitespace handling: the removed entry's trailing space (the +// `RightPadded.After` field, which contains the newline before the next +// element or the closing `)`) is donated to the new last element so the +// block keeps its closing-paren-on-its-own-line shape. +func RemoveFromBlock(cu *tree.CompilationUnit, imp *tree.Import) *tree.CompilationUnit { + if cu == nil || cu.Imports == nil || imp == nil { + return cu + } + c := *cu + imps := *c.Imports + removedLastAfter := tree.Space{} + removedWasLast := false + out := make([]tree.RightPadded[*tree.Import], 0, len(imps.Elements)) + for i, rp := range imps.Elements { + if rp.Element != nil && rp.Element.ID == imp.ID { + if i == len(imps.Elements)-1 { + removedLastAfter = rp.After + removedWasLast = true + } + continue + } + out = append(out, rp) + } + if removedWasLast && len(out) > 0 { + // Donate the closing-paren-on-own-line whitespace to the new + // last element so the block keeps its tidy shape. + out[len(out)-1].After = removedLastAfter + } + imps.Elements = out + if len(out) == 0 { + c.Imports = nil + } else { + c.Imports = &imps + } + return &c +} + +// insertGrouped places imp at the end of its own group while preserving +// the relative order of pre-existing imports. New groups appear in +// stdlib / third-party / local order. +// +// Whitespace handling: the new import inherits a sibling's Prefix (the +// `\n\t` indent inside an `import (...)` block) so the printer renders +// it on its own line. When inserting into an empty block (no siblings), +// a sensible default is used. +func insertGrouped(elements []tree.RightPadded[*tree.Import], imp *tree.Import, modulePath string) []tree.RightPadded[*tree.Import] { + target := GroupOf(ImportPath(imp), modulePath) + insertAt := len(elements) + for i, rp := range elements { + g := GroupOf(ImportPath(rp.Element), modulePath) + if g > target { + insertAt = i + break + } + } + if imp.Prefix.Whitespace == "" && len(elements) > 0 { + // Borrow the surrounding indent. If we're inserting in front, + // take the first sibling's prefix; otherwise the previous + // sibling's. Both reliably end with `\n\t` in a grouped block. + var donor *tree.Import + if insertAt < len(elements) { + donor = elements[insertAt].Element + } else { + donor = elements[len(elements)-1].Element + } + if donor != nil { + imp.Prefix = donor.Prefix + } + } + wrapped := tree.RightPadded[*tree.Import]{Element: imp} + out := make([]tree.RightPadded[*tree.Import], 0, len(elements)+1) + out = append(out, elements[:insertAt]...) + out = append(out, wrapped) + out = append(out, elements[insertAt:]...) + + // Re-balance trailing whitespace: in a grouped block the last + // element's After holds the space before `)`. When we appended at + // the end, the previous tail's After is now between two siblings + // (and is redundant with the new sibling's leading Prefix); donate + // it forward so the new tail sits before `)` correctly. + if insertAt == len(elements) && len(out) >= 2 { + prev := &out[len(out)-2] + newTail := &out[len(out)-1] + newTail.After = prev.After + prev.After = tree.Space{} + } + return out +} + +// NewImport builds an Import LST node for `import [alias] "path"`. Pass +// alias=nil for a regular import, "_" for a blank import, "." for a dot +// import, or any identifier name for an aliased import. +func NewImport(path string, alias *string) *tree.Import { + imp := &tree.Import{ + ID: uuid.New(), + Qualid: &tree.Literal{ID: uuid.New(), Source: `"` + path + `"`, Value: path, Kind: tree.StringLiteral}, + } + if alias != nil { + imp.Alias = &tree.LeftPadded[*tree.Identifier]{ + Before: tree.Space{Whitespace: " "}, + Element: &tree.Identifier{ + ID: uuid.New(), + Name: *alias, + }, + } + } + return imp +} + +// SortByGroup returns the imports re-ordered into stdlib / third-party / +// local sequence, alphabetized within each group, with a blank line +// inserted between non-empty groups. Mirrors `goimports -w` output. +// +// Whitespace handling: +// - Element `Prefix` carries the per-line indent (typically `\n\t`). +// The first element of each non-leading non-empty group gets a leading +// `\n` prepended to its indent so the group separator is a blank line. +// - `RightPadded.After` is anchored to a position (between-elements vs. +// before-`)`) rather than to its element. When the order changes the +// anchor changes too — the element that was last is no longer last. +// This re-balances After so all but the new tail use the original +// between-element spacing and the new tail uses the original close-paren +// spacing. +func SortByGroup(elements []tree.RightPadded[*tree.Import], modulePath string) []tree.RightPadded[*tree.Import] { + if len(elements) <= 1 { + return elements + } + betweenAfter := elements[0].After + closingAfter := elements[len(elements)-1].After + + // Re-derive the per-line indent prefix from the first non-blank-line + // element so the blank-line separator below can prepend a single \n + // to it. The smallest existing prefix that ends in \n + indent wins. + indentPrefix := tree.Space{Whitespace: "\n\t"} + for _, rp := range elements { + if rp.Element == nil { + continue + } + ws := rp.Element.Prefix.Whitespace + // Strip a leading blank-line newline so we get the canonical + // "\n\t" indent rather than "\n\n\t". + canonical := ws + for strings.HasPrefix(canonical, "\n\n") { + canonical = canonical[1:] + } + if strings.HasPrefix(canonical, "\n") { + indentPrefix = tree.Space{Whitespace: canonical} + break + } + } + + type bucket struct { + group ImportGroup + items []tree.RightPadded[*tree.Import] + } + buckets := []bucket{ + {group: Stdlib}, + {group: ThirdParty}, + {group: Local}, + } + for _, rp := range elements { + g := GroupOf(ImportPath(rp.Element), modulePath) + buckets[int(g)].items = append(buckets[int(g)].items, rp) + } + for i := range buckets { + items := buckets[i].items + sort.SliceStable(items, func(a, b int) bool { + return ImportPath(items[a].Element) < ImportPath(items[b].Element) + }) + } + + out := make([]tree.RightPadded[*tree.Import], 0, len(elements)) + groupSeparatorPrefix := tree.Space{Whitespace: "\n" + indentPrefix.Whitespace} + for _, b := range buckets { + if len(b.items) == 0 { + continue + } + for j, item := range b.items { + if item.Element == nil { + continue + } + cloned := *item.Element + switch { + case j == 0 && len(out) > 0: + cloned.Prefix = groupSeparatorPrefix + default: + cloned.Prefix = indentPrefix + } + item.Element = &cloned + b.items[j] = item + } + out = append(out, b.items...) + } + for i := range out { + if i == len(out)-1 { + out[i].After = closingAfter + } else { + out[i].After = betweenAfter + } + } + return out +} + +// FindModulePath extracts the GoResolutionResult marker's ModulePath +// from the cu (or its sibling go.mod, if attached). Returns "" when no +// marker is present (which is fine — IsLocal handles empty modulePath +// by reporting false uniformly). +func FindModulePath(cu *tree.CompilationUnit) string { + if cu == nil { + return "" + } + for _, m := range cu.Markers.Entries { + if mrr, ok := m.(tree.GoResolutionResult); ok { + return mrr.ModulePath + } + } + return "" +} + +// _ tree.GoResolutionResult is referenced via FindModulePath; this +// silences the unused-import linter when callers don't pull in the tree +// package explicitly. +var _ = uuid.UUID{} diff --git a/rewrite-go/pkg/recipe/golang/naming_service.go b/rewrite-go/pkg/recipe/golang/naming_service.go new file mode 100644 index 00000000000..36b0aa1935a --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/naming_service.go @@ -0,0 +1,135 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "go/token" + "unicode" + "unicode/utf8" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" +) + +// NamingService bundles Go-style identifier helpers behind the +// recipe.Service registry. Recipes that synthesize identifiers — e.g. +// when generating accessors or renaming via a target style — can call +// these without re-implementing the rune-level logic each time. +// +// Recipes get one via recipe.Service: +// +// svc := recipe.Service[*golang.NamingService](cu) +// if svc.IsPredeclared(newName) { /* refuse: shadowing builtin */ } +// exportedName := svc.ToPascalCase(field) +// +// All helpers operate on the *first rune* per the Go spec — visibility +// is determined by the first rune's case, not by the rest of the +// identifier — and use unicode-aware case mapping so non-ASCII +// identifiers (rare but legal) round-trip correctly. +type NamingService struct{} + +// ToPascalCase returns the identifier with its first rune upper-cased +// (the "exported" form). Non-letter first runes are returned unchanged +// — caller's job to ensure the result is a valid identifier. +// +// ToPascalCase("fooBar") == "FooBar" +// ToPascalCase("Bar") == "Bar" +// ToPascalCase("") == "" +func (s *NamingService) ToPascalCase(name string) string { + r, size := utf8.DecodeRuneInString(name) + if r == utf8.RuneError || size == 0 { + return name + } + upper := unicode.ToUpper(r) + if upper == r { + return name + } + return string(upper) + name[size:] +} + +// ToCamelCase returns the identifier with its first rune lower-cased +// (the "unexported" form). Mirror of ToPascalCase. +// +// ToCamelCase("FooBar") == "fooBar" +// ToCamelCase("foo") == "foo" +func (s *NamingService) ToCamelCase(name string) string { + r, size := utf8.DecodeRuneInString(name) + if r == utf8.RuneError || size == 0 { + return name + } + lower := unicode.ToLower(r) + if lower == r { + return name + } + return string(lower) + name[size:] +} + +// IsExported reports whether name is exported per Go's rule: the first +// rune is an uppercase Unicode letter. Empty strings are not exported. +// Mirrors go/token.IsExported but kept inside the service surface so +// recipes don't need a separate import. +func (s *NamingService) IsExported(name string) bool { + return token.IsExported(name) +} + +// IsValidIdentifier reports whether name parses as a Go identifier +// (token.IsIdentifier — letter or `_` followed by letters/digits/_, +// not a reserved keyword). +func (s *NamingService) IsValidIdentifier(name string) bool { + return token.IsIdentifier(name) +} + +// IsPredeclared reports whether name is one of Go's predeclared +// identifiers (built-in types, constants, or functions). Recipes +// generating new names should refuse to shadow these — even though +// shadowing is technically legal, it produces confusing code. +// +// The set is taken from the Go spec +// (https://go.dev/ref/spec#Predeclared_identifiers); it does NOT +// include reserved keywords (those are caught by IsValidIdentifier). +func (s *NamingService) IsPredeclared(name string) bool { + _, ok := predeclaredIdentifiers[name] + return ok +} + +// predeclaredIdentifiers is the universe-block name set per the Go +// spec. Update when the language adds new builtins (e.g. `min` / +// `max` / `clear` were added in 1.21). +var predeclaredIdentifiers = map[string]struct{}{ + // Types. + "any": {}, "bool": {}, "byte": {}, "comparable": {}, + "complex64": {}, "complex128": {}, "error": {}, + "float32": {}, "float64": {}, + "int": {}, "int8": {}, "int16": {}, "int32": {}, "int64": {}, + "rune": {}, "string": {}, + "uint": {}, "uint8": {}, "uint16": {}, "uint32": {}, "uint64": {}, "uintptr": {}, + + // Constants. + "true": {}, "false": {}, "iota": {}, + + // Zero value. + "nil": {}, + + // Functions. + "append": {}, "cap": {}, "clear": {}, "close": {}, "complex": {}, + "copy": {}, "delete": {}, "imag": {}, "len": {}, "make": {}, + "max": {}, "min": {}, "new": {}, "panic": {}, "print": {}, + "println": {}, "real": {}, "recover": {}, +} + +func init() { + recipe.RegisterService[*NamingService](func() any { return &NamingService{} }) +} diff --git a/rewrite-go/pkg/recipe/golang/order_imports.go b/rewrite-go/pkg/recipe/golang/order_imports.go new file mode 100644 index 00000000000..6fe186cf718 --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/order_imports.go @@ -0,0 +1,85 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang/internal" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// OrderImports normalizes the order of `import` lines into the +// `goimports -w` convention: stdlib first, third-party second, local +// last; within each group entries are sorted alphabetically by import +// path; a blank line separates non-empty groups. +// +// Local imports are detected via the sibling go.mod's +// `GoResolutionResult.ModulePath` marker (attached by `parseProject` and +// the Java parseWithProject path). Without a module marker, every +// non-stdlib import is treated as third-party. +// +// Idempotent: running OrderImports twice yields the same result as once. +type OrderImports struct { + recipe.Base +} + +func (r *OrderImports) Name() string { return "org.openrewrite.golang.OrderImports" } +func (r *OrderImports) DisplayName() string { return "Order imports" } +func (r *OrderImports) Description() string { + return "Sort `import` lines into stdlib / third-party / local groups. Within each group, entries are alphabetized; non-empty groups are separated by a blank line. Mirrors `goimports -w`. Local detection uses the sibling go.mod's module path." +} + +func (r *OrderImports) Editor() recipe.TreeVisitor { + return visitor.Init(&orderImportsVisitor{}) +} + +type orderImportsVisitor struct { + visitor.GoVisitor +} + +func (v *orderImportsVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + cu = v.GoVisitor.VisitCompilationUnit(cu, p).(*tree.CompilationUnit) + if cu.Imports == nil || len(cu.Imports.Elements) <= 1 { + return cu + } + modulePath := internal.FindModulePath(cu) + sorted := internal.SortByGroup(cu.Imports.Elements, modulePath) + if sameOrder(cu.Imports.Elements, sorted) { + return cu + } + c := *cu + imps := *c.Imports + imps.Elements = sorted + c.Imports = &imps + return &c +} + +func sameOrder(before, after []tree.RightPadded[*tree.Import]) bool { + if len(before) != len(after) { + return false + } + for i := range before { + if before[i].Element == nil || after[i].Element == nil { + return before[i].Element == after[i].Element + } + if before[i].Element.ID != after[i].Element.ID { + return false + } + } + return true +} diff --git a/rewrite-go/pkg/recipe/golang/remove_import.go b/rewrite-go/pkg/recipe/golang/remove_import.go new file mode 100644 index 00000000000..92d0da17dfa --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/remove_import.go @@ -0,0 +1,74 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang/internal" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// RemoveImport deletes a single import from a Go compilation unit. +// +// Matches by import path: any form (regular, aliased, dot, or blank) +// that imports PackagePath is removed. If the imports container becomes +// empty as a result, it is nil-ed out so the printer doesn't emit an +// empty `import ()` block. +// +// Removing an actively-referenced import will of course break the file; +// this recipe trusts the caller to know that. For unused-import cleanup, +// use RemoveUnusedImports. +type RemoveImport struct { + recipe.Base + PackagePath string +} + +func (r *RemoveImport) Name() string { return "org.openrewrite.golang.RemoveImport" } +func (r *RemoveImport) DisplayName() string { return "Remove import" } +func (r *RemoveImport) Description() string { + return "Remove an `import` statement from a Go compilation unit. Matches by import path; any form (regular, aliased, dot, blank) is removed." +} + +func (r *RemoveImport) Options() []recipe.OptionDescriptor { + return []recipe.OptionDescriptor{ + recipe.Option("packagePath", "Package path", "The import path to remove."). + WithExample("fmt").WithValue(r.PackagePath), + } +} + +func (r *RemoveImport) Editor() recipe.TreeVisitor { + return visitor.Init(&removeImportVisitor{cfg: r}) +} + +type removeImportVisitor struct { + visitor.GoVisitor + cfg *RemoveImport +} + +func (v *removeImportVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + cu = v.GoVisitor.VisitCompilationUnit(cu, p).(*tree.CompilationUnit) + if v.cfg.PackagePath == "" || cu.Imports == nil { + return cu + } + for _, rp := range cu.Imports.Elements { + if internal.ImportPath(rp.Element) == v.cfg.PackagePath { + cu = internal.RemoveFromBlock(cu, rp.Element) + } + } + return cu +} diff --git a/rewrite-go/pkg/recipe/golang/remove_unused_imports.go b/rewrite-go/pkg/recipe/golang/remove_unused_imports.go new file mode 100644 index 00000000000..f9e0da31c27 --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/remove_unused_imports.go @@ -0,0 +1,80 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang/internal" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// RemoveUnusedImports drops imports whose alias is never referenced by an +// identifier in the file. Mirrors `goimports -w` behavior. +// +// Per the C3 directive: a single identifier walker driven by Type.Owner +// computes the set of referenced packages. Aliased and dot imports +// resolve uniformly through the identifier's type FQN. Blank imports +// (`import _ "path"`) are preserved by semantic rule — they exist for +// their init() side-effects, not for any user-visible reference. +// +// Recipe authors who want to drop blank imports should use RemoveImport +// targeting the specific path. +type RemoveUnusedImports struct { + recipe.Base +} + +func (r *RemoveUnusedImports) Name() string { return "org.openrewrite.golang.RemoveUnusedImports" } +func (r *RemoveUnusedImports) DisplayName() string { return "Remove unused imports" } +func (r *RemoveUnusedImports) Description() string { + return "Remove imports for packages that are not referenced by any identifier in the file. Blank (`_`) imports are preserved." +} + +func (r *RemoveUnusedImports) Editor() recipe.TreeVisitor { + return visitor.Init(&removeUnusedImportsVisitor{}) +} + +type removeUnusedImportsVisitor struct { + visitor.GoVisitor +} + +func (v *removeUnusedImportsVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + cu = v.GoVisitor.VisitCompilationUnit(cu, p).(*tree.CompilationUnit) + if cu.Imports == nil || len(cu.Imports.Elements) == 0 { + return cu + } + refs := internal.ReferencedPackages(cu) + for _, rp := range cu.Imports.Elements { + imp := rp.Element + if imp == nil { + continue + } + // Blank imports stay — they exist for init() side-effects. + if internal.AliasName(imp) == "_" { + continue + } + // Dot imports also stay — referenced packages can't be tracked + // by FQN because dot-imported names enter the local scope. + if internal.AliasName(imp) == "." { + continue + } + if !refs[internal.ImportPath(imp)] { + cu = internal.RemoveFromBlock(cu, imp) + } + } + return cu +} diff --git a/rewrite-go/pkg/recipe/golang/rename_package.go b/rewrite-go/pkg/recipe/golang/rename_package.go new file mode 100644 index 00000000000..8d331f62e6b --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/rename_package.go @@ -0,0 +1,172 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "path" + "strings" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang/internal" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// RenamePackage renames a Go package across a project. +// +// On every visited compilation unit: +// - If the file's `package` declaration matches the last segment of +// OldPackagePath AND the file is in the renamed package, the +// declaration is rewritten to the last segment of NewPackagePath. +// - Every import whose path equals OldPackagePath OR is under +// OldPackagePath as a sub-path is rewritten to the corresponding +// NewPackagePath path. Sub-paths preserve their suffix: +// `import "old/foo/sub"` becomes `import "new/foo/sub"` when +// renaming `old/foo` to `new/foo`. +// +// The recipe is idempotent: re-running it on a file that's already at +// the new package name is a no-op. +type RenamePackage struct { + recipe.Base + OldPackagePath string + NewPackagePath string +} + +func (r *RenamePackage) Name() string { return "org.openrewrite.golang.RenamePackage" } +func (r *RenamePackage) DisplayName() string { return "Rename package" } +func (r *RenamePackage) Description() string { + return "Rename a Go package across a project — rewrites the `package` declaration in files that own the package, and rewrites import paths in every file that references it." +} + +func (r *RenamePackage) Options() []recipe.OptionDescriptor { + return []recipe.OptionDescriptor{ + recipe.Option("oldPackagePath", "Old package path", "The fully qualified package path to rename."). + WithExample("github.com/old/foo").WithValue(r.OldPackagePath), + recipe.Option("newPackagePath", "New package path", "The fully qualified package path to use instead."). + WithExample("github.com/new/foo").WithValue(r.NewPackagePath), + } +} + +func (r *RenamePackage) Editor() recipe.TreeVisitor { + return visitor.Init(&renamePackageVisitor{cfg: r}) +} + +type renamePackageVisitor struct { + visitor.GoVisitor + cfg *RenamePackage +} + +func (v *renamePackageVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + cu = v.GoVisitor.VisitCompilationUnit(cu, p).(*tree.CompilationUnit) + if v.cfg.OldPackagePath == "" || v.cfg.NewPackagePath == "" { + return cu + } + + // Rewrite the package declaration when this file owns the renamed + // package. The name is only the last path segment, so we compare + // segment-by-segment. + oldName := path.Base(v.cfg.OldPackagePath) + newName := path.Base(v.cfg.NewPackagePath) + if cu.PackageDecl != nil && cu.PackageDecl.Element != nil && + cu.PackageDecl.Element.Name == oldName && + v.fileBelongsTo(cu, v.cfg.OldPackagePath) { + pkg := *cu.PackageDecl + ident := *pkg.Element + ident.Name = newName + pkg.Element = &ident + cu.PackageDecl = &pkg + } + + // Rewrite import paths. Match `OldPackagePath` exactly OR as a + // strict prefix (`OldPackagePath/...`); anything else is left + // alone. Aliased / blank / dot imports are rewritten the same way. + if cu.Imports != nil { + imps := *cu.Imports + out := make([]tree.RightPadded[*tree.Import], len(imps.Elements)) + for i, rp := range imps.Elements { + imp := rp.Element + oldPath := internal.ImportPath(imp) + newPath := rewritePath(oldPath, v.cfg.OldPackagePath, v.cfg.NewPackagePath) + if newPath == oldPath { + out[i] = rp + continue + } + imp = withImportPath(imp, newPath) + rp.Element = imp + out[i] = rp + } + imps.Elements = out + cu.Imports = &imps + } + + return cu +} + +// fileBelongsTo reports whether cu lives inside the package at +// candidatePath. The check uses the parsed `GoResolutionResult` marker's +// ModulePath plus the file's source-relative subdirectory: if the file's +// dir under the module equals candidatePath, the file belongs to it. +// +// When module context is unavailable (no `GoResolutionResult` marker), +// the file's location can't be confidently determined, so we report +// false — matching the conservative default: don't rewrite a `package` +// declaration based solely on a name match (the same name might be +// reused across unrelated directories). Tests that exercise this path +// must wrap their sources in `GoProject(...)` with a `GoMod(...)` +// sibling so the marker is propagated. +func (v *renamePackageVisitor) fileBelongsTo(cu *tree.CompilationUnit, candidatePath string) bool { + modulePath := internal.FindModulePath(cu) + if modulePath == "" { + return false + } + if !strings.HasPrefix(candidatePath, modulePath) { + return false + } + relCandidate := strings.TrimPrefix(strings.TrimPrefix(candidatePath, modulePath), "/") + relFile := path.Dir(cu.SourcePath) + if relFile == "." { + relFile = "" + } + return relFile == relCandidate +} + +func rewritePath(p, oldPath, newPath string) string { + if p == oldPath { + return newPath + } + if strings.HasPrefix(p, oldPath+"/") { + return newPath + strings.TrimPrefix(p, oldPath) + } + return p +} + +// withImportPath returns a copy of imp with its Qualid Literal source + +// value updated to the new import path. Preserves Prefix and Markers +// so the printer keeps the surrounding whitespace. +func withImportPath(imp *tree.Import, newPath string) *tree.Import { + if imp == nil { + return imp + } + c := *imp + if lit, ok := imp.Qualid.(*tree.Literal); ok { + ln := *lit + ln.Value = `"` + newPath + `"` + ln.Source = `"` + newPath + `"` + c.Qualid = &ln + } + return &c +} diff --git a/rewrite-go/pkg/recipe/golang/whitespace_validation_service.go b/rewrite-go/pkg/recipe/golang/whitespace_validation_service.go new file mode 100644 index 00000000000..771b2ac04c8 --- /dev/null +++ b/rewrite-go/pkg/recipe/golang/whitespace_validation_service.go @@ -0,0 +1,97 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package golang + +import ( + "fmt" + "strings" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// WhitespaceValidationService walks a tree and reports any Space whose +// Whitespace or Comment.Suffix contains non-whitespace characters, or +// any Comment.Text that doesn't begin with `//` or `/*`. Such content +// indicates a parser bug — the printer would otherwise re-emit +// non-whitespace as if it were spacing, silently corrupting the source. +// +// Recipes get one via recipe.Service: +// +// svc := recipe.Service[*golang.WhitespaceValidationService](cu) +// if errs := svc.Validate(cu); len(errs) > 0 { /* fail loudly */ } +// +// The test harness uses this via pkg/test, which delegates here so the +// validation logic has a single home and stays callable from recipes +// that want to self-check synthesized subtrees. +type WhitespaceValidationService struct{} + +// Validate walks the tree rooted at root and returns one descriptive +// error per offending Space / Comment. Returns nil when the tree is +// well-formed. +func (s *WhitespaceValidationService) Validate(root tree.Tree) []string { + v := visitor.Init(&whitespaceValidator{}) + v.Visit(root, nil) + return v.errs +} + +// IsValid is the boolean shorthand. Recipes that just want to assert +// "no parser corruption" can write `if !svc.IsValid(cu) { ... }`. +func (s *WhitespaceValidationService) IsValid(root tree.Tree) bool { + return len(s.Validate(root)) == 0 +} + +type whitespaceValidator struct { + visitor.GoVisitor + errs []string +} + +func (v *whitespaceValidator) VisitSpace(space tree.Space, p any) tree.Space { + if space.Whitespace != "" && !isWhitespaceOnly(space.Whitespace) { + v.errs = append(v.errs, fmt.Sprintf("Space.Whitespace contains non-whitespace: %q", truncateForError(space.Whitespace, 80))) + } + for i, c := range space.Comments { + if c.Suffix != "" && !isWhitespaceOnly(c.Suffix) { + v.errs = append(v.errs, fmt.Sprintf("Comment[%d].Suffix contains non-whitespace: %q", i, truncateForError(c.Suffix, 80))) + } + if c.Text != "" && !strings.HasPrefix(c.Text, "//") && !strings.HasPrefix(c.Text, "/*") { + v.errs = append(v.errs, fmt.Sprintf("Comment[%d].Text is not a comment: %q", i, truncateForError(c.Text, 80))) + } + } + return space +} + +func isWhitespaceOnly(s string) bool { + for _, c := range s { + if c != ' ' && c != '\t' && c != '\n' && c != '\r' { + return false + } + } + return true +} + +func truncateForError(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func init() { + recipe.RegisterService[*WhitespaceValidationService](func() any { return &WhitespaceValidationService{} }) +} diff --git a/rewrite-go/pkg/recipe/recipe.go b/rewrite-go/pkg/recipe/recipe.go index b12488898d8..6c971c30a43 100644 --- a/rewrite-go/pkg/recipe/recipe.go +++ b/rewrite-go/pkg/recipe/recipe.go @@ -17,12 +17,16 @@ package recipe import ( + "reflect" "time" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" ) -// TreeVisitor can visit and transform a tree node. +// TreeVisitor can visit and transform a tree node. Visitors that need +// ancestor context (a "cursor") expose it as state — typically by +// embedding visitor.GoVisitor and calling its Cursor() accessor — to +// match the Java OpenRewrite visitor pattern. type TreeVisitor interface { Visit(t tree.Tree, p any) tree.Tree } @@ -53,6 +57,26 @@ type Recipe interface { // Options returns descriptors for this recipe's configurable options. Options() []OptionDescriptor + + // Preconditions returns sub-recipes that gate execution: this recipe only runs + // when all preconditions match. May be nil. + Preconditions() []Recipe + + // DataTables returns descriptors for any DataTable rows this recipe emits. + // May be nil. The runtime for writing rows is provided by the DataTable + // type and DataTableStore (separate from these descriptors). + DataTables() []DataTableDescriptor + + // Maintainers returns the people responsible for the recipe. May be nil. + Maintainers() []Maintainer + + // Contributors returns the people who have contributed code or design. + // May be nil. + Contributors() []Contributor + + // Examples returns before/after examples illustrating the recipe. + // May be nil. + Examples() []Example } // Base provides default implementations for optional Recipe methods. @@ -68,6 +92,11 @@ func (Base) EstimatedEffortPerOccurrence() time.Duration { return 5 * time.Minut func (Base) Editor() TreeVisitor { return nil } func (Base) RecipeList() []Recipe { return nil } func (Base) Options() []OptionDescriptor { return nil } +func (Base) Preconditions() []Recipe { return nil } +func (Base) DataTables() []DataTableDescriptor { return nil } +func (Base) Maintainers() []Maintainer { return nil } +func (Base) Contributors() []Contributor { return nil } +func (Base) Examples() []Example { return nil } // DelegatesTo marks a recipe that delegates entirely to a Java-side recipe. // When the Java host calls PrepareRecipe, the Go server includes the @@ -130,6 +159,73 @@ func (o OptionDescriptor) WithValid(v ...string) OptionDescriptor { o.Valid = v; // AsOptional marks the option as not required. func (o OptionDescriptor) AsOptional() OptionDescriptor { o.Required = false; return o } +// TypeName returns a Java-style type name derived from the option's Value. +// Used to populate the marketplace option `type` wire field. +func (o OptionDescriptor) TypeName() string { + if o.Value == nil { + return "String" + } + switch reflect.TypeOf(o.Value).Kind() { + case reflect.Bool: + return "Boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + return "Integer" + case reflect.Int64: + return "Long" + case reflect.Float32, reflect.Float64: + return "Double" + case reflect.Slice, reflect.Array: + return "List" + default: + return "String" + } +} + +// DataTableDescriptor describes a DataTable a recipe emits. +type DataTableDescriptor struct { + Name string + DisplayName string + Description string + Columns []ColumnDescriptor +} + +// ColumnDescriptor describes a column within a DataTable. +type ColumnDescriptor struct { + Name string + DisplayName string + Description string + Type string // e.g. "String", "Integer", "Long" +} + +// Maintainer represents a recipe maintainer. +type Maintainer struct { + Name string + Email string + Logo string +} + +// Contributor represents someone who has contributed to the recipe. +type Contributor struct { + Name string + Email string + LineCount int +} + +// Example is a before/after example illustrating the recipe. +type Example struct { + Description string + Sources []ExampleSource + Parameters []string +} + +// ExampleSource is one before/after pair within an Example. +type ExampleSource struct { + Before string + After string + Path string + Language string +} + // RecipeDescriptor provides metadata about a recipe for display and serialization. type RecipeDescriptor struct { Name string @@ -139,22 +235,51 @@ type RecipeDescriptor struct { EstimatedEffortPerOccurrence time.Duration Options []OptionDescriptor RecipeList []RecipeDescriptor + Preconditions []RecipeDescriptor + DataTables []DataTableDescriptor + Maintainers []Maintainer + Contributors []Contributor + Examples []Example } -// Describe creates a RecipeDescriptor from a Recipe. +// Describe creates a RecipeDescriptor from a Recipe. Recursive descriptors +// (RecipeList, Preconditions) are protected against cycles via a visited set +// keyed by recipe name; if a recipe name re-appears in the descent, a stub +// descriptor with just the name and display name is returned in its place. func Describe(r Recipe) RecipeDescriptor { + return describe(r, map[string]bool{}) +} + +func describe(r Recipe, seen map[string]bool) RecipeDescriptor { + name := r.Name() + if seen[name] { + return RecipeDescriptor{Name: name, DisplayName: r.DisplayName()} + } + seen[name] = true + defer delete(seen, name) + desc := RecipeDescriptor{ - Name: r.Name(), + Name: name, DisplayName: r.DisplayName(), Description: r.Description(), Tags: r.Tags(), EstimatedEffortPerOccurrence: r.EstimatedEffortPerOccurrence(), Options: r.Options(), + DataTables: r.DataTables(), + Maintainers: r.Maintainers(), + Contributors: r.Contributors(), + Examples: r.Examples(), } if subs := r.RecipeList(); len(subs) > 0 { desc.RecipeList = make([]RecipeDescriptor, len(subs)) for i, sub := range subs { - desc.RecipeList[i] = Describe(sub) + desc.RecipeList[i] = describe(sub, seen) + } + } + if pres := r.Preconditions(); len(pres) > 0 { + desc.Preconditions = make([]RecipeDescriptor, len(pres)) + for i, pre := range pres { + desc.Preconditions[i] = describe(pre, seen) } } return desc diff --git a/rewrite-go/pkg/recipe/service.go b/rewrite-go/pkg/recipe/service.go new file mode 100644 index 00000000000..4cee054ff40 --- /dev/null +++ b/rewrite-go/pkg/recipe/service.go @@ -0,0 +1,85 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package recipe + +import ( + "reflect" + "sync" +) + +// Mirrors org.openrewrite.SourceFile.service(Class): an extension +// point for languages to expose helper services keyed by service type. +// Recipes call Service[T](sourceFile) (or use the package-specific +// helper in pkg/recipe/golang/service) to compose follow-up visitors +// without coupling to specific recipe constructors. +// +// This file holds the Go-side mechanism: a process-wide registry of +// service factories keyed by the service type's reflect.Type. Languages +// register their factories at init() time; callers look them up via +// Service[T](anything) — the lookup ignores the source-file argument +// for now, since Go-side services are stateless. Future stateful +// services (like a ModuleResolutionService keyed by go.mod) can keep +// the same surface and route via the source-file's markers. + +var ( + registryMu sync.RWMutex + registry = map[reflect.Type]func() any{} +) + +// RegisterService installs a factory that produces a service of type T. +// Languages call this from their package init(). +// +// func init() { +// recipe.RegisterService[ImportService](func() any { return &importService{} }) +// } +func RegisterService[T any](factory func() any) { + var zero T + registryMu.Lock() + defer registryMu.Unlock() + registry[reflect.TypeOf(&zero).Elem()] = factory +} + +// Service returns the registered service of type T. Panics with a +// descriptive error when no factory is registered — recipes that depend +// on a service should let the panic surface (it indicates a missing +// language activation). +// +// svc := recipe.Service[golangservice.ImportService](cu) +// v.DoAfterVisit(svc.AddImportVisitor("fmt", nil, false)) +// +// `sourceFile` is unused today; the parameter exists so the call site +// reads naturally and so future stateful services can route on it +// without breaking callers. +func Service[T any](sourceFile any) T { + _ = sourceFile + var zero T + t := reflect.TypeOf(&zero).Elem() + registryMu.RLock() + factory, ok := registry[t] + registryMu.RUnlock() + if !ok { + panic("recipe: no service registered for " + t.String() + + " — did the language's package init() run? " + + "Import the language package (e.g. pkg/recipe/golang) to register services.") + } + v, ok := factory().(T) + if !ok { + panic("recipe: service factory for " + t.String() + + " produced an instance that doesn't satisfy the requested type") + } + return v +} diff --git a/rewrite-go/pkg/rpc/annotation_rpc_test.go b/rewrite-go/pkg/rpc/annotation_rpc_test.go new file mode 100644 index 00000000000..c8193053a1b --- /dev/null +++ b/rewrite-go/pkg/rpc/annotation_rpc_test.go @@ -0,0 +1,153 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rpc + +import ( + "reflect" + "testing" + + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// roundTripNode serializes `before` via GoSender, then feeds the +// emitted RpcObjectData stream into a ReceiveQueue and reads it via +// GoReceiver. The `seed` argument is the empty node skeleton the +// receiver starts from (matching how a real session has a baseline +// from a prior GET_OBJECT cycle). +func roundTripNode(t *testing.T, before tree.Tree, seed tree.Tree) any { + t.Helper() + var messages []RpcObjectData + sendQ := NewSendQueue(1000, func(batch []RpcObjectData) { + messages = append(messages, batch...) + }, make(map[uintptr]int)) + NewGoSender().Visit(before, sendQ) + sendQ.Flush() + + delivered := false + recvQ := NewReceiveQueue(make(map[int]any), func() []RpcObjectData { + if delivered { + return nil + } + delivered = true + return messages + }) + return NewGoReceiver().Visit(seed, recvQ) +} + +func TestAnnotationRpcRoundTrip_BasicTag(t *testing.T) { + // Mirror of the `json:"name"` shape the parser will eventually emit + // for struct field tags. + annID := uuid.MustParse("aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa") + typeID := uuid.MustParse("bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb") + litID := uuid.MustParse("cccccccc-1111-2222-3333-cccccccccccc") + before := &tree.Annotation{ + ID: annID, + AnnotationType: &tree.Identifier{ID: typeID, Name: "json"}, + Arguments: &tree.Container[tree.Expression]{ + Elements: []tree.RightPadded[tree.Expression]{ + {Element: &tree.Literal{ + ID: litID, + Source: `"name"`, + Value: "name", + Kind: tree.StringLiteral, + }}, + }, + }, + } + + seed := &tree.Annotation{ID: annID} + got := roundTripNode(t, before, seed).(*tree.Annotation) + + if got.ID != annID { + t.Errorf("ID: got %s, want %s", got.ID, annID) + } + gotType, ok := got.AnnotationType.(*tree.Identifier) + if !ok { + t.Fatalf("AnnotationType: got %T, want *Identifier", got.AnnotationType) + } + if gotType.Name != "json" { + t.Errorf("AnnotationType.Name: got %q, want %q", gotType.Name, "json") + } + if got.Arguments == nil { + t.Fatal("Arguments: got nil, want non-nil") + } + if len(got.Arguments.Elements) != 1 { + t.Fatalf("Arguments.Elements: got %d, want 1", len(got.Arguments.Elements)) + } + gotLit, ok := got.Arguments.Elements[0].Element.(*tree.Literal) + if !ok { + t.Fatalf("Arguments[0]: got %T, want *Literal", got.Arguments.Elements[0].Element) + } + if gotLit.Source != `"name"` { + t.Errorf("Arguments[0].Source: got %q, want %q", gotLit.Source, `"name"`) + } + if v, _ := gotLit.Value.(string); v != "name" { + t.Errorf("Arguments[0].Value: got %v, want %q", gotLit.Value, "name") + } +} + +func TestAnnotationRpcRoundTrip_NoArguments(t *testing.T) { + // Bare-args case (Arguments == nil) — what `//go:noinline` will + // produce. Receiver must produce nil Arguments, not an empty + // Container. + annID := uuid.MustParse("dddddddd-1111-2222-3333-dddddddddddd") + typeID := uuid.MustParse("eeeeeeee-1111-2222-3333-eeeeeeeeeeee") + before := &tree.Annotation{ + ID: annID, + AnnotationType: &tree.Identifier{ID: typeID, Name: "go:noinline"}, + } + + seed := &tree.Annotation{ID: annID} + got := roundTripNode(t, before, seed).(*tree.Annotation) + + if got.Arguments != nil { + t.Errorf("Arguments: got %+v, want nil", got.Arguments) + } + if !reflect.DeepEqual(got.AnnotationType.(*tree.Identifier).Name, "go:noinline") { + t.Errorf("AnnotationType.Name: got %q, want %q", got.AnnotationType.(*tree.Identifier).Name, "go:noinline") + } +} + +func TestAnnotationRpcRoundTrip_PrefixPreserved(t *testing.T) { + annID := uuid.MustParse("ffffffff-1111-2222-3333-ffffffffffff") + typeID := uuid.MustParse("00000000-1111-2222-3333-000000000000") + litID := uuid.MustParse("11111111-aaaa-bbbb-cccc-111111111111") + before := &tree.Annotation{ + ID: annID, + Prefix: tree.Space{Whitespace: " "}, + AnnotationType: &tree.Identifier{ID: typeID, Name: "validate"}, + Arguments: &tree.Container[tree.Expression]{ + Elements: []tree.RightPadded[tree.Expression]{ + {Element: &tree.Literal{ + ID: litID, + Source: `"required"`, + Value: "required", + Kind: tree.StringLiteral, + }}, + }, + }, + } + + seed := &tree.Annotation{ID: annID} + got := roundTripNode(t, before, seed).(*tree.Annotation) + + if got.Prefix.Whitespace != " " { + t.Errorf("Prefix.Whitespace: got %q, want %q", got.Prefix.Whitespace, " ") + } +} diff --git a/rewrite-go/pkg/rpc/go_receiver.go b/rewrite-go/pkg/rpc/go_receiver.go index cdf4ad5b825..6474c3c98e0 100644 --- a/rewrite-go/pkg/rpc/go_receiver.go +++ b/rewrite-go/pkg/rpc/go_receiver.go @@ -19,224 +19,52 @@ package rpc import ( "github.com/google/uuid" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) -// GoReceiver deserializes Go AST nodes from the receive queue. -// Handles G (Go-specific) nodes and delegates J nodes to JavaReceiver. +// GoReceiver deserializes Go AST nodes via the visitor pattern. +// Mirrors org.openrewrite.golang.internal.rpc.GolangReceiver, which +// extends JavaReceiver, which extends JavaVisitor. +// +// GoReceiver embeds JavaReceiver to inherit J-node Visit overrides +// and the PreVisit hook (id/prefix/markers), and adds VisitX +// overrides for G-specific node types. Self is set by visitor.Init +// so the framework's type-switch dispatch routes to the most-derived +// override. type GoReceiver struct { - java JavaReceiver + JavaReceiver } -// NewGoReceiver creates a GoReceiver with its JavaReceiver properly wired. +// NewGoReceiver creates a GoReceiver ready to deserialize trees. func NewGoReceiver() *GoReceiver { gr := &GoReceiver{} - gr.java = JavaReceiver{ - typeReceiver: NewJavaTypeReceiver(), - parent: gr, - } - return gr + gr.typeReceiver = NewJavaTypeReceiver() + return visitor.Init(gr) } -// Visit dispatches to the appropriate receive method based on node type. -func (r *GoReceiver) Visit(node any, q *ReceiveQueue) any { - if node == nil { +// Visit overrides the framework dispatch to special-case ParseError — +// it isn't a J node, has no Prefix/Markers, and uses its own codec. +// All other tree types fall through to the framework's switch. +func (r *GoReceiver) Visit(t tree.Tree, p any) tree.Tree { + if t == nil { return nil } - - // ParseError has its own codec — handle before preVisit (no prefix field) - if pe, ok := node.(*tree.ParseError); ok { + if pe, ok := t.(*tree.ParseError); ok { c := *pe - return r.receiveParseError(&c, q) - } - - // preVisit: receive ID, prefix, markers - node = r.preVisit(node, q) - - switch v := node.(type) { - // G nodes - case *tree.CompilationUnit: - return r.receiveCompilationUnit(v, q) - case *tree.GoStmt: - return r.receiveGoStmt(v, q) - case *tree.Defer: - return r.receiveDefer(v, q) - case *tree.Send: - return r.receiveSend(v, q) - case *tree.Goto: - return r.receiveGoto(v, q) - case *tree.Fallthrough: - return v - case *tree.Composite: - return r.receiveComposite(v, q) - case *tree.KeyValue: - return r.receiveKeyValue(v, q) - case *tree.Slice: - return r.receiveSlice(v, q) - case *tree.MapType: - return r.receiveMapType(v, q) - case *tree.StatementExpression: - return r.receiveStatementExpression(v, q) - case *tree.PointerType: - return r.receivePointerType(v, q) - case *tree.Channel: - return r.receiveChannel(v, q) - case *tree.FuncType: - return r.receiveFuncType(v, q) - case *tree.StructType: - return r.receiveStructType(v, q) - case *tree.InterfaceType: - return r.receiveInterfaceType(v, q) - case *tree.TypeList: - return r.receiveTypeList(v, q) - case *tree.TypeDecl: - return r.receiveTypeDecl(v, q) - case *tree.MultiAssignment: - return r.receiveMultiAssignment(v, q) - case *tree.CommClause: - return r.receiveCommClause(v, q) - case *tree.IndexList: - return r.receiveIndexList(v, q) - - default: - // Delegate all J nodes to JavaReceiver - return r.java.visitJ(node, q) + return r.receiveParseError(&c, p.(*ReceiveQueue)) } + return r.GoVisitor.Visit(t, p) } -func (r *GoReceiver) preVisit(node any, q *ReceiveQueue) any { - // shallow copy to avoid mutating remoteObjects baseline - switch n := node.(type) { - // G nodes - case *tree.CompilationUnit: - c := *n; node = &c - case *tree.GoStmt: - c := *n; node = &c - case *tree.Defer: - c := *n; node = &c - case *tree.Send: - c := *n; node = &c - case *tree.Goto: - c := *n; node = &c - case *tree.Fallthrough: - c := *n; node = &c - case *tree.Composite: - c := *n; node = &c - case *tree.KeyValue: - c := *n; node = &c - case *tree.Slice: - c := *n; node = &c - case *tree.MapType: - c := *n; node = &c - case *tree.StatementExpression: - c := *n; node = &c - case *tree.PointerType: - c := *n; node = &c - case *tree.Channel: - c := *n; node = &c - case *tree.FuncType: - c := *n; node = &c - case *tree.StructType: - c := *n; node = &c - case *tree.InterfaceType: - c := *n; node = &c - case *tree.TypeList: - c := *n; node = &c - case *tree.TypeDecl: - c := *n; node = &c - case *tree.MultiAssignment: - c := *n; node = &c - case *tree.CommClause: - c := *n; node = &c - case *tree.IndexList: - c := *n; node = &c - // J nodes - case *tree.Identifier: - c := *n; node = &c - case *tree.Literal: - c := *n; node = &c - case *tree.Binary: - c := *n; node = &c - case *tree.Block: - c := *n; node = &c - case *tree.Empty: - c := *n; node = &c - case *tree.Unary: - c := *n; node = &c - case *tree.FieldAccess: - c := *n; node = &c - case *tree.MethodInvocation: - c := *n; node = &c - case *tree.Assignment: - c := *n; node = &c - case *tree.AssignmentOperation: - c := *n; node = &c - case *tree.MethodDeclaration: - c := *n; node = &c - case *tree.VariableDeclarations: - c := *n; node = &c - case *tree.VariableDeclarator: - c := *n; node = &c - case *tree.Return: - c := *n; node = &c - case *tree.If: - c := *n; node = &c - case *tree.Else: - c := *n; node = &c - case *tree.ForLoop: - c := *n; node = &c - case *tree.ForControl: - c := *n; node = &c - case *tree.ForEachLoop: - c := *n; node = &c - case *tree.ForEachControl: - c := *n; node = &c - case *tree.Switch: - c := *n; node = &c - case *tree.Case: - c := *n; node = &c - case *tree.Break: - c := *n; node = &c - case *tree.Continue: - c := *n; node = &c - case *tree.Label: - c := *n; node = &c - case *tree.ArrayType: - c := *n; node = &c - case *tree.ArrayAccess: - c := *n; node = &c - case *tree.ArrayDimension: - c := *n; node = &c - case *tree.Parentheses: - c := *n; node = &c - case *tree.TypeCast: - c := *n; node = &c - case *tree.ControlParentheses: - c := *n; node = &c - case *tree.Import: - c := *n; node = &c - } - - // ID - q.Receive(extractID(node), nil) - // Prefix - if result := q.Receive(nodePrefix(node), func(v any) any { - return receiveSpace(v.(tree.Space), q) - }); result != nil { - setPrefix(node, result.(tree.Space)) - } - // Markers - if result := q.Receive(nodeMarkers(node), func(v any) any { - return receiveMarkersCodec(q, v.(tree.Markers)) - }); result != nil { - setMarkers(node, result.(tree.Markers)) - } - return node -} // --- G nodes --- // receiveParseError deserializes a ParseError matching Java's ParseError.rpcReceive field order: -// id, markers, sourcePath, charsetName, charsetBomMarked, checksum, fileAttributes, text +// id, markers, sourcePath, charsetName, charsetBomMarked, checksum, fileAttributes, text. +// +// ParseError isn't a J node — no Prefix/Markers handling via PreVisit. +// Special-cased at the top of Visit() instead of dispatched through +// the framework switch. func (r *GoReceiver) receiveParseError(pe *tree.ParseError, q *ReceiveQueue) *tree.ParseError { idStr := receiveScalar[string](q, pe.Ident.String()) if idStr != "" { @@ -254,7 +82,8 @@ func (r *GoReceiver) receiveParseError(pe *tree.ParseError, q *ReceiveQueue) *tr return pe } -func (r *GoReceiver) receiveCompilationUnit(cu *tree.CompilationUnit, q *ReceiveQueue) *tree.CompilationUnit { +func (r *GoReceiver) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + q := p.(*ReceiveQueue) c := *cu // shallow copy to avoid mutating remoteObjects baseline cu = &c cu.SourcePath = receiveScalar[string](q, cu.SourcePath) @@ -320,30 +149,33 @@ func (r *GoReceiver) receiveCompilationUnit(cu *tree.CompilationUnit, q *Receive return cu } -func (r *GoReceiver) receiveGoStmt(gs *tree.GoStmt, q *ReceiveQueue) *tree.GoStmt { +func (r *GoReceiver) VisitGoStmt(gs *tree.GoStmt, p any) tree.J { + q := p.(*ReceiveQueue) c := *gs // shallow copy to avoid mutating remoteObjects baseline gs = &c - result := q.Receive(gs.Expr, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(gs.Expr, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { gs.Expr = result.(tree.Expression) } return gs } -func (r *GoReceiver) receiveDefer(d *tree.Defer, q *ReceiveQueue) *tree.Defer { +func (r *GoReceiver) VisitDefer(d *tree.Defer, p any) tree.J { + q := p.(*ReceiveQueue) c := *d // shallow copy to avoid mutating remoteObjects baseline d = &c - result := q.Receive(d.Expr, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(d.Expr, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { d.Expr = result.(tree.Expression) } return d } -func (r *GoReceiver) receiveSend(sn *tree.Send, q *ReceiveQueue) *tree.Send { +func (r *GoReceiver) VisitSend(sn *tree.Send, p any) tree.J { + q := p.(*ReceiveQueue) c := *sn // shallow copy to avoid mutating remoteObjects baseline sn = &c - result := q.Receive(sn.Channel, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(sn.Channel, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { sn.Channel = result.(tree.Expression) } @@ -353,20 +185,22 @@ func (r *GoReceiver) receiveSend(sn *tree.Send, q *ReceiveQueue) *tree.Send { return sn } -func (r *GoReceiver) receiveGoto(g *tree.Goto, q *ReceiveQueue) *tree.Goto { +func (r *GoReceiver) VisitGoto(g *tree.Goto, p any) tree.J { + q := p.(*ReceiveQueue) c := *g // shallow copy to avoid mutating remoteObjects baseline g = &c - result := q.Receive(g.Label, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(g.Label, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { g.Label = result.(*tree.Identifier) } return g } -func (r *GoReceiver) receiveComposite(comp *tree.Composite, q *ReceiveQueue) *tree.Composite { +func (r *GoReceiver) VisitComposite(comp *tree.Composite, p any) tree.J { + q := p.(*ReceiveQueue) c := *comp // shallow copy to avoid mutating remoteObjects baseline comp = &c - result := q.Receive(comp.TypeExpr, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(comp.TypeExpr, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { comp.TypeExpr = result.(tree.Expression) } @@ -376,10 +210,11 @@ func (r *GoReceiver) receiveComposite(comp *tree.Composite, q *ReceiveQueue) *tr return comp } -func (r *GoReceiver) receiveKeyValue(kv *tree.KeyValue, q *ReceiveQueue) *tree.KeyValue { +func (r *GoReceiver) VisitKeyValue(kv *tree.KeyValue, p any) tree.J { + q := p.(*ReceiveQueue) c := *kv // shallow copy to avoid mutating remoteObjects baseline kv = &c - result := q.Receive(kv.Key, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(kv.Key, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { kv.Key = result.(tree.Expression) } @@ -389,10 +224,11 @@ func (r *GoReceiver) receiveKeyValue(kv *tree.KeyValue, q *ReceiveQueue) *tree.K return kv } -func (r *GoReceiver) receiveSlice(sl *tree.Slice, q *ReceiveQueue) *tree.Slice { +func (r *GoReceiver) VisitSlice(sl *tree.Slice, p any) tree.J { + q := p.(*ReceiveQueue) c := *sl // shallow copy to avoid mutating remoteObjects baseline sl = &c - result := q.Receive(sl.Indexed, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(sl.Indexed, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { sl.Indexed = result.(tree.Expression) } @@ -405,7 +241,7 @@ func (r *GoReceiver) receiveSlice(sl *tree.Slice, q *ReceiveQueue) *tree.Slice { if result := q.Receive(sl.High, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { sl.High = coerceToExpressionRP(result) } - max := q.Receive(sl.Max, func(v any) any { return r.Visit(v, q) }) + max := q.Receive(sl.Max, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if max != nil { sl.Max = max.(tree.Expression) } @@ -415,7 +251,8 @@ func (r *GoReceiver) receiveSlice(sl *tree.Slice, q *ReceiveQueue) *tree.Slice { return sl } -func (r *GoReceiver) receiveMapType(mt *tree.MapType, q *ReceiveQueue) *tree.MapType { +func (r *GoReceiver) VisitMapType(mt *tree.MapType, p any) tree.J { + q := p.(*ReceiveQueue) c := *mt // shallow copy to avoid mutating remoteObjects baseline mt = &c if result := q.Receive(mt.OpenBracket, func(v any) any { return receiveSpace(v.(tree.Space), q) }); result != nil { @@ -424,34 +261,37 @@ func (r *GoReceiver) receiveMapType(mt *tree.MapType, q *ReceiveQueue) *tree.Map if result := q.Receive(mt.Key, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { mt.Key = coerceToExpressionRP(result) } - result := q.Receive(mt.Value, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(mt.Value, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { mt.Value = result.(tree.Expression) } return mt } -func (r *GoReceiver) receiveStatementExpression(se *tree.StatementExpression, q *ReceiveQueue) *tree.StatementExpression { +func (r *GoReceiver) VisitStatementExpression(se *tree.StatementExpression, p any) tree.J { + q := p.(*ReceiveQueue) c := *se se = &c - result := q.Receive(se.Statement, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(se.Statement, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { se.Statement = result.(tree.Statement) } return se } -func (r *GoReceiver) receivePointerType(pt *tree.PointerType, q *ReceiveQueue) *tree.PointerType { +func (r *GoReceiver) VisitPointerType(pt *tree.PointerType, p any) tree.J { + q := p.(*ReceiveQueue) c := *pt pt = &c - result := q.Receive(pt.Elem, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(pt.Elem, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { pt.Elem = result.(tree.Expression) } return pt } -func (r *GoReceiver) receiveChannel(ch *tree.Channel, q *ReceiveQueue) *tree.Channel { +func (r *GoReceiver) VisitChannel(ch *tree.Channel, p any) tree.J { + q := p.(*ReceiveQueue) c := *ch // shallow copy to avoid mutating remoteObjects baseline ch = &c dirStr := receiveScalar[string](q, "") @@ -463,47 +303,51 @@ func (r *GoReceiver) receiveChannel(ch *tree.Channel, q *ReceiveQueue) *tree.Cha case "RECV_ONLY": ch.Dir = tree.ChanRecvOnly } - result := q.Receive(ch.Value, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(ch.Value, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { ch.Value = result.(tree.Expression) } return ch } -func (r *GoReceiver) receiveFuncType(ft *tree.FuncType, q *ReceiveQueue) *tree.FuncType { +func (r *GoReceiver) VisitFuncType(ft *tree.FuncType, p any) tree.J { + q := p.(*ReceiveQueue) c := *ft // shallow copy to avoid mutating remoteObjects baseline ft = &c if result := q.Receive(ft.Parameters, func(v any) any { return receiveContainerAs(r, q, v, ContainerStatement) }); result != nil { ft.Parameters = result.(tree.Container[tree.Statement]) } - result := q.Receive(ft.ReturnType, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(ft.ReturnType, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { ft.ReturnType = result.(tree.Expression) } return ft } -func (r *GoReceiver) receiveStructType(st *tree.StructType, q *ReceiveQueue) *tree.StructType { +func (r *GoReceiver) VisitStructType(st *tree.StructType, p any) tree.J { + q := p.(*ReceiveQueue) c := *st // shallow copy to avoid mutating remoteObjects baseline st = &c - result := q.Receive(st.Body, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(st.Body, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { st.Body = result.(*tree.Block) } return st } -func (r *GoReceiver) receiveInterfaceType(it *tree.InterfaceType, q *ReceiveQueue) *tree.InterfaceType { +func (r *GoReceiver) VisitInterfaceType(it *tree.InterfaceType, p any) tree.J { + q := p.(*ReceiveQueue) c := *it // shallow copy to avoid mutating remoteObjects baseline it = &c - result := q.Receive(it.Body, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(it.Body, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { it.Body = result.(*tree.Block) } return it } -func (r *GoReceiver) receiveTypeList(tl *tree.TypeList, q *ReceiveQueue) *tree.TypeList { +func (r *GoReceiver) VisitTypeList(tl *tree.TypeList, p any) tree.J { + q := p.(*ReceiveQueue) c := *tl // shallow copy to avoid mutating remoteObjects baseline tl = &c if result := q.Receive(tl.Types, func(v any) any { return receiveContainerAs(r, q, v, ContainerStatement) }); result != nil { @@ -512,10 +356,25 @@ func (r *GoReceiver) receiveTypeList(tl *tree.TypeList, q *ReceiveQueue) *tree.T return tl } -func (r *GoReceiver) receiveTypeDecl(td *tree.TypeDecl, q *ReceiveQueue) *tree.TypeDecl { +func (r *GoReceiver) VisitTypeDecl(td *tree.TypeDecl, p any) tree.J { + q := p.(*ReceiveQueue) c := *td // shallow copy to avoid mutating remoteObjects baseline td = &c - result := q.Receive(td.Name, func(v any) any { return r.Visit(v, q) }) + // leadingAnnotations + beforeAnns := make([]any, len(td.LeadingAnnotations)) + for i, a := range td.LeadingAnnotations { + beforeAnns[i] = a + } + afterAnns := q.ReceiveList(beforeAnns, func(v any) any { return r.Visit(v.(tree.Tree), q) }) + if afterAnns != nil { + td.LeadingAnnotations = make([]*tree.Annotation, 0, len(afterAnns)) + for _, a := range afterAnns { + if a != nil { + td.LeadingAnnotations = append(td.LeadingAnnotations, a.(*tree.Annotation)) + } + } + } + result := q.Receive(td.Name, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { td.Name = result.(*tree.Identifier) } @@ -529,7 +388,7 @@ func (r *GoReceiver) receiveTypeDecl(td *tree.TypeDecl, q *ReceiveQueue) *tree.T } else { td.Assign = nil } - defResult := q.Receive(td.Definition, func(v any) any { return r.Visit(v, q) }) + defResult := q.Receive(td.Definition, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if defResult != nil { td.Definition = defResult.(tree.Expression) } @@ -546,7 +405,8 @@ func (r *GoReceiver) receiveTypeDecl(td *tree.TypeDecl, q *ReceiveQueue) *tree.T return td } -func (r *GoReceiver) receiveMultiAssignment(ma *tree.MultiAssignment, q *ReceiveQueue) *tree.MultiAssignment { +func (r *GoReceiver) VisitMultiAssignment(ma *tree.MultiAssignment, p any) tree.J { + q := p.(*ReceiveQueue) c := *ma // shallow copy to avoid mutating remoteObjects baseline ma = &c // Variables @@ -580,10 +440,11 @@ func (r *GoReceiver) receiveMultiAssignment(ma *tree.MultiAssignment, q *Receive return ma } -func (r *GoReceiver) receiveCommClause(cc *tree.CommClause, q *ReceiveQueue) *tree.CommClause { +func (r *GoReceiver) VisitCommClause(cc *tree.CommClause, p any) tree.J { + q := p.(*ReceiveQueue) c := *cc // shallow copy to avoid mutating remoteObjects baseline cc = &c - result := q.Receive(cc.Comm, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(cc.Comm, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { cc.Comm = result.(tree.Statement) } @@ -605,10 +466,11 @@ func (r *GoReceiver) receiveCommClause(cc *tree.CommClause, q *ReceiveQueue) *tr return cc } -func (r *GoReceiver) receiveIndexList(il *tree.IndexList, q *ReceiveQueue) *tree.IndexList { +func (r *GoReceiver) VisitIndexList(il *tree.IndexList, p any) tree.J { + q := p.(*ReceiveQueue) c := *il // shallow copy to avoid mutating remoteObjects baseline il = &c - result := q.Receive(il.Target, func(v any) any { return r.Visit(v, q) }) + result := q.Receive(il.Target, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { il.Target = result.(tree.Expression) } diff --git a/rewrite-go/pkg/rpc/go_resolution_result_codec.go b/rewrite-go/pkg/rpc/go_resolution_result_codec.go new file mode 100644 index 00000000000..7ef43250da1 --- /dev/null +++ b/rewrite-go/pkg/rpc/go_resolution_result_codec.go @@ -0,0 +1,294 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rpc + +import ( + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// sendGoResolutionResult mirrors Java's +// org.openrewrite.golang.marker.GoResolutionResult#rpcSend. +// +// Field order MUST match the Java side exactly, otherwise the cross- +// language round-trip queue desyncs: +// +// 1. id (UUID string) +// 2. modulePath (String) +// 3. goVersion (String, nullable) +// 4. toolchain (String, nullable) +// 5. path (String) +// 6. requires (List, ref-by-key) +// 7. replaces (List, ref-by-key) +// 8. excludes (List, ref-by-key) +// 9. retracts (List, ref-by-key) +// 10. resolvedDependencies (List, ref-by-key) +// +// Each list element invokes its own rpcSend on the Java side; we mirror +// the same field order in the per-element onChange callback. +func sendGoResolutionResult(m tree.GoResolutionResult, q *SendQueue) { + q.GetAndSend(m, func(x any) any { return x.(tree.GoResolutionResult).Ident.String() }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.GoResolutionResult).ModulePath }, nil) + q.GetAndSend(m, func(x any) any { return emptyAsNil(x.(tree.GoResolutionResult).GoVersion) }, nil) + q.GetAndSend(m, func(x any) any { return emptyAsNil(x.(tree.GoResolutionResult).Toolchain) }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.GoResolutionResult).Path }, nil) + + q.GetAndSendListAsRef(m, + func(x any) []any { return requireSlice(x.(tree.GoResolutionResult).Requires) }, + func(x any) any { + r := x.(tree.GoRequire) + return r.ModulePath + "@" + r.Version + }, + func(x any) { + r := x.(tree.GoRequire) + q.GetAndSend(r, func(y any) any { return y.(tree.GoRequire).ModulePath }, nil) + q.GetAndSend(r, func(y any) any { return y.(tree.GoRequire).Version }, nil) + q.GetAndSend(r, func(y any) any { return y.(tree.GoRequire).Indirect }, nil) + }) + + q.GetAndSendListAsRef(m, + func(x any) []any { return replaceSlice(x.(tree.GoResolutionResult).Replaces) }, + func(x any) any { + r := x.(tree.GoReplace) + return r.OldPath + "@" + r.OldVersion + "=>" + r.NewPath + "@" + r.NewVersion + }, + func(x any) { + r := x.(tree.GoReplace) + q.GetAndSend(r, func(y any) any { return y.(tree.GoReplace).OldPath }, nil) + q.GetAndSend(r, func(y any) any { return emptyAsNil(y.(tree.GoReplace).OldVersion) }, nil) + q.GetAndSend(r, func(y any) any { return y.(tree.GoReplace).NewPath }, nil) + q.GetAndSend(r, func(y any) any { return emptyAsNil(y.(tree.GoReplace).NewVersion) }, nil) + }) + + q.GetAndSendListAsRef(m, + func(x any) []any { return excludeSlice(x.(tree.GoResolutionResult).Excludes) }, + func(x any) any { + e := x.(tree.GoExclude) + return e.ModulePath + "@" + e.Version + }, + func(x any) { + e := x.(tree.GoExclude) + q.GetAndSend(e, func(y any) any { return y.(tree.GoExclude).ModulePath }, nil) + q.GetAndSend(e, func(y any) any { return y.(tree.GoExclude).Version }, nil) + }) + + q.GetAndSendListAsRef(m, + func(x any) []any { return retractSlice(x.(tree.GoResolutionResult).Retracts) }, + func(x any) any { return x.(tree.GoRetract).VersionRange }, + func(x any) { + r := x.(tree.GoRetract) + q.GetAndSend(r, func(y any) any { return y.(tree.GoRetract).VersionRange }, nil) + q.GetAndSend(r, func(y any) any { return emptyAsNil(y.(tree.GoRetract).Rationale) }, nil) + }) + + q.GetAndSendListAsRef(m, + func(x any) []any { return resolvedSlice(x.(tree.GoResolutionResult).ResolvedDependencies) }, + func(x any) any { + d := x.(tree.GoResolvedDependency) + return d.ModulePath + "@" + d.Version + }, + func(x any) { + d := x.(tree.GoResolvedDependency) + q.GetAndSend(d, func(y any) any { return y.(tree.GoResolvedDependency).ModulePath }, nil) + q.GetAndSend(d, func(y any) any { return y.(tree.GoResolvedDependency).Version }, nil) + q.GetAndSend(d, func(y any) any { return emptyAsNil(y.(tree.GoResolvedDependency).ModuleHash) }, nil) + q.GetAndSend(d, func(y any) any { return emptyAsNil(y.(tree.GoResolvedDependency).GoModHash) }, nil) + }) +} + +// receiveGoResolutionResult mirrors Java's +// org.openrewrite.golang.marker.GoResolutionResult#rpcReceive. +func receiveGoResolutionResult(before tree.GoResolutionResult, q *ReceiveQueue) tree.GoResolutionResult { + idStr := receiveScalar[string](q, before.Ident.String()) + if idStr != "" { + if parsed, err := uuid.Parse(idStr); err == nil { + before.Ident = parsed + } + } + before.ModulePath = receiveScalar[string](q, before.ModulePath) + before.GoVersion = receiveNullableString(q, before.GoVersion) + before.Toolchain = receiveNullableString(q, before.Toolchain) + before.Path = receiveScalar[string](q, before.Path) + + before.Requires = recvRequires(q, before.Requires) + before.Replaces = recvReplaces(q, before.Replaces) + before.Excludes = recvExcludes(q, before.Excludes) + before.Retracts = recvRetracts(q, before.Retracts) + before.ResolvedDependencies = recvResolvedDeps(q, before.ResolvedDependencies) + return before +} + +func recvRequires(q *ReceiveQueue, before []tree.GoRequire) []tree.GoRequire { + beforeAny := requireSlice(before) + afterAny := q.ReceiveList(beforeAny, func(v any) any { + r := v.(tree.GoRequire) + r.ModulePath = receiveScalar[string](q, r.ModulePath) + r.Version = receiveScalar[string](q, r.Version) + r.Indirect = receiveScalar[bool](q, r.Indirect) + return r + }) + if afterAny == nil { + return nil + } + out := make([]tree.GoRequire, len(afterAny)) + for i, v := range afterAny { + out[i] = v.(tree.GoRequire) + } + return out +} + +func recvReplaces(q *ReceiveQueue, before []tree.GoReplace) []tree.GoReplace { + beforeAny := replaceSlice(before) + afterAny := q.ReceiveList(beforeAny, func(v any) any { + r := v.(tree.GoReplace) + r.OldPath = receiveScalar[string](q, r.OldPath) + r.OldVersion = receiveNullableString(q, r.OldVersion) + r.NewPath = receiveScalar[string](q, r.NewPath) + r.NewVersion = receiveNullableString(q, r.NewVersion) + return r + }) + if afterAny == nil { + return nil + } + out := make([]tree.GoReplace, len(afterAny)) + for i, v := range afterAny { + out[i] = v.(tree.GoReplace) + } + return out +} + +func recvExcludes(q *ReceiveQueue, before []tree.GoExclude) []tree.GoExclude { + beforeAny := excludeSlice(before) + afterAny := q.ReceiveList(beforeAny, func(v any) any { + e := v.(tree.GoExclude) + e.ModulePath = receiveScalar[string](q, e.ModulePath) + e.Version = receiveScalar[string](q, e.Version) + return e + }) + if afterAny == nil { + return nil + } + out := make([]tree.GoExclude, len(afterAny)) + for i, v := range afterAny { + out[i] = v.(tree.GoExclude) + } + return out +} + +func recvRetracts(q *ReceiveQueue, before []tree.GoRetract) []tree.GoRetract { + beforeAny := retractSlice(before) + afterAny := q.ReceiveList(beforeAny, func(v any) any { + r := v.(tree.GoRetract) + r.VersionRange = receiveScalar[string](q, r.VersionRange) + r.Rationale = receiveNullableString(q, r.Rationale) + return r + }) + if afterAny == nil { + return nil + } + out := make([]tree.GoRetract, len(afterAny)) + for i, v := range afterAny { + out[i] = v.(tree.GoRetract) + } + return out +} + +func recvResolvedDeps(q *ReceiveQueue, before []tree.GoResolvedDependency) []tree.GoResolvedDependency { + beforeAny := resolvedSlice(before) + afterAny := q.ReceiveList(beforeAny, func(v any) any { + d := v.(tree.GoResolvedDependency) + d.ModulePath = receiveScalar[string](q, d.ModulePath) + d.Version = receiveScalar[string](q, d.Version) + d.ModuleHash = receiveNullableString(q, d.ModuleHash) + d.GoModHash = receiveNullableString(q, d.GoModHash) + return d + }) + if afterAny == nil { + return nil + } + out := make([]tree.GoResolvedDependency, len(afterAny)) + for i, v := range afterAny { + out[i] = v.(tree.GoResolvedDependency) + } + return out +} + +func requireSlice(s []tree.GoRequire) []any { + if s == nil { + return nil + } + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + +func replaceSlice(s []tree.GoReplace) []any { + if s == nil { + return nil + } + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + +func excludeSlice(s []tree.GoExclude) []any { + if s == nil { + return nil + } + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + +func retractSlice(s []tree.GoRetract) []any { + if s == nil { + return nil + } + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + +func resolvedSlice(s []tree.GoResolvedDependency) []any { + if s == nil { + return nil + } + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + +// receiveNullableString reads a value that may be null on the wire and +// returns "" if so. Mirrors how emptyAsNil is sent on the send side. +func receiveNullableString(q *ReceiveQueue, before string) string { + v := q.Receive(emptyAsNil(before), nil) + if v == nil { + return "" + } + return v.(string) +} diff --git a/rewrite-go/pkg/rpc/go_sender.go b/rewrite-go/pkg/rpc/go_sender.go index 5c4a63db8aa..5b3f40d35f7 100644 --- a/rewrite-go/pkg/rpc/go_sender.go +++ b/rewrite-go/pkg/rpc/go_sender.go @@ -18,99 +18,49 @@ package rpc import ( "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) -// GoSender serializes Go AST nodes into the send queue. -// Handles G (Go-specific) nodes and delegates J nodes to JavaSender. +// GoSender serializes Go AST nodes via the visitor pattern. Mirrors +// org.openrewrite.golang.internal.rpc.GolangSender, which extends +// JavaSender, which extends JavaVisitor. +// +// GoSender embeds JavaSender to inherit J-node Visit overrides + the +// PreVisit hook, and adds VisitX overrides for the G-specific node +// types. The framework's type-switch dispatch in visitor.GoVisitor.Visit +// routes to the most-derived override via the Self field set by +// visitor.Init. type GoSender struct { - java JavaSender + JavaSender } -// NewGoSender creates a GoSender with its JavaSender properly wired. +// NewGoSender creates a GoSender ready to serialize trees. Sets Self +// so the framework's dispatch routes to G-node and J-node overrides +// across the embedding chain. func NewGoSender() *GoSender { gs := &GoSender{} - gs.java = JavaSender{ - typeSender: NewJavaTypeSender(), - parent: gs, - } - return gs + gs.typeSender = NewJavaTypeSender() + return visitor.Init(gs) } -// Visit dispatches to the appropriate send method based on node type. -func (s *GoSender) Visit(node any, q *SendQueue) { - if node == nil { - return - } - - // ParseError has its own codec — handle before preVisit (no prefix field) - if pe, ok := node.(*tree.ParseError); ok { - s.sendParseError(pe, q) - return +// Visit overrides the framework dispatch to special-case ParseError — +// it isn't a J node, has no Prefix/Markers, and uses its own codec. +// All other tree types fall through to the framework's switch. +func (s *GoSender) Visit(t tree.Tree, p any) tree.Tree { + if t == nil { + return nil } - - // preVisit: send ID, prefix, markers - s.preVisit(node, q) - - switch v := node.(type) { - // G nodes (Go-specific) - case *tree.CompilationUnit: - s.sendCompilationUnit(v, q) - case *tree.GoStmt: - s.sendGoStmt(v, q) - case *tree.Defer: - s.sendDefer(v, q) - case *tree.Send: - s.sendSend(v, q) - case *tree.Goto: - s.sendGoto(v, q) - case *tree.Fallthrough: - // No fields - case *tree.Composite: - s.sendComposite(v, q) - case *tree.KeyValue: - s.sendKeyValue(v, q) - case *tree.Slice: - s.sendSlice(v, q) - case *tree.MapType: - s.sendMapType(v, q) - case *tree.StatementExpression: - s.sendStatementExpression(v, q) - case *tree.PointerType: - s.sendPointerType(v, q) - case *tree.Channel: - s.sendChannel(v, q) - case *tree.FuncType: - s.sendFuncType(v, q) - case *tree.StructType: - s.sendStructType(v, q) - case *tree.InterfaceType: - s.sendInterfaceType(v, q) - case *tree.TypeList: - s.sendTypeList(v, q) - case *tree.TypeDecl: - s.sendTypeDecl(v, q) - case *tree.MultiAssignment: - s.sendMultiAssignment(v, q) - case *tree.CommClause: - s.sendCommClause(v, q) - case *tree.IndexList: - s.sendIndexList(v, q) - - default: - // Delegate all J nodes to JavaSender - s.java.visitJ(node, q) + if pe, ok := t.(*tree.ParseError); ok { + s.sendParseError(pe, p.(*SendQueue)) + return pe } -} - -func (s *GoSender) preVisit(node any, q *SendQueue) { - q.GetAndSend(node, nodeID, nil) - q.GetAndSend(node, nodePrefix, func(v any) { sendSpace(v.(tree.Space), q) }) - q.GetAndSend(node, nodeMarkers, func(v any) { SendMarkersCodec(v.(tree.Markers), q) }) + return s.GoVisitor.Visit(t, p) } // --- G nodes --- -func (s *GoSender) sendCompilationUnit(cu *tree.CompilationUnit, q *SendQueue) { +func (s *GoSender) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(cu, func(v any) any { return v.(*tree.CompilationUnit).SourcePath }, nil) // charset - Go doesn't track this, send empty/default q.GetAndSend(cu, func(_ any) any { return "UTF-8" }, nil) @@ -151,47 +101,61 @@ func (s *GoSender) sendCompilationUnit(cu *tree.CompilationUnit, q *SendQueue) { // EOF space q.GetAndSend(cu, func(v any) any { return v.(*tree.CompilationUnit).EOF }, func(v any) { sendSpace(v.(tree.Space), q) }) + return cu } -func (s *GoSender) sendGoStmt(gs *tree.GoStmt, q *SendQueue) { +func (s *GoSender) VisitGoStmt(gs *tree.GoStmt, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(gs, func(v any) any { return v.(*tree.GoStmt).Expr }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return gs } -func (s *GoSender) sendDefer(d *tree.Defer, q *SendQueue) { +func (s *GoSender) VisitDefer(d *tree.Defer, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(d, func(v any) any { return v.(*tree.Defer).Expr }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return d } -func (s *GoSender) sendSend(sn *tree.Send, q *SendQueue) { +func (s *GoSender) VisitSend(sn *tree.Send, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(sn, func(v any) any { return v.(*tree.Send).Channel }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(sn, func(v any) any { return v.(*tree.Send).Arrow }, func(v any) { sendLeftPadded(s, v, q) }) + return sn } -func (s *GoSender) sendGoto(g *tree.Goto, q *SendQueue) { +func (s *GoSender) VisitGoto(g *tree.Goto, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(g, func(v any) any { return v.(*tree.Goto).Label }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return g } -func (s *GoSender) sendComposite(c *tree.Composite, q *SendQueue) { +func (s *GoSender) VisitComposite(c *tree.Composite, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(c, func(v any) any { return v.(*tree.Composite).TypeExpr }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(c, func(v any) any { return v.(*tree.Composite).Elements }, func(v any) { sendContainer(s, v, q) }) + return c } -func (s *GoSender) sendKeyValue(kv *tree.KeyValue, q *SendQueue) { +func (s *GoSender) VisitKeyValue(kv *tree.KeyValue, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(kv, func(v any) any { return v.(*tree.KeyValue).Key }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(kv, func(v any) any { return v.(*tree.KeyValue).Value }, func(v any) { sendLeftPadded(s, v, q) }) + return kv } -func (s *GoSender) sendSlice(sl *tree.Slice, q *SendQueue) { +func (s *GoSender) VisitSlice(sl *tree.Slice, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(sl, func(v any) any { return v.(*tree.Slice).Indexed }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(sl, func(v any) any { return v.(*tree.Slice).OpenBracket }, func(v any) { sendSpace(v.(tree.Space), q) }) q.GetAndSend(sl, func(v any) any { return v.(*tree.Slice).Low }, @@ -199,31 +163,39 @@ func (s *GoSender) sendSlice(sl *tree.Slice, q *SendQueue) { q.GetAndSend(sl, func(v any) any { return v.(*tree.Slice).High }, func(v any) { sendRightPadded(s, v, q) }) q.GetAndSend(sl, func(v any) any { return v.(*tree.Slice).Max }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(sl, func(v any) any { return v.(*tree.Slice).CloseBracket }, func(v any) { sendSpace(v.(tree.Space), q) }) + return sl } -func (s *GoSender) sendMapType(mt *tree.MapType, q *SendQueue) { +func (s *GoSender) VisitMapType(mt *tree.MapType, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(mt, func(v any) any { return v.(*tree.MapType).OpenBracket }, func(v any) { sendSpace(v.(tree.Space), q) }) q.GetAndSend(mt, func(v any) any { return v.(*tree.MapType).Key }, func(v any) { sendRightPadded(s, v, q) }) q.GetAndSend(mt, func(v any) any { return v.(*tree.MapType).Value }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return mt } -func (s *GoSender) sendStatementExpression(se *tree.StatementExpression, q *SendQueue) { +func (s *GoSender) VisitStatementExpression(se *tree.StatementExpression, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(se, func(v any) any { return v.(*tree.StatementExpression).Statement }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return se } -func (s *GoSender) sendPointerType(pt *tree.PointerType, q *SendQueue) { +func (s *GoSender) VisitPointerType(pt *tree.PointerType, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(pt, func(v any) any { return v.(*tree.PointerType).Elem }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return pt } -func (s *GoSender) sendChannel(ch *tree.Channel, q *SendQueue) { +func (s *GoSender) VisitChannel(ch *tree.Channel, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(ch, func(v any) any { switch v.(*tree.Channel).Dir { case tree.ChanBidi: @@ -237,34 +209,56 @@ func (s *GoSender) sendChannel(ch *tree.Channel, q *SendQueue) { } }, nil) q.GetAndSend(ch, func(v any) any { return v.(*tree.Channel).Value }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return ch } -func (s *GoSender) sendFuncType(ft *tree.FuncType, q *SendQueue) { +func (s *GoSender) VisitFuncType(ft *tree.FuncType, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(ft, func(v any) any { return v.(*tree.FuncType).Parameters }, func(v any) { sendContainer(s, v, q) }) q.GetAndSend(ft, func(v any) any { return v.(*tree.FuncType).ReturnType }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return ft } -func (s *GoSender) sendStructType(st *tree.StructType, q *SendQueue) { +func (s *GoSender) VisitStructType(st *tree.StructType, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(st, func(v any) any { return v.(*tree.StructType).Body }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return st } -func (s *GoSender) sendInterfaceType(it *tree.InterfaceType, q *SendQueue) { +func (s *GoSender) VisitInterfaceType(it *tree.InterfaceType, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(it, func(v any) any { return v.(*tree.InterfaceType).Body }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return it } -func (s *GoSender) sendTypeList(tl *tree.TypeList, q *SendQueue) { +func (s *GoSender) VisitTypeList(tl *tree.TypeList, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(tl, func(v any) any { return v.(*tree.TypeList).Types }, func(v any) { sendContainer(s, v, q) }) + return tl } -func (s *GoSender) sendTypeDecl(td *tree.TypeDecl, q *SendQueue) { +func (s *GoSender) VisitTypeDecl(td *tree.TypeDecl, p any) tree.J { + q := p.(*SendQueue) + // leadingAnnotations (`//go:` directives modeled as J.Annotation) + q.GetAndSendList(td, + func(v any) []any { + anns := v.(*tree.TypeDecl).LeadingAnnotations + result := make([]any, len(anns)) + for i, a := range anns { + result[i] = a + } + return result + }, + func(v any) any { return extractID(v) }, + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(td, func(v any) any { return v.(*tree.TypeDecl).Name }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // Assign — dereference pointer so sendLeftPadded gets a value type q.GetAndSend(td, func(v any) any { a := v.(*tree.TypeDecl).Assign @@ -272,16 +266,18 @@ func (s *GoSender) sendTypeDecl(td *tree.TypeDecl, q *SendQueue) { return *a }, func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(td, func(v any) any { return v.(*tree.TypeDecl).Definition }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // Specs — dereference pointer so sendContainer gets a value type q.GetAndSend(td, func(v any) any { sp := v.(*tree.TypeDecl).Specs if sp == nil { return nil } return *sp }, func(v any) { sendContainer(s, v, q) }) + return td } -func (s *GoSender) sendMultiAssignment(ma *tree.MultiAssignment, q *SendQueue) { +func (s *GoSender) VisitMultiAssignment(ma *tree.MultiAssignment, p any) tree.J { + q := p.(*SendQueue) q.GetAndSendList(ma, func(v any) []any { vars := v.(*tree.MultiAssignment).Variables @@ -306,11 +302,13 @@ func (s *GoSender) sendMultiAssignment(ma *tree.MultiAssignment, q *SendQueue) { }, func(v any) any { return containerElementID(v) }, func(v any) { sendRightPadded(s, v, q) }) + return ma } -func (s *GoSender) sendCommClause(cc *tree.CommClause, q *SendQueue) { +func (s *GoSender) VisitCommClause(cc *tree.CommClause, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(cc, func(v any) any { return v.(*tree.CommClause).Comm }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(cc, func(v any) any { return v.(*tree.CommClause).Colon }, func(v any) { sendSpace(v.(tree.Space), q) }) q.GetAndSendList(cc, @@ -324,10 +322,15 @@ func (s *GoSender) sendCommClause(cc *tree.CommClause, q *SendQueue) { }, func(v any) any { return containerElementID(v) }, func(v any) { sendRightPadded(s, v, q) }) + return cc } // sendParseError serializes a ParseError matching Java's ParseError.rpcSend field order: // id, markers, sourcePath, charsetName, charsetBomMarked, checksum, fileAttributes, text +// +// ParseError isn't a J node — no Prefix/Markers handling via PreVisit. +// Special-cased at the top of Visit() instead of dispatched through +// the framework switch. func (s *GoSender) sendParseError(pe *tree.ParseError, q *SendQueue) { q.GetAndSend(pe, func(v any) any { return v.(*tree.ParseError).Ident.String() }, nil) q.GetAndSend(pe, func(v any) any { return v.(*tree.ParseError).Markers }, @@ -340,10 +343,12 @@ func (s *GoSender) sendParseError(pe *tree.ParseError, q *SendQueue) { q.GetAndSend(pe, func(v any) any { return v.(*tree.ParseError).Text }, nil) } -func (s *GoSender) sendIndexList(il *tree.IndexList, q *SendQueue) { +func (s *GoSender) VisitIndexList(il *tree.IndexList, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(il, func(v any) any { return v.(*tree.IndexList).Target }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(il, func(v any) any { return v.(*tree.IndexList).Indices }, func(v any) { sendContainer(s, v, q) }) + return il } diff --git a/rewrite-go/pkg/rpc/java_receiver.go b/rewrite-go/pkg/rpc/java_receiver.go index e44ff3b655b..e64e902c93e 100644 --- a/rewrite-go/pkg/rpc/java_receiver.go +++ b/rewrite-go/pkg/rpc/java_receiver.go @@ -19,92 +19,58 @@ package rpc import ( "github.com/google/uuid" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) -// Receiver can deserialize any AST node. +// Receiver deserializes a tree node from a ReceiveQueue. Mirrors the +// visitor.VisitorI signature so receivers slot into the framework's +// dispatch — same shape as rewrite-java's JavaReceiver, which extends +// JavaVisitor. type Receiver interface { - Visit(node any, q *ReceiveQueue) any + Visit(t tree.Tree, p any) tree.Tree } -// JavaReceiver deserializes J (shared Java-like) AST nodes from the receive queue. -// Mirrors JavaReceiver.java for J nodes. +// JavaReceiver deserializes J (shared Java-like) AST nodes via the +// visitor pattern. Mirrors org.openrewrite.java.internal.rpc.JavaReceiver. +// +// JavaReceiver embeds visitor.GoVisitor; the framework's type-switch +// dispatch routes calls to JavaReceiver's VisitX overrides via the +// Self field. PreVisit handles the cross-cutting fields (id, prefix, +// markers) once per node, mirroring Java's JavaVisitor.preVisit. type JavaReceiver struct { + visitor.GoVisitor typeReceiver *JavaTypeReceiver - parent Receiver // the GoReceiver (or other language receiver) that delegates to us -} - -// visitJ dispatches J-node field deserialization (after preVisit has been called by the parent). -func (r *JavaReceiver) visitJ(node any, q *ReceiveQueue) any { - switch v := node.(type) { - case *tree.Identifier: - return r.receiveIdentifier(v, q) - case *tree.Literal: - return r.receiveLiteral(v, q) - case *tree.Binary: - return r.receiveBinary(v, q) - case *tree.Block: - return r.receiveBlock(v, q) - case *tree.Empty: - return v - case *tree.Unary: - return r.receiveUnary(v, q) - case *tree.FieldAccess: - return r.receiveFieldAccess(v, q) - case *tree.MethodInvocation: - return r.receiveMethodInvocation(v, q) - case *tree.Assignment: - return r.receiveAssignment(v, q) - case *tree.AssignmentOperation: - return r.receiveAssignmentOperation(v, q) - case *tree.MethodDeclaration: - return r.receiveMethodDeclaration(v, q) - case *tree.VariableDeclarations: - return r.receiveVariableDeclarations(v, q) - case *tree.VariableDeclarator: - return r.receiveVariableDeclarator(v, q) - case *tree.Return: - return r.receiveReturn(v, q) - case *tree.If: - return r.receiveIf(v, q) - case *tree.Else: - return r.receiveElse(v, q) - case *tree.ForLoop: - return r.receiveForLoop(v, q) - case *tree.ForControl: - return r.receiveForControl(v, q) - case *tree.ForEachLoop: - return r.receiveForEachLoop(v, q) - case *tree.ForEachControl: - return r.receiveForEachControl(v, q) - case *tree.Switch: - return r.receiveSwitch(v, q) - case *tree.Case: - return r.receiveCase(v, q) - case *tree.Break: - return r.receiveBreak(v, q) - case *tree.Continue: - return r.receiveContinue(v, q) - case *tree.Label: - return r.receiveLabel(v, q) - case *tree.ArrayType: - return r.receiveArrayType(v, q) - case *tree.ArrayAccess: - return r.receiveArrayAccess(v, q) - case *tree.ParameterizedType: - return r.receiveParameterizedType(v, q) - case *tree.ArrayDimension: - return r.receiveArrayDimension(v, q) - case *tree.Parentheses: - return r.receiveParentheses(v, q) - case *tree.TypeCast: - return r.receiveTypeCast(v, q) - case *tree.ControlParentheses: - return r.receiveControlParentheses(v, q) - case *tree.Import: - return r.receiveImport(v, q) - default: - return node +} + +// PreVisit deserializes the cross-cutting fields of every J node: +// id, prefix, markers. Returns a (possibly new) instance of t with +// those fields populated from the queue. ParseError isn't a J node +// and is special-cased at the GoReceiver layer. +// +// Each field is updated via the immutable typed wither (WithPrefix / +// WithMarkers, called polymorphically via reflection) so the input +// tree is never mutated and the visitor framework's pointer-identity +// change detection sees the right thing — same input pointer means +// nothing changed; different output pointer means something changed. +// Mirrors rewrite-java's `j.withPrefix(...).withMarkers(...)` chain. +func (r *JavaReceiver) PreVisit(t tree.Tree, p any) tree.Tree { + j, isJ := t.(tree.J) + if !isJ { + return t + } + q := p.(*ReceiveQueue) + q.Receive(j.GetID().String(), nil) + if result := q.Receive(j.GetPrefix(), func(v any) any { + return receiveSpace(v.(tree.Space), q) + }); result != nil { + t = withPrefixViaReflection(t, result.(tree.Space)) } + if result := q.Receive(j.GetMarkers(), func(v any) any { + return receiveMarkersCodec(q, v.(tree.Markers)) + }); result != nil { + t = withMarkersViaReflection(t, result.(tree.Markers)) + } + return t } // receiveType receives a JavaType from the queue with null/Unknown handling. @@ -120,7 +86,8 @@ func (r *JavaReceiver) receiveType(before tree.JavaType, q *ReceiveQueue) tree.J // --- J nodes --- -func (r *JavaReceiver) receiveIdentifier(id *tree.Identifier, q *ReceiveQueue) *tree.Identifier { +func (r *JavaReceiver) VisitIdentifier(id *tree.Identifier, p any) tree.J { + q := p.(*ReceiveQueue) c := *id // shallow copy to avoid mutating remoteObjects baseline id = &c // annotations @@ -128,7 +95,7 @@ func (r *JavaReceiver) receiveIdentifier(id *tree.Identifier, q *ReceiveQueue) * for i, a := range id.Annotations { beforeAnns[i] = a } - afterAnns := q.ReceiveList(beforeAnns, func(v any) any { return r.parent.Visit(v, q) }) + afterAnns := q.ReceiveList(beforeAnns, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if afterAnns != nil { id.Annotations = make([]tree.Tree, len(afterAnns)) for i, a := range afterAnns { @@ -149,7 +116,8 @@ func (r *JavaReceiver) receiveIdentifier(id *tree.Identifier, q *ReceiveQueue) * return id } -func (r *JavaReceiver) receiveLiteral(lit *tree.Literal, q *ReceiveQueue) *tree.Literal { +func (r *JavaReceiver) VisitLiteral(lit *tree.Literal, p any) tree.J { + q := p.(*ReceiveQueue) c := *lit // shallow copy to avoid mutating remoteObjects baseline lit = &c // value @@ -163,17 +131,18 @@ func (r *JavaReceiver) receiveLiteral(lit *tree.Literal, q *ReceiveQueue) *tree. return lit } -func (r *JavaReceiver) receiveBinary(b *tree.Binary, q *ReceiveQueue) *tree.Binary { +func (r *JavaReceiver) VisitBinary(b *tree.Binary, p any) tree.J { + q := p.(*ReceiveQueue) c := *b // shallow copy to avoid mutating remoteObjects baseline b = &c - result := q.Receive(b.Left, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(b.Left, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { b.Left = result.(tree.Expression) } - if result := q.Receive(b.Operator, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(b.Operator, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { b.Operator = result.(tree.LeftPadded[tree.BinaryOperator]) } - rightResult := q.Receive(b.Right, func(v any) any { return r.parent.Visit(v, q) }) + rightResult := q.Receive(b.Right, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if rightResult != nil { b.Right = rightResult.(tree.Expression) } @@ -181,17 +150,18 @@ func (r *JavaReceiver) receiveBinary(b *tree.Binary, q *ReceiveQueue) *tree.Bina return b } -func (r *JavaReceiver) receiveBlock(b *tree.Block, q *ReceiveQueue) *tree.Block { +func (r *JavaReceiver) VisitBlock(b *tree.Block, p any) tree.J { + q := p.(*ReceiveQueue) c := *b // shallow copy to avoid mutating remoteObjects baseline b = &c // static (right-padded) - Java-only field, not stored in Go Block - q.Receive(nil, func(v any) any { return receiveRightPadded(r.parent, q, v) }) + q.Receive(nil, func(v any) any { return receiveRightPadded(r, q, v) }) // statements beforeStmts := make([]any, len(b.Statements)) for i, s := range b.Statements { beforeStmts[i] = s } - afterStmts := q.ReceiveList(beforeStmts, func(v any) any { return receiveRightPadded(r.parent, q, v) }) + afterStmts := q.ReceiveList(beforeStmts, func(v any) any { return receiveRightPadded(r, q, v) }) if afterStmts != nil { b.Statements = make([]tree.RightPadded[tree.Statement], len(afterStmts)) for i, s := range afterStmts { @@ -207,13 +177,36 @@ func (r *JavaReceiver) receiveBlock(b *tree.Block, q *ReceiveQueue) *tree.Block return b } -func (r *JavaReceiver) receiveUnary(u *tree.Unary, q *ReceiveQueue) *tree.Unary { +// receiveAnnotation matches JavaReceiver.visitAnnotation field order: +// annotationType, then nullable arguments container. +func (r *JavaReceiver) VisitAnnotation(ann *tree.Annotation, p any) tree.J { + q := p.(*ReceiveQueue) + c := *ann // shallow copy to avoid mutating remoteObjects baseline + ann = &c + if result := q.Receive(ann.AnnotationType, func(v any) any { return r.Visit(v.(tree.Tree), q) }); result != nil { + ann.AnnotationType = result.(tree.Expression) + } + var beforeArgs any + if ann.Arguments != nil { + beforeArgs = *ann.Arguments + } + if result := q.Receive(beforeArgs, func(v any) any { return receiveContainer(r, q, v) }); result != nil { + container := result.(tree.Container[tree.Expression]) + ann.Arguments = &container + } else { + ann.Arguments = nil + } + return ann +} + +func (r *JavaReceiver) VisitUnary(u *tree.Unary, p any) tree.J { + q := p.(*ReceiveQueue) c := *u // shallow copy to avoid mutating remoteObjects baseline u = &c - if result := q.Receive(u.Operator, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(u.Operator, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { u.Operator = result.(tree.LeftPadded[tree.UnaryOperator]) } - result := q.Receive(u.Operand, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(u.Operand, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { u.Operand = result.(tree.Expression) } @@ -221,21 +214,23 @@ func (r *JavaReceiver) receiveUnary(u *tree.Unary, q *ReceiveQueue) *tree.Unary return u } -func (r *JavaReceiver) receiveFieldAccess(fa *tree.FieldAccess, q *ReceiveQueue) *tree.FieldAccess { +func (r *JavaReceiver) VisitFieldAccess(fa *tree.FieldAccess, p any) tree.J { + q := p.(*ReceiveQueue) c := *fa // shallow copy to avoid mutating remoteObjects baseline fa = &c - result := q.Receive(fa.Target, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(fa.Target, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { fa.Target = result.(tree.Expression) } - if result := q.Receive(fa.Name, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(fa.Name, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { fa.Name = result.(tree.LeftPadded[*tree.Identifier]) } fa.Type = r.receiveType(fa.Type, q) return fa } -func (r *JavaReceiver) receiveMethodInvocation(mi *tree.MethodInvocation, q *ReceiveQueue) *tree.MethodInvocation { +func (r *JavaReceiver) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { + q := p.(*ReceiveQueue) c := *mi // shallow copy to avoid mutating remoteObjects baseline mi = &c // select @@ -243,7 +238,7 @@ func (r *JavaReceiver) receiveMethodInvocation(mi *tree.MethodInvocation, q *Rec if mi.Select != nil { beforeSelect = *mi.Select } - if result := q.Receive(beforeSelect, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(beforeSelect, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { rp := coerceToExpressionRP(result) mi.Select = &rp } else { @@ -252,12 +247,12 @@ func (r *JavaReceiver) receiveMethodInvocation(mi *tree.MethodInvocation, q *Rec // typeParameters (nil for Go) q.Receive(nil, nil) // name - result := q.Receive(mi.Name, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(mi.Name, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { mi.Name = result.(*tree.Identifier) } // arguments - if result := q.Receive(mi.Arguments, func(v any) any { return receiveContainer(r.parent, q, v) }); result != nil { + if result := q.Receive(mi.Arguments, func(v any) any { return receiveContainer(r, q, v) }); result != nil { mi.Arguments = result.(tree.Container[tree.Expression]) } // methodType @@ -270,31 +265,33 @@ func (r *JavaReceiver) receiveMethodInvocation(mi *tree.MethodInvocation, q *Rec return mi } -func (r *JavaReceiver) receiveAssignment(a *tree.Assignment, q *ReceiveQueue) *tree.Assignment { +func (r *JavaReceiver) VisitAssignment(a *tree.Assignment, p any) tree.J { + q := p.(*ReceiveQueue) c := *a // shallow copy to avoid mutating remoteObjects baseline a = &c - result := q.Receive(a.Variable, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(a.Variable, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { a.Variable = result.(tree.Expression) } - if result := q.Receive(a.Value, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(a.Value, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { a.Value = result.(tree.LeftPadded[tree.Expression]) } a.Type = r.receiveType(a.Type, q) return a } -func (r *JavaReceiver) receiveAssignmentOperation(a *tree.AssignmentOperation, q *ReceiveQueue) *tree.AssignmentOperation { +func (r *JavaReceiver) VisitAssignmentOperation(a *tree.AssignmentOperation, p any) tree.J { + q := p.(*ReceiveQueue) c := *a // shallow copy to avoid mutating remoteObjects baseline a = &c - result := q.Receive(a.Variable, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(a.Variable, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { a.Variable = result.(tree.Expression) } - if result := q.Receive(a.Operator, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(a.Operator, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { a.Operator = result.(tree.LeftPadded[tree.AssignmentOperator]) } - assignResult := q.Receive(a.Assignment, func(v any) any { return r.parent.Visit(v, q) }) + assignResult := q.Receive(a.Assignment, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if assignResult != nil { a.Assignment = assignResult.(tree.Expression) } @@ -302,35 +299,48 @@ func (r *JavaReceiver) receiveAssignmentOperation(a *tree.AssignmentOperation, q return a } -func (r *JavaReceiver) receiveMethodDeclaration(md *tree.MethodDeclaration, q *ReceiveQueue) *tree.MethodDeclaration { +func (r *JavaReceiver) VisitMethodDeclaration(md *tree.MethodDeclaration, p any) tree.J { + q := p.(*ReceiveQueue) c := *md // shallow copy to avoid mutating remoteObjects baseline md = &c // leadingAnnotations - q.ReceiveList(nil, nil) + beforeAnns := make([]any, len(md.LeadingAnnotations)) + for i, a := range md.LeadingAnnotations { + beforeAnns[i] = a + } + afterAnns := q.ReceiveList(beforeAnns, func(v any) any { return r.Visit(v.(tree.Tree), q) }) + if afterAnns != nil { + md.LeadingAnnotations = make([]*tree.Annotation, 0, len(afterAnns)) + for _, a := range afterAnns { + if a != nil { + md.LeadingAnnotations = append(md.LeadingAnnotations, a.(*tree.Annotation)) + } + } + } // modifiers q.ReceiveList(nil, nil) // typeParameters q.Receive(nil, nil) // returnTypeExpression - result := q.Receive(md.ReturnType, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(md.ReturnType, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { md.ReturnType = result.(tree.Expression) } // name annotations q.ReceiveList(nil, nil) // name - nameResult := q.Receive(md.Name, func(v any) any { return r.parent.Visit(v, q) }) + nameResult := q.Receive(md.Name, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if nameResult != nil { md.Name = nameResult.(*tree.Identifier) } // parameters - if result := q.Receive(md.Parameters, func(v any) any { return receiveContainerAs(r.parent, q, v, ContainerStatement) }); result != nil { + if result := q.Receive(md.Parameters, func(v any) any { return receiveContainerAs(r, q, v, ContainerStatement) }); result != nil { md.Parameters = result.(tree.Container[tree.Statement]) } // throws q.Receive(nil, nil) // body - bodyResult := q.Receive(md.Body, func(v any) any { return r.parent.Visit(v, q) }) + bodyResult := q.Receive(md.Body, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if bodyResult != nil { md.Body = bodyResult.(*tree.Block) } @@ -346,15 +356,28 @@ func (r *JavaReceiver) receiveMethodDeclaration(md *tree.MethodDeclaration, q *R return md } -func (r *JavaReceiver) receiveVariableDeclarations(vd *tree.VariableDeclarations, q *ReceiveQueue) *tree.VariableDeclarations { +func (r *JavaReceiver) VisitVariableDeclarations(vd *tree.VariableDeclarations, p any) tree.J { + q := p.(*ReceiveQueue) c := *vd // shallow copy to avoid mutating remoteObjects baseline vd = &c // leadingAnnotations - q.ReceiveList(nil, nil) + beforeAnns := make([]any, len(vd.LeadingAnnotations)) + for i, a := range vd.LeadingAnnotations { + beforeAnns[i] = a + } + afterAnns := q.ReceiveList(beforeAnns, func(v any) any { return r.Visit(v.(tree.Tree), q) }) + if afterAnns != nil { + vd.LeadingAnnotations = make([]*tree.Annotation, 0, len(afterAnns)) + for _, a := range afterAnns { + if a != nil { + vd.LeadingAnnotations = append(vd.LeadingAnnotations, a.(*tree.Annotation)) + } + } + } // modifiers q.ReceiveList(nil, nil) // typeExpression - result := q.Receive(vd.TypeExpr, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(vd.TypeExpr, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { vd.TypeExpr = result.(tree.Expression) } @@ -373,7 +396,7 @@ func (r *JavaReceiver) receiveVariableDeclarations(vd *tree.VariableDeclarations for i, v := range vd.Variables { beforeVars[i] = v } - afterVars := q.ReceiveList(beforeVars, func(v any) any { return receiveRightPadded(r.parent, q, v) }) + afterVars := q.ReceiveList(beforeVars, func(v any) any { return receiveRightPadded(r, q, v) }) if afterVars != nil { vd.Variables = make([]tree.RightPadded[*tree.VariableDeclarator], len(afterVars)) for i, v := range afterVars { @@ -383,11 +406,12 @@ func (r *JavaReceiver) receiveVariableDeclarations(vd *tree.VariableDeclarations return vd } -func (r *JavaReceiver) receiveVariableDeclarator(vd *tree.VariableDeclarator, q *ReceiveQueue) *tree.VariableDeclarator { +func (r *JavaReceiver) VisitVariableDeclarator(vd *tree.VariableDeclarator, p any) tree.J { + q := p.(*ReceiveQueue) c := *vd // shallow copy to avoid mutating remoteObjects baseline vd = &c // declarator (name) - result := q.Receive(vd.Name, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(vd.Name, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { vd.Name = result.(*tree.Identifier) } @@ -398,7 +422,7 @@ func (r *JavaReceiver) receiveVariableDeclarator(vd *tree.VariableDeclarator, q if vd.Initializer != nil { beforeInit = *vd.Initializer } - if result := q.Receive(beforeInit, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(beforeInit, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { lp := result.(tree.LeftPadded[tree.Expression]) vd.Initializer = &lp } else { @@ -409,14 +433,15 @@ func (r *JavaReceiver) receiveVariableDeclarator(vd *tree.VariableDeclarator, q return vd } -func (r *JavaReceiver) receiveReturn(ret *tree.Return, q *ReceiveQueue) *tree.Return { +func (r *JavaReceiver) VisitReturn(ret *tree.Return, p any) tree.J { + q := p.(*ReceiveQueue) c := *ret ret = &c var beforeExpr any if len(ret.Expressions) > 0 { beforeExpr = ret.Expressions[0].Element } - result := q.Receive(beforeExpr, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(beforeExpr, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { expr := result.(tree.Expression) if len(ret.Expressions) > 0 { @@ -436,7 +461,8 @@ func (r *JavaReceiver) receiveReturn(ret *tree.Return, q *ReceiveQueue) *tree.Re return ret } -func (r *JavaReceiver) receiveIf(i *tree.If, q *ReceiveQueue) *tree.If { +func (r *JavaReceiver) VisitIf(i *tree.If, p any) tree.J { + q := p.(*ReceiveQueue) c := *i // shallow copy to avoid mutating remoteObjects baseline i = &c // ifCondition - Java sends ControlParentheses; cache it for future round-trips @@ -450,7 +476,7 @@ func (r *JavaReceiver) receiveIf(i *tree.If, q *ReceiveQueue) *tree.If { Tree: tree.RightPadded[tree.Expression]{Element: i.Condition}, } } - cpResult := q.Receive(beforeCP, func(v any) any { return r.parent.Visit(v, q) }) + cpResult := q.Receive(beforeCP, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if cpResult != nil { if cp, ok := cpResult.(*tree.ControlParentheses); ok { i.ConditionCP = cp @@ -460,14 +486,14 @@ func (r *JavaReceiver) receiveIf(i *tree.If, q *ReceiveQueue) *tree.If { } } // thenPart - Java sends RightPadded wrapping the Block - if thenResult := q.Receive(nil, func(v any) any { return receiveRightPadded(r.parent, q, v) }); thenResult != nil { + if thenResult := q.Receive(nil, func(v any) any { return receiveRightPadded(r, q, v) }); thenResult != nil { rp := coerceToStatementRP(thenResult) if blk, ok := rp.Element.(*tree.Block); ok { i.Then = blk } } // elsePart - Java sends Else node, convert to RightPadded - elseResult := q.Receive(nil, func(v any) any { return r.parent.Visit(v, q) }) + elseResult := q.Receive(nil, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if elseResult != nil { el := elseResult.(*tree.Else) i.ElsePart = &tree.RightPadded[tree.J]{ @@ -480,25 +506,27 @@ func (r *JavaReceiver) receiveIf(i *tree.If, q *ReceiveQueue) *tree.If { return i } -func (r *JavaReceiver) receiveElse(el *tree.Else, q *ReceiveQueue) *tree.Else { +func (r *JavaReceiver) VisitElse(el *tree.Else, p any) tree.J { + q := p.(*ReceiveQueue) c := *el // shallow copy to avoid mutating remoteObjects baseline el = &c - if result := q.Receive(el.Body, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(el.Body, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { el.Body = result.(tree.RightPadded[tree.Statement]) } return el } -func (r *JavaReceiver) receiveForLoop(f *tree.ForLoop, q *ReceiveQueue) *tree.ForLoop { +func (r *JavaReceiver) VisitForLoop(f *tree.ForLoop, p any) tree.J { + q := p.(*ReceiveQueue) c := *f // shallow copy to avoid mutating remoteObjects baseline f = &c ctrl := &f.Control - result := q.Receive(ctrl, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(ctrl, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { f.Control = *result.(*tree.ForControl) } // body - Java sends RightPadded wrapping the Block - if bodyResult := q.Receive(nil, func(v any) any { return receiveRightPadded(r.parent, q, v) }); bodyResult != nil { + if bodyResult := q.Receive(nil, func(v any) any { return receiveRightPadded(r, q, v) }); bodyResult != nil { rp := coerceToStatementRP(bodyResult) if blk, ok := rp.Element.(*tree.Block); ok { f.Body = blk @@ -507,7 +535,8 @@ func (r *JavaReceiver) receiveForLoop(f *tree.ForLoop, q *ReceiveQueue) *tree.Fo return f } -func (r *JavaReceiver) receiveForControl(fc *tree.ForControl, q *ReceiveQueue) *tree.ForControl { +func (r *JavaReceiver) VisitForControl(fc *tree.ForControl, p any) tree.J { + q := p.(*ReceiveQueue) c := *fc // shallow copy to avoid mutating remoteObjects baseline fc = &c // init (list of right-padded) @@ -515,7 +544,7 @@ func (r *JavaReceiver) receiveForControl(fc *tree.ForControl, q *ReceiveQueue) * if fc.Init != nil { beforeInit = []any{*fc.Init} } - initList := q.ReceiveList(beforeInit, func(v any) any { return receiveRightPadded(r.parent, q, v) }) + initList := q.ReceiveList(beforeInit, func(v any) any { return receiveRightPadded(r, q, v) }) if len(initList) > 0 { rp := initList[0].(tree.RightPadded[tree.Statement]) fc.Init = &rp @@ -527,7 +556,7 @@ func (r *JavaReceiver) receiveForControl(fc *tree.ForControl, q *ReceiveQueue) * if fc.Condition != nil { beforeCond = *fc.Condition } - if result := q.Receive(beforeCond, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(beforeCond, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { rp := coerceToExpressionRP(result) fc.Condition = &rp } else { @@ -538,7 +567,7 @@ func (r *JavaReceiver) receiveForControl(fc *tree.ForControl, q *ReceiveQueue) * if fc.Update != nil { beforeUpdate = []any{*fc.Update} } - updateList := q.ReceiveList(beforeUpdate, func(v any) any { return receiveRightPadded(r.parent, q, v) }) + updateList := q.ReceiveList(beforeUpdate, func(v any) any { return receiveRightPadded(r, q, v) }) if len(updateList) > 0 { rp := updateList[0].(tree.RightPadded[tree.Statement]) fc.Update = &rp @@ -548,16 +577,17 @@ func (r *JavaReceiver) receiveForControl(fc *tree.ForControl, q *ReceiveQueue) * return fc } -func (r *JavaReceiver) receiveForEachLoop(f *tree.ForEachLoop, q *ReceiveQueue) *tree.ForEachLoop { +func (r *JavaReceiver) VisitForEachLoop(f *tree.ForEachLoop, p any) tree.J { + q := p.(*ReceiveQueue) c := *f // shallow copy to avoid mutating remoteObjects baseline f = &c ctrl := &f.Control - result := q.Receive(ctrl, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(ctrl, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { f.Control = *result.(*tree.ForEachControl) } // body - Java sends RightPadded wrapping the Block - if bodyResult := q.Receive(nil, func(v any) any { return receiveRightPadded(r.parent, q, v) }); bodyResult != nil { + if bodyResult := q.Receive(nil, func(v any) any { return receiveRightPadded(r, q, v) }); bodyResult != nil { rp := coerceToStatementRP(bodyResult) if blk, ok := rp.Element.(*tree.Block); ok { f.Body = blk @@ -566,7 +596,8 @@ func (r *JavaReceiver) receiveForEachLoop(f *tree.ForEachLoop, q *ReceiveQueue) return f } -func (r *JavaReceiver) receiveForEachControl(fc *tree.ForEachControl, q *ReceiveQueue) *tree.ForEachControl { +func (r *JavaReceiver) VisitForEachControl(fc *tree.ForEachControl, p any) tree.J { + q := p.(*ReceiveQueue) c := *fc // shallow copy to avoid mutating remoteObjects baseline fc = &c // key (right-padded, nullable) @@ -574,7 +605,7 @@ func (r *JavaReceiver) receiveForEachControl(fc *tree.ForEachControl, q *Receive if fc.Key != nil { beforeKey = *fc.Key } - if result := q.Receive(beforeKey, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(beforeKey, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { rp := coerceToExpressionRP(result) fc.Key = &rp } else { @@ -585,29 +616,30 @@ func (r *JavaReceiver) receiveForEachControl(fc *tree.ForEachControl, q *Receive if fc.Value != nil { beforeValue = *fc.Value } - if result := q.Receive(beforeValue, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(beforeValue, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { rp := coerceToExpressionRP(result) fc.Value = &rp } else { fc.Value = nil } // operator (left-padded AssignOp as string) - if result := q.Receive(fc.Operator, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(fc.Operator, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { fc.Operator = result.(tree.LeftPadded[tree.AssignOp]) } // iterable - result := q.Receive(fc.Iterable, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(fc.Iterable, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { fc.Iterable = result.(tree.Expression) } return fc } -func (r *JavaReceiver) receiveSwitch(sw *tree.Switch, q *ReceiveQueue) *tree.Switch { +func (r *JavaReceiver) VisitSwitch(sw *tree.Switch, p any) tree.J { + q := p.(*ReceiveQueue) c := *sw // shallow copy to avoid mutating remoteObjects baseline sw = &c // selector - Java sends ControlParentheses, extract inner Expression for Tag - cpResult := q.Receive(nil, func(v any) any { return r.parent.Visit(v, q) }) + cpResult := q.Receive(nil, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if cpResult != nil { if cp, ok := cpResult.(*tree.ControlParentheses); ok { if _, isEmpty := cp.Tree.Element.(*tree.Empty); !isEmpty { @@ -618,22 +650,23 @@ func (r *JavaReceiver) receiveSwitch(sw *tree.Switch, q *ReceiveQueue) *tree.Swi } } } - result := q.Receive(sw.Body, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(sw.Body, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { sw.Body = result.(*tree.Block) } return sw } -func (r *JavaReceiver) receiveCase(cs *tree.Case, q *ReceiveQueue) *tree.Case { +func (r *JavaReceiver) VisitCase(cs *tree.Case, p any) tree.J { + q := p.(*ReceiveQueue) c := *cs // shallow copy to avoid mutating remoteObjects baseline cs = &c q.Receive(nil, nil) // type enum - if result := q.Receive(cs.Expressions, func(v any) any { return receiveContainer(r.parent, q, v) }); result != nil { + if result := q.Receive(cs.Expressions, func(v any) any { return receiveContainer(r, q, v) }); result != nil { cs.Expressions = result.(tree.Container[tree.Expression]) } // statements - Java sends Container>, extract to Go's []RightPadded[Statement] - if result := q.Receive(nil, func(v any) any { return receiveContainerAs(r.parent, q, v, ContainerStatement) }); result != nil { + if result := q.Receive(nil, func(v any) any { return receiveContainerAs(r, q, v, ContainerStatement) }); result != nil { cont := result.(tree.Container[tree.Statement]) cs.Body = cont.Elements } @@ -642,61 +675,66 @@ func (r *JavaReceiver) receiveCase(cs *tree.Case, q *ReceiveQueue) *tree.Case { return cs } -func (r *JavaReceiver) receiveBreak(b *tree.Break, q *ReceiveQueue) *tree.Break { +func (r *JavaReceiver) VisitBreak(b *tree.Break, p any) tree.J { + q := p.(*ReceiveQueue) c := *b // shallow copy to avoid mutating remoteObjects baseline b = &c - result := q.Receive(b.Label, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(b.Label, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { b.Label = result.(*tree.Identifier) } return b } -func (r *JavaReceiver) receiveContinue(cont *tree.Continue, q *ReceiveQueue) *tree.Continue { +func (r *JavaReceiver) VisitContinue(cont *tree.Continue, p any) tree.J { + q := p.(*ReceiveQueue) c := *cont // shallow copy to avoid mutating remoteObjects baseline cont = &c - result := q.Receive(cont.Label, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(cont.Label, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { cont.Label = result.(*tree.Identifier) } return cont } -func (r *JavaReceiver) receiveLabel(l *tree.Label, q *ReceiveQueue) *tree.Label { +func (r *JavaReceiver) VisitLabel(l *tree.Label, p any) tree.J { + q := p.(*ReceiveQueue) c := *l // shallow copy to avoid mutating remoteObjects baseline l = &c - if result := q.Receive(l.Name, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(l.Name, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { l.Name = coerceRightPaddedIdent(result) } - result := q.Receive(l.Statement, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(l.Statement, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { l.Statement = result.(tree.Statement) } return l } -func (r *JavaReceiver) receiveArrayType(at *tree.ArrayType, q *ReceiveQueue) *tree.ArrayType { +func (r *JavaReceiver) VisitArrayType(at *tree.ArrayType, p any) tree.J { + q := p.(*ReceiveQueue) c := *at // shallow copy to avoid mutating remoteObjects baseline at = &c - result := q.Receive(at.ElementType, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(at.ElementType, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { at.ElementType = result.(tree.Expression) } q.ReceiveList(nil, nil) // annotations - if result := q.Receive(at.Dimension, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(at.Dimension, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { at.Dimension = result.(tree.LeftPadded[tree.Space]) } at.Type = r.receiveType(at.Type, q) return at } -func (r *JavaReceiver) receiveParameterizedType(pt *tree.ParameterizedType, q *ReceiveQueue) *tree.ParameterizedType { +func (r *JavaReceiver) VisitParameterizedType(pt *tree.ParameterizedType, p any) tree.J { + q := p.(*ReceiveQueue) c := *pt pt = &c - if result := q.Receive(pt.Clazz, func(v any) any { return r.parent.Visit(v, q) }); result != nil { + if result := q.Receive(pt.Clazz, func(v any) any { return r.Visit(v.(tree.Tree), q) }); result != nil { pt.Clazz = result.(tree.Expression) } - if result := q.Receive(pt.TypeParameters, func(v any) any { return receiveContainer(r.parent, q, v) }); result != nil { + if result := q.Receive(pt.TypeParameters, func(v any) any { return receiveContainer(r, q, v) }); result != nil { container := result.(tree.Container[tree.Expression]) pt.TypeParameters = &container } @@ -704,69 +742,75 @@ func (r *JavaReceiver) receiveParameterizedType(pt *tree.ParameterizedType, q *R return pt } -func (r *JavaReceiver) receiveArrayAccess(aa *tree.ArrayAccess, q *ReceiveQueue) *tree.ArrayAccess { +func (r *JavaReceiver) VisitArrayAccess(aa *tree.ArrayAccess, p any) tree.J { + q := p.(*ReceiveQueue) c := *aa // shallow copy to avoid mutating remoteObjects baseline aa = &c - result := q.Receive(aa.Indexed, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(aa.Indexed, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { aa.Indexed = result.(tree.Expression) } - dimResult := q.Receive(aa.Dimension, func(v any) any { return r.parent.Visit(v, q) }) + dimResult := q.Receive(aa.Dimension, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if dimResult != nil { aa.Dimension = dimResult.(*tree.ArrayDimension) } return aa } -func (r *JavaReceiver) receiveArrayDimension(ad *tree.ArrayDimension, q *ReceiveQueue) *tree.ArrayDimension { +func (r *JavaReceiver) VisitArrayDimension(ad *tree.ArrayDimension, p any) tree.J { + q := p.(*ReceiveQueue) c := *ad // shallow copy to avoid mutating remoteObjects baseline ad = &c - if result := q.Receive(ad.Index, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(ad.Index, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { ad.Index = coerceToExpressionRP(result) } return ad } -func (r *JavaReceiver) receiveParentheses(p *tree.Parentheses, q *ReceiveQueue) *tree.Parentheses { - c := *p // shallow copy to avoid mutating remoteObjects baseline - p = &c - if result := q.Receive(p.Tree, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { - p.Tree = coerceToExpressionRP(result) +func (r *JavaReceiver) VisitParentheses(parens *tree.Parentheses, p any) tree.J { + q := p.(*ReceiveQueue) + c := *parens // shallow copy to avoid mutating remoteObjects baseline + parens = &c + if result := q.Receive(parens.Tree, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { + parens.Tree = coerceToExpressionRP(result) } - return p + return parens } -func (r *JavaReceiver) receiveTypeCast(tc *tree.TypeCast, q *ReceiveQueue) *tree.TypeCast { +func (r *JavaReceiver) VisitTypeCast(tc *tree.TypeCast, p any) tree.J { + q := p.(*ReceiveQueue) c := *tc // shallow copy to avoid mutating remoteObjects baseline tc = &c - result := q.Receive(tc.Clazz, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(tc.Clazz, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { tc.Clazz = result.(*tree.ControlParentheses) } - exprResult := q.Receive(tc.Expr, func(v any) any { return r.parent.Visit(v, q) }) + exprResult := q.Receive(tc.Expr, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if exprResult != nil { tc.Expr = exprResult.(tree.Expression) } return tc } -func (r *JavaReceiver) receiveControlParentheses(cp *tree.ControlParentheses, q *ReceiveQueue) *tree.ControlParentheses { +func (r *JavaReceiver) VisitControlParentheses(cp *tree.ControlParentheses, p any) tree.J { + q := p.(*ReceiveQueue) c := *cp // shallow copy to avoid mutating remoteObjects baseline cp = &c - if result := q.Receive(cp.Tree, func(v any) any { return receiveRightPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(cp.Tree, func(v any) any { return receiveRightPadded(r, q, v) }); result != nil { cp.Tree = coerceToExpressionRP(result) } return cp } -func (r *JavaReceiver) receiveImport(imp *tree.Import, q *ReceiveQueue) *tree.Import { +func (r *JavaReceiver) VisitImport(imp *tree.Import, p any) tree.J { + q := p.(*ReceiveQueue) c := *imp // shallow copy to avoid mutating remoteObjects baseline imp = &c // static (always false for Go, but must receive full LeftPadded protocol) staticBefore := tree.LeftPadded[bool]{Before: tree.EmptySpace, Element: false} - q.Receive(staticBefore, func(v any) any { return receiveLeftPadded(r.parent, q, v) }) + q.Receive(staticBefore, func(v any) any { return receiveLeftPadded(r, q, v) }) // qualid (Expression - could be Literal or FieldAccess depending on direction) - result := q.Receive(imp.Qualid, func(v any) any { return r.parent.Visit(v, q) }) + result := q.Receive(imp.Qualid, func(v any) any { return r.Visit(v.(tree.Tree), q) }) if result != nil { imp.Qualid = result.(tree.Expression) } @@ -775,7 +819,7 @@ func (r *JavaReceiver) receiveImport(imp *tree.Import, q *ReceiveQueue) *tree.Im if imp.Alias != nil { beforeAlias = *imp.Alias } - if result := q.Receive(beforeAlias, func(v any) any { return receiveLeftPadded(r.parent, q, v) }); result != nil { + if result := q.Receive(beforeAlias, func(v any) any { return receiveLeftPadded(r, q, v) }); result != nil { lp := result.(tree.LeftPadded[*tree.Identifier]) imp.Alias = &lp } else { diff --git a/rewrite-go/pkg/rpc/java_sender.go b/rewrite-go/pkg/rpc/java_sender.go index 243aa64063e..2a51d7bdc2c 100644 --- a/rewrite-go/pkg/rpc/java_sender.go +++ b/rewrite-go/pkg/rpc/java_sender.go @@ -19,90 +19,53 @@ package rpc import ( "github.com/google/uuid" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) -// Sender can serialize any AST node. +// Sender serializes a tree node into a SendQueue. Implementations are +// visitor.VisitorI-conformant so they slot into the standard visitor +// dispatch — mirrors rewrite-java's JavaSender, which extends +// JavaVisitor. type Sender interface { - Visit(node any, q *SendQueue) -} - -// JavaSender serializes J (shared Java-like) AST nodes into the send queue. -// Mirrors JavaSender.java for J nodes. + Visit(t tree.Tree, p any) tree.Tree +} + +// JavaSender serializes J (shared Java-like) AST nodes via the visitor +// pattern. Mirrors org.openrewrite.java.internal.rpc.JavaSender. +// +// Architecture: JavaSender embeds visitor.GoVisitor so the framework's +// type-switch dispatch routes calls into JavaSender's VisitX +// overrides. PreVisit handles the cross-cutting fields (id, prefix, +// markers) once per node, mirroring Java's JavaVisitor.preVisit. +// +// Language-specific senders (GoSender) embed JavaSender and override +// the additional G-node Visit methods on top. type JavaSender struct { + visitor.GoVisitor typeSender *JavaTypeSender - parent Sender // the GoSender (or other language sender) that delegates to us -} - -// visitJ dispatches J-node field serialization (after preVisit has been called by the parent). -func (s *JavaSender) visitJ(node any, q *SendQueue) { - switch v := node.(type) { - case *tree.Identifier: - s.sendIdentifier(v, q) - case *tree.Literal: - s.sendLiteral(v, q) - case *tree.Binary: - s.sendBinary(v, q) - case *tree.Block: - s.sendBlock(v, q) - case *tree.Return: - s.sendReturn(v, q) - case *tree.If: - s.sendIf(v, q) - case *tree.Else: - s.sendElse(v, q) - case *tree.Assignment: - s.sendAssignment(v, q) - case *tree.AssignmentOperation: - s.sendAssignmentOperation(v, q) - case *tree.MethodDeclaration: - s.sendMethodDeclaration(v, q) - case *tree.ForLoop: - s.sendForLoop(v, q) - case *tree.ForControl: - s.sendForControl(v, q) - case *tree.ForEachLoop: - s.sendForEachLoop(v, q) - case *tree.ForEachControl: - s.sendForEachControl(v, q) - case *tree.Switch: - s.sendSwitch(v, q) - case *tree.Case: - s.sendCase(v, q) - case *tree.Break: - s.sendBreak(v, q) - case *tree.Continue: - s.sendContinue(v, q) - case *tree.Label: - s.sendLabel(v, q) - case *tree.Empty: - // No fields - case *tree.Unary: - s.sendUnary(v, q) - case *tree.FieldAccess: - s.sendFieldAccess(v, q) - case *tree.MethodInvocation: - s.sendMethodInvocation(v, q) - case *tree.VariableDeclarations: - s.sendVariableDeclarations(v, q) - case *tree.VariableDeclarator: - s.sendVariableDeclarator(v, q) - case *tree.ArrayType: - s.sendArrayType(v, q) - case *tree.ArrayAccess: - s.sendArrayAccess(v, q) - case *tree.ParameterizedType: - s.sendParameterizedType(v, q) - case *tree.ArrayDimension: - s.sendArrayDimension(v, q) - case *tree.Parentheses: - s.sendParentheses(v, q) - case *tree.TypeCast: - s.sendTypeCast(v, q) - case *tree.ControlParentheses: - s.sendControlParentheses(v, q) - case *tree.Import: - s.sendImport(v, q) +} + +// PreVisit serializes the cross-cutting fields of every J node: +// id, prefix, markers. Called by the framework before the +// type-specific VisitX dispatch. ParseError isn't a J node and is +// special-cased at the GoSender layer. +// +// Field access goes through the polymorphic J-interface methods +// (GetID / GetPrefix / GetMarkers), mirroring rewrite-java's +// JavaVisitor.preVisit pattern. +func (s *JavaSender) PreVisit(t tree.Tree, p any) tree.Tree { + j, ok := t.(tree.J) + if !ok { + return t } + q := p.(*SendQueue) + q.GetAndSend(t, func(v any) any { return v.(tree.J).GetID().String() }, nil) + q.GetAndSend(t, func(v any) any { return v.(tree.J).GetPrefix() }, + func(v any) { sendSpace(v.(tree.Space), q) }) + q.GetAndSend(t, func(v any) any { return v.(tree.J).GetMarkers() }, + func(v any) { SendMarkersCodec(v.(tree.Markers), q) }) + _ = j + return t } // visitType sends a JavaType through the type sender with null/Unknown handling. @@ -118,7 +81,8 @@ func (s *JavaSender) visitType(t tree.JavaType, q *SendQueue) { // --- J nodes --- -func (s *JavaSender) sendIdentifier(id *tree.Identifier, q *SendQueue) { +func (s *JavaSender) VisitIdentifier(id *tree.Identifier, p any) tree.J { + q := p.(*SendQueue) // annotations (list) q.GetAndSendList(id, func(v any) []any { @@ -133,7 +97,7 @@ func (s *JavaSender) sendIdentifier(id *tree.Identifier, q *SendQueue) { return result }, func(v any) any { return extractID(v) }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // simpleName q.GetAndSend(id, func(v any) any { return v.(*tree.Identifier).Name }, nil) // type (as ref) @@ -142,9 +106,11 @@ func (s *JavaSender) sendIdentifier(id *tree.Identifier, q *SendQueue) { // fieldType (as ref) q.GetAndSend(id, func(v any) any { return AsRef(v.(*tree.Identifier).FieldType) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return id } -func (s *JavaSender) sendLiteral(lit *tree.Literal, q *SendQueue) { +func (s *JavaSender) VisitLiteral(lit *tree.Literal, p any) tree.J { + q := p.(*SendQueue) // value q.GetAndSend(lit, func(v any) any { return v.(*tree.Literal).Value }, nil) // valueSource (source text) @@ -154,22 +120,26 @@ func (s *JavaSender) sendLiteral(lit *tree.Literal, q *SendQueue) { // type (as ref) q.GetAndSend(lit, func(v any) any { return AsRef(v.(*tree.Literal).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return lit } -func (s *JavaSender) sendBinary(b *tree.Binary, q *SendQueue) { +func (s *JavaSender) VisitBinary(b *tree.Binary, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(b, func(v any) any { return v.(*tree.Binary).Left }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(b, func(v any) any { op := v.(*tree.Binary).Operator return tree.LeftPadded[string]{Before: op.Before, Element: op.Element.String(), Markers: op.Markers} - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(b, func(v any) any { return v.(*tree.Binary).Right }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(b, func(v any) any { return AsRef(v.(*tree.Binary).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return b } -func (s *JavaSender) sendBlock(b *tree.Block, q *SendQueue) { +func (s *JavaSender) VisitBlock(b *tree.Block, p any) tree.J { + q := p.(*SendQueue) // static (right-padded bool) - Java's JRightPadded with element=false // Send manually since Go doesn't have RightPadded[bool] sendRightPaddedBool(false, tree.EmptySpace, tree.Markers{}, q) @@ -184,13 +154,15 @@ func (s *JavaSender) sendBlock(b *tree.Block, q *SendQueue) { return result }, func(v any) any { return containerElementID(v) }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) // end space q.GetAndSend(b, func(v any) any { return v.(*tree.Block).End }, func(v any) { sendSpace(v.(tree.Space), q) }) + return b } -func (s *JavaSender) sendReturn(r *tree.Return, q *SendQueue) { +func (s *JavaSender) VisitReturn(r *tree.Return, p any) tree.J { + q := p.(*SendQueue) // Java's J.Return has a single expression; Go has multiple // The first expression maps to J.Return.expression q.GetAndSend(r, func(v any) any { @@ -199,10 +171,12 @@ func (s *JavaSender) sendReturn(r *tree.Return, q *SendQueue) { return exprs[0].Element } return nil - }, func(v any) { s.parent.Visit(v, q) }) + }, func(v any) { s.Visit(v.(tree.Tree), q) }) + return r } -func (s *JavaSender) sendIf(i *tree.If, q *SendQueue) { +func (s *JavaSender) VisitIf(i *tree.If, p any) tree.J { + q := p.(*SendQueue) // ifCondition - reuse cached ControlParentheses if available, otherwise create new q.GetAndSend(i, func(v any) any { ifNode := v.(*tree.If) @@ -216,14 +190,14 @@ func (s *JavaSender) sendIf(i *tree.If, q *SendQueue) { Markers: tree.Markers{ID: uuid.New()}, Tree: tree.RightPadded[tree.Expression]{Element: ifNode.Condition, After: tree.EmptySpace}, } - }, func(v any) { s.parent.Visit(v, q) }) + }, func(v any) { s.Visit(v.(tree.Tree), q) }) // thenPart (right-padded) q.GetAndSend(i, func(v any) any { return tree.RightPadded[tree.Statement]{ Element: v.(*tree.If).Then, After: tree.EmptySpace, } - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) // elsePart - wrap in Else node for Java's J.If.Else model q.GetAndSend(i, func(v any) any { ep := v.(*tree.If).ElsePart @@ -236,83 +210,105 @@ func (s *JavaSender) sendIf(i *tree.If, q *SendQueue) { Markers: tree.Markers{ID: uuid.New()}, Body: tree.RightPadded[tree.Statement]{Element: ep.Element.(tree.Statement), After: tree.EmptySpace}, } - }, func(v any) { s.parent.Visit(v, q) }) + }, func(v any) { s.Visit(v.(tree.Tree), q) }) + return i } -func (s *JavaSender) sendElse(el *tree.Else, q *SendQueue) { +func (s *JavaSender) VisitElse(el *tree.Else, p any) tree.J { + q := p.(*SendQueue) // body (right-padded Statement) q.GetAndSend(el, func(v any) any { return v.(*tree.Else).Body }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) + return el } -func (s *JavaSender) sendAssignment(a *tree.Assignment, q *SendQueue) { +func (s *JavaSender) VisitAssignment(a *tree.Assignment, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(a, func(v any) any { return v.(*tree.Assignment).Variable }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(a, func(v any) any { return v.(*tree.Assignment).Value }, - func(v any) { sendLeftPadded(s.parent, v, q) }) + func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(a, func(v any) any { return AsRef(v.(*tree.Assignment).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return a } -func (s *JavaSender) sendAssignmentOperation(a *tree.AssignmentOperation, q *SendQueue) { +func (s *JavaSender) VisitAssignmentOperation(a *tree.AssignmentOperation, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(a, func(v any) any { return v.(*tree.AssignmentOperation).Variable }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(a, func(v any) any { op := v.(*tree.AssignmentOperation).Operator return tree.LeftPadded[string]{Before: op.Before, Element: op.Element.String(), Markers: op.Markers} - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(a, func(v any) any { return v.(*tree.AssignmentOperation).Assignment }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(a, func(v any) any { return AsRef(v.(*tree.AssignmentOperation).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return a } -func (s *JavaSender) sendMethodDeclaration(md *tree.MethodDeclaration, q *SendQueue) { +func (s *JavaSender) VisitMethodDeclaration(md *tree.MethodDeclaration, p any) tree.J { + q := p.(*SendQueue) // Go's MethodDeclaration maps to parts of Java's MethodDeclaration // Java sends: leadingAnnotations, modifiers, typeParameters, returnTypeExpression, // name annotations, name, parameters, throws, body, defaultValue, methodType // Go: receiver, name, parameters, returnType, body, methodType - // leadingAnnotations (empty for Go) - q.GetAndSendList(md, func(_ any) []any { return nil }, func(_ any) any { return nil }, nil) + // leadingAnnotations (`//go:` directives modeled as J.Annotation) + q.GetAndSendList(md, + func(v any) []any { + anns := v.(*tree.MethodDeclaration).LeadingAnnotations + result := make([]any, len(anns)) + for i, a := range anns { + result[i] = a + } + return result + }, + func(v any) any { return extractID(v) }, + func(v any) { s.Visit(v.(tree.Tree), q) }) // modifiers (empty for Go) q.GetAndSendList(md, func(_ any) []any { return nil }, func(_ any) any { return nil }, nil) // typeParameters (nil for Go) q.GetAndSend(md, func(_ any) any { return nil }, nil) // returnTypeExpression q.GetAndSend(md, func(v any) any { return v.(*tree.MethodDeclaration).ReturnType }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // name annotations (empty) q.GetAndSendList(md, func(_ any) []any { return nil }, func(_ any) any { return nil }, nil) // name q.GetAndSend(md, func(v any) any { return v.(*tree.MethodDeclaration).Name }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // parameters (container) q.GetAndSend(md, func(v any) any { return v.(*tree.MethodDeclaration).Parameters }, - func(v any) { sendContainer(s.parent, v, q) }) + func(v any) { sendContainer(s, v, q) }) // throws (nil for Go) q.GetAndSend(md, func(_ any) any { return nil }, nil) // body q.GetAndSend(md, func(v any) any { return v.(*tree.MethodDeclaration).Body }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // defaultValue (nil for Go) q.GetAndSend(md, func(_ any) any { return nil }, nil) // methodType (as ref) q.GetAndSend(md, func(v any) any { return AsRef(v.(*tree.MethodDeclaration).MethodType) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return md } -func (s *JavaSender) sendForLoop(f *tree.ForLoop, q *SendQueue) { +func (s *JavaSender) VisitForLoop(f *tree.ForLoop, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(f, func(v any) any { ctrl := v.(*tree.ForLoop).Control return &ctrl - }, func(v any) { s.parent.Visit(v, q) }) + }, func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(f, func(v any) any { return tree.RightPadded[tree.Statement]{Element: v.(*tree.ForLoop).Body, After: tree.EmptySpace} - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) + return f } -func (s *JavaSender) sendForControl(fc *tree.ForControl, q *SendQueue) { +func (s *JavaSender) VisitForControl(fc *tree.ForControl, p any) tree.J { + q := p.(*SendQueue) // init (list of right-padded) q.GetAndSendList(fc, func(v any) []any { @@ -323,13 +319,13 @@ func (s *JavaSender) sendForControl(fc *tree.ForControl, q *SendQueue) { return []any{*init} }, func(v any) any { return containerElementID(v) }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) // condition (right-padded) — dereference pointer q.GetAndSend(fc, func(v any) any { cond := v.(*tree.ForControl).Condition if cond == nil { return nil } return *cond - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) // update (list of right-padded) q.GetAndSendList(fc, func(v any) []any { @@ -340,41 +336,47 @@ func (s *JavaSender) sendForControl(fc *tree.ForControl, q *SendQueue) { return []any{*update} }, func(v any) any { return containerElementID(v) }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) + return fc } -func (s *JavaSender) sendForEachLoop(f *tree.ForEachLoop, q *SendQueue) { +func (s *JavaSender) VisitForEachLoop(f *tree.ForEachLoop, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(f, func(v any) any { ctrl := v.(*tree.ForEachLoop).Control return &ctrl - }, func(v any) { s.parent.Visit(v, q) }) + }, func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(f, func(v any) any { return tree.RightPadded[tree.Statement]{Element: v.(*tree.ForEachLoop).Body, After: tree.EmptySpace} - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) + return f } -func (s *JavaSender) sendForEachControl(fc *tree.ForEachControl, q *SendQueue) { +func (s *JavaSender) VisitForEachControl(fc *tree.ForEachControl, p any) tree.J { + q := p.(*SendQueue) // Go sends: key (right-padded), value (right-padded), operator (left-padded string), iterable // Java GolangReceiver override reads this format q.GetAndSend(fc, func(v any) any { k := v.(*tree.ForEachControl).Key if k == nil { return nil } return *k - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) q.GetAndSend(fc, func(v any) any { val := v.(*tree.ForEachControl).Value if val == nil { return nil } return *val - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) q.GetAndSend(fc, func(v any) any { op := v.(*tree.ForEachControl).Operator return tree.LeftPadded[string]{Before: op.Before, Element: op.Element.String(), Markers: op.Markers} - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(fc, func(v any) any { return v.(*tree.ForEachControl).Iterable }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return fc } -func (s *JavaSender) sendSwitch(sw *tree.Switch, q *SendQueue) { +func (s *JavaSender) VisitSwitch(sw *tree.Switch, p any) tree.J { + q := p.(*SendQueue) // selector - wrap tag in ControlParentheses for Java's J.Switch model q.GetAndSend(sw, func(v any) any { tag := v.(*tree.Switch).Tag @@ -390,96 +392,139 @@ func (s *JavaSender) sendSwitch(sw *tree.Switch, q *SendQueue) { Markers: tree.Markers{ID: uuid.New()}, Tree: tree.RightPadded[tree.Expression]{Element: inner, After: tree.EmptySpace}, } - }, func(v any) { s.parent.Visit(v, q) }) + }, func(v any) { s.Visit(v.(tree.Tree), q) }) // cases (Block) q.GetAndSend(sw, func(v any) any { return v.(*tree.Switch).Body }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return sw } -func (s *JavaSender) sendCase(c *tree.Case, q *SendQueue) { +func (s *JavaSender) VisitCase(c *tree.Case, p any) tree.J { + q := p.(*SendQueue) // type (enum value) q.GetAndSend(c, func(_ any) any { return "Statement" }, nil) // caseLabels (container) q.GetAndSend(c, func(v any) any { return v.(*tree.Case).Expressions }, - func(v any) { sendContainer(s.parent, v, q) }) + func(v any) { sendContainer(s, v, q) }) // statements (container) q.GetAndSend(c, func(v any) any { body := v.(*tree.Case).Body result := make([]tree.RightPadded[tree.Statement], len(body)) copy(result, body) return tree.Container[tree.Statement]{Elements: result} - }, func(v any) { sendContainer(s.parent, v, q) }) + }, func(v any) { sendContainer(s, v, q) }) // body (right-padded, nil for Go-style case) q.GetAndSend(c, func(_ any) any { return nil }, nil) // guard (nil for Go) q.GetAndSend(c, func(_ any) any { return nil }, nil) + return c } -func (s *JavaSender) sendBreak(b *tree.Break, q *SendQueue) { +func (s *JavaSender) VisitBreak(b *tree.Break, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(b, func(v any) any { return v.(*tree.Break).Label }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return b } -func (s *JavaSender) sendContinue(c *tree.Continue, q *SendQueue) { +func (s *JavaSender) VisitContinue(c *tree.Continue, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(c, func(v any) any { return v.(*tree.Continue).Label }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return c } -func (s *JavaSender) sendLabel(l *tree.Label, q *SendQueue) { +func (s *JavaSender) VisitLabel(l *tree.Label, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(l, func(v any) any { return v.(*tree.Label).Name }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) q.GetAndSend(l, func(v any) any { return v.(*tree.Label).Statement }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return l +} + +// sendAnnotation matches JavaSender.visitAnnotation field order: +// annotationType, then nullable arguments container. +func (s *JavaSender) VisitAnnotation(ann *tree.Annotation, p any) tree.J { + q := p.(*SendQueue) + q.GetAndSend(ann, func(v any) any { return v.(*tree.Annotation).AnnotationType }, + func(v any) { s.Visit(v.(tree.Tree), q) }) + q.GetAndSend(ann, func(v any) any { + args := v.(*tree.Annotation).Arguments + if args == nil { + return nil + } + return *args + }, func(v any) { sendContainer(s, v, q) }) + return ann } -func (s *JavaSender) sendUnary(u *tree.Unary, q *SendQueue) { +func (s *JavaSender) VisitUnary(u *tree.Unary, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(u, func(v any) any { op := v.(*tree.Unary).Operator return tree.LeftPadded[string]{Before: op.Before, Element: op.Element.String(), Markers: op.Markers} - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(u, func(v any) any { return v.(*tree.Unary).Operand }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(u, func(v any) any { return AsRef(v.(*tree.Unary).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return u } -func (s *JavaSender) sendFieldAccess(fa *tree.FieldAccess, q *SendQueue) { +func (s *JavaSender) VisitFieldAccess(fa *tree.FieldAccess, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(fa, func(v any) any { return v.(*tree.FieldAccess).Target }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(fa, func(v any) any { return v.(*tree.FieldAccess).Name }, - func(v any) { sendLeftPadded(s.parent, v, q) }) + func(v any) { sendLeftPadded(s, v, q) }) q.GetAndSend(fa, func(v any) any { return AsRef(v.(*tree.FieldAccess).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return fa } -func (s *JavaSender) sendMethodInvocation(mi *tree.MethodInvocation, q *SendQueue) { +func (s *JavaSender) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { + q := p.(*SendQueue) // select (right-padded, nullable) — dereference pointer q.GetAndSend(mi, func(v any) any { sel := v.(*tree.MethodInvocation).Select if sel == nil { return nil } return *sel - }, func(v any) { sendRightPadded(s.parent, v, q) }) + }, func(v any) { sendRightPadded(s, v, q) }) // typeParameters (nil for Go) q.GetAndSend(mi, func(_ any) any { return nil }, nil) // name q.GetAndSend(mi, func(v any) any { return v.(*tree.MethodInvocation).Name }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // arguments (container) q.GetAndSend(mi, func(v any) any { return v.(*tree.MethodInvocation).Arguments }, - func(v any) { sendContainer(s.parent, v, q) }) + func(v any) { sendContainer(s, v, q) }) // methodType (as ref) q.GetAndSend(mi, func(v any) any { return AsRef(v.(*tree.MethodInvocation).MethodType) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return mi } -func (s *JavaSender) sendVariableDeclarations(vd *tree.VariableDeclarations, q *SendQueue) { - // leadingAnnotations (empty -- Go has no annotations) - q.GetAndSendList(vd, func(_ any) []any { return []any{} }, func(_ any) any { return nil }, nil) +func (s *JavaSender) VisitVariableDeclarations(vd *tree.VariableDeclarations, p any) tree.J { + q := p.(*SendQueue) + // leadingAnnotations (struct field tags + `//go:` directives, + // modeled as J.Annotation per the Java contract) + q.GetAndSendList(vd, + func(v any) []any { + anns := v.(*tree.VariableDeclarations).LeadingAnnotations + result := make([]any, len(anns)) + for i, a := range anns { + result[i] = a + } + return result + }, + func(v any) any { return extractID(v) }, + func(v any) { s.Visit(v.(tree.Tree), q) }) // modifiers (empty -- Go has no modifiers) q.GetAndSendList(vd, func(_ any) []any { return []any{} }, func(_ any) any { return nil }, nil) // typeExpression q.GetAndSend(vd, func(v any) any { return v.(*tree.VariableDeclarations).TypeExpr }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // varargs q.GetAndSend(vd, func(v any) any { va := v.(*tree.VariableDeclarations).Varargs @@ -499,14 +544,16 @@ func (s *JavaSender) sendVariableDeclarations(vd *tree.VariableDeclarations, q * return result }, func(v any) any { return containerElementID(v) }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) + return vd } -func (s *JavaSender) sendVariableDeclarator(vd *tree.VariableDeclarator, q *SendQueue) { +func (s *JavaSender) VisitVariableDeclarator(vd *tree.VariableDeclarator, p any) tree.J { + q := p.(*SendQueue) // Java's NamedVariable: declarator (Identifier), dimensionsAfterName, initializer, variableType // Go: Name, Initializer q.GetAndSend(vd, func(v any) any { return v.(*tree.VariableDeclarator).Name }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // dimensionsAfterName (empty for Go) q.GetAndSendList(vd, func(_ any) []any { return nil }, func(_ any) any { return nil }, nil) // initializer (left-padded, nullable) — dereference pointer @@ -514,75 +561,92 @@ func (s *JavaSender) sendVariableDeclarator(vd *tree.VariableDeclarator, q *Send init := v.(*tree.VariableDeclarator).Initializer if init == nil { return nil } return *init - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) // variableType (as ref) - not yet on Go VariableDeclarator q.GetAndSend(vd, func(_ any) any { return nil }, nil) + return vd } -func (s *JavaSender) sendArrayType(at *tree.ArrayType, q *SendQueue) { +func (s *JavaSender) VisitArrayType(at *tree.ArrayType, p any) tree.J { + q := p.(*SendQueue) // elementType q.GetAndSend(at, func(v any) any { return v.(*tree.ArrayType).ElementType }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // annotations (empty for Go) q.GetAndSendList(at, func(_ any) []any { return nil }, func(_ any) any { return nil }, nil) // dimension (left-padded) q.GetAndSend(at, func(v any) any { return v.(*tree.ArrayType).Dimension }, - func(v any) { sendLeftPadded(s.parent, v, q) }) + func(v any) { sendLeftPadded(s, v, q) }) // type q.GetAndSend(at, func(v any) any { return v.(*tree.ArrayType).Type }, nil) + return at } -func (s *JavaSender) sendArrayAccess(aa *tree.ArrayAccess, q *SendQueue) { +func (s *JavaSender) VisitArrayAccess(aa *tree.ArrayAccess, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(aa, func(v any) any { return v.(*tree.ArrayAccess).Indexed }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(aa, func(v any) any { return v.(*tree.ArrayAccess).Dimension }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return aa } -func (s *JavaSender) sendParameterizedType(pt *tree.ParameterizedType, q *SendQueue) { +func (s *JavaSender) VisitParameterizedType(pt *tree.ParameterizedType, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(pt, func(v any) any { return v.(*tree.ParameterizedType).Clazz }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(pt, func(v any) any { return v.(*tree.ParameterizedType).TypeParameters }, - func(v any) { sendContainer(s.parent, v, q) }) + func(v any) { sendContainer(s, v, q) }) q.GetAndSend(pt, func(v any) any { return AsRef(v.(*tree.ParameterizedType).Type) }, func(v any) { s.visitType(GetValueNonNull(v).(tree.JavaType), q) }) + return pt } -func (s *JavaSender) sendArrayDimension(ad *tree.ArrayDimension, q *SendQueue) { +func (s *JavaSender) VisitArrayDimension(ad *tree.ArrayDimension, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(ad, func(v any) any { return ad.Index }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) + return ad } -func (s *JavaSender) sendParentheses(p *tree.Parentheses, q *SendQueue) { - q.GetAndSend(p, func(v any) any { return v.(*tree.Parentheses).Tree }, - func(v any) { sendRightPadded(s.parent, v, q) }) +func (s *JavaSender) VisitParentheses(parens *tree.Parentheses, p any) tree.J { + q := p.(*SendQueue) + q.GetAndSend(parens, func(v any) any { return v.(*tree.Parentheses).Tree }, + func(v any) { sendRightPadded(s, v, q) }) + return parens } -func (s *JavaSender) sendTypeCast(tc *tree.TypeCast, q *SendQueue) { +func (s *JavaSender) VisitTypeCast(tc *tree.TypeCast, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(tc, func(v any) any { return v.(*tree.TypeCast).Clazz }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) q.GetAndSend(tc, func(v any) any { return v.(*tree.TypeCast).Expr }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) + return tc } -func (s *JavaSender) sendControlParentheses(cp *tree.ControlParentheses, q *SendQueue) { +func (s *JavaSender) VisitControlParentheses(cp *tree.ControlParentheses, p any) tree.J { + q := p.(*SendQueue) q.GetAndSend(cp, func(v any) any { return v.(*tree.ControlParentheses).Tree }, - func(v any) { sendRightPadded(s.parent, v, q) }) + func(v any) { sendRightPadded(s, v, q) }) + return cp } -func (s *JavaSender) sendImport(imp *tree.Import, q *SendQueue) { +func (s *JavaSender) VisitImport(imp *tree.Import, p any) tree.J { + q := p.(*SendQueue) // Java Import: static (left-padded), qualid, alias (left-padded) // Static is always false for Go q.GetAndSend(imp, func(_ any) any { return tree.LeftPadded[bool]{Before: tree.EmptySpace, Element: false} - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) // qualid q.GetAndSend(imp, func(v any) any { return v.(*tree.Import).Qualid }, - func(v any) { s.parent.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) // alias — dereference pointer q.GetAndSend(imp, func(v any) any { a := v.(*tree.Import).Alias if a == nil { return nil } return *a - }, func(v any) { sendLeftPadded(s.parent, v, q) }) + }, func(v any) { sendLeftPadded(s, v, q) }) + return imp } diff --git a/rewrite-go/pkg/rpc/marker_codec_test.go b/rewrite-go/pkg/rpc/marker_codec_test.go new file mode 100644 index 00000000000..fec3e2ff959 --- /dev/null +++ b/rewrite-go/pkg/rpc/marker_codec_test.go @@ -0,0 +1,138 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rpc + +import ( + "reflect" + "testing" + + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// roundTripMarkers serializes `before` via SendMarkersCodec, then feeds the +// emitted RpcObjectData stream back into a ReceiveQueue and reads it via +// receiveMarkersCodec. Returns whatever the receiver produced. +func roundTripMarkers(t *testing.T, before tree.Markers) tree.Markers { + t.Helper() + var messages []RpcObjectData + sendQ := NewSendQueue(1000, func(batch []RpcObjectData) { + messages = append(messages, batch...) + }, make(map[uintptr]int)) + SendMarkersCodec(before, sendQ) + sendQ.Flush() + + // Pretend the wire delivers the captured stream as a single batch. + delivered := false + recvQ := NewReceiveQueue(make(map[int]any), func() []RpcObjectData { + if delivered { + return nil + } + delivered = true + return messages + }) + // receiveMarkersCodec expects the receive queue positioned at the + // Markers ID slot, matching what SendMarkersCodec emits. + return receiveMarkersCodec(recvQ, tree.Markers{}) +} + +func TestGoProjectMarkerRoundTrip(t *testing.T) { + id := uuid.MustParse("11111111-2222-3333-4444-555555555555") + gp := tree.GoProject{Ident: id, ProjectName: "example/foo"} + before := tree.Markers{ID: uuid.New(), Entries: []tree.Marker{gp}} + + after := roundTripMarkers(t, before) + if len(after.Entries) != 1 { + t.Fatalf("entries: want 1, got %d", len(after.Entries)) + } + got, ok := after.Entries[0].(tree.GoProject) + if !ok { + t.Fatalf("entry is %T, want tree.GoProject", after.Entries[0]) + } + if got.Ident != id { + t.Errorf("Ident: want %s, got %s", id, got.Ident) + } + if got.ProjectName != "example/foo" { + t.Errorf("ProjectName: want %q, got %q", "example/foo", got.ProjectName) + } +} + +func TestGoResolutionResultMarkerRoundTrip(t *testing.T) { + id := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + mrr := tree.GoResolutionResult{ + Ident: id, + ModulePath: "example.com/foo", + GoVersion: "1.22", + Toolchain: "go1.22.5", + Path: "/tmp/go.mod", + Requires: []tree.GoRequire{ + {ModulePath: "github.com/google/uuid", Version: "v1.6.0"}, + {ModulePath: "golang.org/x/mod", Version: "v0.35.0", Indirect: true}, + }, + Replaces: []tree.GoReplace{ + {OldPath: "github.com/x/y", NewPath: "../local/y"}, + {OldPath: "github.com/a/b", OldVersion: "v1.0.0", NewPath: "github.com/forked/b", NewVersion: "v1.0.1"}, + }, + Excludes: []tree.GoExclude{ + {ModulePath: "github.com/bad", Version: "v0.0.1"}, + }, + Retracts: []tree.GoRetract{ + {VersionRange: "v0.0.5", Rationale: "deleted main.go"}, + {VersionRange: "[v1.0.0, v1.0.5]"}, + }, + ResolvedDependencies: []tree.GoResolvedDependency{ + {ModulePath: "github.com/google/uuid", Version: "v1.6.0", ModuleHash: "h1:abc=", GoModHash: "h1:def="}, + }, + } + before := tree.Markers{ID: uuid.New(), Entries: []tree.Marker{mrr}} + + after := roundTripMarkers(t, before) + if len(after.Entries) != 1 { + t.Fatalf("entries: want 1, got %d", len(after.Entries)) + } + got, ok := after.Entries[0].(tree.GoResolutionResult) + if !ok { + t.Fatalf("entry is %T, want tree.GoResolutionResult", after.Entries[0]) + } + if !reflect.DeepEqual(mrr, got) { + t.Errorf("round-trip mismatch\nbefore: %+v\nafter: %+v", mrr, got) + } +} + +func TestGoResolutionResultEmptyListsRoundTrip(t *testing.T) { + // Mirrors the recent rewrite-core fix where empty descriptor collections + // must serialize as empty arrays, not be omitted. After round-trip an + // initially-empty list should still be present and empty. + id := uuid.MustParse("99999999-0000-0000-0000-000000000000") + mrr := tree.GoResolutionResult{ + Ident: id, + ModulePath: "example.com/empty", + Path: "go.mod", + Requires: []tree.GoRequire{}, + Replaces: []tree.GoReplace{}, + Excludes: []tree.GoExclude{}, + Retracts: []tree.GoRetract{}, + } + before := tree.Markers{ID: uuid.New(), Entries: []tree.Marker{mrr}} + + after := roundTripMarkers(t, before) + got := after.Entries[0].(tree.GoResolutionResult) + if got.ModulePath != "example.com/empty" { + t.Errorf("ModulePath: want %q, got %q", "example.com/empty", got.ModulePath) + } +} diff --git a/rewrite-go/pkg/rpc/node_helpers.go b/rewrite-go/pkg/rpc/node_helpers.go index 9f0bdf3ec6f..6e7ea4a278a 100644 --- a/rewrite-go/pkg/rpc/node_helpers.go +++ b/rewrite-go/pkg/rpc/node_helpers.go @@ -17,594 +17,64 @@ package rpc import ( + "reflect" + "github.com/google/uuid" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" ) -// nodeID extracts and formats the ID from any AST node for serialization. -func nodeID(v any) any { - id := extractID(v) - if id == nil { - return nil - } - // Convert uuid.UUID to string for JSON serialization - if u, ok := id.(uuid.UUID); ok { - return u.String() +// extractID returns the ID of any AST node, polymorphically via the J +// interface. Used as a list-element ID extractor in +// `q.GetAndSendList(...)`. Non-J values return uuid.Nil. +func extractID(v any) any { + if j, ok := v.(tree.J); ok { + return j.GetID() } - return id + return uuid.Nil } -// nodePrefix extracts the prefix Space from any AST node (G or J). -func nodePrefix(v any) any { - switch n := v.(type) { - // G nodes - case *tree.CompilationUnit: - return n.Prefix - case *tree.GoStmt: - return n.Prefix - case *tree.Defer: - return n.Prefix - case *tree.Send: - return n.Prefix - case *tree.Goto: - return n.Prefix - case *tree.Fallthrough: - return n.Prefix - case *tree.Composite: - return n.Prefix - case *tree.KeyValue: - return n.Prefix - case *tree.Slice: - return n.Prefix - case *tree.MapType: - return n.Prefix - case *tree.PointerType: - return n.Prefix - case *tree.StatementExpression: - return n.Prefix - case *tree.Channel: - return n.Prefix - case *tree.FuncType: - return n.Prefix - case *tree.StructType: - return n.Prefix - case *tree.InterfaceType: - return n.Prefix - case *tree.TypeList: - return n.Prefix - case *tree.TypeDecl: - return n.Prefix - case *tree.MultiAssignment: - return n.Prefix - case *tree.CommClause: - return n.Prefix - case *tree.IndexList: - return n.Prefix - // J nodes - case *tree.Identifier: - return n.Prefix - case *tree.Literal: - return n.Prefix - case *tree.Binary: - return n.Prefix - case *tree.Block: - return n.Prefix - case *tree.Return: - return n.Prefix - case *tree.If: - return n.Prefix - case *tree.Else: - return n.Prefix - case *tree.Assignment: - return n.Prefix - case *tree.AssignmentOperation: - return n.Prefix - case *tree.MethodDeclaration: - return n.Prefix - case *tree.ForLoop: - return n.Prefix - case *tree.ForControl: - return n.Prefix - case *tree.ForEachLoop: - return n.Prefix - case *tree.ForEachControl: - return n.Prefix - case *tree.Switch: - return n.Prefix - case *tree.Case: - return n.Prefix - case *tree.Break: - return n.Prefix - case *tree.Continue: - return n.Prefix - case *tree.Label: - return n.Prefix - case *tree.Empty: - return n.Prefix - case *tree.Unary: - return n.Prefix - case *tree.FieldAccess: - return n.Prefix - case *tree.MethodInvocation: - return n.Prefix - case *tree.VariableDeclarations: - return n.Prefix - case *tree.VariableDeclarator: - return n.Prefix - case *tree.ArrayType: - return n.Prefix - case *tree.ArrayAccess: - return n.Prefix - case *tree.ArrayDimension: - return n.Prefix - case *tree.Parentheses: - return n.Prefix - case *tree.TypeCast: - return n.Prefix - case *tree.ControlParentheses: - return n.Prefix - case *tree.Import: - return n.Prefix - default: - return tree.EmptySpace +// withPrefixViaReflection invokes the concrete type's typed +// `WithPrefix(Space) *T` method via reflection and returns the result +// as tree.Tree. Mirrors what rewrite-java's `j.withPrefix(...)` does +// polymorphically — Go can't express that on the J interface because +// the typed method returns the concrete pointer (not J), and Go +// doesn't allow covariant return types on interface methods. Reflection +// is the cleanest way to call it without exposing in-place mutators +// on the public J interface (which would silently bypass the visitor +// framework's pointer-identity change detection). +// +// Used only by JavaReceiver.PreVisit. Returns t unchanged if the type +// doesn't implement WithPrefix (defensive — every J-conformant type +// does). +func withPrefixViaReflection(t tree.Tree, prefix tree.Space) tree.Tree { + rv := reflect.ValueOf(t) + m := rv.MethodByName("WithPrefix") + if !m.IsValid() { + return t } -} - -// nodeMarkers extracts the Markers from any AST node (G or J). -func nodeMarkers(v any) any { - switch n := v.(type) { - // G nodes - case *tree.CompilationUnit: - return n.Markers - case *tree.GoStmt: - return n.Markers - case *tree.Defer: - return n.Markers - case *tree.Send: - return n.Markers - case *tree.Goto: - return n.Markers - case *tree.Fallthrough: - return n.Markers - case *tree.Composite: - return n.Markers - case *tree.KeyValue: - return n.Markers - case *tree.Slice: - return n.Markers - case *tree.MapType: - return n.Markers - case *tree.PointerType: - return n.Markers - case *tree.StatementExpression: - return n.Markers - case *tree.Channel: - return n.Markers - case *tree.FuncType: - return n.Markers - case *tree.StructType: - return n.Markers - case *tree.InterfaceType: - return n.Markers - case *tree.TypeList: - return n.Markers - case *tree.TypeDecl: - return n.Markers - case *tree.MultiAssignment: - return n.Markers - case *tree.CommClause: - return n.Markers - case *tree.IndexList: - return n.Markers - // J nodes - case *tree.Identifier: - return n.Markers - case *tree.Literal: - return n.Markers - case *tree.Binary: - return n.Markers - case *tree.Block: - return n.Markers - case *tree.Return: - return n.Markers - case *tree.If: - return n.Markers - case *tree.Else: - return n.Markers - case *tree.Assignment: - return n.Markers - case *tree.AssignmentOperation: - return n.Markers - case *tree.MethodDeclaration: - return n.Markers - case *tree.ForLoop: - return n.Markers - case *tree.ForControl: - return n.Markers - case *tree.ForEachLoop: - return n.Markers - case *tree.ForEachControl: - return n.Markers - case *tree.Switch: - return n.Markers - case *tree.Case: - return n.Markers - case *tree.Break: - return n.Markers - case *tree.Continue: - return n.Markers - case *tree.Label: - return n.Markers - case *tree.Empty: - return n.Markers - case *tree.Unary: - return n.Markers - case *tree.FieldAccess: - return n.Markers - case *tree.MethodInvocation: - return n.Markers - case *tree.VariableDeclarations: - return n.Markers - case *tree.VariableDeclarator: - return n.Markers - case *tree.ArrayType: - return n.Markers - case *tree.ArrayAccess: - return n.Markers - case *tree.ArrayDimension: - return n.Markers - case *tree.Parentheses: - return n.Markers - case *tree.TypeCast: - return n.Markers - case *tree.ControlParentheses: - return n.Markers - case *tree.Import: - return n.Markers - default: - return tree.Markers{} + results := m.Call([]reflect.Value{reflect.ValueOf(prefix)}) + if len(results) == 0 { + return t } -} - -// setPrefix sets the prefix Space on any AST node (G or J). -func setPrefix(v any, prefix tree.Space) { - switch n := v.(type) { - // G nodes - case *tree.CompilationUnit: - n.Prefix = prefix - case *tree.GoStmt: - n.Prefix = prefix - case *tree.Defer: - n.Prefix = prefix - case *tree.Send: - n.Prefix = prefix - case *tree.Goto: - n.Prefix = prefix - case *tree.Fallthrough: - n.Prefix = prefix - case *tree.Composite: - n.Prefix = prefix - case *tree.KeyValue: - n.Prefix = prefix - case *tree.Slice: - n.Prefix = prefix - case *tree.MapType: - n.Prefix = prefix - case *tree.PointerType: - n.Prefix = prefix - case *tree.StatementExpression: - n.Prefix = prefix - case *tree.Channel: - n.Prefix = prefix - case *tree.FuncType: - n.Prefix = prefix - case *tree.StructType: - n.Prefix = prefix - case *tree.InterfaceType: - n.Prefix = prefix - case *tree.TypeList: - n.Prefix = prefix - case *tree.TypeDecl: - n.Prefix = prefix - case *tree.MultiAssignment: - n.Prefix = prefix - case *tree.CommClause: - n.Prefix = prefix - case *tree.IndexList: - n.Prefix = prefix - // J nodes - case *tree.Identifier: - n.Prefix = prefix - case *tree.Literal: - n.Prefix = prefix - case *tree.Binary: - n.Prefix = prefix - case *tree.Block: - n.Prefix = prefix - case *tree.Return: - n.Prefix = prefix - case *tree.If: - n.Prefix = prefix - case *tree.Else: - n.Prefix = prefix - case *tree.Assignment: - n.Prefix = prefix - case *tree.AssignmentOperation: - n.Prefix = prefix - case *tree.MethodDeclaration: - n.Prefix = prefix - case *tree.ForLoop: - n.Prefix = prefix - case *tree.ForControl: - n.Prefix = prefix - case *tree.ForEachLoop: - n.Prefix = prefix - case *tree.ForEachControl: - n.Prefix = prefix - case *tree.Switch: - n.Prefix = prefix - case *tree.Case: - n.Prefix = prefix - case *tree.Break: - n.Prefix = prefix - case *tree.Continue: - n.Prefix = prefix - case *tree.Label: - n.Prefix = prefix - case *tree.Empty: - n.Prefix = prefix - case *tree.Unary: - n.Prefix = prefix - case *tree.FieldAccess: - n.Prefix = prefix - case *tree.MethodInvocation: - n.Prefix = prefix - case *tree.VariableDeclarations: - n.Prefix = prefix - case *tree.VariableDeclarator: - n.Prefix = prefix - case *tree.ArrayType: - n.Prefix = prefix - case *tree.ArrayAccess: - n.Prefix = prefix - case *tree.ArrayDimension: - n.Prefix = prefix - case *tree.Parentheses: - n.Prefix = prefix - case *tree.TypeCast: - n.Prefix = prefix - case *tree.ControlParentheses: - n.Prefix = prefix - case *tree.Import: - n.Prefix = prefix + if r, ok := results[0].Interface().(tree.Tree); ok { + return r } + return t } -// setMarkers sets the Markers on any AST node (G or J). -func setMarkers(v any, markers tree.Markers) { - switch n := v.(type) { - // G nodes - case *tree.CompilationUnit: - n.Markers = markers - case *tree.GoStmt: - n.Markers = markers - case *tree.Defer: - n.Markers = markers - case *tree.Send: - n.Markers = markers - case *tree.Goto: - n.Markers = markers - case *tree.Fallthrough: - n.Markers = markers - case *tree.Composite: - n.Markers = markers - case *tree.KeyValue: - n.Markers = markers - case *tree.Slice: - n.Markers = markers - case *tree.MapType: - n.Markers = markers - case *tree.PointerType: - n.Markers = markers - case *tree.StatementExpression: - n.Markers = markers - case *tree.Channel: - n.Markers = markers - case *tree.FuncType: - n.Markers = markers - case *tree.StructType: - n.Markers = markers - case *tree.InterfaceType: - n.Markers = markers - case *tree.TypeList: - n.Markers = markers - case *tree.TypeDecl: - n.Markers = markers - case *tree.MultiAssignment: - n.Markers = markers - case *tree.CommClause: - n.Markers = markers - case *tree.IndexList: - n.Markers = markers - // J nodes - case *tree.Identifier: - n.Markers = markers - case *tree.Literal: - n.Markers = markers - case *tree.Binary: - n.Markers = markers - case *tree.Block: - n.Markers = markers - case *tree.Return: - n.Markers = markers - case *tree.If: - n.Markers = markers - case *tree.Else: - n.Markers = markers - case *tree.Assignment: - n.Markers = markers - case *tree.AssignmentOperation: - n.Markers = markers - case *tree.MethodDeclaration: - n.Markers = markers - case *tree.ForLoop: - n.Markers = markers - case *tree.ForControl: - n.Markers = markers - case *tree.ForEachLoop: - n.Markers = markers - case *tree.ForEachControl: - n.Markers = markers - case *tree.Switch: - n.Markers = markers - case *tree.Case: - n.Markers = markers - case *tree.Break: - n.Markers = markers - case *tree.Continue: - n.Markers = markers - case *tree.Label: - n.Markers = markers - case *tree.Empty: - n.Markers = markers - case *tree.Unary: - n.Markers = markers - case *tree.FieldAccess: - n.Markers = markers - case *tree.MethodInvocation: - n.Markers = markers - case *tree.VariableDeclarations: - n.Markers = markers - case *tree.VariableDeclarator: - n.Markers = markers - case *tree.ArrayType: - n.Markers = markers - case *tree.ArrayAccess: - n.Markers = markers - case *tree.ArrayDimension: - n.Markers = markers - case *tree.Parentheses: - n.Markers = markers - case *tree.TypeCast: - n.Markers = markers - case *tree.ControlParentheses: - n.Markers = markers - case *tree.Import: - n.Markers = markers +// withMarkersViaReflection is the WithMarkers counterpart. +func withMarkersViaReflection(t tree.Tree, markers tree.Markers) tree.Tree { + rv := reflect.ValueOf(t) + m := rv.MethodByName("WithMarkers") + if !m.IsValid() { + return t } -} - -// extractID extracts the uuid.UUID ID from any AST node (G or J). -func extractID(v any) any { - switch t := v.(type) { - // J nodes - case *tree.Identifier: - return t.ID - case *tree.Literal: - return t.ID - case *tree.Binary: - return t.ID - case *tree.Block: - return t.ID - case *tree.MethodDeclaration: - return t.ID - case *tree.MethodInvocation: - return t.ID - case *tree.VariableDeclarations: - return t.ID - case *tree.VariableDeclarator: - return t.ID - case *tree.Assignment: - return t.ID - case *tree.Return: - return t.ID - case *tree.If: - return t.ID - case *tree.Else: - return t.ID - case *tree.ForLoop: - return t.ID - case *tree.ForEachLoop: - return t.ID - case *tree.Switch: - return t.ID - case *tree.Case: - return t.ID - case *tree.Import: - return t.ID - case *tree.Empty: - return t.ID - case *tree.Unary: - return t.ID - case *tree.FieldAccess: - return t.ID - case *tree.ArrayAccess: - return t.ID - case *tree.ArrayType: - return t.ID - case *tree.Parentheses: - return t.ID - case *tree.TypeCast: - return t.ID - // G nodes - case *tree.CompilationUnit: - return t.ID - case *tree.GoStmt: - return t.ID - case *tree.Defer: - return t.ID - case *tree.Send: - return t.ID - case *tree.Goto: - return t.ID - case *tree.Fallthrough: - return t.ID - case *tree.Composite: - return t.ID - case *tree.KeyValue: - return t.ID - case *tree.Slice: - return t.ID - case *tree.MapType: - return t.ID - case *tree.PointerType: - return t.ID - case *tree.Channel: - return t.ID - case *tree.FuncType: - return t.ID - case *tree.StructType: - return t.ID - case *tree.InterfaceType: - return t.ID - case *tree.TypeList: - return t.ID - case *tree.TypeDecl: - return t.ID - case *tree.MultiAssignment: - return t.ID - case *tree.CommClause: - return t.ID - case *tree.IndexList: - return t.ID - // Additional J nodes - case *tree.AssignmentOperation: - return t.ID - case *tree.ForControl: - return t.ID - case *tree.ForEachControl: - return t.ID - case *tree.Break: - return t.ID - case *tree.Continue: - return t.ID - case *tree.Label: - return t.ID - case *tree.ArrayDimension: - return t.ID - case *tree.ControlParentheses: - return t.ID - default: - return uuid.Nil + results := m.Call([]reflect.Value{reflect.ValueOf(markers)}) + if len(results) == 0 { + return t + } + if r, ok := results[0].Interface().(tree.Tree); ok { + return r } + return t } diff --git a/rewrite-go/pkg/rpc/padding_rpc.go b/rewrite-go/pkg/rpc/padding_rpc.go index 196bc0fffc9..2cfef814a48 100644 --- a/rewrite-go/pkg/rpc/padding_rpc.go +++ b/rewrite-go/pkg/rpc/padding_rpc.go @@ -27,7 +27,7 @@ func sendRightPadded(s Sender, rp any, q *SendQueue) { elem := rightPaddedElement(rp) if _, ok := elem.(tree.J); ok { q.GetAndSend(rp, func(v any) any { return rightPaddedElement(v) }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) } else { // Non-J elements (primitives, etc.) are sent as raw values q.GetAndSend(rp, func(v any) any { return rightPaddedElement(v) }, nil) @@ -67,7 +67,7 @@ func sendLeftPadded(s Sender, lp any, q *SendQueue) { func(v any) { sendSpace(v.(tree.Space), q) }) case tree.J: q.GetAndSend(lp, func(v any) any { return leftPaddedElement(v) }, - func(v any) { s.Visit(v, q) }) + func(v any) { s.Visit(v.(tree.Tree), q) }) default: // Primitives (strings, enums, bools) are sent as raw values with nil onChange q.GetAndSend(lp, func(v any) any { return leftPaddedElement(v) }, nil) @@ -98,7 +98,7 @@ func receiveRightPadded(r Receiver, q *ReceiveQueue, before any) any { // Element elem := q.Receive(rightPaddedElement(before), func(v any) any { if _, ok := v.(tree.J); ok { - return r.Visit(v, q) + return r.Visit(v.(tree.Tree), q) } return v }) @@ -136,7 +136,7 @@ func receiveLeftPadded(r Receiver, q *ReceiveQueue, before any) any { return receiveSpace(v.(tree.Space), q) } if _, ok := v.(tree.J); ok { - return r.Visit(v, q) + return r.Visit(v.(tree.Tree), q) } return v }) diff --git a/rewrite-go/pkg/rpc/receive_queue.go b/rewrite-go/pkg/rpc/receive_queue.go index 9edb8151484..98e2e128011 100644 --- a/rewrite-go/pkg/rpc/receive_queue.go +++ b/rewrite-go/pkg/rpc/receive_queue.go @@ -93,7 +93,11 @@ func (q *ReceiveQueue) Receive(before any, onChange func(any) any) any { if onChange != nil { after = onChange(before) } else if !isNilValue(before) && getValueType(before) != nil { - after = defaultReceiver.Visit(before, q) + if t, ok := before.(tree.Tree); ok { + after = defaultReceiver.Visit(t, q) + } else { + after = before + } } else if msg.Value != nil { after = msg.Value } else { diff --git a/rewrite-go/pkg/rpc/send_queue.go b/rewrite-go/pkg/rpc/send_queue.go index 2db9db57298..06f2afa8e7e 100644 --- a/rewrite-go/pkg/rpc/send_queue.go +++ b/rewrite-go/pkg/rpc/send_queue.go @@ -219,7 +219,9 @@ func (q *SendQueue) doChange(after, before any, onChange func(any)) { if onChange != nil && !isNilValue(after) { onChange(after) } else if onChange == nil && !isNilValue(after) && getValueType(after) != nil { - defaultSender.Visit(after, q) + if t, ok := after.(tree.Tree); ok { + defaultSender.Visit(t, q) + } } } diff --git a/rewrite-go/pkg/rpc/space_rpc.go b/rewrite-go/pkg/rpc/space_rpc.go index 7962e5678a2..c9230533722 100644 --- a/rewrite-go/pkg/rpc/space_rpc.go +++ b/rewrite-go/pkg/rpc/space_rpc.go @@ -29,6 +29,17 @@ func nilableString(s *string) any { return *s } +// emptyAsNil converts an empty Go string to wire-null, mirroring how Java +// serializes a {@code @Nullable String} field whose value is null. Use this +// for marker fields that are typed `string` on the Go side but `@Nullable` +// on the Java side (where the Java-side empty case is null, not ""). +func emptyAsNil(s string) any { + if s == "" { + return nil + } + return s +} + // sendSpace serializes a Space to the send queue. // Matches JavaSender.visitSpace field order: comments (list), whitespace. func sendSpace(s tree.Space, q *SendQueue) { @@ -173,6 +184,17 @@ func sendMarkerCodecFields(v any, q *SendQueue) { q.GetAndSend(m, func(x any) any { return x.(tree.TrailingComma).Ident.String() }, nil) q.GetAndSend(m, func(x any) any { return x.(tree.TrailingComma).Before.Whitespace }, nil) q.GetAndSend(m, func(x any) any { return x.(tree.TrailingComma).After.Whitespace }, nil) + case tree.Semicolon: + // Semicolon.rpcSend sends: id (UUID string) + q.GetAndSend(m, func(x any) any { return x.(tree.Semicolon).Ident.String() }, nil) + case tree.GoProject: + // GoProject.rpcSend sends: id (UUID string), projectName (string) + q.GetAndSend(m, func(x any) any { return x.(tree.GoProject).Ident.String() }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.GoProject).ProjectName }, nil) + case tree.GoResolutionResult: + // Field order mirrors Java's GoResolutionResult#rpcSend exactly; + // see go_resolution_result_codec.go for the per-field commentary. + sendGoResolutionResult(m, q) case tree.GenericMarker: // Send codec sub-fields matching what Java expects d := m.Data @@ -382,6 +404,25 @@ func receiveMarkersCodec(q *ReceiveQueue, before tree.Markers) tree.Markers { afterWs := receiveScalar[string](q, m.After.Whitespace) m.After = tree.Space{Whitespace: afterWs} return m + case tree.Semicolon: + idStr := receiveScalar[string](q, m.Ident.String()) + if idStr != "" { + if parsed, err := uuid.Parse(idStr); err == nil { + m.Ident = parsed + } + } + return m + case tree.GoProject: + idStr := receiveScalar[string](q, m.Ident.String()) + if idStr != "" { + if parsed, err := uuid.Parse(idStr); err == nil { + m.Ident = parsed + } + } + m.ProjectName = receiveScalar[string](q, m.ProjectName) + return m + case tree.GoResolutionResult: + return receiveGoResolutionResult(m, q) case tree.GenericMarker: // Read codec sub-fields based on the original Java marker type. // Each Java marker's rpcSend sends specific sub-fields that we must consume. diff --git a/rewrite-go/pkg/rpc/value_types.go b/rewrite-go/pkg/rpc/value_types.go index 6e5182f4332..003a14a3367 100644 --- a/rewrite-go/pkg/rpc/value_types.go +++ b/rewrite-go/pkg/rpc/value_types.go @@ -68,6 +68,7 @@ func init() { RegisterValueType(reflect.TypeOf((*tree.Continue)(nil)), "org.openrewrite.java.tree.J$Continue") RegisterValueType(reflect.TypeOf((*tree.Label)(nil)), "org.openrewrite.java.tree.J$Label") RegisterValueType(reflect.TypeOf((*tree.Empty)(nil)), "org.openrewrite.java.tree.J$Empty") + RegisterValueType(reflect.TypeOf((*tree.Annotation)(nil)), "org.openrewrite.java.tree.J$Annotation") RegisterValueType(reflect.TypeOf((*tree.Unary)(nil)), "org.openrewrite.java.tree.J$Unary") RegisterValueType(reflect.TypeOf((*tree.FieldAccess)(nil)), "org.openrewrite.java.tree.J$FieldAccess") RegisterValueType(reflect.TypeOf((*tree.MethodInvocation)(nil)), "org.openrewrite.java.tree.J$MethodInvocation") @@ -104,6 +105,16 @@ func init() { RegisterValueType(reflect.TypeOf(tree.TrailingComma{}), "org.openrewrite.golang.marker.TrailingComma") RegisterValueType(reflect.TypeOf(tree.SearchResult{}), "org.openrewrite.marker.SearchResult") RegisterValueType(reflect.TypeOf(tree.ParseExceptionResult{}), "org.openrewrite.ParseExceptionResult") + RegisterValueType(reflect.TypeOf(tree.Semicolon{}), "org.openrewrite.java.marker.Semicolon") + RegisterValueType(reflect.TypeOf(tree.GoProject{}), "org.openrewrite.golang.marker.GoProject") + RegisterValueType(reflect.TypeOf(tree.GoResolutionResult{}), "org.openrewrite.golang.marker.GoResolutionResult") + // Inner-class types of GoResolutionResult; Java emits them via + // getAndSendListAsRef so each item's wire shape carries a valueType. + RegisterValueType(reflect.TypeOf(tree.GoRequire{}), "org.openrewrite.golang.marker.GoResolutionResult$Require") + RegisterValueType(reflect.TypeOf(tree.GoReplace{}), "org.openrewrite.golang.marker.GoResolutionResult$Replace") + RegisterValueType(reflect.TypeOf(tree.GoExclude{}), "org.openrewrite.golang.marker.GoResolutionResult$Exclude") + RegisterValueType(reflect.TypeOf(tree.GoRetract{}), "org.openrewrite.golang.marker.GoResolutionResult$Retract") + RegisterValueType(reflect.TypeOf(tree.GoResolvedDependency{}), "org.openrewrite.golang.marker.GoResolutionResult$ResolvedDependency") // JavaType types RegisterValueType(reflect.TypeOf((*tree.JavaTypeClass)(nil)), "org.openrewrite.java.tree.JavaType$Class") @@ -161,6 +172,7 @@ func init() { RegisterFactory("org.openrewrite.java.tree.J$Continue", func() any { return &tree.Continue{} }) RegisterFactory("org.openrewrite.java.tree.J$Label", func() any { return &tree.Label{} }) RegisterFactory("org.openrewrite.java.tree.J$Empty", func() any { return &tree.Empty{} }) + RegisterFactory("org.openrewrite.java.tree.J$Annotation", func() any { return &tree.Annotation{} }) RegisterFactory("org.openrewrite.java.tree.J$Unary", func() any { return &tree.Unary{} }) RegisterFactory("org.openrewrite.java.tree.J$FieldAccess", func() any { return &tree.FieldAccess{} }) RegisterFactory("org.openrewrite.java.tree.J$MethodInvocation", func() any { return &tree.MethodInvocation{} }) @@ -209,6 +221,18 @@ func init() { RegisterFactory("org.openrewrite.golang.marker.TypeSwitchGuard", func() any { return tree.TypeSwitchGuard{} }) RegisterFactory("org.openrewrite.golang.marker.StructTag", func() any { return tree.StructTag{} }) RegisterFactory("org.openrewrite.golang.marker.TrailingComma", func() any { return tree.TrailingComma{} }) + // Semicolon: RpcCodec on the Java side; sends only `id`. Replaces the + // previous GenericMarker fallback for the same Java FQN. + RegisterFactory("org.openrewrite.java.marker.Semicolon", func() any { return tree.Semicolon{} }) + // GoProject + GoResolutionResult are RpcCodec on the Java side; codec + // dispatch lives in space_rpc.go's sendMarkerCodecFields / receiveMarkersCodec. + RegisterFactory("org.openrewrite.golang.marker.GoProject", func() any { return tree.GoProject{} }) + RegisterFactory("org.openrewrite.golang.marker.GoResolutionResult", func() any { return tree.GoResolutionResult{} }) + RegisterFactory("org.openrewrite.golang.marker.GoResolutionResult$Require", func() any { return tree.GoRequire{} }) + RegisterFactory("org.openrewrite.golang.marker.GoResolutionResult$Replace", func() any { return tree.GoReplace{} }) + RegisterFactory("org.openrewrite.golang.marker.GoResolutionResult$Exclude", func() any { return tree.GoExclude{} }) + RegisterFactory("org.openrewrite.golang.marker.GoResolutionResult$Retract", func() any { return tree.GoRetract{} }) + RegisterFactory("org.openrewrite.golang.marker.GoResolutionResult$ResolvedDependency", func() any { return tree.GoResolvedDependency{} }) RegisterFactory("org.openrewrite.java.tree.Space", func() any { return tree.Space{} }) RegisterFactory("org.openrewrite.marker.Markers", func() any { return tree.Markers{} }) diff --git a/rewrite-go/pkg/template/PARITY-AUDIT.md b/rewrite-go/pkg/template/PARITY-AUDIT.md new file mode 100644 index 00000000000..5ed74bd021d --- /dev/null +++ b/rewrite-go/pkg/template/PARITY-AUDIT.md @@ -0,0 +1,106 @@ +# GoTemplate ↔ JavaTemplate parity audit + +Item (10) of the rewrite-go parity plan asked for ergonomic parity between +`pkg/template/GoTemplate` and `org.openrewrite.java.JavaTemplate`. This +document lists every public method on `JavaTemplate` (Java) and maps it to +the equivalent surface on `GoTemplate` (Go), noting what was already +present, what was added in this PR, and what is intentionally deferred. + +## Audit summary + +| Surface | JavaTemplate | GoTemplate | Status | +|---|---|---|---| +| Builder | `JavaTemplate.builder(code)` | `template.ExpressionTemplate(code)` / `StatementTemplate(code)` / `TopLevelTemplate(code)` | ✓ shipped (kind-explicit factories preferred over a single overloaded builder) | +| Build the template | `.build()` | `.Build()` | ✓ shipped | +| Required imports | `.imports(String...)` | `.Imports(...string)` | ✓ shipped | +| Static imports | `.staticImports(String...)` | n/a | not applicable — Go has no static-import concept | +| Coordinate-based substitution | `.apply(JavaCoordinates, params...)` | `.Apply(cursor, *MatchResult)` | ✓ shipped via match captures (see deferred note below) | +| Pattern → template rewrite | `JavaIsoVisitor` + `JavaTemplate.apply` per visit | `template.Rewrite(before, after)` returns a `RewriteVisitor` | ✓ shipped (single-call ergonomic that matches and replaces in one step — Go-side delta over Java) | +| Context-sensitive parsing | `.contextSensitive()` | not yet | deferred (recipes in the wild rarely flip this on for refactoring; revisit if a real recipe asks) | +| Named placeholders | `#{name}` substitution by name + type constraint | positional `#{X}` capture-by-name through `*Capture` | ✓ named via `*Capture` already; type constraints are deferred (see below) | +| Type-checked named placeholders | `#{name:any(java.util.List)}` | not yet | deferred — out-of-scope per the eng review's v1 scope cut | +| Cursor-aware insertion | parameter to `.apply(cursor, ...)` | parameter to `.Apply(cursor, ...)` | ✓ shipped — cursor is threaded but unused in the v1 substitution engine; placeholder for future block/scope-aware substitution | + +## Already present before this PR (no delta required) + +The Go-side template engine is ~740 LOC and pre-dates the parity work. +Surface that was already at parity: + +- `TemplateBuilder` with a fluent API (`Captures`, `Imports`, `Build`). +- Three template kinds (`ExpressionTemplate`, `StatementTemplate`, + `TopLevelTemplate`) — this is more explicit than Java's overloaded + `JavaTemplate.builder` (which infers the kind from the substitution + coordinate). Recipe authors don't need to know coordinate semantics. +- `Apply(cursor, *MatchResult)` returns the substituted subtree with + capture values spliced in. +- `Rewrite(before, after)` packages match-and-replace into a single + `RewriteVisitor` — convenient for 1:1 rewrites. +- `getLeadingPrefix` / `setLeadingPrefix` preserve formatting on the + outer node when a template replaces an existing subtree (e.g. the + prefix on a `MethodInvocation.Select.Element` survives the swap). +- Scaffold-based parser (`pkg/template/scaffold.go`) compiles a template + string into an AST that's cached per `GoTemplate` instance. + +## Deferred (intentional out-of-scope items) + +These are explicitly out-of-scope per the eng review: + +1. **Type-checked named placeholders** (`#{name:any(...)}`). v1 keeps + capture-typed placeholders. Rationale: Go's lighter type system makes + the constraint syntax less load-bearing for refactor recipes; revisit + when a real recipe asks for it. +2. **`contextSensitive()` parse mode.** JavaTemplate flips this on when + the template references symbols only resolvable from the surrounding + cursor (e.g. inner-class names). The Go scaffold parser is already + "context-light" by default (it doesn't attempt to resolve the + template's own references against the call site's environment), so + the explicit toggle adds little until Go-specific use cases surface. +3. **Static imports.** Java's `staticImports` adds `import static …` + declarations. Go has no static-import concept; the Go template + compiler ignores the surface entirely. +4. **Coordinate API surface (`JavaCoordinates`).** Java's + `apply(coordinates, params)` lets recipes splice templates *before* / + *after* / *replace* a target node. The Go equivalent is the + pattern-match approach: write a `GoPattern` for the target, write a + `GoTemplate` for the replacement, and use `Rewrite(before, after)`. + Adding a coordinate API on top is feasible but would duplicate the + pattern surface; we'll add it only if a recipe actually needs splice + semantics that Pattern→Template doesn't cover. + +## What recipe authors should know + +For most refactors, the Go-side template surface is what you want: + +```go +import "github.com/openrewrite/rewrite/rewrite-go/pkg/template" + +before := template.ExpressionPattern(`errors.Is(#{X}, #{Y})`).Build() +after := template.ExpressionTemplate(`xerrors.Is(#{X}, #{Y})`).Imports("xerrors").Build() +visitor := template.Rewrite(before, after) +``` + +For inserting a fresh statement (no before-match): + +```go +tmpl := template.StatementTemplate(`fmt.Println("hi")`).Imports("fmt").Build() +result := tmpl.Apply(cursor, nil) +``` + +The Java → Go porting cheat-sheet: + +| Java | Go | +|----------------------------------------|----------------------------------------------------------| +| `JavaTemplate.builder("…").build()` | `template.StatementTemplate("…").Build()` | +| `.imports("foo")` | `.Imports("foo")` | +| `.apply(getCursor(), JavaCoordinates.replace(target), arg1, arg2)` | match `target` with a `GoPattern`, then `template.Rewrite(before, after)` — the `RewriteVisitor` does the splice | +| `#{any()}` as a wildcard placeholder | `template.ExpressionPattern("#{X}").Build()` with an unconstrained `Capture` | +| `#{name:any(java.util.List)}` | not yet — file an issue if you hit this | + +## Conclusion + +GoTemplate's surface is at functional parity with `JavaTemplate` for the +common refactor patterns. The surface differences are mostly stylistic +(explicit kind factories vs. a single overloaded builder) or +intentionally narrowed (no `staticImports`, no coordinate API). The +deferred items are all opt-in features; the default Go template +experience covers what recipes-go authors need today. diff --git a/rewrite-go/pkg/template/template_recipe.go b/rewrite-go/pkg/template/template_recipe.go index cfae189c1fd..a116f448141 100644 --- a/rewrite-go/pkg/template/template_recipe.go +++ b/rewrite-go/pkg/template/template_recipe.go @@ -276,6 +276,11 @@ func (r *builtTemplateRecipe) EstimatedEffortPerOccurrence() time.Duration { ret func (r *builtTemplateRecipe) Editor() recipe.TreeVisitor { return r.editor } func (r *builtTemplateRecipe) RecipeList() []recipe.Recipe { return nil } func (r *builtTemplateRecipe) Options() []recipe.OptionDescriptor { return nil } +func (r *builtTemplateRecipe) Preconditions() []recipe.Recipe { return nil } +func (r *builtTemplateRecipe) DataTables() []recipe.DataTableDescriptor { return nil } +func (r *builtTemplateRecipe) Maintainers() []recipe.Maintainer { return nil } +func (r *builtTemplateRecipe) Contributors() []recipe.Contributor { return nil } +func (r *builtTemplateRecipe) Examples() []recipe.Example { return nil } // --- Embeddable TemplateRecipe struct --- diff --git a/rewrite-go/pkg/test/expect_type.go b/rewrite-go/pkg/test/expect_type.go new file mode 100644 index 00000000000..075b7c76c2f --- /dev/null +++ b/rewrite-go/pkg/test/expect_type.go @@ -0,0 +1,137 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// ExpectType walks the tree rooted at root and asserts that the first +// identifier whose Name == name carries a fully-qualified type whose FQN +// matches expectedFQN. Use this for class/struct/parameterized types; for +// primitives use ExpectPrimitiveType. +// +// Fails the test if no matching identifier is found, if its Type is nil, +// or if the type does not implement tree.FullyQualified. +func ExpectType(t *testing.T, root tree.Tree, name string, expectedFQN string) { + t.Helper() + c := visitor.Init(&identifierTypeCollector{name: name}) + c.Visit(root, nil) + if !c.found { + t.Fatalf("ExpectType(%q): no identifier with that name in tree", name) + } + if c.typ == nil { + t.Fatalf("ExpectType(%q): identifier has nil Type", name) + } + fq, ok := c.typ.(tree.FullyQualified) + if !ok { + t.Fatalf("ExpectType(%q): identifier Type is %T, want FullyQualified", name, c.typ) + } + if got := fq.GetFullyQualifiedName(); got != expectedFQN { + t.Errorf("ExpectType(%q): FQN = %q, want %q", name, got, expectedFQN) + } +} + +// ExpectPrimitiveType asserts that the first identifier named `name` has a +// JavaTypePrimitive whose Keyword matches expectedKeyword (e.g. "int", +// "String", "bool"). Mirrors ExpectType for primitive type attribution. +func ExpectPrimitiveType(t *testing.T, root tree.Tree, name string, expectedKeyword string) { + t.Helper() + c := visitor.Init(&identifierTypeCollector{name: name}) + c.Visit(root, nil) + if !c.found { + t.Fatalf("ExpectPrimitiveType(%q): no identifier with that name in tree", name) + } + if c.typ == nil { + t.Fatalf("ExpectPrimitiveType(%q): identifier has nil Type", name) + } + prim, ok := c.typ.(*tree.JavaTypePrimitive) + if !ok { + t.Fatalf("ExpectPrimitiveType(%q): identifier Type is %T, want *JavaTypePrimitive", name, c.typ) + } + if prim.Keyword != expectedKeyword { + t.Errorf("ExpectPrimitiveType(%q): keyword = %q, want %q", name, prim.Keyword, expectedKeyword) + } +} + +// ExpectMethodType walks the tree rooted at root and asserts that the +// first MethodInvocation or MethodDeclaration whose name matches `name` +// carries a non-nil JavaTypeMethod whose DeclaringType.FullyQualifiedName +// equals expectedDeclaringFQN. +// +// For invocations across packages, expectedDeclaringFQN is the import path +// of the owning package (e.g. "fmt" for fmt.Println). For methods declared +// in the file under test, it is the package's full path +// (e.g. "main.Point" for a method on Point in package main). +func ExpectMethodType(t *testing.T, root tree.Tree, name string, expectedDeclaringFQN string) { + t.Helper() + c := visitor.Init(&methodTypeCollector{name: name}) + c.Visit(root, nil) + if !c.found { + t.Fatalf("ExpectMethodType(%q): no method with that name in tree", name) + } + if c.methodType == nil { + t.Fatalf("ExpectMethodType(%q): method has nil MethodType", name) + } + if c.methodType.DeclaringType == nil { + t.Fatalf("ExpectMethodType(%q): method has nil DeclaringType", name) + } + if got := c.methodType.DeclaringType.FullyQualifiedName; got != expectedDeclaringFQN { + t.Errorf("ExpectMethodType(%q): declaring FQN = %q, want %q", name, got, expectedDeclaringFQN) + } +} + +type identifierTypeCollector struct { + visitor.GoVisitor + name string + found bool + typ tree.JavaType +} + +func (v *identifierTypeCollector) VisitIdentifier(ident *tree.Identifier, p any) tree.J { + if !v.found && ident.Name == v.name { + v.found = true + v.typ = ident.Type + } + return ident +} + +type methodTypeCollector struct { + visitor.GoVisitor + name string + found bool + methodType *tree.JavaTypeMethod +} + +func (v *methodTypeCollector) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { + if !v.found && mi.Name != nil && mi.Name.Name == v.name { + v.found = true + v.methodType = mi.MethodType + } + return v.GoVisitor.VisitMethodInvocation(mi, p) +} + +func (v *methodTypeCollector) VisitMethodDeclaration(md *tree.MethodDeclaration, p any) tree.J { + if !v.found && md.Name != nil && md.Name.Name == v.name { + v.found = true + v.methodType = md.MethodType + } + return v.GoVisitor.VisitMethodDeclaration(md, p) +} diff --git a/rewrite-go/pkg/test/expect_type_test.go b/rewrite-go/pkg/test/expect_type_test.go new file mode 100644 index 00000000000..8f40b7d4d1c --- /dev/null +++ b/rewrite-go/pkg/test/expect_type_test.go @@ -0,0 +1,91 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" +) + +func TestExpectType_ClassType(t *testing.T) { + p := parser.NewGoParser() + cu, err := p.Parse("test.go", `package main + +type Point struct { + X int + Y int +} + +func main() { + p := Point{X: 1, Y: 2} + _ = p +} +`) + if err != nil { + t.Fatal(err) + } + ExpectType(t, cu, "p", "main.Point") +} + +func TestExpectPrimitiveType_LocalVar(t *testing.T) { + p := parser.NewGoParser() + cu, err := p.Parse("test.go", `package main + +func main() { + x := 42 + y := "hello" + _ = x + _ = y +} +`) + if err != nil { + t.Fatal(err) + } + ExpectPrimitiveType(t, cu, "x", "int") + ExpectPrimitiveType(t, cu, "y", "String") +} + +func TestExpectMethodType_StdlibInvocation(t *testing.T) { + p := parser.NewGoParser() + cu, err := p.Parse("test.go", `package main + +import "fmt" + +func main() { + fmt.Println("hello") +} +`) + if err != nil { + t.Fatal(err) + } + ExpectMethodType(t, cu, "Println", "fmt") +} + +func TestExpectMethodType_LocalDeclaration(t *testing.T) { + p := parser.NewGoParser() + cu, err := p.Parse("test.go", `package main + +func add(a int, b int) int { + return a + b +} +`) + if err != nil { + t.Fatal(err) + } + ExpectMethodType(t, cu, "add", "main") +} diff --git a/rewrite-go/pkg/test/spec.go b/rewrite-go/pkg/test/spec.go index e9d1289bdf3..c32c3154d7a 100644 --- a/rewrite-go/pkg/test/spec.go +++ b/rewrite-go/pkg/test/spec.go @@ -17,14 +17,15 @@ package test import ( - "fmt" "math" + "path" "strings" "testing" "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) @@ -34,7 +35,277 @@ type SourceSpec struct { Before string After *string // nil means no change expected (parse-print idempotence only) Path string - AfterRecipe func(t *testing.T, cu *tree.CompilationUnit) // optional post-parse assertion callback + Markers []tree.Marker // markers attached to the parsed source after parse + AfterRecipe func(t *testing.T, cu *tree.CompilationUnit) // optional post-parse assertion callback (.go files only) +} + +// Sources is the unit RewriteRun consumes. Single SourceSpecs and project +// wrappers (GoProject, etc.) both satisfy it so the same harness call can +// mix flat .go files with multi-file projects. +type Sources interface { + Expand() []SourceSpec +} + +// Expand makes a SourceSpec usable wherever a Sources is expected. +func (s SourceSpec) Expand() []SourceSpec { return []SourceSpec{s} } + +// WithPath returns a copy of s with Path set. Use this in multi-package +// projects so the test harness can locate each file's package directory. +func (s SourceSpec) WithPath(p string) SourceSpec { + s.Path = p + return s +} + +// WithAfterRecipe returns a copy of s with the post-parse callback set. +// The callback fires after parsing (with type attribution wired) and +// before recipe application; recipes can read s.Markers off the cu and +// assert on resolved types. +func (s SourceSpec) WithAfterRecipe(fn func(t *testing.T, cu *tree.CompilationUnit)) SourceSpec { + s.AfterRecipe = fn + return s +} + +// project wraps a set of sources and tags each with a GoProject marker on +// expansion. Mirrors Assertions.goProject(name, ...) on the Java side. +type project struct { + name string + inner []Sources +} + +func (p project) Expand() []SourceSpec { + marker := tree.NewGoProject(p.name) + var out []SourceSpec + for _, s := range p.inner { + for _, ss := range s.Expand() { + ss.Markers = append(append([]tree.Marker{}, ss.Markers...), marker) + out = append(out, ss) + } + } + mergeGoSumIntoGoMod(out) + propagateModuleResolution(out) + return out +} + +// propagateModuleResolution copies the go.mod's GoResolutionResult +// marker onto every sibling .go SourceSpec in the project. Recipes +// that need module context per file (e.g. RenamePackage's +// fileBelongsTo check) can then read it directly off the +// CompilationUnit's Markers without re-walking the project. Java does +// the equivalent at parse time by attaching the parsed-go.mod marker +// to each CU. +func propagateModuleResolution(specs []SourceSpec) { + var mrr *tree.GoResolutionResult + for i := range specs { + if specs[i].Path == "go.mod" { + if found := FindGoResolutionResult(specs[i]); found != nil && found.ModulePath != "" { + mrr = found + break + } + } + } + if mrr == nil { + return + } + for i := range specs { + if !strings.HasSuffix(specs[i].Path, ".go") { + continue + } + // Skip if already present (e.g. caller pre-attached one). + alreadyHas := false + for _, m := range specs[i].Markers { + if _, ok := m.(tree.GoResolutionResult); ok { + alreadyHas = true + break + } + } + if alreadyHas { + continue + } + specs[i].Markers = append(append([]tree.Marker{}, specs[i].Markers...), *mrr) + } +} + +// mergeGoSumIntoGoMod finds a sibling go.sum spec inside the same expanded +// project and merges its parsed ResolvedDependencies into the sibling +// go.mod's GoResolutionResult marker. Mirrors the Java side, where +// GoModParser#parseSumSibling reads go.sum off disk during parse — but Go +// tests don't write to disk, so we do the merge in memory. +func mergeGoSumIntoGoMod(specs []SourceSpec) { + var sumIdx, modIdx, markerIdx = -1, -1, -1 + for i, s := range specs { + switch s.Path { + case "go.sum": + sumIdx = i + case "go.mod": + modIdx = i + for j, m := range s.Markers { + if _, ok := m.(tree.GoResolutionResult); ok { + markerIdx = j + } + } + } + } + if sumIdx < 0 || modIdx < 0 || markerIdx < 0 { + return + } + resolved := parser.ParseGoSum(specs[sumIdx].Before) + if len(resolved) == 0 { + return + } + mrr := specs[modIdx].Markers[markerIdx].(tree.GoResolutionResult) + mrr.ResolvedDependencies = resolved + specs[modIdx].Markers[markerIdx] = mrr +} + +// GoProject groups a go.mod and one or more .go SourceSpecs as siblings of +// a single project. Every child receives a tree.GoProject marker. Mirrors +// the Java-side Assertions.goProject(name, sources...). +// +// Example: +// +// spec.RewriteRun(t, +// test.GoProject("foo", +// test.GoMod("module example.com/foo\ngo 1.22\n"), +// test.Golang("package main\nfunc main(){}\n"), +// ), +// ) +func GoProject(name string, sources ...Sources) Sources { + return project{name: name, inner: sources} +} + +// GoMod creates a SourceSpec for go.mod content. The content is dedented +// the same way Golang(...) is and parsed (via parser.ParseGoMod) at +// construction time so the resulting tree.GoResolutionResult marker is +// already attached to spec.Markers — recipes / tests can read module +// path, requires, replaces, etc. without re-parsing. +// +// If the content fails to parse the spec is returned with no +// GoResolutionResult marker; the test still round-trips the content +// verbatim, mirroring the Java goMod test helper's behavior on bad input. +// +// When a sibling GoSum(...) exists in the same project, its parsed +// ResolvedDependencies are merged into the GoResolutionResult marker at +// project-expansion time (see project.Expand). +func GoMod(before string, after ...string) SourceSpec { + content := TrimIndent(before) + spec := SourceSpec{ + Before: content, + Path: "go.mod", + } + if mrr, err := parser.ParseGoMod("go.mod", content); err == nil && mrr != nil { + spec.Markers = append(spec.Markers, *mrr) + } + if len(after) > 0 { + a := TrimIndent(after[0]) + spec.After = &a + } + return spec +} + +// GoSum creates a SourceSpec for go.sum content. The harness round-trips +// the content verbatim (no recipe processing today) and, when a sibling +// GoMod(...) is present in the same GoProject, merges the parsed +// ResolvedDependencies into the GoMod's GoResolutionResult marker. +func GoSum(before string, after ...string) SourceSpec { + content := TrimIndent(before) + spec := SourceSpec{ + Before: content, + Path: "go.sum", + } + if len(after) > 0 { + a := TrimIndent(after[0]) + spec.After = &a + } + return spec +} + +// FindGoResolutionResult walks a SourceSpec's markers for the parsed +// go.mod marker. Returns nil if not present (e.g. on a Golang(...) source). +func FindGoResolutionResult(spec SourceSpec) *tree.GoResolutionResult { + for _, m := range spec.Markers { + if mrr, ok := m.(tree.GoResolutionResult); ok { + return &mrr + } + } + return nil +} + +// parsePackageGroups groups .go SourceSpecs by package directory and +// parses each group together via parser.ParsePackage so files in the same +// package share a types.Info — file A can see file B's symbols. Returns +// a map from the spec's index in `flat` to its CompilationUnit so two +// specs sharing the same Path don't clobber each other in the result. +func parsePackageGroups(t *testing.T, p *parser.GoParser, flat []SourceSpec) map[int]*tree.CompilationUnit { + t.Helper() + type indexed struct { + idx int + input parser.FileInput + } + byDir := map[string][]indexed{} + for i, s := range flat { + if !strings.HasSuffix(s.Path, ".go") { + continue + } + dir := path.Dir(s.Path) + byDir[dir] = append(byDir[dir], indexed{idx: i, input: parser.FileInput{Path: s.Path, Content: s.Before}}) + } + + out := map[int]*tree.CompilationUnit{} + for dir, group := range byDir { + // Pre-filter against BuildContext so post-parse `cus` aligns + // with the included subset of `group`. + included := make([]indexed, 0, len(group)) + files := make([]parser.FileInput, 0, len(group)) + for _, g := range group { + if !parser.MatchBuildContext(p.BuildContext, path.Base(g.input.Path), g.input.Content) { + continue + } + included = append(included, g) + files = append(files, g.input) + } + if len(files) == 0 { + continue + } + cus, err := p.ParsePackage(files) + if err != nil { + t.Fatalf("parse error in package %s: %v", dir, err) + } + for i, cu := range cus { + out[included[i].idx] = cu + } + } + return out +} + +// buildProjectImporter scans a flattened source list for a go.mod with +// a GoResolutionResult marker. If found, registers every sibling .go +// file AND every go.mod-declared require with a ProjectImporter so: +// - intra-project imports type-check against real sources; +// - imports of declared third-party modules resolve to stub packages. +// Returns nil when there's no module context — the caller should fall +// back to importer.Default(). +func buildProjectImporter(flat []SourceSpec) *parser.ProjectImporter { + var mrr *tree.GoResolutionResult + for _, s := range flat { + if found := FindGoResolutionResult(s); found != nil && found.ModulePath != "" { + mrr = found + break + } + } + if mrr == nil { + return nil + } + pi := parser.NewProjectImporter(mrr.ModulePath, nil) + for _, req := range mrr.Requires { + pi.AddRequire(req.ModulePath) + } + for _, s := range flat { + if !strings.HasSuffix(s.Path, ".go") { + continue + } + pi.AddSource(s.Path, s.Before) + } + return pi } // Golang creates a SourceSpec for Go source code. @@ -114,50 +385,12 @@ func TrimIndent(s string) string { return strings.Join(lines, "\n") + "\n" } -// spaceValidator is a visitor that checks every Space it encounters -// for non-whitespace content that would indicate a parser bug. -type spaceValidator struct { - visitor.GoVisitor - errs []string -} - -func (v *spaceValidator) VisitSpace(space tree.Space, p any) tree.Space { - if space.Whitespace != "" && !isWhitespaceOnly(space.Whitespace) { - v.errs = append(v.errs, fmt.Sprintf("Space.Whitespace contains non-whitespace: %q", truncate(space.Whitespace, 80))) - } - for i, c := range space.Comments { - if c.Suffix != "" && !isWhitespaceOnly(c.Suffix) { - v.errs = append(v.errs, fmt.Sprintf("Comment[%d].Suffix contains non-whitespace: %q", i, truncate(c.Suffix, 80))) - } - if c.Text != "" && !strings.HasPrefix(c.Text, "//") && !strings.HasPrefix(c.Text, "/*") { - v.errs = append(v.errs, fmt.Sprintf("Comment[%d].Text is not a comment: %q", i, truncate(c.Text, 80))) - } - } - return space -} - // ValidateSpaces walks the tree and returns errors for any Space that // contains non-whitespace content (which would indicate a parser bug). +// Thin shim over golang.WhitespaceValidationService — recipes call the +// service directly; tests use this name for source-level continuity. func ValidateSpaces(root tree.Tree) []string { - v := visitor.Init(&spaceValidator{}) - v.Visit(root, nil) - return v.errs -} - -func isWhitespaceOnly(s string) bool { - for _, c := range s { - if c != ' ' && c != '\t' && c != '\n' && c != '\r' { - return false - } - } - return true -} - -func truncate(s string, n int) string { - if len(s) <= n { - return s - } - return s[:n] + "..." + return (&golang.WhitespaceValidationService{}).Validate(root) } // JavaRecipeConfig holds config for a Java-delegated recipe test. @@ -203,17 +436,62 @@ func (spec *RecipeSpec) WithJavaRpcClient(client *JavaRpcClient) *RecipeSpec { return spec } -// RewriteRun parses the source specs, checks parse-print idempotence, -// and (if configured) applies a recipe and checks the result. -func (spec *RecipeSpec) RewriteRun(t *testing.T, sources ...SourceSpec) { +// RewriteRun parses each source, checks parse-print idempotence, attaches +// any markers contributed by project wrappers (GoProject), and (if +// configured) applies a recipe and checks the result. Accepts both bare +// SourceSpec values and project wrappers like GoProject — they both +// implement Sources. +func (spec *RecipeSpec) RewriteRun(t *testing.T, sources ...Sources) { t.Helper() + // Flatten any project wrappers into a flat list of SourceSpecs, with + // project markers already attached. + var flat []SourceSpec + for _, s := range sources { + flat = append(flat, s.Expand()...) + } + p := parser.NewGoParser() - for _, src := range sources { - cu, err := p.Parse(src.Path, src.Before) - if err != nil { - t.Fatalf("parse error: %v", err) + // Only treat sources as a multi-file package when there's an explicit + // project context (a goMod sibling). Without it, bare Golang(...) + // specs may share the default Path="test.go" without intending to be + // siblings — preserve per-file parsing in that case. + var parsedByIdx map[int]*tree.CompilationUnit + if pi := buildProjectImporter(flat); pi != nil { + p.Importer = pi + parsedByIdx = parsePackageGroups(t, p, flat) + } + + for i, src := range flat { + // Non-Go sources (e.g. go.mod) are not yet parsed on the Go side. + // We round-trip them verbatim so project layouts compose, but skip + // tree-walks and recipe application. + if !strings.HasSuffix(src.Path, ".go") { + if src.After != nil && *src.After != src.Before { + t.Errorf("non-Go source %q: harness cannot apply recipes to it yet", src.Path) + } + continue + } + + var cu *tree.CompilationUnit + if parsedByIdx != nil { + cu = parsedByIdx[i] + } + if cu == nil { + // No project context (or project parse missed this source) — + // parse this file in isolation so two bare specs sharing a + // default Path don't clobber each other. + parsed, err := p.Parse(src.Path, src.Before) + if err != nil { + t.Fatalf("parse error: %v", err) + } + cu = parsed + } + + // Attach any project markers contributed by GoProject(...) wrappers. + for _, m := range src.Markers { + cu.Markers = tree.AddMarker(cu.Markers, m) } // Validate that no Space contains non-whitespace syntax @@ -266,12 +544,14 @@ func (spec *RecipeSpec) RewriteRun(t *testing.T, sources ...SourceSpec) { func runRecipe(r recipe.Recipe, t tree.Tree) tree.Tree { ctx := recipe.NewExecutionContext() - // Apply this recipe's own editor + // Apply this recipe's own editor, then drain any queued after-visits + // (e.g. ImportService.AddImportVisitor inserted via DoAfterVisit). if editor := r.Editor(); editor != nil { result := editor.Visit(t, ctx) if result != nil { t = result } + t = visitor.DrainAfterVisits(editor, t, ctx) } // Apply sub-recipes diff --git a/rewrite-go/pkg/tree/go.go b/rewrite-go/pkg/tree/go.go index 82c679ee15e..e9c2738fd6a 100644 --- a/rewrite-go/pkg/tree/go.go +++ b/rewrite-go/pkg/tree/go.go @@ -456,13 +456,14 @@ func (n *TypeList) WithMarkers(markers Markers) *TypeList { // Covers: `type Foo struct{...}`, `type Foo interface{...}`, `type Foo int`, `type Foo = Bar`. // For grouped declarations `type ( ... )`, Specs is non-nil and Name/Definition are unused. type TypeDecl struct { - ID uuid.UUID - Prefix Space - Markers Markers - Name *Identifier - Assign *LeftPadded[Space] // non-nil for `type Foo = Bar`; Before = space before `=` - Definition Expression // the type expression (nil for grouped) - Specs *Container[Statement] // non-nil for grouped `type ( ... )`; Before = space before `(` + ID uuid.UUID + Prefix Space + Markers Markers + LeadingAnnotations []*Annotation // `//go:generate ...` etc. + Name *Identifier + Assign *LeftPadded[Space] // non-nil for `type Foo = Bar`; Before = space before `=` + Definition Expression // the type expression (nil for grouped) + Specs *Container[Statement] // non-nil for grouped `type ( ... )`; Before = space before `(` } func (*TypeDecl) isTree() {} @@ -481,6 +482,12 @@ func (n *TypeDecl) WithMarkers(markers Markers) *TypeDecl { return &c } +func (n *TypeDecl) WithLeadingAnnotations(anns []*Annotation) *TypeDecl { + c := *n + c.LeadingAnnotations = anns + return &c +} + // ShortVarDecl is a marker on Assignment indicating `:=` instead of `=`. type ShortVarDecl struct { Ident uuid.UUID diff --git a/rewrite-go/pkg/tree/go_resolution_result.go b/rewrite-go/pkg/tree/go_resolution_result.go new file mode 100644 index 00000000000..5d47cedda92 --- /dev/null +++ b/rewrite-go/pkg/tree/go_resolution_result.go @@ -0,0 +1,108 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tree + +import "github.com/google/uuid" + +// GoResolutionResult mirrors org.openrewrite.golang.marker.GoResolutionResult +// on the Java side: the metadata parsed from a Go module's go.mod file. +// Attached as a Marker to a source representing a go.mod (in tests, to the +// test SourceSpec; at runtime, to whatever tree the Go parser produces for +// go.mod content). +type GoResolutionResult struct { + Ident uuid.UUID + ModulePath string + GoVersion string // empty if no `go` directive + Toolchain string // empty if no `toolchain` directive + Path string // path to the go.mod file + Requires []GoRequire + Replaces []GoReplace + Excludes []GoExclude + Retracts []GoRetract + ResolvedDependencies []GoResolvedDependency +} + +func (m GoResolutionResult) ID() uuid.UUID { return m.Ident } + +// FindRequire returns the require entry for a module, or nil. +func (m GoResolutionResult) FindRequire(modulePath string) *GoRequire { + for i := range m.Requires { + if m.Requires[i].ModulePath == modulePath { + return &m.Requires[i] + } + } + return nil +} + +// FindResolved returns the resolved dependency for a module, or nil. +func (m GoResolutionResult) FindResolved(modulePath string) *GoResolvedDependency { + for i := range m.ResolvedDependencies { + if m.ResolvedDependencies[i].ModulePath == modulePath { + return &m.ResolvedDependencies[i] + } + } + return nil +} + +// GoRequire is one entry in the go.mod `require` list. +type GoRequire struct { + ModulePath string + Version string + Indirect bool // true if marked `// indirect` +} + +// GoReplace is one entry in the go.mod `replace` list. +// OldVersion is empty if the replace targets all versions of OldPath. +// NewVersion is empty if NewPath is a local filesystem path. +type GoReplace struct { + OldPath string + OldVersion string + NewPath string + NewVersion string +} + +// GoExclude is one entry in the go.mod `exclude` list. +type GoExclude struct { + ModulePath string + Version string +} + +// GoRetract is one entry in the go.mod `retract` list. +// VersionRange is the raw expression as written, e.g. "v1.0.0" or "[v1.0.0, v1.1.0]". +type GoRetract struct { + VersionRange string + Rationale string // empty if no `// ...` comment +} + +// GoResolvedDependency is one entry from go.sum. +type GoResolvedDependency struct { + ModulePath string + Version string + ModuleHash string // h1:... — empty if only the go.mod hash is recorded + GoModHash string +} + +// NewGoResolutionResult creates a GoResolutionResult marker with a fresh UUID. +func NewGoResolutionResult(modulePath, goVersion, toolchain, path string) GoResolutionResult { + return GoResolutionResult{ + Ident: uuid.New(), + ModulePath: modulePath, + GoVersion: goVersion, + Toolchain: toolchain, + Path: path, + } +} diff --git a/rewrite-go/pkg/tree/j.go b/rewrite-go/pkg/tree/j.go index 9863b3603ca..bcf18b5976a 100644 --- a/rewrite-go/pkg/tree/j.go +++ b/rewrite-go/pkg/tree/j.go @@ -544,15 +544,16 @@ func (n *AssignmentOperation) WithVariable(variable Expression) *AssignmentOpera // MethodDeclaration represents a function or method declaration. type MethodDeclaration struct { - ID uuid.UUID - Prefix Space - Markers Markers - Receiver *Container[Statement] // nil for free functions; `(r *Type)` receiver - Name *Identifier - Parameters Container[Statement] // parameter list in parentheses - ReturnType Expression // nil for void functions; single type or *TypeList for multiple - Body *Block // nil for forward declarations - MethodType *JavaTypeMethod // the method type signature (nullable) + ID uuid.UUID + Prefix Space + Markers Markers + LeadingAnnotations []*Annotation // `//go:noinline` / `//go:nosplit` etc. on funcs + Receiver *Container[Statement] // nil for free functions; `(r *Type)` receiver + Name *Identifier + Parameters Container[Statement] // parameter list in parentheses + ReturnType Expression // nil for void functions; single type or *TypeList for multiple + Body *Block // nil for forward declarations + MethodType *JavaTypeMethod // the method type signature (nullable) } func (*MethodDeclaration) isTree() {} @@ -572,6 +573,12 @@ func (n *MethodDeclaration) WithMarkers(markers Markers) *MethodDeclaration { return &c } +func (n *MethodDeclaration) WithLeadingAnnotations(anns []*Annotation) *MethodDeclaration { + c := *n + c.LeadingAnnotations = anns + return &c +} + func (n *MethodDeclaration) WithName(name *Identifier) *MethodDeclaration { c := *n c.Name = name @@ -867,6 +874,68 @@ func (n *Label) WithMarkers(markers Markers) *Label { return &c } +// Annotation represents annotation metadata attached to a declaration. +// Mirrors org.openrewrite.java.tree.J.Annotation. +// +// Java has first-class `@Annotation(args)` syntax. Go has no `@`, but +// has two analogous concepts that this type models uniformly: +// +// 1. Struct field tags. Each `key:"value"` pair in a struct field +// tag becomes one Annotation on the field's VariableDeclarations: +// AnnotationType = Identifier{Name: key}, +// Arguments = [Literal{Value: value}]. +// The printer renders the run of struct-tag annotations on a +// VariableDeclarations whose parent is a StructType as a single +// backtick-wrapped tag. +// +// 2. Source directives like `//go:noinline`, `//go:generate`, +// `//lint:ignore`. Each directive becomes one Annotation on the +// enclosing MethodDeclaration / TypeDecl / VariableDeclarations: +// AnnotationType = Identifier{Name: "go:noinline"}, +// Arguments = [Literal(args)] when the directive carries +// text after the keyword, else nil. +// +// In both cases, AnnotationType is an Expression (typically Identifier; +// FieldAccess for qualified directives like `lint:ignore`). +// +// Recipes use AnnotationService to inspect / match / mutate, mirroring +// Java's AnnotationService surface. +type Annotation struct { + ID uuid.UUID + Prefix Space + Markers Markers + AnnotationType Expression // NameTree — Identifier or FieldAccess + Arguments *Container[Expression] // nullable; tags always have one Literal +} + +func (*Annotation) isTree() {} +func (*Annotation) isJ() {} +func (*Annotation) isExpression() {} + +func (n *Annotation) WithPrefix(prefix Space) *Annotation { + c := *n + c.Prefix = prefix + return &c +} + +func (n *Annotation) WithMarkers(markers Markers) *Annotation { + c := *n + c.Markers = markers + return &c +} + +func (n *Annotation) WithAnnotationType(annotationType Expression) *Annotation { + c := *n + c.AnnotationType = annotationType + return &c +} + +func (n *Annotation) WithArguments(arguments *Container[Expression]) *Annotation { + c := *n + c.Arguments = arguments + return &c +} + // Empty represents an empty statement or expression placeholder. type Empty struct { ID uuid.UUID @@ -1073,13 +1142,14 @@ func (n *MethodInvocation) WithName(name *Identifier) *MethodInvocation { // For grouped declarations `var ( ... )` or `const ( ... )`, Specs is non-nil and // Variables/TypeExpr are unused. type VariableDeclarations struct { - ID uuid.UUID - Prefix Space - Markers Markers - TypeExpr Expression // the declared type (nil if inferred) - Varargs *Space // non-nil for variadic params (`...T`); holds prefix of `...` - Variables []RightPadded[*VariableDeclarator] // the declared variables - Specs *Container[Statement] // non-nil for grouped `var ( ... )`; Before = space before `(` + ID uuid.UUID + Prefix Space + Markers Markers + LeadingAnnotations []*Annotation // struct field tags (one per `key:"value"` pair) or `//go:` directives + TypeExpr Expression // the declared type (nil if inferred) + Varargs *Space // non-nil for variadic params (`...T`); holds prefix of `...` + Variables []RightPadded[*VariableDeclarator] // the declared variables + Specs *Container[Statement] // non-nil for grouped `var ( ... )`; Before = space before `(` } func (*VariableDeclarations) isTree() {} @@ -1098,6 +1168,12 @@ func (n *VariableDeclarations) WithMarkers(markers Markers) *VariableDeclaration return &c } +func (n *VariableDeclarations) WithLeadingAnnotations(anns []*Annotation) *VariableDeclarations { + c := *n + c.LeadingAnnotations = anns + return &c +} + // VariableDeclarator represents a single variable with optional initializer. type VariableDeclarator struct { ID uuid.UUID diff --git a/rewrite-go/pkg/tree/j_methods.go b/rewrite-go/pkg/tree/j_methods.go new file mode 100644 index 00000000000..b08867a3183 --- /dev/null +++ b/rewrite-go/pkg/tree/j_methods.go @@ -0,0 +1,250 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by hand for J interface conformance — keep in sync +// with the J types declared in j.go and go.go. Each J node satisfies +// the polymorphic READ accessors GetID / GetPrefix / GetMarkers, +// mirroring rewrite-java's J.getId() / J.getPrefix() / J.getMarkers(). +// +// Mutation is intentionally not on the interface: each type's typed +// `WithPrefix(Space) *T` / `WithMarkers(Markers) *T` methods (in j.go +// / go.go) return a fresh instance, which is what the visitor +// framework's pointer-identity change detection requires. Framework +// code that needs to invoke withers polymorphically (RPC receiver) +// goes through reflection — see pkg/rpc/node_helpers.go. +package tree + +import "github.com/google/uuid" + +func (n *Annotation) GetID() uuid.UUID { return n.ID } +func (n *Annotation) GetPrefix() Space { return n.Prefix } +func (n *Annotation) GetMarkers() Markers { return n.Markers } + +func (n *ArrayAccess) GetID() uuid.UUID { return n.ID } +func (n *ArrayAccess) GetPrefix() Space { return n.Prefix } +func (n *ArrayAccess) GetMarkers() Markers { return n.Markers } + +func (n *ArrayDimension) GetID() uuid.UUID { return n.ID } +func (n *ArrayDimension) GetPrefix() Space { return n.Prefix } +func (n *ArrayDimension) GetMarkers() Markers { return n.Markers } + +func (n *ArrayType) GetID() uuid.UUID { return n.ID } +func (n *ArrayType) GetPrefix() Space { return n.Prefix } +func (n *ArrayType) GetMarkers() Markers { return n.Markers } + +func (n *Assignment) GetID() uuid.UUID { return n.ID } +func (n *Assignment) GetPrefix() Space { return n.Prefix } +func (n *Assignment) GetMarkers() Markers { return n.Markers } + +func (n *AssignmentOperation) GetID() uuid.UUID { return n.ID } +func (n *AssignmentOperation) GetPrefix() Space { return n.Prefix } +func (n *AssignmentOperation) GetMarkers() Markers { return n.Markers } + +func (n *Binary) GetID() uuid.UUID { return n.ID } +func (n *Binary) GetPrefix() Space { return n.Prefix } +func (n *Binary) GetMarkers() Markers { return n.Markers } + +func (n *Block) GetID() uuid.UUID { return n.ID } +func (n *Block) GetPrefix() Space { return n.Prefix } +func (n *Block) GetMarkers() Markers { return n.Markers } + +func (n *Break) GetID() uuid.UUID { return n.ID } +func (n *Break) GetPrefix() Space { return n.Prefix } +func (n *Break) GetMarkers() Markers { return n.Markers } + +func (n *Case) GetID() uuid.UUID { return n.ID } +func (n *Case) GetPrefix() Space { return n.Prefix } +func (n *Case) GetMarkers() Markers { return n.Markers } + +func (n *Channel) GetID() uuid.UUID { return n.ID } +func (n *Channel) GetPrefix() Space { return n.Prefix } +func (n *Channel) GetMarkers() Markers { return n.Markers } + +func (n *CommClause) GetID() uuid.UUID { return n.ID } +func (n *CommClause) GetPrefix() Space { return n.Prefix } +func (n *CommClause) GetMarkers() Markers { return n.Markers } + +func (n *CompilationUnit) GetID() uuid.UUID { return n.ID } +func (n *CompilationUnit) GetPrefix() Space { return n.Prefix } +func (n *CompilationUnit) GetMarkers() Markers { return n.Markers } + +func (n *Composite) GetID() uuid.UUID { return n.ID } +func (n *Composite) GetPrefix() Space { return n.Prefix } +func (n *Composite) GetMarkers() Markers { return n.Markers } + +func (n *Continue) GetID() uuid.UUID { return n.ID } +func (n *Continue) GetPrefix() Space { return n.Prefix } +func (n *Continue) GetMarkers() Markers { return n.Markers } + +func (n *ControlParentheses) GetID() uuid.UUID { return n.ID } +func (n *ControlParentheses) GetPrefix() Space { return n.Prefix } +func (n *ControlParentheses) GetMarkers() Markers { return n.Markers } + +func (n *Defer) GetID() uuid.UUID { return n.ID } +func (n *Defer) GetPrefix() Space { return n.Prefix } +func (n *Defer) GetMarkers() Markers { return n.Markers } + +func (n *Else) GetID() uuid.UUID { return n.ID } +func (n *Else) GetPrefix() Space { return n.Prefix } +func (n *Else) GetMarkers() Markers { return n.Markers } + +func (n *Empty) GetID() uuid.UUID { return n.ID } +func (n *Empty) GetPrefix() Space { return n.Prefix } +func (n *Empty) GetMarkers() Markers { return n.Markers } + +func (n *Fallthrough) GetID() uuid.UUID { return n.ID } +func (n *Fallthrough) GetPrefix() Space { return n.Prefix } +func (n *Fallthrough) GetMarkers() Markers { return n.Markers } + +func (n *FieldAccess) GetID() uuid.UUID { return n.ID } +func (n *FieldAccess) GetPrefix() Space { return n.Prefix } +func (n *FieldAccess) GetMarkers() Markers { return n.Markers } + +func (n *ForControl) GetID() uuid.UUID { return n.ID } +func (n *ForControl) GetPrefix() Space { return n.Prefix } +func (n *ForControl) GetMarkers() Markers { return n.Markers } + +func (n *ForEachControl) GetID() uuid.UUID { return n.ID } +func (n *ForEachControl) GetPrefix() Space { return n.Prefix } +func (n *ForEachControl) GetMarkers() Markers { return n.Markers } + +func (n *ForEachLoop) GetID() uuid.UUID { return n.ID } +func (n *ForEachLoop) GetPrefix() Space { return n.Prefix } +func (n *ForEachLoop) GetMarkers() Markers { return n.Markers } + +func (n *ForLoop) GetID() uuid.UUID { return n.ID } +func (n *ForLoop) GetPrefix() Space { return n.Prefix } +func (n *ForLoop) GetMarkers() Markers { return n.Markers } + +func (n *FuncType) GetID() uuid.UUID { return n.ID } +func (n *FuncType) GetPrefix() Space { return n.Prefix } +func (n *FuncType) GetMarkers() Markers { return n.Markers } + +func (n *GoStmt) GetID() uuid.UUID { return n.ID } +func (n *GoStmt) GetPrefix() Space { return n.Prefix } +func (n *GoStmt) GetMarkers() Markers { return n.Markers } + +func (n *Goto) GetID() uuid.UUID { return n.ID } +func (n *Goto) GetPrefix() Space { return n.Prefix } +func (n *Goto) GetMarkers() Markers { return n.Markers } + +func (n *Identifier) GetID() uuid.UUID { return n.ID } +func (n *Identifier) GetPrefix() Space { return n.Prefix } +func (n *Identifier) GetMarkers() Markers { return n.Markers } + +func (n *If) GetID() uuid.UUID { return n.ID } +func (n *If) GetPrefix() Space { return n.Prefix } +func (n *If) GetMarkers() Markers { return n.Markers } + +func (n *Import) GetID() uuid.UUID { return n.ID } +func (n *Import) GetPrefix() Space { return n.Prefix } +func (n *Import) GetMarkers() Markers { return n.Markers } + +func (n *IndexList) GetID() uuid.UUID { return n.ID } +func (n *IndexList) GetPrefix() Space { return n.Prefix } +func (n *IndexList) GetMarkers() Markers { return n.Markers } + +func (n *InterfaceType) GetID() uuid.UUID { return n.ID } +func (n *InterfaceType) GetPrefix() Space { return n.Prefix } +func (n *InterfaceType) GetMarkers() Markers { return n.Markers } + +func (n *KeyValue) GetID() uuid.UUID { return n.ID } +func (n *KeyValue) GetPrefix() Space { return n.Prefix } +func (n *KeyValue) GetMarkers() Markers { return n.Markers } + +func (n *Label) GetID() uuid.UUID { return n.ID } +func (n *Label) GetPrefix() Space { return n.Prefix } +func (n *Label) GetMarkers() Markers { return n.Markers } + +func (n *Literal) GetID() uuid.UUID { return n.ID } +func (n *Literal) GetPrefix() Space { return n.Prefix } +func (n *Literal) GetMarkers() Markers { return n.Markers } + +func (n *MapType) GetID() uuid.UUID { return n.ID } +func (n *MapType) GetPrefix() Space { return n.Prefix } +func (n *MapType) GetMarkers() Markers { return n.Markers } + +func (n *MethodDeclaration) GetID() uuid.UUID { return n.ID } +func (n *MethodDeclaration) GetPrefix() Space { return n.Prefix } +func (n *MethodDeclaration) GetMarkers() Markers { return n.Markers } + +func (n *MethodInvocation) GetID() uuid.UUID { return n.ID } +func (n *MethodInvocation) GetPrefix() Space { return n.Prefix } +func (n *MethodInvocation) GetMarkers() Markers { return n.Markers } + +func (n *MultiAssignment) GetID() uuid.UUID { return n.ID } +func (n *MultiAssignment) GetPrefix() Space { return n.Prefix } +func (n *MultiAssignment) GetMarkers() Markers { return n.Markers } + +func (n *ParameterizedType) GetID() uuid.UUID { return n.ID } +func (n *ParameterizedType) GetPrefix() Space { return n.Prefix } +func (n *ParameterizedType) GetMarkers() Markers { return n.Markers } + +func (n *Parentheses) GetID() uuid.UUID { return n.ID } +func (n *Parentheses) GetPrefix() Space { return n.Prefix } +func (n *Parentheses) GetMarkers() Markers { return n.Markers } + +func (n *PointerType) GetID() uuid.UUID { return n.ID } +func (n *PointerType) GetPrefix() Space { return n.Prefix } +func (n *PointerType) GetMarkers() Markers { return n.Markers } + +func (n *Return) GetID() uuid.UUID { return n.ID } +func (n *Return) GetPrefix() Space { return n.Prefix } +func (n *Return) GetMarkers() Markers { return n.Markers } + +func (n *Send) GetID() uuid.UUID { return n.ID } +func (n *Send) GetPrefix() Space { return n.Prefix } +func (n *Send) GetMarkers() Markers { return n.Markers } + +func (n *Slice) GetID() uuid.UUID { return n.ID } +func (n *Slice) GetPrefix() Space { return n.Prefix } +func (n *Slice) GetMarkers() Markers { return n.Markers } + +func (n *StatementExpression) GetID() uuid.UUID { return n.ID } +func (n *StatementExpression) GetPrefix() Space { return n.Prefix } +func (n *StatementExpression) GetMarkers() Markers { return n.Markers } + +func (n *StructType) GetID() uuid.UUID { return n.ID } +func (n *StructType) GetPrefix() Space { return n.Prefix } +func (n *StructType) GetMarkers() Markers { return n.Markers } + +func (n *Switch) GetID() uuid.UUID { return n.ID } +func (n *Switch) GetPrefix() Space { return n.Prefix } +func (n *Switch) GetMarkers() Markers { return n.Markers } + +func (n *TypeCast) GetID() uuid.UUID { return n.ID } +func (n *TypeCast) GetPrefix() Space { return n.Prefix } +func (n *TypeCast) GetMarkers() Markers { return n.Markers } + +func (n *TypeDecl) GetID() uuid.UUID { return n.ID } +func (n *TypeDecl) GetPrefix() Space { return n.Prefix } +func (n *TypeDecl) GetMarkers() Markers { return n.Markers } + +func (n *TypeList) GetID() uuid.UUID { return n.ID } +func (n *TypeList) GetPrefix() Space { return n.Prefix } +func (n *TypeList) GetMarkers() Markers { return n.Markers } + +func (n *Unary) GetID() uuid.UUID { return n.ID } +func (n *Unary) GetPrefix() Space { return n.Prefix } +func (n *Unary) GetMarkers() Markers { return n.Markers } + +func (n *VariableDeclarations) GetID() uuid.UUID { return n.ID } +func (n *VariableDeclarations) GetPrefix() Space { return n.Prefix } +func (n *VariableDeclarations) GetMarkers() Markers { return n.Markers } + +func (n *VariableDeclarator) GetID() uuid.UUID { return n.ID } +func (n *VariableDeclarator) GetPrefix() Space { return n.Prefix } +func (n *VariableDeclarator) GetMarkers() Markers { return n.Markers } diff --git a/rewrite-go/pkg/tree/markers.go b/rewrite-go/pkg/tree/markers.go index f427c659394..65c5287ddc4 100644 --- a/rewrite-go/pkg/tree/markers.go +++ b/rewrite-go/pkg/tree/markers.go @@ -109,6 +109,40 @@ type Markup struct { func (m Markup) ID() uuid.UUID { return m.Ident } +// GoProject identifies the Go project (logical grouping of go.mod + .go +// files) a source belongs to. Mirrors org.openrewrite.golang.marker.GoProject +// on the Java side. +type GoProject struct { + Ident uuid.UUID + ProjectName string +} + +func (m GoProject) ID() uuid.UUID { return m.Ident } + +// Semicolon marks a RightPadded element that is followed by an explicit +// `;` separator in the source — i.e. multiple statements on one line: +// `_ = 1; _ = 2`. Go inserts implicit semicolons at end-of-line so most +// files don't need this marker; it's only emitted when the source +// literally has a `;` between statements that the printer must +// reproduce. +// +// Mirrors org.openrewrite.java.marker.Semicolon on the Java side. +type Semicolon struct { + Ident uuid.UUID +} + +func (m Semicolon) ID() uuid.UUID { return m.Ident } + +// NewSemicolon creates a Semicolon marker with a fresh UUID. +func NewSemicolon() Semicolon { + return Semicolon{Ident: uuid.New()} +} + +// NewGoProject creates a GoProject marker with a new UUID. +func NewGoProject(projectName string) GoProject { + return GoProject{Ident: uuid.New(), ProjectName: projectName} +} + // NewSearchResult creates a SearchResult marker with a new UUID. func NewSearchResult(description string) SearchResult { return SearchResult{Ident: uuid.New(), Description: description} diff --git a/rewrite-go/pkg/tree/parse_error.go b/rewrite-go/pkg/tree/parse_error.go index b94c88d5ddb..5bbe524c284 100644 --- a/rewrite-go/pkg/tree/parse_error.go +++ b/rewrite-go/pkg/tree/parse_error.go @@ -29,6 +29,15 @@ type ParseError struct { Text string } +// ParseError isn't a J node (no Prefix, no acceptVisitor double- +// dispatch in the J hierarchy) but it does flow through the same +// Tree-typed Visit pipeline as a SourceFile alternate, so it satisfies +// Tree to allow visitor.GoVisitor.Visit to receive it. The visitor +// framework's switch has no case for ParseError; it falls through to +// the default arm. RPC senders/receivers special-case it ahead of the +// dispatch. +func (*ParseError) isTree() {} + // NewParseError creates a ParseError from a source path, source text, and error. func NewParseError(sourcePath string, source string, err error) *ParseError { marker := ParseExceptionResult{ diff --git a/rewrite-go/pkg/tree/search_walker.go b/rewrite-go/pkg/tree/search_walker.go new file mode 100644 index 00000000000..56be6a383d6 --- /dev/null +++ b/rewrite-go/pkg/tree/search_walker.go @@ -0,0 +1,109 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tree + +import ( + "reflect" + + "github.com/google/uuid" +) + +// CollectSearchResultIDs walks the given tree and returns the IDs of every +// SearchResult and SearchResultMarker found in any node's Markers. The +// returned slice has stable insertion order; duplicates are dropped. +// +// Implementation: reflection-based descent. Every LST node has a `Markers +// Markers` field, plus zero or more child fields that are themselves +// trees (or wrappers like RightPadded / LeftPadded / Container holding +// trees). A visitor-based walker would be cleaner but would require every +// concrete node type to opt in; reflection lets BatchVisit collect search +// markers without touching the 60+ node definitions. +func CollectSearchResultIDs(t Tree) []uuid.UUID { + w := &searchWalker{seen: make(map[uuid.UUID]struct{})} + w.walk(reflect.ValueOf(t)) + return w.ids +} + +type searchWalker struct { + ids []uuid.UUID + seen map[uuid.UUID]struct{} +} + +var ( + treeIface = reflect.TypeOf((*Tree)(nil)).Elem() + markersType = reflect.TypeOf(Markers{}) + uuidType = reflect.TypeOf(uuid.UUID{}) +) + +func (w *searchWalker) walk(v reflect.Value) { + if !v.IsValid() { + return + } + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + if v.IsNil() { + return + } + w.walk(v.Elem()) + case reflect.Struct: + // Check for a Markers field on this struct and harvest search-result IDs. + if mf := v.FieldByName("Markers"); mf.IsValid() && mf.Type() == markersType { + w.collectFrom(mf.Interface().(Markers)) + } + // Descend into every other field. We skip Markers (already handled) + // and Prefix (a Space — never carries markers we care about). + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + name := v.Type().Field(i).Name + if name == "Markers" || name == "Prefix" || name == "ID" { + continue + } + // Skip uuid.UUID values — they're not trees. + if f.Type() == uuidType { + continue + } + w.walk(f) + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + w.walk(v.Index(i)) + } + } +} + +func (w *searchWalker) collectFrom(m Markers) { + for _, marker := range m.Entries { + var id uuid.UUID + switch x := marker.(type) { + case SearchResult: + id = x.Ident + case *SearchResult: + id = x.Ident + case SearchResultMarker: + id = x.Ident + case *SearchResultMarker: + id = x.Ident + default: + continue + } + if _, dup := w.seen[id]; dup { + continue + } + w.seen[id] = struct{}{} + w.ids = append(w.ids, id) + } +} diff --git a/rewrite-go/pkg/tree/tree.go b/rewrite-go/pkg/tree/tree.go index ac18dcdacc2..5b44a1eedbf 100644 --- a/rewrite-go/pkg/tree/tree.go +++ b/rewrite-go/pkg/tree/tree.go @@ -16,13 +16,37 @@ package tree +import "github.com/google/uuid" + // Tree is the root interface for all LST nodes. type Tree interface{ isTree() } -// J is the interface for all Java-like AST nodes that carry a prefix space. +// J is the interface for all Java-like AST nodes that carry a prefix +// space. Mirrors org.openrewrite.java.tree.J — `getId`, `getPrefix`, +// `getMarkers` are polymorphic accessors so RPC senders, receivers, +// and other framework code can read the cross-cutting fields without +// per-type switches. +// +// Concrete impls live in j_methods.go (keep in sync with the J types +// in j.go and go.go). +// +// Mutation: the J interface is read-only. To produce a modified node, +// use the typed `WithPrefix(Space) *T` / `WithMarkers(Markers) *T` +// per-type methods — they return a *new* instance so the visitor +// framework's change-detection (pointer identity) works correctly. +// Framework code that needs to invoke them polymorphically (e.g. the +// RPC receiver's PreVisit) goes through reflection — see the +// `withPrefixViaReflection` helper in pkg/rpc. +// +// In-place mutation of Prefix / Markers / ID is never safe because it +// would silently bypass RecipeScheduler's "did this recipe change the +// tree?" check. type J interface { Tree isJ() + GetID() uuid.UUID + GetPrefix() Space + GetMarkers() Markers } // Expression is a J node that evaluates to a value. diff --git a/rewrite-go/pkg/visitor/cursor.go b/rewrite-go/pkg/visitor/cursor.go index 04549ce7b15..086f1524976 100644 --- a/rewrite-go/pkg/visitor/cursor.go +++ b/rewrite-go/pkg/visitor/cursor.go @@ -18,22 +18,122 @@ package visitor import "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" -// Cursor tracks the path from root to the currently visited node, -// providing context during tree traversal. +// Cursor tracks the path from root to the currently visited node and +// carries a per-frame message map so passes can stash state for +// ancestors / descendants. Mirrors org.openrewrite.Cursor in Java. type Cursor struct { - parent *Cursor - value tree.Tree + parent *Cursor + value tree.Tree + messages map[string]any } func NewCursor(parent *Cursor, value tree.Tree) *Cursor { return &Cursor{parent: parent, value: value} } -func (c *Cursor) Parent() *Cursor { return c.parent } +func (c *Cursor) Parent() *Cursor { return c.parent } func (c *Cursor) Value() tree.Tree { return c.value } +// PutMessage stores a value on this cursor's frame, keyed by name. +// Mirrors Java Cursor.putMessage(String, Object). +func (c *Cursor) PutMessage(key string, value any) { + if c.messages == nil { + c.messages = make(map[string]any, 4) + } + c.messages[key] = value +} + +// GetMessage returns the value previously stored on THIS frame, or nil +// if no value exists. Does not walk up the parent chain. Mirrors Java +// Cursor.getMessage(String). +func (c *Cursor) GetMessage(key string) any { + if c.messages == nil { + return nil + } + return c.messages[key] +} + +// GetNearestMessage walks up the cursor chain (starting from this frame) +// returning the first matching value. Returns nil if no frame has a +// value for the key. Mirrors Java Cursor.getNearestMessage(String). +func (c *Cursor) GetNearestMessage(key string) any { + for cur := c; cur != nil; cur = cur.parent { + if cur.messages != nil { + if v, ok := cur.messages[key]; ok { + return v + } + } + } + return nil +} + +// GetNearestMessageOrDefault is GetNearestMessage with a fallback when +// no frame holds a value for the key. +func (c *Cursor) GetNearestMessageOrDefault(key string, defaultValue any) any { + if v := c.GetNearestMessage(key); v != nil { + return v + } + return defaultValue +} + +// PollNearestMessage walks up the chain like GetNearestMessage, but +// REMOVES the value from the frame where it was found. Returns nil if +// no frame had it. Mirrors Java Cursor.pollNearestMessage(String). +func (c *Cursor) PollNearestMessage(key string) any { + for cur := c; cur != nil; cur = cur.parent { + if cur.messages != nil { + if v, ok := cur.messages[key]; ok { + delete(cur.messages, key) + return v + } + } + } + return nil +} + +// ComputeMessageIfAbsent returns the value for the key on THIS frame, +// computing and storing it via the supplier if absent. Mirrors Java +// Cursor.computeMessageIfAbsent. +func (c *Cursor) ComputeMessageIfAbsent(key string, supplier func() any) any { + if c.messages == nil { + c.messages = make(map[string]any, 4) + } + if v, ok := c.messages[key]; ok { + return v + } + v := supplier() + c.messages[key] = v + return v +} + +// PutMessageOnFirstEnclosing walks up looking for the first ancestor +// whose value matches the predicate, and stores the message on that +// frame. No-op if no ancestor matches. Mirrors Java +// Cursor.putMessageOnFirstEnclosing(Class, String, Object), generalized +// to a predicate so callers can match on any type or condition. +func (c *Cursor) PutMessageOnFirstEnclosing(match func(t tree.Tree) bool, key string, value any) { + for cur := c; cur != nil; cur = cur.parent { + if cur.value != nil && match(cur.value) { + cur.PutMessage(key, value) + return + } + } +} + +// BuildChain constructs a cursor chain from a list of tree values, root first. +// Returns nil for an empty input. Used by the RPC layer to reconstruct the +// cursor from a Visit request's `cursor` field (a list of tree IDs whose +// values have already been fetched in order). +func BuildChain(values []tree.Tree) *Cursor { + var c *Cursor + for _, v := range values { + c = NewCursor(c, v) + } + return c +} + // FirstEnclosing walks up the cursor chain to find the first ancestor -// matching the given type. +// matching the given type. The cursor itself is not considered — only ancestors. func FirstEnclosing[T tree.Tree](c *Cursor) (T, bool) { for cur := c.parent; cur != nil; cur = cur.parent { if v, ok := cur.value.(T); ok { diff --git a/rewrite-go/pkg/visitor/drain.go b/rewrite-go/pkg/visitor/drain.go new file mode 100644 index 00000000000..671b3c0aad6 --- /dev/null +++ b/rewrite-go/pkg/visitor/drain.go @@ -0,0 +1,62 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package visitor + +import "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + +// AfterVisitsProvider is the structural interface a visitor implements +// to expose its DoAfterVisit queue. GoVisitor satisfies it; user-defined +// visitors that embed GoVisitor inherit the methods automatically. +type AfterVisitsProvider interface { + AfterVisits() []AfterVisitor + DoAfterVisit(AfterVisitor) +} + +// DrainAfterVisits applies any visitors that `editor` queued via +// GoVisitor.DoAfterVisit. After-visits can themselves queue more +// after-visits (transitive); this loops until no provider has anything +// left in its queue. Mirrors JavaVisitor's afterVisit drain semantics. +// +// Returns the (possibly modified) tree. Callers should ALWAYS run this +// after the main editor.Visit so DoAfterVisit-queued follow-ups land. +// Pass `editor` as the recipe's TreeVisitor and the current tree + ctx. +func DrainAfterVisits(editor any, t tree.Tree, ctx any) tree.Tree { + parent, ok := editor.(AfterVisitsProvider) + if !ok { + return t + } + for { + batch := parent.AfterVisits() + if len(batch) == 0 { + return t + } + for _, v := range batch { + result := v.Visit(t, ctx) + if result != nil { + t = result + } + // Forward any after-visits the queued visitor itself + // produced back onto the parent so the outer loop drains + // them next iteration. + if pv, ok := v.(AfterVisitsProvider); ok { + for _, m := range pv.AfterVisits() { + parent.DoAfterVisit(m) + } + } + } + } +} diff --git a/rewrite-go/pkg/visitor/go_visitor.go b/rewrite-go/pkg/visitor/go_visitor.go index 0082a8f233f..905d33a12b7 100644 --- a/rewrite-go/pkg/visitor/go_visitor.go +++ b/rewrite-go/pkg/visitor/go_visitor.go @@ -21,15 +21,78 @@ import "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" // GoVisitor traverses and optionally transforms an OpenRewrite LST. // Embed GoVisitor in a struct and override visit methods to customize behavior. // Set Self to the outer struct to enable virtual dispatch. +// +// Cursor: GoVisitor maintains the current cursor as state. Recipes that +// need ancestor context call v.Cursor() inside any Visit* override (matches +// the JavaVisitor.getCursor() pattern). The RPC layer seeds an initial +// cursor via SetCursor before traversal begins. +// +// After-visits: a visit method can queue follow-up visitors via +// DoAfterVisit. The recipe runner drains the queue once the main visit +// returns and re-applies each queued visitor (transitively — afters can +// queue afters). This mirrors JavaVisitor.doAfterVisit and is the canonical +// way to compose side-effects like "add an import" after the main edit. type GoVisitor struct { // Self must point to the outermost embedding struct for virtual dispatch. // If nil, dispatches to the default implementations on GoVisitor itself. Self interface{} - cursor *Cursor + cursor *Cursor + afterVisits []AfterVisitor +} + +// AfterVisitor is the interface a follow-up visitor must satisfy to +// participate in DoAfterVisit. It's a structural alias for the +// recipe.TreeVisitor interface — duplicated here to avoid an import +// cycle between pkg/visitor and pkg/recipe. +type AfterVisitor interface { + Visit(t tree.Tree, p any) tree.Tree +} + +// Cursor returns the current cursor (the path from root to the node +// currently being visited). Mirrors JavaVisitor.getCursor(). +func (v *GoVisitor) Cursor() *Cursor { return v.cursor } + +// SetCursor seeds the visitor with an initial cursor chain. The RPC layer +// calls this with the chain reconstructed from a Visit request's cursor +// IDs before invoking Visit. Recipes typically don't call this directly. +func (v *GoVisitor) SetCursor(c *Cursor) { v.cursor = c } + +// DoAfterVisit queues a follow-up visitor to run after the main visit +// completes. Mirrors JavaVisitor.doAfterVisit. Use this from inside any +// Visit* override to compose side-effects like adding an import: +// +// svc := service.ImportServiceFor(cu) +// v.DoAfterVisit(svc.AddImportVisitor("fmt", nil, false)) +// +// The recipe runner drains the queue after Visit returns; queued +// visitors can themselves queue more after-visitors (transitive). +func (v *GoVisitor) DoAfterVisit(other AfterVisitor) { + v.afterVisits = append(v.afterVisits, other) +} + +// AfterVisits returns the queued follow-up visitors, then clears the +// queue. The recipe runner calls this once the main Visit returns and +// applies each visitor to the modified tree, looping until empty. +func (v *GoVisitor) AfterVisits() []AfterVisitor { + out := v.afterVisits + v.afterVisits = nil + return out } // Visit dispatches to the appropriate visit method based on the node's concrete type. +// +// Lifecycle: +// 1. cursor is pushed for `t`. +// 2. PreVisit(t, p) is called via virtual dispatch — subclasses +// (e.g. RPC sender/receiver) override it to handle cross-cutting +// fields (id, prefix, markers) once per node. +// 3. The type-specific Visit* method is dispatched via virtual +// dispatch on the (possibly modified) tree returned by PreVisit. +// 4. cursor pops on return. +// +// PreVisit returning nil short-circuits the visit — useful for +// receivers that get DELETE state from the wire. func (v *GoVisitor) Visit(t tree.Tree, p any) tree.Tree { if t == nil { return nil @@ -38,6 +101,11 @@ func (v *GoVisitor) Visit(t tree.Tree, p any) tree.Tree { v.cursor = NewCursor(v.cursor, t) defer func() { v.cursor = v.cursor.parent }() + t = v.self().PreVisit(t, p) + if t == nil { + return nil + } + switch n := t.(type) { case *tree.CompilationUnit: return v.self().VisitCompilationUnit(n, p) @@ -101,6 +169,8 @@ func (v *GoVisitor) Visit(t tree.Tree, p any) tree.Tree { return v.self().VisitFallthrough(n, p) case *tree.Empty: return v.self().VisitEmpty(n, p) + case *tree.Annotation: + return v.self().VisitAnnotation(n, p) case *tree.ArrayType: return v.self().VisitArrayType(n, p) case *tree.Parentheses: @@ -162,6 +232,7 @@ func (v *GoVisitor) self() VisitorI { // VisitorI defines all overridable visit methods. type VisitorI interface { Visit(t tree.Tree, p any) tree.Tree + PreVisit(t tree.Tree, p any) tree.Tree VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J VisitIdentifier(ident *tree.Identifier, p any) tree.J VisitLiteral(lit *tree.Literal, p any) tree.J @@ -193,6 +264,7 @@ type VisitorI interface { VisitGoto(g *tree.Goto, p any) tree.J VisitFallthrough(f *tree.Fallthrough, p any) tree.J VisitEmpty(empty *tree.Empty, p any) tree.J + VisitAnnotation(ann *tree.Annotation, p any) tree.J VisitArrayType(at *tree.ArrayType, p any) tree.J VisitParentheses(paren *tree.Parentheses, p any) tree.J VisitTypeCast(tc *tree.TypeCast, p any) tree.J @@ -224,6 +296,14 @@ var _ VisitorI = (*GoVisitor)(nil) // --- Default visit implementations --- +// PreVisit is the per-node hook called by Visit() before dispatching +// to the type-specific Visit* method. The default implementation is +// the identity function. RPC senders/receivers override it to +// serialize/deserialize the cross-cutting `id`, `prefix`, and +// `markers` fields once per node, mirroring Java's +// JavaVisitor.preVisit pattern. +func (v *GoVisitor) PreVisit(t tree.Tree, p any) tree.Tree { return t } + func (v *GoVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { cu = cu.WithPrefix(v.self().VisitSpace(cu.Prefix, p)) cu = cu.WithMarkers(v.visitMarkers(cu.Markers, p)) @@ -312,6 +392,17 @@ func (v *GoVisitor) VisitAssignment(assign *tree.Assignment, p any) tree.J { func (v *GoVisitor) VisitMethodDeclaration(md *tree.MethodDeclaration, p any) tree.J { md = md.WithPrefix(v.self().VisitSpace(md.Prefix, p)) md = md.WithMarkers(v.visitMarkers(md.Markers, p)) + if len(md.LeadingAnnotations) > 0 { + anns := make([]*tree.Annotation, 0, len(md.LeadingAnnotations)) + for _, a := range md.LeadingAnnotations { + visited := v.self().Visit(a, p) + if visited == nil { + continue + } + anns = append(anns, visited.(*tree.Annotation)) + } + md = md.WithLeadingAnnotations(anns) + } md = md.WithName(visitAndCast[*tree.Identifier](v, md.Name, p)) if md.Body != nil { md = md.WithBody(visitAndCast[*tree.Block](v, md.Body, p)) @@ -323,6 +414,13 @@ func (v *GoVisitor) VisitFieldAccess(fa *tree.FieldAccess, p any) tree.J { fa = fa.WithPrefix(v.self().VisitSpace(fa.Prefix, p)) fa = fa.WithMarkers(v.visitMarkers(fa.Markers, p)) fa = fa.WithTarget(visitExpression(v, fa.Target, p)) + // Visit the selector identifier so recipes that traverse identifiers + // see the right-hand side of `target.Name` (e.g. the `Box` in + // `a.Box[int]{...}`). Mirrors JavaIsoVisitor.visitFieldAccess. + name := fa.Name + name.Before = v.self().VisitSpace(name.Before, p) + name.Element = visitAndCast[*tree.Identifier](v, name.Element, p) + fa.Name = name return fa } @@ -344,6 +442,17 @@ func (v *GoVisitor) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree func (v *GoVisitor) VisitVariableDeclarations(vd *tree.VariableDeclarations, p any) tree.J { vd = vd.WithPrefix(v.self().VisitSpace(vd.Prefix, p)) vd = vd.WithMarkers(v.visitMarkers(vd.Markers, p)) + if len(vd.LeadingAnnotations) > 0 { + anns := make([]*tree.Annotation, 0, len(vd.LeadingAnnotations)) + for _, a := range vd.LeadingAnnotations { + visited := v.self().Visit(a, p) + if visited == nil { + continue + } + anns = append(anns, visited.(*tree.Annotation)) + } + vd = vd.WithLeadingAnnotations(anns) + } if vd.TypeExpr != nil { vd.TypeExpr = visitExpression(v, vd.TypeExpr, p) } @@ -478,6 +587,22 @@ func (v *GoVisitor) VisitEmpty(empty *tree.Empty, p any) tree.J { return empty } +func (v *GoVisitor) VisitAnnotation(ann *tree.Annotation, p any) tree.J { + ann = ann.WithPrefix(v.self().VisitSpace(ann.Prefix, p)) + ann = ann.WithMarkers(v.visitMarkers(ann.Markers, p)) + if ann.AnnotationType != nil { + ann = ann.WithAnnotationType(visitExpression(v, ann.AnnotationType, p)) + } + if ann.Arguments != nil { + args := *ann.Arguments + args.Before = v.self().VisitSpace(args.Before, p) + args.Markers = v.visitMarkers(args.Markers, p) + args.Elements = visitRightPaddedExpressionList(v, args.Elements, p) + ann = ann.WithArguments(&args) + } + return ann +} + func (v *GoVisitor) VisitArrayType(at *tree.ArrayType, p any) tree.J { at = at.WithPrefix(v.self().VisitSpace(at.Prefix, p)) at = at.WithMarkers(v.visitMarkers(at.Markers, p)) @@ -511,12 +636,24 @@ func (v *GoVisitor) VisitArrayAccess(aa *tree.ArrayAccess, p any) tree.J { func (v *GoVisitor) VisitParameterizedType(pt *tree.ParameterizedType, p any) tree.J { pt = pt.WithPrefix(v.self().VisitSpace(pt.Prefix, p)) pt = pt.WithMarkers(v.visitMarkers(pt.Markers, p)) + if pt.Clazz != nil { + pt.Clazz = visitExpression(v, pt.Clazz, p) + } + if pt.TypeParameters != nil { + pt.TypeParameters.Before = v.self().VisitSpace(pt.TypeParameters.Before, p) + pt.TypeParameters.Elements = visitRightPaddedList(v, pt.TypeParameters.Elements, p) + } return pt } func (v *GoVisitor) VisitIndexList(il *tree.IndexList, p any) tree.J { il = il.WithPrefix(v.self().VisitSpace(il.Prefix, p)) il = il.WithMarkers(v.visitMarkers(il.Markers, p)) + if il.Target != nil { + il.Target = visitExpression(v, il.Target, p) + } + il.Indices.Before = v.self().VisitSpace(il.Indices.Before, p) + il.Indices.Elements = visitRightPaddedList(v, il.Indices.Elements, p) return il } @@ -529,6 +666,11 @@ func (v *GoVisitor) VisitArrayDimension(ad *tree.ArrayDimension, p any) tree.J { func (v *GoVisitor) VisitComposite(c *tree.Composite, p any) tree.J { c = c.WithPrefix(v.self().VisitSpace(c.Prefix, p)) c = c.WithMarkers(v.visitMarkers(c.Markers, p)) + if c.TypeExpr != nil { + c.TypeExpr = visitExpression(v, c.TypeExpr, p) + } + c.Elements.Before = v.self().VisitSpace(c.Elements.Before, p) + c.Elements.Elements = visitRightPaddedList(v, c.Elements.Elements, p) return c } @@ -587,6 +729,17 @@ func (v *GoVisitor) VisitTypeList(tl *tree.TypeList, p any) tree.J { func (v *GoVisitor) VisitTypeDecl(td *tree.TypeDecl, p any) tree.J { td = td.WithPrefix(v.self().VisitSpace(td.Prefix, p)) td = td.WithMarkers(v.visitMarkers(td.Markers, p)) + if len(td.LeadingAnnotations) > 0 { + anns := make([]*tree.Annotation, 0, len(td.LeadingAnnotations)) + for _, a := range td.LeadingAnnotations { + visited := v.self().Visit(a, p) + if visited == nil { + continue + } + anns = append(anns, visited.(*tree.Annotation)) + } + td = td.WithLeadingAnnotations(anns) + } return td } diff --git a/rewrite-go/rpc b/rewrite-go/rpc index b987241bea6..9293bbbea1e 100755 Binary files a/rewrite-go/rpc and b/rewrite-go/rpc differ diff --git a/rewrite-go/src/integTest/java/org/openrewrite/golang/GoProjectTest.java b/rewrite-go/src/integTest/java/org/openrewrite/golang/GoProjectTest.java new file mode 100644 index 00000000000..f41d26f5dfa --- /dev/null +++ b/rewrite-go/src/integTest/java/org/openrewrite/golang/GoProjectTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.openrewrite.SourceFile; +import org.openrewrite.golang.marker.GoProject; +import org.openrewrite.golang.marker.GoResolutionResult; +import org.openrewrite.golang.rpc.GoRewriteRpc; +import org.openrewrite.golang.tree.Go; +import org.openrewrite.java.tree.J; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.test.TypeValidation; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.openrewrite.golang.Assertions.go; +import static org.openrewrite.golang.Assertions.goMod; +import static org.openrewrite.golang.Assertions.goProject; + +@Timeout(value = 120, unit = TimeUnit.SECONDS) +class GoProjectTest implements RewriteTest { + + @TempDir + Path tempDir; + + @BeforeEach + void before() { + Path binaryPath = Paths.get("build/rewrite-go-rpc").toAbsolutePath(); + GoRewriteRpc.setFactory(GoRewriteRpc.builder() + .goBinaryPath(binaryPath) + .log(tempDir.resolve("go-rpc.log")) + .traceRpcMessages()); + } + + @AfterEach + void after() { + GoRewriteRpc.shutdownCurrent(); + } + + @Override + public void defaults(org.openrewrite.test.RecipeSpec spec) { + spec.typeValidationOptions(TypeValidation.builder() + .allowNonWhitespaceInWhitespace(true) + .identifiers(false) + .methodInvocations(false) + .build()); + } + + /** + * Direct test of the project-aware Parse RPC: a parser configured with + * module + go.mod content sends the project context to the Go server, + * which builds a ProjectImporter so the parsed Go.CompilationUnit gets + * type attribution on third-party imports declared in `require`. + */ + @Test + void thirdPartyImportResolvesViaModuleContext() { + String goModContent = + "module example.com/foo\n\n" + + "go 1.22\n\n" + + "require github.com/x/y v1.2.3\n"; + + GolangParser parser = GolangParser.builder() + .module("example.com/foo") + .goMod(goModContent) + .build(); + + SourceFile sf = parser.parse( + "package main\n\n" + + "import \"github.com/x/y\"\n\n" + + "func main() { _ = y.Hello() }\n" + ).findFirst().orElseThrow(); + + assertThat(sf).isInstanceOf(Go.CompilationUnit.class); + // Walk the parsed CU for the `y` package-alias identifier and + // confirm its type came back resolved. Without module context it + // would be null because importer.Default() doesn't know about + // github.com/x/y. + boolean[] sawResolvedY = {false}; + new org.openrewrite.java.JavaIsoVisitor() { + @Override + public J.Identifier visitIdentifier(J.Identifier ident, Integer p) { + if ("y".equals(ident.getSimpleName()) && ident.getType() != null) { + sawResolvedY[0] = true; + } + return super.visitIdentifier(ident, p); + } + }.visit(sf, 0); + assertThat(sawResolvedY[0]) + .as("expected `y` import identifier to have a non-nil Type via project context") + .isTrue(); + } + + @Test + void goModAndGoFilesAreSiblingsTaggedWithProjectMarker() { + rewriteRun( + goProject("foo", + goMod( + """ + module example.com/foo + + go 1.22 + + require github.com/x/y v1.2.3 + """, + s -> s.afterRecipe(pt -> { + // The goMod source carries the resolution result. + GoResolutionResult mrr = pt.getMarkers().findFirst(GoResolutionResult.class).orElseThrow(); + assertThat(mrr.getModulePath()).isEqualTo("example.com/foo"); + assertThat(mrr.getGoVersion()).isEqualTo("1.22"); + // And the project marker added by goProject(...). + GoProject project = pt.getMarkers().findFirst(GoProject.class).orElseThrow(); + assertThat(project.getProjectName()).isEqualTo("foo"); + })), + go( + """ + package main + + import "github.com/x/y" + + func main() { y.Hello() } + """, + s -> s.afterRecipe(cu -> { + // The .go file carries GoProject (from the wrapper) but NOT + // GoResolutionResult — that lives on the sibling go.mod, just + // as MavenResolutionResult lives on the sibling pom.xml. + GoProject project = cu.getMarkers().findFirst(GoProject.class).orElseThrow(); + assertThat(project.getProjectName()).isEqualTo("foo"); + assertThat(cu.getMarkers().findFirst(GoResolutionResult.class)).isEmpty(); + })) + ) + ); + } +} diff --git a/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangParserIntegTest.java b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangParserIntegTest.java index db34a20bdab..3ad48bb6b44 100644 --- a/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangParserIntegTest.java +++ b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangParserIntegTest.java @@ -29,6 +29,8 @@ import java.nio.file.Paths; import java.util.concurrent.TimeUnit; +import static org.openrewrite.golang.Assertions.expectMethodType; +import static org.openrewrite.golang.Assertions.expectType; import static org.openrewrite.golang.Assertions.go; /** @@ -106,12 +108,29 @@ func main() { \tfmt.Println("Hello") } """, - spec -> spec.afterRecipe(cu -> { - var methods = org.openrewrite.java.search.FindMethods.find(cu, "fmt Println(..)"); - org.assertj.core.api.Assertions.assertThat(methods) - .as("FindMethods should find fmt.Println invocation via type attribution") - .isNotEmpty(); - }) + spec -> spec.afterRecipe(cu -> expectMethodType(cu, "Println", "fmt")) + ) + ); + } + + @Test + void verifyStructTypeAttribution() { + rewriteRun( + go( + """ + package main + + type Point struct { + \tX int + \tY int + } + + func main() { + \tp := Point{X: 1, Y: 2} + \t_ = p + } + """, + spec -> spec.afterRecipe(cu -> expectType(cu, "p", "main.Point")) ) ); } diff --git a/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangRecipeIntegTest.java b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangRecipeIntegTest.java index 6fe2b3d0489..f37c3527b9b 100644 --- a/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangRecipeIntegTest.java +++ b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/GolangRecipeIntegTest.java @@ -15,7 +15,9 @@ */ package org.openrewrite.golang.rpc; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -174,8 +176,10 @@ void goNativeRecipeViaRpcFullCliPath() { // Reset to simulate fresh session, then install recipes (like CLI does) rpc.reset(); - rpc.installRecipes( - new java.io.File("/Users/jonathan/Projects/github/moderneinc/recipes-go/recipes-code-quality")); + java.io.File recipesPath = resolveRecipesGoPath(); + Assumptions.assumeTrue(recipesPath != null, + "recipes-go checkout not found; set -Drecipes.go.path= or RECIPES_GO_PATH env to enable"); + rpc.installRecipes(recipesPath); var recipe = rpc.prepareRecipe("org.openrewrite.golang.test.RenameXToFlag"); Tree result = recipe.getVisitor().visit(cu, new InMemoryExecutionContext()); @@ -363,4 +367,68 @@ func world() { ) ); } + + /** + * Locate a local recipes-code-quality checkout for the full-CLI-path test. + * Resolution order: -Drecipes.go.path system property, RECIPES_GO_PATH env + * var, then a sibling lookup walking up from the current working dir. + * Returns null if none found OR if its go.mod replace points at a + * rewrite-go directory that doesn't exist on this machine — caller skips. + */ + private static java.io.@Nullable File resolveRecipesGoPath() { + java.io.File candidate = locateRecipesGoPath(); + return candidate != null && replaceTargetExists(candidate) ? candidate : null; + } + + private static java.io.@Nullable File locateRecipesGoPath() { + String prop = System.getProperty("recipes.go.path"); + if (prop == null || prop.isEmpty()) { + prop = System.getenv("RECIPES_GO_PATH"); + } + if (prop != null && !prop.isEmpty()) { + java.io.File f = new java.io.File(prop); + return f.isDirectory() ? f : null; + } + java.io.File cur = new java.io.File(System.getProperty("user.dir")).getAbsoluteFile(); + while (cur != null) { + for (String rel : new String[]{ + "moderneinc/recipes-go/.worktrees/golang/recipes-code-quality", + "recipes-go/recipes-code-quality" + }) { + java.io.File c = new java.io.File(cur, rel); + if (c.isDirectory()) { + return c; + } + } + cur = cur.getParentFile(); + } + return null; + } + + /** + * Confirm that the rewrite-go directory referenced by the recipes-go + * go.mod replace actually exists on this machine. Worktree layouts can + * cause the relative replace to point at a non-existent path; in that + * case the install would fail mid-`go mod tidy` with a confusing error. + */ + private static boolean replaceTargetExists(java.io.File recipesGoDir) { + try { + java.nio.file.Path goMod = recipesGoDir.toPath().resolve("go.mod"); + for (String line : java.nio.file.Files.readAllLines(goMod)) { + String t = line.trim(); + if (!t.startsWith("replace ")) continue; + int arrow = t.indexOf("=>"); + if (arrow < 0) continue; + String target = t.substring(arrow + 2).trim(); + if (target.contains("@")) continue; + java.nio.file.Path resolved = recipesGoDir.toPath().resolve(target).normalize(); + if (target.contains("rewrite-go") && !java.nio.file.Files.isDirectory(resolved)) { + return false; + } + } + return true; + } catch (java.io.IOException e) { + return false; + } + } } diff --git a/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/MarkerRoundTripTest.java b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/MarkerRoundTripTest.java new file mode 100644 index 00000000000..be288434d96 --- /dev/null +++ b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/MarkerRoundTripTest.java @@ -0,0 +1,229 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.rpc; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.openrewrite.SourceFile; +import org.openrewrite.Tree; +import org.openrewrite.golang.GolangParser; +import org.openrewrite.golang.marker.GoProject; +import org.openrewrite.golang.marker.GoResolutionResult; +import org.openrewrite.marker.Markers; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Cross-language round-trip for {@link GoProject} and {@link GoResolutionResult}. + *

+ * Both markers implement {@code RpcCodec} on the Java side and have matching + * dispatch in the Go {@code pkg/rpc} codec (see + * {@code go_resolution_result_codec.go}). This test exercises the + * Java → Go → Java path by: + *

    + *
  1. Parsing a Go source via the RPC subprocess.
  2. + *
  3. Attaching populated marker instances to the resulting LST.
  4. + *
  5. Sending the LST back to Go via {@code rpc.print(...)} — Go must + * deserialize the markers without error.
  6. + *
  7. Asserting the printed source matches the input, proving the + * round-trip didn't truncate or lose data.
  8. + *
+ * Field-order or name-mapping divergence between the two languages causes + * receive-side panics or empty markers, both of which fail this test. + */ +@Timeout(value = 120, unit = TimeUnit.SECONDS) +class MarkerRoundTripTest { + + @TempDir + Path tempDir; + + @BeforeEach + void before() { + Path binaryPath = Paths.get("build/rewrite-go-rpc").toAbsolutePath(); + GoRewriteRpc.setFactory(GoRewriteRpc.builder() + .goBinaryPath(binaryPath) + .log(tempDir.resolve("go-rpc.log")) + .traceRpcMessages()); + } + + @AfterEach + void after() { + GoRewriteRpc.shutdownCurrent(); + } + + @Test + void goProjectMarkerRoundTripsViaPrint() { + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + String source = "package main\n\nfunc main() {\n}\n"; + SourceFile cu = GolangParser.builder().build() + .parse(source).findFirst().orElseThrow(); + + GoProject marker = new GoProject(UUID.randomUUID(), "example/foo"); + cu = cu.withMarkers(cu.getMarkers().addIfAbsent(marker)); + + // Force the marker through Go's receive codec. + String printed = rpc.print(cu); + assertThat(printed).isEqualTo(source); + } + + @Test + void goResolutionResultMarkerRoundTripsViaPrint() { + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + String source = "package main\n\nfunc main() {\n}\n"; + SourceFile cu = GolangParser.builder().build() + .parse(source).findFirst().orElseThrow(); + + GoResolutionResult marker = new GoResolutionResult( + UUID.randomUUID(), + "example.com/foo", + "1.22", + "go1.22.5", + "/tmp/go.mod", + Arrays.asList( + new GoResolutionResult.Require("github.com/google/uuid", "v1.6.0", false), + new GoResolutionResult.Require("golang.org/x/mod", "v0.35.0", true) + ), + Arrays.asList( + new GoResolutionResult.Replace("github.com/x/y", null, "../local/y", null), + new GoResolutionResult.Replace("github.com/a/b", "v1.0.0", "github.com/forked/b", "v1.0.1") + ), + Collections.singletonList( + new GoResolutionResult.Exclude("github.com/bad", "v0.0.1") + ), + Arrays.asList( + new GoResolutionResult.Retract("v0.0.5", "deleted main.go"), + new GoResolutionResult.Retract("[v1.0.0, v1.0.5]", null) + ), + Collections.singletonList( + new GoResolutionResult.ResolvedDependency( + "github.com/google/uuid", "v1.6.0", + "h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=", + "h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=") + ) + ); + cu = cu.withMarkers(cu.getMarkers().addIfAbsent(marker)); + + String printed = rpc.print(cu); + assertThat(printed).isEqualTo(source); + } + + @Test + void emptyGoResolutionResultRoundTripsViaPrint() { + // Mirrors the descriptor empty-list fix: collections that are empty + // must serialize as empty arrays so the receive side doesn't read + // null and panic. + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + String source = "package main\n\nfunc main() {\n}\n"; + SourceFile cu = GolangParser.builder().build() + .parse(source).findFirst().orElseThrow(); + + GoResolutionResult marker = new GoResolutionResult( + UUID.randomUUID(), + "example.com/empty", + null, + null, + "go.mod", + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList() + ); + cu = cu.withMarkers(cu.getMarkers().addIfAbsent(marker)); + + String printed = rpc.print(cu); + assertThat(printed).isEqualTo(source); + } + + @Test + void bothMarkersTogetherRoundTripViaPrint() { + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + String source = "package main\n\nfunc main() {\n}\n"; + SourceFile cu = GolangParser.builder().build() + .parse(source).findFirst().orElseThrow(); + + Markers markers = cu.getMarkers() + .addIfAbsent(new GoProject(UUID.randomUUID(), "example/foo")) + .addIfAbsent(new GoResolutionResult( + UUID.randomUUID(), + "example.com/foo", + "1.22", + null, + "go.mod", + Collections.singletonList( + new GoResolutionResult.Require("github.com/google/uuid", "v1.6.0", false) + ), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList() + )); + cu = cu.withMarkers(markers); + + String printed = rpc.print(cu); + assertThat(printed).isEqualTo(source); + } + + @Test + @SuppressWarnings("unused") + void roundTripPreservesGoResolutionResultFieldsViaVisit() { + // Exercises the full Java → Go → Java path: Visit RPC sends the LST + // (with attached markers) to Go, runs a no-op recipe, and ships the + // result back. Markers on the input must be preserved. + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + String source = "package main\n\nfunc f() {\n\tvar x = true\n\t_ = x\n}\n"; + SourceFile cu = GolangParser.builder().build() + .parse(source).findFirst().orElseThrow(); + + UUID projectId = UUID.randomUUID(); + UUID gomodId = UUID.randomUUID(); + cu = cu.withMarkers(cu.getMarkers() + .addIfAbsent(new GoProject(projectId, "example/foo")) + .addIfAbsent(new GoResolutionResult( + gomodId, "example.com/foo", "1.22", null, "go.mod", + Collections.singletonList( + new GoResolutionResult.Require("github.com/google/uuid", "v1.6.0", false)), + Collections.emptyList(), Collections.emptyList(), + Collections.emptyList(), Collections.emptyList()))); + + var recipe = rpc.prepareRecipe("org.openrewrite.golang.test.RenameXToFlag"); + Tree result = recipe.getVisitor().visit(cu, new org.openrewrite.InMemoryExecutionContext()); + assertThat(result).isInstanceOf(SourceFile.class); + + Markers resultMarkers = ((SourceFile) result).getMarkers(); + GoProject project = resultMarkers.findFirst(GoProject.class).orElseThrow( + () -> new AssertionError("GoProject marker missing from round-trip result")); + assertThat(project.getId()).isEqualTo(projectId); + assertThat(project.getProjectName()).isEqualTo("example/foo"); + + GoResolutionResult mrr = resultMarkers.findFirst(GoResolutionResult.class).orElseThrow( + () -> new AssertionError("GoResolutionResult marker missing from round-trip result")); + assertThat(mrr.getId()).isEqualTo(gomodId); + assertThat(mrr.getModulePath()).isEqualTo("example.com/foo"); + assertThat(mrr.getRequires()).hasSize(1); + assertThat(mrr.getRequires().get(0).getModulePath()).isEqualTo("github.com/google/uuid"); + } +} diff --git a/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/ParseProjectModuleContextTest.java b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/ParseProjectModuleContextTest.java new file mode 100644 index 00000000000..cc75671f056 --- /dev/null +++ b/rewrite-go/src/integTest/java/org/openrewrite/golang/rpc/ParseProjectModuleContextTest.java @@ -0,0 +1,181 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.rpc; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.SourceFile; +import org.openrewrite.golang.marker.GoResolutionResult; +import org.openrewrite.golang.tree.Go; + +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * End-to-end test for {@code handleParseProject}'s module-context + * inference (item 8 of the rewrite-go parity plan). + *

+ * Each .go file in the discovered project tree must resolve against its + * closest-ancestor go.mod, not the project root's. The owning + * {@link GoResolutionResult} marker is attached to each compilation unit + * so Java-side recipes can read module dependencies without re-parsing + * go.mod themselves. + */ +@Timeout(value = 120, unit = TimeUnit.SECONDS) +class ParseProjectModuleContextTest { + + @TempDir + Path tempDir; + + @TempDir + Path projectDir; + + @BeforeEach + void before() { + Path binaryPath = Paths.get("build/rewrite-go-rpc").toAbsolutePath(); + GoRewriteRpc.setFactory(GoRewriteRpc.builder() + .goBinaryPath(binaryPath) + .log(tempDir.resolve("go-rpc.log")) + .traceRpcMessages()); + } + + @AfterEach + void after() { + GoRewriteRpc.shutdownCurrent(); + } + + @Test + void singleModuleAttachesResolutionResult() throws Exception { + // Single root go.mod + two .go files in different packages. + write(projectDir.resolve("go.mod"), """ + module example.com/foo + + go 1.22 + + require github.com/google/uuid v1.6.0 + """); + write(projectDir.resolve("main.go"), """ + package main + + func main() {} + """); + write(projectDir.resolve("sub/sub.go"), """ + package sub + + func Hello() string { return "hi" } + """); + + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + List sources = rpc.parseProject(projectDir, new InMemoryExecutionContext()).collect(Collectors.toList()); + + // Both .go files should be parsed and carry the same GoResolutionResult. + List cus = sources.stream() + .filter(s -> s instanceof Go.CompilationUnit) + .map(s -> (Go.CompilationUnit) s) + .collect(Collectors.toList()); + assertThat(cus).as("expected 2 .go compilation units").hasSize(2); + + for (Go.CompilationUnit cu : cus) { + GoResolutionResult mrr = cu.getMarkers().findFirst(GoResolutionResult.class).orElseThrow( + () -> new AssertionError("missing GoResolutionResult on " + cu.getSourcePath())); + assertThat(mrr.getModulePath()).isEqualTo("example.com/foo"); + assertThat(mrr.getGoVersion()).isEqualTo("1.22"); + assertThat(mrr.getRequires()) + .extracting(GoResolutionResult.Require::getModulePath) + .contains("github.com/google/uuid"); + } + } + + @Test + void nestedSubmoduleResolvesAgainstClosestAncestor() throws Exception { + // Root go.mod + nested submodule with its own go.mod. Each .go file + // must resolve against its closest-ancestor go.mod, not the root. + write(projectDir.resolve("go.mod"), """ + module example.com/root + + go 1.22 + """); + write(projectDir.resolve("main.go"), """ + package main + + func main() {} + """); + write(projectDir.resolve("nested/go.mod"), """ + module example.com/nested + + go 1.22 + """); + write(projectDir.resolve("nested/lib.go"), """ + package nested + + func Lib() string { return "nested" } + """); + + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + List sources = rpc.parseProject(projectDir, new InMemoryExecutionContext()).collect(Collectors.toList()); + + Go.CompilationUnit rootCu = findBySuffix(sources, "main.go"); + Go.CompilationUnit nestedCu = findBySuffix(sources, "nested/lib.go"); + + assertThat(rootCu.getMarkers().findFirst(GoResolutionResult.class).orElseThrow().getModulePath()) + .isEqualTo("example.com/root"); + assertThat(nestedCu.getMarkers().findFirst(GoResolutionResult.class).orElseThrow().getModulePath()) + .isEqualTo("example.com/nested"); + } + + @Test + void noGoModLeavesCompilationUnitsUnattributed() throws Exception { + // Without any go.mod the project still parses, but no module + // resolution marker is attached. + write(projectDir.resolve("main.go"), """ + package main + + func main() {} + """); + + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + List sources = rpc.parseProject(projectDir, new InMemoryExecutionContext()).collect(Collectors.toList()); + + Go.CompilationUnit cu = findBySuffix(sources, "main.go"); + assertThat(cu.getMarkers().findFirst(GoResolutionResult.class)).isEmpty(); + } + + private static Go.CompilationUnit findBySuffix(List sources, String suffix) { + return sources.stream() + .filter(s -> s instanceof Go.CompilationUnit) + .map(s -> (Go.CompilationUnit) s) + .filter(cu -> cu.getSourcePath().toString().replace('\\', '/').endsWith(suffix)) + .findFirst() + .orElseThrow(() -> new AssertionError("no source ending with " + suffix + + "; got " + sources.stream().map(SourceFile::getSourcePath).collect(Collectors.toList()))); + } + + private static void write(Path path, String content) throws java.io.IOException { + Files.createDirectories(path.getParent()); + Files.write(path, content.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/Assertions.java b/rewrite-go/src/main/java/org/openrewrite/golang/Assertions.java index 0a5e99882c6..3d88ee54273 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/Assertions.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/Assertions.java @@ -16,7 +16,13 @@ package org.openrewrite.golang; import org.jspecify.annotations.Nullable; +import org.openrewrite.SourceFile; +import org.openrewrite.Tree; +import org.openrewrite.golang.marker.GoProject; import org.openrewrite.golang.tree.Go; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; import org.openrewrite.test.SourceSpec; import org.openrewrite.test.SourceSpecs; import org.openrewrite.text.PlainText; @@ -74,4 +80,200 @@ public static SourceSpecs goMod(@Nullable String before, String after, spec.accept(goMod); return goMod; } + + /** + * Wrap go.sum content as a sibling SourceSpec. When placed inside a + * {@link #goProject(String, SourceSpecs...)} alongside a + * {@link #goMod(String)}, the parser reads the sibling go.sum off disk + * during parse (via {@link GoModParser#parseSumContent}) and populates + * {@code GoResolutionResult.resolvedDependencies}. + *

+ * Today there is no dedicated parser for go.sum — content round-trips as + * a {@link PlainText}. The marker side-effect happens during go.mod + * parsing. + */ + public static SourceSpecs goSum(@Nullable String before) { + return goSum(before, s -> { + }); + } + + public static SourceSpecs goSum(@Nullable String before, Consumer> spec) { + SourceSpec

goSum = new SourceSpec<>(PlainText.class, null, + org.openrewrite.text.PlainTextParser.builder(), before, null); + goSum.path("go.sum"); + spec.accept(goSum); + return goSum; + } + + public static SourceSpecs goSum(@Nullable String before, String after) { + return goSum(before, after, s -> { + }); + } + + public static SourceSpecs goSum(@Nullable String before, String after, + Consumer<SourceSpec<PlainText>> spec) { + SourceSpec<PlainText> goSum = new SourceSpec<>(PlainText.class, null, + org.openrewrite.text.PlainTextParser.builder(), before, s -> after); + goSum.path("go.sum"); + spec.accept(goSum); + return goSum; + } + + /** + * Wrap sibling sources in a Go project directory and tag each with a + * {@link GoProject} marker. Mirrors {@code mavenProject(name, sources...)}: + * the {@code go.mod} (via {@link #goMod(String)}) and {@code .go} files + * (via {@link #go(String)}) are SIBLINGS inside the project directory, + * not nested inside one another. Recipes that need module-level + * dependency information look up the sibling go.mod's + * {@link org.openrewrite.golang.marker.GoResolutionResult}. + */ + public static SourceSpecs goProject(String project, SourceSpecs... sources) { + return goProject(project, spec -> project(spec, project), sources); + } + + public static SourceSpecs goProject(String project, Consumer<SourceSpec<SourceFile>> spec, SourceSpecs... sources) { + return SourceSpecs.dir(project, spec, sources); + } + + /** + * Tag a single SourceSpec with a {@link GoProject} marker. Used by the + * {@code goProject(name, ...)} consumer to apply the marker to every + * child source. + */ + public static SourceSpec<?> project(SourceSpec<?> sourceSpec, String projectName) { + return sourceSpec.markers(new GoProject(Tree.randomId(), projectName)); + } + + /** + * Walk {@code root} and assert that the first {@link J.Identifier} whose + * {@code simpleName} equals {@code name} carries a non-null + * {@link JavaType.FullyQualified} type whose fully-qualified name equals + * {@code expectedFqn}. Use this for class/struct/parameterized types; for + * primitives use {@link #expectPrimitiveType(Tree, String, String)}. + * + * @throws AssertionError when no such identifier exists, its type is + * null, or the type is not fully qualified. + */ + public static void expectType(Tree root, String name, String expectedFqn) { + IdentifierTypeFinder finder = new IdentifierTypeFinder(name); + finder.visit(root, 0); + if (!finder.found) { + throw new AssertionError("expectType(\"" + name + "\"): no identifier with that name in tree"); + } + if (finder.type == null) { + throw new AssertionError("expectType(\"" + name + "\"): identifier has null type"); + } + if (!(finder.type instanceof JavaType.FullyQualified)) { + throw new AssertionError("expectType(\"" + name + "\"): identifier type is " + + finder.type.getClass().getSimpleName() + ", want FullyQualified"); + } + String got = ((JavaType.FullyQualified) finder.type).getFullyQualifiedName(); + if (!got.equals(expectedFqn)) { + throw new AssertionError("expectType(\"" + name + "\"): FQN = \"" + got + "\", want \"" + expectedFqn + "\""); + } + } + + /** + * Walk {@code root} and assert that the first {@link J.Identifier} whose + * {@code simpleName} equals {@code name} carries a {@link JavaType.Primitive} + * whose keyword equals {@code expectedKeyword} (e.g. {@code "int"}, + * {@code "String"}, {@code "boolean"}). + */ + public static void expectPrimitiveType(Tree root, String name, String expectedKeyword) { + IdentifierTypeFinder finder = new IdentifierTypeFinder(name); + finder.visit(root, 0); + if (!finder.found) { + throw new AssertionError("expectPrimitiveType(\"" + name + "\"): no identifier with that name in tree"); + } + if (finder.type == null) { + throw new AssertionError("expectPrimitiveType(\"" + name + "\"): identifier has null type"); + } + if (!(finder.type instanceof JavaType.Primitive)) { + throw new AssertionError("expectPrimitiveType(\"" + name + "\"): identifier type is " + + finder.type.getClass().getSimpleName() + ", want Primitive"); + } + String got = ((JavaType.Primitive) finder.type).getKeyword(); + if (!got.equals(expectedKeyword)) { + throw new AssertionError("expectPrimitiveType(\"" + name + "\"): keyword = \"" + got + "\", want \"" + expectedKeyword + "\""); + } + } + + /** + * Walk {@code root} and assert that the first + * {@link J.MethodInvocation} or {@link J.MethodDeclaration} whose name + * equals {@code name} carries a non-null {@link JavaType.Method} whose + * {@code declaringType} fully-qualified name equals + * {@code expectedDeclaringFqn}. + * + * <p>For invocations across packages, {@code expectedDeclaringFqn} is the + * import path of the owning package (e.g. {@code "fmt"} for + * {@code fmt.Println}). For methods declared in the file under test it is + * the package's full path (e.g. {@code "main.Point"}). + */ + public static void expectMethodType(Tree root, String name, String expectedDeclaringFqn) { + MethodTypeFinder finder = new MethodTypeFinder(name); + finder.visit(root, 0); + if (!finder.found) { + throw new AssertionError("expectMethodType(\"" + name + "\"): no method with that name in tree"); + } + if (finder.methodType == null) { + throw new AssertionError("expectMethodType(\"" + name + "\"): method has null methodType"); + } + if (finder.methodType.getDeclaringType() == null) { + throw new AssertionError("expectMethodType(\"" + name + "\"): method has null declaringType"); + } + String got = finder.methodType.getDeclaringType().getFullyQualifiedName(); + if (!got.equals(expectedDeclaringFqn)) { + throw new AssertionError("expectMethodType(\"" + name + "\"): declaring FQN = \"" + got + + "\", want \"" + expectedDeclaringFqn + "\""); + } + } + + private static final class IdentifierTypeFinder extends JavaIsoVisitor<Integer> { + private final String name; + boolean found; + @Nullable JavaType type; + + IdentifierTypeFinder(String name) { + this.name = name; + } + + @Override + public J.Identifier visitIdentifier(J.Identifier identifier, Integer p) { + if (!found && name.equals(identifier.getSimpleName())) { + found = true; + type = identifier.getType(); + } + return identifier; + } + } + + private static final class MethodTypeFinder extends JavaIsoVisitor<Integer> { + private final String name; + boolean found; + JavaType.@Nullable Method methodType; + + MethodTypeFinder(String name) { + this.name = name; + } + + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Integer p) { + if (!found && name.equals(method.getSimpleName())) { + found = true; + methodType = method.getMethodType(); + } + return super.visitMethodInvocation(method, p); + } + + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Integer p) { + if (!found && name.equals(method.getSimpleName())) { + found = true; + methodType = method.getMethodType(); + } + return super.visitMethodDeclaration(method, p); + } + } } diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/GoModParser.java b/rewrite-go/src/main/java/org/openrewrite/golang/GoModParser.java index 71b189debee..9c737c1e0a5 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/GoModParser.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/GoModParser.java @@ -253,40 +253,63 @@ private static void parseBlockEntry(BlockState block, String rawLine, String lin } private static List<ResolvedDependency> parseSumSibling(Path goModPath) { - List<ResolvedDependency> resolved = new ArrayList<>(); Path sumPath = goModPath.resolveSibling("go.sum"); java.io.File sumFile = sumPath.toFile(); if (!sumFile.isFile()) { + return new ArrayList<>(); + } + try { + String content = new String(java.nio.file.Files.readAllBytes(sumPath), java.nio.charset.StandardCharsets.UTF_8); + return parseSumContent(content); + } catch (java.io.IOException ignored) { + // go.sum read failures are non-fatal — return empty list + return new ArrayList<>(); + } + } + + /** + * Parse go.sum content (string) into the same shape as + * {@link #parseSumSibling(Path)}. Mirrors the Go-side + * {@code parser.ParseGoSum} for cross-language parity. + * <p> + * Malformed lines are logged and skipped — go.sum is best-effort + * metadata, not authoritative; one bad line shouldn't tank a parse. + */ + public static List<ResolvedDependency> parseSumContent(@Nullable String content) { + List<ResolvedDependency> resolved = new ArrayList<>(); + if (content == null || content.isEmpty()) { return resolved; } - try (java.io.BufferedReader reader = new java.io.BufferedReader(new java.io.FileReader(sumFile))) { - // go.sum format: "<module> <version>[/go.mod] h1:<hash>" - // Each module version appears on two lines: one for the module zip, one for its go.mod file. - java.util.Map<String, String[]> byKey = new java.util.LinkedHashMap<>(); - String line; - while ((line = reader.readLine()) != null) { - Matcher m = GO_SUM_LINE.matcher(line); - if (!m.matches()) { - continue; - } - String module = m.group(1); - String version = m.group(2); - boolean isGoMod = m.group(3) != null; - String hash = m.group(4); - String key = module + "@" + version; - String[] slot = byKey.computeIfAbsent(key, k -> new String[2]); - if (isGoMod) { - slot[1] = "h1:" + hash; - } else { - slot[0] = "h1:" + hash; - } + // go.sum format: "<module> <version>[/go.mod] h1:<hash>" + // Each module version appears on two lines: one for the module zip, one for its go.mod file. + java.util.Map<String, String[]> byKey = new java.util.LinkedHashMap<>(); + String[] lines = content.split("\\r?\\n", -1); + for (int i = 0; i < lines.length; i++) { + String line = lines[i]; + if (line.trim().isEmpty()) { + continue; } - for (java.util.Map.Entry<String, String[]> e : byKey.entrySet()) { - String[] parts = e.getKey().split("@", 2); - resolved.add(new ResolvedDependency(parts[0], parts[1], e.getValue()[0], e.getValue()[1])); + Matcher m = GO_SUM_LINE.matcher(line); + if (!m.matches()) { + java.util.logging.Logger.getLogger(GoModParser.class.getName()) + .fine("go.sum line " + (i + 1) + ": skipping malformed entry: " + line); + continue; } - } catch (java.io.IOException ignored) { - // go.sum read failures are non-fatal — return whatever we collected + String module = m.group(1); + String version = m.group(2); + boolean isGoMod = m.group(3) != null; + String hash = m.group(4); + String key = module + "@" + version; + String[] slot = byKey.computeIfAbsent(key, k -> new String[2]); + if (isGoMod) { + slot[1] = "h1:" + hash; + } else { + slot[0] = "h1:" + hash; + } + } + for (java.util.Map.Entry<String, String[]> e : byKey.entrySet()) { + String[] parts = e.getKey().split("@", 2); + resolved.add(new ResolvedDependency(parts[0], parts[1], e.getValue()[0], e.getValue()[1])); } return resolved; } diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/GolangParser.java b/rewrite-go/src/main/java/org/openrewrite/golang/GolangParser.java index 325cf7f5c12..f94e532d0b9 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/GolangParser.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/GolangParser.java @@ -31,12 +31,36 @@ * This parser uses RPC to communicate with a Go process that performs the actual parsing. * The Go process uses Go's standard library parser to parse source code and converts it to * the OpenRewrite LST format. + * <p> + * When constructed with module + go.mod context (via {@link Builder#module(String)} and + * {@link Builder#goMod(String)}), the parser routes batches through + * {@link GoRewriteRpc#parseWithProject} so the Go server builds a ProjectImporter + * for type attribution: intra-project imports resolve against sibling sources, and + * imports of go.mod-declared third-party modules resolve to stub + * {@code *types.Package} objects. */ public class GolangParser implements Parser { + private final @Nullable String module; + private final @Nullable String goModContent; + + GolangParser(@Nullable String module, @Nullable String goModContent) { + this.module = module; + this.goModContent = goModContent; + } + + public GolangParser() { + this(null, null); + } + @Override public Stream<SourceFile> parseInputs(Iterable<Input> sources, @Nullable Path relativeTo, ExecutionContext ctx) { - return GoRewriteRpc.getOrStart().parse(sources, relativeTo, this, + GoRewriteRpc rpc = GoRewriteRpc.getOrStart(); + if (module != null && !module.isEmpty()) { + return rpc.parseWithProject(sources, relativeTo, this, + Go.CompilationUnit.class.getName(), ctx, module, goModContent); + } + return rpc.parse(sources, relativeTo, this, Go.CompilationUnit.class.getName(), ctx); } @@ -56,13 +80,35 @@ public static Builder builder() { public static class Builder extends Parser.Builder { + private @Nullable String module; + private @Nullable String goModContent; + public Builder() { super(Go.CompilationUnit.class); } + /** + * Set the Go module path (e.g. {@code example.com/foo}) so the + * parser asks the Go server to build a project-aware Importer. + */ + public Builder module(@Nullable String module) { + this.module = module; + return this; + } + + /** + * Set the raw go.mod content. Used by the Go server to register + * {@code require} directives so imports of those modules resolve + * to stub packages even when their sources aren't present. + */ + public Builder goMod(@Nullable String goModContent) { + this.goModContent = goModContent; + return this; + } + @Override public GolangParser build() { - return new GolangParser(); + return new GolangParser(module, goModContent); } @Override diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/GolangVisitor.java b/rewrite-go/src/main/java/org/openrewrite/golang/GolangVisitor.java index fbea58e77ca..c5689b2b352 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/GolangVisitor.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/GolangVisitor.java @@ -223,6 +223,7 @@ public J visitTypeDecl(Go.TypeDecl typeDecl, P p) { Go.TypeDecl t = typeDecl; t = t.withPrefix(visitSpace(t.getPrefix(), Space.Location.LANGUAGE_EXTENSION, p)); t = t.withMarkers(visitMarkers(t.getMarkers(), p)); + t = t.withLeadingAnnotations(ListUtils.map(t.getLeadingAnnotations(), a -> visitAndCast(a, p))); t = t.withName((J.Identifier) visitAndCast(t.getName(), p)); if (t.getDefinition() != null) { t = t.withDefinition((Expression) visitAndCast(t.getDefinition(), p)); diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangReceiver.java b/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangReceiver.java index 5369773cf75..56313f5556a 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangReceiver.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangReceiver.java @@ -173,6 +173,7 @@ public J visitTypeList(Go.TypeList typeList, RpcReceiveQueue q) { @Override public J visitTypeDecl(Go.TypeDecl typeDecl, RpcReceiveQueue q) { return typeDecl + .withLeadingAnnotations(q.receiveList(typeDecl.getLeadingAnnotations(), a -> (J.Annotation) visitNonNull(a, q))) .withName(q.receive(typeDecl.getName(), el -> (J.Identifier) visitNonNull(el, q))) .getPadding().withAssign(q.receive(typeDecl.getPadding().getAssign(), el -> visitLeftPadded(el, q))) .withDefinition(q.receive(typeDecl.getDefinition(), el -> (Expression) visitNonNull(el, q))) diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangSender.java b/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangSender.java index 7eaea970e37..6dcc445525c 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangSender.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/internal/rpc/GolangSender.java @@ -167,6 +167,7 @@ public J visitTypeList(Go.TypeList typeList, RpcSendQueue q) { @Override public J visitTypeDecl(Go.TypeDecl typeDecl, RpcSendQueue q) { + q.getAndSendList(typeDecl, Go.TypeDecl::getLeadingAnnotations, Tree::getId, a -> visit(a, q)); q.getAndSend(typeDecl, Go.TypeDecl::getName, el -> visit(el, q)); q.getAndSend(typeDecl, t -> t.getPadding().getAssign(), el -> visitLeftPadded(el, q)); q.getAndSend(typeDecl, Go.TypeDecl::getDefinition, el -> visit(el, q)); diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/marker/GoProject.java b/rewrite-go/src/main/java/org/openrewrite/golang/marker/GoProject.java new file mode 100644 index 00000000000..e7f9bfab501 --- /dev/null +++ b/rewrite-go/src/main/java/org/openrewrite/golang/marker/GoProject.java @@ -0,0 +1,57 @@ +/* + * Copyright 2026 the original author or authors. + * <p> + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * <p> + * https://docs.moderne.io/licensing/moderne-source-available-license + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.marker; + +import lombok.EqualsAndHashCode; +import lombok.Value; +import lombok.With; +import org.openrewrite.marker.Marker; +import org.openrewrite.rpc.RpcCodec; +import org.openrewrite.rpc.RpcReceiveQueue; +import org.openrewrite.rpc.RpcSendQueue; + +import java.util.UUID; + +/** + * Identifies the Go project a source file belongs to. Mirrors + * {@link org.openrewrite.java.marker.JavaProject}. + * <p> + * Recipes that need module-level dependency information look up the + * sibling {@code go.mod} source by path and read its + * {@link GoResolutionResult} marker — the same connector-by-path pattern + * that Maven uses for {@code pom.xml} beside {@code src/main/java}. + */ +@Value +@With +public class GoProject implements Marker, RpcCodec<GoProject> { + @EqualsAndHashCode.Exclude + UUID id; + + String projectName; + + @Override + public void rpcSend(GoProject after, RpcSendQueue q) { + q.getAndSend(after, Marker::getId); + q.getAndSend(after, GoProject::getProjectName); + } + + @Override + public GoProject rpcReceive(GoProject before, RpcReceiveQueue q) { + return before + .withId(q.receiveAndGet(before.getId(), UUID::fromString)) + .withProjectName(q.receive(before.getProjectName())); + } +} diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoParseRequest.java b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoParseRequest.java new file mode 100644 index 00000000000..f53365863a8 --- /dev/null +++ b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoParseRequest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2026 the original author or authors. + * <p> + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * <p> + * https://docs.moderne.io/licensing/moderne-source-available-license + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.rpc; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.rpc.request.Parse; +import org.openrewrite.rpc.request.RpcRequest; + +import java.util.List; + +/** + * Go-specific Parse RPC payload. Mirrors the rewrite-core {@code Parse} + * shape but adds optional {@code module} and {@code goModContent} fields + * the Go server uses to build a project-aware {@code ProjectImporter} + * for type attribution. Other languages don't need these fields and the + * Go server ignores them when absent. + */ +@Value +public class GoParseRequest implements RpcRequest { + List<Parse.Input> inputs; + + @Nullable + String relativeTo; + + /** + * Module path declared by the project's go.mod, e.g. {@code example.com/foo}. + * When set, the Go server constructs a ProjectImporter and uses it for + * type attribution; the inputs in this batch are treated as siblings + * of that module. + */ + @Nullable + String module; + + /** + * Raw go.mod content. The Go server parses it for {@code require} + * directives, registering each as a known module path so imports of + * those modules resolve to stub {@code *types.Package} objects. + */ + @Nullable + String goModContent; +} diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoRewriteRpc.java b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoRewriteRpc.java index 32e9dabd4f9..2946a9cf218 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoRewriteRpc.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/GoRewriteRpc.java @@ -17,11 +17,20 @@ import lombok.Getter; import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Parser; +import org.openrewrite.SourceFile; +import org.openrewrite.golang.GolangParser; import org.openrewrite.marketplace.RecipeBundleResolver; import org.openrewrite.marketplace.RecipeMarketplace; import org.openrewrite.rpc.RewriteRpc; import org.openrewrite.rpc.RewriteRpcProcess; import org.openrewrite.rpc.RewriteRpcProcessManager; +import org.openrewrite.rpc.request.Parse; +import org.openrewrite.rpc.request.ParseResponse; +import org.openrewrite.tree.ParseError; +import org.openrewrite.tree.ParsingEventListener; +import org.openrewrite.tree.ParsingExecutionContextView; import java.io.File; import java.io.IOException; @@ -29,13 +38,19 @@ import java.io.UncheckedIOException; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.time.Duration; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Spliterator; +import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Stream; +import java.util.stream.StreamSupport; /** * RPC client that communicates with a Go process for parsing and printing Go source code. @@ -46,11 +61,14 @@ public class GoRewriteRpc extends RewriteRpc { private static final RewriteRpcProcessManager<GoRewriteRpc> MANAGER = new RewriteRpcProcessManager<>(builder()); private final String command; + private final Map<String, String> commandEnv; private final RewriteRpcProcess process; - GoRewriteRpc(RewriteRpcProcess process, RecipeMarketplace marketplace, List<RecipeBundleResolver> resolvers, String command) { + GoRewriteRpc(RewriteRpcProcess process, RecipeMarketplace marketplace, List<RecipeBundleResolver> resolvers, + String command, Map<String, String> commandEnv) { super(process.getRpcClient(), marketplace, resolvers); this.command = command; + this.commandEnv = commandEnv; this.process = process; } @@ -76,6 +94,72 @@ public static void shutdownCurrent() { MANAGER.shutdown(); } + public static void resetCurrent() { + MANAGER.reset(); + } + + /** + * Parse a batch of Go source inputs with project (module) context. + * The Go server constructs a {@code ProjectImporter} from the module + * path + go.mod content, registers every input as a sibling, and uses + * it for type attribution. Files in the same directory are parsed + * together so cross-file references inside a package resolve. + * <p> + * Use the regular {@link #parse} path when no project context is + * available — that falls back to per-file, stdlib-only type + * attribution (today's behavior). + */ + public java.util.stream.Stream<SourceFile> parseWithProject( + Iterable<Parser.Input> inputs, + @Nullable Path relativeTo, + Parser parser, + String sourceFileType, + ExecutionContext ctx, + String module, + @Nullable String goModContent) { + java.util.List<Parser.Input> inputList = new java.util.ArrayList<>(); + java.util.List<Parse.Input> mappedInputs = new java.util.ArrayList<>(); + for (Parser.Input input : inputs) { + inputList.add(input); + if (input.isSynthetic() || !java.nio.file.Files.isRegularFile(input.getPath())) { + mappedInputs.add(new Parse.Input(input.getSource(ctx).readFully(), input.getPath())); + } else { + mappedInputs.add(new Parse.Input(null, input.getPath())); + } + } + if (inputList.isEmpty()) { + return java.util.stream.Stream.empty(); + } + + ParsingEventListener parsingListener = ParsingExecutionContextView.view(ctx).getParsingListener(); + parsingListener.intermediateMessage(String.format("Starting parsing of %,d files (module=%s)", inputList.size(), module)); + + java.util.List<String> ids = send("Parse", new GoParseRequest( + mappedInputs, + relativeTo != null ? relativeTo.toString() : null, + module, + goModContent + ), ParseResponse.class); + if (ids.size() != inputList.size()) { + throw new IllegalStateException("Parse response size " + ids.size() + " != input size " + inputList.size()); + } + + java.util.List<SourceFile> result = new java.util.ArrayList<>(ids.size()); + for (int i = 0; i < ids.size(); i++) { + Parser.Input input = inputList.get(i); + parsingListener.startedParsing(input); + SourceFile sf; + try { + sf = getObject(ids.get(i), sourceFileType); + } catch (Exception e) { + sf = ParseError.build(parser, input, relativeTo, ctx, e); + } + result.add(sf); + parsingListener.parsed(input, sf); + } + return result.stream(); + } + /** * Install recipes from a local file path (e.g., a local Go module). * @@ -116,6 +200,111 @@ public InstallRecipesResponse installRecipes(String packageName, @Nullable Strin ); } + /** + * Parses an entire Go project directory. + * + * @param projectPath Path to the project directory to parse + * @param ctx Execution context for parsing + * @return Stream of parsed source files + */ + public Stream<SourceFile> parseProject(Path projectPath, ExecutionContext ctx) { + return parseProject(projectPath, null, null, ctx); + } + + /** + * Parses an entire Go project directory. + * + * @param projectPath Path to the project directory to parse + * @param exclusions Optional glob patterns to exclude from parsing + * @param ctx Execution context for parsing + * @return Stream of parsed source files + */ + public Stream<SourceFile> parseProject(Path projectPath, @Nullable List<String> exclusions, ExecutionContext ctx) { + return parseProject(projectPath, exclusions, null, ctx); + } + + /** + * Parses an entire Go project directory. + * + * @param projectPath Path to the project directory to parse + * @param exclusions Optional glob patterns to exclude from parsing + * @param relativeTo Optional path to make source file paths relative to. If not specified, + * paths are relative to projectPath. Use this when parsing a subdirectory + * but wanting paths relative to the repository root. + * @param ctx Execution context for parsing + * @return Stream of parsed source files + */ + public Stream<SourceFile> parseProject(Path projectPath, @Nullable List<String> exclusions, @Nullable Path relativeTo, ExecutionContext ctx) { + ParsingEventListener parsingListener = ParsingExecutionContextView.view(ctx).getParsingListener(); + + return StreamSupport.stream(new Spliterator<SourceFile>() { + private int index = 0; + private @Nullable ParseProjectResponse response; + + @Override + public boolean tryAdvance(Consumer<? super SourceFile> action) { + if (response == null) { + parsingListener.intermediateMessage("Starting project parsing: " + projectPath); + response = send("ParseProject", new ParseProject(projectPath, exclusions, relativeTo), ParseProjectResponse.class); + parsingListener.intermediateMessage(String.format("Discovered %,d files to parse", response.size())); + } + + if (index >= response.size()) { + return false; + } + + ParseProjectResponse.Item item = response.get(index); + index++; + + SourceFile sourceFile; + try { + sourceFile = getObject(item.getId(), item.getSourceFileType()); + parsingListener.startedParsing(Parser.Input.fromFile(sourceFile.getSourcePath())); + } catch (Exception e) { + // A single file's RPC deserialization failed. Convert it to a ParseError + // pointing at the offending file and keep the stream going. + // + // `item.sourcePath` may be null when talking to an older Go peer — fall + // back to the RPC object id so we still produce a usable ParseError. + String relativePath = item.getSourcePath() != null ? item.getSourcePath() : item.getId(); + Path sourcePath = Paths.get(relativePath); + Path absoluteSourcePath = projectPath.resolve(sourcePath); + try { + sourceFile = ParseError.build( + GolangParser.builder().build(), + Parser.Input.fromFile(absoluteSourcePath), + projectPath, + ctx, + e); + } catch (Exception readFailure) { + // If the file can't be read (e.g. the fallback path from id doesn't + // exist on disk), wrap the original exception in a runtime so the + // stream continues. Without the source text the error is less useful + // but the alternative — aborting the whole stream — is worse. + throw new RuntimeException("ParseProject item " + item.getId() + " failed; readback also failed", e); + } + } + action.accept(sourceFile); + return true; + } + + @Override + public @Nullable Spliterator<SourceFile> trySplit() { + return null; + } + + @Override + public long estimateSize() { + return response == null ? Long.MAX_VALUE : response.size() - index; + } + + @Override + public int characteristics() { + return response == null ? ORDERED : ORDERED | SIZED | SUBSIZED; + } + }, false); + } + public static Builder builder() { return new Builder(); } @@ -123,9 +312,14 @@ public static Builder builder() { public static class Builder implements Supplier<GoRewriteRpc> { private RecipeMarketplace marketplace = new RecipeMarketplace(); private List<RecipeBundleResolver> resolvers = Collections.emptyList(); + private final Map<String, String> environment = new HashMap<>(); private Supplier<@Nullable Path> goBinaryPathSupplier = () -> null; private Duration timeout = Duration.ofSeconds(60); private @Nullable Path log; + private @Nullable Path metricsCsv; + private @Nullable Path recipeInstallDir; + private @Nullable Path dataTablesCsvDir; + private @Nullable Path workingDirectory; private boolean traceRpcMessages; public Builder marketplace(RecipeMarketplace marketplace) { @@ -166,11 +360,40 @@ public Builder log(@Nullable Path log) { return this; } - public Builder traceRpcMessages() { - this.traceRpcMessages = true; + public Builder metricsCsv(@Nullable Path metricsCsv) { + this.metricsCsv = metricsCsv; + return this; + } + + public Builder recipeInstallDir(@Nullable Path recipeInstallDir) { + this.recipeInstallDir = recipeInstallDir; return this; } + public Builder dataTablesCsvDir(@Nullable Path dataTablesCsvDir) { + this.dataTablesCsvDir = dataTablesCsvDir; + return this; + } + + public Builder environment(Map<String, String> environment) { + this.environment.putAll(environment); + return this; + } + + public Builder workingDirectory(@Nullable Path workingDirectory) { + this.workingDirectory = workingDirectory; + return this; + } + + public Builder traceRpcMessages(boolean verboseLogging) { + this.traceRpcMessages = verboseLogging; + return this; + } + + public Builder traceRpcMessages() { + return traceRpcMessages(true); + } + @Override public GoRewriteRpc get() { @Nullable Path goBinaryPath = goBinaryPathSupplier.get(); @@ -179,9 +402,9 @@ public GoRewriteRpc get() { binaryPath = goBinaryPath.toString(); } else { // Check for a custom binary with installed recipes - java.nio.file.Path customBin = java.nio.file.Paths.get( + Path customBin = Paths.get( System.getProperty("user.home"), ".rewrite", "go-recipes", "rewrite-go-rpc"); - if (java.nio.file.Files.isExecutable(customBin)) { + if (Files.isExecutable(customBin)) { binaryPath = customBin.toString(); } else { binaryPath = "rewrite-go-rpc"; @@ -191,15 +414,25 @@ public GoRewriteRpc get() { Stream<@Nullable String> cmd = Stream.of( binaryPath, log == null ? null : "--log-file=" + log.toAbsolutePath().normalize(), + metricsCsv == null ? null : "--metrics-csv=" + metricsCsv.toAbsolutePath().normalize(), + recipeInstallDir == null ? null : "--recipe-install-dir=" + recipeInstallDir.toAbsolutePath().normalize(), + dataTablesCsvDir == null ? null : "--data-tables-csv-dir=" + dataTablesCsvDir.toAbsolutePath().normalize(), traceRpcMessages ? "--trace-rpc-messages" : null ); String[] cmdArr = cmd.filter(Objects::nonNull).toArray(String[]::new); RewriteRpcProcess process = new RewriteRpcProcess(cmdArr); + + if (workingDirectory != null) { + process.setWorkingDirectory(workingDirectory); + } + process.setStderrRedirect(log); + process.environment().putAll(environment); process.start(); try { - return (GoRewriteRpc) new GoRewriteRpc(process, marketplace, resolvers, String.join(" ", cmdArr)) + return (GoRewriteRpc) new GoRewriteRpc(process, marketplace, resolvers, + String.join(" ", cmdArr), process.environment()) .livenessCheck(process::getLivenessCheck) .timeout(timeout) .log(log == null ? null : new PrintStream(Files.newOutputStream(log, StandardOpenOption.APPEND, StandardOpenOption.CREATE))); diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/rpc/ParseProject.java b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/ParseProject.java new file mode 100644 index 00000000000..bfe153db96a --- /dev/null +++ b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/ParseProject.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025 the original author or authors. + * <p> + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * <p> + * https://docs.moderne.io/licensing/moderne-source-available-license + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.rpc; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.rpc.request.RpcRequest; + +import java.nio.file.Path; +import java.util.List; + +/** + * RPC request to parse an entire Go project. + * Discovers and parses all .go files under the project directory. + */ +@Value +class ParseProject implements RpcRequest { + Path projectPath; + + @Nullable + List<String> exclusions; + + @Nullable + Path relativeTo; + + ParseProject(Path projectPath) { + this(projectPath, null, null); + } + + ParseProject(Path projectPath, @Nullable List<String> exclusions) { + this(projectPath, exclusions, null); + } + + ParseProject(Path projectPath, @Nullable List<String> exclusions, @Nullable Path relativeTo) { + this.projectPath = projectPath; + this.exclusions = exclusions; + this.relativeTo = relativeTo; + } +} diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/rpc/ParseProjectResponse.java b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/ParseProjectResponse.java new file mode 100644 index 00000000000..7b263932822 --- /dev/null +++ b/rewrite-go/src/main/java/org/openrewrite/golang/rpc/ParseProjectResponse.java @@ -0,0 +1,37 @@ +/* + * Copyright 2025 the original author or authors. + * <p> + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * <p> + * https://docs.moderne.io/licensing/moderne-source-available-license + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.rpc; + +import lombok.Value; +import org.jspecify.annotations.Nullable; + +import java.util.ArrayList; + +class ParseProjectResponse extends ArrayList<ParseProjectResponse.Item> { + + @Value + static class Item { + String id; + String sourceFileType; + + /** + * The relative source path of the file. May be null when talking to an + * older Go peer that doesn't populate it yet — callers fall back to + * {@link #id} for error reporting in that case. + */ + @Nullable String sourcePath; + } +} diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/tree/Go.java b/rewrite-go/src/main/java/org/openrewrite/golang/tree/Go.java index 580aaf1144a..807a412e027 100644 --- a/rewrite-go/src/main/java/org/openrewrite/golang/tree/Go.java +++ b/rewrite-go/src/main/java/org/openrewrite/golang/tree/Go.java @@ -1411,6 +1411,10 @@ final class TypeDecl implements Go, Statement { @Getter Markers markers; + @With + @Getter + List<J.Annotation> leadingAnnotations; + @With @Getter J.Identifier name; @@ -1468,7 +1472,7 @@ public static class Padding { } public Go.TypeDecl withAssign(@Nullable JLeftPadded<Space> assign) { - return t.assign == assign ? t : new Go.TypeDecl(t.padding, t.id, t.prefix, t.markers, t.name, assign, t.definition, t.specs); + return t.assign == assign ? t : new Go.TypeDecl(t.padding, t.id, t.prefix, t.markers, t.leadingAnnotations, t.name, assign, t.definition, t.specs); } public @Nullable JContainer<Statement> getSpecs() { @@ -1476,7 +1480,7 @@ public Go.TypeDecl withAssign(@Nullable JLeftPadded<Space> assign) { } public Go.TypeDecl withSpecs(@Nullable JContainer<Statement> specs) { - return t.specs == specs ? t : new Go.TypeDecl(t.padding, t.id, t.prefix, t.markers, t.name, t.assign, t.definition, specs); + return t.specs == specs ? t : new Go.TypeDecl(t.padding, t.id, t.prefix, t.markers, t.leadingAnnotations, t.name, t.assign, t.definition, specs); } } } diff --git a/rewrite-go/src/test/java/org/openrewrite/golang/GoModConformanceTest.java b/rewrite-go/src/test/java/org/openrewrite/golang/GoModConformanceTest.java new file mode 100644 index 00000000000..724387ad892 --- /dev/null +++ b/rewrite-go/src/test/java/org/openrewrite/golang/GoModConformanceTest.java @@ -0,0 +1,196 @@ +/* + * Copyright 2026 the original author or authors. + * <p> + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * <p> + * https://docs.moderne.io/licensing/moderne-source-available-license + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.SerializationFeature; +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.TestFactory; +import org.openrewrite.golang.marker.GoResolutionResult; +import org.openrewrite.text.PlainText; +import org.openrewrite.text.PlainTextParser; +import org.openrewrite.ExecutionContext; +import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.SourceFile; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Java-side driver for the shared go.mod / go.sum conformance corpus. + * <p> + * Iterates every {@code *.gomod} under + * {@code src/test/resources/gomod-conformance/}, parses it (with sibling + * {@code *.gosum} when present), and asserts the resulting marker matches + * the canonical JSON shape in the case's {@code *.gomod.json} golden. + * <p> + * The Go-side {@code TestGoModConformanceCorpus} runs the same corpus. + * Drift between the two indicates a parser parity bug. + */ +class GoModConformanceTest { + + private static final Path CORPUS_DIR = Paths.get("src/test/resources/gomod-conformance"); + private static final ObjectMapper MAPPER = new ObjectMapper() + .configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true); + + @TestFactory + Stream<DynamicTest> conformanceCorpus() throws IOException { + try (Stream<Path> entries = Files.list(CORPUS_DIR)) { + List<Path> gomods = entries + .filter(p -> p.toString().endsWith(".gomod")) + .sorted(Comparator.comparing(Path::getFileName)) + .collect(Collectors.toList()); + assertThat(gomods).as("conformance corpus is non-empty").isNotEmpty(); + return gomods.stream().map(p -> DynamicTest.dynamicTest(stripGomod(p.getFileName().toString()), () -> runCase(p))); + } + } + + private static String stripGomod(String name) { + return name.substring(0, name.length() - ".gomod".length()); + } + + private static void runCase(Path goModPath) throws IOException { + String caseName = stripGomod(goModPath.getFileName().toString()); + String modContent = new String(Files.readAllBytes(goModPath), StandardCharsets.UTF_8); + + // Parse via the production GoModParser path (PlainText delegate + + // marker decoration). Use a virtual source path so parseSumSibling + // doesn't accidentally read a real go.sum on disk. + PlainTextParser delegate = new PlainTextParser(); + ExecutionContext ctx = new InMemoryExecutionContext(); + SourceFile sf = delegate.parse(modContent).iterator().next(); + PlainText pt = (PlainText) sf; + GoResolutionResult marker = GoModParser.parseMarker(pt); + assertThat(marker).as("parseMarker for case %s", caseName).isNotNull(); + + // If a sibling .gosum is present, attach its parsed contents. + Path sumPath = goModPath.resolveSibling(caseName + ".gosum"); + if (Files.isRegularFile(sumPath)) { + String sumContent = new String(Files.readAllBytes(sumPath), StandardCharsets.UTF_8); + marker = marker.withResolvedDependencies(GoModParser.parseSumContent(sumContent)); + } + + ObjectNode actual = toConformance(marker); + + Path goldenPath = goModPath.resolveSibling(caseName + ".gomod.json"); + ObjectNode expected = (ObjectNode) MAPPER.readTree(goldenPath.toFile()); + + assertThat(actual).as("conformance shape for case %s", caseName).isEqualTo(expected); + } + + /** + * Convert a {@link GoResolutionResult} marker into the canonical + * conformance JSON shape. Fields and ordering MUST stay in sync with + * the Go-side {@code conformanceShape} in + * {@code test/gomod_conformance_test.go}. + */ + private static ObjectNode toConformance(GoResolutionResult m) { + ObjectNode out = MAPPER.createObjectNode(); + out.put("modulePath", m.getModulePath()); + out.put("goVersion", m.getGoVersion() == null ? "" : m.getGoVersion()); + out.put("toolchain", m.getToolchain() == null ? "" : m.getToolchain()); + out.set("requires", MAPPER.valueToTree(toRequires(m.getRequires()))); + out.set("replaces", MAPPER.valueToTree(toReplaces(m.getReplaces()))); + out.set("excludes", MAPPER.valueToTree(toExcludes(m.getExcludes()))); + out.set("retracts", MAPPER.valueToTree(toRetracts(m.getRetracts()))); + out.set("resolvedDependencies", MAPPER.valueToTree(toResolved(m.getResolvedDependencies()))); + return out; + } + + private static List<ObjectNode> toRequires(List<GoResolutionResult.Require> reqs) { + List<ObjectNode> out = new ArrayList<>(); + if (reqs == null) return out; + for (GoResolutionResult.Require r : reqs) { + ObjectNode n = MAPPER.createObjectNode(); + n.put("modulePath", r.getModulePath()); + n.put("version", r.getVersion()); + n.put("indirect", r.isIndirect()); + out.add(n); + } + return out; + } + + private static List<ObjectNode> toReplaces(List<GoResolutionResult.Replace> reps) { + List<ObjectNode> out = new ArrayList<>(); + if (reps == null) return out; + for (GoResolutionResult.Replace r : reps) { + ObjectNode n = MAPPER.createObjectNode(); + n.put("oldPath", r.getOldPath()); + putNullable(n, "oldVersion", r.getOldVersion()); + n.put("newPath", r.getNewPath()); + putNullable(n, "newVersion", r.getNewVersion()); + out.add(n); + } + return out; + } + + private static List<ObjectNode> toExcludes(List<GoResolutionResult.Exclude> excs) { + List<ObjectNode> out = new ArrayList<>(); + if (excs == null) return out; + for (GoResolutionResult.Exclude e : excs) { + ObjectNode n = MAPPER.createObjectNode(); + n.put("modulePath", e.getModulePath()); + n.put("version", e.getVersion()); + out.add(n); + } + return out; + } + + private static List<ObjectNode> toRetracts(List<GoResolutionResult.Retract> rets) { + List<ObjectNode> out = new ArrayList<>(); + if (rets == null) return out; + for (GoResolutionResult.Retract r : rets) { + ObjectNode n = MAPPER.createObjectNode(); + n.put("versionRange", r.getVersionRange()); + putNullable(n, "rationale", r.getRationale()); + out.add(n); + } + return out; + } + + private static List<ObjectNode> toResolved(List<GoResolutionResult.ResolvedDependency> deps) { + List<ObjectNode> out = new ArrayList<>(); + if (deps == null) return out; + for (GoResolutionResult.ResolvedDependency d : deps) { + ObjectNode n = MAPPER.createObjectNode(); + n.put("modulePath", d.getModulePath()); + n.put("version", d.getVersion()); + putNullable(n, "moduleHash", d.getModuleHash()); + putNullable(n, "goModHash", d.getGoModHash()); + out.add(n); + } + return out; + } + + private static void putNullable(ObjectNode n, String key, String value) { + if (value == null) { + n.putNull(key); + } else { + n.put(key, value); + } + } +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/README.md b/rewrite-go/src/test/resources/gomod-conformance/README.md new file mode 100644 index 00000000000..d97687e9097 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/README.md @@ -0,0 +1,40 @@ +# GoMod conformance corpus + +A shared corpus of `go.mod` (and occasionally sibling `go.sum`) inputs used +to enforce field-for-field parity between the Java `GoModParser` and the +Go `parser.ParseGoMod` / `parser.ParseGoSum`. + +Each case has: + +- `<case>.gomod` — the input go.mod text. +- `<case>.gomod.json` — the expected `GoResolutionResult`, serialized as + the canonical conformance shape (lower-camelCase JSON, defined by both + language test harnesses). +- (optional) `<case>.gosum` — sibling go.sum content. When present, the + test suite parses it via `parseSumContent` (Java) / `ParseGoSum` (Go) and + attaches the resulting `resolvedDependencies` to the marker before + comparison. + +Both languages run the same test corpus: + +- Java: `org.openrewrite.golang.GoModConformanceTest` +- Go: `test/gomod_conformance_test.go` + +The canonical JSON shape is: + +```json +{ + "modulePath": "example.com/foo", + "goVersion": "1.22", + "toolchain": "go1.22.3", + "requires": [{"modulePath": "github.com/x/y", "version": "v1.2.3", "indirect": false}], + "replaces": [{"oldPath": "...", "oldVersion": null, "newPath": "...", "newVersion": null}], + "excludes": [{"modulePath": "...", "version": "..."}], + "retracts": [{"versionRange": "v1.0.0", "rationale": "..."}], + "resolvedDependencies": [{"modulePath": "...", "version": "...", "moduleHash": "h1:...", "goModHash": "h1:..."}] +} +``` + +Optional fields are present in all cases; empty lists are `[]` (not omitted). +String fields default to `""` when absent on the parsed marker; nullable +string fields are written as JSON `null`. diff --git a/rewrite-go/src/test/resources/gomod-conformance/basic.gomod b/rewrite-go/src/test/resources/gomod-conformance/basic.gomod new file mode 100644 index 00000000000..f5cd3c0dcfc --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/basic.gomod @@ -0,0 +1,5 @@ +module example.com/foo + +go 1.22 + +require github.com/google/uuid v1.6.0 diff --git a/rewrite-go/src/test/resources/gomod-conformance/basic.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/basic.gomod.json new file mode 100644 index 00000000000..98c011158c6 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/basic.gomod.json @@ -0,0 +1,12 @@ +{ + "modulePath": "example.com/foo", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/google/uuid", "version": "v1.6.0", "indirect": false} + ], + "replaces": [], + "excludes": [], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/block-requires.gomod b/rewrite-go/src/test/resources/gomod-conformance/block-requires.gomod new file mode 100644 index 00000000000..27ff5f93bd5 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/block-requires.gomod @@ -0,0 +1,9 @@ +module example.com/multi + +go 1.22 + +require ( + github.com/x/y v1.2.3 + github.com/z/w v0.5.0 // indirect + github.com/a/b/v2 v2.0.0 +) diff --git a/rewrite-go/src/test/resources/gomod-conformance/block-requires.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/block-requires.gomod.json new file mode 100644 index 00000000000..3fd79d0f40a --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/block-requires.gomod.json @@ -0,0 +1,14 @@ +{ + "modulePath": "example.com/multi", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/x/y", "version": "v1.2.3", "indirect": false}, + {"modulePath": "github.com/z/w", "version": "v0.5.0", "indirect": true}, + {"modulePath": "github.com/a/b/v2", "version": "v2.0.0", "indirect": false} + ], + "replaces": [], + "excludes": [], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/exclude.gomod b/rewrite-go/src/test/resources/gomod-conformance/exclude.gomod new file mode 100644 index 00000000000..9fa164f89de --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/exclude.gomod @@ -0,0 +1,7 @@ +module example.com/withexclude + +go 1.22 + +exclude github.com/bad v0.0.1 + +require github.com/x/y v1.0.0 diff --git a/rewrite-go/src/test/resources/gomod-conformance/exclude.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/exclude.gomod.json new file mode 100644 index 00000000000..7bdfa84232d --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/exclude.gomod.json @@ -0,0 +1,14 @@ +{ + "modulePath": "example.com/withexclude", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/x/y", "version": "v1.0.0", "indirect": false} + ], + "replaces": [], + "excludes": [ + {"modulePath": "github.com/bad", "version": "v0.0.1"} + ], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/incompatible.gomod b/rewrite-go/src/test/resources/gomod-conformance/incompatible.gomod new file mode 100644 index 00000000000..2ace44d18f1 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/incompatible.gomod @@ -0,0 +1,5 @@ +module example.com/withincompat + +go 1.22 + +require github.com/legacy/lib v2.0.0+incompatible diff --git a/rewrite-go/src/test/resources/gomod-conformance/incompatible.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/incompatible.gomod.json new file mode 100644 index 00000000000..c515241df18 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/incompatible.gomod.json @@ -0,0 +1,12 @@ +{ + "modulePath": "example.com/withincompat", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/legacy/lib", "version": "v2.0.0+incompatible", "indirect": false} + ], + "replaces": [], + "excludes": [], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/replace-local.gomod b/rewrite-go/src/test/resources/gomod-conformance/replace-local.gomod new file mode 100644 index 00000000000..d418c41d404 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/replace-local.gomod @@ -0,0 +1,7 @@ +module example.com/withlocal + +go 1.22 + +replace github.com/x/y => ../local/y + +require github.com/x/y v0.0.0 diff --git a/rewrite-go/src/test/resources/gomod-conformance/replace-local.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/replace-local.gomod.json new file mode 100644 index 00000000000..3ad56729d9e --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/replace-local.gomod.json @@ -0,0 +1,14 @@ +{ + "modulePath": "example.com/withlocal", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/x/y", "version": "v0.0.0", "indirect": false} + ], + "replaces": [ + {"oldPath": "github.com/x/y", "oldVersion": null, "newPath": "../local/y", "newVersion": null} + ], + "excludes": [], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/replace-module.gomod b/rewrite-go/src/test/resources/gomod-conformance/replace-module.gomod new file mode 100644 index 00000000000..f8e22414a09 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/replace-module.gomod @@ -0,0 +1,7 @@ +module example.com/withreplace + +go 1.22 + +replace github.com/x/y v1.2.3 => github.com/forked/y v1.2.4 + +require github.com/x/y v1.2.3 diff --git a/rewrite-go/src/test/resources/gomod-conformance/replace-module.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/replace-module.gomod.json new file mode 100644 index 00000000000..7cc158d19eb --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/replace-module.gomod.json @@ -0,0 +1,14 @@ +{ + "modulePath": "example.com/withreplace", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/x/y", "version": "v1.2.3", "indirect": false} + ], + "replaces": [ + {"oldPath": "github.com/x/y", "oldVersion": "v1.2.3", "newPath": "github.com/forked/y", "newVersion": "v1.2.4"} + ], + "excludes": [], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/retract-range.gomod b/rewrite-go/src/test/resources/gomod-conformance/retract-range.gomod new file mode 100644 index 00000000000..865b311d48a --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/retract-range.gomod @@ -0,0 +1,5 @@ +module example.com/withretractrange + +go 1.22 + +retract [v1.0.0, v1.0.5] diff --git a/rewrite-go/src/test/resources/gomod-conformance/retract-range.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/retract-range.gomod.json new file mode 100644 index 00000000000..2f4c6254021 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/retract-range.gomod.json @@ -0,0 +1,12 @@ +{ + "modulePath": "example.com/withretractrange", + "goVersion": "1.22", + "toolchain": "", + "requires": [], + "replaces": [], + "excludes": [], + "retracts": [ + {"versionRange": "[v1.0.0, v1.0.5]", "rationale": null} + ], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/retract-single.gomod b/rewrite-go/src/test/resources/gomod-conformance/retract-single.gomod new file mode 100644 index 00000000000..5ee2a05e379 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/retract-single.gomod @@ -0,0 +1,5 @@ +module example.com/withretract + +go 1.22 + +retract v0.0.5 // accidentally deleted main.go diff --git a/rewrite-go/src/test/resources/gomod-conformance/retract-single.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/retract-single.gomod.json new file mode 100644 index 00000000000..da862dacff4 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/retract-single.gomod.json @@ -0,0 +1,12 @@ +{ + "modulePath": "example.com/withretract", + "goVersion": "1.22", + "toolchain": "", + "requires": [], + "replaces": [], + "excludes": [], + "retracts": [ + {"versionRange": "v0.0.5", "rationale": "accidentally deleted main.go"} + ], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/toolchain.gomod b/rewrite-go/src/test/resources/gomod-conformance/toolchain.gomod new file mode 100644 index 00000000000..2ff72979875 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/toolchain.gomod @@ -0,0 +1,7 @@ +module example.com/withtoolchain + +go 1.22 + +toolchain go1.22.5 + +require github.com/x/y v1.0.0 diff --git a/rewrite-go/src/test/resources/gomod-conformance/toolchain.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/toolchain.gomod.json new file mode 100644 index 00000000000..47e1f579520 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/toolchain.gomod.json @@ -0,0 +1,12 @@ +{ + "modulePath": "example.com/withtoolchain", + "goVersion": "1.22", + "toolchain": "go1.22.5", + "requires": [ + {"modulePath": "github.com/x/y", "version": "v1.0.0", "indirect": false} + ], + "replaces": [], + "excludes": [], + "retracts": [], + "resolvedDependencies": [] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gomod b/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gomod new file mode 100644 index 00000000000..bef7920bcd1 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gomod @@ -0,0 +1,8 @@ +module example.com/withsum + +go 1.22 + +require ( + github.com/google/uuid v1.6.0 + golang.org/x/mod v0.35.0 // indirect +) diff --git a/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gomod.json b/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gomod.json new file mode 100644 index 00000000000..e5f3b7b314d --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gomod.json @@ -0,0 +1,16 @@ +{ + "modulePath": "example.com/withsum", + "goVersion": "1.22", + "toolchain": "", + "requires": [ + {"modulePath": "github.com/google/uuid", "version": "v1.6.0", "indirect": false}, + {"modulePath": "golang.org/x/mod", "version": "v0.35.0", "indirect": true} + ], + "replaces": [], + "excludes": [], + "retracts": [], + "resolvedDependencies": [ + {"modulePath": "github.com/google/uuid", "version": "v1.6.0", "moduleHash": "h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=", "goModHash": "h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo="}, + {"modulePath": "golang.org/x/mod", "version": "v0.35.0", "moduleHash": "h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=", "goModHash": "h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU="} + ] +} diff --git a/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gosum b/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gosum new file mode 100644 index 00000000000..4ae3bf38875 --- /dev/null +++ b/rewrite-go/src/test/resources/gomod-conformance/with-gosum.gosum @@ -0,0 +1,4 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= diff --git a/rewrite-go/test/annotation_service_test.go b/rewrite-go/test/annotation_service_test.go new file mode 100644 index 00000000000..8442e239411 --- /dev/null +++ b/rewrite-go/test/annotation_service_test.go @@ -0,0 +1,172 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// Step 4 of AnnotationService rollout: the public service surface. +// Recipes use AllAnnotations / IsAnnotatedWith / FindAnnotations to +// inspect, and AddAnnotationVisitor / RemoveAnnotationVisitor (via +// DoAfterVisit) to mutate. + +func TestAnnotationService_Registered(t *testing.T) { + svc := recipe.Service[*golang.AnnotationService](nil) + if svc == nil { + t.Fatal("expected AnnotationService to be registered") + } +} + +func TestAnnotationService_IsAnnotatedWith_StructTag(t *testing.T) { + src := "package main\n\ntype User struct {\n\tName string `json:\"name\"`\n}\n" + field := parseStructAndFindField(t, src, "Name") + svc := &golang.AnnotationService{} + if !svc.IsAnnotatedWith(field, "json") { + t.Errorf("expected struct field with json tag to match \"json\"") + } + if svc.IsAnnotatedWith(field, "validate") { + t.Errorf("did not expect match for absent tag \"validate\"") + } +} + +func TestAnnotationService_IsAnnotatedWith_Directive(t *testing.T) { + src := "package main\n\n//go:noinline\nfunc slow() {}\n" + md := parseAndFindMethod(t, src, "slow") + svc := &golang.AnnotationService{} + if !svc.IsAnnotatedWith(md, "go:noinline") { + t.Errorf("expected method with go:noinline to match") + } +} + +func TestAnnotationService_IsAnnotatedWith_WildcardPrefix(t *testing.T) { + src := "package main\n\n//go:noinline\n//go:nosplit\nfunc slow() {}\n" + md := parseAndFindMethod(t, src, "slow") + svc := &golang.AnnotationService{} + if !svc.IsAnnotatedWith(md, "go:*") { + t.Errorf("expected method with go: directives to match \"go:*\"") + } + if !svc.IsAnnotatedWith(md, "*") { + t.Errorf("expected universal match \"*\" to succeed") + } + if svc.IsAnnotatedWith(md, "lint:*") { + t.Errorf("did not expect match for \"lint:*\" on go-only directives") + } +} + +func TestAnnotationService_FindAnnotations(t *testing.T) { + src := "package main\n\ntype User struct {\n\tEmail string `json:\"email\" db:\"email_address\" validate:\"required\"`\n}\n" + field := parseStructAndFindField(t, src, "Email") + svc := &golang.AnnotationService{} + + jsonAnns := svc.FindAnnotations(field, "json") + if len(jsonAnns) != 1 { + t.Fatalf("expected 1 json annotation, got %d", len(jsonAnns)) + } + if v, _ := jsonAnns[0].Arguments.Elements[0].Element.(*tree.Literal).Value.(string); v != "email" { + t.Errorf("json value: got %q, want \"email\"", v) + } +} + +func TestAnnotationService_AllAnnotations_ViaCursor(t *testing.T) { + src := "package main\n\n//go:noinline\nfunc slow() {}\n" + md := parseAndFindMethod(t, src, "slow") + svc := &golang.AnnotationService{} + + // Build a cursor positioned AT the MethodDeclaration. + c := buildCursor(md) + anns := svc.AllAnnotations(c) + if len(anns) != 1 { + t.Fatalf("AllAnnotations: got %d, want 1", len(anns)) + } + if anns[0].AnnotationType.(*tree.Identifier).Name != "go:noinline" { + t.Errorf("annotation: got %+v", anns[0].AnnotationType) + } +} + +func TestAnnotationService_AddAnnotationVisitor_OnFunc(t *testing.T) { + src := "package main\n\nfunc slow() { _ = 1 }\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + svc := &golang.AnnotationService{} + ann := &tree.Annotation{ + ID: uuid.New(), + Prefix: tree.Space{Whitespace: "\n"}, + AnnotationType: &tree.Identifier{ID: uuid.New(), Name: "go:noinline"}, + } + v := svc.AddAnnotationVisitor(func(t tree.Tree) bool { + md, ok := t.(*tree.MethodDeclaration) + return ok && md.Name != nil && md.Name.Name == "slow" + }, ann) + + out := v.Visit(cu, nil).(tree.Tree) + + want := "package main\n\n//go:noinline\nfunc slow() { _ = 1 }\n" + if got := printer.Print(out); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestAnnotationService_RemoveAnnotationVisitor(t *testing.T) { + // Start with two go: directives, remove one specifically. + src := "package main\n\n//go:noinline\n//go:nosplit\nfunc slow() {}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + svc := &golang.AnnotationService{} + v := svc.RemoveAnnotationVisitor("go:nosplit") + out := v.Visit(cu, nil).(tree.Tree) + + want := "package main\n\n//go:noinline\nfunc slow() {}\n" + if got := printer.Print(out); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestAnnotationService_Matches_ViaCursor(t *testing.T) { + src := "package main\n\n//go:noinline\nfunc slow() {}\n" + md := parseAndFindMethod(t, src, "slow") + c := buildCursor(md) + svc := &golang.AnnotationService{} + if !svc.Matches(c, golang.NewAnnotationMatcher("go:noinline")) { + t.Error("expected matcher \"go:noinline\" to match") + } + if svc.Matches(c, golang.NewAnnotationMatcher("go:nosplit")) { + t.Error("did not expect matcher \"go:nosplit\" to match") + } +} + +// buildCursor wraps a node in a single-element cursor for testing +// AnnotationService.AllAnnotations / Matches without going through the +// full visitor dispatch. +func buildCursor(t tree.Tree) *visitor.Cursor { + return visitor.NewCursor(nil, t) +} diff --git a/rewrite-go/test/annotation_test.go b/rewrite-go/test/annotation_test.go new file mode 100644 index 00000000000..f655f68e737 --- /dev/null +++ b/rewrite-go/test/annotation_test.go @@ -0,0 +1,163 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/google/uuid" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// Step 1 of AnnotationService rollout: tree.Annotation type + visitor + +// printer entry. No parser wiring yet — these tests construct +// annotations programmatically and verify the visitor walks them and +// the printer emits the expected struct-tag form. + +// newJSONTagAnnotation builds an Annotation that represents a single +// struct-tag pair `json:"name"` for the test fixtures below. +func newJSONTagAnnotation(key, value string) *tree.Annotation { + return &tree.Annotation{ + ID: uuid.New(), + AnnotationType: &tree.Identifier{ID: uuid.New(), Name: key}, + Arguments: &tree.Container[tree.Expression]{ + Elements: []tree.RightPadded[tree.Expression]{ + {Element: &tree.Literal{ + ID: uuid.New(), + Source: `"` + value + `"`, + Value: value, + Kind: tree.StringLiteral, + }}, + }, + }, + } +} + +func TestAnnotation_PrintsBasicTagShape(t *testing.T) { + ann := newJSONTagAnnotation("json", "name") + out := printer.Print(ann) + want := `json:"name"` + if out != want { + t.Errorf("got %q, want %q", out, want) + } +} + +func TestAnnotation_PrintsWithoutArguments(t *testing.T) { + // An Annotation with nil Arguments should print just the type + // expression (mirrors Java's bare `@Override`-style annotation). + ann := &tree.Annotation{ + ID: uuid.New(), + AnnotationType: &tree.Identifier{ID: uuid.New(), Name: "go:noinline"}, + } + out := printer.Print(ann) + want := `go:noinline` + if out != want { + t.Errorf("got %q, want %q", out, want) + } +} + +func TestAnnotation_PrintsPrefixWhitespace(t *testing.T) { + // The leading space (between the previous syntax and the annotation) + // lives on the Annotation's Prefix. + ann := newJSONTagAnnotation("validate", "required") + ann.Prefix = tree.Space{Whitespace: " "} + out := printer.Print(ann) + want := ` validate:"required"` + if out != want { + t.Errorf("got %q, want %q", out, want) + } +} + +func TestAnnotation_VisitorRoundtripIdentity(t *testing.T) { + // A no-op visitor over an Annotation should produce a tree whose + // printed form is identical to the input's. + ann := newJSONTagAnnotation("json", "user_id") + ann.Prefix = tree.Space{Whitespace: " "} + + v := visitor.Init(&visitor.GoVisitor{}) + out := v.Visit(ann, nil) + if out == nil { + t.Fatal("visitor returned nil") + } + got := printer.Print(out.(tree.Tree)) + want := ` json:"user_id"` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestAnnotation_VisitorReachesAnnotationType(t *testing.T) { + // Custom visitor that renames the AnnotationType identifier + // confirms the visitor recurses into the type child. + ann := newJSONTagAnnotation("json", "x") + v := visitor.Init(&renamingVisitor{from: "json", to: "yaml"}) + out := v.Visit(ann, nil).(*tree.Annotation) + + got := printer.Print(out) + want := `yaml:"x"` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestAnnotation_VisitorReachesArguments(t *testing.T) { + // Custom visitor that mutates the Literal value in Arguments + // confirms the visitor recurses into the arguments container. + ann := newJSONTagAnnotation("json", "x") + v := visitor.Init(&literalRewriter{want: "x", repl: "y"}) + out := v.Visit(ann, nil).(*tree.Annotation) + + got := printer.Print(out) + want := `json:"y"` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +type renamingVisitor struct { + visitor.GoVisitor + from, to string +} + +func (v *renamingVisitor) VisitIdentifier(id *tree.Identifier, p any) tree.J { + id = v.GoVisitor.VisitIdentifier(id, p).(*tree.Identifier) + if id.Name == v.from { + c := *id + c.Name = v.to + return &c + } + return id +} + +type literalRewriter struct { + visitor.GoVisitor + want, repl string +} + +func (v *literalRewriter) VisitLiteral(lit *tree.Literal, p any) tree.J { + lit = v.GoVisitor.VisitLiteral(lit, p).(*tree.Literal) + if s, ok := lit.Value.(string); ok && s == v.want { + c := *lit + c.Value = v.repl + c.Source = `"` + v.repl + `"` + return &c + } + return lit +} diff --git a/rewrite-go/test/auto_format_test.go b/rewrite-go/test/auto_format_test.go new file mode 100644 index 00000000000..b53b62b0163 --- /dev/null +++ b/rewrite-go/test/auto_format_test.go @@ -0,0 +1,301 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/format" + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// applyVisitor parses src, runs visitor (and any DoAfterVisit-queued +// follow-ups), and returns the printed result. +func applyVisitor(t *testing.T, src string, v recipe.TreeVisitor) string { + t.Helper() + p := parser.NewGoParser() + cu, err := p.Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + result := v.Visit(cu, nil) + if result == nil { + t.Fatal("visit returned nil") + } + + final := visitor.DrainAfterVisits(v, result.(tree.Tree), nil) + return printer.Print(final) +} + +// ---- Service registry ---- + +func TestAutoFormatService_RegisteredOnInit(t *testing.T) { + svc := recipe.Service[*golang.AutoFormatService](nil) + if svc == nil { + t.Fatal("expected AutoFormatService to be registered, got nil") + } +} + +// ---- RemoveTrailingWhitespaceVisitor ---- + +func TestRemoveTrailingWhitespace_StripsTrailingTabsFromLines(t *testing.T) { + src := "package main \n\nfunc main() {}\n" + out := applyVisitor(t, src, format.NewRemoveTrailingWhitespaceVisitor(nil)) + want := "package main\n\nfunc main() {}\n" + if out != want { + t.Errorf("got %q, want %q", out, want) + } +} + +// ---- BlankLinesVisitor ---- + +// Regression: the leading blank line above the first statement of a +// block lives on the *leftmost descendant* of that statement (e.g. +// Variable.Prefix), not on Asg.Prefix. The visitor walks the leftmost +// spine via transformLeftmostPrefix to find it. +func TestBlankLines_StripsLeadingBlankLineAtBlockStart(t *testing.T) { + src := `package main + +func main() { + + a := 1 + _ = a +} +` + want := `package main + +func main() { + a := 1 + _ = a +} +` + out := applyVisitor(t, src, format.NewBlankLinesVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// Regression: the trailing blank line above the closing brace lives on +// Block.End — straightforward direct manipulation. +func TestBlankLines_StripsTrailingBlankLineAtBlockEnd(t *testing.T) { + src := `package main + +func main() { + a := 1 + _ = a + +} +` + want := `package main + +func main() { + a := 1 + _ = a +} +` + out := applyVisitor(t, src, format.NewBlankLinesVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +func TestBlankLines_CapsRunOfBlankLinesInBlock(t *testing.T) { + src := `package main + +func main() { + a := 1 + + + + b := 2 + _ = a + b +} +` + want := `package main + +func main() { + a := 1 + + b := 2 + _ = a + b +} +` + out := applyVisitor(t, src, format.NewBlankLinesVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// ---- TabsAndIndentsVisitor ---- + +func TestTabsAndIndents_ReindentsFunctionBody(t *testing.T) { + src := "package main\n\nfunc main() {\n\t\t a := 1\n\t_ = a\n}\n" + want := "package main\n\nfunc main() {\n\ta := 1\n\t_ = a\n}\n" + out := applyVisitor(t, src, format.NewTabsAndIndentsVisitor(nil)) + if out != want { + t.Errorf("got %q, want %q", out, want) + } +} + +func TestTabsAndIndents_NestedBlockGetsTwoTabs(t *testing.T) { + src := `package main + +func main() { + if true { + a := 1 + _ = a + } +} +` + want := `package main + +func main() { + if true { + a := 1 + _ = a + } +} +` + out := applyVisitor(t, src, format.NewTabsAndIndentsVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// ---- SpacesVisitor ---- + +func TestSpaces_NormalizesBinaryOperatorSpacing(t *testing.T) { + src := `package main + +func main() { + a := 1+2 + _ = a +} +` + want := `package main + +func main() { + a := 1 + 2 + _ = a +} +` + out := applyVisitor(t, src, format.NewSpacesVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// Regression: when the right operand of `:=` is itself a Binary, the +// leading single-space-after-`:=` lives on the leftmost leaf of the +// Binary tree (e.g., the Literal `1` in `1+2+3`). Setting Binary.Prefix +// directly would double the space. +func TestSpaces_NoSpaceDoublingWithBinaryOperand(t *testing.T) { + src := `package main + +func main() { + a := 1+2+3 + _ = a +} +` + want := `package main + +func main() { + a := 1 + 2 + 3 + _ = a +} +` + out := applyVisitor(t, src, format.NewSpacesVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// Regression: same delegation rule when the assigned expression is a +// FieldAccess — the space-after-`:=` lives on FieldAccess.Target.Prefix. +func TestSpaces_FieldAccessLeadingSpace(t *testing.T) { + src := `package main + +func main() { + x := struct{ a int }{a: 1} + y :=x.a + _ = y +} +` + want := `package main + +func main() { + x := struct{ a int }{a: 1} + y := x.a + _ = y +} +` + out := applyVisitor(t, src, format.NewSpacesVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// Regression: TabsAndIndentsVisitor places `case` clauses at the +// switch-keyword's depth (gofmt convention) and case bodies one tab +// deeper. +func TestTabsAndIndents_SwitchCaseAlignsWithSwitch(t *testing.T) { + src := `package main + +func main() { + switch x := 1; x { + case 1: + println("one") + case 2: + println("two") + } +} +` + want := `package main + +func main() { + switch x := 1; x { + case 1: + println("one") + case 2: + println("two") + } +} +` + out := applyVisitor(t, src, format.NewTabsAndIndentsVisitor(nil)) + if out != want { + t.Errorf("got:\n%s\nwant:\n%s", out, want) + } +} + +// ---- AutoFormatVisitor (composition) ---- + +func TestAutoFormat_FullPipelineEndToEnd(t *testing.T) { + // Combines: trailing whitespace on `func main() {`, blank line at + // start of body, wrong indent on nested block + its body, + // missing space around `+`. Expect all four passes to fire. + src := "package main\n\nfunc main() { \n\n\n\tif true {\n\ta := 1+2\n\t_ = a\n\t}\n}\n" + want := "package main\n\nfunc main() {\n\tif true {\n\t\ta := 1 + 2\n\t\t_ = a\n\t}\n}\n" + out := applyVisitor(t, src, format.NewAutoFormatVisitor(nil)) + if out != want { + t.Errorf("got %q, want %q", out, want) + } +} diff --git a/rewrite-go/test/build_tags_test.go b/rewrite-go/test/build_tags_test.go new file mode 100644 index 00000000000..e5a50c7477d --- /dev/null +++ b/rewrite-go/test/build_tags_test.go @@ -0,0 +1,161 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "go/build" + "sort" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" +) + +// parsedNames returns the file names included by ParsePackage for the +// given build context — the names of files that survived `//go:build` +// and filename-suffix constraint evaluation. +func parsedNames(t *testing.T, buildCtx build.Context, files []parser.FileInput) []string { + t.Helper() + p := parser.NewGoParserWithBuildContext(buildCtx) + cus, err := p.ParsePackage(files) + if err != nil { + t.Fatalf("ParsePackage: %v", err) + } + out := make([]string, 0, len(cus)) + for _, cu := range cus { + out = append(out, cu.SourcePath) + } + sort.Strings(out) + return out +} + +func ctx(goos, goarch string) build.Context { + c := build.Default + c.GOOS = goos + c.GOARCH = goarch + return c +} + +// Case 1: //go:build linux — included on Linux, excluded on macOS. +func TestBuildTags_GoBuildLinuxOnly(t *testing.T) { + files := []parser.FileInput{ + {Path: "main.go", Content: "package p\n\nfunc Main() {}\n"}, + {Path: "lin.go", Content: "//go:build linux\n\npackage p\n\nfunc Lin() {}\n"}, + } + if got := parsedNames(t, ctx("linux", "amd64"), files); !equal(got, []string{"lin.go", "main.go"}) { + t.Errorf("on linux: got %v, want [lin.go main.go]", got) + } + if got := parsedNames(t, ctx("darwin", "amd64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("on darwin: got %v, want [main.go]", got) + } +} + +// Case 2: filename suffix matching. `_linux.go`, `_amd64.go`, and +// `_linux_amd64.go` exclude themselves on the wrong platform. +func TestBuildTags_FilenameSuffix(t *testing.T) { + files := []parser.FileInput{ + {Path: "main.go", Content: "package p\n\nfunc Main() {}\n"}, + {Path: "extra_linux.go", Content: "package p\n\nfunc Lin() {}\n"}, + {Path: "extra_amd64.go", Content: "package p\n\nfunc Amd() {}\n"}, + {Path: "extra_linux_amd64.go", Content: "package p\n\nfunc Both() {}\n"}, + } + got := parsedNames(t, ctx("linux", "amd64"), files) + want := []string{"extra_amd64.go", "extra_linux.go", "extra_linux_amd64.go", "main.go"} + if !equal(got, want) { + t.Errorf("on linux/amd64: got %v, want %v", got, want) + } + + got = parsedNames(t, ctx("darwin", "arm64"), files) + want = []string{"main.go"} + if !equal(got, want) { + t.Errorf("on darwin/arm64: got %v, want %v", got, want) + } +} + +// Case 3: combined constraints — `//go:build linux && amd64`. +func TestBuildTags_CombinedConstraint(t *testing.T) { + files := []parser.FileInput{ + {Path: "main.go", Content: "package p\n\nfunc Main() {}\n"}, + {Path: "both.go", Content: "//go:build linux && amd64\n\npackage p\n\nfunc Both() {}\n"}, + } + if got := parsedNames(t, ctx("linux", "amd64"), files); !equal(got, []string{"both.go", "main.go"}) { + t.Errorf("linux/amd64: got %v", got) + } + if got := parsedNames(t, ctx("linux", "arm64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("linux/arm64: got %v", got) + } + if got := parsedNames(t, ctx("darwin", "amd64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("darwin/amd64: got %v", got) + } +} + +// Case 4: negated constraints — `//go:build !windows`. +func TestBuildTags_NegatedConstraint(t *testing.T) { + files := []parser.FileInput{ + {Path: "main.go", Content: "package p\n\nfunc Main() {}\n"}, + {Path: "nowin.go", Content: "//go:build !windows\n\npackage p\n\nfunc NoWin() {}\n"}, + } + if got := parsedNames(t, ctx("windows", "amd64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("on windows: got %v", got) + } + if got := parsedNames(t, ctx("linux", "amd64"), files); !equal(got, []string{"main.go", "nowin.go"}) { + t.Errorf("on linux: got %v", got) + } +} + +// Case 5: legacy `// +build` syntax — still recognized. +func TestBuildTags_LegacyPlusBuild(t *testing.T) { + files := []parser.FileInput{ + {Path: "main.go", Content: "package p\n\nfunc Main() {}\n"}, + {Path: "lin.go", Content: "// +build linux\n\npackage p\n\nfunc Lin() {}\n"}, + } + if got := parsedNames(t, ctx("linux", "amd64"), files); !equal(got, []string{"lin.go", "main.go"}) { + t.Errorf("on linux: got %v", got) + } + if got := parsedNames(t, ctx("darwin", "amd64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("on darwin: got %v", got) + } +} + +// Case 6: mixed filename + //go:build. Filename says linux, content +// constraint says amd64; the file is included only when BOTH match. +func TestBuildTags_MixedFilenameAndGoBuild(t *testing.T) { + files := []parser.FileInput{ + {Path: "main.go", Content: "package p\n\nfunc Main() {}\n"}, + {Path: "x_linux.go", Content: "//go:build amd64\n\npackage p\n\nfunc Both() {}\n"}, + } + if got := parsedNames(t, ctx("linux", "amd64"), files); !equal(got, []string{"main.go", "x_linux.go"}) { + t.Errorf("linux/amd64: got %v", got) + } + if got := parsedNames(t, ctx("linux", "arm64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("linux/arm64 (filename matches, constraint does not): got %v", got) + } + if got := parsedNames(t, ctx("darwin", "amd64"), files); !equal(got, []string{"main.go"}) { + t.Errorf("darwin/amd64 (constraint matches, filename does not): got %v", got) + } +} + +func equal(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/rewrite-go/test/cross_package_generics_test.go b/rewrite-go/test/cross_package_generics_test.go new file mode 100644 index 00000000000..bd054570177 --- /dev/null +++ b/rewrite-go/test/cross_package_generics_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// genericsScaffold builds a multi-package project rooted at a temp dir and +// returns the parsed compilation units keyed by their relative path. Each +// case stages files via the on-disk vendor walker pattern so the parser +// resolves cross-package generic references the same way it does in +// production parses. +func genericsScaffold(t *testing.T, files map[string]string, modulePath, mainRel string) *tree.CompilationUnit { + t.Helper() + root := t.TempDir() + for rel, content := range files { + full := filepath.Join(root, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", full, err) + } + if err := os.WriteFile(full, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", full, err) + } + } + pi := parser.NewProjectImporter(modulePath, nil) + pi.SetProjectRoot(root) + for rel, content := range files { + pi.AddSource(rel, content) + } + p := parser.NewGoParser() + p.Importer = pi + + mainContent := files[mainRel] + cu, err := p.Parse(mainRel, mainContent) + if err != nil { + t.Fatalf("parse %s: %v", mainRel, err) + } + return cu +} + +// Case 1: Generic function used across packages. +// +// package a +// func Map[T any](xs []T, f func(T) T) []T { ... } +// +// package main +// import "example.com/foo/a" +// func main() { a.Map([]int{1, 2, 3}, double) } +// +// The MethodInvocation for `a.Map` should carry a non-nil MethodType +// whose DeclaringType.FullyQualifiedName == the imported package's path. +// Without cross-package generic resolution the MethodType would be nil. +func TestCrossPackageGenerics_GenericFunc(t *testing.T) { + cu := genericsScaffold(t, map[string]string{ + "a/a.go": `package a + +func Map[T any](xs []T, f func(T) T) []T { + out := make([]T, len(xs)) + for i, x := range xs { + out[i] = f(x) + } + return out +} +`, + "main.go": `package main + +import "example.com/foo/a" + +func double(x int) int { return x * 2 } + +func main() { + _ = a.Map([]int{1, 2, 3}, double) +} +`, + }, "example.com/foo", "main.go") + ExpectMethodType(t, cu, "Map", "example.com/foo/a") +} + +// Case 2: Generic struct used across packages. +// +// package a +// type Box[T any] struct{ V T } +// +// package main +// import "example.com/foo/a" +// var b a.Box[int] +// +// The `Box` identifier in main should resolve to a FullyQualified type +// whose FQN is the package's full path. +func TestCrossPackageGenerics_GenericStruct(t *testing.T) { + cu := genericsScaffold(t, map[string]string{ + "a/a.go": `package a + +type Box[T any] struct{ V T } +`, + "main.go": `package main + +import "example.com/foo/a" + +func main() { + _ = a.Box[int]{V: 42} +} +`, + }, "example.com/foo", "main.go") + ExpectType(t, cu, "Box", "example.com/foo/a.Box") +} + +// Case 3: Multi-parameter generics across packages. +// +// package a +// type Pair[K, V comparable] struct{ K K; V V } +// +// package main +// import "example.com/foo/a" +// var p a.Pair[string, int] +// +// Pair should resolve to a FullyQualified type whose FQN ends in `Pair` +// and the multi-param shape doesn't break attribution. +func TestCrossPackageGenerics_MultiParam(t *testing.T) { + cu := genericsScaffold(t, map[string]string{ + "a/a.go": `package a + +type Pair[K, V comparable] struct { + Key K + Val V +} +`, + "main.go": `package main + +import "example.com/foo/a" + +func main() { + _ = a.Pair[string, int]{Key: "x", Val: 1} +} +`, + }, "example.com/foo", "main.go") + ExpectType(t, cu, "Pair", "example.com/foo/a.Pair") +} + +// Case 4: Bounded type parameters across packages. +// +// package a +// func Sum[T int | float64](xs []T) T { ... } +// +// package main +// import "example.com/foo/a" +// func main() { _ = a.Sum([]int{1, 2}) } +// +// Even with a union constraint, the cross-package call should resolve. +func TestCrossPackageGenerics_BoundedTypeParam(t *testing.T) { + cu := genericsScaffold(t, map[string]string{ + "a/a.go": `package a + +func Sum[T int | float64](xs []T) T { + var total T + for _, x := range xs { + total += x + } + return total +} +`, + "main.go": `package main + +import "example.com/foo/a" + +func main() { + _ = a.Sum([]int{1, 2, 3}) +} +`, + }, "example.com/foo", "main.go") + ExpectMethodType(t, cu, "Sum", "example.com/foo/a") +} diff --git a/rewrite-go/test/cursor_messages_test.go b/rewrite-go/test/cursor_messages_test.go new file mode 100644 index 00000000000..9d13c58a9e0 --- /dev/null +++ b/rewrite-go/test/cursor_messages_test.go @@ -0,0 +1,115 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +func threeFrameCursor(t *testing.T) (*visitor.Cursor, *visitor.Cursor, *visitor.Cursor) { + t.Helper() + cu, err := parser.NewGoParser().Parse("a.go", "package main\n") + if err != nil { + t.Fatal(err) + } + root := visitor.NewCursor(nil, cu) + mid := visitor.NewCursor(root, cu) + leaf := visitor.NewCursor(mid, cu) + return root, mid, leaf +} + +func TestCursorPutGetMessageThisFrameOnly(t *testing.T) { + _, _, leaf := threeFrameCursor(t) + leaf.PutMessage("k", 42) + if got := leaf.GetMessage("k"); got != 42 { + t.Fatalf("expected 42 on leaf, got %v", got) + } + // GetMessage does NOT walk up — value stays on the frame it was set. + if got := leaf.Parent().GetMessage("k"); got != nil { + t.Fatalf("expected nil on parent, got %v", got) + } +} + +func TestCursorGetNearestMessageWalksUp(t *testing.T) { + root, _, leaf := threeFrameCursor(t) + root.PutMessage("k", "from-root") + if got := leaf.GetNearestMessage("k"); got != "from-root" { + t.Fatalf("expected leaf to find root's message, got %v", got) + } + if got := leaf.GetNearestMessage("missing"); got != nil { + t.Fatalf("expected nil for missing key, got %v", got) + } + if got := leaf.GetNearestMessageOrDefault("missing", "fallback"); got != "fallback" { + t.Fatalf("expected fallback, got %v", got) + } +} + +func TestCursorPollNearestMessageRemoves(t *testing.T) { + root, _, leaf := threeFrameCursor(t) + root.PutMessage("k", "v") + if got := leaf.PollNearestMessage("k"); got != "v" { + t.Fatalf("expected to poll 'v' from root, got %v", got) + } + if got := root.GetMessage("k"); got != nil { + t.Fatalf("expected key to be removed from root, got %v", got) + } + if got := leaf.PollNearestMessage("k"); got != nil { + t.Fatalf("expected nil after poll, got %v", got) + } +} + +func TestCursorComputeMessageIfAbsent(t *testing.T) { + _, _, leaf := threeFrameCursor(t) + calls := 0 + v1 := leaf.ComputeMessageIfAbsent("k", func() any { + calls++ + return "computed" + }) + v2 := leaf.ComputeMessageIfAbsent("k", func() any { + calls++ + return "second" + }) + if v1 != "computed" || v2 != "computed" { + t.Fatalf("expected stable computed value, got v1=%v v2=%v", v1, v2) + } + if calls != 1 { + t.Fatalf("expected supplier to fire once, fired %d", calls) + } +} + +func TestCursorPutMessageOnFirstEnclosing(t *testing.T) { + root, _, leaf := threeFrameCursor(t) + // Stash a message on the first ancestor whose value is a CompilationUnit. + leaf.PutMessageOnFirstEnclosing(func(t tree.Tree) bool { + _, ok := t.(*tree.CompilationUnit) + return ok + }, "tag", "matched") + // Found ancestor (every frame in this fixture is the same CU); the + // leaf itself matches first. + if got := leaf.GetMessage("tag"); got != "matched" { + t.Fatalf("expected leaf to receive the message, got %v", got) + } + // Sanity: a predicate matching nothing leaves the chain untouched. + leaf.PutMessageOnFirstEnclosing(func(tree.Tree) bool { return false }, "x", 1) + if got := root.GetMessage("x"); got != nil { + t.Fatalf("expected no-op when predicate matches nothing, got %v", got) + } +} diff --git a/rewrite-go/test/cursor_test.go b/rewrite-go/test/cursor_test.go new file mode 100644 index 00000000000..27d4c7f95ae --- /dev/null +++ b/rewrite-go/test/cursor_test.go @@ -0,0 +1,82 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +func TestCursorBuildChain(t *testing.T) { + cu, err := parser.NewGoParser().Parse("a.go", "package main\n") + if err != nil { + t.Fatal(err) + } + chain := visitor.BuildChain([]tree.Tree{cu}) + if chain == nil || chain.Value() != cu || chain.Parent() != nil { + t.Fatalf("expected single-element chain rooted at cu; got parent=%v value=%v", chain.Parent(), chain.Value()) + } + + chain2 := visitor.BuildChain(nil) + if chain2 != nil { + t.Fatalf("expected nil chain for empty input, got %v", chain2) + } +} + +// TestVisitorCursorState confirms that GoVisitor exposes its cursor as +// state via Cursor() / SetCursor(), matching the JavaVisitor pattern. +// The RPC layer seeds an initial cursor before traversal; recipes read +// it from inside any Visit* override. +func TestVisitorCursorState(t *testing.T) { + cu, err := parser.NewGoParser().Parse("a.go", "package main\nfunc f(){}\n") + if err != nil { + t.Fatal(err) + } + + v := &cursorObservingVisitor{} + visitor.Init(v) + + outer := visitor.BuildChain([]tree.Tree{cu}) + v.SetCursor(outer) + if v.Cursor() != outer { + t.Fatalf("Cursor() should return what SetCursor seeded") + } + + v.Visit(cu, recipe.NewExecutionContext()) + if !v.observedCU { + t.Fatal("VisitCompilationUnit was never invoked") + } + if v.cuCursor == nil || v.cuCursor.Value() != cu { + t.Fatalf("expected v.Cursor().Value() == cu inside VisitCompilationUnit; got %v", v.cuCursor) + } +} + +type cursorObservingVisitor struct { + visitor.GoVisitor + observedCU bool + cuCursor *visitor.Cursor +} + +func (v *cursorObservingVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + v.observedCU = true + v.cuCursor = v.Cursor() + return cu +} diff --git a/rewrite-go/test/data_table_test.go b/rewrite-go/test/data_table_test.go new file mode 100644 index 00000000000..7fff62b8a9d --- /dev/null +++ b/rewrite-go/test/data_table_test.go @@ -0,0 +1,124 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" +) + +type findingsRow struct { + File string + Count int +} + +var findings = recipe.NewDataTable[findingsRow]( + "org.example.MyRecipe.Findings", + "My findings", + "What MyRecipe found", + []recipe.ColumnDescriptor{ + {Name: "File", DisplayName: "File", Description: "The file"}, + {Name: "Count", DisplayName: "Count", Description: "Number of hits", Type: "Integer"}, + }, +) + +func TestInMemoryDataTableStoreRoundTrip(t *testing.T) { + ctx := recipe.NewExecutionContext() + ctx.PutMessage(recipe.DataTableStoreKey, recipe.NewInMemoryDataTableStore()) + + findings.InsertRow(ctx, findingsRow{File: "a.go", Count: 3}) + findings.InsertRow(ctx, findingsRow{File: "b.go", Count: 1}) + + store, _ := ctx.GetMessage(recipe.DataTableStoreKey) + rows := store.(recipe.DataTableStore).GetRows("org.example.MyRecipe.Findings", "") + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) + } + if r, ok := rows[0].(findingsRow); !ok || r.File != "a.go" || r.Count != 3 { + t.Fatalf("unexpected row 0: %+v", rows[0]) + } +} + +func TestDataTableLazyStoreCreation(t *testing.T) { + ctx := recipe.NewExecutionContext() + // No store installed — InsertRow should create an InMemory one. + findings.InsertRow(ctx, findingsRow{File: "c.go", Count: 7}) + + store, ok := ctx.GetMessage(recipe.DataTableStoreKey) + if !ok { + t.Fatal("expected store to be lazily created") + } + if _, ok := store.(*recipe.InMemoryDataTableStore); !ok { + t.Fatalf("expected InMemoryDataTableStore, got %T", store) + } + rows := store.(recipe.DataTableStore).GetRows("org.example.MyRecipe.Findings", "") + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } +} + +func TestCsvDataTableStoreWritesFile(t *testing.T) { + dir := t.TempDir() + store, err := recipe.NewCsvDataTableStore(dir) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + ctx := recipe.NewExecutionContext() + ctx.PutMessage(recipe.DataTableStoreKey, store) + + findings.InsertRow(ctx, findingsRow{File: "x.go", Count: 5}) + findings.InsertRow(ctx, findingsRow{File: "y,go", Count: 9}) // tests CSV escaping + store.Close() + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 CSV file, got %d", len(entries)) + } + csvPath := filepath.Join(dir, entries[0].Name()) + data, err := os.ReadFile(csvPath) + if err != nil { + t.Fatal(err) + } + got := string(data) + for _, want := range []string{ + "# @name org.example.MyRecipe.Findings", + "File,Count", + "x.go,5", + `"y,go",9`, // comma in field forces quoting + } { + if !strings.Contains(got, want) { + t.Errorf("CSV missing expected line %q; got:\n%s", want, got) + } + } +} + +func TestSanitizeScope(t *testing.T) { + got := recipe.SanitizeScope("org.openrewrite.Foo$Bar") + // lowercase + non-alnum→dash + 4-char hash suffix (matching JS spec) + if !strings.HasPrefix(got, "org-openrewrite-foo-bar-") { + t.Errorf("unexpected sanitized value: %s", got) + } +} diff --git a/rewrite-go/test/directive_test.go b/rewrite-go/test/directive_test.go new file mode 100644 index 00000000000..3279bec944e --- /dev/null +++ b/rewrite-go/test/directive_test.go @@ -0,0 +1,217 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// Step 3 of AnnotationService rollout: the parser extracts `//go:` +// and `//lint:` directives from leading comments above top-level +// MethodDeclaration / TypeDecl / VariableDeclarations into +// LeadingAnnotations. The printer reassembles them as `//<name> <args>` +// lines on roundtrip. + +func parseAndFindMethod(t *testing.T, src, name string) *tree.MethodDeclaration { + t.Helper() + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, rp := range cu.Statements { + if md, ok := rp.Element.(*tree.MethodDeclaration); ok && md.Name != nil && md.Name.Name == name { + return md + } + } + t.Fatalf("method %q not found", name) + return nil +} + +func parseAndFindType(t *testing.T, src, name string) *tree.TypeDecl { + t.Helper() + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, rp := range cu.Statements { + if td, ok := rp.Element.(*tree.TypeDecl); ok && td.Name != nil && td.Name.Name == name { + return td + } + } + t.Fatalf("type %q not found", name) + return nil +} + +func parseAndFindVar(t *testing.T, src string) *tree.VariableDeclarations { + t.Helper() + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, rp := range cu.Statements { + if vd, ok := rp.Element.(*tree.VariableDeclarations); ok { + return vd + } + } + t.Fatalf("no var declaration found") + return nil +} + +func TestDirective_BareGoNoinline(t *testing.T) { + src := "package main\n\n//go:noinline\nfunc slow() {}\n" + md := parseAndFindMethod(t, src, "slow") + if got := len(md.LeadingAnnotations); got != 1 { + t.Fatalf("LeadingAnnotations: got %d, want 1", got) + } + ann := md.LeadingAnnotations[0] + if id, _ := ann.AnnotationType.(*tree.Identifier); id == nil || id.Name != "go:noinline" { + t.Errorf("AnnotationType: got %+v, want Identifier{Name:\"go:noinline\"}", ann.AnnotationType) + } + if ann.Arguments != nil { + t.Errorf("Arguments: got %+v, want nil for bare directive", ann.Arguments) + } +} + +func TestDirective_GoLinknameWithArgs(t *testing.T) { + src := "package main\n\n//go:linkname x runtime.x\nvar x int = 1\n" + vd := parseAndFindVar(t, src) + if got := len(vd.LeadingAnnotations); got != 1 { + t.Fatalf("LeadingAnnotations: got %d, want 1", got) + } + ann := vd.LeadingAnnotations[0] + if id, _ := ann.AnnotationType.(*tree.Identifier); id == nil || id.Name != "go:linkname" { + t.Errorf("AnnotationType: got %+v, want Identifier{Name:\"go:linkname\"}", ann.AnnotationType) + } + if ann.Arguments == nil || len(ann.Arguments.Elements) != 1 { + t.Fatalf("Arguments: got %+v, want one Literal", ann.Arguments) + } + lit, _ := ann.Arguments.Elements[0].Element.(*tree.Literal) + if lit == nil || lit.Source != "x runtime.x" { + t.Errorf("Args: got %+v, want \"x runtime.x\"", lit) + } +} + +func TestDirective_MultipleDirectivesOnFunc(t *testing.T) { + src := "package main\n\n//go:noinline\n//go:nosplit\nfunc slow() {}\n" + md := parseAndFindMethod(t, src, "slow") + if got := len(md.LeadingAnnotations); got != 2 { + t.Fatalf("LeadingAnnotations: got %d, want 2", got) + } + if md.LeadingAnnotations[0].AnnotationType.(*tree.Identifier).Name != "go:noinline" { + t.Errorf("[0]: got %+v", md.LeadingAnnotations[0].AnnotationType) + } + if md.LeadingAnnotations[1].AnnotationType.(*tree.Identifier).Name != "go:nosplit" { + t.Errorf("[1]: got %+v", md.LeadingAnnotations[1].AnnotationType) + } +} + +func TestDirective_LintIgnoreOnType(t *testing.T) { + src := "package main\n\n//lint:ignore U1000 unused but kept\ntype Foo struct{}\n" + td := parseAndFindType(t, src, "Foo") + if got := len(td.LeadingAnnotations); got != 1 { + t.Fatalf("LeadingAnnotations: got %d, want 1", got) + } + ann := td.LeadingAnnotations[0] + if ann.AnnotationType.(*tree.Identifier).Name != "lint:ignore" { + t.Errorf("AnnotationType: got %+v", ann.AnnotationType) + } + if ann.Arguments == nil { + t.Fatal("Arguments: got nil") + } + if got := ann.Arguments.Elements[0].Element.(*tree.Literal).Source; got != "U1000 unused but kept" { + t.Errorf("Args: got %q, want %q", got, "U1000 unused but kept") + } +} + +func TestDirective_RegularCommentNotExtracted(t *testing.T) { + src := "package main\n\n// regular doc comment\nfunc f() {}\n" + md := parseAndFindMethod(t, src, "f") + if got := len(md.LeadingAnnotations); got != 0 { + t.Errorf("LeadingAnnotations: got %d, want 0 — regular doc comments stay as comments", got) + } +} + +func TestDirective_DirectivePrecedingRegularComment(t *testing.T) { + // Directive first, regular comment after, then func: the directive + // is extracted, the regular comment stays in the func's Prefix. + src := "package main\n\n//go:noinline\n// regular doc\nfunc f() {}\n" + md := parseAndFindMethod(t, src, "f") + if got := len(md.LeadingAnnotations); got != 1 { + t.Fatalf("LeadingAnnotations: got %d, want 1", got) + } + if got := len(md.Prefix.Comments); got != 1 { + t.Errorf("Prefix.Comments: got %d, want 1 (the regular doc)", got) + } +} + +func TestDirective_RegularCommentBeforeDirectiveStops(t *testing.T) { + // When a regular comment appears BEFORE the directive, extraction + // stops at the regular comment — the directive stays in Prefix. + src := "package main\n\n// regular doc\n//go:noinline\nfunc f() {}\n" + md := parseAndFindMethod(t, src, "f") + if got := len(md.LeadingAnnotations); got != 0 { + t.Errorf("LeadingAnnotations: got %d, want 0 (extraction halts at first non-directive)", got) + } +} + +func TestDirective_RoundtripFunc(t *testing.T) { + src := "package main\n\n//go:noinline\n//go:nosplit\nfunc slow() {}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if got := printer.Print(cu); got != src { + t.Errorf("roundtrip mismatch\nexpected: %q\nactual: %q", src, got) + } +} + +func TestDirective_RoundtripType(t *testing.T) { + src := "package main\n\n//go:generate go run gen.go\ntype Foo struct{}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if got := printer.Print(cu); got != src { + t.Errorf("roundtrip mismatch\nexpected: %q\nactual: %q", src, got) + } +} + +func TestDirective_RoundtripVar(t *testing.T) { + src := "package main\n\n//go:linkname x runtime.x\nvar x int = 1\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if got := printer.Print(cu); got != src { + t.Errorf("roundtrip mismatch\nexpected: %q\nactual: %q", src, got) + } +} + +func TestDirective_RoundtripMixed(t *testing.T) { + src := "package main\n\n//go:noinline\n//go:nosplit\nfunc slow() { _ = 1 }\n\n//go:generate go run gen.go\ntype Foo struct{}\n\n//go:linkname x runtime.x\nvar x int = 1\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if got := printer.Print(cu); got != src { + t.Errorf("roundtrip mismatch\nexpected: %q\nactual: %q", src, got) + } +} diff --git a/rewrite-go/test/go_project_test.go b/rewrite-go/test/go_project_test.go new file mode 100644 index 00000000000..b527d844e8d --- /dev/null +++ b/rewrite-go/test/go_project_test.go @@ -0,0 +1,104 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// TestGoProjectTagsGoSiblingsButNotMod confirms the Go-side test harness +// matches the Java-side Assertions.goProject(...) shape: a project wrapper +// tags every .go sibling with a tree.GoProject marker, and the go.mod +// sibling round-trips verbatim (Go-side go.mod parsing is a follow-up). +func TestGoProjectTagsGoSiblingsButNotMod(t *testing.T) { + goSrc := test.Golang(` + package main + + func main() {} + `) + goSrc.AfterRecipe = func(t *testing.T, cu *tree.CompilationUnit) { + t.Helper() + project, ok := findGoProject(cu.Markers) + if !ok { + t.Fatal("expected GoProject marker on .go file but none was attached") + } + if project.ProjectName != "foo" { + t.Fatalf("expected GoProject name=%q, got %q", "foo", project.ProjectName) + } + } + + spec := test.NewRecipeSpec() + spec.RewriteRun(t, + test.GoProject("foo", + test.GoMod(` + module example.com/foo + + go 1.22 + `), + goSrc, + ), + ) +} + +// TestGoProjectMixesWithBareGolangSpecs confirms a single RewriteRun call +// can take both project wrappers and bare Golang(...) sources side-by-side. +// The bare source carries no GoProject marker. +func TestGoProjectMixesWithBareGolangSpecs(t *testing.T) { + wrapped := test.Golang(` + package main + + func main() {} + `) + wrapped.AfterRecipe = func(t *testing.T, cu *tree.CompilationUnit) { + t.Helper() + if _, ok := findGoProject(cu.Markers); !ok { + t.Fatal("wrapped source should carry GoProject marker") + } + } + + bare := test.Golang(` + package main + + func main() {} + `) + bare.AfterRecipe = func(t *testing.T, cu *tree.CompilationUnit) { + t.Helper() + if _, ok := findGoProject(cu.Markers); ok { + t.Fatal("bare source should NOT carry GoProject marker") + } + } + + spec := test.NewRecipeSpec() + spec.RewriteRun(t, + test.GoProject("foo", wrapped), + bare, + ) +} + +// findGoProject is a small lookup helper that mirrors what real recipes +// would do: scan a tree's Markers for a GoProject and return it. +func findGoProject(m tree.Markers) (tree.GoProject, bool) { + for _, e := range m.Entries { + if p, ok := e.(tree.GoProject); ok { + return p, true + } + } + return tree.GoProject{}, false +} diff --git a/rewrite-go/test/gomod_conformance_test.go b/rewrite-go/test/gomod_conformance_test.go new file mode 100644 index 00000000000..d81d4266120 --- /dev/null +++ b/rewrite-go/test/gomod_conformance_test.go @@ -0,0 +1,211 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// conformanceShape is the canonical JSON form used by both Java and Go +// conformance tests. Field names and ordering MUST stay in sync with the +// Java GoModConformanceTest's matching record / class. +type conformanceShape struct { + ModulePath string `json:"modulePath"` + GoVersion string `json:"goVersion"` + Toolchain string `json:"toolchain"` + Requires []conformanceRequire `json:"requires"` + Replaces []conformanceReplace `json:"replaces"` + Excludes []conformanceExclude `json:"excludes"` + Retracts []conformanceRetract `json:"retracts"` + ResolvedDependencies []conformanceResolvedDep `json:"resolvedDependencies"` +} + +type conformanceRequire struct { + ModulePath string `json:"modulePath"` + Version string `json:"version"` + Indirect bool `json:"indirect"` +} + +type conformanceReplace struct { + OldPath string `json:"oldPath"` + OldVersion *string `json:"oldVersion"` + NewPath string `json:"newPath"` + NewVersion *string `json:"newVersion"` +} + +type conformanceExclude struct { + ModulePath string `json:"modulePath"` + Version string `json:"version"` +} + +type conformanceRetract struct { + VersionRange string `json:"versionRange"` + Rationale *string `json:"rationale"` +} + +type conformanceResolvedDep struct { + ModulePath string `json:"modulePath"` + Version string `json:"version"` + ModuleHash *string `json:"moduleHash"` + GoModHash *string `json:"goModHash"` +} + +func toConformance(mrr *tree.GoResolutionResult) conformanceShape { + out := conformanceShape{ + ModulePath: mrr.ModulePath, + GoVersion: mrr.GoVersion, + Toolchain: mrr.Toolchain, + Requires: []conformanceRequire{}, + Replaces: []conformanceReplace{}, + Excludes: []conformanceExclude{}, + Retracts: []conformanceRetract{}, + ResolvedDependencies: []conformanceResolvedDep{}, + } + for _, r := range mrr.Requires { + out.Requires = append(out.Requires, conformanceRequire{ + ModulePath: r.ModulePath, + Version: r.Version, + Indirect: r.Indirect, + }) + } + for _, r := range mrr.Replaces { + out.Replaces = append(out.Replaces, conformanceReplace{ + OldPath: r.OldPath, + OldVersion: nilIfEmpty(r.OldVersion), + NewPath: r.NewPath, + NewVersion: nilIfEmpty(r.NewVersion), + }) + } + for _, e := range mrr.Excludes { + out.Excludes = append(out.Excludes, conformanceExclude{ + ModulePath: e.ModulePath, + Version: e.Version, + }) + } + for _, r := range mrr.Retracts { + out.Retracts = append(out.Retracts, conformanceRetract{ + VersionRange: r.VersionRange, + Rationale: nilIfEmpty(r.Rationale), + }) + } + for _, d := range mrr.ResolvedDependencies { + out.ResolvedDependencies = append(out.ResolvedDependencies, conformanceResolvedDep{ + ModulePath: d.ModulePath, + Version: d.Version, + ModuleHash: nilIfEmpty(d.ModuleHash), + GoModHash: nilIfEmpty(d.GoModHash), + }) + } + return out +} + +func nilIfEmpty(s string) *string { + if s == "" { + return nil + } + return &s +} + +// TestGoModConformanceCorpus iterates every .gomod under +// src/test/resources/gomod-conformance/, parses it (with sibling .gosum if +// present), and compares the result to the corresponding .gomod.json +// golden. The Java GoModConformanceTest runs the same corpus. +func TestGoModConformanceCorpus(t *testing.T) { + corpusDir := filepath.Join("..", "src", "test", "resources", "gomod-conformance") + entries, err := os.ReadDir(corpusDir) + if err != nil { + t.Fatalf("read corpus dir: %v", err) + } + cases := 0 + for _, ent := range entries { + if ent.IsDir() || !strings.HasSuffix(ent.Name(), ".gomod") { + continue + } + cases++ + caseName := strings.TrimSuffix(ent.Name(), ".gomod") + t.Run(caseName, func(t *testing.T) { + modContent := mustRead(t, filepath.Join(corpusDir, ent.Name())) + mrr, err := parser.ParseGoMod("go.mod", modContent) + if err != nil { + t.Fatalf("parse: %v", err) + } + if sumPath := filepath.Join(corpusDir, caseName+".gosum"); fileExists(sumPath) { + mrr.ResolvedDependencies = parser.ParseGoSum(mustRead(t, sumPath)) + } + actual := toConformance(mrr) + + goldenContent := mustRead(t, filepath.Join(corpusDir, caseName+".gomod.json")) + var expected conformanceShape + if err := json.Unmarshal([]byte(goldenContent), &expected); err != nil { + t.Fatalf("unmarshal golden: %v", err) + } + normalize(&expected) + + if !reflect.DeepEqual(actual, expected) { + actualJSON, _ := json.MarshalIndent(actual, "", " ") + expectedJSON, _ := json.MarshalIndent(expected, "", " ") + t.Errorf("conformance mismatch\nactual:\n%s\n\nexpected:\n%s", + string(actualJSON), string(expectedJSON)) + } + }) + } + if cases == 0 { + t.Fatal("no .gomod cases found in corpus") + } +} + +// normalize replaces nil slices with empty slices so reflect.DeepEqual +// matches the conformance shape produced by toConformance. +func normalize(c *conformanceShape) { + if c.Requires == nil { + c.Requires = []conformanceRequire{} + } + if c.Replaces == nil { + c.Replaces = []conformanceReplace{} + } + if c.Excludes == nil { + c.Excludes = []conformanceExclude{} + } + if c.Retracts == nil { + c.Retracts = []conformanceRetract{} + } + if c.ResolvedDependencies == nil { + c.ResolvedDependencies = []conformanceResolvedDep{} + } +} + +func mustRead(t *testing.T, path string) string { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s: %v", path, err) + } + return string(b) +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} diff --git a/rewrite-go/test/gomod_parser_test.go b/rewrite-go/test/gomod_parser_test.go new file mode 100644 index 00000000000..68e051ef77e --- /dev/null +++ b/rewrite-go/test/gomod_parser_test.go @@ -0,0 +1,214 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/test" +) + +func TestParseGoModBasicFields(t *testing.T) { + mrr, err := parser.ParseGoMod("go.mod", `module example.com/foo + +go 1.22 + +toolchain go1.22.3 + +require ( + github.com/x/y v1.2.3 + github.com/z/w v0.5.0 // indirect +) + +replace github.com/x/y => github.com/x/y v1.2.4 + +exclude github.com/bad v0.0.1 + +retract v0.0.5 // accidentally deleted main.go +retract [v1.0.0, v1.0.5] +`) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if mrr.ModulePath != "example.com/foo" { + t.Errorf("ModulePath: want %q, got %q", "example.com/foo", mrr.ModulePath) + } + if mrr.GoVersion != "1.22" { + t.Errorf("GoVersion: want %q, got %q", "1.22", mrr.GoVersion) + } + if mrr.Toolchain != "go1.22.3" { + t.Errorf("Toolchain: want %q, got %q", "go1.22.3", mrr.Toolchain) + } + if len(mrr.Requires) != 2 { + t.Fatalf("Requires len: want 2, got %d", len(mrr.Requires)) + } + if mrr.Requires[0].ModulePath != "github.com/x/y" || mrr.Requires[0].Version != "v1.2.3" || mrr.Requires[0].Indirect { + t.Errorf("Requires[0]: %+v", mrr.Requires[0]) + } + if mrr.Requires[1].ModulePath != "github.com/z/w" || !mrr.Requires[1].Indirect { + t.Errorf("Requires[1]: %+v", mrr.Requires[1]) + } + if len(mrr.Replaces) != 1 || mrr.Replaces[0].OldPath != "github.com/x/y" || mrr.Replaces[0].NewVersion != "v1.2.4" { + t.Errorf("Replaces: %+v", mrr.Replaces) + } + if len(mrr.Excludes) != 1 || mrr.Excludes[0].ModulePath != "github.com/bad" { + t.Errorf("Excludes: %+v", mrr.Excludes) + } + if len(mrr.Retracts) != 2 { + t.Fatalf("Retracts len: want 2, got %d", len(mrr.Retracts)) + } + if mrr.Retracts[0].VersionRange != "v0.0.5" || mrr.Retracts[0].Rationale == "" { + t.Errorf("Retracts[0]: %+v", mrr.Retracts[0]) + } + if mrr.Retracts[1].VersionRange != "[v1.0.0, v1.0.5]" { + t.Errorf("Retracts[1] range: %+v", mrr.Retracts[1]) + } +} + +func TestGoModSourceSpecCarriesParsedMarker(t *testing.T) { + spec := test.GoMod(` + module example.com/foo + + go 1.22 + + require github.com/x/y v1.2.3 + `) + mrr := test.FindGoResolutionResult(spec) + if mrr == nil { + t.Fatal("expected GoResolutionResult marker on the GoMod SourceSpec") + } + if mrr.ModulePath != "example.com/foo" { + t.Errorf("ModulePath: want %q, got %q", "example.com/foo", mrr.ModulePath) + } + if r := mrr.FindRequire("github.com/x/y"); r == nil || r.Version != "v1.2.3" { + t.Errorf("FindRequire: %+v", r) + } +} + +func TestGoModBadInputDoesNotAttachMarker(t *testing.T) { + spec := test.GoMod(`this is not a valid go.mod`) + if mrr := test.FindGoResolutionResult(spec); mrr != nil { + t.Errorf("expected no GoResolutionResult marker on malformed input, got %+v", mrr) + } +} + +func TestParseGoSumBasic(t *testing.T) { + resolved := parser.ParseGoSum(`github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= +`) + if len(resolved) != 2 { + t.Fatalf("ParseGoSum: want 2 entries, got %d (%+v)", len(resolved), resolved) + } + uuid := resolved[0] + if uuid.ModulePath != "github.com/google/uuid" || uuid.Version != "v1.6.0" { + t.Errorf("entry[0]: want github.com/google/uuid@v1.6.0, got %+v", uuid) + } + if uuid.ModuleHash != "h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=" { + t.Errorf("entry[0].ModuleHash: %q", uuid.ModuleHash) + } + if uuid.GoModHash != "h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=" { + t.Errorf("entry[0].GoModHash: %q", uuid.GoModHash) + } + if mod := resolved[1]; mod.ModulePath != "golang.org/x/mod" || mod.Version != "v0.35.0" { + t.Errorf("entry[1]: want golang.org/x/mod@v0.35.0, got %+v", mod) + } +} + +func TestParseGoSumOnlyGoModHashRecorded(t *testing.T) { + // When go.sum has only the /go.mod line for a dependency (i.e. the + // module zip wasn't downloaded), ModuleHash is empty but GoModHash is + // set. This is what go.sum looks like for indirect deps that only the + // build graph knows about. + resolved := parser.ParseGoSum(`example.com/indirect v1.0.0/go.mod h1:abc123= +`) + if len(resolved) != 1 { + t.Fatalf("want 1 entry, got %d", len(resolved)) + } + if resolved[0].ModuleHash != "" { + t.Errorf("ModuleHash: want empty, got %q", resolved[0].ModuleHash) + } + if resolved[0].GoModHash != "h1:abc123=" { + t.Errorf("GoModHash: %q", resolved[0].GoModHash) + } +} + +func TestParseGoSumMalformedLineSkipped(t *testing.T) { + // A malformed line in the middle of a valid go.sum should be skipped + // (logged, not fatal) and adjacent entries should still parse. + resolved := parser.ParseGoSum(`github.com/a/b v1.0.0 h1:hashA= +this is not a valid go.sum line +github.com/c/d v2.0.0 h1:hashC= +`) + if len(resolved) != 2 { + t.Fatalf("want 2 entries (malformed skipped), got %d (%+v)", len(resolved), resolved) + } + if resolved[0].ModulePath != "github.com/a/b" || resolved[1].ModulePath != "github.com/c/d" { + t.Errorf("unexpected modules: %+v", resolved) + } +} + +func TestParseGoSumEmptyInput(t *testing.T) { + if got := parser.ParseGoSum(""); got != nil { + t.Errorf("want nil for empty input, got %+v", got) + } +} + +func TestGoProjectMergesGoSumIntoGoModMarker(t *testing.T) { + // Sibling go.mod + go.sum inside a GoProject: harness should merge + // the parsed ResolvedDependencies into the GoResolutionResult marker + // at expansion time. + expanded := test.GoProject("foo", + test.GoMod(` + module example.com/foo + + go 1.22 + + require github.com/google/uuid v1.6.0 + `), + test.GoSum(` + github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= + github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= + `), + ).Expand() + + var modSpec *test.SourceSpec + for i, s := range expanded { + if s.Path == "go.mod" { + modSpec = &expanded[i] + } + } + if modSpec == nil { + t.Fatal("no go.mod spec in expanded project") + } + mrr := test.FindGoResolutionResult(*modSpec) + if mrr == nil { + t.Fatal("no GoResolutionResult marker on go.mod") + } + if len(mrr.ResolvedDependencies) != 1 { + t.Fatalf("want 1 resolved dep, got %d (%+v)", len(mrr.ResolvedDependencies), mrr.ResolvedDependencies) + } + rd := mrr.ResolvedDependencies[0] + if rd.ModulePath != "github.com/google/uuid" || rd.Version != "v1.6.0" { + t.Errorf("unexpected resolved dep: %+v", rd) + } + if rd.ModuleHash == "" || rd.GoModHash == "" { + t.Errorf("expected both ModuleHash and GoModHash populated: %+v", rd) + } +} diff --git a/rewrite-go/test/import_recipes_test.go b/rewrite-go/test/import_recipes_test.go new file mode 100644 index 00000000000..be97b90ee0a --- /dev/null +++ b/rewrite-go/test/import_recipes_test.go @@ -0,0 +1,306 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" +) + +// strPtr returns a pointer to s for use as an Alias option. +func strPtr(s string) *string { return &s } + +// ---- AddImport ---- + +func TestAddImport_NoOpWhenAlreadyImported(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.AddImport{PackagePath: "fmt"}) + spec.RewriteRun(t, + Golang(` + package main + + import "fmt" + + func main() { fmt.Println("hi") } + `), + ) +} + +func TestAddImport_AddsToExistingBlock(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.AddImport{PackagePath: "strings"}) + before := ` + package main + + import ( + "fmt" + ) + + func main() { fmt.Println("hi") } + ` + after := ` + package main + + import ( + "fmt" + "strings" + ) + + func main() { fmt.Println("hi") } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +func TestAddImport_AddsToFileWithNoImports(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.AddImport{PackagePath: "fmt"}) + before := ` + package main + + func main() {} + ` + after := ` + package main + + import "fmt" + + func main() {} + ` + spec.RewriteRun(t, Golang(before, after)) +} + +func TestAddImport_OnlyIfReferenced_NoOpWhenNotReferenced(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.AddImport{ + PackagePath: "github.com/x/y", + OnlyIfReferenced: true, + }) + spec.RewriteRun(t, + Golang(` + package main + + func main() {} + `), + ) +} + +func TestAddImport_AliasedFormDoesNotMatchRegular(t *testing.T) { + // `import yy "github.com/x/y"` is present; AddImport(github.com/x/y, alias=nil) + // should treat it as MISSING the regular form because the alias differs. + // (This mirrors the Java AddImport semantics for explicit alias asks.) + spec := NewRecipeSpec().WithRecipe(&golang.AddImport{ + PackagePath: "github.com/x/y", + Alias: strPtr("yy"), + }) + spec.RewriteRun(t, + Golang(` + package main + + import yy "github.com/x/y" + + func main() { _ = yy.Hello() } + `), + ) +} + +// ---- RemoveImport ---- + +func TestRemoveImport_DeletesMatching(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RemoveImport{PackagePath: "strings"}) + before := ` + package main + + import ( + "fmt" + "strings" + ) + + func main() { fmt.Println(strings.ToUpper("hi")) } + ` + after := ` + package main + + import ( + "fmt" + ) + + func main() { fmt.Println(strings.ToUpper("hi")) } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +func TestRemoveImport_NoOpWhenAbsent(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RemoveImport{PackagePath: "strings"}) + spec.RewriteRun(t, + Golang(` + package main + + import "fmt" + + func main() { fmt.Println("hi") } + `), + ) +} + +// ---- RemoveUnusedImports ---- + +func TestRemoveUnusedImports_DropsUnreferenced(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RemoveUnusedImports{}) + before := ` + package main + + import ( + "fmt" + "strings" + ) + + func main() { fmt.Println("hi") } + ` + after := ` + package main + + import ( + "fmt" + ) + + func main() { fmt.Println("hi") } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +func TestRemoveUnusedImports_PreservesBlankImports(t *testing.T) { + // Blank imports stay — they exist for init() side-effects. + spec := NewRecipeSpec().WithRecipe(&golang.RemoveUnusedImports{}) + spec.RewriteRun(t, + Golang(` + package main + + import ( + _ "github.com/x/y" + "fmt" + ) + + func main() { fmt.Println("hi") } + `), + ) +} + +func TestRemoveUnusedImports_NoOpWhenAllUsed(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RemoveUnusedImports{}) + spec.RewriteRun(t, + Golang(` + package main + + import ( + "fmt" + "strings" + ) + + func main() { fmt.Println(strings.ToUpper("hi")) } + `), + ) +} + +// ---- OrderImports ---- + +func TestOrderImports_IdempotentOnAlreadyOrdered(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.OrderImports{}) + spec.RewriteRun(t, + Golang(` + package main + + import ( + "fmt" + + "github.com/x/y" + ) + + func main() { + fmt.Println("hi") + _ = y.Hello() + } + `), + ) +} + +func TestOrderImports_ReorderJumbledBlock(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.OrderImports{}) + before := ` + package main + + import ( + "github.com/x/y" + "fmt" + ) + + func main() { + fmt.Println("hi") + _ = y.Hello() + } + ` + after := ` + package main + + import ( + "fmt" + + "github.com/x/y" + ) + + func main() { + fmt.Println("hi") + _ = y.Hello() + } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +// goimports orders within each group alphabetically and inserts a blank +// line between groups. Run on a 3-import block that needs both: cross-group +// reorder of "github.com/x/y" → tail, and within-stdlib reorder of +// "strings" before "fmt". +func TestOrderImports_AlphabeticalWithinGroupAndBlankLineBetween(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.OrderImports{}) + before := ` + package main + + import ( + "github.com/x/y" + "strings" + "fmt" + ) + + func main() { + fmt.Println(strings.ToUpper("hi")) + _ = y.Hello() + } + ` + after := ` + package main + + import ( + "fmt" + "strings" + + "github.com/x/y" + ) + + func main() { + fmt.Println(strings.ToUpper("hi")) + _ = y.Hello() + } + ` + spec.RewriteRun(t, Golang(before, after)) +} diff --git a/rewrite-go/test/import_service_test.go b/rewrite-go/test/import_service_test.go new file mode 100644 index 00000000000..35e35789f20 --- /dev/null +++ b/rewrite-go/test/import_service_test.go @@ -0,0 +1,135 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// TestImportService_RegisteredOnInit verifies that simply importing +// pkg/recipe/golang triggers registration of *golang.ImportService. +func TestImportService_RegisteredOnInit(t *testing.T) { + svc := recipe.Service[*golang.ImportService](nil) + if svc == nil { + t.Fatal("recipe.Service returned nil for *golang.ImportService") + } +} + +// addStringsImportRecipe is a recipe that uses ImportService via +// DoAfterVisit to add a "strings" import as a side-effect of finding a +// MethodInvocation matching a target name. Demonstrates the canonical +// composition pattern. +type addStringsImportRecipe struct { + recipe.Base +} + +func (r *addStringsImportRecipe) Name() string { return "test.AddStringsImport" } +func (r *addStringsImportRecipe) DisplayName() string { return "Add strings import via service" } +func (r *addStringsImportRecipe) Description() string { return "Test recipe." } + +func (r *addStringsImportRecipe) Editor() recipe.TreeVisitor { + return visitor.Init(&addStringsVisitor{}) +} + +type addStringsVisitor struct{ visitor.GoVisitor } + +func (v *addStringsVisitor) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { + mi = v.GoVisitor.VisitMethodInvocation(mi, p).(*tree.MethodInvocation) + if mi.Name != nil && mi.Name.Name == "Println" { + // Recipe edits don't change the call here — just queue a + // follow-up visitor that adds the import. The harness drains + // the after-visits queue after the main visit completes. + svc := recipe.Service[*golang.ImportService](nil) + v.DoAfterVisit(svc.AddImportVisitor("strings", nil, false)) + } + return mi +} + +// TestImportService_AddImportViaDoAfterVisit demonstrates the full +// pattern: a recipe edits a method body, then queues an AddImport +// visitor via DoAfterVisit. The after-visit drain in the recipe +// runner applies it without explicit caller orchestration. +func TestImportService_AddImportViaDoAfterVisit(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&addStringsImportRecipe{}) + before := ` + package main + + import "fmt" + + func main() { fmt.Println("hi") } + ` + after := ` + package main + + import ( + "fmt" + "strings" + ) + + func main() { fmt.Println("hi") } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +// TestImportService_RemoveImport via DoAfterVisit. +func TestImportService_RemoveImportViaDoAfterVisit(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&removeFmtImportRecipe{}) + before := ` + package main + + import ( + "fmt" + "strings" + ) + + func main() { strings.ToUpper("hi") } + ` + after := ` + package main + + import ( + "strings" + ) + + func main() { strings.ToUpper("hi") } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +type removeFmtImportRecipe struct{ recipe.Base } + +func (r *removeFmtImportRecipe) Name() string { return "test.RemoveFmtImport" } +func (r *removeFmtImportRecipe) DisplayName() string { return "Remove fmt import via service" } +func (r *removeFmtImportRecipe) Description() string { return "Test recipe." } +func (r *removeFmtImportRecipe) Editor() recipe.TreeVisitor { + return visitor.Init(&removeFmtVisitor{}) +} + +type removeFmtVisitor struct{ visitor.GoVisitor } + +func (v *removeFmtVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { + cu = v.GoVisitor.VisitCompilationUnit(cu, p).(*tree.CompilationUnit) + svc := recipe.Service[*golang.ImportService](nil) + v.DoAfterVisit(svc.RemoveImportVisitor("fmt")) + return cu +} diff --git a/rewrite-go/test/multi_file_package_test.go b/rewrite-go/test/multi_file_package_test.go new file mode 100644 index 00000000000..b9ffcd2738e --- /dev/null +++ b/rewrite-go/test/multi_file_package_test.go @@ -0,0 +1,85 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// TestParsePackageResolvesCrossFileSymbols directly exercises ParsePackage +// without the test harness: file A calls a function defined in file B, +// both in the same package. The shared types.Info should populate B's +// definition AND A's reference with the same types.Object. +func TestParsePackageResolvesCrossFileSymbols(t *testing.T) { + cus, err := parser.NewGoParser().ParsePackage([]parser.FileInput{ + {Path: "main.go", Content: "package main\n\nfunc main() { helper() }\n"}, + {Path: "helper.go", Content: "package main\n\nfunc helper() {}\n"}, + }) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if len(cus) != 2 { + t.Fatalf("expected 2 CUs, got %d", len(cus)) + } + + mainTypes := collectIdentTypes(cus[0]) + if mainTypes["helper"] == nil { + t.Errorf("expected `helper` reference in main.go to have a non-nil Type after multi-file parse; got nil (cross-file resolution still broken)") + } +} + +// TestGoProjectMultiFilePackageResolves is the harness-level integration: +// inside a GoProject with a go.mod, two files at the project root both +// declare `package main` and reference each other. +func TestGoProjectMultiFilePackageResolves(t *testing.T) { + mainSrc := test.Golang(` + package main + + func main() { helper() } + `).WithPath("main.go") + mainSrc.AfterRecipe = func(t *testing.T, cu *tree.CompilationUnit) { + t.Helper() + ids := collectIdentTypes(cu) + if ids["helper"] == nil { + t.Errorf("`helper` reference should have a resolved Type when parsed alongside helper.go; got nil") + } + } + + helperSrc := test.Golang(` + package main + + func helper() {} + `).WithPath("helper.go") + + spec := test.NewRecipeSpec() + spec.RewriteRun(t, + test.GoProject("foo", + test.GoMod(` + module example.com/foo + + go 1.22 + `), + helperSrc, + mainSrc, + ), + ) +} + diff --git a/rewrite-go/test/naming_service_test.go b/rewrite-go/test/naming_service_test.go new file mode 100644 index 00000000000..59ff80e1488 --- /dev/null +++ b/rewrite-go/test/naming_service_test.go @@ -0,0 +1,119 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" +) + +func TestNamingService_RegisteredOnInit(t *testing.T) { + svc := recipe.Service[*golang.NamingService](nil) + if svc == nil { + t.Fatal("recipe.Service returned nil for *golang.NamingService") + } +} + +func TestNamingService_ToPascalCase(t *testing.T) { + svc := &golang.NamingService{} + cases := []struct{ in, want string }{ + {"fooBar", "FooBar"}, + {"foo", "Foo"}, + {"FooBar", "FooBar"}, // already exported + {"", ""}, + {"_priv", "_priv"}, // first rune isn't a letter; passthrough + {"αlpha", "Αlpha"}, // unicode (Greek alpha) + } + for _, c := range cases { + if got := svc.ToPascalCase(c.in); got != c.want { + t.Errorf("ToPascalCase(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestNamingService_ToCamelCase(t *testing.T) { + svc := &golang.NamingService{} + cases := []struct{ in, want string }{ + {"FooBar", "fooBar"}, + {"Foo", "foo"}, + {"foo", "foo"}, // already camel + {"", ""}, + {"Αlpha", "αlpha"}, // unicode + } + for _, c := range cases { + if got := svc.ToCamelCase(c.in); got != c.want { + t.Errorf("ToCamelCase(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestNamingService_IsExported(t *testing.T) { + svc := &golang.NamingService{} + cases := []struct { + in string + want bool + }{ + {"Foo", true}, + {"foo", false}, + {"_Foo", false}, // underscore is not an uppercase letter + {"", false}, + {"Αlpha", true}, // unicode upper + {"αlpha", false}, // unicode lower + } + for _, c := range cases { + if got := svc.IsExported(c.in); got != c.want { + t.Errorf("IsExported(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestNamingService_IsValidIdentifier(t *testing.T) { + svc := &golang.NamingService{} + cases := []struct { + in string + want bool + }{ + {"foo", true}, + {"Foo123", true}, + {"_x", true}, + {"123foo", false}, // can't start with digit + {"foo-bar", false}, + {"", false}, + {"func", false}, // reserved keyword + } + for _, c := range cases { + if got := svc.IsValidIdentifier(c.in); got != c.want { + t.Errorf("IsValidIdentifier(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestNamingService_IsPredeclared(t *testing.T) { + svc := &golang.NamingService{} + for _, name := range []string{"int", "string", "true", "false", "nil", "iota", "len", "make", "new", "any", "comparable", "min", "max", "clear", "error"} { + if !svc.IsPredeclared(name) { + t.Errorf("IsPredeclared(%q) = false, want true", name) + } + } + for _, name := range []string{"Foo", "foo", "MyType", "Println", "func", "if"} { + if svc.IsPredeclared(name) { + t.Errorf("IsPredeclared(%q) = true, want false", name) + } + } +} diff --git a/rewrite-go/test/project_type_attribution_test.go b/rewrite-go/test/project_type_attribution_test.go new file mode 100644 index 00000000000..84ac3e0c48f --- /dev/null +++ b/rewrite-go/test/project_type_attribution_test.go @@ -0,0 +1,130 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// identTypeWalker collects each visited Identifier's Type into a map. +type identTypeWalker struct { + visitor.GoVisitor + types map[string]tree.JavaType +} + +func (v *identTypeWalker) VisitIdentifier(ident *tree.Identifier, p any) tree.J { + if ident.Type != nil { + v.types[ident.Name] = ident.Type + } + return ident +} + +// TestProjectImporterResolvesIntraProjectImport directly exercises the +// ProjectImporter without going through the test harness: a sub-package +// is registered, then the importer is asked for it and we inspect the +// resulting *types.Package. +func TestProjectImporterResolvesIntraProjectImport(t *testing.T) { + pi := parser.NewProjectImporter("example.com/foo", nil) + pi.AddSource("sub/sub.go", "package sub\n\nfunc Hello() string { return \"hi\" }\n") + + pkg, err := pi.Import("example.com/foo/sub") + if err != nil { + t.Fatalf("Import returned error: %v", err) + } + if pkg == nil { + t.Fatal("Import returned nil package") + } + if pkg.Name() != "sub" { + t.Errorf("package name: want %q, got %q", "sub", pkg.Name()) + } + if hello := pkg.Scope().Lookup("Hello"); hello == nil { + t.Fatal("expected sub.Hello to be defined in the resolved package") + } +} + +// TestProjectImporterFallsBackToStdlib confirms that paths the importer +// doesn't have sources for fall through to importer.Default() so stdlib +// imports keep working. +func TestProjectImporterFallsBackToStdlib(t *testing.T) { + pi := parser.NewProjectImporter("example.com/foo", nil) + pkg, err := pi.Import("fmt") + if err != nil { + t.Fatalf("stdlib fallback failed: %v", err) + } + if pkg == nil || pkg.Name() != "fmt" { + t.Fatalf("expected pkg=fmt, got %v", pkg) + } +} + +// TestGoProjectWiresImporterIntoHarness is the integration assertion: a +// project with go.mod + sub-package + main importing the sub gets type +// attribution on the import. +func TestGoProjectWiresImporterIntoHarness(t *testing.T) { + mainSrc := test.Golang(` + package main + + import "example.com/foo/sub" + + func main() { _ = sub.Hello() } + `).WithPath("main.go") + mainSrc.AfterRecipe = func(t *testing.T, cu *tree.CompilationUnit) { + t.Helper() + // Find an Identifier referencing "sub" or "Hello" and assert its + // Type came back resolved (non-nil). Without the project importer, + // these would all be nil because importer.Default() doesn't know + // about example.com/foo/sub. + identTypes := collectIdentTypes(cu) + if identTypes["sub"] == nil { + t.Errorf("expected `sub` identifier in main.go to have a resolved Type, got nil") + } + if identTypes["Hello"] == nil { + t.Errorf("expected `Hello` identifier in main.go to have a resolved Type, got nil") + } + } + + spec := test.NewRecipeSpec() + spec.RewriteRun(t, + test.GoProject("foo", + test.GoMod(` + module example.com/foo + + go 1.22 + `), + test.Golang(` + package sub + + func Hello() string { return "hi" } + `).WithPath("sub/sub.go"), + mainSrc, + ), + ) +} + +// collectIdentTypes walks the tree and returns a map of identifier name +// → its Type (whichever last assignment wins; sufficient for these tests). +func collectIdentTypes(cu *tree.CompilationUnit) map[string]tree.JavaType { + out := map[string]tree.JavaType{} + collector := &identTypeWalker{types: out} + visitor.Init(collector) + collector.Visit(cu, nil) + return out +} diff --git a/rewrite-go/test/rename_package_test.go b/rewrite-go/test/rename_package_test.go new file mode 100644 index 00000000000..d8a6cdbb8f6 --- /dev/null +++ b/rewrite-go/test/rename_package_test.go @@ -0,0 +1,225 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" +) + +func TestRenamePackage_RewritesImportPath(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "github.com/old/foo", + NewPackagePath: "github.com/new/foo", + }) + before := ` + package main + + import "github.com/old/foo" + + func main() { _ = foo.Hello() } + ` + after := ` + package main + + import "github.com/new/foo" + + func main() { _ = foo.Hello() } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +func TestRenamePackage_RewritesSubPackageImports(t *testing.T) { + // `import "old/foo/sub"` rewrites to `import "new/foo/sub"` when + // renaming `old/foo` to `new/foo`. + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "github.com/old/foo", + NewPackagePath: "github.com/new/foo", + }) + before := ` + package main + + import ( + "github.com/old/foo" + "github.com/old/foo/sub" + ) + + func main() { + _ = foo.A() + _ = sub.B() + } + ` + after := ` + package main + + import ( + "github.com/new/foo" + "github.com/new/foo/sub" + ) + + func main() { + _ = foo.A() + _ = sub.B() + } + ` + spec.RewriteRun(t, Golang(before, after)) +} + +func TestRenamePackage_LeavesUnrelatedImports(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "github.com/old/foo", + NewPackagePath: "github.com/new/foo", + }) + spec.RewriteRun(t, + Golang(` + package main + + import ( + "fmt" + "github.com/other/bar" + ) + + func main() { fmt.Println(bar.X) } + `), + ) +} + +func TestRenamePackage_IdempotentOnRenamedPath(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "github.com/old/foo", + NewPackagePath: "github.com/new/foo", + }) + spec.RewriteRun(t, + Golang(` + package main + + import "github.com/new/foo" + + func main() { _ = foo.Hello() } + `), + ) +} + +// Files that own the renamed package — i.e. live at the directory +// matching OldPackagePath under the module — must have their `package` +// declaration rewritten to the new last-segment name. This is the +// "cross-file scope" half of the recipe, distinct from import +// rewriting in caller files. +func TestRenamePackage_RewritesOwnedPackageDecl(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "example.com/myapp/internal/old", + NewPackagePath: "example.com/myapp/internal/new", + }) + + owned := Golang(` + package old + + func Util() {} + `, ` + package new + + func Util() {} + `).WithPath("internal/old/util.go") + + spec.RewriteRun(t, + GoProject("myapp", + GoMod(` + module example.com/myapp + + go 1.22 + `), + owned, + ), + ) +} + +// A file with a coincidentally-matching package name but living +// elsewhere in the module must NOT have its package declaration +// rewritten. fileBelongsTo gates this on the file's source-relative +// directory matching the candidate's module-relative path. +func TestRenamePackage_LeavesCoincidentPackageDecl(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "example.com/myapp/internal/old", + NewPackagePath: "example.com/myapp/internal/new", + }) + + // Same package name `old`, but at a different directory — must be + // left alone. + unrelated := Golang(` + package old + + func Other() {} + `).WithPath("vendor/something/old/x.go") + + spec.RewriteRun(t, + GoProject("myapp", + GoMod(` + module example.com/myapp + + go 1.22 + `), + unrelated, + ), + ) +} + +// End-to-end: a single recipe run handles both the owning file's +// package declaration AND the consumer's import path in one project. +func TestRenamePackage_RewritesProjectWide(t *testing.T) { + spec := NewRecipeSpec().WithRecipe(&golang.RenamePackage{ + OldPackagePath: "example.com/myapp/internal/old", + NewPackagePath: "example.com/myapp/internal/new", + }) + + owned := Golang(` + package old + + func Util() {} + `, ` + package new + + func Util() {} + `).WithPath("internal/old/util.go") + + consumer := Golang(` + package main + + import "example.com/myapp/internal/old" + + func main() { old.Util() } + `, ` + package main + + import "example.com/myapp/internal/new" + + func main() { old.Util() } + `).WithPath("cmd/app/main.go") + + spec.RewriteRun(t, + GoProject("myapp", + GoMod(` + module example.com/myapp + + go 1.22 + `), + owned, + consumer, + ), + ) +} diff --git a/rewrite-go/test/require_resolution_test.go b/rewrite-go/test/require_resolution_test.go new file mode 100644 index 00000000000..2957e80a99f --- /dev/null +++ b/rewrite-go/test/require_resolution_test.go @@ -0,0 +1,107 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +func TestProjectImporterStubsRequiredModule(t *testing.T) { + pi := parser.NewProjectImporter("example.com/foo", nil) + pi.AddRequire("github.com/x/y") + + pkg, err := pi.Import("github.com/x/y") + if err != nil { + t.Fatalf("Import returned error: %v", err) + } + if pkg == nil { + t.Fatal("expected stub package, got nil") + } + if pkg.Path() != "github.com/x/y" { + t.Errorf("Path: want %q, got %q", "github.com/x/y", pkg.Path()) + } + if pkg.Name() != "y" { + t.Errorf("Name: want %q, got %q", "y", pkg.Name()) + } +} + +func TestProjectImporterStubMatchesSubPath(t *testing.T) { + pi := parser.NewProjectImporter("example.com/foo", nil) + pi.AddRequire("github.com/x/y") + + // `import "github.com/x/y/sub"` should also stub-resolve, because the + // require covers the whole module subtree. + pkg, err := pi.Import("github.com/x/y/sub") + if err != nil { + t.Fatalf("Import returned error: %v", err) + } + if pkg == nil { + t.Fatal("expected stub package for sub-path, got nil") + } + if pkg.Name() != "sub" { + t.Errorf("Name: want %q, got %q", "sub", pkg.Name()) + } +} + +func TestProjectImporterUnknownPathFallsThroughToError(t *testing.T) { + pi := parser.NewProjectImporter("example.com/foo", nil) + pi.AddRequire("github.com/x/y") + + // A module not in requires and not stdlib should not stub-resolve — + // importer.Default() returns an error for it. + if _, err := pi.Import("github.com/never/heard/of"); err == nil { + t.Errorf("expected error for unrequired non-stdlib path, got success") + } +} + +func TestGoProjectThirdPartyImportResolves(t *testing.T) { + mainSrc := test.Golang(` + package main + + import "github.com/x/y" + + func main() { _ = y.Hello() } + `).WithPath("main.go") + mainSrc.AfterRecipe = func(t *testing.T, cu *tree.CompilationUnit) { + t.Helper() + // `y` references the imported package; with the stub in place the + // identifier should now have a non-nil Type. Without require-driven + // stubbing this would be nil. + ids := collectIdentTypes(cu) + if ids["y"] == nil { + t.Errorf("expected `y` import identifier to have a non-nil Type via the require stub; got nil") + } + } + + spec := test.NewRecipeSpec() + spec.RewriteRun(t, + test.GoProject("foo", + test.GoMod(` + module example.com/foo + + go 1.22 + + require github.com/x/y v1.2.3 + `), + mainSrc, + ), + ) +} diff --git a/rewrite-go/test/search_walker_test.go b/rewrite-go/test/search_walker_test.go new file mode 100644 index 00000000000..3be34a29a38 --- /dev/null +++ b/rewrite-go/test/search_walker_test.go @@ -0,0 +1,85 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" + "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" +) + +// markBinaryVisitor attaches a SearchResult marker to every binary expression. +type markBinaryVisitor struct { + visitor.GoVisitor + marker tree.SearchResult +} + +func (v *markBinaryVisitor) VisitBinary(bin *tree.Binary, p any) tree.J { + bin = v.GoVisitor.VisitBinary(bin, p).(*tree.Binary) + bin = bin.WithMarkers(tree.AddMarker(bin.Markers, v.marker)) + return bin +} + +func TestCollectSearchResultIDsEmpty(t *testing.T) { + cu, err := parser.NewGoParser().Parse("a.go", "package main\n") + if err != nil { + t.Fatal(err) + } + if got := tree.CollectSearchResultIDs(cu); len(got) != 0 { + t.Fatalf("expected no search results, got %v", got) + } +} + +func TestCollectSearchResultIDsAfterMark(t *testing.T) { + cu, err := parser.NewGoParser().Parse("a.go", "package main\n\nvar x = 1 + 2\n") + if err != nil { + t.Fatal(err) + } + mark := tree.NewSearchResult("found a binary expr") + v := &markBinaryVisitor{marker: mark} + visitor.Init(v) + + result := v.Visit(cu, recipe.NewExecutionContext()).(tree.Tree) + ids := tree.CollectSearchResultIDs(result) + if len(ids) != 1 { + t.Fatalf("expected exactly one search result id, got %d (%v)", len(ids), ids) + } + if ids[0] != mark.Ident { + t.Fatalf("collected id %v does not match marker id %v", ids[0], mark.Ident) + } +} + +func TestCollectSearchResultIDsDedupes(t *testing.T) { + cu, err := parser.NewGoParser().Parse("a.go", "package main\n\nvar x = 1 + 2 + 3\n") + if err != nil { + t.Fatal(err) + } + // Same marker (same UUID) attached to two binary expressions: collector + // should only return it once. + mark := tree.NewSearchResult("dup") + v := &markBinaryVisitor{marker: mark} + visitor.Init(v) + + result := v.Visit(cu, recipe.NewExecutionContext()).(tree.Tree) + ids := tree.CollectSearchResultIDs(result) + if len(ids) != 1 { + t.Fatalf("expected dedup to produce 1 id, got %d (%v)", len(ids), ids) + } +} diff --git a/rewrite-go/test/struct_tag_test.go b/rewrite-go/test/struct_tag_test.go new file mode 100644 index 00000000000..30793aa16e7 --- /dev/null +++ b/rewrite-go/test/struct_tag_test.go @@ -0,0 +1,188 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/printer" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// Step 2 of AnnotationService rollout: the parser splits struct field +// tags into one Annotation per `key:"value"` pair on +// VariableDeclarations.LeadingAnnotations. The printer reassembles the +// run, wrapping it in backticks. Roundtrip on gofmt'd input is exact +// (Option 1 in the design discussion: lossy on inner-padding only). + +func parseStructAndFindField(t *testing.T, src, fieldName string) *tree.VariableDeclarations { + t.Helper() + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, rp := range cu.Statements { + td, ok := rp.Element.(*tree.TypeDecl) + if !ok { + continue + } + st, ok := td.Definition.(*tree.StructType) + if !ok || st.Body == nil { + continue + } + for _, fr := range st.Body.Statements { + vd, ok := fr.Element.(*tree.VariableDeclarations) + if !ok { + continue + } + for _, vr := range vd.Variables { + if vr.Element != nil && vr.Element.Name != nil && vr.Element.Name.Name == fieldName { + return vd + } + } + } + } + t.Fatalf("field %q not found in struct", fieldName) + return nil +} + +func TestStructTag_SingleKeyParsesIntoOneAnnotation(t *testing.T) { + src := "package main\n\ntype User struct {\n\tName string `json:\"name\"`\n}\n" + vd := parseStructAndFindField(t, src, "Name") + + if got := len(vd.LeadingAnnotations); got != 1 { + t.Fatalf("LeadingAnnotations: got %d, want 1", got) + } + ann := vd.LeadingAnnotations[0] + if id, ok := ann.AnnotationType.(*tree.Identifier); !ok || id.Name != "json" { + t.Errorf("AnnotationType: got %+v, want Identifier{Name:\"json\"}", ann.AnnotationType) + } + if ann.Arguments == nil || len(ann.Arguments.Elements) != 1 { + t.Fatalf("Arguments: got %+v, want 1 element", ann.Arguments) + } + lit, ok := ann.Arguments.Elements[0].Element.(*tree.Literal) + if !ok { + t.Fatalf("Arguments[0]: got %T, want *Literal", ann.Arguments.Elements[0].Element) + } + if lit.Source != `"name"` { + t.Errorf("Arguments[0].Source: got %q, want %q", lit.Source, `"name"`) + } + if v, _ := lit.Value.(string); v != "name" { + t.Errorf("Arguments[0].Value: got %v, want %q", lit.Value, "name") + } +} + +func TestStructTag_MultipleKeysParseIntoMultipleAnnotations(t *testing.T) { + src := "package main\n\ntype User struct {\n\tEmail string `json:\"email,omitempty\" db:\"email_address\"`\n}\n" + vd := parseStructAndFindField(t, src, "Email") + + if got := len(vd.LeadingAnnotations); got != 2 { + t.Fatalf("LeadingAnnotations: got %d, want 2", got) + } + + first := vd.LeadingAnnotations[0] + if id, ok := first.AnnotationType.(*tree.Identifier); !ok || id.Name != "json" { + t.Errorf("[0] AnnotationType: got %+v, want json", first.AnnotationType) + } + if lit := first.Arguments.Elements[0].Element.(*tree.Literal); lit.Source != `"email,omitempty"` { + t.Errorf("[0] Source: got %q, want %q", lit.Source, `"email,omitempty"`) + } + + second := vd.LeadingAnnotations[1] + if id, ok := second.AnnotationType.(*tree.Identifier); !ok || id.Name != "db" { + t.Errorf("[1] AnnotationType: got %+v, want db", second.AnnotationType) + } + if lit := second.Arguments.Elements[0].Element.(*tree.Literal); lit.Source != `"email_address"` { + t.Errorf("[1] Source: got %q, want %q", lit.Source, `"email_address"`) + } + // Inter-pair whitespace lives on the second annotation's Prefix. + if second.Prefix.Whitespace != " " { + t.Errorf("[1] Prefix.Whitespace: got %q, want %q", second.Prefix.Whitespace, " ") + } +} + +func TestStructTag_NoMarkerLeftBehind(t *testing.T) { + src := "package main\n\ntype User struct {\n\tName string `json:\"name\"`\n}\n" + vd := parseStructAndFindField(t, src, "Name") + + for _, m := range vd.Markers.Entries { + if _, ok := m.(tree.StructTag); ok { + t.Errorf("StructTag marker should no longer be emitted; LeadingAnnotations is the canonical pathway") + } + } +} + +func TestStructTag_RoundtripGofmtdInput(t *testing.T) { + src := "package main\n\ntype User struct {\n\tName string `json:\"name\"`\n\tEmail string `json:\"email,omitempty\" db:\"email_address\"`\n\tID int `json:\"-\"`\n}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + got := printer.Print(cu) + if got != src { + t.Errorf("roundtrip mismatch\nexpected:\n%s\nactual:\n%s", src, got) + } +} + +func TestStructTag_RoundtripWithEscapes(t *testing.T) { + // A backslash-escaped quote inside the value must roundtrip. + src := "package main\n\ntype X struct {\n\tField string `json:\"a\\\"b\"`\n}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + got := printer.Print(cu) + if got != src { + t.Errorf("roundtrip mismatch\nexpected: %q\nactual: %q", src, got) + } +} + +func TestStructTag_DashValueRoundtrip(t *testing.T) { + // `json:"-"` is the convention for "skip this field". + src := "package main\n\ntype X struct {\n\tField string `json:\"-\"`\n}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + got := printer.Print(cu) + if got != src { + t.Errorf("roundtrip mismatch\nexpected: %q\nactual: %q", src, got) + } + vd := parseStructAndFindField(t, src, "Field") + lit := vd.LeadingAnnotations[0].Arguments.Elements[0].Element.(*tree.Literal) + if v, _ := lit.Value.(string); v != "-" { + t.Errorf("Value: got %v, want %q", lit.Value, "-") + } +} + +func TestStructTag_NonStructDoesNotEmitAnnotations(t *testing.T) { + // Top-level / local var declarations don't have struct-tag syntax; + // the parser should not emit any LeadingAnnotations on them. + src := "package main\n\nvar x int = 1\n\nfunc f() {\n\ty := 2\n\t_ = y\n}\n" + cu, err := parser.NewGoParser().Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + for _, rp := range cu.Statements { + if vd, ok := rp.Element.(*tree.VariableDeclarations); ok { + if len(vd.LeadingAnnotations) > 0 { + t.Errorf("top-level VariableDeclarations got %d LeadingAnnotations, want 0", len(vd.LeadingAnnotations)) + } + } + } +} diff --git a/rewrite-go/test/testdata/printer-corpus/README.md b/rewrite-go/test/testdata/printer-corpus/README.md new file mode 100644 index 00000000000..b7e63c32085 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/README.md @@ -0,0 +1,47 @@ +# Printer fidelity corpus + +A growing collection of `.go` fixtures used to detect regressions in +`pkg/printer/go_printer.go`. Each fixture is parsed and printed; the +output must be byte-equal to the input. + +## Layout + +``` +test/testdata/printer-corpus/ + gofmt/ ← non-gofmt'd inputs (mixed tabs/spaces, brace placement, etc.) + generics/ ← multi-line type parameters, union constraints, nested generics + README.md + TODO.md ← known failures with notes on the suspected fix area +``` + +Lives under `testdata/` so `go test ./...` skips it (Go treats +`testdata/` as a magic directory). Every `.go` file under any +subdirectory is included automatically by the corpus driver in +`pkg/printer/parity_test.go`. + +## Running + +The corpus is gated behind the `parityaudit` build tag so it never runs +in CI. Locally: + +```sh +make parity +``` + +That target invokes `go test -tags parityaudit ./pkg/printer/...`, +which picks up the corpus driver and walks the fixtures. + +## Adding cases + +1. Drop a `.go` file under `gofmt/` or `generics/` with whatever shape + you suspect breaks the printer. +2. Run `make parity`. If your case fails, file the diff in `TODO.md` + alongside a one-line guess at the broken printer code path. +3. Fix the printer; re-run; the test passes. + +## Why isn't this in CI? + +P2 in the eng review: corpus runs are open-ended (a new bug can land +without a corpus regression, and a corpus diff can take longer to triage +than tests like `go test`). Keeping it manual gives fast iteration on +real bug reports without making the CI pipeline noisy. diff --git a/rewrite-go/test/testdata/printer-corpus/TODO.md b/rewrite-go/test/testdata/printer-corpus/TODO.md new file mode 100644 index 00000000000..2ee203c8d9d --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/TODO.md @@ -0,0 +1,30 @@ +# Printer corpus — known failures + +Track parse → print byte-equality failures here. Each entry should +describe the failing fixture, paste a minimal diff, and guess at the +broken printer code path so the next dev (or you, in a week) can pick +up the trail. + +## Status + +As of the initial corpus (13 fixtures across `gofmt/` and `generics/`), +**all cases pass byte-equality**. The list below is empty. + +When `make parity` fails on a new fixture: + +1. Add a heading here with the fixture path. +2. Paste the minimal expected/actual diff from the test output. +3. Note one or two suspect locations in `pkg/printer/go_printer.go`. +4. Open a PR that adds the failing fixture **and** the printer fix in + the same change so the corpus stays green. + +## Open + +_(none)_ + +## Adjacent work that surfaced through the corpus driver + +_(none)_ — the previously-noted semicolon-between-statements parser bug, +the AddImport-on-empty-file whitespace bug, and the OrderImports +reorder-loses-newlines bug have all shipped. See +`test/import_recipes_test.go` for the regression coverage. diff --git a/rewrite-go/test/testdata/printer-corpus/generics/01_generic_func.go b/rewrite-go/test/testdata/printer-corpus/generics/01_generic_func.go new file mode 100644 index 00000000000..3b4d75b5573 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/generics/01_generic_func.go @@ -0,0 +1,9 @@ +package main + +func Map[T any](xs []T, f func(T) T) []T { + out := make([]T, len(xs)) + for i, x := range xs { + out[i] = f(x) + } + return out +} diff --git a/rewrite-go/test/testdata/printer-corpus/generics/02_generic_struct.go b/rewrite-go/test/testdata/printer-corpus/generics/02_generic_struct.go new file mode 100644 index 00000000000..93d8f46df65 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/generics/02_generic_struct.go @@ -0,0 +1,9 @@ +package main + +type Box[T any] struct { + V T +} + +func (b Box[T]) Get() T { + return b.V +} diff --git a/rewrite-go/test/testdata/printer-corpus/generics/03_multi_param.go b/rewrite-go/test/testdata/printer-corpus/generics/03_multi_param.go new file mode 100644 index 00000000000..b277710ef61 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/generics/03_multi_param.go @@ -0,0 +1,6 @@ +package main + +type Pair[K comparable, V any] struct { + Key K + Val V +} diff --git a/rewrite-go/test/testdata/printer-corpus/generics/04_union_constraint.go b/rewrite-go/test/testdata/printer-corpus/generics/04_union_constraint.go new file mode 100644 index 00000000000..b8a00861907 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/generics/04_union_constraint.go @@ -0,0 +1,9 @@ +package main + +func Sum[T int | float64](xs []T) T { + var total T + for _, x := range xs { + total += x + } + return total +} diff --git a/rewrite-go/test/testdata/printer-corpus/generics/05_multiline_type_params.go b/rewrite-go/test/testdata/printer-corpus/generics/05_multiline_type_params.go new file mode 100644 index 00000000000..835a1344880 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/generics/05_multiline_type_params.go @@ -0,0 +1,9 @@ +package main + +func Combine[ + K comparable, + V any, + R any, +](k K, v V, fn func(K, V) R) R { + return fn(k, v) +} diff --git a/rewrite-go/test/testdata/printer-corpus/generics/06_nested_generic.go b/rewrite-go/test/testdata/printer-corpus/generics/06_nested_generic.go new file mode 100644 index 00000000000..35e022d6ecf --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/generics/06_nested_generic.go @@ -0,0 +1,13 @@ +package main + +type Container[T any] struct { + Items []Box[T] +} + +type Box[T any] struct { + V T +} + +func New[T any]() Container[T] { + return Container[T]{Items: []Box[T]{}} +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/01_basic_main.go b/rewrite-go/test/testdata/printer-corpus/gofmt/01_basic_main.go new file mode 100644 index 00000000000..635db7ae6c1 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/01_basic_main.go @@ -0,0 +1,7 @@ +package main + +import "fmt" + +func main() { + fmt.Println("hello, world") +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/02_grouped_imports.go b/rewrite-go/test/testdata/printer-corpus/gofmt/02_grouped_imports.go new file mode 100644 index 00000000000..347502b300a --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/02_grouped_imports.go @@ -0,0 +1,10 @@ +package main + +import ( + "fmt" + "strings" +) + +func main() { + fmt.Println(strings.ToUpper("hi")) +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/03_no_trailing_newline.go b/rewrite-go/test/testdata/printer-corpus/gofmt/03_no_trailing_newline.go new file mode 100644 index 00000000000..fd69b0c219d --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/03_no_trailing_newline.go @@ -0,0 +1,3 @@ +package main + +func main() {} \ No newline at end of file diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/04_struct_with_tags.go b/rewrite-go/test/testdata/printer-corpus/gofmt/04_struct_with_tags.go new file mode 100644 index 00000000000..f8e477a709a --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/04_struct_with_tags.go @@ -0,0 +1,6 @@ +package main + +type Point struct { + X int `json:"x"` + Y int `json:"y"` +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/05_method_decl.go b/rewrite-go/test/testdata/printer-corpus/gofmt/05_method_decl.go new file mode 100644 index 00000000000..703c80b2171 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/05_method_decl.go @@ -0,0 +1,13 @@ +package main + +type Counter struct { + count int +} + +func (c *Counter) Inc() { + c.count++ +} + +func (c Counter) Value() int { + return c.count +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/06_weird_spacing.go b/rewrite-go/test/testdata/printer-corpus/gofmt/06_weird_spacing.go new file mode 100644 index 00000000000..a3dd29a2103 --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/06_weird_spacing.go @@ -0,0 +1,7 @@ +package main + +import "fmt" + +func main() { + fmt.Println( "spacing" ) +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/07_block_comments.go b/rewrite-go/test/testdata/printer-corpus/gofmt/07_block_comments.go new file mode 100644 index 00000000000..99a693d6ece --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/07_block_comments.go @@ -0,0 +1,11 @@ +package main + +/* + * A block comment spanning + * multiple lines. + */ +func main() { + // line comment + x := 1 // trailing + _ = x +} diff --git a/rewrite-go/test/testdata/printer-corpus/gofmt/08_inline_semicolon.go b/rewrite-go/test/testdata/printer-corpus/gofmt/08_inline_semicolon.go new file mode 100644 index 00000000000..b1e5a6a7f0a --- /dev/null +++ b/rewrite-go/test/testdata/printer-corpus/gofmt/08_inline_semicolon.go @@ -0,0 +1,6 @@ +package main + +func main() { + _ = 1; _ = 2 + x := 10; y := 20; _ = x + y +} diff --git a/rewrite-go/test/type_attribution_test.go b/rewrite-go/test/type_attribution_test.go index ac3f9114a2e..126f3f27eda 100644 --- a/rewrite-go/test/type_attribution_test.go +++ b/rewrite-go/test/type_attribution_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" "github.com/openrewrite/rewrite/rewrite-go/pkg/visitor" ) @@ -50,19 +51,6 @@ func (v *methodTypeCollector) VisitMethodDeclaration(md *tree.MethodDeclaration, return v.GoVisitor.VisitMethodDeclaration(md, p) } -// methodInvocationCollector visits a tree and collects method invocation types. -type methodInvocationCollector struct { - visitor.GoVisitor - methodTypes map[string]*tree.JavaTypeMethod -} - -func (v *methodInvocationCollector) VisitMethodInvocation(mi *tree.MethodInvocation, p any) tree.J { - if mi.MethodType != nil && mi.Name != nil { - v.methodTypes[mi.Name.Name] = mi.MethodType - } - return v.GoVisitor.VisitMethodInvocation(mi, p) -} - func TestTypeAttributionLocalVars(t *testing.T) { p := parser.NewGoParser() cu, err := p.Parse("test.go", `package main @@ -156,22 +144,7 @@ func main() { t.Fatal(err) } - v := visitor.Init(&methodInvocationCollector{methodTypes: make(map[string]*tree.JavaTypeMethod)}) - v.Visit(cu, nil) - - printlnType, ok := v.methodTypes["Println"] - if !ok { - t.Fatal("no method type for fmt.Println()") - } - if printlnType.Name != "Println" { - t.Errorf("expected method name 'Println', got '%s'", printlnType.Name) - } - if printlnType.DeclaringType == nil { - t.Fatal("expected declaring type for Println") - } - if printlnType.DeclaringType.FullyQualifiedName != "fmt" { - t.Errorf("expected declaring type 'fmt', got '%s'", printlnType.DeclaringType.FullyQualifiedName) - } + ExpectMethodType(t, cu, "Println", "fmt") } func TestTypeAttributionStructType(t *testing.T) { @@ -192,24 +165,7 @@ func main() { t.Fatal(err) } - v := visitor.Init(&typeCollector{identTypes: make(map[string]tree.JavaType)}) - v.Visit(cu, nil) - - // "p" should have a type attributed to it - if pType, ok := v.identTypes["p"]; ok { - if cls, ok := pType.(*tree.JavaTypeClass); ok { - if cls.FullyQualifiedName != "main.Point" { - t.Errorf("expected p to be main.Point, got %s", cls.FullyQualifiedName) - } - if cls.Kind != "Class" { - t.Errorf("expected kind Class, got %s", cls.Kind) - } - } else { - t.Errorf("expected p to be class type, got %T", pType) - } - } else { - t.Error("no type attribution for p") - } + ExpectType(t, cu, "p", "main.Point") } func TestTypeAttributionGracefulDegradation(t *testing.T) { diff --git a/rewrite-go/test/vendor_walker_test.go b/rewrite-go/test/vendor_walker_test.go new file mode 100644 index 00000000000..7b2c92bb718 --- /dev/null +++ b/rewrite-go/test/vendor_walker_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + . "github.com/openrewrite/rewrite/rewrite-go/pkg/test" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// vendorScaffold builds a vendor directory layout under root from a map of +// relative path → file content. Used by the T2 corpus. +func vendorScaffold(t *testing.T, root string, files map[string]string) { + t.Helper() + for rel, content := range files { + full := filepath.Join(root, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", full, err) + } + if err := os.WriteFile(full, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", full, err) + } + } +} + +// parseInProject sets up a ProjectImporter with the given project root + +// require/replace metadata, then parses src as the file at sourcePath. +// Returns the parsed compilation unit so tests can assert on resolved +// types via the ExpectType helpers. +func parseInProject(t *testing.T, root string, modulePath string, requires []string, replaces map[string]string, sourcePath, src string) *tree.CompilationUnit { + t.Helper() + pi := parser.NewProjectImporter(modulePath, nil) + pi.SetProjectRoot(root) + for _, r := range requires { + pi.AddRequire(r) + } + for old, newPath := range replaces { + pi.AddReplace(old, newPath, "") + } + pi.AddSource(sourcePath, src) + p := parser.NewGoParser() + p.Importer = pi + cu, err := p.Parse(sourcePath, src) + if err != nil { + t.Fatalf("parse: %v", err) + } + return cu +} + +// Case 1: vendor happy path — vendored package's func type-resolves. +func TestVendorWalker_HappyPath(t *testing.T) { + root := t.TempDir() + vendorScaffold(t, root, map[string]string{ + "vendor/github.com/x/y/y.go": "package y\n\nfunc Hello() string { return \"hi\" }\n", + }) + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/x/y"}, + nil, + "main.go", + "package main\n\nimport \"github.com/x/y\"\n\nfunc main() { _ = y.Hello() }\n", + ) + ExpectMethodType(t, cu, "Hello", "github.com/x/y") +} + +// Case 2: replace → local path. `replace foo => ./local/foo` walks the +// local dir relative to the project root. +func TestVendorWalker_ReplaceLocal(t *testing.T) { + root := t.TempDir() + vendorScaffold(t, root, map[string]string{ + "local/y/y.go": "package y\n\nfunc Hello() string { return \"hi\" }\n", + }) + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/x/y"}, + map[string]string{"github.com/x/y": "./local/y"}, + "main.go", + "package main\n\nimport \"github.com/x/y\"\n\nfunc main() { _ = y.Hello() }\n", + ) + ExpectMethodType(t, cu, "Hello", "github.com/x/y") +} + +// Case 3: replace → other module path. `replace foo => bar` walks +// `vendor/bar/` instead of `vendor/foo/`. +func TestVendorWalker_ReplaceModule(t *testing.T) { + root := t.TempDir() + vendorScaffold(t, root, map[string]string{ + "vendor/github.com/forked/y/y.go": "package y\n\nfunc Hello() string { return \"forked\" }\n", + }) + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/x/y"}, + map[string]string{"github.com/x/y": "github.com/forked/y"}, + "main.go", + "package main\n\nimport \"github.com/x/y\"\n\nfunc main() { _ = y.Hello() }\n", + ) + // The package's import path is still github.com/x/y — Go's importer + // is told that's the path the user wrote. Method resolves on it. + ExpectMethodType(t, cu, "Hello", "github.com/x/y") +} + +// Case 4: aliased third-party import. `import yy "github.com/x/y"` — +// the alias should still resolve because Go's parser handles aliasing +// independently of the package's actual name. +func TestVendorWalker_AliasedImport(t *testing.T) { + root := t.TempDir() + vendorScaffold(t, root, map[string]string{ + "vendor/github.com/x/y/y.go": "package y\n\nfunc Hello() string { return \"hi\" }\n", + }) + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/x/y"}, + nil, + "main.go", + "package main\n\nimport yy \"github.com/x/y\"\n\nfunc main() { _ = yy.Hello() }\n", + ) + ExpectMethodType(t, cu, "Hello", "github.com/x/y") +} + +// Case 5: multi-level transitive third-party. Vendored A imports +// vendored B; the importer must recursively walk vendor for B too. +func TestVendorWalker_MultiLevelTransitive(t *testing.T) { + root := t.TempDir() + vendorScaffold(t, root, map[string]string{ + "vendor/github.com/a/a/a.go": "package a\n\nimport \"github.com/b/b\"\n\nfunc Use() string { return b.World() }\n", + "vendor/github.com/b/b/b.go": "package b\n\nfunc World() string { return \"world\" }\n", + }) + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/a/a", "github.com/b/b"}, + nil, + "main.go", + "package main\n\nimport \"github.com/a/a\"\n\nfunc main() { _ = a.Use() }\n", + ) + ExpectMethodType(t, cu, "Use", "github.com/a/a") +} + +// Case 6: vendor parse error → fallback to stub. Per C4 directive: a +// broken vendored file logs + falls back to the stub tier so the parent +// parse still succeeds. +func TestVendorWalker_ParseErrorFallsBack(t *testing.T) { + root := t.TempDir() + vendorScaffold(t, root, map[string]string{ + // Intentionally malformed: missing function body close brace. + "vendor/github.com/x/y/y.go": "package y\n\nfunc Broken() string { return\n", + }) + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/x/y"}, + nil, + "main.go", + "package main\n\nimport \"github.com/x/y\"\n\nvar _ = y\n", + ) + // The parent parse must succeed (stub fallback). The `y` package + // alias should be a FullyQualified type with the import path as FQN. + ExpectType(t, cu, "y", "github.com/x/y") +} + +// Case 7: missing vendor → fallback to stub. No vendor dir exists for +// the imported package; importer falls through to the stub tier so the +// parent parse still succeeds and the package alias resolves. +func TestVendorWalker_MissingVendorFallsBack(t *testing.T) { + root := t.TempDir() + cu := parseInProject(t, root, + "example.com/foo", + []string{"github.com/x/y"}, + nil, + "main.go", + "package main\n\nimport \"github.com/x/y\"\n\nvar _ = y\n", + ) + ExpectType(t, cu, "y", "github.com/x/y") +} diff --git a/rewrite-go/test/whitespace_validation_service_test.go b/rewrite-go/test/whitespace_validation_service_test.go new file mode 100644 index 00000000000..729a4ce57c4 --- /dev/null +++ b/rewrite-go/test/whitespace_validation_service_test.go @@ -0,0 +1,90 @@ +/* + * Copyright 2026 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "strings" + "testing" + + "github.com/openrewrite/rewrite/rewrite-go/pkg/parser" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe" + "github.com/openrewrite/rewrite/rewrite-go/pkg/recipe/golang" + "github.com/openrewrite/rewrite/rewrite-go/pkg/tree" +) + +// TestWhitespaceValidationService_RegisteredOnInit verifies that +// importing pkg/recipe/golang registers the service. +func TestWhitespaceValidationService_RegisteredOnInit(t *testing.T) { + svc := recipe.Service[*golang.WhitespaceValidationService](nil) + if svc == nil { + t.Fatal("recipe.Service returned nil for *golang.WhitespaceValidationService") + } +} + +// TestWhitespaceValidationService_CleanTree confirms a freshly parsed +// CU validates clean. +func TestWhitespaceValidationService_CleanTree(t *testing.T) { + src := "package main\n\nfunc main() {\n\tprintln(\"hi\")\n}\n" + p := parser.NewGoParser() + cu, err := p.Parse("test.go", src) + if err != nil { + t.Fatalf("parse: %v", err) + } + svc := &golang.WhitespaceValidationService{} + if errs := svc.Validate(cu); len(errs) != 0 { + t.Fatalf("expected clean tree to validate, got %d errs:\n%s", len(errs), strings.Join(errs, "\n")) + } + if !svc.IsValid(cu) { + t.Error("IsValid disagrees with Validate on a clean tree") + } +} + +// TestWhitespaceValidationService_DetectsCorruption hand-crafts a tree +// containing a Space whose Whitespace contains source text that should +// have been parsed into a node. Verifies the validator flags it. +func TestWhitespaceValidationService_DetectsCorruption(t *testing.T) { + cu := &tree.CompilationUnit{ + Prefix: tree.Space{Whitespace: "package main"}, // non-whitespace stowed away + } + svc := &golang.WhitespaceValidationService{} + errs := svc.Validate(cu) + if len(errs) == 0 { + t.Fatal("expected validator to flag non-whitespace in Space.Whitespace") + } + if !strings.Contains(errs[0], "non-whitespace") { + t.Errorf("error should mention non-whitespace, got: %s", errs[0]) + } + if svc.IsValid(cu) { + t.Error("IsValid should be false when Validate returned errors") + } +} + +// TestWhitespaceValidationService_DetectsBadComment crafts a Comment +// whose Text doesn't begin with `//` or `/*` — the printer would emit +// it raw, so the validator must catch it. +func TestWhitespaceValidationService_DetectsBadComment(t *testing.T) { + cu := &tree.CompilationUnit{ + Prefix: tree.Space{ + Comments: []tree.Comment{{Text: "this is not a comment", Suffix: "\n"}}, + }, + } + svc := &golang.WhitespaceValidationService{} + errs := svc.Validate(cu) + if len(errs) == 0 { + t.Fatal("expected validator to flag a non-comment Text") + } +} diff --git a/rewrite-javascript/rewrite/package-lock.json b/rewrite-javascript/rewrite/package-lock.json index 6d136f82310..89dd92a6a26 100644 --- a/rewrite-javascript/rewrite/package-lock.json +++ b/rewrite-javascript/rewrite/package-lock.json @@ -1053,6 +1053,7 @@ "resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.1.tgz", "integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==", "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1455,6 +1456,7 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -1689,6 +1691,7 @@ "integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.27.0", "fdir": "^6.5.0", @@ -1867,6 +1870,7 @@ "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.1.tgz", "integrity": "sha512-lcYcMxX2PO9XMGvAJkJ3OsNMw+/7FKes7/hgerGUYWIoWu5j/+YQqcZr5JnPZWzOsEBgMbSbiSTn/dv/69Mkpw==", "license": "ISC", + "peer": true, "bin": { "yaml": "bin.mjs" }, diff --git a/rewrite-javascript/rewrite/src/rpc/request/batch-visit.ts b/rewrite-javascript/rewrite/src/rpc/request/batch-visit.ts index f4ee7ae4962..c2a4c3c47ae 100644 --- a/rewrite-javascript/rewrite/src/rpc/request/batch-visit.ts +++ b/rewrite-javascript/rewrite/src/rpc/request/batch-visit.ts @@ -86,11 +86,24 @@ export class BatchVisit { const visitor = await Visit.instantiateVisitor( {visitor: item.visitor, visitorOptions: item.visitorOptions}, preparedRecipes, recipeCursors, p); + // Snapshot ctx message keys so we can flag whether + // the visitor put anything new into the context. + const preKeys = new Set<string | symbol>( + Reflect.ownKeys(p.messages) as (string | symbol)[] + ); const after = await visitor.visit(tree, p, cursor); const modified = after !== tree; const deleted = after == null; + let hasNewMessages = false; + for (const k of Reflect.ownKeys(p.messages)) { + if (!preKeys.has(k)) { + hasNewMessages = true; + break; + } + } + // Diff SearchResult IDs against the running set let searchResultIds: string[]; if (deleted) { @@ -101,7 +114,7 @@ export class BatchVisit { for (const id of searchResultIds) knownIds.add(id); } - results.push({modified, deleted, hasNewMessages: false, searchResultIds}); + results.push({modified, deleted, hasNewMessages, searchResultIds}); if (deleted) { localObjects.delete(request.treeId);