Skip to content

Commit de8ece6

Browse files
feat: Implement type flattener (#419)
* feat: Implement structs package to allow flattening an autogenerated struct * feat: Add tests * feat: Add cmd for go:generate commands * chore: lint * fix: Handle code review comments * fix: Allow multiname type params
1 parent 71d532b commit de8ece6

File tree

13 files changed

+907
-0
lines changed

13 files changed

+907
-0
lines changed

cmd/flatten/main.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// flatten generates an optimized Go struct from a source file.
2+
//
3+
// Usage:
4+
//
5+
// //go:generate go run github.com/cloudquery/codegen/cmd/flatten -source=../api/models.go -from=Finding -to=flatFinding -output=finding_generated.go -package=services -sort
6+
// //go:generate go run github.com/cloudquery/codegen/cmd/flatten -source=../api/models.go -from=Finding -to=flatFinding -output=finding_generated.go -package=services -sort -extra=AssetType:string:assetType,omitempty
7+
package main
8+
9+
import (
10+
"flag"
11+
"fmt"
12+
"log"
13+
"path/filepath"
14+
"strings"
15+
16+
"github.com/cloudquery/codegen/structs"
17+
)
18+
19+
func main() {
20+
source := flag.String("source", "", "path to Go source file containing the struct")
21+
from := flag.String("from", "", "name of the source struct")
22+
to := flag.String("to", "", "name of the generated struct")
23+
output := flag.String("output", "", "output filename")
24+
pkg := flag.String("package", "", "Go package name for the generated file")
25+
sort := flag.Bool("sort", false, "sort fields: ID first, then alphabetically")
26+
extra := flag.String("extra", "", "extra fields to prepend (Name:Type:JSONTag)")
27+
flag.Parse()
28+
29+
if *source == "" || *from == "" || *to == "" || *output == "" {
30+
flag.Usage()
31+
log.Fatal("required flags: -source, -from, -to, -output")
32+
}
33+
34+
cfg := structs.StructConfig{
35+
SourceName: *from,
36+
OutputName: *to,
37+
OutputFile: filepath.Base(*output),
38+
ExtraFields: parseExtraFields(*extra),
39+
}
40+
41+
outputDir := filepath.Dir(*output)
42+
if outputDir == "" || outputDir == "." {
43+
outputDir = "."
44+
}
45+
46+
var opts []structs.Option
47+
if *pkg != "" {
48+
opts = append(opts, structs.WithPackageName(*pkg))
49+
}
50+
if *sort {
51+
opts = append(opts, structs.WithSortFields())
52+
}
53+
54+
if err := structs.Flatten(*source, []structs.StructConfig{cfg}, outputDir, opts...); err != nil {
55+
log.Fatal(err)
56+
}
57+
}
58+
59+
func parseExtraFields(s string) []structs.Field {
60+
if s == "" {
61+
return nil
62+
}
63+
parts := strings.SplitN(s, ":", 3)
64+
if len(parts) != 3 {
65+
log.Fatalf("invalid -extra %q: expected Name:Type:JSONTag", s)
66+
}
67+
return []structs.Field{{
68+
Name: strings.TrimSpace(parts[0]),
69+
Type: strings.TrimSpace(parts[1]),
70+
JSONTag: strings.TrimSpace(parts[2]),
71+
}}
72+
}
73+
74+
func init() {
75+
flag.Usage = func() {
76+
_, _ = fmt.Fprintf(flag.CommandLine.Output(), "Usage: flatten [flags]\n\nGenerates an optimized Go struct with nested fields replaced by map[string]any.\n\nFlags:\n")
77+
flag.PrintDefaults()
78+
}
79+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Code generated by codegen/structs; DO NOT EDIT.
2+
3+
package testpkg
4+
5+
// flatFinding is an optimized version of Finding
6+
// where deeply nested struct fields are replaced with map[string]any to avoid
7+
// the decode -> allocate -> re-encode cycle.
8+
type flatFinding struct {
9+
ID string `json:"id"`
10+
Name string `json:"name"`
11+
Score *float64 `json:"score,omitempty"`
12+
Severity string `json:"severity"`
13+
Status string `json:"status"`
14+
Deleted bool `json:"deleted"`
15+
Count int64 `json:"count"`
16+
Address map[string]any `json:"address,omitempty"`
17+
Tags []map[string]any `json:"tags,omitempty"`
18+
Connection map[string]any `json:"connection"`
19+
Categories []string `json:"categories,omitempty"`
20+
Entity map[string]any `json:"entity"`
21+
}
22+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Code generated by codegen/structs; DO NOT EDIT.
2+
3+
package testpkg
4+
5+
// extendedSimple is an optimized version of Simple
6+
// where deeply nested struct fields are replaced with map[string]any to avoid
7+
// the decode -> allocate -> re-encode cycle.
8+
type extendedSimple struct {
9+
AssetType string `json:"assetType,omitempty"`
10+
ID string `json:"id"`
11+
Value int64 `json:"value"`
12+
}
13+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Code generated by codegen/structs; DO NOT EDIT.
2+
3+
package testpkg
4+
5+
// flatFinding is an optimized version of Finding
6+
// where deeply nested struct fields are replaced with map[string]any to avoid
7+
// the decode -> allocate -> re-encode cycle.
8+
type flatFinding struct {
9+
ID string `json:"id"`
10+
Name string `json:"name"`
11+
Score *float64 `json:"score,omitempty"`
12+
Severity string `json:"severity"`
13+
Status string `json:"status"`
14+
Deleted bool `json:"deleted"`
15+
Count int64 `json:"count"`
16+
Address map[string]any `json:"address,omitempty"`
17+
Tags []map[string]any `json:"tags,omitempty"`
18+
Connection map[string]any `json:"connection"`
19+
Categories []string `json:"categories,omitempty"`
20+
Entity map[string]any `json:"entity"`
21+
}
22+
23+
// Code generated by codegen/structs; DO NOT EDIT.
24+
25+
package testpkg
26+
27+
// flatSimple is an optimized version of Simple
28+
// where deeply nested struct fields are replaced with map[string]any to avoid
29+
// the decode -> allocate -> re-encode cycle.
30+
type flatSimple struct {
31+
ID string `json:"id"`
32+
Value int64 `json:"value"`
33+
}
34+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Code generated by codegen/structs; DO NOT EDIT.
2+
3+
package testpkg
4+
5+
// sortedFinding is an optimized version of Finding
6+
// where deeply nested struct fields are replaced with map[string]any to avoid
7+
// the decode -> allocate -> re-encode cycle.
8+
type sortedFinding struct {
9+
ID string `json:"id"`
10+
Address map[string]any `json:"address,omitempty"`
11+
Categories []string `json:"categories,omitempty"`
12+
Connection map[string]any `json:"connection"`
13+
Count int64 `json:"count"`
14+
Deleted bool `json:"deleted"`
15+
Entity map[string]any `json:"entity"`
16+
Name string `json:"name"`
17+
Score *float64 `json:"score,omitempty"`
18+
Severity string `json:"severity"`
19+
Status string `json:"status"`
20+
Tags []map[string]any `json:"tags,omitempty"`
21+
}
22+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Code generated by codegen/structs; DO NOT EDIT.
2+
3+
package testpkg
4+
5+
// sortedFinding is an optimized version of Finding
6+
// where deeply nested struct fields are replaced with map[string]any to avoid
7+
// the decode -> allocate -> re-encode cycle.
8+
type sortedFinding struct {
9+
AssetType string `json:"assetType,omitempty"`
10+
ID string `json:"id"`
11+
Address map[string]any `json:"address,omitempty"`
12+
Categories []string `json:"categories,omitempty"`
13+
Connection map[string]any `json:"connection"`
14+
Count int64 `json:"count"`
15+
Deleted bool `json:"deleted"`
16+
Entity map[string]any `json:"entity"`
17+
Name string `json:"name"`
18+
Score *float64 `json:"score,omitempty"`
19+
Severity string `json:"severity"`
20+
Status string `json:"status"`
21+
Tags []map[string]any `json:"tags,omitempty"`
22+
}
23+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Code generated by codegen/structs; DO NOT EDIT.
2+
3+
package testpkg
4+
5+
// flatFinding is an optimized version of Finding
6+
// where deeply nested struct fields are replaced with map[string]any to avoid
7+
// the decode -> allocate -> re-encode cycle.
8+
type flatFinding struct {
9+
ID string `json:"id"`
10+
Name string `json:"name"`
11+
Score *float64 `json:"score,omitempty"`
12+
Severity string `json:"severity"`
13+
Status string `json:"status"`
14+
Deleted bool `json:"deleted"`
15+
Count int64 `json:"count"`
16+
Address *Address `json:"address,omitempty"`
17+
Tags []map[string]any `json:"tags,omitempty"`
18+
Connection map[string]any `json:"connection"`
19+
Categories []string `json:"categories,omitempty"`
20+
Entity map[string]any `json:"entity"`
21+
}
22+

structs/flatten.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Package structs provides code generation utilities for optimizing Go struct
2+
// types. The primary use case is flattening deeply nested API model structs
3+
// into decode-efficient versions where nested struct fields are replaced with
4+
// map[string]any, eliminating the decode -> allocate -> re-encode cycle that occurs
5+
// when the CloudQuery SDK serializes nested fields as JSON columns.
6+
package structs
7+
8+
import (
9+
"bytes"
10+
"embed"
11+
"errors"
12+
"fmt"
13+
"go/ast"
14+
"go/format"
15+
"os"
16+
"path/filepath"
17+
"slices"
18+
"strings"
19+
"text/template"
20+
)
21+
22+
//go:embed templates/*.go.tpl
23+
var templatesFS embed.FS
24+
25+
var tmpl = template.Must(template.ParseFS(templatesFS, "templates/*.go.tpl"))
26+
27+
// StructConfig defines how to transform one source struct into an optimized version.
28+
type StructConfig struct {
29+
// SourceName is the struct name in the source file (e.g. "VulnerabilityFinding").
30+
SourceName string
31+
// OutputName is the generated struct name (e.g. "customVulnerabilityFinding").
32+
OutputName string
33+
// OutputFile is the filename for the generated file (e.g. "vulnerability_finding_generated.go").
34+
OutputFile string
35+
// ExtraFields are additional fields prepended to the generated struct.
36+
ExtraFields []Field
37+
}
38+
39+
// Field represents a struct field in the generated output.
40+
type Field struct {
41+
Name string
42+
Type string
43+
JSONTag string
44+
}
45+
46+
type templateData struct {
47+
HeaderComment string
48+
PackageName string
49+
SourceName string
50+
OutputName string
51+
Fields []Field
52+
}
53+
54+
// Flatten parses the Go source file at srcPath, finds the structs described by
55+
// configs, replaces deeply nested struct fields with map[string]any, and writes
56+
// generated files into outputDir.
57+
//
58+
// Scalar types and string-based enum types
59+
// are kept typed. All other types (structs, interfaces, cross-package types,
60+
// maps) are flattened to map[string]any. Pointer wrappers are preserved for
61+
// scalar types but collapsed for complex types.
62+
func Flatten(srcPath string, configs []StructConfig, outputDir string, opts ...Option) error {
63+
var o options
64+
for _, opt := range opts {
65+
opt(&o)
66+
}
67+
o.setDefaults(outputDir)
68+
69+
f, err := parseFile(srcPath)
70+
if err != nil {
71+
return err
72+
}
73+
74+
typeKinds := buildTypeKinds(f)
75+
scalars := o.scalars()
76+
77+
var errs []error
78+
for _, cfg := range configs {
79+
if err := flattenOne(f, cfg, typeKinds, scalars, outputDir, &o); err != nil {
80+
errs = append(errs, fmt.Errorf("%s: %w", cfg.SourceName, err))
81+
}
82+
}
83+
return errors.Join(errs...)
84+
}
85+
86+
func flattenOne(f *ast.File, cfg StructConfig, typeKinds map[string]string, scalars map[string]bool, outputDir string, o *options) error {
87+
st, err := findStruct(f, cfg.SourceName)
88+
if err != nil {
89+
return err
90+
}
91+
92+
fields := make([]Field, 0, len(cfg.ExtraFields)+len(st.Fields.List))
93+
fields = append(fields, cfg.ExtraFields...)
94+
sourceFields := extractFields(st, typeKinds, scalars)
95+
if o.sortFields {
96+
sortFields(sourceFields)
97+
}
98+
fields = append(fields, sourceFields...)
99+
100+
data := templateData{
101+
HeaderComment: o.headerComment,
102+
PackageName: o.packageName,
103+
SourceName: cfg.SourceName,
104+
OutputName: cfg.OutputName,
105+
Fields: fields,
106+
}
107+
108+
var buf bytes.Buffer
109+
if err := tmpl.ExecuteTemplate(&buf, "struct.go.tpl", data); err != nil {
110+
return fmt.Errorf("executing template: %w", err)
111+
}
112+
113+
formatted, err := format.Source(buf.Bytes())
114+
if err != nil {
115+
fmt.Fprintf(os.Stderr, "warning: could not format generated code for %s: %v\n", cfg.OutputFile, err)
116+
formatted = buf.Bytes()
117+
}
118+
119+
outputPath := filepath.Join(outputDir, cfg.OutputFile)
120+
if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil {
121+
return fmt.Errorf("creating output directory: %w", err)
122+
}
123+
if err := os.WriteFile(outputPath, formatted, 0o644); err != nil {
124+
return fmt.Errorf("writing %s: %w", outputPath, err)
125+
}
126+
127+
fmt.Printf("Generated %s with %d fields\n", outputPath, len(fields))
128+
return nil
129+
}
130+
131+
// sortFields sorts fields with "ID" first, then alphabetically by name (case-insensitive).
132+
func sortFields(fields []Field) {
133+
slices.SortStableFunc(fields, func(a, b Field) int {
134+
aIsID := strings.EqualFold(a.Name, "ID")
135+
bIsID := strings.EqualFold(b.Name, "ID")
136+
switch {
137+
case aIsID && !bIsID:
138+
return -1
139+
case !aIsID && bIsID:
140+
return 1
141+
default:
142+
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
143+
}
144+
})
145+
}

0 commit comments

Comments
 (0)