provider_factory_test.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. package sts
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/stretchr/testify/require"
  7. )
  8. func TestProviderFactory_CreateOIDCProvider(t *testing.T) {
  9. factory := NewProviderFactory()
  10. config := &ProviderConfig{
  11. Name: "test-oidc",
  12. Type: "oidc",
  13. Enabled: true,
  14. Config: map[string]interface{}{
  15. "issuer": "https://test-issuer.com",
  16. "clientId": "test-client",
  17. "clientSecret": "test-secret",
  18. "jwksUri": "https://test-issuer.com/.well-known/jwks.json",
  19. "scopes": []string{"openid", "profile", "email"},
  20. },
  21. }
  22. provider, err := factory.CreateProvider(config)
  23. require.NoError(t, err)
  24. assert.NotNil(t, provider)
  25. assert.Equal(t, "test-oidc", provider.Name())
  26. }
  27. // Note: Mock provider tests removed - mock providers are now test-only
  28. // and not available through the production ProviderFactory
  29. func TestProviderFactory_DisabledProvider(t *testing.T) {
  30. factory := NewProviderFactory()
  31. config := &ProviderConfig{
  32. Name: "disabled-provider",
  33. Type: "oidc",
  34. Enabled: false,
  35. Config: map[string]interface{}{
  36. "issuer": "https://test-issuer.com",
  37. "clientId": "test-client",
  38. },
  39. }
  40. provider, err := factory.CreateProvider(config)
  41. require.NoError(t, err)
  42. assert.Nil(t, provider) // Should return nil for disabled providers
  43. }
  44. func TestProviderFactory_InvalidProviderType(t *testing.T) {
  45. factory := NewProviderFactory()
  46. config := &ProviderConfig{
  47. Name: "invalid-provider",
  48. Type: "unsupported-type",
  49. Enabled: true,
  50. Config: map[string]interface{}{},
  51. }
  52. provider, err := factory.CreateProvider(config)
  53. assert.Error(t, err)
  54. assert.Nil(t, provider)
  55. assert.Contains(t, err.Error(), "unsupported provider type")
  56. }
  57. func TestProviderFactory_LoadMultipleProviders(t *testing.T) {
  58. factory := NewProviderFactory()
  59. configs := []*ProviderConfig{
  60. {
  61. Name: "oidc-provider",
  62. Type: "oidc",
  63. Enabled: true,
  64. Config: map[string]interface{}{
  65. "issuer": "https://oidc-issuer.com",
  66. "clientId": "oidc-client",
  67. },
  68. },
  69. {
  70. Name: "disabled-provider",
  71. Type: "oidc",
  72. Enabled: false,
  73. Config: map[string]interface{}{
  74. "issuer": "https://disabled-issuer.com",
  75. "clientId": "disabled-client",
  76. },
  77. },
  78. }
  79. providers, err := factory.LoadProvidersFromConfig(configs)
  80. require.NoError(t, err)
  81. assert.Len(t, providers, 1) // Only enabled providers should be loaded
  82. assert.Contains(t, providers, "oidc-provider")
  83. assert.NotContains(t, providers, "disabled-provider")
  84. }
  85. func TestProviderFactory_ValidateOIDCConfig(t *testing.T) {
  86. factory := NewProviderFactory()
  87. t.Run("valid config", func(t *testing.T) {
  88. config := &ProviderConfig{
  89. Name: "valid-oidc",
  90. Type: "oidc",
  91. Enabled: true,
  92. Config: map[string]interface{}{
  93. "issuer": "https://valid-issuer.com",
  94. "clientId": "valid-client",
  95. },
  96. }
  97. err := factory.ValidateProviderConfig(config)
  98. assert.NoError(t, err)
  99. })
  100. t.Run("missing issuer", func(t *testing.T) {
  101. config := &ProviderConfig{
  102. Name: "invalid-oidc",
  103. Type: "oidc",
  104. Enabled: true,
  105. Config: map[string]interface{}{
  106. "clientId": "valid-client",
  107. },
  108. }
  109. err := factory.ValidateProviderConfig(config)
  110. assert.Error(t, err)
  111. assert.Contains(t, err.Error(), "issuer")
  112. })
  113. t.Run("missing clientId", func(t *testing.T) {
  114. config := &ProviderConfig{
  115. Name: "invalid-oidc",
  116. Type: "oidc",
  117. Enabled: true,
  118. Config: map[string]interface{}{
  119. "issuer": "https://valid-issuer.com",
  120. },
  121. }
  122. err := factory.ValidateProviderConfig(config)
  123. assert.Error(t, err)
  124. assert.Contains(t, err.Error(), "clientId")
  125. })
  126. }
  127. func TestProviderFactory_ConvertToStringSlice(t *testing.T) {
  128. factory := NewProviderFactory()
  129. t.Run("string slice", func(t *testing.T) {
  130. input := []string{"a", "b", "c"}
  131. result, err := factory.convertToStringSlice(input)
  132. require.NoError(t, err)
  133. assert.Equal(t, []string{"a", "b", "c"}, result)
  134. })
  135. t.Run("interface slice", func(t *testing.T) {
  136. input := []interface{}{"a", "b", "c"}
  137. result, err := factory.convertToStringSlice(input)
  138. require.NoError(t, err)
  139. assert.Equal(t, []string{"a", "b", "c"}, result)
  140. })
  141. t.Run("invalid type", func(t *testing.T) {
  142. input := "not-a-slice"
  143. result, err := factory.convertToStringSlice(input)
  144. assert.Error(t, err)
  145. assert.Nil(t, result)
  146. })
  147. }
  148. func TestProviderFactory_ConfigConversionErrors(t *testing.T) {
  149. factory := NewProviderFactory()
  150. t.Run("invalid scopes type", func(t *testing.T) {
  151. config := &ProviderConfig{
  152. Name: "invalid-scopes",
  153. Type: "oidc",
  154. Enabled: true,
  155. Config: map[string]interface{}{
  156. "issuer": "https://test-issuer.com",
  157. "clientId": "test-client",
  158. "scopes": "invalid-not-array", // Should be array
  159. },
  160. }
  161. provider, err := factory.CreateProvider(config)
  162. assert.Error(t, err)
  163. assert.Nil(t, provider)
  164. assert.Contains(t, err.Error(), "failed to convert scopes")
  165. })
  166. t.Run("invalid claimsMapping type", func(t *testing.T) {
  167. config := &ProviderConfig{
  168. Name: "invalid-claims",
  169. Type: "oidc",
  170. Enabled: true,
  171. Config: map[string]interface{}{
  172. "issuer": "https://test-issuer.com",
  173. "clientId": "test-client",
  174. "claimsMapping": "invalid-not-map", // Should be map
  175. },
  176. }
  177. provider, err := factory.CreateProvider(config)
  178. assert.Error(t, err)
  179. assert.Nil(t, provider)
  180. assert.Contains(t, err.Error(), "failed to convert claimsMapping")
  181. })
  182. t.Run("invalid roleMapping type", func(t *testing.T) {
  183. config := &ProviderConfig{
  184. Name: "invalid-roles",
  185. Type: "oidc",
  186. Enabled: true,
  187. Config: map[string]interface{}{
  188. "issuer": "https://test-issuer.com",
  189. "clientId": "test-client",
  190. "roleMapping": "invalid-not-map", // Should be map
  191. },
  192. }
  193. provider, err := factory.CreateProvider(config)
  194. assert.Error(t, err)
  195. assert.Nil(t, provider)
  196. assert.Contains(t, err.Error(), "failed to convert roleMapping")
  197. })
  198. }
  199. func TestProviderFactory_ConvertToStringMap(t *testing.T) {
  200. factory := NewProviderFactory()
  201. t.Run("string map", func(t *testing.T) {
  202. input := map[string]string{"key1": "value1", "key2": "value2"}
  203. result, err := factory.convertToStringMap(input)
  204. require.NoError(t, err)
  205. assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
  206. })
  207. t.Run("interface map", func(t *testing.T) {
  208. input := map[string]interface{}{"key1": "value1", "key2": "value2"}
  209. result, err := factory.convertToStringMap(input)
  210. require.NoError(t, err)
  211. assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
  212. })
  213. t.Run("invalid type", func(t *testing.T) {
  214. input := "not-a-map"
  215. result, err := factory.convertToStringMap(input)
  216. assert.Error(t, err)
  217. assert.Nil(t, result)
  218. })
  219. }
  220. func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) {
  221. factory := NewProviderFactory()
  222. supportedTypes := factory.GetSupportedProviderTypes()
  223. assert.Contains(t, supportedTypes, "oidc")
  224. assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production
  225. }
  226. func TestSTSService_LoadProvidersFromConfig(t *testing.T) {
  227. stsConfig := &STSConfig{
  228. TokenDuration: FlexibleDuration{3600 * time.Second},
  229. MaxSessionLength: FlexibleDuration{43200 * time.Second},
  230. Issuer: "test-issuer",
  231. SigningKey: []byte("test-signing-key-32-characters-long"),
  232. Providers: []*ProviderConfig{
  233. {
  234. Name: "test-provider",
  235. Type: "oidc",
  236. Enabled: true,
  237. Config: map[string]interface{}{
  238. "issuer": "https://test-issuer.com",
  239. "clientId": "test-client",
  240. },
  241. },
  242. },
  243. }
  244. stsService := NewSTSService()
  245. err := stsService.Initialize(stsConfig)
  246. require.NoError(t, err)
  247. // Check that provider was loaded
  248. assert.Len(t, stsService.providers, 1)
  249. assert.Contains(t, stsService.providers, "test-provider")
  250. assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name())
  251. }
  252. func TestSTSService_NoProvidersConfig(t *testing.T) {
  253. stsConfig := &STSConfig{
  254. TokenDuration: FlexibleDuration{3600 * time.Second},
  255. MaxSessionLength: FlexibleDuration{43200 * time.Second},
  256. Issuer: "test-issuer",
  257. SigningKey: []byte("test-signing-key-32-characters-long"),
  258. // No providers configured
  259. }
  260. stsService := NewSTSService()
  261. err := stsService.Initialize(stsConfig)
  262. require.NoError(t, err)
  263. // Should initialize successfully with no providers
  264. assert.Len(t, stsService.providers, 0)
  265. }