registry.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package kms
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "sync"
  7. "github.com/seaweedfs/seaweedfs/weed/util"
  8. )
  9. // ProviderRegistry manages KMS provider implementations
  10. type ProviderRegistry struct {
  11. mu sync.RWMutex
  12. providers map[string]ProviderFactory
  13. instances map[string]KMSProvider
  14. }
  15. // ProviderFactory creates a new KMS provider instance
  16. type ProviderFactory func(config util.Configuration) (KMSProvider, error)
  17. var defaultRegistry = NewProviderRegistry()
  18. // NewProviderRegistry creates a new provider registry
  19. func NewProviderRegistry() *ProviderRegistry {
  20. return &ProviderRegistry{
  21. providers: make(map[string]ProviderFactory),
  22. instances: make(map[string]KMSProvider),
  23. }
  24. }
  25. // RegisterProvider registers a new KMS provider factory
  26. func RegisterProvider(name string, factory ProviderFactory) {
  27. defaultRegistry.RegisterProvider(name, factory)
  28. }
  29. // RegisterProvider registers a new KMS provider factory in this registry
  30. func (r *ProviderRegistry) RegisterProvider(name string, factory ProviderFactory) {
  31. r.mu.Lock()
  32. defer r.mu.Unlock()
  33. r.providers[name] = factory
  34. }
  35. // GetProvider returns a KMS provider instance, creating it if necessary
  36. func GetProvider(name string, config util.Configuration) (KMSProvider, error) {
  37. return defaultRegistry.GetProvider(name, config)
  38. }
  39. // GetProvider returns a KMS provider instance, creating it if necessary
  40. func (r *ProviderRegistry) GetProvider(name string, config util.Configuration) (KMSProvider, error) {
  41. r.mu.Lock()
  42. defer r.mu.Unlock()
  43. // Return existing instance if available
  44. if instance, exists := r.instances[name]; exists {
  45. return instance, nil
  46. }
  47. // Find the factory
  48. factory, exists := r.providers[name]
  49. if !exists {
  50. return nil, fmt.Errorf("KMS provider '%s' not registered", name)
  51. }
  52. // Create new instance
  53. instance, err := factory(config)
  54. if err != nil {
  55. return nil, fmt.Errorf("failed to create KMS provider '%s': %v", name, err)
  56. }
  57. // Cache the instance
  58. r.instances[name] = instance
  59. return instance, nil
  60. }
  61. // ListProviders returns the names of all registered providers
  62. func ListProviders() []string {
  63. return defaultRegistry.ListProviders()
  64. }
  65. // ListProviders returns the names of all registered providers
  66. func (r *ProviderRegistry) ListProviders() []string {
  67. r.mu.RLock()
  68. defer r.mu.RUnlock()
  69. names := make([]string, 0, len(r.providers))
  70. for name := range r.providers {
  71. names = append(names, name)
  72. }
  73. return names
  74. }
  75. // CloseAll closes all provider instances
  76. func CloseAll() error {
  77. return defaultRegistry.CloseAll()
  78. }
  79. // CloseAll closes all provider instances in this registry
  80. func (r *ProviderRegistry) CloseAll() error {
  81. r.mu.Lock()
  82. defer r.mu.Unlock()
  83. var allErrors []error
  84. for name, instance := range r.instances {
  85. if err := instance.Close(); err != nil {
  86. allErrors = append(allErrors, fmt.Errorf("failed to close KMS provider '%s': %w", name, err))
  87. }
  88. }
  89. // Clear the instances map
  90. r.instances = make(map[string]KMSProvider)
  91. return errors.Join(allErrors...)
  92. }
  93. // WithKMSProvider is a helper function to execute code with a KMS provider
  94. func WithKMSProvider(name string, config util.Configuration, fn func(KMSProvider) error) error {
  95. provider, err := GetProvider(name, config)
  96. if err != nil {
  97. return err
  98. }
  99. return fn(provider)
  100. }
  101. // TestKMSConnection tests the connection to a KMS provider
  102. func TestKMSConnection(ctx context.Context, provider KMSProvider, testKeyID string) error {
  103. if provider == nil {
  104. return fmt.Errorf("KMS provider is nil")
  105. }
  106. // Try to describe a test key to verify connectivity
  107. _, err := provider.DescribeKey(ctx, &DescribeKeyRequest{
  108. KeyID: testKeyID,
  109. })
  110. if err != nil {
  111. // If the key doesn't exist, that's still a successful connection test
  112. if kmsErr, ok := err.(*KMSError); ok && kmsErr.Code == ErrCodeNotFoundException {
  113. return nil
  114. }
  115. return fmt.Errorf("KMS connection test failed: %v", err)
  116. }
  117. return nil
  118. }