Skip to content

Commit 167ef07

Browse files
committed
fix: Suport cross package types when Flattening
1 parent 562507e commit 167ef07

File tree

5 files changed

+82
-4
lines changed

5 files changed

+82
-4
lines changed

structs/flatten.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type Field struct {
4646
type templateData struct {
4747
HeaderComment string
4848
PackageName string
49+
Imports []string
4950
SourceName string
5051
OutputName string
5152
Fields []Field
@@ -100,6 +101,7 @@ func flattenOne(f *ast.File, cfg StructConfig, typeKinds map[string]string, scal
100101
data := templateData{
101102
HeaderComment: o.headerComment,
102103
PackageName: o.packageName,
104+
Imports: collectImports(fields),
103105
SourceName: cfg.SourceName,
104106
OutputName: cfg.OutputName,
105107
Fields: fields,
@@ -128,6 +130,38 @@ func flattenOne(f *ast.File, cfg StructConfig, typeKinds map[string]string, scal
128130
return nil
129131
}
130132

133+
// collectImports extracts unique package import paths from field types that
134+
// contain a dot (e.g. "time.Time" -> "time"). Returns a sorted, deduplicated list.
135+
func collectImports(fields []Field) []string {
136+
// Map well-known qualified type prefixes to their import paths.
137+
pkgToImport := map[string]string{
138+
"time": "time",
139+
"net": "net",
140+
"url": "net/url",
141+
"json": "encoding/json",
142+
"uuid": "github.com/google/uuid",
143+
}
144+
145+
seen := map[string]bool{}
146+
for _, f := range fields {
147+
typ := strings.TrimPrefix(f.Type, "*")
148+
typ = strings.TrimPrefix(typ, "[]")
149+
if dot := strings.IndexByte(typ, '.'); dot > 0 {
150+
pkg := typ[:dot]
151+
if imp, ok := pkgToImport[pkg]; ok && !seen[imp] {
152+
seen[imp] = true
153+
}
154+
}
155+
}
156+
157+
imports := make([]string, 0, len(seen))
158+
for imp := range seen {
159+
imports = append(imports, imp)
160+
}
161+
slices.Sort(imports)
162+
return imports
163+
}
164+
131165
// sortFields sorts fields with "ID" first, then alphabetically by name (case-insensitive).
132166
func sortFields(fields []Field) {
133167
slices.SortStableFunc(fields, func(a, b Field) int {

structs/flatten_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,13 @@ func TestResolveType(t *testing.T) {
202202
want: "[]map[string]any",
203203
},
204204
{
205-
name: "cross-package type",
205+
name: "cross-package scalar (time.Time)",
206206
expr: &ast.SelectorExpr{X: &ast.Ident{Name: "time"}, Sel: &ast.Ident{Name: "Time"}},
207+
want: "time.Time",
208+
},
209+
{
210+
name: "cross-package struct",
211+
expr: &ast.SelectorExpr{X: &ast.Ident{Name: "api"}, Sel: &ast.Ident{Name: "Finding"}},
207212
want: "map[string]any",
208213
},
209214
{

structs/options.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,35 @@
11
package structs
22

33
// defaultScalarKinds are Go types that are inexpensive to decode and should stay typed.
4+
// Includes both unqualified names (for types in the same package) and
5+
// fully qualified names like "time.Time" (for cross-package types).
46
var defaultScalarKinds = map[string]bool{
57
"string": true,
68
"bool": true,
9+
"float32": true,
710
"float64": true,
8-
"int64": true,
911
"int": true,
12+
"int8": true,
13+
"int16": true,
14+
"int32": true,
15+
"int64": true,
16+
"uint": true,
17+
"uint8": true,
18+
"uint16": true,
19+
"uint32": true,
20+
"uint64": true,
21+
"byte": true,
22+
"rune": true,
23+
24+
"time.Time": true,
25+
"time.Duration": true,
26+
"net.IP": true,
27+
"net.HardwareAddr": true,
28+
"url.URL": true,
29+
"json.RawMessage": true,
30+
"json.Number": true,
31+
32+
"uuid.UUID": true,
1033
}
1134

1235
type options struct {

structs/resolve.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,18 @@ func resolveType(expr ast.Expr, typeKinds map[string]string, scalars map[string]
8080
}
8181
return "[]" + inner
8282

83+
case *ast.SelectorExpr:
84+
// Cross-package types like time.Time.
85+
if pkg, ok := t.X.(*ast.Ident); ok {
86+
qualified := pkg.Name + "." + t.Sel.Name
87+
if scalars[qualified] {
88+
return qualified
89+
}
90+
}
91+
return "map[string]any"
92+
8393
default:
84-
// SelectorExpr (cross-package types), MapType, InterfaceType, etc.
94+
// MapType, InterfaceType, etc.
8595
return "map[string]any"
8696
}
8797
}

structs/templates/struct.go.tpl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
// {{ .HeaderComment }}
22

33
package {{ .PackageName }}
4-
4+
{{ if .Imports }}
5+
import (
6+
{{- range .Imports }}
7+
"{{ . }}"
8+
{{- end }}
9+
)
10+
{{ end }}
511
// {{ .OutputName }} is an optimized version of {{ .SourceName }}
612
// where deeply nested struct fields are replaced with map[string]any to avoid
713
// the decode -> allocate -> re-encode cycle.

0 commit comments

Comments
 (0)