| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- package sts
- import (
- "context"
- "fmt"
- "strings"
- "testing"
- "time"
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/providers"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- // createSTSTestJWT creates a test JWT token for STS service tests
- func createSTSTestJWT(t *testing.T, issuer, subject string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
- return tokenString
- }
- // TestSTSServiceInitialization tests STS service initialization
- func TestSTSServiceInitialization(t *testing.T) {
- tests := []struct {
- name string
- config *STSConfig
- wantErr bool
- }{
- {
- name: "valid config",
- config: &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "seaweedfs-sts",
- SigningKey: []byte("test-signing-key"),
- },
- wantErr: false,
- },
- {
- name: "missing signing key",
- config: &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- Issuer: "seaweedfs-sts",
- },
- wantErr: true,
- },
- {
- name: "invalid token duration",
- config: &STSConfig{
- TokenDuration: FlexibleDuration{-time.Hour},
- Issuer: "seaweedfs-sts",
- SigningKey: []byte("test-key"),
- },
- wantErr: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- service := NewSTSService()
- err := service.Initialize(tt.config)
- if tt.wantErr {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- assert.True(t, service.IsInitialized())
- }
- })
- }
- }
- // TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
- func TestAssumeRoleWithWebIdentity(t *testing.T) {
- service := setupTestSTSService(t)
- tests := []struct {
- name string
- roleArn string
- webIdentityToken string
- sessionName string
- durationSeconds *int64
- wantErr bool
- expectedSubject string
- }{
- {
- name: "successful role assumption",
- roleArn: "arn:seaweed:iam::role/TestRole",
- webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
- sessionName: "test-session",
- durationSeconds: nil, // Use default
- wantErr: false,
- expectedSubject: "test-user-id",
- },
- {
- name: "invalid web identity token",
- roleArn: "arn:seaweed:iam::role/TestRole",
- webIdentityToken: "invalid-token",
- sessionName: "test-session",
- wantErr: true,
- },
- {
- name: "non-existent role",
- roleArn: "arn:seaweed:iam::role/NonExistentRole",
- webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- sessionName: "test-session",
- wantErr: true,
- },
- {
- name: "custom session duration",
- roleArn: "arn:seaweed:iam::role/TestRole",
- webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- sessionName: "test-session",
- durationSeconds: int64Ptr(7200), // 2 hours
- wantErr: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: tt.roleArn,
- WebIdentityToken: tt.webIdentityToken,
- RoleSessionName: tt.sessionName,
- DurationSeconds: tt.durationSeconds,
- }
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- if tt.wantErr {
- assert.Error(t, err)
- assert.Nil(t, response)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, response)
- assert.NotNil(t, response.Credentials)
- assert.NotNil(t, response.AssumedRoleUser)
- // Verify credentials
- creds := response.Credentials
- assert.NotEmpty(t, creds.AccessKeyId)
- assert.NotEmpty(t, creds.SecretAccessKey)
- assert.NotEmpty(t, creds.SessionToken)
- assert.True(t, creds.Expiration.After(time.Now()))
- // Verify assumed role user
- user := response.AssumedRoleUser
- assert.Equal(t, tt.roleArn, user.AssumedRoleId)
- assert.Contains(t, user.Arn, tt.sessionName)
- if tt.expectedSubject != "" {
- assert.Equal(t, tt.expectedSubject, user.Subject)
- }
- }
- })
- }
- }
- // TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
- func TestAssumeRoleWithLDAP(t *testing.T) {
- service := setupTestSTSService(t)
- tests := []struct {
- name string
- roleArn string
- username string
- password string
- sessionName string
- wantErr bool
- }{
- {
- name: "successful LDAP role assumption",
- roleArn: "arn:seaweed:iam::role/LDAPRole",
- username: "testuser",
- password: "testpass",
- sessionName: "ldap-session",
- wantErr: false,
- },
- {
- name: "invalid LDAP credentials",
- roleArn: "arn:seaweed:iam::role/LDAPRole",
- username: "testuser",
- password: "wrongpass",
- sessionName: "ldap-session",
- wantErr: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- request := &AssumeRoleWithCredentialsRequest{
- RoleArn: tt.roleArn,
- Username: tt.username,
- Password: tt.password,
- RoleSessionName: tt.sessionName,
- ProviderName: "test-ldap",
- }
- response, err := service.AssumeRoleWithCredentials(ctx, request)
- if tt.wantErr {
- assert.Error(t, err)
- assert.Nil(t, response)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, response)
- assert.NotNil(t, response.Credentials)
- }
- })
- }
- }
- // TestSessionTokenValidation tests session token validation
- func TestSessionTokenValidation(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
- // First, create a session
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:seaweed:iam::role/TestRole",
- WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- }
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
- sessionToken := response.Credentials.SessionToken
- tests := []struct {
- name string
- token string
- wantErr bool
- }{
- {
- name: "valid session token",
- token: sessionToken,
- wantErr: false,
- },
- {
- name: "invalid session token",
- token: "invalid-session-token",
- wantErr: true,
- },
- {
- name: "empty session token",
- token: "",
- wantErr: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- session, err := service.ValidateSessionToken(ctx, tt.token)
- if tt.wantErr {
- assert.Error(t, err)
- assert.Nil(t, session)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, session)
- assert.Equal(t, "test-session", session.SessionName)
- assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn)
- }
- })
- }
- }
- // TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
- // Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
- func TestSessionTokenPersistence(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
- // Create a session first
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:seaweed:iam::role/TestRole",
- WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- }
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- sessionToken := response.Credentials.SessionToken
- // Verify token is valid initially
- session, err := service.ValidateSessionToken(ctx, sessionToken)
- assert.NoError(t, err)
- assert.NotNil(t, session)
- assert.Equal(t, "test-session", session.SessionName)
- // In a stateless JWT system, tokens remain valid throughout their lifetime
- // Multiple validations should all succeed as long as the token hasn't expired
- session2, err := service.ValidateSessionToken(ctx, sessionToken)
- assert.NoError(t, err, "Token should remain valid in stateless system")
- assert.NotNil(t, session2, "Session should be returned from JWT token")
- assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
- }
- // Helper functions
- func setupTestSTSService(t *testing.T) *STSService {
- service := NewSTSService()
- config := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- }
- err := service.Initialize(config)
- require.NoError(t, err)
- // Set up mock trust policy validator (required for STS testing)
- mockValidator := &MockTrustPolicyValidator{}
- service.SetTrustPolicyValidator(mockValidator)
- // Register test providers
- mockOIDCProvider := &MockIdentityProvider{
- name: "test-oidc",
- validTokens: map[string]*providers.TokenClaims{
- createSTSTestJWT(t, "test-issuer", "test-user"): {
- Subject: "test-user-id",
- Issuer: "test-issuer",
- Claims: map[string]interface{}{
- "email": "test@example.com",
- "name": "Test User",
- },
- },
- },
- }
- mockLDAPProvider := &MockIdentityProvider{
- name: "test-ldap",
- validCredentials: map[string]string{
- "testuser": "testpass",
- },
- }
- service.RegisterProvider(mockOIDCProvider)
- service.RegisterProvider(mockLDAPProvider)
- return service
- }
- func int64Ptr(v int64) *int64 {
- return &v
- }
- // Mock identity provider for testing
- type MockIdentityProvider struct {
- name string
- validTokens map[string]*providers.TokenClaims
- validCredentials map[string]string
- }
- func (m *MockIdentityProvider) Name() string {
- return m.name
- }
- func (m *MockIdentityProvider) GetIssuer() string {
- return "test-issuer" // This matches the issuer in the token claims
- }
- func (m *MockIdentityProvider) Initialize(config interface{}) error {
- return nil
- }
- func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
- // First try to parse as JWT token
- if len(token) > 20 && strings.Count(token, ".") >= 2 {
- parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
- if err == nil {
- if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
- issuer, _ := claims["iss"].(string)
- subject, _ := claims["sub"].(string)
- // Verify the issuer matches what we expect
- if issuer == "test-issuer" && subject != "" {
- return &providers.ExternalIdentity{
- UserID: subject,
- Email: subject + "@test-domain.com",
- DisplayName: "Test User " + subject,
- Provider: m.name,
- }, nil
- }
- }
- }
- }
- // Handle legacy OIDC tokens (for backwards compatibility)
- if claims, exists := m.validTokens[token]; exists {
- email, _ := claims.GetClaimString("email")
- name, _ := claims.GetClaimString("name")
- return &providers.ExternalIdentity{
- UserID: claims.Subject,
- Email: email,
- DisplayName: name,
- Provider: m.name,
- }, nil
- }
- // Handle LDAP credentials (username:password format)
- if m.validCredentials != nil {
- parts := strings.Split(token, ":")
- if len(parts) == 2 {
- username, password := parts[0], parts[1]
- if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
- return &providers.ExternalIdentity{
- UserID: username,
- Email: username + "@" + m.name + ".com",
- DisplayName: "Test User " + username,
- Provider: m.name,
- }, nil
- }
- }
- }
- return nil, fmt.Errorf("unknown test token: %s", token)
- }
- func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
- return &providers.ExternalIdentity{
- UserID: userID,
- Email: userID + "@" + m.name + ".com",
- Provider: m.name,
- }, nil
- }
- func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
- if claims, exists := m.validTokens[token]; exists {
- return claims, nil
- }
- return nil, fmt.Errorf("invalid token")
- }
|