provider_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package providers
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/stretchr/testify/require"
  8. )
  9. // TestIdentityProviderInterface tests the core identity provider interface
  10. func TestIdentityProviderInterface(t *testing.T) {
  11. tests := []struct {
  12. name string
  13. provider IdentityProvider
  14. wantErr bool
  15. }{
  16. // We'll add test cases as we implement providers
  17. }
  18. for _, tt := range tests {
  19. t.Run(tt.name, func(t *testing.T) {
  20. // Test provider name
  21. name := tt.provider.Name()
  22. assert.NotEmpty(t, name, "Provider name should not be empty")
  23. // Test initialization
  24. err := tt.provider.Initialize(nil)
  25. if tt.wantErr {
  26. assert.Error(t, err)
  27. return
  28. }
  29. require.NoError(t, err)
  30. // Test authentication with invalid token
  31. ctx := context.Background()
  32. _, err = tt.provider.Authenticate(ctx, "invalid-token")
  33. assert.Error(t, err, "Should fail with invalid token")
  34. })
  35. }
  36. }
  37. // TestExternalIdentityValidation tests external identity structure validation
  38. func TestExternalIdentityValidation(t *testing.T) {
  39. tests := []struct {
  40. name string
  41. identity *ExternalIdentity
  42. wantErr bool
  43. }{
  44. {
  45. name: "valid identity",
  46. identity: &ExternalIdentity{
  47. UserID: "user123",
  48. Email: "user@example.com",
  49. DisplayName: "Test User",
  50. Groups: []string{"group1", "group2"},
  51. Attributes: map[string]string{"dept": "engineering"},
  52. Provider: "test-provider",
  53. },
  54. wantErr: false,
  55. },
  56. {
  57. name: "missing user id",
  58. identity: &ExternalIdentity{
  59. Email: "user@example.com",
  60. Provider: "test-provider",
  61. },
  62. wantErr: true,
  63. },
  64. {
  65. name: "missing provider",
  66. identity: &ExternalIdentity{
  67. UserID: "user123",
  68. Email: "user@example.com",
  69. },
  70. wantErr: true,
  71. },
  72. {
  73. name: "invalid email",
  74. identity: &ExternalIdentity{
  75. UserID: "user123",
  76. Email: "invalid-email",
  77. Provider: "test-provider",
  78. },
  79. wantErr: true,
  80. },
  81. }
  82. for _, tt := range tests {
  83. t.Run(tt.name, func(t *testing.T) {
  84. err := tt.identity.Validate()
  85. if tt.wantErr {
  86. assert.Error(t, err)
  87. } else {
  88. assert.NoError(t, err)
  89. }
  90. })
  91. }
  92. }
  93. // TestTokenClaimsValidation tests token claims structure
  94. func TestTokenClaimsValidation(t *testing.T) {
  95. tests := []struct {
  96. name string
  97. claims *TokenClaims
  98. valid bool
  99. }{
  100. {
  101. name: "valid claims",
  102. claims: &TokenClaims{
  103. Subject: "user123",
  104. Issuer: "https://provider.example.com",
  105. Audience: "seaweedfs",
  106. ExpiresAt: time.Now().Add(time.Hour),
  107. IssuedAt: time.Now().Add(-time.Minute),
  108. Claims: map[string]interface{}{"email": "user@example.com"},
  109. },
  110. valid: true,
  111. },
  112. {
  113. name: "expired token",
  114. claims: &TokenClaims{
  115. Subject: "user123",
  116. Issuer: "https://provider.example.com",
  117. Audience: "seaweedfs",
  118. ExpiresAt: time.Now().Add(-time.Hour), // Expired
  119. IssuedAt: time.Now().Add(-time.Hour * 2),
  120. Claims: map[string]interface{}{"email": "user@example.com"},
  121. },
  122. valid: false,
  123. },
  124. {
  125. name: "future issued token",
  126. claims: &TokenClaims{
  127. Subject: "user123",
  128. Issuer: "https://provider.example.com",
  129. Audience: "seaweedfs",
  130. ExpiresAt: time.Now().Add(time.Hour),
  131. IssuedAt: time.Now().Add(time.Hour), // Future
  132. Claims: map[string]interface{}{"email": "user@example.com"},
  133. },
  134. valid: false,
  135. },
  136. }
  137. for _, tt := range tests {
  138. t.Run(tt.name, func(t *testing.T) {
  139. valid := tt.claims.IsValid()
  140. assert.Equal(t, tt.valid, valid)
  141. })
  142. }
  143. }
  144. // TestProviderRegistry tests provider registration and discovery
  145. func TestProviderRegistry(t *testing.T) {
  146. // Clear registry for test
  147. registry := NewProviderRegistry()
  148. t.Run("register provider", func(t *testing.T) {
  149. mockProvider := &MockProvider{name: "test-provider"}
  150. err := registry.RegisterProvider(mockProvider)
  151. assert.NoError(t, err)
  152. // Test duplicate registration
  153. err = registry.RegisterProvider(mockProvider)
  154. assert.Error(t, err, "Should not allow duplicate registration")
  155. })
  156. t.Run("get provider", func(t *testing.T) {
  157. provider, exists := registry.GetProvider("test-provider")
  158. assert.True(t, exists)
  159. assert.Equal(t, "test-provider", provider.Name())
  160. // Test non-existent provider
  161. _, exists = registry.GetProvider("non-existent")
  162. assert.False(t, exists)
  163. })
  164. t.Run("list providers", func(t *testing.T) {
  165. providers := registry.ListProviders()
  166. assert.Len(t, providers, 1)
  167. assert.Equal(t, "test-provider", providers[0])
  168. })
  169. }
  170. // MockProvider for testing
  171. type MockProvider struct {
  172. name string
  173. initialized bool
  174. shouldError bool
  175. }
  176. func (m *MockProvider) Name() string {
  177. return m.name
  178. }
  179. func (m *MockProvider) Initialize(config interface{}) error {
  180. if m.shouldError {
  181. return assert.AnError
  182. }
  183. m.initialized = true
  184. return nil
  185. }
  186. func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) {
  187. if !m.initialized {
  188. return nil, assert.AnError
  189. }
  190. if token == "invalid-token" {
  191. return nil, assert.AnError
  192. }
  193. return &ExternalIdentity{
  194. UserID: "test-user",
  195. Email: "test@example.com",
  196. DisplayName: "Test User",
  197. Provider: m.name,
  198. }, nil
  199. }
  200. func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) {
  201. if !m.initialized || userID == "" {
  202. return nil, assert.AnError
  203. }
  204. return &ExternalIdentity{
  205. UserID: userID,
  206. Email: userID + "@example.com",
  207. DisplayName: "User " + userID,
  208. Provider: m.name,
  209. }, nil
  210. }
  211. func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) {
  212. if !m.initialized || token == "invalid-token" {
  213. return nil, assert.AnError
  214. }
  215. return &TokenClaims{
  216. Subject: "test-user",
  217. Issuer: "test-issuer",
  218. Audience: "seaweedfs",
  219. ExpiresAt: time.Now().Add(time.Hour),
  220. IssuedAt: time.Now(),
  221. Claims: map[string]interface{}{"email": "test@example.com"},
  222. }, nil
  223. }