Skip to content

Commit 6297a98

Browse files
authored
feat(spark): Add multi-tenant Databricks token support via cross-namespace K8s secrets (#3394)
Enable per-project Databricks authentication by reading tokens from Kubernetes secrets in workflow namespaces, with backward-compatible fallback to the FLYTE_DATABRICKS_ACCESS_TOKEN environment variable. Changes: - Add get_secret_from_k8s() for cross-namespace K8s secret reading - Add get_databricks_token() with tiered resolution (K8s -> env var) - Update DatabricksJobMetadata to persist auth_token across lifecycle - Update DatabricksConnector.create/get/delete to use per-project tokens - Add DatabricksV2.databricks_token_secret for custom secret names - Add 31 comprehensive tests covering all token resolution paths Tracking: flyteorg/flyte#6911 Signed-off-by: Rohit Sharma <[email protected]>
1 parent a735a62 commit 6297a98

File tree

4 files changed

+888
-58
lines changed

4 files changed

+888
-58
lines changed

plugins/flytekit-spark/flytekitplugins/spark/connector.py

Lines changed: 169 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import http
22
import json
3+
import logging
34
import os
45
import typing
56
from dataclasses import dataclass
@@ -13,12 +14,14 @@
1314
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_connector_secret
1415
from flytekit.models.core.execution import TaskLog
1516
from flytekit.models.literals import LiteralMap
16-
from flytekit.models.task import TaskTemplate
17+
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate
1718

1819
from .utils import is_serverless_config as _is_serverless_config
1920

2021
aiohttp = lazy_module("aiohttp")
2122

23+
logger = logging.getLogger(__name__)
24+
2225
DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
2326
DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
2427
DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER"
@@ -28,6 +31,7 @@
2831
class DatabricksJobMetadata(ResourceMeta):
2932
databricks_instance: str
3033
run_id: str
34+
auth_token: Optional[str] = None # Store auth token for get/delete operations
3135

3236

3337
def _configure_serverless(databricks_job: dict, envs: dict) -> str:
@@ -252,7 +256,11 @@ def __init__(self):
252256
super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata)
253257

254258
async def create(
255-
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs
259+
self,
260+
task_template: TaskTemplate,
261+
inputs: Optional[LiteralMap] = None,
262+
task_execution_metadata: Optional[TaskExecutionMetadata] = None,
263+
**kwargs,
256264
) -> DatabricksJobMetadata:
257265
data = json.dumps(_get_databricks_job_spec(task_template))
258266
databricks_instance = task_template.custom.get(
@@ -264,24 +272,43 @@ async def create(
264272
f"Missing databricks instance. Please set the value through the task config or set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector."
265273
)
266274

275+
# Get workflow-specific token or fall back to default
276+
namespace = task_execution_metadata.namespace if task_execution_metadata else None
277+
278+
# Extract custom secret name from task template (if provided)
279+
custom_secret_name = task_template.custom.get("databricksTokenSecret")
280+
281+
logger.info(f"Creating Databricks job for namespace: {namespace or 'unknown'}")
282+
if custom_secret_name:
283+
logger.info(f"Using custom secret name: {custom_secret_name}")
284+
285+
auth_token = get_databricks_token(
286+
namespace=namespace, task_template=task_template, secret_name=custom_secret_name
287+
)
267288
databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit"
268289

269290
async with aiohttp.ClientSession() as session:
270-
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
291+
async with session.post(databricks_url, headers=get_header(auth_token=auth_token), data=data) as resp:
271292
response = await resp.json()
272293
if resp.status != http.HTTPStatus.OK:
273294
raise RuntimeError(f"Failed to create databricks job with error: {response}")
274295

275-
return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"]))
296+
logger.info(f"Successfully created Databricks job with run_id: {response['run_id']}")
297+
return DatabricksJobMetadata(
298+
databricks_instance=databricks_instance, run_id=str(response["run_id"]), auth_token=auth_token
299+
)
276300

277301
async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource:
278302
databricks_instance = resource_meta.databricks_instance
279303
databricks_url = (
280304
f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}"
281305
)
282306

307+
# Use the stored auth token if available, otherwise fall back to default
308+
headers = get_header(auth_token=resource_meta.auth_token)
309+
283310
async with aiohttp.ClientSession() as session:
284-
async with session.get(databricks_url, headers=get_header()) as resp:
311+
async with session.get(databricks_url, headers=headers) as resp:
285312
if resp.status != http.HTTPStatus.OK:
286313
raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
287314
response = await resp.json()
@@ -312,8 +339,11 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs):
312339
databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel"
313340
data = json.dumps({"run_id": resource_meta.run_id})
314341

342+
# Use the stored auth token if available, otherwise fall back to default
343+
headers = get_header(auth_token=resource_meta.auth_token)
344+
315345
async with aiohttp.ClientSession() as session:
316-
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
346+
async with session.post(databricks_url, headers=headers, data=data) as resp:
317347
if resp.status != http.HTTPStatus.OK:
318348
raise RuntimeError(
319349
f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}"
@@ -334,9 +364,139 @@ def __init__(self):
334364
super(DatabricksConnector, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata)
335365

336366

337-
def get_header() -> typing.Dict[str, str]:
338-
token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN")
339-
return {"Authorization": f"Bearer {token}", "content-type": "application/json"}
367+
def get_secret_from_k8s(secret_name: str, secret_key: str, namespace: str) -> Optional[str]:
368+
"""Read a secret from Kubernetes using the Kubernetes Python client.
369+
370+
Args:
371+
secret_name (str): Name of the Kubernetes secret (e.g., "databricks-token").
372+
secret_key (str): Key within the secret (e.g., "token").
373+
namespace (str): Kubernetes namespace where the secret is stored.
374+
375+
Returns:
376+
Optional[str]: The secret value as a string, or None if not found.
377+
"""
378+
try:
379+
import base64
380+
381+
from kubernetes import client, config
382+
383+
# Try to load in-cluster config first (when running in K8s)
384+
try:
385+
config.load_incluster_config()
386+
except config.ConfigException:
387+
# Fall back to kubeconfig (for local testing)
388+
try:
389+
config.load_kube_config()
390+
except Exception as e:
391+
logger.warning(f"Failed to load Kubernetes config: {e}")
392+
return None
393+
394+
v1 = client.CoreV1Api()
395+
396+
try:
397+
secret = v1.read_namespaced_secret(name=secret_name, namespace=namespace)
398+
if secret.data and secret_key in secret.data:
399+
# Kubernetes secrets are base64 encoded
400+
secret_value = base64.b64decode(secret.data[secret_key]).decode("utf-8")
401+
return secret_value
402+
else:
403+
logger.debug(
404+
f"Secret '{secret_name}' exists but key '{secret_key}' not found in namespace '{namespace}'"
405+
)
406+
return None
407+
except client.exceptions.ApiException as e:
408+
if e.status == 404:
409+
logger.debug(f"Secret '{secret_name}' not found in namespace '{namespace}'")
410+
else:
411+
logger.warning(f"Error reading secret '{secret_name}' from namespace '{namespace}': {e}")
412+
return None
413+
414+
except ImportError:
415+
logger.warning("kubernetes Python package not installed - cannot read namespace secrets")
416+
return None
417+
except Exception as e:
418+
logger.warning(f"Unexpected error reading K8s secret: {e}")
419+
return None
420+
421+
422+
def get_databricks_token(
423+
namespace: Optional[str] = None, task_template: Optional[TaskTemplate] = None, secret_name: Optional[str] = None
424+
) -> str:
425+
"""Get the Databricks access token with multi-tenant support.
426+
427+
Token resolution: namespace K8s secret -> FLYTE_DATABRICKS_ACCESS_TOKEN env var.
428+
429+
Args:
430+
namespace (Optional[str]): Kubernetes namespace for workflow-specific token lookup.
431+
task_template (Optional[TaskTemplate]): Optional TaskTemplate (kept for API compatibility).
432+
secret_name (Optional[str]): Custom secret name. Defaults to 'databricks-token'.
433+
434+
Returns:
435+
str: The Databricks access token.
436+
437+
Raises:
438+
ValueError: If no token is found from any source.
439+
"""
440+
token = None
441+
token_source = "unknown"
442+
443+
# Use custom secret name or default to 'databricks-token'
444+
k8s_secret_name = secret_name or "databricks-token"
445+
446+
# Step 1: Try namespace-specific K8s secret (cross-namespace lookup)
447+
if namespace:
448+
logger.info(f"Looking for Databricks token in workflow namespace: {namespace} (secret: {k8s_secret_name})")
449+
token = get_secret_from_k8s(secret_name=k8s_secret_name, secret_key="token", namespace=namespace)
450+
451+
if token:
452+
logger.info(f"Found Databricks token in namespace '{namespace}' from secret '{k8s_secret_name}'")
453+
token_source = f"k8s_namespace:{namespace}/secret:{k8s_secret_name}"
454+
else:
455+
logger.info(
456+
f"Databricks token not found in secret '{k8s_secret_name}' in namespace '{namespace}' - trying fallback"
457+
)
458+
else:
459+
logger.info("No namespace provided for cross-namespace lookup")
460+
461+
# Step 2: Fall back to environment variable (backward compatibility)
462+
if token is None:
463+
logger.info("Falling back to default Databricks token (FLYTE_DATABRICKS_ACCESS_TOKEN)")
464+
try:
465+
token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN")
466+
token_source = "env_variable"
467+
except Exception as e:
468+
logger.error(f"Failed to get default Databricks token: {e}")
469+
raise ValueError(
470+
"No Databricks token found from any source:\n"
471+
f"1. Namespace-specific K8s secret '{k8s_secret_name}'\n"
472+
"2. FLYTE_DATABRICKS_ACCESS_TOKEN environment variable\n"
473+
f"Workflow namespace: {namespace or 'N/A'}"
474+
)
475+
476+
if not token:
477+
raise ValueError("Databricks token is empty")
478+
479+
# Log token info without exposing the actual token value
480+
token_preview = f"{token[:8]}..." if len(token) > 8 else "***"
481+
logger.info(f"Using Databricks token from: {token_source} (preview: {token_preview})")
482+
483+
return token
484+
485+
486+
def get_header(task_template: Optional[TaskTemplate] = None, auth_token: Optional[str] = None) -> typing.Dict[str, str]:
487+
"""Get the authorization header for Databricks API calls.
488+
489+
Args:
490+
task_template (Optional[TaskTemplate]): TaskTemplate with workflow-specific secret requests.
491+
auth_token (Optional[str]): Pre-fetched auth token to use directly.
492+
493+
Returns:
494+
typing.Dict[str, str]: Authorization and content-type headers.
495+
"""
496+
if auth_token is None:
497+
auth_token = get_databricks_token(task_template)
498+
499+
return {"Authorization": f"Bearer {auth_token}", "content-type": "application/json"}
340500

341501

342502
def result_state_is_available(life_cycle_state: str) -> bool:

plugins/flytekit-spark/flytekitplugins/spark/task.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class DatabricksV2(Spark):
8787
Use the form <account>.cloud.databricks.com.
8888
databricks_service_credential_provider (Optional[str]): Provider name for Databricks
8989
Service Credentials for S3 access. Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var.
90+
databricks_token_secret (Optional[str]): Custom name for the K8s secret containing
91+
the Databricks token. Defaults to 'databricks-token' if not specified.
9092
notebook_path (Optional[str]): Path to Databricks notebook
9193
(e.g., "/Users/[email protected]/notebook").
9294
notebook_base_parameters (Optional[Dict[str, str]]): Parameters to pass to the notebook.
@@ -194,12 +196,11 @@ class DatabricksV2(Spark):
194196
"""
195197

196198
databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
197-
databricks_instance: Optional[str] = None # Falls back to FLYTE_DATABRICKS_INSTANCE env var
198-
databricks_service_credential_provider: Optional[str] = (
199-
None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var
200-
)
201-
notebook_path: Optional[str] = None # Path to Databricks notebook (e.g., "/Users/[email protected]/notebook")
202-
notebook_base_parameters: Optional[Dict[str, str]] = None # Parameters to pass to the notebook
199+
databricks_instance: Optional[str] = None
200+
databricks_service_credential_provider: Optional[str] = None
201+
databricks_token_secret: Optional[str] = None
202+
notebook_path: Optional[str] = None
203+
notebook_base_parameters: Optional[Dict[str, str]] = None
203204

204205

205206
# This method does not reset the SparkSession since it's a bit hard to handle multiple
@@ -311,6 +312,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
311312
cfg = cast(DatabricksV2, self.task_config)
312313
if cfg.databricks_service_credential_provider:
313314
custom_dict["databricksServiceCredentialProvider"] = cfg.databricks_service_credential_provider
315+
if cfg.databricks_token_secret:
316+
custom_dict["databricksTokenSecret"] = cfg.databricks_token_secret
314317
if cfg.notebook_path:
315318
custom_dict["notebookPath"] = cfg.notebook_path
316319
if cfg.notebook_base_parameters:
@@ -479,7 +482,7 @@ def execute(self, **kwargs) -> Any:
479482
if ctx.execution_state and ctx.execution_state.is_local_execution():
480483
return AsyncConnectorExecutorMixin.execute(self, **kwargs)
481484
except Exception as e:
482-
click.secho(f"Connector failed to run the task with error: {e}", fg="red")
485+
click.secho(f"Connector failed to run the task with error: {e}", fg="red")
483486
click.secho("Falling back to local execution", fg="red")
484487
return PythonFunctionTask.execute(self, **kwargs)
485488

0 commit comments

Comments
 (0)