Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .schemas/authenticators.cookie_session.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@
"title": "Preserve Path",
"type": "boolean",
"description": "When set to true, any path specified in `check_session_url` will be preserved instead of overwriting the path with the path from the original request"
},
"cache": {
"type": "object",
"title": "Session Cache Configuration",
"description": "Configures the session cache which reduces the number of requests to the session store.",
"properties": {
"enabled": {
"type": "boolean",
"title": "Enable Cache",
"description": "Enable the session cache. Defaults to false.",
"default": false
},
"ttl": {
"type": "string",
"title": "Cache TTL",
"description": "The time to live for cached sessions. Defaults to 1s.",
"default": "1s",
"examples": ["1s", "60s", "5m", "1h"]
},
"max_cost": {
"type": "integer",
"title": "Maximum Cache Cost",
"description": "The maximum cost of the cache in bytes. Defaults to 100000000 (100MB).",
"default": 100000000
}
},
"additionalProperties": false
}
},
"required": ["check_session_url"],
Expand Down
8 changes: 4 additions & 4 deletions driver/configuration/provider_koanf_public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/rs/cors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"

"github.com/ory/x/configx"
"github.com/ory/x/logrusx"
Expand Down Expand Up @@ -248,7 +248,7 @@ func TestKoanfProvider(t *testing.T) {
})

t.Run("authenticator=cookie_session", func(t *testing.T) {
a := authn.NewAuthenticatorCookieSession(p, trace.NewNoopTracerProvider())
a := authn.NewAuthenticatorCookieSession(p, logger, noop.NewTracerProvider())
assert.True(t, p.AuthenticatorIsEnabled(a.GetID()))
require.NoError(t, a.Validate(nil))

Expand Down Expand Up @@ -286,7 +286,7 @@ func TestKoanfProvider(t *testing.T) {
})

t.Run("authenticator=oauth2_introspection", func(t *testing.T) {
a := authn.NewAuthenticatorOAuth2Introspection(p, logger, trace.NewNoopTracerProvider())
a := authn.NewAuthenticatorOAuth2Introspection(p, logger, noop.NewTracerProvider())
assert.True(t, p.AuthenticatorIsEnabled(a.GetID()))
require.NoError(t, a.Validate(nil))

Expand Down Expand Up @@ -434,7 +434,7 @@ func TestAuthenticatorOAuth2TokenIntrospectionPreAuthorization(t *testing.T) {
{enabled: true, id: "a", secret: "b", turl: "https://some-url", err: false},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
a := authn.NewAuthenticatorOAuth2Introspection(p, logrusx.New("", ""), trace.NewNoopTracerProvider())
a := authn.NewAuthenticatorOAuth2Introspection(p, logrusx.New("", ""), noop.NewTracerProvider())

config, _, err := a.Config(json.RawMessage(fmt.Sprintf(`{
"pre_authorization": {
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func (r *RegistryMemory) prepareAuthn() {
if r.authenticators == nil {
interim := []authn.Authenticator{
authn.NewAuthenticatorAnonymous(r.c),
authn.NewAuthenticatorCookieSession(r.c, r.trc.Provider()),
authn.NewAuthenticatorCookieSession(r.c, r.Logger(), r.trc.Provider()),
authn.NewAuthenticatorBearerToken(r.c, r.trc.Provider()),
authn.NewAuthenticatorJWT(r.c, r),
authn.NewAuthenticatorNoOp(r.c),
Expand Down
5 changes: 5 additions & 0 deletions helper/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,9 @@ var (
CodeField: http.StatusNotFound,
StatusField: http.StatusText(http.StatusNotFound),
}
ErrTooManyRequests = &herodot.DefaultError{
ErrorField: "The upstream service rate limit was exceeded",
CodeField: http.StatusTooManyRequests,
StatusField: http.StatusText(http.StatusTooManyRequests),
}
)
160 changes: 144 additions & 16 deletions pipeline/authn/authenticator_cookie_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
package authn

import (
"crypto/md5" //#nosec G501 -- MD5 is used for cache key generation, not cryptography
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/dgraph-io/ristretto"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/trace"

"github.com/pkg/errors"
"github.com/tidwall/gjson"

"github.com/ory/oathkeeper/x/header"
"github.com/ory/x/logrusx"
"github.com/ory/x/otelx"
"github.com/ory/x/stringsx"

Expand All @@ -36,16 +42,23 @@ type AuthenticatorCookieSessionFilter struct {
}

type AuthenticatorCookieSessionConfiguration struct {
Only []string `json:"only"`
CheckSessionURL string `json:"check_session_url"`
PreserveQuery bool `json:"preserve_query"`
PreservePath bool `json:"preserve_path"`
ExtraFrom string `json:"extra_from"`
SubjectFrom string `json:"subject_from"`
PreserveHost bool `json:"preserve_host"`
ForwardHTTPHeaders []string `json:"forward_http_headers"`
SetHeaders map[string]string `json:"additional_headers"`
ForceMethod string `json:"force_method"`
Only []string `json:"only"`
CheckSessionURL string `json:"check_session_url"`
PreserveQuery bool `json:"preserve_query"`
PreservePath bool `json:"preserve_path"`
ExtraFrom string `json:"extra_from"`
SubjectFrom string `json:"subject_from"`
PreserveHost bool `json:"preserve_host"`
ForwardHTTPHeaders []string `json:"forward_http_headers"`
SetHeaders map[string]string `json:"additional_headers"`
ForceMethod string `json:"force_method"`
Cache cookieSessionCacheConfig `json:"cache"`
}

type cookieSessionCacheConfig struct {
Enabled bool `json:"enabled"`
TTL string `json:"ttl"`
MaxCost int64 `json:"max_cost"`
}

func (a *AuthenticatorCookieSessionConfiguration) GetCheckSessionURL() string {
Expand Down Expand Up @@ -77,16 +90,20 @@ func (a *AuthenticatorCookieSessionConfiguration) GetForceMethod() string {
}

type AuthenticatorCookieSession struct {
c configuration.Provider
client *http.Client
tracer trace.Tracer
c configuration.Provider
client *http.Client
tracer trace.Tracer
sessionCache *ristretto.Cache[string, []byte]
cacheTTL *time.Duration
logger *logrusx.Logger
}

var _ AuthenticatorForwardConfig = new(AuthenticatorCookieSessionConfiguration)

func NewAuthenticatorCookieSession(c configuration.Provider, provider trace.TracerProvider) *AuthenticatorCookieSession {
func NewAuthenticatorCookieSession(c configuration.Provider, logger *logrusx.Logger, provider trace.TracerProvider) *AuthenticatorCookieSession {
return &AuthenticatorCookieSession{
c: c,
c: c,
logger: logger,
client: &http.Client{
Transport: otelhttp.NewTransport(
http.DefaultTransport,
Expand Down Expand Up @@ -127,9 +144,97 @@ func (a *AuthenticatorCookieSession) Config(config json.RawMessage) (*Authentica
// Add Authorization and Cookie headers for backward compatibility
c.ForwardHTTPHeaders = append(c.ForwardHTTPHeaders, []string{header.Cookie}...)

if c.Cache.TTL != "" {
cacheTTL, err := time.ParseDuration(c.Cache.TTL)
if err != nil {
return nil, err
}

if a.sessionCache != nil {
if a.cacheTTL == nil || (a.cacheTTL != nil && a.cacheTTL.Seconds() > cacheTTL.Seconds()) {
a.sessionCache.Clear()
}
}

a.cacheTTL = &cacheTTL
}

if a.sessionCache == nil {
cost := c.Cache.MaxCost
if cost == 0 {
cost = 10000000
}
a.logger.Debugf("Creating session cache with max cost: %d", cost)
cache, err := ristretto.NewCache(&ristretto.Config[string, []byte]{
NumCounters: cost * 10,
MaxCost: cost,
BufferItems: 64,
Cost: func(value []byte) int64 {
return 1
},
IgnoreInternalCost: true,
})
if err != nil {
return nil, err
}

a.sessionCache = cache
}

return &c, nil
}

func cookiesToCacheKey(cookies []*http.Cookie) string {
var parts []string
for _, cookie := range cookies {
parts = append(parts, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
}
return fmt.Sprintf("%x", md5.Sum([]byte(strings.Join(parts, "|")))) //#nosec G401 -- MD5 is used for cache key generation, not cryptography
}

type cachedSessionData struct {
Subject string `json:"subject"`
Extra map[string]interface{} `json:"extra"`
}

func (a *AuthenticatorCookieSession) sessionFromCache(config *AuthenticatorCookieSessionConfiguration, r *http.Request) *cachedSessionData {
if !config.Cache.Enabled {
return nil
}

key := cookiesToCacheKey(r.Cookies())
i, found := a.sessionCache.Get(key)
if !found {
return nil
}

var v cachedSessionData
if err := json.Unmarshal(i, &v); err != nil {
return nil
}
return &v
}

func (a *AuthenticatorCookieSession) sessionToCache(config *AuthenticatorCookieSessionConfiguration, r *http.Request, subject string, extra map[string]interface{}) {
if !config.Cache.Enabled {
return
}

key := cookiesToCacheKey(r.Cookies())
data := cachedSessionData{
Subject: subject,
Extra: extra,
}

if v, err := json.Marshal(data); err != nil {
return
} else if a.cacheTTL != nil {
a.sessionCache.SetWithTTL(key, v, 1, *a.cacheTTL)
} else {
a.sessionCache.Set(key, v, 1)
}
}

func (a *AuthenticatorCookieSession) Authenticate(r *http.Request, session *AuthenticationSession, config json.RawMessage, _ pipeline.Rule) (err error) {
ctx, span := a.tracer.Start(r.Context(), "pipeline.authn.AuthenticatorCookieSession.Authenticate")
defer otelx.End(span, &err)
Expand All @@ -144,6 +249,13 @@ func (a *AuthenticatorCookieSession) Authenticate(r *http.Request, session *Auth
return errors.WithStack(ErrAuthenticatorNotResponsible)
}

cachedSession := a.sessionFromCache(cf, r)
if cachedSession != nil {
session.Subject = cachedSession.Subject
session.Extra = cachedSession.Extra
return nil
}

body, err := forwardRequestToSessionStore(a.client, r, cf)
if err != nil {
return err
Expand All @@ -167,6 +279,9 @@ func (a *AuthenticatorCookieSession) Authenticate(r *http.Request, session *Auth

session.Subject = subject
session.Extra = extra

a.sessionToCache(cf, r, subject, extra)

return nil
}

Expand Down Expand Up @@ -204,7 +319,20 @@ func forwardRequestToSessionStore(client *http.Client, r *http.Request, cf Authe
return json.RawMessage{}, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to fetch cookie session context from remote: %+v", err))
}
return body, nil
} else {
}

switch res.StatusCode {
case http.StatusTooManyRequests:
return json.RawMessage{}, errors.WithStack(helper.ErrTooManyRequests.WithReason("Session store rate limit exceeded"))
case http.StatusServiceUnavailable:
return json.RawMessage{}, errors.WithStack(helper.ErrUpstreamServiceNotAvailable.WithReason("Session store is unavailable"))
case http.StatusInternalServerError:
return json.RawMessage{}, errors.WithStack(helper.ErrUpstreamServiceInternalServerError.WithReason("Session store returned internal server error"))
case http.StatusGatewayTimeout:
return json.RawMessage{}, errors.WithStack(helper.ErrUpstreamServiceTimeout.WithReason("Session store request timed out"))
case http.StatusNotFound:
return json.RawMessage{}, errors.WithStack(helper.ErrUpstreamServiceNotFound.WithReason("Session store endpoint not found"))
default:
return json.RawMessage{}, errors.WithStack(helper.ErrUnauthorized)
}
}
Expand Down
Loading
Loading