Skip to content

Commit 9b33fde

Browse files
Merge branch 'main' into fix-float64-type-inference
2 parents ca22468 + a0d5e75 commit 9b33fde

File tree

12 files changed

+2003
-98
lines changed

12 files changed

+2003
-98
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.vscode
22
.idea
3+
.claude
34

45
# Binaries for programs and plugins
56
*.exe

.golangci.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ linters:
5353
# - wsl
5454

5555
linters-settings:
56+
depguard:
57+
rules:
58+
main:
59+
allow:
60+
- $gostd
61+
- github.com/databricks/databricks-sql-go
5662
gosec:
5763
exclude-generated: true
5864
severity: "low"

auth/tokenprovider/cached.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
// CachedTokenProvider wraps another provider and caches tokens
13+
type CachedTokenProvider struct {
14+
provider TokenProvider
15+
cache *Token
16+
mutex sync.RWMutex
17+
refreshing bool // prevents thundering herd
18+
// RefreshThreshold determines when to refresh (default 5 minutes before expiry)
19+
RefreshThreshold time.Duration
20+
}
21+
22+
// NewCachedTokenProvider creates a caching wrapper around any token provider
23+
func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider {
24+
return &CachedTokenProvider{
25+
provider: provider,
26+
RefreshThreshold: 5 * time.Minute,
27+
}
28+
}
29+
30+
// GetToken retrieves a token, using cache if available and valid
31+
func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) {
32+
// Check if context is already cancelled
33+
if err := ctx.Err(); err != nil {
34+
return nil, fmt.Errorf("cached token provider: context cancelled: %w", err)
35+
}
36+
37+
// Try to get from cache first
38+
p.mutex.RLock()
39+
cached := p.cache
40+
needsRefresh := p.shouldRefresh(cached)
41+
isRefreshing := p.refreshing
42+
p.mutex.RUnlock()
43+
44+
// If cache is valid and not being refreshed, return a copy
45+
if cached != nil && !needsRefresh {
46+
log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name())
47+
// Return a copy to avoid concurrent modification issues
48+
return copyToken(cached), nil
49+
}
50+
51+
// If another goroutine is already refreshing, wait briefly and retry
52+
if isRefreshing {
53+
time.Sleep(50 * time.Millisecond)
54+
p.mutex.RLock()
55+
cached = p.cache
56+
needsRefresh = p.shouldRefresh(cached)
57+
p.mutex.RUnlock()
58+
59+
if cached != nil && !needsRefresh {
60+
return copyToken(cached), nil
61+
}
62+
}
63+
64+
// Need to refresh - acquire write lock
65+
p.mutex.Lock()
66+
67+
// Double-check after acquiring write lock
68+
if p.cache != nil && !p.shouldRefresh(p.cache) {
69+
p.mutex.Unlock()
70+
return copyToken(p.cache), nil
71+
}
72+
73+
// Mark as refreshing to prevent other goroutines from also refreshing
74+
p.refreshing = true
75+
p.mutex.Unlock()
76+
77+
// Fetch new token WITHOUT holding the lock
78+
log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
79+
token, err := p.provider.GetToken(ctx)
80+
81+
// Update cache with result
82+
p.mutex.Lock()
83+
p.refreshing = false
84+
if err != nil {
85+
p.mutex.Unlock()
86+
return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
87+
}
88+
89+
p.cache = token
90+
p.mutex.Unlock()
91+
92+
return copyToken(token), nil
93+
}
94+
95+
// copyToken creates a copy of a token to avoid concurrent modification issues
96+
func copyToken(t *Token) *Token {
97+
if t == nil {
98+
return nil
99+
}
100+
101+
scopesCopy := make([]string, len(t.Scopes))
102+
copy(scopesCopy, t.Scopes)
103+
104+
return &Token{
105+
AccessToken: t.AccessToken,
106+
TokenType: t.TokenType,
107+
ExpiresAt: t.ExpiresAt,
108+
Scopes: scopesCopy,
109+
}
110+
}
111+
112+
// shouldRefresh determines if a token should be refreshed based on expiry time.
113+
// Returns true if:
114+
// - token is nil
115+
// - token has expired
116+
// - token will expire within RefreshThreshold (default 5 minutes)
117+
//
118+
// Returns false if:
119+
// - token has no expiry time (never expires)
120+
// - token is still valid and not close to expiry
121+
func (p *CachedTokenProvider) shouldRefresh(token *Token) bool {
122+
if token == nil {
123+
return true
124+
}
125+
126+
// If no expiry time, assume token doesn't expire
127+
if token.ExpiresAt.IsZero() {
128+
return false
129+
}
130+
131+
// Refresh if within threshold of expiry
132+
refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold)
133+
return time.Now().After(refreshAt)
134+
}
135+
136+
// Name returns the provider name
137+
func (p *CachedTokenProvider) Name() string {
138+
return fmt.Sprintf("cached[%s]", p.provider.Name())
139+
}
140+
141+
// ClearCache clears the cached token
142+
func (p *CachedTokenProvider) ClearCache() {
143+
p.mutex.Lock()
144+
p.cache = nil
145+
p.mutex.Unlock()
146+
}

0 commit comments

Comments
 (0)