Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/snowflake/connector/platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,21 @@ def is_github_action():
)


def is_aws_wif_outbound_token_enabled():
"""
Check if AWS WIF outbound token is enabled via environment variable.

Returns:
_DetectionState: DETECTED if ENABLE_AWS_WIF_OUTBOUND_TOKEN env var is true,
NOT_DETECTED otherwise.
"""
return (
_DetectionState.DETECTED
if os.environ.get("ENABLE_AWS_WIF_OUTBOUND_TOKEN", "false").lower() == "true"
else _DetectionState.NOT_DETECTED
)


@cache
def detect_platforms(
platform_detection_timeout_seconds: float | None,
Expand Down Expand Up @@ -490,6 +505,7 @@ def detect_platforms(
"is_gce_cloud_run_service": is_gcp_cloud_run_service(),
"is_gce_cloud_run_job": is_gcp_cloud_run_job(),
"is_github_action": is_github_action(),
"is_aws_wif_outbound_token_enabled": is_aws_wif_outbound_token_enabled(),
}

# Run network-calling functions in parallel
Expand Down
60 changes: 39 additions & 21 deletions src/snowflake/connector/wif_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,29 +195,47 @@ def create_aws_attestation(
)
region = get_aws_region()
partition = session.get_partition_for_region(region)
sts_hostname = get_aws_sts_hostname(region, partition)
request = AWSRequest(
method="POST",
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
headers={
"Host": sts_hostname,
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
},
)
# TODO: Remove this environment variable check once AWS WIF outbound token is fully released
# and make it the default behavior (SNOW-2919437)
if os.environ.get("ENABLE_AWS_WIF_OUTBOUND_TOKEN", "false").lower() == "true":
sts_client = session.client("sts", region_name=region)
response = sts_client.get_web_identity_token(
Audience=[SNOWFLAKE_AUDIENCE], SigningAlgorithm="ES384"
)
jwt_token = response["WebIdentityToken"]
return WorkloadIdentityAttestation(
AttestationProvider.AWS,
jwt_token,
{"region": region, "partition": partition},
)
else:
sts_hostname = get_aws_sts_hostname(region, partition)
request = AWSRequest(
method="POST",
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
headers={
"Host": sts_hostname,
"X-Snowflake-Audience": SNOWFLAKE_AUDIENCE,
},
)

SigV4Auth(aws_creds, "sts", region).add_auth(request)
SigV4Auth(aws_creds, "sts", region).add_auth(request)

assertion_dict = {
"url": request.url,
"method": request.method,
"headers": dict(request.headers.items()),
}
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
# Unlike other providers, for AWS, we only include general identifiers (region and partition)
# rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call.
return WorkloadIdentityAttestation(
AttestationProvider.AWS, credential, {"region": region, "partition": partition}
)
assertion_dict = {
"url": request.url,
"method": request.method,
"headers": dict(request.headers.items()),
}
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode(
"utf-8"
)
# Unlike other providers, for AWS, we only include general identifiers (region and partition)
# rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call.
return WorkloadIdentityAttestation(
AttestationProvider.AWS,
credential,
{"region": region, "partition": partition},
)


def get_gcp_access_token(session_manager: SessionManager) -> str:
Expand Down
4 changes: 4 additions & 0 deletions test/csp_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(self):
b'{"region": "us-east-1", "instanceId": "i-1234567890abcdef0"}'
)
self.metadata_token = "test-token"
self.web_identity_token = "fake.jwt.token-for-testing-only"

def assume_role(self, **kwargs):
if (
Expand Down Expand Up @@ -423,6 +424,9 @@ def boto3_client(self, *args, **kwargs):
mock_client = mock.Mock()
mock_client.get_caller_identity.return_value = self.caller_identity
mock_client.assume_role = self.assume_role
mock_client.get_web_identity_token.return_value = {
"WebIdentityToken": self.web_identity_token
}
return mock_client

def __enter__(self):
Expand Down
32 changes: 32 additions & 0 deletions test/unit/test_auth_workload_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,38 @@ def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_pat
assert fake_aws_environment.assume_role_call_count == 2


@pytest.mark.parametrize(
"env_value,expected_format",
[
("true", "jwt"),
("false", "old"),
(None, "old"),
],
)
def test_aws_token_format_based_on_env_variable(
fake_aws_environment: FakeAwsEnvironment,
monkeypatch,
env_value,
expected_format,
):
"""Test that AWS uses correct token format based on ENABLE_AWS_WIF_OUTBOUND_TOKEN environment variable."""
if env_value is not None:
monkeypatch.setenv("ENABLE_AWS_WIF_OUTBOUND_TOKEN", env_value)

auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
auth_class.prepare(conn=None)

data = extract_api_data(auth_class)

assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY"
assert data["PROVIDER"] == "AWS"

if expected_format == "jwt":
assert data["TOKEN"] == fake_aws_environment.web_identity_token
else:
verify_aws_token(data["TOKEN"], fake_aws_environment.region)


# -- GCP Tests --


Expand Down
Loading