sts_service_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. package sts
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/golang-jwt/jwt/v5"
  9. "github.com/seaweedfs/seaweedfs/weed/iam/providers"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. )
  13. // createSTSTestJWT creates a test JWT token for STS service tests
  14. func createSTSTestJWT(t *testing.T, issuer, subject string) string {
  15. token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
  16. "iss": issuer,
  17. "sub": subject,
  18. "aud": "test-client",
  19. "exp": time.Now().Add(time.Hour).Unix(),
  20. "iat": time.Now().Unix(),
  21. })
  22. tokenString, err := token.SignedString([]byte("test-signing-key"))
  23. require.NoError(t, err)
  24. return tokenString
  25. }
  26. // TestSTSServiceInitialization tests STS service initialization
  27. func TestSTSServiceInitialization(t *testing.T) {
  28. tests := []struct {
  29. name string
  30. config *STSConfig
  31. wantErr bool
  32. }{
  33. {
  34. name: "valid config",
  35. config: &STSConfig{
  36. TokenDuration: FlexibleDuration{time.Hour},
  37. MaxSessionLength: FlexibleDuration{time.Hour * 12},
  38. Issuer: "seaweedfs-sts",
  39. SigningKey: []byte("test-signing-key"),
  40. },
  41. wantErr: false,
  42. },
  43. {
  44. name: "missing signing key",
  45. config: &STSConfig{
  46. TokenDuration: FlexibleDuration{time.Hour},
  47. Issuer: "seaweedfs-sts",
  48. },
  49. wantErr: true,
  50. },
  51. {
  52. name: "invalid token duration",
  53. config: &STSConfig{
  54. TokenDuration: FlexibleDuration{-time.Hour},
  55. Issuer: "seaweedfs-sts",
  56. SigningKey: []byte("test-key"),
  57. },
  58. wantErr: true,
  59. },
  60. }
  61. for _, tt := range tests {
  62. t.Run(tt.name, func(t *testing.T) {
  63. service := NewSTSService()
  64. err := service.Initialize(tt.config)
  65. if tt.wantErr {
  66. assert.Error(t, err)
  67. } else {
  68. assert.NoError(t, err)
  69. assert.True(t, service.IsInitialized())
  70. }
  71. })
  72. }
  73. }
  74. // TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
  75. func TestAssumeRoleWithWebIdentity(t *testing.T) {
  76. service := setupTestSTSService(t)
  77. tests := []struct {
  78. name string
  79. roleArn string
  80. webIdentityToken string
  81. sessionName string
  82. durationSeconds *int64
  83. wantErr bool
  84. expectedSubject string
  85. }{
  86. {
  87. name: "successful role assumption",
  88. roleArn: "arn:seaweed:iam::role/TestRole",
  89. webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
  90. sessionName: "test-session",
  91. durationSeconds: nil, // Use default
  92. wantErr: false,
  93. expectedSubject: "test-user-id",
  94. },
  95. {
  96. name: "invalid web identity token",
  97. roleArn: "arn:seaweed:iam::role/TestRole",
  98. webIdentityToken: "invalid-token",
  99. sessionName: "test-session",
  100. wantErr: true,
  101. },
  102. {
  103. name: "non-existent role",
  104. roleArn: "arn:seaweed:iam::role/NonExistentRole",
  105. webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
  106. sessionName: "test-session",
  107. wantErr: true,
  108. },
  109. {
  110. name: "custom session duration",
  111. roleArn: "arn:seaweed:iam::role/TestRole",
  112. webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
  113. sessionName: "test-session",
  114. durationSeconds: int64Ptr(7200), // 2 hours
  115. wantErr: false,
  116. },
  117. }
  118. for _, tt := range tests {
  119. t.Run(tt.name, func(t *testing.T) {
  120. ctx := context.Background()
  121. request := &AssumeRoleWithWebIdentityRequest{
  122. RoleArn: tt.roleArn,
  123. WebIdentityToken: tt.webIdentityToken,
  124. RoleSessionName: tt.sessionName,
  125. DurationSeconds: tt.durationSeconds,
  126. }
  127. response, err := service.AssumeRoleWithWebIdentity(ctx, request)
  128. if tt.wantErr {
  129. assert.Error(t, err)
  130. assert.Nil(t, response)
  131. } else {
  132. assert.NoError(t, err)
  133. assert.NotNil(t, response)
  134. assert.NotNil(t, response.Credentials)
  135. assert.NotNil(t, response.AssumedRoleUser)
  136. // Verify credentials
  137. creds := response.Credentials
  138. assert.NotEmpty(t, creds.AccessKeyId)
  139. assert.NotEmpty(t, creds.SecretAccessKey)
  140. assert.NotEmpty(t, creds.SessionToken)
  141. assert.True(t, creds.Expiration.After(time.Now()))
  142. // Verify assumed role user
  143. user := response.AssumedRoleUser
  144. assert.Equal(t, tt.roleArn, user.AssumedRoleId)
  145. assert.Contains(t, user.Arn, tt.sessionName)
  146. if tt.expectedSubject != "" {
  147. assert.Equal(t, tt.expectedSubject, user.Subject)
  148. }
  149. }
  150. })
  151. }
  152. }
  153. // TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
  154. func TestAssumeRoleWithLDAP(t *testing.T) {
  155. service := setupTestSTSService(t)
  156. tests := []struct {
  157. name string
  158. roleArn string
  159. username string
  160. password string
  161. sessionName string
  162. wantErr bool
  163. }{
  164. {
  165. name: "successful LDAP role assumption",
  166. roleArn: "arn:seaweed:iam::role/LDAPRole",
  167. username: "testuser",
  168. password: "testpass",
  169. sessionName: "ldap-session",
  170. wantErr: false,
  171. },
  172. {
  173. name: "invalid LDAP credentials",
  174. roleArn: "arn:seaweed:iam::role/LDAPRole",
  175. username: "testuser",
  176. password: "wrongpass",
  177. sessionName: "ldap-session",
  178. wantErr: true,
  179. },
  180. }
  181. for _, tt := range tests {
  182. t.Run(tt.name, func(t *testing.T) {
  183. ctx := context.Background()
  184. request := &AssumeRoleWithCredentialsRequest{
  185. RoleArn: tt.roleArn,
  186. Username: tt.username,
  187. Password: tt.password,
  188. RoleSessionName: tt.sessionName,
  189. ProviderName: "test-ldap",
  190. }
  191. response, err := service.AssumeRoleWithCredentials(ctx, request)
  192. if tt.wantErr {
  193. assert.Error(t, err)
  194. assert.Nil(t, response)
  195. } else {
  196. assert.NoError(t, err)
  197. assert.NotNil(t, response)
  198. assert.NotNil(t, response.Credentials)
  199. }
  200. })
  201. }
  202. }
  203. // TestSessionTokenValidation tests session token validation
  204. func TestSessionTokenValidation(t *testing.T) {
  205. service := setupTestSTSService(t)
  206. ctx := context.Background()
  207. // First, create a session
  208. request := &AssumeRoleWithWebIdentityRequest{
  209. RoleArn: "arn:seaweed:iam::role/TestRole",
  210. WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
  211. RoleSessionName: "test-session",
  212. }
  213. response, err := service.AssumeRoleWithWebIdentity(ctx, request)
  214. require.NoError(t, err)
  215. require.NotNil(t, response)
  216. sessionToken := response.Credentials.SessionToken
  217. tests := []struct {
  218. name string
  219. token string
  220. wantErr bool
  221. }{
  222. {
  223. name: "valid session token",
  224. token: sessionToken,
  225. wantErr: false,
  226. },
  227. {
  228. name: "invalid session token",
  229. token: "invalid-session-token",
  230. wantErr: true,
  231. },
  232. {
  233. name: "empty session token",
  234. token: "",
  235. wantErr: true,
  236. },
  237. }
  238. for _, tt := range tests {
  239. t.Run(tt.name, func(t *testing.T) {
  240. session, err := service.ValidateSessionToken(ctx, tt.token)
  241. if tt.wantErr {
  242. assert.Error(t, err)
  243. assert.Nil(t, session)
  244. } else {
  245. assert.NoError(t, err)
  246. assert.NotNil(t, session)
  247. assert.Equal(t, "test-session", session.SessionName)
  248. assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn)
  249. }
  250. })
  251. }
  252. }
  253. // TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
  254. // Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
  255. func TestSessionTokenPersistence(t *testing.T) {
  256. service := setupTestSTSService(t)
  257. ctx := context.Background()
  258. // Create a session first
  259. request := &AssumeRoleWithWebIdentityRequest{
  260. RoleArn: "arn:seaweed:iam::role/TestRole",
  261. WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
  262. RoleSessionName: "test-session",
  263. }
  264. response, err := service.AssumeRoleWithWebIdentity(ctx, request)
  265. require.NoError(t, err)
  266. sessionToken := response.Credentials.SessionToken
  267. // Verify token is valid initially
  268. session, err := service.ValidateSessionToken(ctx, sessionToken)
  269. assert.NoError(t, err)
  270. assert.NotNil(t, session)
  271. assert.Equal(t, "test-session", session.SessionName)
  272. // In a stateless JWT system, tokens remain valid throughout their lifetime
  273. // Multiple validations should all succeed as long as the token hasn't expired
  274. session2, err := service.ValidateSessionToken(ctx, sessionToken)
  275. assert.NoError(t, err, "Token should remain valid in stateless system")
  276. assert.NotNil(t, session2, "Session should be returned from JWT token")
  277. assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
  278. }
  279. // Helper functions
  280. func setupTestSTSService(t *testing.T) *STSService {
  281. service := NewSTSService()
  282. config := &STSConfig{
  283. TokenDuration: FlexibleDuration{time.Hour},
  284. MaxSessionLength: FlexibleDuration{time.Hour * 12},
  285. Issuer: "test-sts",
  286. SigningKey: []byte("test-signing-key-32-characters-long"),
  287. }
  288. err := service.Initialize(config)
  289. require.NoError(t, err)
  290. // Set up mock trust policy validator (required for STS testing)
  291. mockValidator := &MockTrustPolicyValidator{}
  292. service.SetTrustPolicyValidator(mockValidator)
  293. // Register test providers
  294. mockOIDCProvider := &MockIdentityProvider{
  295. name: "test-oidc",
  296. validTokens: map[string]*providers.TokenClaims{
  297. createSTSTestJWT(t, "test-issuer", "test-user"): {
  298. Subject: "test-user-id",
  299. Issuer: "test-issuer",
  300. Claims: map[string]interface{}{
  301. "email": "test@example.com",
  302. "name": "Test User",
  303. },
  304. },
  305. },
  306. }
  307. mockLDAPProvider := &MockIdentityProvider{
  308. name: "test-ldap",
  309. validCredentials: map[string]string{
  310. "testuser": "testpass",
  311. },
  312. }
  313. service.RegisterProvider(mockOIDCProvider)
  314. service.RegisterProvider(mockLDAPProvider)
  315. return service
  316. }
  317. func int64Ptr(v int64) *int64 {
  318. return &v
  319. }
  320. // Mock identity provider for testing
  321. type MockIdentityProvider struct {
  322. name string
  323. validTokens map[string]*providers.TokenClaims
  324. validCredentials map[string]string
  325. }
  326. func (m *MockIdentityProvider) Name() string {
  327. return m.name
  328. }
  329. func (m *MockIdentityProvider) GetIssuer() string {
  330. return "test-issuer" // This matches the issuer in the token claims
  331. }
  332. func (m *MockIdentityProvider) Initialize(config interface{}) error {
  333. return nil
  334. }
  335. func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
  336. // First try to parse as JWT token
  337. if len(token) > 20 && strings.Count(token, ".") >= 2 {
  338. parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
  339. if err == nil {
  340. if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
  341. issuer, _ := claims["iss"].(string)
  342. subject, _ := claims["sub"].(string)
  343. // Verify the issuer matches what we expect
  344. if issuer == "test-issuer" && subject != "" {
  345. return &providers.ExternalIdentity{
  346. UserID: subject,
  347. Email: subject + "@test-domain.com",
  348. DisplayName: "Test User " + subject,
  349. Provider: m.name,
  350. }, nil
  351. }
  352. }
  353. }
  354. }
  355. // Handle legacy OIDC tokens (for backwards compatibility)
  356. if claims, exists := m.validTokens[token]; exists {
  357. email, _ := claims.GetClaimString("email")
  358. name, _ := claims.GetClaimString("name")
  359. return &providers.ExternalIdentity{
  360. UserID: claims.Subject,
  361. Email: email,
  362. DisplayName: name,
  363. Provider: m.name,
  364. }, nil
  365. }
  366. // Handle LDAP credentials (username:password format)
  367. if m.validCredentials != nil {
  368. parts := strings.Split(token, ":")
  369. if len(parts) == 2 {
  370. username, password := parts[0], parts[1]
  371. if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
  372. return &providers.ExternalIdentity{
  373. UserID: username,
  374. Email: username + "@" + m.name + ".com",
  375. DisplayName: "Test User " + username,
  376. Provider: m.name,
  377. }, nil
  378. }
  379. }
  380. }
  381. return nil, fmt.Errorf("unknown test token: %s", token)
  382. }
  383. func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
  384. return &providers.ExternalIdentity{
  385. UserID: userID,
  386. Email: userID + "@" + m.name + ".com",
  387. Provider: m.name,
  388. }, nil
  389. }
  390. func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
  391. if claims, exists := m.validTokens[token]; exists {
  392. return claims, nil
  393. }
  394. return nil, fmt.Errorf("invalid token")
  395. }