| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- package oidc
- import (
- "context"
- "crypto/rand"
- "crypto/rsa"
- "encoding/base64"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/providers"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- // TestOIDCProviderInitialization tests OIDC provider initialization
- func TestOIDCProviderInitialization(t *testing.T) {
- tests := []struct {
- name string
- config *OIDCConfig
- wantErr bool
- }{
- {
- name: "valid config",
- config: &OIDCConfig{
- Issuer: "https://accounts.google.com",
- ClientID: "test-client-id",
- JWKSUri: "https://www.googleapis.com/oauth2/v3/certs",
- },
- wantErr: false,
- },
- {
- name: "missing issuer",
- config: &OIDCConfig{
- ClientID: "test-client-id",
- },
- wantErr: true,
- },
- {
- name: "missing client id",
- config: &OIDCConfig{
- Issuer: "https://accounts.google.com",
- },
- wantErr: true,
- },
- {
- name: "invalid issuer url",
- config: &OIDCConfig{
- Issuer: "not-a-url",
- ClientID: "test-client-id",
- },
- wantErr: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- provider := NewOIDCProvider("test-provider")
- err := provider.Initialize(tt.config)
- if tt.wantErr {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- assert.Equal(t, "test-provider", provider.Name())
- }
- })
- }
- }
- // TestOIDCProviderJWTValidation tests JWT token validation
- func TestOIDCProviderJWTValidation(t *testing.T) {
- // Set up test server with JWKS endpoint
- privateKey, publicKey := generateTestKeys(t)
- jwks := map[string]interface{}{
- "keys": []map[string]interface{}{
- {
- "kty": "RSA",
- "kid": "test-key-id",
- "use": "sig",
- "alg": "RS256",
- "n": encodePublicKey(t, publicKey),
- "e": "AQAB",
- },
- },
- }
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path == "/.well-known/openid_configuration" {
- config := map[string]interface{}{
- "issuer": "http://" + r.Host,
- "jwks_uri": "http://" + r.Host + "/jwks",
- }
- json.NewEncoder(w).Encode(config)
- } else if r.URL.Path == "/jwks" {
- json.NewEncoder(w).Encode(jwks)
- }
- }))
- defer server.Close()
- provider := NewOIDCProvider("test-oidc")
- config := &OIDCConfig{
- Issuer: server.URL,
- ClientID: "test-client",
- JWKSUri: server.URL + "/jwks",
- }
- err := provider.Initialize(config)
- require.NoError(t, err)
- t.Run("valid token", func(t *testing.T) {
- // Create valid JWT token
- token := createTestJWT(t, privateKey, jwt.MapClaims{
- "iss": server.URL,
- "aud": "test-client",
- "sub": "user123",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- "email": "user@example.com",
- "name": "Test User",
- })
- claims, err := provider.ValidateToken(context.Background(), token)
- require.NoError(t, err)
- require.NotNil(t, claims)
- assert.Equal(t, "user123", claims.Subject)
- assert.Equal(t, server.URL, claims.Issuer)
- email, exists := claims.GetClaimString("email")
- assert.True(t, exists)
- assert.Equal(t, "user@example.com", email)
- })
- t.Run("valid token with array audience", func(t *testing.T) {
- // Create valid JWT token with audience as an array (per RFC 7519)
- token := createTestJWT(t, privateKey, jwt.MapClaims{
- "iss": server.URL,
- "aud": []string{"test-client", "another-client"},
- "sub": "user456",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- "email": "user2@example.com",
- "name": "Test User 2",
- })
- claims, err := provider.ValidateToken(context.Background(), token)
- require.NoError(t, err)
- require.NotNil(t, claims)
- assert.Equal(t, "user456", claims.Subject)
- assert.Equal(t, server.URL, claims.Issuer)
- email, exists := claims.GetClaimString("email")
- assert.True(t, exists)
- assert.Equal(t, "user2@example.com", email)
- })
- t.Run("expired token", func(t *testing.T) {
- // Create expired JWT token
- token := createTestJWT(t, privateKey, jwt.MapClaims{
- "iss": server.URL,
- "aud": "test-client",
- "sub": "user123",
- "exp": time.Now().Add(-time.Hour).Unix(), // Expired
- "iat": time.Now().Add(-time.Hour * 2).Unix(),
- })
- _, err := provider.ValidateToken(context.Background(), token)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "expired")
- })
- t.Run("invalid signature", func(t *testing.T) {
- // Create token with wrong key
- wrongKey, _ := generateTestKeys(t)
- token := createTestJWT(t, wrongKey, jwt.MapClaims{
- "iss": server.URL,
- "aud": "test-client",
- "sub": "user123",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
- _, err := provider.ValidateToken(context.Background(), token)
- assert.Error(t, err)
- })
- }
- // TestOIDCProviderAuthentication tests authentication flow
- func TestOIDCProviderAuthentication(t *testing.T) {
- // Set up test OIDC provider
- privateKey, publicKey := generateTestKeys(t)
- server := setupOIDCTestServer(t, publicKey)
- defer server.Close()
- provider := NewOIDCProvider("test-oidc")
- config := &OIDCConfig{
- Issuer: server.URL,
- ClientID: "test-client",
- JWKSUri: server.URL + "/jwks",
- RoleMapping: &providers.RoleMapping{
- Rules: []providers.MappingRule{
- {
- Claim: "email",
- Value: "*@example.com",
- Role: "arn:seaweed:iam::role/UserRole",
- },
- {
- Claim: "groups",
- Value: "admins",
- Role: "arn:seaweed:iam::role/AdminRole",
- },
- },
- DefaultRole: "arn:seaweed:iam::role/GuestRole",
- },
- }
- err := provider.Initialize(config)
- require.NoError(t, err)
- t.Run("successful authentication", func(t *testing.T) {
- token := createTestJWT(t, privateKey, jwt.MapClaims{
- "iss": server.URL,
- "aud": "test-client",
- "sub": "user123",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- "email": "user@example.com",
- "name": "Test User",
- "groups": []string{"users", "developers"},
- })
- identity, err := provider.Authenticate(context.Background(), token)
- require.NoError(t, err)
- require.NotNil(t, identity)
- assert.Equal(t, "user123", identity.UserID)
- assert.Equal(t, "user@example.com", identity.Email)
- assert.Equal(t, "Test User", identity.DisplayName)
- assert.Equal(t, "test-oidc", identity.Provider)
- assert.Contains(t, identity.Groups, "users")
- assert.Contains(t, identity.Groups, "developers")
- })
- t.Run("authentication with invalid token", func(t *testing.T) {
- _, err := provider.Authenticate(context.Background(), "invalid-token")
- assert.Error(t, err)
- })
- }
- // TestOIDCProviderUserInfo tests user info retrieval
- func TestOIDCProviderUserInfo(t *testing.T) {
- // Set up test server with UserInfo endpoint
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path == "/userinfo" {
- // Check for Authorization header
- authHeader := r.Header.Get("Authorization")
- if !strings.HasPrefix(authHeader, "Bearer ") {
- w.WriteHeader(http.StatusUnauthorized)
- w.Write([]byte(`{"error": "unauthorized"}`))
- return
- }
- accessToken := strings.TrimPrefix(authHeader, "Bearer ")
- // Return 401 for explicitly invalid tokens
- if accessToken == "invalid-token" {
- w.WriteHeader(http.StatusUnauthorized)
- w.Write([]byte(`{"error": "invalid_token"}`))
- return
- }
- // Mock user info response
- userInfo := map[string]interface{}{
- "sub": "user123",
- "email": "user@example.com",
- "name": "Test User",
- "groups": []string{"users", "developers"},
- }
- // Customize response based on token
- if strings.Contains(accessToken, "admin") {
- userInfo["groups"] = []string{"admins"}
- }
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(userInfo)
- }
- }))
- defer server.Close()
- provider := NewOIDCProvider("test-oidc")
- config := &OIDCConfig{
- Issuer: server.URL,
- ClientID: "test-client",
- UserInfoUri: server.URL + "/userinfo",
- }
- err := provider.Initialize(config)
- require.NoError(t, err)
- t.Run("get user info with access token", func(t *testing.T) {
- // Test using access token (real UserInfo endpoint call)
- identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token")
- require.NoError(t, err)
- require.NotNil(t, identity)
- assert.Equal(t, "user123", identity.UserID)
- assert.Equal(t, "user@example.com", identity.Email)
- assert.Equal(t, "Test User", identity.DisplayName)
- assert.Contains(t, identity.Groups, "users")
- assert.Contains(t, identity.Groups, "developers")
- assert.Equal(t, "test-oidc", identity.Provider)
- })
- t.Run("get admin user info", func(t *testing.T) {
- // Test admin token response
- identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token")
- require.NoError(t, err)
- require.NotNil(t, identity)
- assert.Equal(t, "user123", identity.UserID)
- assert.Contains(t, identity.Groups, "admins")
- })
- t.Run("get user info without token", func(t *testing.T) {
- // Test without access token (should fail)
- _, err := provider.GetUserInfoWithToken(context.Background(), "")
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "access token cannot be empty")
- })
- t.Run("get user info with invalid token", func(t *testing.T) {
- // Test with invalid access token (should get 401)
- _, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token")
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401")
- })
- t.Run("get user info with custom claims mapping", func(t *testing.T) {
- // Create provider with custom claims mapping
- customProvider := NewOIDCProvider("test-custom-oidc")
- customConfig := &OIDCConfig{
- Issuer: server.URL,
- ClientID: "test-client",
- UserInfoUri: server.URL + "/userinfo",
- ClaimsMapping: map[string]string{
- "customEmail": "email",
- "customName": "name",
- },
- }
- err := customProvider.Initialize(customConfig)
- require.NoError(t, err)
- identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token")
- require.NoError(t, err)
- require.NotNil(t, identity)
- // Standard claims should still work
- assert.Equal(t, "user123", identity.UserID)
- assert.Equal(t, "user@example.com", identity.Email)
- assert.Equal(t, "Test User", identity.DisplayName)
- })
- t.Run("get user info with empty id", func(t *testing.T) {
- _, err := provider.GetUserInfo(context.Background(), "")
- assert.Error(t, err)
- })
- }
- // Helper functions for testing
- func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
- privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
- require.NoError(t, err)
- return privateKey, &privateKey.PublicKey
- }
- func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
- token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
- token.Header["kid"] = "test-key-id"
- tokenString, err := token.SignedString(privateKey)
- require.NoError(t, err)
- return tokenString
- }
- func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
- // Properly encode the RSA modulus (N) as base64url
- return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
- }
- func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
- jwks := map[string]interface{}{
- "keys": []map[string]interface{}{
- {
- "kty": "RSA",
- "kid": "test-key-id",
- "use": "sig",
- "alg": "RS256",
- "n": encodePublicKey(t, publicKey),
- "e": "AQAB",
- },
- },
- }
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- switch r.URL.Path {
- case "/.well-known/openid_configuration":
- config := map[string]interface{}{
- "issuer": "http://" + r.Host,
- "jwks_uri": "http://" + r.Host + "/jwks",
- "userinfo_endpoint": "http://" + r.Host + "/userinfo",
- }
- json.NewEncoder(w).Encode(config)
- case "/jwks":
- json.NewEncoder(w).Encode(jwks)
- case "/userinfo":
- // Mock UserInfo endpoint
- authHeader := r.Header.Get("Authorization")
- if !strings.HasPrefix(authHeader, "Bearer ") {
- w.WriteHeader(http.StatusUnauthorized)
- w.Write([]byte(`{"error": "unauthorized"}`))
- return
- }
- accessToken := strings.TrimPrefix(authHeader, "Bearer ")
- // Return 401 for explicitly invalid tokens
- if accessToken == "invalid-token" {
- w.WriteHeader(http.StatusUnauthorized)
- w.Write([]byte(`{"error": "invalid_token"}`))
- return
- }
- // Mock user info response based on access token
- userInfo := map[string]interface{}{
- "sub": "user123",
- "email": "user@example.com",
- "name": "Test User",
- "groups": []string{"users", "developers"},
- }
- // Customize response based on token
- if strings.Contains(accessToken, "admin") {
- userInfo["groups"] = []string{"admins"}
- }
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(userInfo)
- default:
- http.NotFound(w, r)
- }
- }))
- }
|