| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- package sts
- import (
- "crypto/rand"
- "crypto/sha256"
- "encoding/base64"
- "encoding/hex"
- "fmt"
- "time"
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/utils"
- )
- // TokenGenerator handles token generation and validation
- type TokenGenerator struct {
- signingKey []byte
- issuer string
- }
- // NewTokenGenerator creates a new token generator
- func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator {
- return &TokenGenerator{
- signingKey: signingKey,
- issuer: issuer,
- }
- }
- // GenerateSessionToken creates a signed JWT session token (legacy method for compatibility)
- func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) {
- claims := NewSTSSessionClaims(sessionId, t.issuer, expiresAt)
- return t.GenerateJWTWithClaims(claims)
- }
- // GenerateJWTWithClaims creates a signed JWT token with comprehensive session claims
- func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string, error) {
- if claims == nil {
- return "", fmt.Errorf("claims cannot be nil")
- }
- // Ensure issuer is set from token generator
- if claims.Issuer == "" {
- claims.Issuer = t.issuer
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString(t.signingKey)
- }
- // ValidateSessionToken validates and extracts claims from a session token
- func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionTokenClaims, error) {
- token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
- }
- return t.signingKey, nil
- })
- if err != nil {
- return nil, fmt.Errorf(ErrInvalidToken, err)
- }
- if !token.Valid {
- return nil, fmt.Errorf(ErrTokenNotValid)
- }
- claims, ok := token.Claims.(jwt.MapClaims)
- if !ok {
- return nil, fmt.Errorf(ErrInvalidTokenClaims)
- }
- // Verify issuer
- if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer {
- return nil, fmt.Errorf(ErrInvalidIssuer)
- }
- // Extract session ID
- sessionId, ok := claims[JWTClaimSubject].(string)
- if !ok {
- return nil, fmt.Errorf(ErrMissingSessionID)
- }
- return &SessionTokenClaims{
- SessionId: sessionId,
- ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0),
- IssuedAt: time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0),
- }, nil
- }
- // ValidateJWTWithClaims validates and extracts comprehensive session claims from a JWT token
- func (t *TokenGenerator) ValidateJWTWithClaims(tokenString string) (*STSSessionClaims, error) {
- token, err := jwt.ParseWithClaims(tokenString, &STSSessionClaims{}, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
- }
- return t.signingKey, nil
- })
- if err != nil {
- return nil, fmt.Errorf(ErrInvalidToken, err)
- }
- if !token.Valid {
- return nil, fmt.Errorf(ErrTokenNotValid)
- }
- claims, ok := token.Claims.(*STSSessionClaims)
- if !ok {
- return nil, fmt.Errorf(ErrInvalidTokenClaims)
- }
- // Validate issuer
- if claims.Issuer != t.issuer {
- return nil, fmt.Errorf(ErrInvalidIssuer)
- }
- // Validate that required fields are present
- if claims.SessionId == "" {
- return nil, fmt.Errorf(ErrMissingSessionID)
- }
- // Additional validation using the claims' own validation method
- if !claims.IsValid() {
- return nil, fmt.Errorf(ErrTokenNotValid)
- }
- return claims, nil
- }
- // SessionTokenClaims represents parsed session token claims
- type SessionTokenClaims struct {
- SessionId string
- ExpiresAt time.Time
- IssuedAt time.Time
- }
- // CredentialGenerator generates AWS-compatible temporary credentials
- type CredentialGenerator struct{}
- // NewCredentialGenerator creates a new credential generator
- func NewCredentialGenerator() *CredentialGenerator {
- return &CredentialGenerator{}
- }
- // GenerateTemporaryCredentials creates temporary AWS credentials
- func (c *CredentialGenerator) GenerateTemporaryCredentials(sessionId string, expiration time.Time) (*Credentials, error) {
- accessKeyId, err := c.generateAccessKeyId(sessionId)
- if err != nil {
- return nil, fmt.Errorf("failed to generate access key ID: %w", err)
- }
- secretAccessKey, err := c.generateSecretAccessKey()
- if err != nil {
- return nil, fmt.Errorf("failed to generate secret access key: %w", err)
- }
- sessionToken, err := c.generateSessionTokenId(sessionId)
- if err != nil {
- return nil, fmt.Errorf("failed to generate session token: %w", err)
- }
- return &Credentials{
- AccessKeyId: accessKeyId,
- SecretAccessKey: secretAccessKey,
- SessionToken: sessionToken,
- Expiration: expiration,
- }, nil
- }
- // generateAccessKeyId generates an AWS-style access key ID
- func (c *CredentialGenerator) generateAccessKeyId(sessionId string) (string, error) {
- // Create a deterministic but unique access key ID based on session
- hash := sha256.Sum256([]byte("access-key:" + sessionId))
- return "AKIA" + hex.EncodeToString(hash[:8]), nil // AWS format: AKIA + 16 chars
- }
- // generateSecretAccessKey generates a random secret access key
- func (c *CredentialGenerator) generateSecretAccessKey() (string, error) {
- // Generate 32 random bytes for secret key
- secretBytes := make([]byte, 32)
- _, err := rand.Read(secretBytes)
- if err != nil {
- return "", err
- }
- return base64.StdEncoding.EncodeToString(secretBytes), nil
- }
- // generateSessionTokenId generates a session token identifier
- func (c *CredentialGenerator) generateSessionTokenId(sessionId string) (string, error) {
- // Create session token with session ID embedded
- hash := sha256.Sum256([]byte("session-token:" + sessionId))
- return "ST" + hex.EncodeToString(hash[:16]), nil // Custom format
- }
- // generateSessionId generates a unique session ID
- func GenerateSessionId() (string, error) {
- randomBytes := make([]byte, 16)
- _, err := rand.Read(randomBytes)
- if err != nil {
- return "", err
- }
- return hex.EncodeToString(randomBytes), nil
- }
- // generateAssumedRoleArn generates the ARN for an assumed role user
- func GenerateAssumedRoleArn(roleArn, sessionName string) string {
- // Convert role ARN to assumed role user ARN
- // arn:seaweed:iam::role/RoleName -> arn:seaweed:sts::assumed-role/RoleName/SessionName
- roleName := utils.ExtractRoleNameFromArn(roleArn)
- if roleName == "" {
- // This should not happen if validation is done properly upstream
- return fmt.Sprintf("arn:seaweed:sts::assumed-role/INVALID-ARN/%s", sessionName)
- }
- return fmt.Sprintf("arn:seaweed:sts::assumed-role/%s/%s", roleName, sessionName)
- }
|