Skip to content

Commit ebd0fbf

Browse files
Feat/change execute to execute string (#14)
feat: using snowflake `conn.execute_string()` instead of `cursor.execute()` to run sql queries! --------- Co-authored-by: avr2002 <avr13405@gmail.com>
1 parent 9052ef1 commit ebd0fbf

File tree

9 files changed

+184
-55
lines changed

9 files changed

+184
-55
lines changed

.github/workflows/ci-cd-ds-platform-utils.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Publish DS Projen
1+
name: Publish DS Platform Utils
22

33
on:
44
workflow_dispatch:
@@ -16,7 +16,7 @@ jobs:
1616
- name: Checkout Repository
1717
uses: actions/checkout@v4
1818
with:
19-
fetch-depth: 0 # Fetch all history for version tagging
19+
fetch-depth: 0 # Fetch all history for version tagging
2020

2121
- name: Set up uv
2222
uses: astral-sh/setup-uv@v5
@@ -44,7 +44,7 @@ jobs:
4444
cache-dependency-glob: "${{ github.workspace }}/uv.lock"
4545

4646
- name: Run pre-commit hooks
47-
run: SKIP=no-commit-to-branch uv run poe lint # using poethepoet needs to be setup before using poe lint
47+
run: SKIP=no-commit-to-branch uv run poe lint # using poethepoet needs to be setup before using poe lint
4848

4949
build-wheel:
5050
name: Build Wheel

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ds-platform-utils"
3-
version = "0.2.3"
3+
version = "0.3.0"
44
description = "Utility library for Pattern Data Science."
55
readme = "README.md"
66
authors = [
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Shared Snowflake utility functions."""
2+
3+
import warnings
4+
from typing import Iterable, Optional
5+
6+
from snowflake.connector import SnowflakeConnection
7+
from snowflake.connector.cursor import SnowflakeCursor
8+
from snowflake.connector.errors import ProgrammingError
9+
10+
11+
def _execute_sql(conn: SnowflakeConnection, sql: str) -> Optional[SnowflakeCursor]:
12+
"""Execute SQL statement(s) using Snowflake's ``connection.execute_string()`` and return the *last* resulting cursor.
13+
14+
Snowflake's ``execute_string`` allows a single string containing multiple SQL
15+
statements (separated by semicolons) to be executed at once. Unlike
16+
``cursor.execute()``, which handles exactly one statement and returns a single
17+
cursor object, ``execute_string`` returns a **list of cursors**—one cursor for each
18+
individual SQL statement in the batch.
19+
20+
:param conn: Snowflake connection object
21+
:param sql: SQL query or batch of semicolon-delimited SQL statements
22+
:return: The cursor corresponding to the last executed statement, or None if no
23+
statements were executed or if the SQL contains only whitespace/comments
24+
"""
25+
if not sql.strip():
26+
return None
27+
28+
try:
29+
cursors: Iterable[SnowflakeCursor] = conn.execute_string(sql.strip())
30+
31+
if cursors is None:
32+
return None
33+
34+
*_, last = cursors
35+
return last
36+
except ProgrammingError as e:
37+
if "Empty SQL statement" in str(e):
38+
# raise a warning and return None
39+
warnings.warn("Empty SQL statement encountered; returning None.", category=UserWarning, stacklevel=2)
40+
return None
41+
raise

src/ds_platform_utils/_snowflake/write_audit_publish.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from jinja2 import DebugUndefined, Template
88
from snowflake.connector.cursor import SnowflakeCursor
99

10+
from ds_platform_utils._snowflake.run_query import _execute_sql
1011
from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA
1112

1213

@@ -200,8 +201,8 @@ def run_query(query: str, cursor: Optional[SnowflakeCursor] = None) -> None:
200201
print(f"Would execute query:\n{query}")
201202
return
202203

203-
# Count statements so we can tell Snowflake exactly how many to expect
204-
cursor.execute(query, num_statements=0) # 0 means any number of statements
204+
# run the query using _execute_sql utility which handles multiple statements via execute_string
205+
_execute_sql(cursor.connection, query)
205206
cursor.connection.commit()
206207

207208

@@ -216,7 +217,10 @@ def run_audit_query(query: str, cursor: Optional[SnowflakeCursor] = None) -> dic
216217
if cursor is None:
217218
return {"mock_result": True}
218219

219-
cursor.execute(query)
220+
cursor = _execute_sql(cursor.connection, query)
221+
if cursor is None:
222+
return {}
223+
220224
result = cursor.fetchone()
221225
if not result:
222226
return {}
@@ -243,11 +247,17 @@ def fetch_table_preview(
243247
if not cursor:
244248
return [{"mock_col": "mock_val"}]
245249

246-
cursor.execute(f"""
250+
cursor = _execute_sql(
251+
cursor.connection,
252+
f"""
247253
SELECT *
248254
FROM {database}.{schema}.{table_name}
249255
LIMIT {n_rows};
250-
""")
256+
""",
257+
)
258+
if cursor is None:
259+
return []
260+
251261
columns = [col[0] for col in cursor.description]
252262
rows = cursor.fetchall()
253263
return [dict(zip(columns, row)) for row in rows]

src/ds_platform_utils/metaflow/get_snowflake_connection.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from metaflow import Snowflake, current
55
from snowflake.connector import SnowflakeConnection
66

7+
from ds_platform_utils._snowflake.run_query import _execute_sql
8+
79
####################
810
# --- Metaflow --- #
911
####################
@@ -41,7 +43,12 @@ def get_snowflake_connection(
4143
In metaflow, each step is a separate Python process, so the connection will automatically be
4244
closed at the end of any steps that use this singleton.
4345
"""
44-
return _create_snowflake_connection(use_utc=use_utc, query_tag=current.project_name)
46+
if current and hasattr(current, "project_name"):
47+
query_tag = current.project_name
48+
else:
49+
query_tag = None
50+
51+
return _create_snowflake_connection(use_utc=use_utc, query_tag=query_tag)
4552

4653

4754
#####################
@@ -66,11 +73,10 @@ def _create_snowflake_connection(
6673
if query_tag:
6774
queries.append(f"ALTER SESSION SET QUERY_TAG = '{query_tag}';")
6875

69-
# Execute all queries in single batch
70-
with conn.cursor() as cursor:
71-
sql = "\n".join(queries)
72-
_debug_print_query(sql)
73-
cursor.execute(sql, num_statements=0)
76+
# Merge into single SQL batch
77+
sql = "\n".join(queries)
78+
_debug_print_query(sql)
79+
_execute_sql(conn, sql)
7480

7581
return conn
7682

src/ds_platform_utils/metaflow/pandas.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from snowflake.connector import SnowflakeConnection
1212
from snowflake.connector.pandas_tools import write_pandas
1313

14+
from ds_platform_utils._snowflake.run_query import _execute_sql
1415
from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA
1516
from ds_platform_utils.metaflow.get_snowflake_connection import _debug_print_query, get_snowflake_connection
1617
from ds_platform_utils.metaflow.write_audit_publish import (
@@ -111,15 +112,14 @@ def publish_pandas( # noqa: PLR0913 (too many arguments)
111112

112113
# set warehouse
113114
if warehouse is not None:
114-
with conn.cursor() as cur:
115-
cur.execute(f"USE WAREHOUSE {warehouse};")
115+
_execute_sql(conn, f"USE WAREHOUSE {warehouse};")
116116

117-
# set query tag for cost tracking in select.dev
118-
# REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly,
119-
# so we set a session query tag instead.
120-
tags = get_select_dev_query_tags()
121-
query_tag_str = json.dumps(tags)
122-
cur.execute(f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';")
117+
# set query tag for cost tracking in select.dev
118+
# REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly,
119+
# so we set a session query tag instead.
120+
tags = get_select_dev_query_tags()
121+
query_tag_str = json.dumps(tags)
122+
_execute_sql(conn, f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';")
123123

124124
# https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/snowpark/api/snowflake.snowpark.Session.write_pandas
125125
write_pandas(
@@ -198,16 +198,20 @@ def query_pandas_from_snowflake(
198198
current.card.append(Markdown(f"```sql\n{query}\n```"))
199199

200200
conn: SnowflakeConnection = get_snowflake_connection(use_utc)
201-
with conn.cursor() as cur:
202-
if warehouse is not None:
203-
cur.execute(f"USE WAREHOUSE {warehouse};")
201+
if warehouse is not None:
202+
_execute_sql(conn, f"USE WAREHOUSE {warehouse};")
204203

204+
cursor_result = _execute_sql(conn, query)
205+
if cursor_result is None:
206+
# No statements to execute, return empty DataFrame
207+
df = pd.DataFrame()
208+
else:
205209
# force_return_table=True -- returns a Pyarrow Table always even if the result is empty
206-
result: pyarrow.Table = cur.execute(query).fetch_arrow_all(force_return_table=True)
207-
210+
result: pyarrow.Table = cursor_result.fetch_arrow_all(force_return_table=True)
208211
df = result.to_pandas()
209212
df.columns = df.columns.str.lower()
210213

211-
current.card.append(Markdown("### Query Result"))
212-
current.card.append(Table.from_dataframe(df.head()))
213-
return df
214+
current.card.append(Markdown("### Query Result"))
215+
current.card.append(Table.from_dataframe(df.head()))
216+
217+
return df

src/ds_platform_utils/metaflow/write_audit_publish.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from metaflow.cards import Artifact, Markdown, Table
1111
from snowflake.connector.cursor import SnowflakeCursor
1212

13+
from ds_platform_utils._snowflake.run_query import _execute_sql
1314
from ds_platform_utils.metaflow.get_snowflake_connection import get_snowflake_connection
1415

1516
if TYPE_CHECKING:
@@ -97,7 +98,7 @@ def get_select_dev_query_tags() -> Dict[str, str]:
9798
stacklevel=2,
9899
)
99100

100-
def extract(prefix: str, default: str = "unknown") -> str:
101+
def _extract(prefix: str, default: str = "unknown") -> str:
101102
for tag in fetched_tags:
102103
if tag.startswith(prefix + ":"):
103104
return tag.split(":", 1)[1]
@@ -106,19 +107,19 @@ def extract(prefix: str, default: str = "unknown") -> str:
106107
# most of these will be unknown if no tags are set on the flow
107108
# (most likely for the flow runs which are triggered manually locally)
108109
return {
109-
"app": extract(
110+
"app": _extract(
110111
"ds.domain"
111112
), # first tag after 'app:', is the domain of the flow, fetched from current tags of the flow
112-
"workload_id": extract(
113+
"workload_id": _extract(
113114
"ds.project"
114115
), # second tag after 'workload_id:', is the project of the flow which it belongs to
115-
"flow_name": current.flow_name, # name of the metaflow flow
116+
"flow_name": current.flow_name,
116117
"project": current.project_name, # Project name from the @project decorator, lets us
117118
# identify the flow’s project without relying on user tags (added via --tag).
118119
"step_name": current.step_name, # name of the current step
119120
"run_id": current.run_id, # run_id: unique id of the current run
120121
"user": current.username, # username of user who triggered the run (argo-workflows if its a deployed flow)
121-
"domain": extract("ds.domain"), # business unit (domain) of the flow, same as app
122+
"domain": _extract("ds.domain"), # business unit (domain) of the flow, same as app
122123
"namespace": current.namespace, # namespace of the flow
123124
"perimeter": str(os.environ.get("OB_CURRENT_PERIMETER") or os.environ.get("OBP_PERIMETER")),
124125
"is_production": str(
@@ -216,7 +217,7 @@ def publish( # noqa: PLR0913, D417
216217

217218
with conn.cursor() as cur:
218219
if warehouse is not None:
219-
cur.execute(f"USE WAREHOUSE {warehouse}")
220+
_execute_sql(conn, f"USE WAREHOUSE {warehouse}")
220221

221222
last_op_was_write = False
222223
for operation in write_audit_publish(
@@ -334,20 +335,28 @@ def fetch_table_preview(
334335
:param table_name: Table name
335336
:param cursor: Snowflake cursor
336337
"""
337-
cursor.execute(f"""
338-
SELECT *
339-
FROM {database}.{schema}.{table_name}
340-
LIMIT {n_rows};
341-
""")
342-
columns = [col[0] for col in cursor.description]
343-
rows = cursor.fetchall()
344-
345-
# Create header row plus data rows
346-
table_rows = [[Artifact(col) for col in columns]] # Header row
347-
for row in rows:
348-
table_rows.append([Artifact(val) for val in row]) # Data rows
349-
350-
return [
351-
Markdown(f"### Table Preview: ({database}.{schema}.{table_name})"),
352-
Table(table_rows),
353-
]
338+
if cursor is None:
339+
return []
340+
else:
341+
result_cursor = _execute_sql(
342+
cursor.connection,
343+
f"""
344+
SELECT *
345+
FROM {database}.{schema}.{table_name}
346+
LIMIT {n_rows};
347+
""",
348+
)
349+
if result_cursor is None:
350+
return []
351+
columns = [col[0] for col in result_cursor.description]
352+
rows = result_cursor.fetchall()
353+
354+
# Create header row plus data rows
355+
table_rows = [[Artifact(col) for col in columns]] # Header row
356+
for row in rows:
357+
table_rows.append([Artifact(val) for val in row]) # Data rows
358+
359+
return [
360+
Markdown(f"### Table Preview: ({database}.{schema}.{table_name})"),
361+
Table(table_rows),
362+
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Functional test for _execute_sql."""
2+
3+
from typing import Generator
4+
5+
import pytest
6+
from snowflake.connector import SnowflakeConnection
7+
8+
from ds_platform_utils._snowflake.run_query import _execute_sql
9+
from ds_platform_utils.metaflow.get_snowflake_connection import get_snowflake_connection
10+
11+
12+
@pytest.fixture(scope="module")
13+
def snowflake_conn() -> Generator[SnowflakeConnection, None, None]:
14+
"""Get a Snowflake connection for testing."""
15+
yield get_snowflake_connection(use_utc=True)
16+
17+
18+
def test_execute_sql_empty_string(snowflake_conn):
19+
"""Empty string returns None."""
20+
cursor = _execute_sql(snowflake_conn, "")
21+
assert cursor is None
22+
23+
24+
def test_execute_sql_whitespace_only(snowflake_conn):
25+
"""Whitespace-only string returns None."""
26+
cursor = _execute_sql(snowflake_conn, " \n\t ")
27+
assert cursor is None
28+
29+
30+
def test_execute_sql_only_semicolons(snowflake_conn):
31+
"""String with only semicolons returns None and raises warning."""
32+
with pytest.warns(UserWarning, match="Empty SQL statement encountered"):
33+
cursor = _execute_sql(snowflake_conn, " ; ;")
34+
assert cursor is None
35+
36+
37+
def test_execute_sql_only_comments(snowflake_conn):
38+
"""String with only comments returns None and raises warning."""
39+
with pytest.warns(UserWarning, match="Empty SQL statement encountered"):
40+
cursor = _execute_sql(snowflake_conn, "/* only comments */")
41+
assert cursor is None
42+
43+
44+
def test_execute_sql_single_statement(snowflake_conn):
45+
"""Single statement returns cursor with expected result."""
46+
cursor = _execute_sql(snowflake_conn, "SELECT 1 AS x;")
47+
assert cursor is not None
48+
rows = cursor.fetchall()
49+
assert len(rows) == 1
50+
assert rows[0][0] == 1
51+
52+
53+
def test_execute_sql_multi_statement(snowflake_conn):
54+
"""Multi-statement returns cursor for last statement only."""
55+
cursor = _execute_sql(snowflake_conn, "SELECT 1 AS x; SELECT 2 AS x;")
56+
assert cursor is not None
57+
rows = cursor.fetchall()
58+
assert len(rows) == 1
59+
assert rows[0][0] == 2 # Last statement result

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)