cross_instance_token_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package sts
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/golang-jwt/jwt/v5"
  7. "github.com/seaweedfs/seaweedfs/weed/iam/oidc"
  8. "github.com/seaweedfs/seaweedfs/weed/iam/providers"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. )
  12. // Test-only constants for mock providers
  13. const (
  14. ProviderTypeMock = "mock"
  15. )
  16. // createMockOIDCProvider creates a mock OIDC provider for testing
  17. // This is only available in test builds
  18. func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) {
  19. // Convert config to OIDC format
  20. factory := NewProviderFactory()
  21. oidcConfig, err := factory.convertToOIDCConfig(config)
  22. if err != nil {
  23. return nil, err
  24. }
  25. // Set default values for mock provider if not provided
  26. if oidcConfig.Issuer == "" {
  27. oidcConfig.Issuer = "http://localhost:9999"
  28. }
  29. provider := oidc.NewMockOIDCProvider(name)
  30. if err := provider.Initialize(oidcConfig); err != nil {
  31. return nil, err
  32. }
  33. // Set up default test data for the mock provider
  34. provider.SetupDefaultTestData()
  35. return provider, nil
  36. }
  37. // createMockJWT creates a test JWT token with the specified issuer for mock provider testing
  38. func createMockJWT(t *testing.T, issuer, subject string) string {
  39. token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
  40. "iss": issuer,
  41. "sub": subject,
  42. "aud": "test-client",
  43. "exp": time.Now().Add(time.Hour).Unix(),
  44. "iat": time.Now().Unix(),
  45. })
  46. tokenString, err := token.SignedString([]byte("test-signing-key"))
  47. require.NoError(t, err)
  48. return tokenString
  49. }
  50. // TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance
  51. // can be used and validated by other STS instances in a distributed environment
  52. func TestCrossInstanceTokenUsage(t *testing.T) {
  53. ctx := context.Background()
  54. // Dummy filer address for testing
  55. // Common configuration that would be shared across all instances in production
  56. sharedConfig := &STSConfig{
  57. TokenDuration: FlexibleDuration{time.Hour},
  58. MaxSessionLength: FlexibleDuration{12 * time.Hour},
  59. Issuer: "distributed-sts-cluster", // SAME across all instances
  60. SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
  61. Providers: []*ProviderConfig{
  62. {
  63. Name: "company-oidc",
  64. Type: ProviderTypeOIDC,
  65. Enabled: true,
  66. Config: map[string]interface{}{
  67. ConfigFieldIssuer: "https://sso.company.com/realms/production",
  68. ConfigFieldClientID: "seaweedfs-cluster",
  69. ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
  70. },
  71. },
  72. },
  73. }
  74. // Create multiple STS instances simulating different S3 gateway instances
  75. instanceA := NewSTSService() // e.g., s3-gateway-1
  76. instanceB := NewSTSService() // e.g., s3-gateway-2
  77. instanceC := NewSTSService() // e.g., s3-gateway-3
  78. // Initialize all instances with IDENTICAL configuration
  79. err := instanceA.Initialize(sharedConfig)
  80. require.NoError(t, err, "Instance A should initialize")
  81. err = instanceB.Initialize(sharedConfig)
  82. require.NoError(t, err, "Instance B should initialize")
  83. err = instanceC.Initialize(sharedConfig)
  84. require.NoError(t, err, "Instance C should initialize")
  85. // Set up mock trust policy validator for all instances (required for STS testing)
  86. mockValidator := &MockTrustPolicyValidator{}
  87. instanceA.SetTrustPolicyValidator(mockValidator)
  88. instanceB.SetTrustPolicyValidator(mockValidator)
  89. instanceC.SetTrustPolicyValidator(mockValidator)
  90. // Manually register mock provider for testing (not available in production)
  91. mockProviderConfig := map[string]interface{}{
  92. ConfigFieldIssuer: "http://test-mock:9999",
  93. ConfigFieldClientID: TestClientID,
  94. }
  95. mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig)
  96. require.NoError(t, err)
  97. mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig)
  98. require.NoError(t, err)
  99. mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig)
  100. require.NoError(t, err)
  101. instanceA.RegisterProvider(mockProviderA)
  102. instanceB.RegisterProvider(mockProviderB)
  103. instanceC.RegisterProvider(mockProviderC)
  104. // Test 1: Token generated on Instance A can be validated on Instance B & C
  105. t.Run("cross_instance_token_validation", func(t *testing.T) {
  106. // Generate session token on Instance A
  107. sessionId := TestSessionID
  108. expiresAt := time.Now().Add(time.Hour)
  109. tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
  110. require.NoError(t, err, "Instance A should generate token")
  111. // Validate token on Instance B
  112. claimsFromB, err := instanceB.tokenGenerator.ValidateSessionToken(tokenFromA)
  113. require.NoError(t, err, "Instance B should validate token from Instance A")
  114. assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
  115. // Validate same token on Instance C
  116. claimsFromC, err := instanceC.tokenGenerator.ValidateSessionToken(tokenFromA)
  117. require.NoError(t, err, "Instance C should validate token from Instance A")
  118. assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
  119. // All instances should extract identical claims
  120. assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId)
  121. assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix())
  122. assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix())
  123. })
  124. // Test 2: Complete assume role flow across instances
  125. t.Run("cross_instance_assume_role_flow", func(t *testing.T) {
  126. // Step 1: User authenticates and assumes role on Instance A
  127. // Create a valid JWT token for the mock provider
  128. mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
  129. assumeRequest := &AssumeRoleWithWebIdentityRequest{
  130. RoleArn: "arn:seaweed:iam::role/CrossInstanceTestRole",
  131. WebIdentityToken: mockToken, // JWT token for mock provider
  132. RoleSessionName: "cross-instance-test-session",
  133. DurationSeconds: int64ToPtr(3600),
  134. }
  135. // Instance A processes assume role request
  136. responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
  137. require.NoError(t, err, "Instance A should process assume role")
  138. sessionToken := responseFromA.Credentials.SessionToken
  139. accessKeyId := responseFromA.Credentials.AccessKeyId
  140. secretAccessKey := responseFromA.Credentials.SecretAccessKey
  141. // Verify response structure
  142. assert.NotEmpty(t, sessionToken, "Should have session token")
  143. assert.NotEmpty(t, accessKeyId, "Should have access key ID")
  144. assert.NotEmpty(t, secretAccessKey, "Should have secret access key")
  145. assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
  146. // Step 2: Use session token on Instance B (different instance)
  147. sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
  148. require.NoError(t, err, "Instance B should validate session token from Instance A")
  149. assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
  150. assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
  151. // Step 3: Use same session token on Instance C (yet another instance)
  152. sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
  153. require.NoError(t, err, "Instance C should validate session token from Instance A")
  154. // All instances should return identical session information
  155. assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId)
  156. assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName)
  157. assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn)
  158. assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject)
  159. assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider)
  160. })
  161. // Test 3: Session revocation across instances
  162. t.Run("cross_instance_session_revocation", func(t *testing.T) {
  163. // Create session on Instance A
  164. mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
  165. assumeRequest := &AssumeRoleWithWebIdentityRequest{
  166. RoleArn: "arn:seaweed:iam::role/RevocationTestRole",
  167. WebIdentityToken: mockToken,
  168. RoleSessionName: "revocation-test-session",
  169. }
  170. response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
  171. require.NoError(t, err)
  172. sessionToken := response.Credentials.SessionToken
  173. // Verify token works on Instance B
  174. _, err = instanceB.ValidateSessionToken(ctx, sessionToken)
  175. require.NoError(t, err, "Token should be valid on Instance B initially")
  176. // Validate session on Instance C to verify cross-instance token compatibility
  177. _, err = instanceC.ValidateSessionToken(ctx, sessionToken)
  178. require.NoError(t, err, "Instance C should be able to validate session token")
  179. // In a stateless JWT system, tokens remain valid on all instances since they're self-contained
  180. // No revocation is possible without breaking the stateless architecture
  181. _, err = instanceA.ValidateSessionToken(ctx, sessionToken)
  182. assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)")
  183. // Verify token is still valid on Instance B
  184. _, err = instanceB.ValidateSessionToken(ctx, sessionToken)
  185. assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)")
  186. })
  187. // Test 4: Provider consistency across instances
  188. t.Run("provider_consistency_affects_token_generation", func(t *testing.T) {
  189. // All instances should have same providers and be able to process same OIDC tokens
  190. providerNamesA := instanceA.getProviderNames()
  191. providerNamesB := instanceB.getProviderNames()
  192. providerNamesC := instanceC.getProviderNames()
  193. assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers")
  194. assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers")
  195. // All instances should be able to process same web identity token
  196. testToken := createMockJWT(t, "http://test-mock:9999", "test-user")
  197. // Try to assume role with same token on different instances
  198. assumeRequest := &AssumeRoleWithWebIdentityRequest{
  199. RoleArn: "arn:seaweed:iam::role/ProviderTestRole",
  200. WebIdentityToken: testToken,
  201. RoleSessionName: "provider-consistency-test",
  202. }
  203. // Should work on any instance
  204. responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
  205. responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
  206. responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
  207. require.NoError(t, errA, "Instance A should process OIDC token")
  208. require.NoError(t, errB, "Instance B should process OIDC token")
  209. require.NoError(t, errC, "Instance C should process OIDC token")
  210. // All should return valid responses (sessions will have different IDs but same structure)
  211. assert.NotEmpty(t, responseA.Credentials.SessionToken)
  212. assert.NotEmpty(t, responseB.Credentials.SessionToken)
  213. assert.NotEmpty(t, responseC.Credentials.SessionToken)
  214. })
  215. }
  216. // TestSTSDistributedConfigurationRequirements tests the configuration requirements
  217. // for cross-instance token compatibility
  218. func TestSTSDistributedConfigurationRequirements(t *testing.T) {
  219. _ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
  220. t.Run("same_signing_key_required", func(t *testing.T) {
  221. // Instance A with signing key 1
  222. configA := &STSConfig{
  223. TokenDuration: FlexibleDuration{time.Hour},
  224. MaxSessionLength: FlexibleDuration{12 * time.Hour},
  225. Issuer: "test-sts",
  226. SigningKey: []byte("signing-key-1-32-characters-long"),
  227. }
  228. // Instance B with different signing key
  229. configB := &STSConfig{
  230. TokenDuration: FlexibleDuration{time.Hour},
  231. MaxSessionLength: FlexibleDuration{12 * time.Hour},
  232. Issuer: "test-sts",
  233. SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT!
  234. }
  235. instanceA := NewSTSService()
  236. instanceB := NewSTSService()
  237. err := instanceA.Initialize(configA)
  238. require.NoError(t, err)
  239. err = instanceB.Initialize(configB)
  240. require.NoError(t, err)
  241. // Generate token on Instance A
  242. sessionId := "test-session"
  243. expiresAt := time.Now().Add(time.Hour)
  244. tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
  245. require.NoError(t, err)
  246. // Instance A should validate its own token
  247. _, err = instanceA.tokenGenerator.ValidateSessionToken(tokenFromA)
  248. assert.NoError(t, err, "Instance A should validate own token")
  249. // Instance B should REJECT token due to different signing key
  250. _, err = instanceB.tokenGenerator.ValidateSessionToken(tokenFromA)
  251. assert.Error(t, err, "Instance B should reject token with different signing key")
  252. assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
  253. })
  254. t.Run("same_issuer_required", func(t *testing.T) {
  255. sharedSigningKey := []byte("shared-signing-key-32-characters-lo")
  256. // Instance A with issuer 1
  257. configA := &STSConfig{
  258. TokenDuration: FlexibleDuration{time.Hour},
  259. MaxSessionLength: FlexibleDuration{12 * time.Hour},
  260. Issuer: "sts-cluster-1",
  261. SigningKey: sharedSigningKey,
  262. }
  263. // Instance B with different issuer
  264. configB := &STSConfig{
  265. TokenDuration: FlexibleDuration{time.Hour},
  266. MaxSessionLength: FlexibleDuration{12 * time.Hour},
  267. Issuer: "sts-cluster-2", // DIFFERENT!
  268. SigningKey: sharedSigningKey,
  269. }
  270. instanceA := NewSTSService()
  271. instanceB := NewSTSService()
  272. err := instanceA.Initialize(configA)
  273. require.NoError(t, err)
  274. err = instanceB.Initialize(configB)
  275. require.NoError(t, err)
  276. // Generate token on Instance A
  277. sessionId := "test-session"
  278. expiresAt := time.Now().Add(time.Hour)
  279. tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
  280. require.NoError(t, err)
  281. // Instance B should REJECT token due to different issuer
  282. _, err = instanceB.tokenGenerator.ValidateSessionToken(tokenFromA)
  283. assert.Error(t, err, "Instance B should reject token with different issuer")
  284. assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
  285. })
  286. t.Run("identical_configuration_required", func(t *testing.T) {
  287. // Identical configuration
  288. identicalConfig := &STSConfig{
  289. TokenDuration: FlexibleDuration{time.Hour},
  290. MaxSessionLength: FlexibleDuration{12 * time.Hour},
  291. Issuer: "production-sts-cluster",
  292. SigningKey: []byte("production-signing-key-32-chars-l"),
  293. }
  294. // Create multiple instances with identical config
  295. instances := make([]*STSService, 5)
  296. for i := 0; i < 5; i++ {
  297. instances[i] = NewSTSService()
  298. err := instances[i].Initialize(identicalConfig)
  299. require.NoError(t, err, "Instance %d should initialize", i)
  300. }
  301. // Generate token on Instance 0
  302. sessionId := "multi-instance-test"
  303. expiresAt := time.Now().Add(time.Hour)
  304. token, err := instances[0].tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
  305. require.NoError(t, err)
  306. // All other instances should validate the token
  307. for i := 1; i < 5; i++ {
  308. claims, err := instances[i].tokenGenerator.ValidateSessionToken(token)
  309. require.NoError(t, err, "Instance %d should validate token", i)
  310. assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
  311. }
  312. })
  313. }
  314. // TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
  315. func TestSTSRealWorldDistributedScenarios(t *testing.T) {
  316. ctx := context.Background()
  317. t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
  318. // Simulate real production scenario:
  319. // 1. User authenticates with OIDC provider
  320. // 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1
  321. // 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer
  322. // 4. All instances should handle the session token correctly
  323. productionConfig := &STSConfig{
  324. TokenDuration: FlexibleDuration{2 * time.Hour},
  325. MaxSessionLength: FlexibleDuration{24 * time.Hour},
  326. Issuer: "seaweedfs-production-sts",
  327. SigningKey: []byte("prod-signing-key-32-characters-lon"),
  328. Providers: []*ProviderConfig{
  329. {
  330. Name: "corporate-oidc",
  331. Type: "oidc",
  332. Enabled: true,
  333. Config: map[string]interface{}{
  334. "issuer": "https://sso.company.com/realms/production",
  335. "clientId": "seaweedfs-prod-cluster",
  336. "clientSecret": "supersecret-prod-key",
  337. "scopes": []string{"openid", "profile", "email", "groups"},
  338. },
  339. },
  340. },
  341. }
  342. // Create 3 S3 Gateway instances behind load balancer
  343. gateway1 := NewSTSService()
  344. gateway2 := NewSTSService()
  345. gateway3 := NewSTSService()
  346. err := gateway1.Initialize(productionConfig)
  347. require.NoError(t, err)
  348. err = gateway2.Initialize(productionConfig)
  349. require.NoError(t, err)
  350. err = gateway3.Initialize(productionConfig)
  351. require.NoError(t, err)
  352. // Set up mock trust policy validator for all gateway instances
  353. mockValidator := &MockTrustPolicyValidator{}
  354. gateway1.SetTrustPolicyValidator(mockValidator)
  355. gateway2.SetTrustPolicyValidator(mockValidator)
  356. gateway3.SetTrustPolicyValidator(mockValidator)
  357. // Manually register mock provider for testing (not available in production)
  358. mockProviderConfig := map[string]interface{}{
  359. ConfigFieldIssuer: "http://test-mock:9999",
  360. ConfigFieldClientID: "test-client-id",
  361. }
  362. mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig)
  363. require.NoError(t, err)
  364. mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig)
  365. require.NoError(t, err)
  366. mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig)
  367. require.NoError(t, err)
  368. gateway1.RegisterProvider(mockProvider1)
  369. gateway2.RegisterProvider(mockProvider2)
  370. gateway3.RegisterProvider(mockProvider3)
  371. // Step 1: User authenticates and hits Gateway 1 for AssumeRole
  372. mockToken := createMockJWT(t, "http://test-mock:9999", "production-user")
  373. assumeRequest := &AssumeRoleWithWebIdentityRequest{
  374. RoleArn: "arn:seaweed:iam::role/ProductionS3User",
  375. WebIdentityToken: mockToken, // JWT token from mock provider
  376. RoleSessionName: "user-production-session",
  377. DurationSeconds: int64ToPtr(7200), // 2 hours
  378. }
  379. stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
  380. require.NoError(t, err, "Gateway 1 should handle AssumeRole")
  381. sessionToken := stsResponse.Credentials.SessionToken
  382. accessKey := stsResponse.Credentials.AccessKeyId
  383. secretKey := stsResponse.Credentials.SecretAccessKey
  384. // Step 2: User makes S3 requests that hit different gateways via load balancer
  385. // Simulate S3 request validation on Gateway 2
  386. sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
  387. require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
  388. assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
  389. assert.Equal(t, "arn:seaweed:iam::role/ProductionS3User", sessionInfo2.RoleArn)
  390. // Simulate S3 request validation on Gateway 3
  391. sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
  392. require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
  393. assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")
  394. // Step 3: Verify credentials are consistent
  395. assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent")
  396. assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent")
  397. // Step 4: Session expiration should be honored across all instances
  398. assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired")
  399. assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
  400. // Step 5: Token should be identical when parsed
  401. claims2, err := gateway2.tokenGenerator.ValidateSessionToken(sessionToken)
  402. require.NoError(t, err)
  403. claims3, err := gateway3.tokenGenerator.ValidateSessionToken(sessionToken)
  404. require.NoError(t, err)
  405. assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")
  406. assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match")
  407. })
  408. }
  409. // Helper function to convert int64 to pointer
  410. func int64ToPtr(i int64) *int64 {
  411. return &i
  412. }