Skip to content

Commit ca59a0d

Browse files
fix: memory leak on reload
1 parent b1c3e4d commit ca59a0d

File tree

5 files changed

+233
-16
lines changed

5 files changed

+233
-16
lines changed

collector.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
type Collector interface {
1919
// Collect is the equivalent of prometheus.Collector.Collect() but takes a context to run in and a database to run on.
2020
Collect(context.Context, *sql.DB, chan<- Metric)
21+
// Close releases any resources held by the collector (e.g. prepared statements).
22+
Close() error
2123
}
2224

2325
// collector implements Collector. It wraps a collection of queries, metrics and the database to collect them from.
@@ -84,6 +86,25 @@ func (c *collector) Collect(ctx context.Context, conn *sql.DB, ch chan<- Metric)
8486
wg.Wait()
8587
}
8688

89+
// Close releases all prepared statements held by this collector's queries.
90+
func (c *collector) Close() error {
91+
var errs []error
92+
for _, q := range c.queries {
93+
if err := q.Close(); err != nil {
94+
errs = append(errs, err)
95+
}
96+
}
97+
if len(errs) > 0 {
98+
return fmt.Errorf("collector %s close errors: %v", c.logContext, errs)
99+
}
100+
return nil
101+
}
102+
103+
// Close implements Collector for cachingCollector.
104+
func (cc *cachingCollector) Close() error {
105+
return cc.rawColl.Close()
106+
}
107+
87108
// newCachingCollector returns a new Collector wrapping the provided raw Collector.
88109
func newCachingCollector(rawColl *collector) Collector {
89110
cc := &cachingCollector{

query.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,14 @@ func (q *Query) scanRow(rows *sql.Rows, dest []any) (map[string]any, errors.With
235235
}
236236
return result, nil
237237
}
238+
239+
// Close releases the prepared statement if one was cached.
240+
func (q *Query) Close() error {
241+
if q.stmt != nil {
242+
err := q.stmt.Close()
243+
q.stmt = nil
244+
q.conn = nil
245+
return err
246+
}
247+
return nil
248+
}

reload.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ func Reload(e Exporter, configFile *string) error {
2525
configCurrent.Collectors = configCurrent.Collectors[:0]
2626
}
2727
configCurrent.Collectors = configNext.Collectors
28-
slog.Debug("Total collector size change", "from", len(configCurrent.Collectors), "to", len(configNext.Collectors))
28+
slog.Debug("Total collector size change", "from", len(configCurrent.Collectors), "to",
29+
len(configNext.Collectors))
2930

3031
// Reload targets
3132
switch {
@@ -49,27 +50,24 @@ func Reload(e Exporter, configFile *string) error {
4950
func reloadTarget(e Exporter, nc, cc *cfg.Config) error {
5051
slog.Warn("Recreating target...")
5152

52-
// Intended: we want to preserve connection details from the previous config. Only collectors will be updated.
5353
nc.Target.DSN = cc.Target.DSN
54-
// Apply the new target configuration
5554
cc.Target = nc.Target
56-
// Recreate the target object
5755
target, err := NewTarget("", cc.Target.Name, "", string(cc.Target.DSN),
5856
cc.Target.Collectors(), nil, cc.Globals, cc.Target.EnablePing)
5957
if err != nil {
6058
slog.Error("Error recreating a target", "error", err)
6159
return err
6260
}
6361

64-
// Populate the target list
62+
// Close old targets before replacing — releases sql.DB pools and sql.Stmts.
63+
closeTargets(e.Targets())
6564
e.UpdateTarget([]Target{target})
6665
slog.Warn("Collectors have been successfully updated for the target")
6766
return nil
6867
}
6968

7069
func reloadJobs(e Exporter, nc, cc *cfg.Config) error {
7170
slog.Warn("Recreating jobs...")
72-
// We want to preserve `static_configs`` from the previous config revision to avoid any connection changes
7371
for _, currentJob := range cc.Jobs {
7472
for _, newJob := range nc.Jobs {
7573
if newJob.Name == currentJob.Name {
@@ -96,7 +94,19 @@ func reloadJobs(e Exporter, nc, cc *cfg.Config) error {
9694
return updateErr
9795
}
9896

97+
// Close old targets before replacing — releases sql.DB pools and sql.Stmts.
98+
closeTargets(e.Targets())
9999
e.UpdateTarget(targets)
100100
slog.Warn("Collectors have been successfully updated for the jobs")
101101
return nil
102102
}
103+
104+
// closeTargets closes each target's database connection and prepared statements, logging but not propagating errors so
105+
// a single bad close does not block the rest.
106+
func closeTargets(targets []Target) {
107+
for _, t := range targets {
108+
if err := t.Close(); err != nil {
109+
slog.Warn("Error closing target during reload", "error", err)
110+
}
111+
}
112+
}

reload_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package sql_exporter
2+
3+
import (
4+
"fmt"
5+
_ "net/http/pprof"
6+
"os"
7+
"path/filepath"
8+
"runtime"
9+
"strings"
10+
"testing"
11+
12+
_ "github.com/mithrandie/csvq-driver"
13+
14+
"github.com/prometheus/client_golang/prometheus"
15+
)
16+
17+
// setupCSVDir creates a temp directory with a minimal CSV file usable as a table.
18+
func setupCSVDirs(t *testing.T, n int) []string {
19+
t.Helper()
20+
base := t.TempDir()
21+
dirs := make([]string, n)
22+
for i := range n {
23+
dir := filepath.Join(base, fmt.Sprintf("csv_%d", i))
24+
if err := os.MkdirAll(dir, 0o755); err != nil {
25+
t.Fatalf("mkdir CSV dir %d: %v", i, err)
26+
}
27+
if err := os.WriteFile(filepath.Join(dir, "metrics.csv"), []byte("value\n1\n"), 0o644); err != nil {
28+
t.Fatalf("write CSV %d: %v", i, err)
29+
}
30+
dirs[i] = dir
31+
}
32+
return dirs
33+
}
34+
35+
// writeConfig writes a sql_exporter YAML config file pointing at csvDir with
36+
// n targets, returning the config file path.
37+
func writeConfig(t *testing.T, dirs []string, n int) string {
38+
t.Helper()
39+
40+
var sb strings.Builder
41+
for i := range n {
42+
fmt.Fprintf(&sb, `
43+
- collector_name: col%d
44+
metrics:
45+
- metric_name: csvq_value_%d
46+
type: gauge
47+
help: "test metric %d"
48+
values: [value]
49+
query: "SELECT value FROM metrics"
50+
`, i, i, i)
51+
}
52+
collectors := sb.String()
53+
54+
sb.Reset()
55+
for i := range n {
56+
fmt.Fprintf(&sb, " target%d: csvq:%s\n", i, dirs[i])
57+
}
58+
targets := sb.String()
59+
60+
content := fmt.Sprintf(`
61+
global:
62+
scrape_timeout: 10s
63+
scrape_timeout_offset: 500ms
64+
min_interval: 0s
65+
max_connections: 3
66+
max_idle_connections: 3
67+
68+
collector_files: []
69+
70+
collectors:%s
71+
72+
jobs:
73+
- job_name: test_job
74+
collectors: [col0, col1, col2, col3, col4, col5, col6, col7, col8, col9]
75+
static_configs:
76+
- targets:
77+
%s`, collectors, targets)
78+
79+
cfgFile := filepath.Join(t.TempDir(), "sql_exporter.yml")
80+
if err := os.WriteFile(cfgFile, []byte(content), 0o644); err != nil {
81+
t.Fatalf("write config: %v", err)
82+
}
83+
84+
return cfgFile
85+
}
86+
87+
func printMemStats(t *testing.T, label string) {
88+
t.Helper()
89+
var ms runtime.MemStats
90+
runtime.ReadMemStats(&ms)
91+
t.Logf("[%s] HeapAlloc=%.2f MB HeapObjects=%d",
92+
label, float64(ms.HeapAlloc)/1024/1024, ms.HeapObjects)
93+
}
94+
95+
func runReloadCycles(t *testing.T, e Exporter, cfgFile string, numCycles int) int {
96+
t.Helper()
97+
98+
printMemStats(t, "initial")
99+
initialGoroutines := runtime.NumGoroutine()
100+
t.Logf("initial goroutines: %d", initialGoroutines)
101+
102+
for cycle := 1; cycle <= numCycles; cycle++ {
103+
for _, old := range e.Targets() {
104+
if err := old.Close(); err != nil {
105+
t.Logf("cycle %02d close error: %v", cycle, err)
106+
}
107+
}
108+
109+
if err := Reload(e, &cfgFile); err != nil {
110+
t.Fatalf("cycle %02d Reload: %v", cycle, err)
111+
}
112+
113+
if cycle%10 == 0 {
114+
runtime.GC()
115+
goroutines := runtime.NumGoroutine()
116+
printMemStats(t, fmt.Sprintf("cycle %02d", cycle))
117+
t.Logf("cycle %02d | goroutines: %d (+%d vs initial)",
118+
cycle, goroutines, goroutines-initialGoroutines)
119+
}
120+
}
121+
122+
return runtime.NumGoroutine() - initialGoroutines
123+
}
124+
125+
func TestReloadMemoryLeak(t *testing.T) {
126+
const (
127+
numTargets = 10
128+
numCycles = 500
129+
tolerance = 5
130+
)
131+
132+
dirs := setupCSVDirs(t, numTargets)
133+
cfgFile := writeConfig(t, dirs, numTargets)
134+
135+
e, err := NewExporter(cfgFile, prometheus.NewRegistry())
136+
if err != nil {
137+
t.Fatalf("NewExporter: %v", err)
138+
}
139+
140+
delta := runReloadCycles(t, e, cfgFile, numCycles)
141+
142+
t.Logf("goroutine delta=%d (expected <= %d)", delta, tolerance)
143+
if delta > tolerance {
144+
t.Errorf("expected goroutine delta <= %d, got %d — leak still present", tolerance, delta)
145+
}
146+
}

target.go

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ const (
3232
type Target interface {
3333
// Collect is the equivalent of prometheus.Collector.Collect(), but takes a context to run in.
3434
Collect(ctx context.Context, ch chan<- Metric)
35+
Close() error
3536
JobGroup() string
3637
}
3738

@@ -88,8 +89,10 @@ func NewTarget(
8889
collectors = append(collectors, c)
8990
}
9091

91-
upDesc := NewAutomaticMetricDesc(logContext, upMetricName, upMetricHelp, prometheus.GaugeValue, constLabelPairs)
92-
scrapeDurationDesc := NewAutomaticMetricDesc(logContext, scrapeDurationName, scrapeDurationHelp, prometheus.GaugeValue, constLabelPairs)
92+
upDesc := NewAutomaticMetricDesc(logContext, upMetricName, upMetricHelp,
93+
prometheus.GaugeValue, constLabelPairs)
94+
scrapeDurationDesc := NewAutomaticMetricDesc(logContext, scrapeDurationName, scrapeDurationHelp,
95+
prometheus.GaugeValue, constLabelPairs)
9396
t := target{
9497
name: tname,
9598
jobGroup: jg,
@@ -143,12 +146,37 @@ func (t *target) Collect(ctx context.Context, ch chan<- Metric) {
143146
}
144147
}
145148

149+
// Close closes all collectors' prepared statements and the underlying *sql.DB connection pool. Safe to call even if
150+
// the connection was never opened.
151+
func (t *target) Close() error {
152+
var errs []error
153+
// Close prepared statements first — before the db handle they reference is gone.
154+
for _, c := range t.collectors {
155+
if err := c.Close(); err != nil {
156+
errs = append(errs, err)
157+
}
158+
}
159+
// Close the connection pool, which terminates all internal sql.DB goroutines (connectionOpener,
160+
// connectionResetter) and releases idle connections.
161+
if t.conn != nil {
162+
if err := t.conn.Close(); err != nil {
163+
errs = append(errs, err)
164+
}
165+
t.conn = nil
166+
}
167+
if len(errs) > 0 {
168+
return fmt.Errorf("target %s close errors: %v", t.logContext, errs)
169+
}
170+
return nil
171+
}
172+
146173
func (t *target) ping(ctx context.Context) errors.WithContext {
147-
// Create the DB handle, if necessary. It won't usually open an actual connection, so we'll need to ping afterwards.
148-
// We cannot do this only once at creation time because the sql.Open() documentation says it "may" open an actual
149-
// connection, so it "may" actually fail to open a handle to a DB that's initially down.
174+
// Create the DB handle, if necessary. It won't usually open an actual connection, so we'll need to ping
175+
// afterwards. We cannot do this only once at creation time because the sql.Open() documentation says it "may" open
176+
// an actual connection, so it "may" actually fail to open a handle to a DB that's initially down.
150177
if t.conn == nil {
151-
conn, err := OpenConnection(ctx, t.logContext, t.dsn, t.globalConfig.MaxConns, t.globalConfig.MaxIdleConns, t.globalConfig.MaxConnLifetime)
178+
conn, err := OpenConnection(ctx, t.logContext, t.dsn, t.globalConfig.MaxConns,
179+
t.globalConfig.MaxIdleConns, t.globalConfig.MaxConnLifetime)
152180
if err != nil {
153181
if err != ctx.Err() {
154182
return errors.Wrap(t.logContext, err)
@@ -160,12 +188,13 @@ func (t *target) ping(ctx context.Context) errors.WithContext {
160188
}
161189

162190
// If we have a handle and the context is not closed, test whether the database is up.
163-
// FIXME: we ping the database during each request even with cacheCollector. It leads
164-
// to additional charges for paid database services.
191+
// FIXME: we ping the database during each request even with cacheCollector. It leads to additional charges for
192+
// paid database services.
165193
if t.conn != nil && ctx.Err() == nil && *t.enablePing {
166194
var err error
167-
// Ping up to max_connections + 1 times as long as the returned error is driver.ErrBadConn, to purge the connection
168-
// pool of bad connections. This might happen if the previous scrape timed out and in-flight queries got canceled.
195+
// Ping up to max_connections + 1 times as long as the returned error is driver.ErrBadConn, to purge the
196+
// connection pool of bad connections. This might happen if the previous scrape timed out and in-flight queries
197+
// got canceled.
169198
for i := 0; i <= t.globalConfig.MaxConns; i++ {
170199
if err = PingDB(ctx, t.conn); err != driver.ErrBadConn {
171200
break

0 commit comments

Comments
 (0)