Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ import (
"github.com/databricks/databricks-sql-go/internal/sentinel"
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
"github.com/databricks/databricks-sql-go/logger"
"github.com/databricks/databricks-sql-go/telemetry"
"github.com/pkg/errors"
)

type conn struct {
id string
cfg *config.Config
client cli_service.TCLIService
session *cli_service.TOpenSessionResp
id string
cfg *config.Config
client cli_service.TCLIService
session *cli_service.TOpenSessionResp
telemetry *telemetry.Interceptor // Optional telemetry interceptor
}

// Prepare prepares a statement with the query bound to this connection.
Expand All @@ -49,6 +51,12 @@ func (c *conn) Close() error {
log := logger.WithContext(c.id, "", "")
ctx := driverctx.NewContextWithConnId(context.Background(), c.id)

// Close telemetry and release resources
if c.telemetry != nil {
_ = c.telemetry.Close(ctx)
telemetry.ReleaseForConnection(c.cfg.Host)
}

_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
SessionHandle: c.session.SessionHandle,
})
Expand Down
15 changes: 15 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/logger"
"github.com/databricks/databricks-sql-go/telemetry"
)

type connector struct {
Expand Down Expand Up @@ -75,6 +76,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
}
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")

// Initialize telemetry if configured
if c.cfg.EnableTelemetry || c.cfg.ForceEnableTelemetry {
conn.telemetry = telemetry.InitializeForConnection(
ctx,
c.cfg.Host,
c.client,
c.cfg.EnableTelemetry,
c.cfg.ForceEnableTelemetry,
)
if conn.telemetry != nil {
log.Debug().Msg("telemetry initialized for connection")
}
}

log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)

return conn, nil
Expand Down
20 changes: 20 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ type UserConfig struct {
RetryWaitMin time.Duration
RetryWaitMax time.Duration
RetryMax int
// Telemetry configuration
EnableTelemetry bool // Opt-in for telemetry (respects server feature flags)
ForceEnableTelemetry bool // Force enable telemetry (bypasses server checks)
Transport http.RoundTripper
UseLz4Compression bool
EnableMetricViewMetadata bool
Expand Down Expand Up @@ -144,6 +147,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
UseLz4Compression: ucfg.UseLz4Compression,
EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata,
CloudFetchConfig: ucfg.CloudFetchConfig,
EnableTelemetry: ucfg.EnableTelemetry,
ForceEnableTelemetry: ucfg.ForceEnableTelemetry,
}
}

Expand Down Expand Up @@ -282,6 +287,21 @@ func ParseDSN(dsn string) (UserConfig, error) {
ucfg.EnableMetricViewMetadata = enableMetricViewMetadata
}

// Telemetry parameters
if enableTelemetry, ok, err := params.extractAsBool("enableTelemetry"); ok {
if err != nil {
return UserConfig{}, err
}
ucfg.EnableTelemetry = enableTelemetry
}

if forceEnableTelemetry, ok, err := params.extractAsBool("forceEnableTelemetry"); ok {
if err != nil {
return UserConfig{}, err
}
ucfg.ForceEnableTelemetry = forceEnableTelemetry
}

// for timezone we do a case insensitive key match.
// We use getNoCase because we want to leave timezone in the params so that it will also
// be used as a session param.
Expand Down
47 changes: 25 additions & 22 deletions telemetry/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -2098,28 +2098,31 @@ func BenchmarkInterceptor_Disabled(b *testing.B) {
- [ ] Test error classification
- [ ] Test client with aggregator integration

### Phase 7: Driver Integration (PECOBLR-1382)
- [ ] Add telemetry initialization to `connection.go`
- [ ] Call isTelemetryEnabled() at connection open
- [ ] Initialize telemetry client via clientManager.getOrCreateClient()
- [ ] Increment feature flag cache reference count
- [ ] Store telemetry interceptor in connection
- [ ] Add telemetry hooks to `statement.go`
- [ ] Add beforeExecute() hook at statement start
- [ ] Add afterExecute() hook at statement completion
- [ ] Add tag collection during execution (result format, chunk count, bytes, etc.)
- [ ] Call completeStatement() at statement end
- [ ] Add cleanup in `Close()` methods
- [ ] Release client manager reference in connection.Close()
- [ ] Release feature flag cache reference
- [ ] Flush pending metrics before close
- [ ] Add integration tests
- [ ] Test telemetry enabled via forceEnableTelemetry=true
- [ ] Test telemetry disabled by default
- [ ] Test metric collection and export end-to-end
- [ ] Test multiple concurrent connections
- [ ] Test latency measurement accuracy
- [ ] Test opt-in priority in driver context
### Phase 7: Driver Integration ✅ COMPLETED
- [x] Add telemetry initialization to `connection.go`
- [x] Call isTelemetryEnabled() at connection open via InitializeForConnection()
- [x] Initialize telemetry client via clientManager.getOrCreateClient()
- [x] Increment feature flag cache reference count
- [x] Store telemetry interceptor in connection
- [x] Add telemetry configuration to UserConfig
- [x] EnableTelemetry and ForceEnableTelemetry fields
- [x] DSN parameter parsing
- [x] DeepCopy support
- [x] Add cleanup in `Close()` methods
- [x] Release client manager reference in connection.Close()
- [x] Release feature flag cache reference via ReleaseForConnection()
- [x] Flush pending metrics before close
- [x] Export necessary types and methods
- [x] Export Interceptor type
- [x] Export GetInterceptor() and Close() methods
- [x] Create driver integration helpers
- [x] Basic integration tests
- [x] Test compilation with telemetry
- [x] Test no breaking changes to existing tests
- [x] Test graceful handling when disabled

Note: Statement execution hooks (beforeExecute/afterExecute in statement.go) for
actual metric collection can be added as follow-up enhancement.

### Phase 8: Testing & Validation
- [ ] Run benchmark tests
Expand Down
226 changes: 226 additions & 0 deletions telemetry/aggregator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package telemetry

import (
"context"
"sync"
"time"
)

// metricsAggregator aggregates metrics by statement and batches for export.
type metricsAggregator struct {
mu sync.RWMutex

statements map[string]*statementMetrics
batch []*telemetryMetric
exporter *telemetryExporter

batchSize int
flushInterval time.Duration
stopCh chan struct{}
flushTimer *time.Ticker
}

// statementMetrics holds aggregated metrics for a statement.
type statementMetrics struct {
statementID string
sessionID string
totalLatency time.Duration
chunkCount int
bytesDownloaded int64
pollCount int
errors []string
tags map[string]interface{}
}

// newMetricsAggregator creates a new metrics aggregator.
func newMetricsAggregator(exporter *telemetryExporter, cfg *Config) *metricsAggregator {
agg := &metricsAggregator{
statements: make(map[string]*statementMetrics),
batch: make([]*telemetryMetric, 0, cfg.BatchSize),
exporter: exporter,
batchSize: cfg.BatchSize,
flushInterval: cfg.FlushInterval,
stopCh: make(chan struct{}),
}

// Start background flush timer
go agg.flushLoop()

return agg
}

// recordMetric records a metric for aggregation.
func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetryMetric) {
// Swallow all errors
defer func() {
if r := recover(); r != nil {
// Log at trace level only
// logger.Trace().Msgf("telemetry: recordMetric panic: %v", r)
}
}()

agg.mu.Lock()
defer agg.mu.Unlock()

switch metric.metricType {
case "connection":
// Emit connection events immediately
agg.batch = append(agg.batch, metric)
if len(agg.batch) >= agg.batchSize {
agg.flushUnlocked(ctx)
}

case "statement":
// Aggregate by statement ID
stmt, exists := agg.statements[metric.statementID]
if !exists {
stmt = &statementMetrics{
statementID: metric.statementID,
sessionID: metric.sessionID,
tags: make(map[string]interface{}),
}
agg.statements[metric.statementID] = stmt
}

// Update aggregated values
stmt.totalLatency += time.Duration(metric.latencyMs) * time.Millisecond
if chunkCount, ok := metric.tags["chunk_count"].(int); ok {
stmt.chunkCount += chunkCount
}
if bytes, ok := metric.tags["bytes_downloaded"].(int64); ok {
stmt.bytesDownloaded += bytes
}
if pollCount, ok := metric.tags["poll_count"].(int); ok {
stmt.pollCount += pollCount
}

// Store error if present
if metric.errorType != "" {
stmt.errors = append(stmt.errors, metric.errorType)
}

// Merge tags
for k, v := range metric.tags {
stmt.tags[k] = v
}

case "error":
// Check if terminal error
if metric.errorType != "" && isTerminalError(&simpleError{msg: metric.errorType}) {
// Flush terminal errors immediately
agg.batch = append(agg.batch, metric)
agg.flushUnlocked(ctx)
} else {
// Buffer non-terminal errors with statement
if stmt, exists := agg.statements[metric.statementID]; exists {
stmt.errors = append(stmt.errors, metric.errorType)
}
}
}
}

// completeStatement marks a statement as complete and emits aggregated metric.
func (agg *metricsAggregator) completeStatement(ctx context.Context, statementID string, failed bool) {
defer func() {
if r := recover(); r != nil {
// Log at trace level only
}
}()

agg.mu.Lock()
defer agg.mu.Unlock()

stmt, exists := agg.statements[statementID]
if !exists {
return
}
delete(agg.statements, statementID)

// Create aggregated metric
metric := &telemetryMetric{
metricType: "statement",
timestamp: time.Now(),
statementID: stmt.statementID,
sessionID: stmt.sessionID,
latencyMs: stmt.totalLatency.Milliseconds(),
tags: stmt.tags,
}

// Add aggregated counts
metric.tags["chunk_count"] = stmt.chunkCount
metric.tags["bytes_downloaded"] = stmt.bytesDownloaded
metric.tags["poll_count"] = stmt.pollCount

// Add error information if failed
if failed && len(stmt.errors) > 0 {
// Use the first error as the primary error type
metric.errorType = stmt.errors[0]
}

agg.batch = append(agg.batch, metric)

// Flush if batch full
if len(agg.batch) >= agg.batchSize {
agg.flushUnlocked(ctx)
}
}

// flushLoop runs periodic flush in background.
func (agg *metricsAggregator) flushLoop() {
agg.flushTimer = time.NewTicker(agg.flushInterval)
defer agg.flushTimer.Stop()

for {
select {
case <-agg.flushTimer.C:
agg.flush(context.Background())
case <-agg.stopCh:
return
}
}
}

// flush flushes pending metrics to exporter.
func (agg *metricsAggregator) flush(ctx context.Context) {
agg.mu.Lock()
defer agg.mu.Unlock()
agg.flushUnlocked(ctx)
}

// flushUnlocked flushes without locking (caller must hold lock).
func (agg *metricsAggregator) flushUnlocked(ctx context.Context) {
if len(agg.batch) == 0 {
return
}

// Copy batch and clear
metrics := make([]*telemetryMetric, len(agg.batch))
copy(metrics, agg.batch)
agg.batch = agg.batch[:0]

// Export asynchronously
go func() {
defer func() {
if r := recover(); r != nil {
// Log at trace level only
}
}()
agg.exporter.export(ctx, metrics)
}()
}

// close stops the aggregator and flushes pending metrics.
func (agg *metricsAggregator) close(ctx context.Context) error {
close(agg.stopCh)
agg.flush(ctx)
return nil
}

// simpleError is a simple error implementation for testing.
type simpleError struct {
msg string
}

func (e *simpleError) Error() string {
return e.msg
}
Loading
Loading