Skip to content

Commit 2e883d4

Browse files
Fix multiple bugs in rewrite-go parser, visitor, printer, and RPC (#7212)
Motivation: Code review of the rewrite-go module identified several bugs and missing visitor coverage that would impact correctness and robustness. Summary: - Fix panic recovery in safeHandleRequest to return a proper JSON-RPC error response instead of nil when a panic is caught - Fix parser to handle multiple import blocks (e.g., `import "fmt"` followed by `import "os"`), using a new ImportBlock marker to track block boundaries - Fix visitAndCast and visitExpression to handle nil returns from visitors instead of panicking on the type assertion - Fix VisitCompilationUnit to visit PackageDecl and Imports, which were previously skipped entirely (preventing recipe visitors from transforming package names or imports) - Remove dead code in mapArrayType (unreachable closePrefix computation)
1 parent 5eaabf3 commit 2e883d4

File tree

10 files changed

+358
-27
lines changed

10 files changed

+358
-27
lines changed

rewrite-go/rewrite/cmd/rpc/main.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,17 @@ func (s *server) writeMessage(resp *jsonRPCResponse) error {
194194
}
195195

196196
// safeHandleRequest wraps handleRequest with panic recovery.
197-
func (s *server) safeHandleRequest(req *jsonRPCRequest) *jsonRPCResponse {
197+
func (s *server) safeHandleRequest(req *jsonRPCRequest) (resp *jsonRPCResponse) {
198198
defer func() {
199199
if r := recover(); r != nil {
200200
buf := make([]byte, 4096)
201201
n := runtime.Stack(buf, false)
202202
s.logger.Printf("PANIC in %s: %v\n%s", req.Method, r, buf[:n])
203+
resp = &jsonRPCResponse{
204+
JSONRPC: "2.0",
205+
ID: req.ID,
206+
Error: &rpcError{Code: -32603, Message: fmt.Sprintf("Internal error: %v", r)},
207+
}
203208
}
204209
}()
205210
return s.handleRequest(req)

rewrite-go/rewrite/pkg/parser/go_parser.go

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -171,59 +171,102 @@ func (ctx *parseContext) mapFile(file *ast.File, sourcePath string) *tree.Compil
171171
}
172172
}
173173

174-
// mapImports maps the import declarations in the file.
174+
// mapImports maps all import declarations in the file into a single Container.
175+
// Go allows multiple import blocks; subsequent blocks are tracked via ImportBlock markers.
175176
func (ctx *parseContext) mapImports(file *ast.File) *tree.Container[*tree.Import] {
176-
var importDecl *ast.GenDecl
177+
// Collect all import GenDecls in order.
178+
var importDecls []*ast.GenDecl
177179
for _, decl := range file.Decls {
178180
if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.IMPORT {
179-
importDecl = gd
180-
break
181+
importDecls = append(importDecls, gd)
181182
}
182183
}
183-
if importDecl == nil {
184+
if len(importDecls) == 0 {
184185
return nil
185186
}
186187

187-
before := ctx.prefixAndSkip(importDecl.Pos(), len("import"))
188-
189188
var elements []tree.RightPadded[*tree.Import]
190-
191189
var containerMarkers tree.Markers
192-
if importDecl.Lparen.IsValid() {
193-
openParenPrefix := ctx.prefix(importDecl.Lparen)
194-
ctx.skip(1) // skip "("
190+
prevGrouped := false
195191

192+
// First import block: captured into Container.Before and Container.Markers
193+
first := importDecls[0]
194+
before := ctx.prefixAndSkip(first.Pos(), len("import"))
195+
196+
if first.Lparen.IsValid() {
197+
prevGrouped = true
198+
openParenPrefix := ctx.prefix(first.Lparen)
199+
ctx.skip(1) // skip "("
196200
containerMarkers = tree.Markers{
197201
ID: uuid.New(),
198202
Entries: []tree.Marker{
199203
tree.GroupedImport{Ident: uuid.New(), Before: openParenPrefix},
200204
},
201205
}
206+
}
202207

203-
for _, spec := range importDecl.Specs {
204-
is := spec.(*ast.ImportSpec)
205-
imp := ctx.mapImportSpec(is)
206-
elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp})
207-
}
208+
for _, spec := range first.Specs {
209+
is := spec.(*ast.ImportSpec)
210+
imp := ctx.mapImportSpec(is)
211+
elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp})
212+
}
208213

209-
closeParen := ctx.prefix(importDecl.Rparen)
214+
if first.Lparen.IsValid() {
215+
closeParen := ctx.prefix(first.Rparen)
210216
ctx.skip(1) // skip ")"
211-
212217
if len(elements) > 0 {
213218
elements[len(elements)-1].After = closeParen
214219
}
215-
} else {
216-
for _, spec := range importDecl.Specs {
217-
is := spec.(*ast.ImportSpec)
218-
imp := ctx.mapImportSpec(is)
219-
elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp})
220+
}
221+
222+
// Subsequent import blocks: attach ImportBlock marker to first import of each
223+
for _, importDecl := range importDecls[1:] {
224+
blockBefore := ctx.prefixAndSkip(importDecl.Pos(), len("import"))
225+
grouped := importDecl.Lparen.IsValid()
226+
var groupedBefore tree.Space
227+
if grouped {
228+
groupedBefore = ctx.prefix(importDecl.Lparen)
229+
ctx.skip(1) // skip "("
220230
}
231+
232+
ctx.mapImportBlockSpecs(importDecl, &elements, tree.ImportBlock{
233+
Ident: uuid.New(),
234+
ClosePrevious: prevGrouped,
235+
Before: blockBefore,
236+
Grouped: grouped,
237+
GroupedBefore: groupedBefore,
238+
})
239+
240+
if grouped {
241+
closeParen := ctx.prefix(importDecl.Rparen)
242+
ctx.skip(1) // skip ")"
243+
if len(elements) > 0 {
244+
elements[len(elements)-1].After = closeParen
245+
}
246+
}
247+
prevGrouped = grouped
221248
}
222249

223250
container := tree.Container[*tree.Import]{Before: before, Elements: elements, Markers: containerMarkers}
224251
return &container
225252
}
226253

254+
// mapImportBlockSpecs maps the specs of a subsequent import block, attaching
255+
// the ImportBlock marker to the first spec's Import node.
256+
func (ctx *parseContext) mapImportBlockSpecs(decl *ast.GenDecl, elements *[]tree.RightPadded[*tree.Import], marker tree.ImportBlock) {
257+
for j, spec := range decl.Specs {
258+
is := spec.(*ast.ImportSpec)
259+
imp := ctx.mapImportSpec(is)
260+
if j == 0 {
261+
imp.Markers = tree.Markers{
262+
ID: uuid.New(),
263+
Entries: []tree.Marker{marker},
264+
}
265+
}
266+
*elements = append(*elements, tree.RightPadded[*tree.Import]{Element: imp})
267+
}
268+
}
269+
227270
// mapImportSpec maps a single import spec.
228271
func (ctx *parseContext) mapImportSpec(spec *ast.ImportSpec) *tree.Import {
229272
prefix := ctx.prefix(spec.Pos())
@@ -1761,8 +1804,8 @@ func (ctx *parseContext) mapArrayType(expr *ast.ArrayType) tree.Expression {
17611804
length = ctx.mapExpr(expr.Len)
17621805
}
17631806

1764-
closePrefix := ctx.prefix(expr.Lbrack + token.Pos(ctx.findNextFrom('[', ctx.file.Offset(expr.Lbrack)) - ctx.file.Offset(expr.Lbrack)))
17651807
// Find the `]`
1808+
var closePrefix tree.Space
17661809
rbrackOff := ctx.findNext(']')
17671810
if rbrackOff >= 0 {
17681811
closePrefix = ctx.prefix(ctx.file.Pos(rbrackOff))

rewrite-go/rewrite/pkg/printer/go_printer.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,29 @@ func (p *GoPrinter) VisitCompilationUnit(cu *tree.CompilationUnit, param any) tr
7575
out.Append("import")
7676

7777
grouped := tree.FindMarker[tree.GroupedImport](cu.Imports.Markers)
78-
if grouped != nil {
78+
isGrouped := grouped != nil
79+
if isGrouped {
7980
p.visitSpace(grouped.Before, out)
8081
out.Append("(")
8182
}
8283
for _, rp := range cu.Imports.Elements {
84+
block := tree.FindMarker[tree.ImportBlock](rp.Element.Markers)
85+
if block != nil {
86+
if block.ClosePrevious {
87+
out.Append(")")
88+
}
89+
p.visitSpace(block.Before, out)
90+
out.Append("import")
91+
if block.Grouped {
92+
p.visitSpace(block.GroupedBefore, out)
93+
out.Append("(")
94+
}
95+
isGrouped = block.Grouped
96+
}
8397
p.Visit(rp.Element, out)
8498
p.visitSpace(rp.After, out)
8599
}
86-
if grouped != nil {
100+
if isGrouped {
87101
out.Append(")")
88102
}
89103
}

rewrite-go/rewrite/pkg/rpc/space_rpc.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ func sendMarkerCodecFields(v any, q *SendQueue) {
119119
// GroupedImport.rpcSend sends: id (UUID string), before whitespace (string)
120120
q.GetAndSend(m, func(x any) any { return x.(tree.GroupedImport).Ident.String() }, nil)
121121
q.GetAndSend(m, func(x any) any { return x.(tree.GroupedImport).Before.Whitespace }, nil)
122+
case tree.ImportBlock:
123+
// ImportBlock.rpcSend sends: id, closePrevious, before, grouped, groupedBefore
124+
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Ident.String() }, nil)
125+
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).ClosePrevious }, nil)
126+
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Before.Whitespace }, nil)
127+
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Grouped }, nil)
128+
q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).GroupedBefore.Whitespace }, nil)
122129
case tree.ShortVarDecl:
123130
q.GetAndSend(m, func(x any) any { return x.(tree.ShortVarDecl).Ident.String() }, nil)
124131
case tree.VarKeyword:
@@ -193,6 +200,21 @@ func receiveMarkersCodec(q *ReceiveQueue, before tree.Markers) tree.Markers {
193200
ws := receiveScalar[string](q, m.Before.Whitespace)
194201
m.Before = tree.Space{Whitespace: ws}
195202
return m
203+
case tree.ImportBlock:
204+
// ImportBlock.rpcReceive: id, closePrevious, before, grouped, groupedBefore
205+
idStr := receiveScalar[string](q, m.Ident.String())
206+
if idStr != "" {
207+
if parsed, err := uuid.Parse(idStr); err == nil {
208+
m.Ident = parsed
209+
}
210+
}
211+
m.ClosePrevious = receiveScalar[bool](q, m.ClosePrevious)
212+
ws := receiveScalar[string](q, m.Before.Whitespace)
213+
m.Before = tree.Space{Whitespace: ws}
214+
m.Grouped = receiveScalar[bool](q, m.Grouped)
215+
gbWs := receiveScalar[string](q, m.GroupedBefore.Whitespace)
216+
m.GroupedBefore = tree.Space{Whitespace: gbWs}
217+
return m
196218
case tree.ShortVarDecl:
197219
idStr := receiveScalar[string](q, m.Ident.String())
198220
if idStr != "" {

rewrite-go/rewrite/pkg/rpc/value_types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ func init() {
8686

8787
// Go-specific marker valueType registrations (for send-side type resolution)
8888
RegisterValueType(reflect.TypeOf(tree.GroupedImport{}), "org.openrewrite.golang.marker.GroupedImport")
89+
RegisterValueType(reflect.TypeOf(tree.ImportBlock{}), "org.openrewrite.golang.marker.ImportBlock")
8990
RegisterValueType(reflect.TypeOf(tree.ShortVarDecl{}), "org.openrewrite.golang.marker.ShortVarDecl")
9091
RegisterValueType(reflect.TypeOf(tree.VarKeyword{}), "org.openrewrite.golang.marker.VarKeyword")
9192
RegisterValueType(reflect.TypeOf(tree.ConstDecl{}), "org.openrewrite.golang.marker.ConstDecl")
@@ -175,6 +176,8 @@ func init() {
175176
RegisterFactory("org.openrewrite.marker.SearchResult", func() any { return tree.SearchResult{} })
176177
// GroupedImport: IS an RpcCodec, sends 2 sub-fields (id, before whitespace)
177178
RegisterFactory("org.openrewrite.golang.marker.GroupedImport", func() any { return tree.GroupedImport{} })
179+
// ImportBlock: IS an RpcCodec, sends 5 sub-fields (id, closePrevious, before, grouped, groupedBefore)
180+
RegisterFactory("org.openrewrite.golang.marker.ImportBlock", func() any { return tree.ImportBlock{} })
178181
// Go-specific markers: all are RpcCodec
179182
RegisterFactory("org.openrewrite.golang.marker.ShortVarDecl", func() any { return tree.ShortVarDecl{} })
180183
RegisterFactory("org.openrewrite.golang.marker.VarKeyword", func() any { return tree.VarKeyword{} })

rewrite-go/rewrite/pkg/tree/go.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ func (n *CompilationUnit) WithStatements(statements []RightPadded[Statement]) *C
5252
return &c
5353
}
5454

55+
func (n *CompilationUnit) WithPackageDecl(pkg *RightPadded[*Identifier]) *CompilationUnit {
56+
c := *n
57+
c.PackageDecl = pkg
58+
return &c
59+
}
60+
61+
func (n *CompilationUnit) WithImports(imports *Container[*Import]) *CompilationUnit {
62+
c := *n
63+
c.Imports = imports
64+
return &c
65+
}
66+
5567
func (n *CompilationUnit) WithEOF(eof Space) *CompilationUnit {
5668
c := *n
5769
c.EOF = eof
@@ -462,6 +474,19 @@ type GroupedImport struct {
462474

463475
func (g GroupedImport) ID() uuid.UUID { return g.Ident }
464476

477+
// ImportBlock is a marker on the first Import of a subsequent import block
478+
// (2nd, 3rd, etc.) in files with multiple import declarations. It carries
479+
// the information needed to print the block boundary.
480+
type ImportBlock struct {
481+
Ident uuid.UUID
482+
ClosePrevious bool // true if the previous block was grouped (need to print ")")
483+
Before Space // space before the "import" keyword
484+
Grouped bool // true if this block uses import (...)
485+
GroupedBefore Space // space between "import" and "(" (only if Grouped)
486+
}
487+
488+
func (b ImportBlock) ID() uuid.UUID { return b.Ident }
489+
465490
// MultiAssignment represents a multi-value assignment: `x, y = 1, 2` or `x, y := f()`.
466491
type MultiAssignment struct {
467492
ID uuid.UUID

rewrite-go/rewrite/pkg/visitor/go_visitor.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,19 @@ var _ VisitorI = (*GoVisitor)(nil)
218218
func (v *GoVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J {
219219
cu = cu.WithPrefix(v.self().VisitSpace(cu.Prefix, p))
220220
cu = cu.WithMarkers(v.visitMarkers(cu.Markers, p))
221+
if cu.PackageDecl != nil {
222+
pkg := *cu.PackageDecl
223+
pkg.Element = visitAndCast[*tree.Identifier](v, pkg.Element, p)
224+
pkg.After = v.self().VisitSpace(pkg.After, p)
225+
cu = cu.WithPackageDecl(&pkg)
226+
}
227+
if cu.Imports != nil {
228+
imports := *cu.Imports
229+
imports.Before = v.self().VisitSpace(imports.Before, p)
230+
imports.Markers = v.visitMarkers(imports.Markers, p)
231+
imports.Elements = visitRightPaddedList(v, imports.Elements, p)
232+
cu = cu.WithImports(&imports)
233+
}
221234
cu = cu.WithStatements(visitRightPaddedList(v, cu.Statements, p))
222235
cu = cu.WithEOF(v.self().VisitSpace(cu.EOF, p))
223236
return cu
@@ -552,11 +565,18 @@ func (v *GoVisitor) visitMarkers(markers tree.Markers, p any) tree.Markers {
552565

553566
func visitAndCast[T tree.Tree](v *GoVisitor, t tree.Tree, p any) T {
554567
result := v.self().Visit(t, p)
568+
if result == nil {
569+
var zero T
570+
return zero
571+
}
555572
return result.(T)
556573
}
557574

558575
func visitExpression(v *GoVisitor, expr tree.Expression, p any) tree.Expression {
559576
result := v.self().Visit(expr, p)
577+
if result == nil {
578+
return nil
579+
}
560580
return result.(tree.Expression)
561581
}
562582

rewrite-go/rewrite/test/import_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,35 @@ func TestParseGroupedImports(t *testing.T) {
4848
}
4949
`))
5050
}
51+
52+
func TestParseMultipleImportBlocks(t *testing.T) {
53+
NewRecipeSpec().RewriteRun(t,
54+
Golang(`
55+
package main
56+
57+
import "fmt"
58+
import "os"
59+
60+
func hello() {
61+
}
62+
`))
63+
}
64+
65+
func TestParseMultipleGroupedImportBlocks(t *testing.T) {
66+
NewRecipeSpec().RewriteRun(t,
67+
Golang(`
68+
package main
69+
70+
import (
71+
"fmt"
72+
)
73+
74+
import (
75+
"os"
76+
"strings"
77+
)
78+
79+
func hello() {
80+
}
81+
`))
82+
}

0 commit comments

Comments
 (0)