oidc_provider_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. package oidc
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/rsa"
  6. "encoding/base64"
  7. "encoding/json"
  8. "net/http"
  9. "net/http/httptest"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/golang-jwt/jwt/v5"
  14. "github.com/seaweedfs/seaweedfs/weed/iam/providers"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. )
  18. // TestOIDCProviderInitialization tests OIDC provider initialization
  19. func TestOIDCProviderInitialization(t *testing.T) {
  20. tests := []struct {
  21. name string
  22. config *OIDCConfig
  23. wantErr bool
  24. }{
  25. {
  26. name: "valid config",
  27. config: &OIDCConfig{
  28. Issuer: "https://accounts.google.com",
  29. ClientID: "test-client-id",
  30. JWKSUri: "https://www.googleapis.com/oauth2/v3/certs",
  31. },
  32. wantErr: false,
  33. },
  34. {
  35. name: "missing issuer",
  36. config: &OIDCConfig{
  37. ClientID: "test-client-id",
  38. },
  39. wantErr: true,
  40. },
  41. {
  42. name: "missing client id",
  43. config: &OIDCConfig{
  44. Issuer: "https://accounts.google.com",
  45. },
  46. wantErr: true,
  47. },
  48. {
  49. name: "invalid issuer url",
  50. config: &OIDCConfig{
  51. Issuer: "not-a-url",
  52. ClientID: "test-client-id",
  53. },
  54. wantErr: true,
  55. },
  56. }
  57. for _, tt := range tests {
  58. t.Run(tt.name, func(t *testing.T) {
  59. provider := NewOIDCProvider("test-provider")
  60. err := provider.Initialize(tt.config)
  61. if tt.wantErr {
  62. assert.Error(t, err)
  63. } else {
  64. assert.NoError(t, err)
  65. assert.Equal(t, "test-provider", provider.Name())
  66. }
  67. })
  68. }
  69. }
  70. // TestOIDCProviderJWTValidation tests JWT token validation
  71. func TestOIDCProviderJWTValidation(t *testing.T) {
  72. // Set up test server with JWKS endpoint
  73. privateKey, publicKey := generateTestKeys(t)
  74. jwks := map[string]interface{}{
  75. "keys": []map[string]interface{}{
  76. {
  77. "kty": "RSA",
  78. "kid": "test-key-id",
  79. "use": "sig",
  80. "alg": "RS256",
  81. "n": encodePublicKey(t, publicKey),
  82. "e": "AQAB",
  83. },
  84. },
  85. }
  86. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  87. if r.URL.Path == "/.well-known/openid_configuration" {
  88. config := map[string]interface{}{
  89. "issuer": "http://" + r.Host,
  90. "jwks_uri": "http://" + r.Host + "/jwks",
  91. }
  92. json.NewEncoder(w).Encode(config)
  93. } else if r.URL.Path == "/jwks" {
  94. json.NewEncoder(w).Encode(jwks)
  95. }
  96. }))
  97. defer server.Close()
  98. provider := NewOIDCProvider("test-oidc")
  99. config := &OIDCConfig{
  100. Issuer: server.URL,
  101. ClientID: "test-client",
  102. JWKSUri: server.URL + "/jwks",
  103. }
  104. err := provider.Initialize(config)
  105. require.NoError(t, err)
  106. t.Run("valid token", func(t *testing.T) {
  107. // Create valid JWT token
  108. token := createTestJWT(t, privateKey, jwt.MapClaims{
  109. "iss": server.URL,
  110. "aud": "test-client",
  111. "sub": "user123",
  112. "exp": time.Now().Add(time.Hour).Unix(),
  113. "iat": time.Now().Unix(),
  114. "email": "user@example.com",
  115. "name": "Test User",
  116. })
  117. claims, err := provider.ValidateToken(context.Background(), token)
  118. require.NoError(t, err)
  119. require.NotNil(t, claims)
  120. assert.Equal(t, "user123", claims.Subject)
  121. assert.Equal(t, server.URL, claims.Issuer)
  122. email, exists := claims.GetClaimString("email")
  123. assert.True(t, exists)
  124. assert.Equal(t, "user@example.com", email)
  125. })
  126. t.Run("valid token with array audience", func(t *testing.T) {
  127. // Create valid JWT token with audience as an array (per RFC 7519)
  128. token := createTestJWT(t, privateKey, jwt.MapClaims{
  129. "iss": server.URL,
  130. "aud": []string{"test-client", "another-client"},
  131. "sub": "user456",
  132. "exp": time.Now().Add(time.Hour).Unix(),
  133. "iat": time.Now().Unix(),
  134. "email": "user2@example.com",
  135. "name": "Test User 2",
  136. })
  137. claims, err := provider.ValidateToken(context.Background(), token)
  138. require.NoError(t, err)
  139. require.NotNil(t, claims)
  140. assert.Equal(t, "user456", claims.Subject)
  141. assert.Equal(t, server.URL, claims.Issuer)
  142. email, exists := claims.GetClaimString("email")
  143. assert.True(t, exists)
  144. assert.Equal(t, "user2@example.com", email)
  145. })
  146. t.Run("expired token", func(t *testing.T) {
  147. // Create expired JWT token
  148. token := createTestJWT(t, privateKey, jwt.MapClaims{
  149. "iss": server.URL,
  150. "aud": "test-client",
  151. "sub": "user123",
  152. "exp": time.Now().Add(-time.Hour).Unix(), // Expired
  153. "iat": time.Now().Add(-time.Hour * 2).Unix(),
  154. })
  155. _, err := provider.ValidateToken(context.Background(), token)
  156. assert.Error(t, err)
  157. assert.Contains(t, err.Error(), "expired")
  158. })
  159. t.Run("invalid signature", func(t *testing.T) {
  160. // Create token with wrong key
  161. wrongKey, _ := generateTestKeys(t)
  162. token := createTestJWT(t, wrongKey, jwt.MapClaims{
  163. "iss": server.URL,
  164. "aud": "test-client",
  165. "sub": "user123",
  166. "exp": time.Now().Add(time.Hour).Unix(),
  167. "iat": time.Now().Unix(),
  168. })
  169. _, err := provider.ValidateToken(context.Background(), token)
  170. assert.Error(t, err)
  171. })
  172. }
  173. // TestOIDCProviderAuthentication tests authentication flow
  174. func TestOIDCProviderAuthentication(t *testing.T) {
  175. // Set up test OIDC provider
  176. privateKey, publicKey := generateTestKeys(t)
  177. server := setupOIDCTestServer(t, publicKey)
  178. defer server.Close()
  179. provider := NewOIDCProvider("test-oidc")
  180. config := &OIDCConfig{
  181. Issuer: server.URL,
  182. ClientID: "test-client",
  183. JWKSUri: server.URL + "/jwks",
  184. RoleMapping: &providers.RoleMapping{
  185. Rules: []providers.MappingRule{
  186. {
  187. Claim: "email",
  188. Value: "*@example.com",
  189. Role: "arn:seaweed:iam::role/UserRole",
  190. },
  191. {
  192. Claim: "groups",
  193. Value: "admins",
  194. Role: "arn:seaweed:iam::role/AdminRole",
  195. },
  196. },
  197. DefaultRole: "arn:seaweed:iam::role/GuestRole",
  198. },
  199. }
  200. err := provider.Initialize(config)
  201. require.NoError(t, err)
  202. t.Run("successful authentication", func(t *testing.T) {
  203. token := createTestJWT(t, privateKey, jwt.MapClaims{
  204. "iss": server.URL,
  205. "aud": "test-client",
  206. "sub": "user123",
  207. "exp": time.Now().Add(time.Hour).Unix(),
  208. "iat": time.Now().Unix(),
  209. "email": "user@example.com",
  210. "name": "Test User",
  211. "groups": []string{"users", "developers"},
  212. })
  213. identity, err := provider.Authenticate(context.Background(), token)
  214. require.NoError(t, err)
  215. require.NotNil(t, identity)
  216. assert.Equal(t, "user123", identity.UserID)
  217. assert.Equal(t, "user@example.com", identity.Email)
  218. assert.Equal(t, "Test User", identity.DisplayName)
  219. assert.Equal(t, "test-oidc", identity.Provider)
  220. assert.Contains(t, identity.Groups, "users")
  221. assert.Contains(t, identity.Groups, "developers")
  222. })
  223. t.Run("authentication with invalid token", func(t *testing.T) {
  224. _, err := provider.Authenticate(context.Background(), "invalid-token")
  225. assert.Error(t, err)
  226. })
  227. }
  228. // TestOIDCProviderUserInfo tests user info retrieval
  229. func TestOIDCProviderUserInfo(t *testing.T) {
  230. // Set up test server with UserInfo endpoint
  231. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  232. if r.URL.Path == "/userinfo" {
  233. // Check for Authorization header
  234. authHeader := r.Header.Get("Authorization")
  235. if !strings.HasPrefix(authHeader, "Bearer ") {
  236. w.WriteHeader(http.StatusUnauthorized)
  237. w.Write([]byte(`{"error": "unauthorized"}`))
  238. return
  239. }
  240. accessToken := strings.TrimPrefix(authHeader, "Bearer ")
  241. // Return 401 for explicitly invalid tokens
  242. if accessToken == "invalid-token" {
  243. w.WriteHeader(http.StatusUnauthorized)
  244. w.Write([]byte(`{"error": "invalid_token"}`))
  245. return
  246. }
  247. // Mock user info response
  248. userInfo := map[string]interface{}{
  249. "sub": "user123",
  250. "email": "user@example.com",
  251. "name": "Test User",
  252. "groups": []string{"users", "developers"},
  253. }
  254. // Customize response based on token
  255. if strings.Contains(accessToken, "admin") {
  256. userInfo["groups"] = []string{"admins"}
  257. }
  258. w.Header().Set("Content-Type", "application/json")
  259. json.NewEncoder(w).Encode(userInfo)
  260. }
  261. }))
  262. defer server.Close()
  263. provider := NewOIDCProvider("test-oidc")
  264. config := &OIDCConfig{
  265. Issuer: server.URL,
  266. ClientID: "test-client",
  267. UserInfoUri: server.URL + "/userinfo",
  268. }
  269. err := provider.Initialize(config)
  270. require.NoError(t, err)
  271. t.Run("get user info with access token", func(t *testing.T) {
  272. // Test using access token (real UserInfo endpoint call)
  273. identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token")
  274. require.NoError(t, err)
  275. require.NotNil(t, identity)
  276. assert.Equal(t, "user123", identity.UserID)
  277. assert.Equal(t, "user@example.com", identity.Email)
  278. assert.Equal(t, "Test User", identity.DisplayName)
  279. assert.Contains(t, identity.Groups, "users")
  280. assert.Contains(t, identity.Groups, "developers")
  281. assert.Equal(t, "test-oidc", identity.Provider)
  282. })
  283. t.Run("get admin user info", func(t *testing.T) {
  284. // Test admin token response
  285. identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token")
  286. require.NoError(t, err)
  287. require.NotNil(t, identity)
  288. assert.Equal(t, "user123", identity.UserID)
  289. assert.Contains(t, identity.Groups, "admins")
  290. })
  291. t.Run("get user info without token", func(t *testing.T) {
  292. // Test without access token (should fail)
  293. _, err := provider.GetUserInfoWithToken(context.Background(), "")
  294. assert.Error(t, err)
  295. assert.Contains(t, err.Error(), "access token cannot be empty")
  296. })
  297. t.Run("get user info with invalid token", func(t *testing.T) {
  298. // Test with invalid access token (should get 401)
  299. _, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token")
  300. assert.Error(t, err)
  301. assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401")
  302. })
  303. t.Run("get user info with custom claims mapping", func(t *testing.T) {
  304. // Create provider with custom claims mapping
  305. customProvider := NewOIDCProvider("test-custom-oidc")
  306. customConfig := &OIDCConfig{
  307. Issuer: server.URL,
  308. ClientID: "test-client",
  309. UserInfoUri: server.URL + "/userinfo",
  310. ClaimsMapping: map[string]string{
  311. "customEmail": "email",
  312. "customName": "name",
  313. },
  314. }
  315. err := customProvider.Initialize(customConfig)
  316. require.NoError(t, err)
  317. identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token")
  318. require.NoError(t, err)
  319. require.NotNil(t, identity)
  320. // Standard claims should still work
  321. assert.Equal(t, "user123", identity.UserID)
  322. assert.Equal(t, "user@example.com", identity.Email)
  323. assert.Equal(t, "Test User", identity.DisplayName)
  324. })
  325. t.Run("get user info with empty id", func(t *testing.T) {
  326. _, err := provider.GetUserInfo(context.Background(), "")
  327. assert.Error(t, err)
  328. })
  329. }
  330. // Helper functions for testing
  331. func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
  332. privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  333. require.NoError(t, err)
  334. return privateKey, &privateKey.PublicKey
  335. }
  336. func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
  337. token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
  338. token.Header["kid"] = "test-key-id"
  339. tokenString, err := token.SignedString(privateKey)
  340. require.NoError(t, err)
  341. return tokenString
  342. }
  343. func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
  344. // Properly encode the RSA modulus (N) as base64url
  345. return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
  346. }
  347. func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
  348. jwks := map[string]interface{}{
  349. "keys": []map[string]interface{}{
  350. {
  351. "kty": "RSA",
  352. "kid": "test-key-id",
  353. "use": "sig",
  354. "alg": "RS256",
  355. "n": encodePublicKey(t, publicKey),
  356. "e": "AQAB",
  357. },
  358. },
  359. }
  360. return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  361. switch r.URL.Path {
  362. case "/.well-known/openid_configuration":
  363. config := map[string]interface{}{
  364. "issuer": "http://" + r.Host,
  365. "jwks_uri": "http://" + r.Host + "/jwks",
  366. "userinfo_endpoint": "http://" + r.Host + "/userinfo",
  367. }
  368. json.NewEncoder(w).Encode(config)
  369. case "/jwks":
  370. json.NewEncoder(w).Encode(jwks)
  371. case "/userinfo":
  372. // Mock UserInfo endpoint
  373. authHeader := r.Header.Get("Authorization")
  374. if !strings.HasPrefix(authHeader, "Bearer ") {
  375. w.WriteHeader(http.StatusUnauthorized)
  376. w.Write([]byte(`{"error": "unauthorized"}`))
  377. return
  378. }
  379. accessToken := strings.TrimPrefix(authHeader, "Bearer ")
  380. // Return 401 for explicitly invalid tokens
  381. if accessToken == "invalid-token" {
  382. w.WriteHeader(http.StatusUnauthorized)
  383. w.Write([]byte(`{"error": "invalid_token"}`))
  384. return
  385. }
  386. // Mock user info response based on access token
  387. userInfo := map[string]interface{}{
  388. "sub": "user123",
  389. "email": "user@example.com",
  390. "name": "Test User",
  391. "groups": []string{"users", "developers"},
  392. }
  393. // Customize response based on token
  394. if strings.Contains(accessToken, "admin") {
  395. userInfo["groups"] = []string{"admins"}
  396. }
  397. w.Header().Set("Content-Type", "application/json")
  398. json.NewEncoder(w).Encode(userInfo)
  399. default:
  400. http.NotFound(w, r)
  401. }
  402. }))
  403. }