Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions core/authenticate/authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ const (
// AccessTokenClientAssertion is used to authenticate using access token generated
// by the system for the user
AccessTokenClientAssertion ClientAssertion = "access_token"
// OpaqueTokenClientAssertion is used to authenticate using opaque token generated
// for API clients
OpaqueTokenClientAssertion ClientAssertion = "opaque"
// JWTGrantClientAssertion is used to authenticate using JWT token generated
// using public/private key pair that provides access token for the client
JWTGrantClientAssertion ClientAssertion = "jwt_grant"
Expand All @@ -60,7 +57,6 @@ var APIAssertions = []ClientAssertion{
PATClientAssertion,
AccessTokenClientAssertion,
JWTGrantClientAssertion,
OpaqueTokenClientAssertion,
// ClientCredentialsClientAssertion should be removed in future to avoid DDOS attacks on CPU
// and should only be allowed to be used get access token for the client
ClientCredentialsClientAssertion,
Expand Down Expand Up @@ -138,6 +134,8 @@ type Principal struct {
// Type is the namespace of principal
// E.g. app/user, app/serviceuser, app/pat
Type string
// AuthVia is the credential type that authenticated this principal
AuthVia ClientAssertion

User *user.User
ServiceUser *serviceuser.ServiceUser
Expand Down
15 changes: 12 additions & 3 deletions core/authenticate/authenticators.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
frontiersession "github.com/raystack/frontier/core/authenticate/session"
"github.com/raystack/frontier/core/authenticate/token"
"github.com/raystack/frontier/core/serviceuser"
patErrors "github.com/raystack/frontier/core/userpat/errors"
"github.com/raystack/frontier/internal/bootstrap/schema"
"github.com/raystack/frontier/pkg/errors"
Expand All @@ -26,7 +27,6 @@ var authenticators = map[ClientAssertion]AuthenticatorFunc{
AccessTokenClientAssertion: authenticateWithAccessToken,
JWTGrantClientAssertion: authenticateWithJWTGrant,
ClientCredentialsClientAssertion: authenticateWithClientCredentials,
OpaqueTokenClientAssertion: authenticateWithClientCredentials,
PassthroughHeaderClientAssertion: authenticateWithPassthroughHeader,
}

Expand Down Expand Up @@ -189,8 +189,17 @@ func authenticateWithJWTGrant(ctx context.Context, s *Service) (Principal, error
ServiceUser: &serviceUser,
}, nil
}
s.log.DebugContext(ctx, "failed to parse as user token ", "err", err)
return Principal{}, errors.ErrUnauthenticated
switch {
case errors.Is(err, serviceuser.ErrTokenNotJWT),
errors.Is(err, serviceuser.ErrInvalidKeyID),
errors.Is(err, serviceuser.ErrCredNotExist):
return Principal{}, errSkip
case errors.Is(err, serviceuser.ErrInvalidCred):
s.log.DebugContext(ctx, "service user grant failed verification", "err", err)
return Principal{}, errors.ErrUnauthenticated
default:
return Principal{}, err
}
}

// authenticateWithClientCredentials validates client_id:client_secret credentials.
Expand Down
4 changes: 4 additions & 0 deletions core/authenticate/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,9 @@ func (s Service) GetPrincipal(ctx context.Context, assertions ...ClientAssertion
}

if val, ok := GetPrincipalFromContext(ctx); ok {
if len(assertions) > 0 && !slices.Contains(assertions, val.AuthVia) {
return Principal{}, errors.ErrUnauthenticated
}
return *val, nil
}

Expand All @@ -772,6 +775,7 @@ func (s Service) GetPrincipal(ctx context.Context, assertions ...ClientAssertion
}
principal, err := authenticator(ctx, &s)
if err == nil {
principal.AuthVia = assertion
return principal, nil
}
if !errors.Is(err, errSkip) {
Expand Down
123 changes: 88 additions & 35 deletions core/authenticate/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/raystack/frontier/core/authenticate/token"
"github.com/raystack/frontier/core/serviceuser"
"github.com/raystack/frontier/core/user"
patModels "github.com/raystack/frontier/core/userpat/models"
"github.com/raystack/frontier/internal/bootstrap/schema"
mailerMock "github.com/raystack/frontier/pkg/mailer/mocks"
pkgMetadata "github.com/raystack/frontier/pkg/metadata"
Expand Down Expand Up @@ -88,8 +89,9 @@ func TestService_GetPrincipal(t *testing.T) {
assertions: []authenticate.ClientAssertion{authenticate.SessionClientAssertion},
},
want: authenticate.Principal{
ID: userID.String(),
Type: schema.UserPrincipal,
ID: userID.String(),
Type: schema.UserPrincipal,
AuthVia: authenticate.SessionClientAssertion,
User: &user.User{
ID: userID.String(),
},
Expand Down Expand Up @@ -149,8 +151,9 @@ func TestService_GetPrincipal(t *testing.T) {
assertions: []authenticate.ClientAssertion{authenticate.AccessTokenClientAssertion},
},
want: authenticate.Principal{
ID: userID.String(),
Type: schema.UserPrincipal,
ID: userID.String(),
Type: schema.UserPrincipal,
AuthVia: authenticate.AccessTokenClientAssertion,
User: &user.User{
ID: userID.String(),
},
Expand Down Expand Up @@ -195,8 +198,9 @@ func TestService_GetPrincipal(t *testing.T) {
assertions: []authenticate.ClientAssertion{authenticate.JWTGrantClientAssertion},
},
want: authenticate.Principal{
ID: userID.String(),
Type: schema.ServiceUserPrincipal,
ID: userID.String(),
Type: schema.ServiceUserPrincipal,
AuthVia: authenticate.JWTGrantClientAssertion,
ServiceUser: &serviceuser.ServiceUser{
ID: userID.String(),
},
Expand Down Expand Up @@ -240,35 +244,9 @@ func TestService_GetPrincipal(t *testing.T) {
assertions: []authenticate.ClientAssertion{authenticate.ClientCredentialsClientAssertion},
},
want: authenticate.Principal{
ID: userID.String(),
Type: schema.ServiceUserPrincipal,
ServiceUser: &serviceuser.ServiceUser{
ID: userID.String(),
},
},
wantErr: false,
setup: func() *authenticate.Service {
mockFlow, mockUserService, mockTokenService, mockSessionService, mockServiceUserService := createMocks(t)

mockServiceUserService.EXPECT().GetBySecret(mock.Anything, "user", "password").Return(serviceuser.ServiceUser{
ID: userID.String(),
}, nil)

return authenticate.NewService(nil, authenticate.Config{},
mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil)
},
},
{
name: "fetch principal from opaque token",
args: args{
ctx: metadata.NewIncomingContext(context.Background(), map[string][]string{
consts.UserSecretGatewayKey: {userToken},
}),
assertions: []authenticate.ClientAssertion{authenticate.OpaqueTokenClientAssertion},
},
want: authenticate.Principal{
ID: userID.String(),
Type: schema.ServiceUserPrincipal,
ID: userID.String(),
Type: schema.ServiceUserPrincipal,
AuthVia: authenticate.ClientCredentialsClientAssertion,
ServiceUser: &serviceuser.ServiceUser{
ID: userID.String(),
},
Expand Down Expand Up @@ -457,3 +435,78 @@ func TestService_StartFlow(t *testing.T) {
})
}
}

func TestService_GetPrincipal_JWTGrantSkipsNonGrantToken(t *testing.T) {
userID := uuid.New()
patValue := "fpt_opaque-not-a-jwt"

mockFlow, mockUserService, mockTokenService, mockSessionService, mockServiceUserService := createMocks(t)
mockPATService := mocks.NewUserPATService(t)

mockServiceUserService.EXPECT().GetByJWT(mock.Anything, patValue).
Return(serviceuser.ServiceUser{}, serviceuser.ErrTokenNotJWT)
pat := patModels.PAT{ID: "pat-1", UserID: userID.String(), ExpiresAt: time.Now().Add(time.Hour)}
mockPATService.EXPECT().Validate(mock.Anything, patValue).Return(pat, nil)
mockUserService.EXPECT().GetByID(mock.Anything, userID.String()).
Return(user.User{ID: userID.String()}, nil)

svc := authenticate.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), authenticate.Config{},
mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, mockPATService)

ctx := metadata.NewIncomingContext(context.Background(), map[string][]string{
consts.UserTokenGatewayKey: {patValue},
})

got, err := svc.GetPrincipal(ctx,
authenticate.JWTGrantClientAssertion, authenticate.PATClientAssertion)
require.NoError(t, err)
assert.Equal(t, schema.PATPrincipal, got.Type)
require.NotNil(t, got.PAT)
assert.Equal(t, "pat-1", got.ID)
}

func TestService_GetPrincipal_RestrictsByAuthVia(t *testing.T) {
// lists mirror what the handlers pass: session.go uses {Session}; AuthToken uses the token-exchange set.
sessionOnly := []authenticate.ClientAssertion{authenticate.SessionClientAssertion}
authTokenSet := []authenticate.ClientAssertion{
authenticate.SessionClientAssertion,
authenticate.ClientCredentialsClientAssertion,
authenticate.JWTGrantClientAssertion,
authenticate.PATClientAssertion,
}

tests := []struct {
name string
authVia authenticate.ClientAssertion
allowed []authenticate.ClientAssertion
wantErr bool
}{
{"session endpoints accept a session", authenticate.SessionClientAssertion, sessionOnly, false},
{"session endpoints reject a PAT", authenticate.PATClientAssertion, sessionOnly, true},
{"session endpoints reject an access token", authenticate.AccessTokenClientAssertion, sessionOnly, true},
{"session endpoints reject client credentials", authenticate.ClientCredentialsClientAssertion, sessionOnly, true},
{"session endpoints reject a jwt grant", authenticate.JWTGrantClientAssertion, sessionOnly, true},

{"authtoken accepts a session", authenticate.SessionClientAssertion, authTokenSet, false},
{"authtoken accepts client credentials", authenticate.ClientCredentialsClientAssertion, authTokenSet, false},
{"authtoken accepts a jwt grant", authenticate.JWTGrantClientAssertion, authTokenSet, false},
{"authtoken accepts a PAT", authenticate.PATClientAssertion, authTokenSet, false},
{"authtoken rejects an access token", authenticate.AccessTokenClientAssertion, authTokenSet, true},
{"authtoken rejects passthrough", authenticate.PassthroughHeaderClientAssertion, authTokenSet, true},
}

svc := authenticate.NewService(nil, authenticate.Config{}, nil, nil, nil, nil, nil, nil, nil, nil)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := authenticate.SetContextWithPrincipal(context.Background(), &authenticate.Principal{
ID: "principal-1",
Type: schema.UserPrincipal,
AuthVia: tt.authVia,
})
if _, err := svc.GetPrincipal(ctx, tt.allowed...); (err != nil) != tt.wantErr {
t.Errorf("GetPrincipal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
1 change: 1 addition & 0 deletions core/serviceuser/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ var (
ErrInvalidCred = errors.New("service user credential is invalid")
ErrInvalidID = errors.New("service user id is invalid")
ErrInvalidKeyID = errors.New("service user key is invalid")
ErrTokenNotJWT = errors.New("token is not a jwt")
ErrConflict = errors.New("service user already exist")
ErrEmptyKey = errors.New("empty key")
ErrDisabled = errors.New("service user is disabled")
Expand Down
17 changes: 12 additions & 5 deletions core/serviceuser/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,21 +363,28 @@ func (s Service) GetBySecret(ctx context.Context, credID string, reqSecret strin
func (s Service) GetByJWT(ctx context.Context, token string) (ServiceUser, error) {
insecureToken, err := jwt.ParseInsecure([]byte(token))
if err != nil {
return ServiceUser{}, fmt.Errorf("invalid serviceuser token: %w", err)
return ServiceUser{}, fmt.Errorf("%w: %v", ErrTokenNotJWT, err)
}
tokenKID, ok := insecureToken.Get(jwk.KeyIDKey)
if !ok {
return ServiceUser{}, fmt.Errorf("invalid key id from token")
return ServiceUser{}, fmt.Errorf("missing key id in token: %w", ErrInvalidKeyID)
}
cred, err := s.credRepo.Get(ctx, tokenKID.(string))
kid, ok := tokenKID.(string)
if !ok {
return ServiceUser{}, fmt.Errorf("key id is not a string: %w", ErrInvalidKeyID)
}
if _, err := uuid.Parse(kid); err != nil {
return ServiceUser{}, fmt.Errorf("key id is not a valid uuid: %w", ErrInvalidKeyID)
}
cred, err := s.credRepo.Get(ctx, kid)
if err != nil {
return ServiceUser{}, fmt.Errorf("credential invalid of kid %s: %w", tokenKID.(string), err)
return ServiceUser{}, fmt.Errorf("credential invalid of kid %s: %w", kid, err)
}

// verify token
_, err = jwt.Parse([]byte(token), jwt.WithKeySet(cred.PublicKey))
if err != nil {
return ServiceUser{}, fmt.Errorf("invalid serviceuser token: %w", err)
return ServiceUser{}, fmt.Errorf("%w: %v", ErrInvalidCred, err)
}
return s.repo.GetByID(ctx, cred.ServiceUserID)
}
Expand Down
88 changes: 88 additions & 0 deletions core/serviceuser/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@ import (
"io"
"log/slog"
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/raystack/frontier/core/relation"
"github.com/raystack/frontier/core/serviceuser"
"github.com/raystack/frontier/core/serviceuser/mocks"
"github.com/raystack/frontier/internal/bootstrap/schema"
"github.com/raystack/frontier/pkg/utils"
)

func newTestService(t *testing.T) (*serviceuser.Service, *mocks.Repository, *mocks.CredentialRepository, *mocks.RelationService, *mocks.MembershipService) {
Expand Down Expand Up @@ -227,3 +232,86 @@ func TestService_ListByOrg(t *testing.T) {
})
}
}

func TestService_GetByJWT_Classification(t *testing.T) {
ctx := context.Background()

// buildToken signs a jwt carrying kid in its claims and returns the matching public set.
buildToken := func(t *testing.T, kid string) ([]byte, jwk.Set) {
t.Helper()
key, err := utils.CreateJWKWithKID(kid)
if err != nil {
t.Fatalf("CreateJWKWithKID: %v", err)
}
tok, err := utils.BuildToken(key, "issuer", "subject", time.Hour, nil)
if err != nil {
t.Fatalf("BuildToken: %v", err)
}
set := jwk.NewSet()
if err := set.AddKey(key); err != nil {
t.Fatalf("AddKey: %v", err)
}
pub, err := utils.GetPublicKeySet(ctx, set)
if err != nil {
t.Fatalf("GetPublicKeySet: %v", err)
}
return tok, pub
}

t.Run("not a jwt skips with ErrTokenNotJWT", func(t *testing.T) {
svc, _, _, _, _ := newTestService(t)
if _, err := svc.GetByJWT(ctx, "fpt_not-a-jwt"); !errors.Is(err, serviceuser.ErrTokenNotJWT) {
t.Errorf("GetByJWT() error = %v, want errors.Is(ErrTokenNotJWT)", err)
}
})

t.Run("malformed (non-uuid) kid skips with ErrInvalidKeyID", func(t *testing.T) {
svc, _, _, _, _ := newTestService(t)
tok, _ := buildToken(t, "not-a-uuid")
// credRepo.Get must not be called for a malformed kid
if _, err := svc.GetByJWT(ctx, string(tok)); !errors.Is(err, serviceuser.ErrInvalidKeyID) {
t.Errorf("GetByJWT() error = %v, want errors.Is(ErrInvalidKeyID)", err)
}
})

t.Run("non-string kid skips with ErrInvalidKeyID", func(t *testing.T) {
svc, _, _, _, _ := newTestService(t)
key, err := utils.CreateJWKWithKID("33333333-3333-3333-3333-333333333333")
if err != nil {
t.Fatalf("CreateJWKWithKID: %v", err)
}
tok, err := jwt.NewBuilder().Claim(jwk.KeyIDKey, 12345).Build()
if err != nil {
t.Fatalf("Build: %v", err)
}
signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, key))
if err != nil {
t.Fatalf("Sign: %v", err)
}
if _, err := svc.GetByJWT(ctx, string(signed)); !errors.Is(err, serviceuser.ErrInvalidKeyID) {
t.Errorf("GetByJWT() error = %v, want errors.Is(ErrInvalidKeyID)", err)
}
})

t.Run("unknown kid skips with ErrCredNotExist", func(t *testing.T) {
svc, _, credRepo, _, _ := newTestService(t)
kid := "11111111-1111-1111-1111-111111111111"
tok, _ := buildToken(t, kid)
credRepo.On("Get", ctx, kid).Return(serviceuser.Credential{}, serviceuser.ErrCredNotExist)
if _, err := svc.GetByJWT(ctx, string(tok)); !errors.Is(err, serviceuser.ErrCredNotExist) {
t.Errorf("GetByJWT() error = %v, want errors.Is(ErrCredNotExist)", err)
}
})

t.Run("bad signature stops with ErrInvalidCred", func(t *testing.T) {
svc, _, credRepo, _, _ := newTestService(t)
kid := "22222222-2222-2222-2222-222222222222"
tok, _ := buildToken(t, kid) // signed by one key
_, otherPub := buildToken(t, kid) // verified against a different key with the same kid
credRepo.On("Get", ctx, kid).Return(
serviceuser.Credential{ID: kid, ServiceUserID: "su-1", PublicKey: otherPub}, nil)
if _, err := svc.GetByJWT(ctx, string(tok)); !errors.Is(err, serviceuser.ErrInvalidCred) {
t.Errorf("GetByJWT() error = %v, want errors.Is(ErrInvalidCred)", err)
}
})
}
Loading
Loading