Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graphistry/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self) -> None:

self.idp_name: Optional[str] = None
self.sso_state: Optional[str] = None
self.sso_state_created_at: Optional[float] = None
self.sso_state_ttl_s: int = get_from_env("GRAPHISTRY_SSO_STATE_TTL_S", int, 300)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self.personal_key: Optional[str] = None
self.personal_key_id: Optional[str] = None
Expand Down
7 changes: 7 additions & 0 deletions graphistry/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ class SsoStateInvalidException(SsoException):
"""
pass

class SsoStateExpiredException(SsoException):
"""
Raised when the SSO state has exceeded the client-side TTL,
meaning the server's PKCE verifier has likely expired.
"""
pass



class TokenExpireException(Exception):
Expand Down
16 changes: 15 additions & 1 deletion graphistry/pygraphistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from . import bolt_util
from .plotter import Plotter
from .util import in_databricks, setup_logger, in_ipython
from .exceptions import SsoRetrieveTokenTimeoutException, TokenExpireException, SsoStateInvalidException
from .exceptions import SsoRetrieveTokenTimeoutException, TokenExpireException, SsoStateInvalidException, SsoStateExpiredException

from .messages import (
MSG_REGISTER_MISSING_PASSWORD,
Expand Down Expand Up @@ -315,6 +315,7 @@ def _handle_auth_url(self, auth_url: str, sso_timeout: Optional[int], sso_opt_in
self.session.org_name = org_name
# finish, set back to None
self.session.sso_state = None
self.session.sso_state_created_at = None
print("Successfully logged in")
self._maybe_switch_org(org_name)
return self.api_token()
Expand All @@ -334,6 +335,8 @@ def _handle_auth_url(self, auth_url: str, sso_timeout: Optional[int], sso_opt_in
# print("Keep trying to get token ...")
# time.sleep(5)

ttl_s = self.session.sso_state_ttl_s
print(f"SSO link expires in {ttl_s // 60} minutes. If you wait longer, re-run graphistry.register() for a fresh link.")
print("Please run graphistry.sso_get_token() to complete the authentication")
return None

Expand All @@ -352,6 +355,16 @@ def _sso_get_token(self) -> Tuple[Optional[str], Optional[str]]:
if state is None:
raise SsoStateInvalidException("[SSO] Invalid SSO state: NoneType encountered")

created_at = self.session.sso_state_created_at
ttl = self.session.sso_state_ttl_s
if created_at is not None and (time.time() - created_at) > ttl:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. especially as we're not confident of SSO expiry, instead of being preemptive and risk being wrong, probably better to instead limit this to a nicer error message in the case of a failure

Ex:

  • try to do the action against the server
  • server raises an exn
  • we catch the exn, and if flagged as a potential sso expiry, stack the exns so they get both messages (important to get both, not just potentially incorrectly cloud with this)
  1. i'm unsure if this is the right place to detect such an expiry; it'd be good to test that the correct library point is being tested (live test, vs synthetic)

self.session.sso_state = None
self.session.sso_state_created_at = None
raise SsoStateExpiredException(
f"[SSO] SSO link expired (older than {ttl}s). "
f"Run graphistry.register(..., is_sso_login=True) to get a new link."
)

# print("_sso_get_token : {}".format(state))
arrow_uploader = ArrowUploader(
client_session=self.session,
Expand Down Expand Up @@ -2494,6 +2507,7 @@ def sso_state(self, value: Optional[str] = None):

# setter
self.session.sso_state = value.strip()
self.session.sso_state_created_at = time.time()

def scene_settings(self,
menu: Optional[bool] = None,
Expand Down
61 changes: 61 additions & 0 deletions graphistry/tests/test_arrow_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,64 @@ def test_sso_get_token_missing_org_raises(self, mock_get):

with pytest.raises(Exception):
au.sso_get_token(state='ignored-valid')


class TestSsoStateTtl(unittest.TestCase):

def setUp(self):
self.client = graphistry.PyGraphistry
self.client.session.sso_state = None
self.client.session.sso_state_created_at = None

def tearDown(self):
self.client.session.sso_state = None
self.client.session.sso_state_created_at = None

def test_sso_state_created_at_set_on_state_assignment(self):
import time
before = time.time()
self.client.sso_state('test-state-123')
after = time.time()
assert self.client.session.sso_state == 'test-state-123'
assert self.client.session.sso_state_created_at is not None
assert before <= self.client.session.sso_state_created_at <= after

def test_sso_state_ttl_default(self):
assert self.client.session.sso_state_ttl_s == 300

def test_expired_state_raises(self):
import time
from graphistry.exceptions import SsoStateExpiredException
self.client.sso_state('expired-state')
# Backdate created_at so it appears expired
self.client.session.sso_state_created_at = time.time() - 400
with pytest.raises(SsoStateExpiredException, match="expired"):
self.client._sso_get_token()
# State should be cleared
assert self.client.session.sso_state is None
assert self.client.session.sso_state_created_at is None

@mock.patch('requests.get')
def test_fresh_state_not_expired(self, mock_get):
import time
mock_resp = mock.Mock()
mock_resp.status_code = 200
mock_resp.raise_for_status = mock.Mock()
mock_resp.json.return_value = {
'status': 'OK',
'data': {
'token': 'tok123',
'active_organization': {
'slug': 'test-org',
'is_found': True,
'is_member': True,
}
}
}
mock_get.return_value = mock_resp

self.client.sso_state('fresh-state')
# Just created, should not be expired
with mock.patch.object(type(self.client), '_maybe_switch_org'):
token, org_name = self.client._sso_get_token()
assert token is not None
Loading