11import http
22import json
3+ import logging
34import os
45import typing
56from dataclasses import dataclass
1314from flytekit .extend .backend .utils import convert_to_flyte_phase , get_connector_secret
1415from flytekit .models .core .execution import TaskLog
1516from flytekit .models .literals import LiteralMap
16- from flytekit .models .task import TaskTemplate
17+ from flytekit .models .task import TaskExecutionMetadata , TaskTemplate
1718
1819from .utils import is_serverless_config as _is_serverless_config
1920
2021aiohttp = lazy_module ("aiohttp" )
2122
23+ logger = logging .getLogger (__name__ )
24+
2225DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
2326DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
2427DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER"
2831class DatabricksJobMetadata (ResourceMeta ):
2932 databricks_instance : str
3033 run_id : str
34+ auth_token : Optional [str ] = None # Store auth token for get/delete operations
3135
3236
3337def _configure_serverless (databricks_job : dict , envs : dict ) -> str :
@@ -252,7 +256,11 @@ def __init__(self):
252256 super ().__init__ (task_type_name = "spark" , metadata_type = DatabricksJobMetadata )
253257
254258 async def create (
255- self , task_template : TaskTemplate , inputs : Optional [LiteralMap ] = None , ** kwargs
259+ self ,
260+ task_template : TaskTemplate ,
261+ inputs : Optional [LiteralMap ] = None ,
262+ task_execution_metadata : Optional [TaskExecutionMetadata ] = None ,
263+ ** kwargs ,
256264 ) -> DatabricksJobMetadata :
257265 data = json .dumps (_get_databricks_job_spec (task_template ))
258266 databricks_instance = task_template .custom .get (
@@ -264,24 +272,43 @@ async def create(
264272 f"Missing databricks instance. Please set the value through the task config or set the { DEFAULT_DATABRICKS_INSTANCE_ENV_KEY } environment variable in the connector."
265273 )
266274
275+ # Get workflow-specific token or fall back to default
276+ namespace = task_execution_metadata .namespace if task_execution_metadata else None
277+
278+ # Extract custom secret name from task template (if provided)
279+ custom_secret_name = task_template .custom .get ("databricksTokenSecret" )
280+
281+ logger .info (f"Creating Databricks job for namespace: { namespace or 'unknown' } " )
282+ if custom_secret_name :
283+ logger .info (f"Using custom secret name: { custom_secret_name } " )
284+
285+ auth_token = get_databricks_token (
286+ namespace = namespace , task_template = task_template , secret_name = custom_secret_name
287+ )
267288 databricks_url = f"https://{ databricks_instance } { DATABRICKS_API_ENDPOINT } /runs/submit"
268289
269290 async with aiohttp .ClientSession () as session :
270- async with session .post (databricks_url , headers = get_header (), data = data ) as resp :
291+ async with session .post (databricks_url , headers = get_header (auth_token = auth_token ), data = data ) as resp :
271292 response = await resp .json ()
272293 if resp .status != http .HTTPStatus .OK :
273294 raise RuntimeError (f"Failed to create databricks job with error: { response } " )
274295
275- return DatabricksJobMetadata (databricks_instance = databricks_instance , run_id = str (response ["run_id" ]))
296+ logger .info (f"Successfully created Databricks job with run_id: { response ['run_id' ]} " )
297+ return DatabricksJobMetadata (
298+ databricks_instance = databricks_instance , run_id = str (response ["run_id" ]), auth_token = auth_token
299+ )
276300
277301 async def get (self , resource_meta : DatabricksJobMetadata , ** kwargs ) -> Resource :
278302 databricks_instance = resource_meta .databricks_instance
279303 databricks_url = (
280304 f"https://{ databricks_instance } { DATABRICKS_API_ENDPOINT } /runs/get?run_id={ resource_meta .run_id } "
281305 )
282306
307+ # Use the stored auth token if available, otherwise fall back to default
308+ headers = get_header (auth_token = resource_meta .auth_token )
309+
283310 async with aiohttp .ClientSession () as session :
284- async with session .get (databricks_url , headers = get_header () ) as resp :
311+ async with session .get (databricks_url , headers = headers ) as resp :
285312 if resp .status != http .HTTPStatus .OK :
286313 raise RuntimeError (f"Failed to get databricks job { resource_meta .run_id } with error: { resp .reason } " )
287314 response = await resp .json ()
@@ -312,8 +339,11 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs):
312339 databricks_url = f"https://{ resource_meta .databricks_instance } { DATABRICKS_API_ENDPOINT } /runs/cancel"
313340 data = json .dumps ({"run_id" : resource_meta .run_id })
314341
342+ # Use the stored auth token if available, otherwise fall back to default
343+ headers = get_header (auth_token = resource_meta .auth_token )
344+
315345 async with aiohttp .ClientSession () as session :
316- async with session .post (databricks_url , headers = get_header () , data = data ) as resp :
346+ async with session .post (databricks_url , headers = headers , data = data ) as resp :
317347 if resp .status != http .HTTPStatus .OK :
318348 raise RuntimeError (
319349 f"Failed to cancel databricks job { resource_meta .run_id } with error: { resp .reason } "
@@ -334,9 +364,139 @@ def __init__(self):
334364 super (DatabricksConnector , self ).__init__ (task_type_name = "databricks" , metadata_type = DatabricksJobMetadata )
335365
336366
337- def get_header () -> typing .Dict [str , str ]:
338- token = get_connector_secret ("FLYTE_DATABRICKS_ACCESS_TOKEN" )
339- return {"Authorization" : f"Bearer { token } " , "content-type" : "application/json" }
367+ def get_secret_from_k8s (secret_name : str , secret_key : str , namespace : str ) -> Optional [str ]:
368+ """Read a secret from Kubernetes using the Kubernetes Python client.
369+
370+ Args:
371+ secret_name (str): Name of the Kubernetes secret (e.g., "databricks-token").
372+ secret_key (str): Key within the secret (e.g., "token").
373+ namespace (str): Kubernetes namespace where the secret is stored.
374+
375+ Returns:
376+ Optional[str]: The secret value as a string, or None if not found.
377+ """
378+ try :
379+ import base64
380+
381+ from kubernetes import client , config
382+
383+ # Try to load in-cluster config first (when running in K8s)
384+ try :
385+ config .load_incluster_config ()
386+ except config .ConfigException :
387+ # Fall back to kubeconfig (for local testing)
388+ try :
389+ config .load_kube_config ()
390+ except Exception as e :
391+ logger .warning (f"Failed to load Kubernetes config: { e } " )
392+ return None
393+
394+ v1 = client .CoreV1Api ()
395+
396+ try :
397+ secret = v1 .read_namespaced_secret (name = secret_name , namespace = namespace )
398+ if secret .data and secret_key in secret .data :
399+ # Kubernetes secrets are base64 encoded
400+ secret_value = base64 .b64decode (secret .data [secret_key ]).decode ("utf-8" )
401+ return secret_value
402+ else :
403+ logger .debug (
404+ f"Secret '{ secret_name } ' exists but key '{ secret_key } ' not found in namespace '{ namespace } '"
405+ )
406+ return None
407+ except client .exceptions .ApiException as e :
408+ if e .status == 404 :
409+ logger .debug (f"Secret '{ secret_name } ' not found in namespace '{ namespace } '" )
410+ else :
411+ logger .warning (f"Error reading secret '{ secret_name } ' from namespace '{ namespace } ': { e } " )
412+ return None
413+
414+ except ImportError :
415+ logger .warning ("kubernetes Python package not installed - cannot read namespace secrets" )
416+ return None
417+ except Exception as e :
418+ logger .warning (f"Unexpected error reading K8s secret: { e } " )
419+ return None
420+
421+
422+ def get_databricks_token (
423+ namespace : Optional [str ] = None , task_template : Optional [TaskTemplate ] = None , secret_name : Optional [str ] = None
424+ ) -> str :
425+ """Get the Databricks access token with multi-tenant support.
426+
427+ Token resolution: namespace K8s secret -> FLYTE_DATABRICKS_ACCESS_TOKEN env var.
428+
429+ Args:
430+ namespace (Optional[str]): Kubernetes namespace for workflow-specific token lookup.
431+ task_template (Optional[TaskTemplate]): Optional TaskTemplate (kept for API compatibility).
432+ secret_name (Optional[str]): Custom secret name. Defaults to 'databricks-token'.
433+
434+ Returns:
435+ str: The Databricks access token.
436+
437+ Raises:
438+ ValueError: If no token is found from any source.
439+ """
440+ token = None
441+ token_source = "unknown"
442+
443+ # Use custom secret name or default to 'databricks-token'
444+ k8s_secret_name = secret_name or "databricks-token"
445+
446+ # Step 1: Try namespace-specific K8s secret (cross-namespace lookup)
447+ if namespace :
448+ logger .info (f"Looking for Databricks token in workflow namespace: { namespace } (secret: { k8s_secret_name } )" )
449+ token = get_secret_from_k8s (secret_name = k8s_secret_name , secret_key = "token" , namespace = namespace )
450+
451+ if token :
452+ logger .info (f"Found Databricks token in namespace '{ namespace } ' from secret '{ k8s_secret_name } '" )
453+ token_source = f"k8s_namespace:{ namespace } /secret:{ k8s_secret_name } "
454+ else :
455+ logger .info (
456+ f"Databricks token not found in secret '{ k8s_secret_name } ' in namespace '{ namespace } ' - trying fallback"
457+ )
458+ else :
459+ logger .info ("No namespace provided for cross-namespace lookup" )
460+
461+ # Step 2: Fall back to environment variable (backward compatibility)
462+ if token is None :
463+ logger .info ("Falling back to default Databricks token (FLYTE_DATABRICKS_ACCESS_TOKEN)" )
464+ try :
465+ token = get_connector_secret ("FLYTE_DATABRICKS_ACCESS_TOKEN" )
466+ token_source = "env_variable"
467+ except Exception as e :
468+ logger .error (f"Failed to get default Databricks token: { e } " )
469+ raise ValueError (
470+ "No Databricks token found from any source:\n "
471+ f"1. Namespace-specific K8s secret '{ k8s_secret_name } '\n "
472+ "2. FLYTE_DATABRICKS_ACCESS_TOKEN environment variable\n "
473+ f"Workflow namespace: { namespace or 'N/A' } "
474+ )
475+
476+ if not token :
477+ raise ValueError ("Databricks token is empty" )
478+
479+ # Log token info without exposing the actual token value
480+ token_preview = f"{ token [:8 ]} ..." if len (token ) > 8 else "***"
481+ logger .info (f"Using Databricks token from: { token_source } (preview: { token_preview } )" )
482+
483+ return token
484+
485+
486+ def get_header (task_template : Optional [TaskTemplate ] = None , auth_token : Optional [str ] = None ) -> typing .Dict [str , str ]:
487+ """Get the authorization header for Databricks API calls.
488+
489+ Args:
490+ task_template (Optional[TaskTemplate]): TaskTemplate with workflow-specific secret requests.
491+ auth_token (Optional[str]): Pre-fetched auth token to use directly.
492+
493+ Returns:
494+ typing.Dict[str, str]: Authorization and content-type headers.
495+ """
496+ if auth_token is None :
497+ auth_token = get_databricks_token (task_template )
498+
499+ return {"Authorization" : f"Bearer { auth_token } " , "content-type" : "application/json" }
340500
341501
342502def result_state_is_available (life_cycle_state : str ) -> bool :
0 commit comments