Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from ._cursor import SnowflakeCursor, SnowflakeCursorBase
from ._description import CLIENT_NAME
from ._direct_file_operation_utils import FileOperationParser, StreamDownloader
from ._network import SnowflakeRestful
from ._network import SnowflakeRestful, create_restful_client
from ._session_manager import (
AioHttpConfig,
SessionManager,
Expand Down Expand Up @@ -218,7 +218,7 @@ async def __open_connection(self):
use_numpy=self._numpy, support_negative_year=self._support_negative_year
)

self._rest = SnowflakeRestful(
self._rest = create_restful_client(
host=self.host,
port=self.port,
protocol=self._protocol,
Expand Down
52 changes: 52 additions & 0 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
from ..time_util import TimeoutBackoffCtx
from ..xp import is_xp_environment
from ._description import CLIENT_NAME
from ._session_manager import (
SessionManager,
Expand Down Expand Up @@ -858,3 +859,54 @@ async def use_session(
) -> AsyncGenerator[aiohttp.ClientSession]:
async with self._session_manager.use_session(url) as session:
yield session


def create_restful_client(
host: str = "127.0.0.1",
port: int = 8080,
protocol: str = "http",
inject_client_pause: int = 0,
connection: SnowflakeConnection | None = None,
session_manager: SessionManager | None = None,
) -> SnowflakeRestful | SnowflakeRestfulSync:
"""Factory function to create appropriate REST client based on environment.

In XP environment, returns XPRestful (sync) which uses direct XP API calls.
The XP modules are synchronous, so we return the sync version even in async context.
Otherwise, returns async SnowflakeRestful which uses HTTP.

Args:
host: Server hostname
port: Server port
protocol: Protocol (http/https)
inject_client_pause: Client pause injection for testing
connection: Snowflake connection object
session_manager: Session manager for HTTP requests

Returns:
Appropriate REST client instance
"""
if is_xp_environment():
logger.debug(
"Creating XPRestful client for XP environment (sync, even in async context)"
)
from ..xp.network import XPRestful

return XPRestful(
host=host,
port=port,
protocol=protocol,
inject_client_pause=inject_client_pause,
connection=connection,
session_manager=session_manager,
)
else:
logger.debug("Creating async SnowflakeRestful client for standard environment")
return SnowflakeRestful(
host=host,
port=port,
protocol=protocol,
inject_client_pause=inject_client_pause,
connection=connection,
session_manager=session_manager,
)
52 changes: 52 additions & 0 deletions src/snowflake/connector/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
from ..token_cache import TokenCache, TokenKey, TokenType
from ..version import VERSION
from ..xp import is_xp_environment
from .no_auth import AuthNoAuth
from .oauth import AuthByOAuth

Expand Down Expand Up @@ -187,6 +188,10 @@ def authenticate(
if isinstance(auth_instance, AuthNoAuth):
return {}

# XP environment: bypass HTTP authentication and use XP API
if is_xp_environment():
return self._authenticate_xp(session_parameters or {})

if timeout is None:
timeout = auth_instance.timeout

Expand Down Expand Up @@ -615,6 +620,53 @@ def get_token_cache(self) -> TokenCache:
)
return self._token_cache

def _authenticate_xp(
self, session_parameters: dict[Any, Any]
) -> dict[str, str | int | bool]:
"""Authenticate in XP environment using direct XP API calls.

This method bypasses HTTP authentication and directly fetches session
parameters from the XP environment.

Args:
session_parameters: Session parameters to be populated

Returns:
Dictionary of session parameters
"""
logger.debug("Authenticating in XP environment")

try:
import _snowflake

# Execute a simple query with describe_only to fetch session parameters
# This is equivalent to the monkey-patched authentication approach
result = _snowflake.execute_sql(
"SELECT 1",
is_describe_only=True,
stmt_params=None,
binding_params=None,
_no_results=False,
)

# Extract parameters from result
if isinstance(result, dict) and "parameters" in result:
for param in result["parameters"]:
if isinstance(param, dict) and "name" in param and "value" in param:
session_parameters[param["name"]] = param["value"]

# Update connection parameters
if session_parameters:
self._rest._connection._update_parameters(session_parameters)

logger.debug("XP authentication completed")
return session_parameters

except Exception as e:
logger.error(f"XP authentication failed: {e}", exc_info=True)
# Fall back to empty parameters if XP authentication fails
return session_parameters


def get_token_from_private_key(
user: str, account: str, privatekey_path: str, key_password: str | None
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
WORKLOAD_IDENTITY_AUTHENTICATOR,
ReauthenticationRequest,
SnowflakeRestful,
create_restful_client,
)
from .session_manager import (
HttpConfig,
Expand Down Expand Up @@ -1336,7 +1337,7 @@ def __open_connection(self):
use_numpy=self._numpy, support_negative_year=self._support_negative_year
)

self._rest = SnowflakeRestful(
self._rest = create_restful_client(
host=self.host,
port=self.port,
protocol=self._protocol,
Expand Down
80 changes: 80 additions & 0 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED
from .telemetry import TelemetryData, TelemetryField
from .time_util import get_time_millis
from .xp import is_xp_environment

if TYPE_CHECKING: # pragma: no cover
from pandas import DataFrame
Expand Down Expand Up @@ -1867,6 +1868,42 @@ def _upload(
file_transfer_agent.execute()
self._init_result_and_meta(file_transfer_agent.result())

def _download_stream_xp(
self, stage_location: str, decompress: bool = False
) -> IO[bytes]:
"""Downloads from stage location as a stream using XP _sfstream.

Args:
stage_location (str): The location of the stage to download from.
decompress (bool, optional): Whether to decompress the file, by
default we do not decompress.

Returns:
IO[bytes]: A stream to read from.
"""
import _sfstream

# Get RSO ID from connection if available
rso_id = getattr(self.connection, "_rso_id", None)

# Open stream for reading
stream = _sfstream.SfStream(
stage_location,
file_type=_sfstream.FileType.STAGE,
mode=_sfstream.Mode.READ,
rso_id=rso_id,
)

if decompress:
import gzip
import io

data = stream.read()
stream.close()
return io.BytesIO(gzip.decompress(data))

return stream

def _download_stream(
self, stage_location: str, decompress: bool = False
) -> IO[bytes]:
Expand All @@ -1880,6 +1917,11 @@ def _download_stream(
Returns:
IO[bytes]: A stream to read from.
"""
# XP environment: use _sfstream for direct stage access
if is_xp_environment():
return self._download_stream_xp(stage_location, decompress)

# Standard environment: use HTTP/cloud storage
# Interpret the file operation.
ret = self.connection._file_operation_parser.parse_file_operation(
stage_location=stage_location,
Expand All @@ -1893,6 +1935,38 @@ def _download_stream(
# Set up stream downloading based on the interpretation and return the stream for reading.
return self.connection._stream_downloader.download_as_stream(ret, decompress)

def _upload_stream_xp(
self,
input_stream: IO[bytes],
stage_location: str,
) -> None:
"""Uploads content in the input stream to stage location using XP _sfstream.

Args:
input_stream (IO[bytes]): A stream to read from.
stage_location (str): The location of the stage to upload to.
"""
import _sfstream

# Get RSO ID from connection if available
rso_id = getattr(self.connection, "_rso_id", None)

# Open stream for writing
stream = _sfstream.SfStream(
stage_location,
file_type=_sfstream.FileType.STAGE,
mode=_sfstream.Mode.WRITE,
rso_id=rso_id,
)

# Write content from input stream
input_stream.seek(0)
data = input_stream.read()
stream.write(data)
stream.close()

logger.debug(f"Successfully uploaded stream to {stage_location}")

def _upload_stream(
self,
input_stream: IO[bytes],
Expand All @@ -1913,6 +1987,12 @@ def _upload_stream(
if _do_reset:
self.reset()

# XP environment: use _sfstream for direct stage access
if is_xp_environment():
self._upload_stream_xp(input_stream, stage_location)
return

# Standard environment: use HTTP/cloud storage
# Interpret the file operation.
ret = self.connection._file_operation_parser.parse_file_operation(
stage_location=stage_location,
Expand Down
39 changes: 35 additions & 4 deletions src/snowflake/connector/file_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,17 @@
from .local_storage_client import SnowflakeLocalStorageClient
from .s3_storage_client import SnowflakeS3RestClient
from .storage_client import SnowflakeFileEncryptionMaterial, SnowflakeStorageClient
from .xp import is_xp_environment

if TYPE_CHECKING: # pragma: no cover
from .connection import SnowflakeConnection
from .cursor import SnowflakeCursor
from .file_compression_type import CompressionType

VALID_STORAGE = [LOCAL_FS, S3_FS, AZURE_FS, GCS_FS]
# XP storage type constant
STORED_PROC_FS = "STORED_PROC_FS"

VALID_STORAGE = [LOCAL_FS, S3_FS, AZURE_FS, GCS_FS, STORED_PROC_FS]

INJECT_WAIT_IN_PUT = 0

Expand Down Expand Up @@ -732,6 +736,17 @@ def _create_file_transfer_client(
self._command,
unsafe_file_write=self._unsafe_file_write,
)
elif self._stage_location_type == STORED_PROC_FS:
# XP storage client for stored procedures
from .xp.storage_client import XPStorageClient

return XPStorageClient(
meta,
self._stage_info,
4 * megabyte,
credentials=self._credentials,
unsafe_file_write=self._unsafe_file_write,
)
raise Exception(f"{self._stage_location_type} is an unknown stage type")

def _transfer_accelerate_config(self) -> None:
Expand Down Expand Up @@ -1237,8 +1252,24 @@ def _process_file_compression_type(self) -> None:
def _strip_stage_prefix_from_dst_file_name_for_download(self, dst_file_name):
"""Strips the stage prefix from dst_file_name for download.

Note that this is no-op in most cases, and therefore we return as is.
But for some workloads they will monkeypatch this method to add their
stripping logic.
In XP environment, strips stage name prefix from file paths to ensure
consistent behavior with non-XP downloads.

Args:
dst_file_name: The destination file name, possibly with stage prefix

Returns:
File name with stage prefix removed if in XP environment
"""
# XP-specific: strip stage prefix for stored procedure downloads
if is_xp_environment() and dst_file_name:
# Find the position after the stage name
# Stage names are in format: @stage_name/path/to/file
if dst_file_name.startswith("@"):
# Find first slash after the stage name
first_slash = dst_file_name.find("/")
if first_slash > 0:
# Strip everything up to and including the first slash
return dst_file_name[first_slash + 1 :]

return dst_file_name
Loading
Loading