Skip to content

Commit 4dd4f3b

Browse files
author
Chris Grierson
committed
fix: add timeout to requests.post()/get() calls in oauth.py
requests.post() and requests.get() calls in oauth.py's retrieve_token(), get_azure_entra_id_workspace_endpoints(), and PATOAuthTokenExchange.refresh() do not pass a timeout= parameter. When the OAuth endpoint is unreachable or slow, these calls block indefinitely. The SDK's per-request timeout (session.request(timeout=60)) does not protect against this because the token refresh runs inside session.auth, before the timeout takes effect. Add an http_timeout_seconds field to ClientCredentials and PATOAuthTokenExchange dataclasses (default 60, matching _BaseClient), and a timeout parameter to retrieve_token() and get_azure_entra_id_workspace_endpoints(). All call sites in credentials_provider.py and config.py now pass cfg.http_timeout_seconds so the timeout is user-configurable via Config. Fixes #1338 Signed-off-by: Chris Grierson <[email protected]>
1 parent 6c1da5b commit 4dd4f3b

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

databricks/sdk/config.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from ._base_client import _fix_host_if_needed
1616
from .client_types import ClientType, HostType
1717
from .clock import Clock, RealClock
18-
from .credentials_provider import (CredentialsStrategy, DefaultCredentials,
19-
OAuthCredentialsProvider)
20-
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
21-
DatabricksEnvironment, get_environment_for_hostname)
22-
from .oauth import (OidcEndpoints, Token,
23-
get_azure_entra_id_workspace_endpoints,
24-
get_endpoints_from_url, get_host_metadata)
18+
from .credentials_provider import CredentialsStrategy, DefaultCredentials, OAuthCredentialsProvider
19+
from .environments import ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname
20+
from .oauth import (
21+
OidcEndpoints,
22+
Token,
23+
get_azure_entra_id_workspace_endpoints,
24+
get_endpoints_from_url,
25+
get_host_metadata,
26+
)
2527

2628
logger = logging.getLogger("databricks.sdk")
2729

@@ -546,7 +548,7 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
546548
if not self.host:
547549
return None
548550
if self.is_azure and self.azure_client_id:
549-
return get_azure_entra_id_workspace_endpoints(self.host)
551+
return get_azure_entra_id_workspace_endpoints(self.host, timeout=self.http_timeout_seconds or 60)
550552
return self.databricks_oidc_endpoints
551553

552554
def debug_string(self) -> str:

databricks/sdk/credentials_provider.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,7 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
158158
# This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
159159
# above, so that we are not throwing import errors when not in
160160
# runtime and no config variables are set.
161-
from databricks.sdk.runtime import (init_runtime_legacy_auth,
162-
init_runtime_native_auth,
163-
init_runtime_repl_auth)
161+
from databricks.sdk.runtime import init_runtime_legacy_auth, init_runtime_native_auth, init_runtime_repl_auth
164162

165163
for init in [
166164
init_runtime_native_auth,
@@ -203,6 +201,7 @@ def get_notebook_pat_token() -> Optional[str]:
203201
host=cfg.host,
204202
scopes=cfg.get_scopes_as_string(),
205203
authorization_details=cfg.authorization_details,
204+
http_timeout_seconds=cfg.http_timeout_seconds or 60,
206205
)
207206

208207
def inner() -> Dict[str, str]:
@@ -232,6 +231,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
232231
use_header=True,
233232
disable_async=cfg.disable_async_token_refresh,
234233
authorization_details=cfg.authorization_details,
234+
http_timeout_seconds=cfg.http_timeout_seconds or 60,
235235
)
236236

237237
def inner() -> Dict[str, str]:
@@ -258,7 +258,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
258258
elif cfg.azure_client_id:
259259
client_id = cfg.azure_client_id
260260
client_secret = cfg.azure_client_secret
261-
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host)
261+
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds or 60)
262262
if not client_id:
263263
client_id = "databricks-cli"
264264
oidc_endpoints = cfg.databricks_oidc_endpoints
@@ -348,6 +348,7 @@ def token_source_for(resource: str) -> oauth.TokenSource:
348348
disable_async=cfg.disable_async_token_refresh,
349349
scopes=cfg.get_scopes_as_string(),
350350
authorization_details=cfg.authorization_details,
351+
http_timeout_seconds=cfg.http_timeout_seconds or 60,
351352
)
352353

353354
_ensure_host_present(cfg, token_source_for)
@@ -470,6 +471,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
470471
use_params=True,
471472
disable_async=cfg.disable_async_token_refresh,
472473
authorization_details=cfg.authorization_details,
474+
http_timeout_seconds=cfg.http_timeout_seconds or 60,
473475
)
474476

475477
def refreshed_headers() -> Dict[str, str]:
@@ -532,7 +534,9 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
532534
aad_endpoint = cfg.arm_environment.active_directory_endpoint
533535
if not cfg.azure_tenant_id:
534536
# detect Azure AD Tenant ID if it's not specified directly
535-
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host).token_endpoint
537+
token_endpoint = get_azure_entra_id_workspace_endpoints(
538+
cfg.host, timeout=cfg.http_timeout_seconds or 60
539+
).token_endpoint
536540
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
537541

538542
inner = oauth.ClientCredentials(
@@ -548,6 +552,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
548552
disable_async=cfg.disable_async_token_refresh,
549553
scopes=cfg.get_scopes_as_string(),
550554
authorization_details=cfg.authorization_details,
555+
http_timeout_seconds=cfg.http_timeout_seconds or 60,
551556
)
552557

553558
def refreshed_headers() -> Dict[str, str]:

databricks/sdk/oauth.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def retrieve_token(
194194
use_params=False,
195195
use_header=False,
196196
headers=None,
197+
timeout=60,
197198
) -> Token:
198199
logger.debug(f"Retrieving token for {client_id}")
199200
if use_params:
@@ -206,7 +207,7 @@ def retrieve_token(
206207
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
207208
else:
208209
auth = IgnoreNetrcAuth()
209-
resp = requests.post(token_url, params, auth=auth, headers=headers)
210+
resp = requests.post(token_url, params, auth=auth, headers=headers, timeout=timeout)
210211
if not resp.ok:
211212
if resp.headers["Content-Type"].startswith("application/json"):
212213
err = resp.json()
@@ -513,16 +514,18 @@ def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _Bas
513514

514515
def get_azure_entra_id_workspace_endpoints(
515516
host: str,
517+
timeout: int = 60,
516518
) -> Optional[OidcEndpoints]:
517519
"""
518520
Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
519521
using an application registered in Azure Entra ID.
520522
:param host: The Databricks workspace host.
523+
:param timeout: HTTP request timeout in seconds.
521524
:return: The OIDC endpoints for the workspace's Azure Entra ID tenant.
522525
"""
523526
# In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
524527
host = _fix_host_if_needed(host)
525-
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
528+
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False, timeout=timeout)
526529
real_auth_url = res.headers.get("location")
527530
if not real_auth_url:
528531
return None
@@ -828,6 +831,7 @@ class ClientCredentials(Refreshable):
828831
use_header: bool = False
829832
disable_async: bool = True
830833
authorization_details: str = None
834+
http_timeout_seconds: int = 60
831835

832836
def __post_init__(self):
833837
super().__init__(disable_async=self.disable_async)
@@ -848,6 +852,7 @@ def refresh(self) -> Token:
848852
params,
849853
use_params=self.use_params,
850854
use_header=self.use_header,
855+
timeout=self.http_timeout_seconds,
851856
)
852857

853858

@@ -874,6 +879,7 @@ class PATOAuthTokenExchange(Refreshable):
874879
scopes: str
875880
authorization_details: str = None
876881
disable_async: bool = True
882+
http_timeout_seconds: int = 60
877883

878884
def __post_init__(self):
879885
super().__init__(disable_async=self.disable_async)
@@ -890,7 +896,7 @@ def refresh(self) -> Token:
890896
if self.authorization_details:
891897
params["authorization_details"] = self.authorization_details
892898

893-
resp = requests.post(token_exchange_url, params)
899+
resp = requests.post(token_exchange_url, params, timeout=self.http_timeout_seconds)
894900
if not resp.ok:
895901
if resp.headers["Content-Type"].startswith("application/json"):
896902
err = resp.json()

0 commit comments

Comments
 (0)