Skip to content

Commit d2a3339

Browse files
committed
Added functionality for is_disconnect()
1 parent a6a3564 commit d2a3339

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

src/databricks/sqlalchemy/base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ def do_ping(self, dbapi_connection):
342342
This method is called by SQLAlchemy when pool_pre_ping=True to verify
343343
connections are still valid before using them from the pool.
344344
345+
This implementation improves upon SQLAlchemy's default do_ping() by
346+
wrapping the cursor creation in a try block, which properly handles
347+
cases where the connection is closed and cursor() itself raises an exception.
348+
345349
Args:
346350
dbapi_connection: A raw DBAPI connection (from databricks-sql-connector)
347351
@@ -361,6 +365,39 @@ def do_ping(self, dbapi_connection):
361365
# SQLAlchemy will discard it and create a new one
362366
return False
363367

368+
def is_disconnect(self, e, connection, cursor):
369+
"""Determine if an exception indicates the connection was lost.
370+
371+
This method is called by SQLAlchemy after exceptions occur during query
372+
execution to determine if the error was due to a lost connection. If this
373+
returns True, SQLAlchemy will invalidate the connection and create a new
374+
one for the next operation.
375+
376+
This is complementary to do_ping():
377+
- do_ping() is proactive: checks connection health BEFORE queries
378+
- is_disconnect() is reactive: classifies errors AFTER they occur
379+
380+
Args:
381+
e: The exception that was raised
382+
connection: The connection that raised the exception (may be None)
383+
cursor: The cursor that raised the exception (may be None)
384+
385+
Returns:
386+
True if the error indicates a disconnect, False otherwise
387+
"""
388+
from databricks.sql.exc import InterfaceError, DatabaseError
389+
390+
# InterfaceError: Client-side errors (e.g., connection already closed)
391+
if isinstance(e, InterfaceError):
392+
return True
393+
394+
# DatabaseError: Server-side errors with invalid handle indicate session expired
395+
if isinstance(e, DatabaseError):
396+
error_msg = str(e).lower()
397+
return "invalid" in error_msg and "handle" in error_msg
398+
399+
return False
400+
364401
@reflection.cache
365402
def has_table(
366403
self, connection, table_name, schema=None, catalog=None, **kwargs

tests/test_local/e2e/test_basic.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,3 +596,73 @@ def test_pool_pre_ping_with_closed_connection(connection_details):
596596

597597
# Cleanup
598598
engine.dispose()
599+
600+
601+
def test_is_disconnect_handles_runtime_errors(db_engine):
602+
"""Test that is_disconnect() properly classifies disconnect errors during query execution.
603+
604+
This tests the reactive error handling (complementary to pool_pre_ping's proactive checking).
605+
When a connection fails DURING a query, is_disconnect() should recognize the error
606+
and tell SQLAlchemy to invalidate the connection.
607+
"""
608+
from sqlalchemy import create_engine, text
609+
from sqlalchemy.exc import DBAPIError
610+
611+
engine = create_engine(
612+
db_engine.url,
613+
pool_pre_ping=False, # Disabled - we want to test is_disconnect, not do_ping
614+
pool_size=1,
615+
max_overflow=0,
616+
)
617+
618+
# Step 1: Execute a successful query
619+
with engine.connect() as conn:
620+
result = conn.execute(text("SELECT VERSION()")).scalar()
621+
assert result is not None
622+
623+
# Get session ID of working connection
624+
raw_conn = conn.connection.dbapi_connection
625+
session_id_1 = raw_conn.get_session_id_hex()
626+
assert session_id_1 is not None
627+
628+
# Step 2: Manually close the connection to simulate server-side session expiration
629+
pooled_conn = engine.pool._pool.queue[0]
630+
pooled_conn.driver_connection.close()
631+
632+
# Step 3: Try to execute query on closed connection
633+
# This should:
634+
# 1. Fail with an exception
635+
# 2. is_disconnect() gets called by SQLAlchemy
636+
# 3. Returns True (recognizes it as disconnect error)
637+
# 4. SQLAlchemy invalidates the connection
638+
# 5. Next operation gets a fresh connection
639+
640+
# First query will fail because connection is closed
641+
try:
642+
with engine.connect() as conn:
643+
conn.execute(text("SELECT VERSION()")).scalar()
644+
# If we get here without exception, the connection wasn't actually closed
645+
pytest.skip("Connection wasn't properly closed - cannot test is_disconnect")
646+
except DBAPIError as e:
647+
# Expected - connection was closed
648+
# is_disconnect() should have been called and returned True
649+
# This causes SQLAlchemy to invalidate the connection
650+
assert "closed" in str(e).lower() or "invalid" in str(e).lower()
651+
652+
# Step 4: Next query should work because is_disconnect() invalidated the bad connection
653+
with engine.connect() as conn:
654+
result = conn.execute(text("SELECT VERSION()")).scalar()
655+
assert result is not None
656+
657+
# Verify we got a NEW connection
658+
raw_conn = conn.connection.dbapi_connection
659+
session_id_2 = raw_conn.get_session_id_hex()
660+
assert session_id_2 is not None
661+
662+
# Different session ID proves connection was invalidated and recreated
663+
assert session_id_1 != session_id_2, (
664+
"is_disconnect() should have invalidated the bad connection, "
665+
"causing SQLAlchemy to create a new one with different session ID"
666+
)
667+
668+
engine.dispose()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Tests for DatabricksDialect.is_disconnect() method."""
2+
import pytest
3+
from databricks.sqlalchemy import DatabricksDialect
4+
from databricks.sql.exc import InterfaceError, DatabaseError, OperationalError
5+
6+
7+
class TestIsDisconnect:
8+
@pytest.fixture
9+
def dialect(self):
10+
return DatabricksDialect()
11+
12+
def test_interface_error_is_disconnect(self, dialect):
13+
"""InterfaceError (client-side) is always a disconnect."""
14+
error = InterfaceError("Cannot create cursor from closed connection")
15+
assert dialect.is_disconnect(error, None, None) is True
16+
17+
def test_database_error_with_invalid_handle(self, dialect):
18+
"""DatabaseError with 'invalid handle' is a disconnect."""
19+
test_cases = [
20+
DatabaseError("Invalid SessionHandle"),
21+
DatabaseError("[Errno INVALID_HANDLE] Session does not exist"),
22+
DatabaseError("INVALID HANDLE"),
23+
DatabaseError("invalid handle"),
24+
]
25+
for error in test_cases:
26+
assert dialect.is_disconnect(error, None, None) is True
27+
28+
def test_database_error_without_invalid_handle(self, dialect):
29+
"""DatabaseError without 'invalid handle' is not a disconnect."""
30+
test_cases = [
31+
DatabaseError("Syntax error in SQL"),
32+
DatabaseError("Table not found"),
33+
DatabaseError("Permission denied"),
34+
]
35+
for error in test_cases:
36+
assert dialect.is_disconnect(error, None, None) is False
37+
38+
def test_other_errors_not_disconnect(self, dialect):
39+
"""Other exception types are not disconnects."""
40+
test_cases = [
41+
OperationalError("Timeout waiting for query"),
42+
Exception("Some random error"),
43+
]
44+
for error in test_cases:
45+
assert dialect.is_disconnect(error, None, None) is False

0 commit comments

Comments
 (0)