| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- package kms
- import (
- "context"
- "errors"
- "fmt"
- "sync"
- "github.com/seaweedfs/seaweedfs/weed/util"
- )
- // ProviderRegistry manages KMS provider implementations
- type ProviderRegistry struct {
- mu sync.RWMutex
- providers map[string]ProviderFactory
- instances map[string]KMSProvider
- }
- // ProviderFactory creates a new KMS provider instance
- type ProviderFactory func(config util.Configuration) (KMSProvider, error)
- var defaultRegistry = NewProviderRegistry()
- // NewProviderRegistry creates a new provider registry
- func NewProviderRegistry() *ProviderRegistry {
- return &ProviderRegistry{
- providers: make(map[string]ProviderFactory),
- instances: make(map[string]KMSProvider),
- }
- }
- // RegisterProvider registers a new KMS provider factory
- func RegisterProvider(name string, factory ProviderFactory) {
- defaultRegistry.RegisterProvider(name, factory)
- }
- // RegisterProvider registers a new KMS provider factory in this registry
- func (r *ProviderRegistry) RegisterProvider(name string, factory ProviderFactory) {
- r.mu.Lock()
- defer r.mu.Unlock()
- r.providers[name] = factory
- }
- // GetProvider returns a KMS provider instance, creating it if necessary
- func GetProvider(name string, config util.Configuration) (KMSProvider, error) {
- return defaultRegistry.GetProvider(name, config)
- }
- // GetProvider returns a KMS provider instance, creating it if necessary
- func (r *ProviderRegistry) GetProvider(name string, config util.Configuration) (KMSProvider, error) {
- r.mu.Lock()
- defer r.mu.Unlock()
- // Return existing instance if available
- if instance, exists := r.instances[name]; exists {
- return instance, nil
- }
- // Find the factory
- factory, exists := r.providers[name]
- if !exists {
- return nil, fmt.Errorf("KMS provider '%s' not registered", name)
- }
- // Create new instance
- instance, err := factory(config)
- if err != nil {
- return nil, fmt.Errorf("failed to create KMS provider '%s': %v", name, err)
- }
- // Cache the instance
- r.instances[name] = instance
- return instance, nil
- }
- // ListProviders returns the names of all registered providers
- func ListProviders() []string {
- return defaultRegistry.ListProviders()
- }
- // ListProviders returns the names of all registered providers
- func (r *ProviderRegistry) ListProviders() []string {
- r.mu.RLock()
- defer r.mu.RUnlock()
- names := make([]string, 0, len(r.providers))
- for name := range r.providers {
- names = append(names, name)
- }
- return names
- }
- // CloseAll closes all provider instances
- func CloseAll() error {
- return defaultRegistry.CloseAll()
- }
- // CloseAll closes all provider instances in this registry
- func (r *ProviderRegistry) CloseAll() error {
- r.mu.Lock()
- defer r.mu.Unlock()
- var allErrors []error
- for name, instance := range r.instances {
- if err := instance.Close(); err != nil {
- allErrors = append(allErrors, fmt.Errorf("failed to close KMS provider '%s': %w", name, err))
- }
- }
- // Clear the instances map
- r.instances = make(map[string]KMSProvider)
- return errors.Join(allErrors...)
- }
- // WithKMSProvider is a helper function to execute code with a KMS provider
- func WithKMSProvider(name string, config util.Configuration, fn func(KMSProvider) error) error {
- provider, err := GetProvider(name, config)
- if err != nil {
- return err
- }
- return fn(provider)
- }
- // TestKMSConnection tests the connection to a KMS provider
- func TestKMSConnection(ctx context.Context, provider KMSProvider, testKeyID string) error {
- if provider == nil {
- return fmt.Errorf("KMS provider is nil")
- }
- // Try to describe a test key to verify connectivity
- _, err := provider.DescribeKey(ctx, &DescribeKeyRequest{
- KeyID: testKeyID,
- })
- if err != nil {
- // If the key doesn't exist, that's still a successful connection test
- if kmsErr, ok := err.(*KMSError); ok && kmsErr.Code == ErrCodeNotFoundException {
- return nil
- }
- return fmt.Errorf("KMS connection test failed: %v", err)
- }
- return nil
- }
|