security_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package sts
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/golang-jwt/jwt/v5"
  9. "github.com/seaweedfs/seaweedfs/weed/iam/providers"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. )
  13. // TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
  14. // with specific issuer claims can only be validated by the provider registered for that issuer
  15. func TestSecurityIssuerToProviderMapping(t *testing.T) {
  16. ctx := context.Background()
  17. // Create STS service with two mock providers
  18. service := NewSTSService()
  19. config := &STSConfig{
  20. TokenDuration: FlexibleDuration{time.Hour},
  21. MaxSessionLength: FlexibleDuration{time.Hour * 12},
  22. Issuer: "test-sts",
  23. SigningKey: []byte("test-signing-key-32-characters-long"),
  24. }
  25. err := service.Initialize(config)
  26. require.NoError(t, err)
  27. // Set up mock trust policy validator
  28. mockValidator := &MockTrustPolicyValidator{}
  29. service.SetTrustPolicyValidator(mockValidator)
  30. // Create two mock providers with different issuers
  31. providerA := &MockIdentityProviderWithIssuer{
  32. name: "provider-a",
  33. issuer: "https://provider-a.com",
  34. validTokens: map[string]bool{
  35. "token-for-provider-a": true,
  36. },
  37. }
  38. providerB := &MockIdentityProviderWithIssuer{
  39. name: "provider-b",
  40. issuer: "https://provider-b.com",
  41. validTokens: map[string]bool{
  42. "token-for-provider-b": true,
  43. },
  44. }
  45. // Register both providers
  46. err = service.RegisterProvider(providerA)
  47. require.NoError(t, err)
  48. err = service.RegisterProvider(providerB)
  49. require.NoError(t, err)
  50. // Create JWT tokens with specific issuer claims
  51. tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
  52. tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
  53. t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
  54. // This should succeed - token has issuer A and provider A is registered
  55. identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
  56. assert.NoError(t, err)
  57. assert.NotNil(t, identity)
  58. assert.Equal(t, "provider-a", provider.Name())
  59. })
  60. t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
  61. // This should succeed - token has issuer B and provider B is registered
  62. identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
  63. assert.NoError(t, err)
  64. assert.NotNil(t, identity)
  65. assert.Equal(t, "provider-b", provider.Name())
  66. })
  67. t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
  68. // Create token with unregistered issuer
  69. tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
  70. // This should fail - no provider registered for this issuer
  71. identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
  72. assert.Error(t, err)
  73. assert.Nil(t, identity)
  74. assert.Nil(t, provider)
  75. assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
  76. })
  77. t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) {
  78. // Non-JWT tokens should be rejected - no fallback mechanism exists for security
  79. identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
  80. assert.Error(t, err)
  81. assert.Nil(t, identity)
  82. assert.Nil(t, provider)
  83. assert.Contains(t, err.Error(), "web identity token must be a valid JWT token")
  84. })
  85. }
  86. // createTestJWT creates a test JWT token with the specified issuer and subject
  87. func createTestJWT(t *testing.T, issuer, subject string) string {
  88. token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
  89. "iss": issuer,
  90. "sub": subject,
  91. "aud": "test-client",
  92. "exp": time.Now().Add(time.Hour).Unix(),
  93. "iat": time.Now().Unix(),
  94. })
  95. tokenString, err := token.SignedString([]byte("test-signing-key"))
  96. require.NoError(t, err)
  97. return tokenString
  98. }
  99. // MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
  100. type MockIdentityProviderWithIssuer struct {
  101. name string
  102. issuer string
  103. validTokens map[string]bool
  104. }
  105. func (m *MockIdentityProviderWithIssuer) Name() string {
  106. return m.name
  107. }
  108. func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
  109. return m.issuer
  110. }
  111. func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
  112. return nil
  113. }
  114. func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
  115. // For JWT tokens, parse and validate the token format
  116. if len(token) > 50 && strings.Contains(token, ".") {
  117. // This looks like a JWT - parse it to get the subject
  118. parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
  119. if err != nil {
  120. return nil, fmt.Errorf("invalid JWT token")
  121. }
  122. claims, ok := parsedToken.Claims.(jwt.MapClaims)
  123. if !ok {
  124. return nil, fmt.Errorf("invalid claims")
  125. }
  126. issuer, _ := claims["iss"].(string)
  127. subject, _ := claims["sub"].(string)
  128. // Verify the issuer matches what we expect
  129. if issuer != m.issuer {
  130. return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
  131. }
  132. return &providers.ExternalIdentity{
  133. UserID: subject,
  134. Email: subject + "@" + m.name + ".com",
  135. Provider: m.name,
  136. }, nil
  137. }
  138. // For non-JWT tokens, check our simple token list
  139. if m.validTokens[token] {
  140. return &providers.ExternalIdentity{
  141. UserID: "test-user",
  142. Email: "test@" + m.name + ".com",
  143. Provider: m.name,
  144. }, nil
  145. }
  146. return nil, fmt.Errorf("invalid token")
  147. }
  148. func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
  149. return &providers.ExternalIdentity{
  150. UserID: userID,
  151. Email: userID + "@" + m.name + ".com",
  152. Provider: m.name,
  153. }, nil
  154. }
  155. func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
  156. if m.validTokens[token] {
  157. return &providers.TokenClaims{
  158. Subject: "test-user",
  159. Issuer: m.issuer,
  160. }, nil
  161. }
  162. return nil, fmt.Errorf("invalid token")
  163. }