Skip to content

Commit f9def54

Browse files
authored
Merge branch 'master' into small-map-task-change
2 parents 2920718 + 6297a98 commit f9def54

File tree

8 files changed

+2041
-94
lines changed

8 files changed

+2041
-94
lines changed

Dockerfile.connector

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@ ARG VERSION
77

88
RUN apt-get update && apt-get install build-essential -y \
99
&& pip install uv
10-
11-
RUN uv pip install --system --no-cache-dir -U flytekit[connector]==$VERSION \
10+
# Pin pendulum<3.0: Apache Airflow (via flytekitplugins-airflow) imports
11+
# pendulum.tz.timezone() at module load time (airflow/settings.py).
12+
# Pendulum 3.x changed the tz API, causing the connector to crash on startup:
13+
# airflow/settings.py → TIMEZONE = pendulum.tz.timezone("UTC") → AttributeError
14+
# Without this pin, uv resolves to pendulum 3.x which breaks the import chain:
15+
# pyflyte serve connector → load_implicit_plugins → airflow → pendulum → crash
16+
RUN uv pip install --system --no-cache-dir -U \
17+
"pendulum>=2.0.0,<3.0" \
18+
flytekit[connector]==$VERSION \
1219
flytekitplugins-airflow==$VERSION \
1320
flytekitplugins-bigquery==$VERSION \
1421
flytekitplugins-k8sdataservice==$VERSION \

plugins/flytekit-spark/flytekitplugins/spark/connector.py

Lines changed: 378 additions & 34 deletions
Large diffs are not rendered by default.

plugins/flytekit-spark/flytekitplugins/spark/task.py

Lines changed: 231 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from flytekit.models.task import K8sPod
1919

2020
from .models import SparkJob, SparkType
21+
from .utils import is_serverless_config
2122

2223
pyspark_sql = lazy_module("pyspark.sql")
2324
SparkSession = pyspark_sql.SparkSession
@@ -73,17 +74,133 @@ def __post_init__(self):
7374
class DatabricksV2(Spark):
7475
"""
7576
Use this to configure a Databricks task. Task's marked with this will automatically execute
76-
natively onto databricks platform as a distributed execution of spark
77+
natively onto databricks platform as a distributed execution of spark.
7778
78-
Args:
79-
databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases.
80-
For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
81-
For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html
82-
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
79+
Supports both classic compute (clusters) and serverless compute.
80+
81+
Attributes:
82+
databricks_conf (Optional[Dict[str, Union[str, dict]]]): Databricks job configuration
83+
compliant with API version 2.1, supporting 2.0 use cases.
84+
For the configuration structure, visit: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
85+
For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html
86+
databricks_instance (Optional[str]): Domain name of your deployment.
87+
Use the form <account>.cloud.databricks.com.
88+
databricks_service_credential_provider (Optional[str]): Provider name for Databricks
89+
Service Credentials for S3 access. Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var.
90+
databricks_token_secret (Optional[str]): Custom name for the K8s secret containing
91+
the Databricks token. Defaults to 'databricks-token' if not specified.
92+
notebook_path (Optional[str]): Path to Databricks notebook
93+
(e.g., "/Users/[email protected]/notebook").
94+
notebook_base_parameters (Optional[Dict[str, str]]): Parameters to pass to the notebook.
95+
96+
Compute Modes:
97+
The connector auto-detects the compute mode based on the databricks_conf contents:
98+
99+
1. Classic Compute (existing cluster):
100+
Provide `existing_cluster_id` in databricks_conf.
101+
102+
2. Classic Compute (new cluster):
103+
Provide `new_cluster` configuration in databricks_conf.
104+
105+
3. Serverless Compute (pre-configured environment):
106+
Provide `environment_key` referencing a pre-configured environment in Databricks.
107+
Do not include `existing_cluster_id` or `new_cluster`.
108+
109+
4. Serverless Compute (inline environment spec):
110+
Provide `environments` array with environment specifications.
111+
Optionally include `environment_key` to specify which environment to use.
112+
Do not include `existing_cluster_id` or `new_cluster`.
113+
114+
Example - Classic Compute with new cluster::
115+
116+
DatabricksV2(
117+
databricks_conf={
118+
"run_name": "my-spark-job",
119+
"new_cluster": {
120+
"spark_version": "13.3.x-scala2.12",
121+
"node_type_id": "m5.xlarge",
122+
"num_workers": 2,
123+
},
124+
},
125+
databricks_instance="my-workspace.cloud.databricks.com",
126+
)
127+
128+
Example - Serverless Compute with pre-configured environment::
129+
130+
DatabricksV2(
131+
databricks_conf={
132+
"run_name": "my-serverless-job",
133+
"environment_key": "my-preconfigured-env",
134+
},
135+
databricks_instance="my-workspace.cloud.databricks.com",
136+
)
137+
138+
Example - Serverless Compute with inline environment spec::
139+
140+
DatabricksV2(
141+
databricks_conf={
142+
"run_name": "my-serverless-job",
143+
"environment_key": "default",
144+
"environments": [{
145+
"environment_key": "default",
146+
"spec": {
147+
"client": "1",
148+
"dependencies": ["pandas==2.0.0", "numpy==1.24.0"],
149+
}
150+
}],
151+
},
152+
databricks_instance="my-workspace.cloud.databricks.com",
153+
)
154+
155+
Note:
156+
Serverless compute has certain limitations compared to classic compute:
157+
- Only Python and SQL are supported (no Scala or R)
158+
- Only Spark Connect APIs are supported (no RDD APIs)
159+
- Must use Unity Catalog for external data sources
160+
- No support for compute-scoped init scripts or libraries
161+
For full details, see: https://docs.databricks.com/en/compute/serverless/limitations.html
162+
163+
Serverless Entrypoint:
164+
Both classic and serverless use the same ``flytetools`` repo for their entrypoints.
165+
Classic uses ``flytekitplugins/databricks/entrypoint.py`` and serverless uses
166+
``flytekitplugins/databricks/entrypoint_serverless.py``. No additional configuration needed.
167+
168+
To override the default, provide ``git_source`` and ``python_file`` in ``databricks_conf``.
169+
170+
AWS Credentials for Serverless:
171+
Databricks serverless does not provide AWS credentials via instance metadata.
172+
To access S3 (for Flyte data), configure a Databricks Service Credential.
173+
174+
The provider name is resolved in this order:
175+
1. ``databricks_service_credential_provider`` in the task config (per-task override)
176+
2. ``FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER`` environment variable on the connector (default for all tasks)
177+
178+
The entrypoint will use this to obtain AWS credentials via:
179+
dbutils.credentials.getServiceCredentialsProvider(provider_name)
180+
181+
Notebook Support:
182+
To run a Databricks notebook instead of a Python file, set `notebook_path`.
183+
Parameters can be passed via `notebook_base_parameters`.
184+
185+
Example - Running a notebook::
186+
187+
DatabricksV2(
188+
databricks_conf={
189+
"run_name": "my-notebook-job",
190+
"new_cluster": {...},
191+
},
192+
databricks_instance="my-workspace.cloud.databricks.com",
193+
notebook_path="/Users/[email protected]/my-notebook",
194+
notebook_base_parameters={"param1": "value1"},
195+
)
83196
"""
84197

85198
databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
86199
databricks_instance: Optional[str] = None
200+
databricks_service_credential_provider: Optional[str] = None
201+
databricks_token_secret: Optional[str] = None
202+
notebook_path: Optional[str] = None
203+
notebook_base_parameters: Optional[Dict[str, str]] = None
87204

88205

89206
# This method does not reset the SparkSession since it's a bit hard to handle multiple
@@ -187,7 +304,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
187304
job._databricks_conf = cfg.databricks_conf
188305
job._databricks_instance = cfg.databricks_instance
189306

190-
return MessageToDict(job.to_flyte_idl())
307+
# Serialize to dict
308+
custom_dict = MessageToDict(job.to_flyte_idl())
309+
310+
# Add DatabricksV2-specific fields (not part of protobuf)
311+
if isinstance(self.task_config, DatabricksV2):
312+
cfg = cast(DatabricksV2, self.task_config)
313+
if cfg.databricks_service_credential_provider:
314+
custom_dict["databricksServiceCredentialProvider"] = cfg.databricks_service_credential_provider
315+
if cfg.databricks_token_secret:
316+
custom_dict["databricksTokenSecret"] = cfg.databricks_token_secret
317+
if cfg.notebook_path:
318+
custom_dict["notebookPath"] = cfg.notebook_path
319+
if cfg.notebook_base_parameters:
320+
custom_dict["notebookBaseParameters"] = cfg.notebook_base_parameters
321+
322+
return custom_dict
191323

192324
def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]:
193325
"""
@@ -210,10 +342,101 @@ def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8s
210342

211343
return K8sPod.from_pod_template(pod_template)
212344

345+
def _is_databricks_serverless(self) -> bool:
346+
"""
347+
Detect if we're running in Databricks serverless environment.
348+
349+
Serverless uses Spark Connect and requires different SparkSession handling.
350+
"""
351+
# Check for explicit serverless markers set by our entrypoint
352+
if os.environ.get("DATABRICKS_SERVERLESS") == "true":
353+
return True
354+
if os.environ.get("SPARK_CONNECT_MODE") == "true":
355+
return True
356+
357+
is_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ
358+
359+
is_serverless_cfg = False
360+
if isinstance(self.task_config, DatabricksV2):
361+
conf = self.task_config.databricks_conf or {}
362+
if is_serverless_config(conf):
363+
is_serverless_cfg = True
364+
365+
return is_databricks and (is_serverless_cfg or "SPARK_HOME" not in os.environ)
366+
367+
def _get_databricks_serverless_spark_session(self) -> Optional[SparkSession]:
368+
"""
369+
Get SparkSession in Databricks serverless environment.
370+
371+
The entrypoint injects the SparkSession into:
372+
1. Custom module '_flyte_spark_session' in sys.modules (most reliable)
373+
2. builtins.spark (backup)
374+
375+
Returns:
376+
Optional[SparkSession]: SparkSession or None if not available.
377+
"""
378+
import sys
379+
380+
# Method 1: Try custom module (most reliable - survives module reloads)
381+
try:
382+
if "_flyte_spark_session" in sys.modules:
383+
spark_module = sys.modules["_flyte_spark_session"]
384+
if hasattr(spark_module, "spark") and spark_module.spark is not None:
385+
logger.info("Got SparkSession from _flyte_spark_session module")
386+
return spark_module.spark
387+
except Exception as e:
388+
logger.debug(f"Could not get spark from _flyte_spark_session: {e}")
389+
390+
# Method 2: Try builtins (backup location)
391+
try:
392+
import builtins
393+
394+
if hasattr(builtins, "spark") and builtins.spark is not None:
395+
logger.info("Got SparkSession from builtins")
396+
return builtins.spark
397+
except Exception as e:
398+
logger.debug(f"Could not get spark from builtins: {e}")
399+
400+
# Method 3: Try __main__ module
401+
try:
402+
import __main__
403+
404+
if hasattr(__main__, "spark") and __main__.spark is not None:
405+
logger.info("Got SparkSession from __main__")
406+
return __main__.spark
407+
except Exception as e:
408+
logger.debug(f"Could not get spark from __main__: {e}")
409+
410+
# Method 4: Try active session
411+
try:
412+
from pyspark.sql import SparkSession
413+
414+
active = SparkSession.getActiveSession()
415+
if active:
416+
logger.info("Got active SparkSession")
417+
return active
418+
except Exception as e:
419+
logger.debug(f"Could not get active SparkSession: {e}")
420+
421+
logger.warning("Could not obtain SparkSession in serverless environment")
422+
return None
423+
213424
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
214425
import pyspark as _pyspark
215426

216427
ctx = FlyteContextManager.current_context()
428+
429+
# Databricks serverless uses Spark Connect - SparkSession is pre-configured
430+
if self._is_databricks_serverless():
431+
logger.info("Detected Databricks serverless environment - using pre-configured SparkSession")
432+
self.sess = self._get_databricks_serverless_spark_session()
433+
434+
if self.sess is None:
435+
logger.warning("No SparkSession available - task will run without Spark")
436+
437+
return user_params.builder().add_attr("SPARK_SESSION", self.sess).build()
438+
439+
# Standard Spark session creation for non-serverless environments
217440
sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}")
218441
if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION):
219442
# If either of above cases is not true, then we are in local execution of this task
@@ -259,7 +482,7 @@ def execute(self, **kwargs) -> Any:
259482
if ctx.execution_state and ctx.execution_state.is_local_execution():
260483
return AsyncConnectorExecutorMixin.execute(self, **kwargs)
261484
except Exception as e:
262-
click.secho(f"Connector failed to run the task with error: {e}", fg="red")
485+
click.secho(f"Connector failed to run the task with error: {e}", fg="red")
263486
click.secho("Falling back to local execution", fg="red")
264487
return PythonFunctionTask.execute(self, **kwargs)
265488

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
def is_serverless_config(databricks_conf: dict) -> bool:
2+
"""
3+
Detect if the Databricks configuration is for serverless compute.
4+
5+
Serverless is indicated by having ``environment_key`` or ``environments``
6+
without any cluster config (``existing_cluster_id`` or ``new_cluster``).
7+
8+
Args:
9+
databricks_conf (dict): The databricks job configuration dict.
10+
11+
Returns:
12+
bool: True if the configuration targets serverless compute.
13+
"""
14+
has_cluster_config = (
15+
databricks_conf.get("existing_cluster_id") is not None or databricks_conf.get("new_cluster") is not None
16+
)
17+
has_serverless_config = bool(databricks_conf.get("environment_key") or databricks_conf.get("environments"))
18+
return not has_cluster_config and has_serverless_config

0 commit comments

Comments
 (0)