Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions amber/src/main/python/pytexera/storage/dataset_file_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,34 @@
import os
import requests
import urllib.parse
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

# (connect, read) timeout and retry settings for the file-service GETs below.
_CONNECT_TIMEOUT_SECONDS = 10
_READ_TIMEOUT_SECONDS = 60
_REQUEST_TIMEOUT = (_CONNECT_TIMEOUT_SECONDS, _READ_TIMEOUT_SECONDS)
_MAX_RETRIES = 3
_RETRY_BACKOFF_FACTOR = 0.5
_RETRY_STATUS_FORCELIST = (500, 502, 503, 504)


def _build_session() -> requests.Session:
"""Returns a Session that retries GETs on connection errors and 5xx."""
retry = Retry(
total=_MAX_RETRIES,
connect=_MAX_RETRIES,
read=_MAX_RETRIES,
backoff_factor=_RETRY_BACKOFF_FACTOR,
status_forcelist=_RETRY_STATUS_FORCELIST,
allowed_methods=frozenset({"GET"}),
raise_on_status=False,
)
adapter = HTTPAdapter(max_retries=retry)
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)
return session


class DatasetFileDocument:
Expand Down Expand Up @@ -69,7 +97,18 @@ def get_presigned_url(self) -> str:

params = {"filePath": encoded_file_path}

response = requests.get(self.presign_endpoint, headers=headers, params=params)
try:
with _build_session() as session:
response = session.get(
self.presign_endpoint,
headers=headers,
params=params,
timeout=_REQUEST_TIMEOUT,
)
except requests.exceptions.RequestException as e:
raise RuntimeError(
f"Failed to get presigned URL: request failed: {e}"
) from e

if response.status_code != 200:
raise RuntimeError(
Expand Down Expand Up @@ -100,7 +139,13 @@ def read_file(self) -> io.BytesIO:
:raises: RuntimeError if the retrieval fails.
"""
presigned_url = self.get_presigned_url()
response = requests.get(presigned_url)
try:
with _build_session() as session:
response = session.get(presigned_url, timeout=_REQUEST_TIMEOUT)
except requests.exceptions.RequestException as e:
raise RuntimeError(
f"Failed to retrieve file content: request failed: {e}"
) from e

if response.status_code != 200:
raise RuntimeError(
Expand Down
129 changes: 114 additions & 15 deletions amber/src/test/python/pytexera/storage/test_dataset_file_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
import io

import pytest
import requests
from unittest.mock import patch, MagicMock

from pytexera.storage.dataset_file_document import DatasetFileDocument

from pytexera.storage.dataset_file_document import (
DatasetFileDocument,
_build_session,
_REQUEST_TIMEOUT,
_MAX_RETRIES,
_RETRY_STATUS_FORCELIST,
)

DEFAULT_ENDPOINT = "http://localhost:9092/api/dataset/presign-download"
CUSTOM_ENDPOINT = "https://example.test/api/presign"
Expand Down Expand Up @@ -95,15 +101,19 @@ def _make_doc(self, monkeypatch, path="/bob@x.com/ds/v1/file.csv"):

def test_returns_presigned_url_field_from_json_body(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(
200, body={"presignedUrl": "https://signed.test/x"}
)
assert doc.get_presigned_url() == "https://signed.test/x"

def test_sends_bearer_authorization_header_with_jwt(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
_, kwargs = mock_get.call_args
Expand All @@ -113,7 +123,9 @@ def test_url_encodes_filepath_query_parameter(self, monkeypatch):
# urllib.parse.quote keeps "/" as safe by default, but encodes "@"
# and " " — pin both pieces so the contract is explicit.
doc = self._make_doc(monkeypatch, path="/bob@x.com/ds/v1/data file.csv")
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
_, kwargs = mock_get.call_args
Expand All @@ -124,29 +136,37 @@ def test_url_encodes_filepath_query_parameter(self, monkeypatch):

def test_calls_configured_endpoint(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
args, _ = mock_get.call_args
assert args[0] == CUSTOM_ENDPOINT

def test_raises_runtime_error_with_status_and_body_on_failure(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(403, body="forbidden")
with pytest.raises(RuntimeError, match=r"403.*forbidden"):
doc.get_presigned_url()

def test_raises_when_response_body_lacks_presigned_url_key(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"other": "value"})
with pytest.raises(RuntimeError, match="'presignedUrl' missing"):
doc.get_presigned_url()

def test_raises_when_response_body_is_not_valid_json(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
response = MagicMock()
response.status_code = 200
response.json.side_effect = ValueError("Expecting value")
Expand All @@ -157,14 +177,18 @@ def test_raises_when_response_body_is_not_valid_json(self, monkeypatch):

def test_raises_when_presigned_url_is_empty_string(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": ""})
with pytest.raises(RuntimeError, match="'presignedUrl' missing"):
doc.get_presigned_url()

def test_raises_when_presigned_url_is_not_a_string(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": None})
with pytest.raises(RuntimeError, match="'presignedUrl' missing"):
doc.get_presigned_url()
Expand All @@ -178,7 +202,9 @@ def _make_doc(self, monkeypatch):

def test_returns_bytesio_with_downloaded_content(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(200, content=b"hello-bytes"),
Expand All @@ -189,14 +215,18 @@ def test_returns_bytesio_with_downloaded_content(self, monkeypatch):

def test_propagates_presigned_url_failure(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(500, body="upstream down")
with pytest.raises(RuntimeError, match=r"500.*upstream down"):
doc.read_file()

def test_raises_runtime_error_when_download_fails(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(404, body="missing"),
Expand All @@ -206,11 +236,80 @@ def test_raises_runtime_error_when_download_fails(self, monkeypatch):

def test_downloads_from_presigned_url_returned_by_first_call(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(200, content=b""),
]
doc.read_file()
second_call_args, _ = mock_get.call_args_list[1]
assert second_call_args[0] == "https://signed.test/x"


class TestTimeoutsAndRetries:
def _make_doc(self, monkeypatch):
monkeypatch.setenv("USER_JWT_TOKEN", "test-jwt-token")
monkeypatch.setenv("FILE_SERVICE_GET_PRESIGNED_URL_ENDPOINT", CUSTOM_ENDPOINT)
return DatasetFileDocument("/bob@x.com/ds/v1/file.csv")

def test_presigned_url_request_passes_request_timeout(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
_, kwargs = mock_get.call_args
assert kwargs["timeout"] == _REQUEST_TIMEOUT

def test_download_request_passes_request_timeout(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(200, content=b"data"),
]
doc.read_file()
_, download_kwargs = mock_get.call_args_list[1]
assert download_kwargs["timeout"] == _REQUEST_TIMEOUT

def test_session_mounts_retry_adapter_for_http_and_https(self):
session = _build_session()
try:
for prefix in ("http://", "https://"):
retry = session.get_adapter(prefix).max_retries
assert retry.total == _MAX_RETRIES
assert retry.connect == _MAX_RETRIES
assert retry.read == _MAX_RETRIES
assert set(retry.status_forcelist) == set(_RETRY_STATUS_FORCELIST)
# Only idempotent GETs should be retried.
assert retry.allowed_methods == frozenset({"GET"})
finally:
session.close()

def test_presigned_url_request_timeout_is_wrapped_in_runtime_error(
self, monkeypatch
):
doc = self._make_doc(monkeypatch)
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.side_effect = requests.exceptions.ReadTimeout("timed out")
with pytest.raises(RuntimeError, match="request failed"):
doc.get_presigned_url()

def test_download_request_timeout_is_wrapped_in_runtime_error(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch(
"pytexera.storage.dataset_file_document.requests.Session.get"
) as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
requests.exceptions.ConnectionError("connection reset"),
]
with pytest.raises(RuntimeError, match="Failed to retrieve file content"):
doc.read_file()
Loading