local_kms.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. package local
  2. import (
  3. "context"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "crypto/rand"
  7. "encoding/json"
  8. "fmt"
  9. "io"
  10. "sort"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/seaweedfs/seaweedfs/weed/glog"
  15. "github.com/seaweedfs/seaweedfs/weed/kms"
  16. "github.com/seaweedfs/seaweedfs/weed/util"
  17. )
  18. // LocalKMSProvider implements a local, in-memory KMS for development and testing
  19. // WARNING: This is NOT suitable for production use - keys are stored in memory
  20. type LocalKMSProvider struct {
  21. mu sync.RWMutex
  22. keys map[string]*LocalKey
  23. defaultKeyID string
  24. enableOnDemandCreate bool // Whether to create keys on-demand for missing key IDs
  25. }
  26. // LocalKey represents a key stored in the local KMS
  27. type LocalKey struct {
  28. KeyID string `json:"keyId"`
  29. ARN string `json:"arn"`
  30. Description string `json:"description"`
  31. KeyMaterial []byte `json:"keyMaterial"` // 256-bit master key
  32. KeyUsage kms.KeyUsage `json:"keyUsage"`
  33. KeyState kms.KeyState `json:"keyState"`
  34. Origin kms.KeyOrigin `json:"origin"`
  35. CreatedAt time.Time `json:"createdAt"`
  36. Aliases []string `json:"aliases"`
  37. Metadata map[string]string `json:"metadata"`
  38. }
  39. // LocalKMSConfig contains configuration for the local KMS provider
  40. type LocalKMSConfig struct {
  41. DefaultKeyID string `json:"defaultKeyId"`
  42. Keys map[string]*LocalKey `json:"keys"`
  43. EnableOnDemandCreate bool `json:"enableOnDemandCreate"`
  44. }
  45. func init() {
  46. // Register the local KMS provider
  47. kms.RegisterProvider("local", NewLocalKMSProvider)
  48. }
  49. // NewLocalKMSProvider creates a new local KMS provider
  50. func NewLocalKMSProvider(config util.Configuration) (kms.KMSProvider, error) {
  51. provider := &LocalKMSProvider{
  52. keys: make(map[string]*LocalKey),
  53. enableOnDemandCreate: true, // Default to true for development/testing convenience
  54. }
  55. // Load configuration if provided
  56. if config != nil {
  57. if err := provider.loadConfig(config); err != nil {
  58. return nil, fmt.Errorf("failed to load local KMS config: %v", err)
  59. }
  60. }
  61. // Create a default key if none exists
  62. if len(provider.keys) == 0 {
  63. defaultKey, err := provider.createDefaultKey()
  64. if err != nil {
  65. return nil, fmt.Errorf("failed to create default key: %v", err)
  66. }
  67. provider.defaultKeyID = defaultKey.KeyID
  68. glog.V(1).Infof("Local KMS: Created default key %s", defaultKey.KeyID)
  69. }
  70. return provider, nil
  71. }
  72. // loadConfig loads configuration from the provided config
  73. func (p *LocalKMSProvider) loadConfig(config util.Configuration) error {
  74. if config == nil {
  75. return nil
  76. }
  77. p.enableOnDemandCreate = config.GetBool("enableOnDemandCreate")
  78. // TODO: Load pre-existing keys from configuration if provided
  79. // For now, rely on default key creation in constructor
  80. glog.V(2).Infof("Local KMS: enableOnDemandCreate = %v", p.enableOnDemandCreate)
  81. return nil
  82. }
  83. // createDefaultKey creates a default master key for the local KMS
  84. func (p *LocalKMSProvider) createDefaultKey() (*LocalKey, error) {
  85. keyID, err := generateKeyID()
  86. if err != nil {
  87. return nil, fmt.Errorf("failed to generate key ID: %w", err)
  88. }
  89. keyMaterial := make([]byte, 32) // 256-bit key
  90. if _, err := io.ReadFull(rand.Reader, keyMaterial); err != nil {
  91. return nil, fmt.Errorf("failed to generate key material: %w", err)
  92. }
  93. key := &LocalKey{
  94. KeyID: keyID,
  95. ARN: fmt.Sprintf("arn:aws:kms:local:000000000000:key/%s", keyID),
  96. Description: "Default local KMS key for SeaweedFS",
  97. KeyMaterial: keyMaterial,
  98. KeyUsage: kms.KeyUsageEncryptDecrypt,
  99. KeyState: kms.KeyStateEnabled,
  100. Origin: kms.KeyOriginLocal,
  101. CreatedAt: time.Now(),
  102. Aliases: []string{"alias/seaweedfs-default"},
  103. Metadata: make(map[string]string),
  104. }
  105. p.mu.Lock()
  106. defer p.mu.Unlock()
  107. p.keys[keyID] = key
  108. // Also register aliases
  109. for _, alias := range key.Aliases {
  110. p.keys[alias] = key
  111. }
  112. return key, nil
  113. }
  114. // GenerateDataKey implements the KMSProvider interface
  115. func (p *LocalKMSProvider) GenerateDataKey(ctx context.Context, req *kms.GenerateDataKeyRequest) (*kms.GenerateDataKeyResponse, error) {
  116. if req.KeySpec != kms.KeySpecAES256 {
  117. return nil, &kms.KMSError{
  118. Code: kms.ErrCodeInvalidKeyUsage,
  119. Message: fmt.Sprintf("Unsupported key spec: %s", req.KeySpec),
  120. KeyID: req.KeyID,
  121. }
  122. }
  123. // Resolve the key
  124. key, err := p.getKey(req.KeyID)
  125. if err != nil {
  126. return nil, err
  127. }
  128. if key.KeyState != kms.KeyStateEnabled {
  129. return nil, &kms.KMSError{
  130. Code: kms.ErrCodeKeyUnavailable,
  131. Message: fmt.Sprintf("Key %s is in state %s", key.KeyID, key.KeyState),
  132. KeyID: key.KeyID,
  133. }
  134. }
  135. // Generate a random 256-bit data key
  136. dataKey := make([]byte, 32)
  137. if _, err := io.ReadFull(rand.Reader, dataKey); err != nil {
  138. return nil, &kms.KMSError{
  139. Code: kms.ErrCodeKMSInternalFailure,
  140. Message: "Failed to generate data key",
  141. KeyID: key.KeyID,
  142. }
  143. }
  144. // Encrypt the data key with the master key
  145. encryptedDataKey, err := p.encryptDataKey(dataKey, key, req.EncryptionContext)
  146. if err != nil {
  147. kms.ClearSensitiveData(dataKey)
  148. return nil, &kms.KMSError{
  149. Code: kms.ErrCodeKMSInternalFailure,
  150. Message: fmt.Sprintf("Failed to encrypt data key: %v", err),
  151. KeyID: key.KeyID,
  152. }
  153. }
  154. return &kms.GenerateDataKeyResponse{
  155. KeyID: key.KeyID,
  156. Plaintext: dataKey,
  157. CiphertextBlob: encryptedDataKey,
  158. }, nil
  159. }
  160. // Decrypt implements the KMSProvider interface
  161. func (p *LocalKMSProvider) Decrypt(ctx context.Context, req *kms.DecryptRequest) (*kms.DecryptResponse, error) {
  162. // Parse the encrypted data key to extract metadata
  163. metadata, err := p.parseEncryptedDataKey(req.CiphertextBlob)
  164. if err != nil {
  165. return nil, &kms.KMSError{
  166. Code: kms.ErrCodeInvalidCiphertext,
  167. Message: fmt.Sprintf("Invalid ciphertext format: %v", err),
  168. }
  169. }
  170. // Verify encryption context matches
  171. if !p.encryptionContextMatches(metadata.EncryptionContext, req.EncryptionContext) {
  172. return nil, &kms.KMSError{
  173. Code: kms.ErrCodeInvalidCiphertext,
  174. Message: "Encryption context mismatch",
  175. KeyID: metadata.KeyID,
  176. }
  177. }
  178. // Get the master key
  179. key, err := p.getKey(metadata.KeyID)
  180. if err != nil {
  181. return nil, err
  182. }
  183. if key.KeyState != kms.KeyStateEnabled {
  184. return nil, &kms.KMSError{
  185. Code: kms.ErrCodeKeyUnavailable,
  186. Message: fmt.Sprintf("Key %s is in state %s", key.KeyID, key.KeyState),
  187. KeyID: key.KeyID,
  188. }
  189. }
  190. // Decrypt the data key
  191. dataKey, err := p.decryptDataKey(metadata, key)
  192. if err != nil {
  193. return nil, &kms.KMSError{
  194. Code: kms.ErrCodeInvalidCiphertext,
  195. Message: fmt.Sprintf("Failed to decrypt data key: %v", err),
  196. KeyID: key.KeyID,
  197. }
  198. }
  199. return &kms.DecryptResponse{
  200. KeyID: key.KeyID,
  201. Plaintext: dataKey,
  202. }, nil
  203. }
  204. // DescribeKey implements the KMSProvider interface
  205. func (p *LocalKMSProvider) DescribeKey(ctx context.Context, req *kms.DescribeKeyRequest) (*kms.DescribeKeyResponse, error) {
  206. key, err := p.getKey(req.KeyID)
  207. if err != nil {
  208. return nil, err
  209. }
  210. return &kms.DescribeKeyResponse{
  211. KeyID: key.KeyID,
  212. ARN: key.ARN,
  213. Description: key.Description,
  214. KeyUsage: key.KeyUsage,
  215. KeyState: key.KeyState,
  216. Origin: key.Origin,
  217. }, nil
  218. }
  219. // GetKeyID implements the KMSProvider interface
  220. func (p *LocalKMSProvider) GetKeyID(ctx context.Context, keyIdentifier string) (string, error) {
  221. key, err := p.getKey(keyIdentifier)
  222. if err != nil {
  223. return "", err
  224. }
  225. return key.KeyID, nil
  226. }
  227. // Close implements the KMSProvider interface
  228. func (p *LocalKMSProvider) Close() error {
  229. p.mu.Lock()
  230. defer p.mu.Unlock()
  231. // Clear all key material from memory
  232. for _, key := range p.keys {
  233. kms.ClearSensitiveData(key.KeyMaterial)
  234. }
  235. p.keys = make(map[string]*LocalKey)
  236. return nil
  237. }
  238. // getKey retrieves a key by ID or alias, creating it on-demand if it doesn't exist
  239. func (p *LocalKMSProvider) getKey(keyIdentifier string) (*LocalKey, error) {
  240. p.mu.RLock()
  241. // Try direct lookup first
  242. if key, exists := p.keys[keyIdentifier]; exists {
  243. p.mu.RUnlock()
  244. return key, nil
  245. }
  246. // Try with default key if no identifier provided
  247. if keyIdentifier == "" && p.defaultKeyID != "" {
  248. if key, exists := p.keys[p.defaultKeyID]; exists {
  249. p.mu.RUnlock()
  250. return key, nil
  251. }
  252. }
  253. p.mu.RUnlock()
  254. // Key doesn't exist - create on-demand if enabled and key identifier is reasonable
  255. if keyIdentifier != "" && p.enableOnDemandCreate && p.isReasonableKeyIdentifier(keyIdentifier) {
  256. glog.V(1).Infof("Creating on-demand local KMS key: %s", keyIdentifier)
  257. key, err := p.CreateKeyWithID(keyIdentifier, fmt.Sprintf("Auto-created local KMS key: %s", keyIdentifier))
  258. if err != nil {
  259. return nil, &kms.KMSError{
  260. Code: kms.ErrCodeKMSInternalFailure,
  261. Message: fmt.Sprintf("Failed to create on-demand key %s: %v", keyIdentifier, err),
  262. KeyID: keyIdentifier,
  263. }
  264. }
  265. return key, nil
  266. }
  267. return nil, &kms.KMSError{
  268. Code: kms.ErrCodeNotFoundException,
  269. Message: fmt.Sprintf("Key not found: %s", keyIdentifier),
  270. KeyID: keyIdentifier,
  271. }
  272. }
  273. // isReasonableKeyIdentifier determines if a key identifier is reasonable for on-demand creation
  274. func (p *LocalKMSProvider) isReasonableKeyIdentifier(keyIdentifier string) bool {
  275. // Basic validation: reasonable length and character set
  276. if len(keyIdentifier) < 3 || len(keyIdentifier) > 100 {
  277. return false
  278. }
  279. // Allow alphanumeric characters, hyphens, underscores, and forward slashes
  280. // This covers most reasonable key identifier formats without being overly restrictive
  281. for _, r := range keyIdentifier {
  282. if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
  283. (r >= '0' && r <= '9') || r == '-' || r == '_' || r == '/') {
  284. return false
  285. }
  286. }
  287. // Reject keys that start or end with separators
  288. if keyIdentifier[0] == '-' || keyIdentifier[0] == '_' || keyIdentifier[0] == '/' ||
  289. keyIdentifier[len(keyIdentifier)-1] == '-' || keyIdentifier[len(keyIdentifier)-1] == '_' || keyIdentifier[len(keyIdentifier)-1] == '/' {
  290. return false
  291. }
  292. return true
  293. }
  294. // encryptedDataKeyMetadata represents the metadata stored with encrypted data keys
  295. type encryptedDataKeyMetadata struct {
  296. KeyID string `json:"keyId"`
  297. EncryptionContext map[string]string `json:"encryptionContext"`
  298. EncryptedData []byte `json:"encryptedData"`
  299. Nonce []byte `json:"nonce"` // Renamed from IV to be more explicit about AES-GCM usage
  300. }
  301. // encryptDataKey encrypts a data key using the master key with AES-GCM for authenticated encryption
  302. func (p *LocalKMSProvider) encryptDataKey(dataKey []byte, masterKey *LocalKey, encryptionContext map[string]string) ([]byte, error) {
  303. block, err := aes.NewCipher(masterKey.KeyMaterial)
  304. if err != nil {
  305. return nil, err
  306. }
  307. gcm, err := cipher.NewGCM(block)
  308. if err != nil {
  309. return nil, err
  310. }
  311. // Generate a random nonce
  312. nonce := make([]byte, gcm.NonceSize())
  313. if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
  314. return nil, err
  315. }
  316. // Prepare additional authenticated data (AAD) from the encryption context
  317. // Use deterministic marshaling to ensure consistent AAD
  318. var aad []byte
  319. if len(encryptionContext) > 0 {
  320. var err error
  321. aad, err = marshalEncryptionContextDeterministic(encryptionContext)
  322. if err != nil {
  323. return nil, fmt.Errorf("failed to marshal encryption context for AAD: %w", err)
  324. }
  325. }
  326. // Encrypt using AES-GCM
  327. encryptedData := gcm.Seal(nil, nonce, dataKey, aad)
  328. // Create metadata structure
  329. metadata := &encryptedDataKeyMetadata{
  330. KeyID: masterKey.KeyID,
  331. EncryptionContext: encryptionContext,
  332. EncryptedData: encryptedData,
  333. Nonce: nonce,
  334. }
  335. // Serialize metadata to JSON
  336. return json.Marshal(metadata)
  337. }
  338. // decryptDataKey decrypts a data key using the master key with AES-GCM for authenticated decryption
  339. func (p *LocalKMSProvider) decryptDataKey(metadata *encryptedDataKeyMetadata, masterKey *LocalKey) ([]byte, error) {
  340. block, err := aes.NewCipher(masterKey.KeyMaterial)
  341. if err != nil {
  342. return nil, err
  343. }
  344. gcm, err := cipher.NewGCM(block)
  345. if err != nil {
  346. return nil, err
  347. }
  348. // Prepare additional authenticated data (AAD)
  349. var aad []byte
  350. if len(metadata.EncryptionContext) > 0 {
  351. var err error
  352. aad, err = marshalEncryptionContextDeterministic(metadata.EncryptionContext)
  353. if err != nil {
  354. return nil, fmt.Errorf("failed to marshal encryption context for AAD: %w", err)
  355. }
  356. }
  357. // Decrypt using AES-GCM
  358. nonce := metadata.Nonce
  359. if len(nonce) != gcm.NonceSize() {
  360. return nil, fmt.Errorf("invalid nonce size: expected %d, got %d", gcm.NonceSize(), len(nonce))
  361. }
  362. dataKey, err := gcm.Open(nil, nonce, metadata.EncryptedData, aad)
  363. if err != nil {
  364. return nil, fmt.Errorf("failed to decrypt with GCM: %w", err)
  365. }
  366. return dataKey, nil
  367. }
  368. // parseEncryptedDataKey parses the encrypted data key blob
  369. func (p *LocalKMSProvider) parseEncryptedDataKey(ciphertextBlob []byte) (*encryptedDataKeyMetadata, error) {
  370. var metadata encryptedDataKeyMetadata
  371. if err := json.Unmarshal(ciphertextBlob, &metadata); err != nil {
  372. return nil, fmt.Errorf("failed to parse ciphertext blob: %v", err)
  373. }
  374. return &metadata, nil
  375. }
  376. // encryptionContextMatches checks if two encryption contexts match
  377. func (p *LocalKMSProvider) encryptionContextMatches(ctx1, ctx2 map[string]string) bool {
  378. if len(ctx1) != len(ctx2) {
  379. return false
  380. }
  381. for k, v := range ctx1 {
  382. if ctx2[k] != v {
  383. return false
  384. }
  385. }
  386. return true
  387. }
  388. // generateKeyID generates a random key ID
  389. func generateKeyID() (string, error) {
  390. // Generate a UUID-like key ID
  391. b := make([]byte, 16)
  392. if _, err := io.ReadFull(rand.Reader, b); err != nil {
  393. return "", fmt.Errorf("failed to generate random bytes for key ID: %w", err)
  394. }
  395. return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
  396. b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]), nil
  397. }
  398. // CreateKey creates a new key in the local KMS (for testing)
  399. func (p *LocalKMSProvider) CreateKey(description string, aliases []string) (*LocalKey, error) {
  400. keyID, err := generateKeyID()
  401. if err != nil {
  402. return nil, fmt.Errorf("failed to generate key ID: %w", err)
  403. }
  404. keyMaterial := make([]byte, 32)
  405. if _, err := io.ReadFull(rand.Reader, keyMaterial); err != nil {
  406. return nil, err
  407. }
  408. key := &LocalKey{
  409. KeyID: keyID,
  410. ARN: fmt.Sprintf("arn:aws:kms:local:000000000000:key/%s", keyID),
  411. Description: description,
  412. KeyMaterial: keyMaterial,
  413. KeyUsage: kms.KeyUsageEncryptDecrypt,
  414. KeyState: kms.KeyStateEnabled,
  415. Origin: kms.KeyOriginLocal,
  416. CreatedAt: time.Now(),
  417. Aliases: aliases,
  418. Metadata: make(map[string]string),
  419. }
  420. p.mu.Lock()
  421. defer p.mu.Unlock()
  422. p.keys[keyID] = key
  423. for _, alias := range aliases {
  424. // Ensure alias has proper format
  425. if !strings.HasPrefix(alias, "alias/") {
  426. alias = "alias/" + alias
  427. }
  428. p.keys[alias] = key
  429. }
  430. return key, nil
  431. }
  432. // CreateKeyWithID creates a key with a specific keyID (for testing only)
  433. func (p *LocalKMSProvider) CreateKeyWithID(keyID, description string) (*LocalKey, error) {
  434. keyMaterial := make([]byte, 32)
  435. if _, err := io.ReadFull(rand.Reader, keyMaterial); err != nil {
  436. return nil, fmt.Errorf("failed to generate key material: %w", err)
  437. }
  438. key := &LocalKey{
  439. KeyID: keyID,
  440. ARN: fmt.Sprintf("arn:aws:kms:local:000000000000:key/%s", keyID),
  441. Description: description,
  442. KeyMaterial: keyMaterial,
  443. KeyUsage: kms.KeyUsageEncryptDecrypt,
  444. KeyState: kms.KeyStateEnabled,
  445. Origin: kms.KeyOriginLocal,
  446. CreatedAt: time.Now(),
  447. Aliases: []string{}, // No aliases by default
  448. Metadata: make(map[string]string),
  449. }
  450. p.mu.Lock()
  451. defer p.mu.Unlock()
  452. // Register key with the exact keyID provided
  453. p.keys[keyID] = key
  454. return key, nil
  455. }
  456. // marshalEncryptionContextDeterministic creates a deterministic byte representation of encryption context
  457. // This ensures that the same encryption context always produces the same AAD for AES-GCM
  458. func marshalEncryptionContextDeterministic(encryptionContext map[string]string) ([]byte, error) {
  459. if len(encryptionContext) == 0 {
  460. return nil, nil
  461. }
  462. // Sort keys to ensure deterministic output
  463. keys := make([]string, 0, len(encryptionContext))
  464. for k := range encryptionContext {
  465. keys = append(keys, k)
  466. }
  467. sort.Strings(keys)
  468. // Build deterministic representation with proper JSON escaping
  469. var buf strings.Builder
  470. buf.WriteString("{")
  471. for i, k := range keys {
  472. if i > 0 {
  473. buf.WriteString(",")
  474. }
  475. // Marshal key and value to get proper JSON string escaping
  476. keyBytes, err := json.Marshal(k)
  477. if err != nil {
  478. return nil, fmt.Errorf("failed to marshal encryption context key '%s': %w", k, err)
  479. }
  480. valueBytes, err := json.Marshal(encryptionContext[k])
  481. if err != nil {
  482. return nil, fmt.Errorf("failed to marshal encryption context value for key '%s': %w", k, err)
  483. }
  484. buf.Write(keyBytes)
  485. buf.WriteString(":")
  486. buf.Write(valueBytes)
  487. }
  488. buf.WriteString("}")
  489. return []byte(buf.String()), nil
  490. }