Skip to content

Commit 07e28b7

Browse files
committed
Consolidate serviceToServiceVisitor and serviceToServiceVisitorWithFallback into one function
1 parent 41ef7b4 commit 07e28b7

File tree

7 files changed

+17
-34
lines changed

7 files changed

+17
-34
lines changed

config/auth_azure_cli.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner
5252
return azureVisitor(cfg, refreshableVisitor(inner, opts...)), nil
5353
}
5454
management := azureReuseTokenSource(t, ts, opts...)
55-
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, opts...)), nil
55+
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...)), nil
5656
}
5757

5858
func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {

config/auth_azure_client_secret.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
5757
opts := cacheOptions(cfg)
5858
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, env.AzureApplicationID), opts...)
5959
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, managementEndpoint), opts...)
60-
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, opts...))
60+
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...))
6161
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
6262
}

config/auth_azure_msi.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (creden
4646
opts := cacheOptions(cfg)
4747
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureApplicationID), opts...)
4848
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()), opts...)
49-
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, opts...))
49+
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...))
5050
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
5151
}
5252

config/auth_gcp_google_credentials.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credenti
4242
return nil, fmt.Errorf("could not obtain OAuth2 token from JSON: %w", err)
4343
}
4444
logger.Infof(ctx, "Using Google Credentials")
45-
visitor := serviceToServiceVisitorWithFallback(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token", cacheOptions(cfg)...)
45+
visitor := serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token", true, cacheOptions(cfg)...)
4646
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
4747
}
4848

config/auth_gcp_google_id.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c
4444
return credentials.CredentialsProviderFn(visitor), nil
4545
}
4646
logger.Infof(ctx, "Using Google Default Application Credentials")
47-
visitor := serviceToServiceVisitorWithFallback(inner, platform, "X-Databricks-GCP-SA-Access-Token", opts...)
47+
visitor := serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token", true, opts...)
4848
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
4949
}
5050

config/oauth_visitors.go

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ func cacheOptions(cfg *Config) []auth.Option {
2020
}
2121

2222
// serviceToServiceVisitor returns a visitor that sets the Authorization header
23-
// to the token from the auth token source and the provided secondary header to
24-
// the token from the secondary token source.
25-
func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string, opts ...auth.Option) func(r *http.Request) error {
23+
// to the token from the primary token source and the provided secondary header
24+
// to the token from the secondary token source. If secondaryOptional is true,
25+
// a failure to get the secondary token logs a warning and skips the header
26+
// instead of returning an error.
27+
func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string, secondaryOptional bool, opts ...auth.Option) func(r *http.Request) error {
2628
refreshableAuth := auth.NewCachedTokenSource(authconv.AuthTokenSource(primary), opts...)
2729
refreshableSecondary := auth.NewCachedTokenSource(authconv.AuthTokenSource(secondary), opts...)
2830
return func(r *http.Request) error {
@@ -34,36 +36,17 @@ func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHea
3436

3537
cloud, err := refreshableSecondary.Token(context.Background())
3638
if err != nil {
39+
if secondaryOptional {
40+
logger.Warnf(r.Context(), "Failed to get secondary token for %s header: %v. Skipping.", secondaryHeader, err)
41+
return nil
42+
}
3743
return fmt.Errorf("cloud token: %w", err)
3844
}
3945
r.Header.Set(secondaryHeader, cloud.AccessToken)
4046
return nil
4147
}
4248
}
4349

44-
// serviceToServiceVisitorWithFallback is like serviceToServiceVisitor but
45-
// logs a warning and skips the secondary header when the secondary token
46-
// source fails, instead of returning an error.
47-
func serviceToServiceVisitorWithFallback(primary, secondary oauth2.TokenSource, secondaryHeader string, opts ...auth.Option) func(r *http.Request) error {
48-
refreshableAuth := auth.NewCachedTokenSource(authconv.AuthTokenSource(primary), opts...)
49-
refreshableSecondary := auth.NewCachedTokenSource(authconv.AuthTokenSource(secondary), opts...)
50-
return func(r *http.Request) error {
51-
inner, err := refreshableAuth.Token(context.Background())
52-
if err != nil {
53-
return fmt.Errorf("inner token: %w", err)
54-
}
55-
inner.SetAuthHeader(r)
56-
57-
cloud, err := refreshableSecondary.Token(context.Background())
58-
if err != nil {
59-
logger.Warnf(r.Context(), "Failed to get secondary token for %s header: %v. Skipping.", secondaryHeader, err)
60-
return nil
61-
}
62-
r.Header.Set(secondaryHeader, cloud.AccessToken)
63-
return nil
64-
}
65-
}
66-
6750
// The same as serviceToServiceVisitor, but without a secondary token source.
6851
func refreshableVisitor(inner oauth2.TokenSource, opts ...auth.Option) func(r *http.Request) error {
6952
return refreshableAuthVisitor(authconv.AuthTokenSource(inner), opts...)

config/oauth_visitors_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (s *staticTokenSource) Token() (*oauth2.Token, error) {
4747
func TestServiceToServiceVisitorWithFallback_BothSucceed(t *testing.T) {
4848
primary := &staticTokenSource{token: &oauth2.Token{AccessToken: "primary-token"}}
4949
secondary := &staticTokenSource{token: &oauth2.Token{AccessToken: "secondary-token"}}
50-
visitor := serviceToServiceVisitorWithFallback(primary, secondary, "X-Secondary")
50+
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", true)
5151

5252
req, err := http.NewRequest("GET", "https://example.com", nil)
5353
require.NoError(t, err)
@@ -60,7 +60,7 @@ func TestServiceToServiceVisitorWithFallback_BothSucceed(t *testing.T) {
6060
func TestServiceToServiceVisitorWithFallback_SecondaryFails_SkipsHeader(t *testing.T) {
6161
primary := &staticTokenSource{token: &oauth2.Token{AccessToken: "primary-token"}}
6262
secondary := &staticTokenSource{err: fmt.Errorf("secondary failed")}
63-
visitor := serviceToServiceVisitorWithFallback(primary, secondary, "X-Secondary")
63+
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", true)
6464

6565
req, err := http.NewRequest("GET", "https://example.com", nil)
6666
require.NoError(t, err)
@@ -73,7 +73,7 @@ func TestServiceToServiceVisitorWithFallback_SecondaryFails_SkipsHeader(t *testi
7373
func TestServiceToServiceVisitorWithFallback_PrimaryFails_ReturnsError(t *testing.T) {
7474
primary := &staticTokenSource{err: fmt.Errorf("primary failed")}
7575
secondary := &staticTokenSource{token: &oauth2.Token{AccessToken: "secondary-token"}}
76-
visitor := serviceToServiceVisitorWithFallback(primary, secondary, "X-Secondary")
76+
visitor := serviceToServiceVisitor(primary, secondary, "X-Secondary", true)
7777

7878
req, err := http.NewRequest("GET", "https://example.com", nil)
7979
require.NoError(t, err)

0 commit comments

Comments
 (0)