token_utils.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package sts
  2. import (
  3. "crypto/rand"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "encoding/hex"
  7. "fmt"
  8. "time"
  9. "github.com/golang-jwt/jwt/v5"
  10. "github.com/seaweedfs/seaweedfs/weed/iam/utils"
  11. )
  12. // TokenGenerator handles token generation and validation
  13. type TokenGenerator struct {
  14. signingKey []byte
  15. issuer string
  16. }
  17. // NewTokenGenerator creates a new token generator
  18. func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator {
  19. return &TokenGenerator{
  20. signingKey: signingKey,
  21. issuer: issuer,
  22. }
  23. }
  24. // GenerateSessionToken creates a signed JWT session token (legacy method for compatibility)
  25. func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) {
  26. claims := NewSTSSessionClaims(sessionId, t.issuer, expiresAt)
  27. return t.GenerateJWTWithClaims(claims)
  28. }
  29. // GenerateJWTWithClaims creates a signed JWT token with comprehensive session claims
  30. func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string, error) {
  31. if claims == nil {
  32. return "", fmt.Errorf("claims cannot be nil")
  33. }
  34. // Ensure issuer is set from token generator
  35. if claims.Issuer == "" {
  36. claims.Issuer = t.issuer
  37. }
  38. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  39. return token.SignedString(t.signingKey)
  40. }
  41. // ValidateSessionToken validates and extracts claims from a session token
  42. func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionTokenClaims, error) {
  43. token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
  44. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  45. return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  46. }
  47. return t.signingKey, nil
  48. })
  49. if err != nil {
  50. return nil, fmt.Errorf(ErrInvalidToken, err)
  51. }
  52. if !token.Valid {
  53. return nil, fmt.Errorf(ErrTokenNotValid)
  54. }
  55. claims, ok := token.Claims.(jwt.MapClaims)
  56. if !ok {
  57. return nil, fmt.Errorf(ErrInvalidTokenClaims)
  58. }
  59. // Verify issuer
  60. if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer {
  61. return nil, fmt.Errorf(ErrInvalidIssuer)
  62. }
  63. // Extract session ID
  64. sessionId, ok := claims[JWTClaimSubject].(string)
  65. if !ok {
  66. return nil, fmt.Errorf(ErrMissingSessionID)
  67. }
  68. return &SessionTokenClaims{
  69. SessionId: sessionId,
  70. ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0),
  71. IssuedAt: time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0),
  72. }, nil
  73. }
  74. // ValidateJWTWithClaims validates and extracts comprehensive session claims from a JWT token
  75. func (t *TokenGenerator) ValidateJWTWithClaims(tokenString string) (*STSSessionClaims, error) {
  76. token, err := jwt.ParseWithClaims(tokenString, &STSSessionClaims{}, func(token *jwt.Token) (interface{}, error) {
  77. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  78. return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  79. }
  80. return t.signingKey, nil
  81. })
  82. if err != nil {
  83. return nil, fmt.Errorf(ErrInvalidToken, err)
  84. }
  85. if !token.Valid {
  86. return nil, fmt.Errorf(ErrTokenNotValid)
  87. }
  88. claims, ok := token.Claims.(*STSSessionClaims)
  89. if !ok {
  90. return nil, fmt.Errorf(ErrInvalidTokenClaims)
  91. }
  92. // Validate issuer
  93. if claims.Issuer != t.issuer {
  94. return nil, fmt.Errorf(ErrInvalidIssuer)
  95. }
  96. // Validate that required fields are present
  97. if claims.SessionId == "" {
  98. return nil, fmt.Errorf(ErrMissingSessionID)
  99. }
  100. // Additional validation using the claims' own validation method
  101. if !claims.IsValid() {
  102. return nil, fmt.Errorf(ErrTokenNotValid)
  103. }
  104. return claims, nil
  105. }
  106. // SessionTokenClaims represents parsed session token claims
  107. type SessionTokenClaims struct {
  108. SessionId string
  109. ExpiresAt time.Time
  110. IssuedAt time.Time
  111. }
  112. // CredentialGenerator generates AWS-compatible temporary credentials
  113. type CredentialGenerator struct{}
  114. // NewCredentialGenerator creates a new credential generator
  115. func NewCredentialGenerator() *CredentialGenerator {
  116. return &CredentialGenerator{}
  117. }
  118. // GenerateTemporaryCredentials creates temporary AWS credentials
  119. func (c *CredentialGenerator) GenerateTemporaryCredentials(sessionId string, expiration time.Time) (*Credentials, error) {
  120. accessKeyId, err := c.generateAccessKeyId(sessionId)
  121. if err != nil {
  122. return nil, fmt.Errorf("failed to generate access key ID: %w", err)
  123. }
  124. secretAccessKey, err := c.generateSecretAccessKey()
  125. if err != nil {
  126. return nil, fmt.Errorf("failed to generate secret access key: %w", err)
  127. }
  128. sessionToken, err := c.generateSessionTokenId(sessionId)
  129. if err != nil {
  130. return nil, fmt.Errorf("failed to generate session token: %w", err)
  131. }
  132. return &Credentials{
  133. AccessKeyId: accessKeyId,
  134. SecretAccessKey: secretAccessKey,
  135. SessionToken: sessionToken,
  136. Expiration: expiration,
  137. }, nil
  138. }
  139. // generateAccessKeyId generates an AWS-style access key ID
  140. func (c *CredentialGenerator) generateAccessKeyId(sessionId string) (string, error) {
  141. // Create a deterministic but unique access key ID based on session
  142. hash := sha256.Sum256([]byte("access-key:" + sessionId))
  143. return "AKIA" + hex.EncodeToString(hash[:8]), nil // AWS format: AKIA + 16 chars
  144. }
  145. // generateSecretAccessKey generates a random secret access key
  146. func (c *CredentialGenerator) generateSecretAccessKey() (string, error) {
  147. // Generate 32 random bytes for secret key
  148. secretBytes := make([]byte, 32)
  149. _, err := rand.Read(secretBytes)
  150. if err != nil {
  151. return "", err
  152. }
  153. return base64.StdEncoding.EncodeToString(secretBytes), nil
  154. }
  155. // generateSessionTokenId generates a session token identifier
  156. func (c *CredentialGenerator) generateSessionTokenId(sessionId string) (string, error) {
  157. // Create session token with session ID embedded
  158. hash := sha256.Sum256([]byte("session-token:" + sessionId))
  159. return "ST" + hex.EncodeToString(hash[:16]), nil // Custom format
  160. }
  161. // generateSessionId generates a unique session ID
  162. func GenerateSessionId() (string, error) {
  163. randomBytes := make([]byte, 16)
  164. _, err := rand.Read(randomBytes)
  165. if err != nil {
  166. return "", err
  167. }
  168. return hex.EncodeToString(randomBytes), nil
  169. }
  170. // generateAssumedRoleArn generates the ARN for an assumed role user
  171. func GenerateAssumedRoleArn(roleArn, sessionName string) string {
  172. // Convert role ARN to assumed role user ARN
  173. // arn:seaweed:iam::role/RoleName -> arn:seaweed:sts::assumed-role/RoleName/SessionName
  174. roleName := utils.ExtractRoleNameFromArn(roleArn)
  175. if roleName == "" {
  176. // This should not happen if validation is done properly upstream
  177. return fmt.Sprintf("arn:seaweed:sts::assumed-role/INVALID-ARN/%s", sessionName)
  178. }
  179. return fmt.Sprintf("arn:seaweed:sts::assumed-role/%s/%s", roleName, sessionName)
  180. }