aws_kms.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. package aws
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "fmt"
  6. "net/http"
  7. "strings"
  8. "time"
  9. "github.com/aws/aws-sdk-go/aws"
  10. "github.com/aws/aws-sdk-go/aws/awserr"
  11. "github.com/aws/aws-sdk-go/aws/credentials"
  12. "github.com/aws/aws-sdk-go/aws/session"
  13. "github.com/aws/aws-sdk-go/service/kms"
  14. "github.com/seaweedfs/seaweedfs/weed/glog"
  15. seaweedkms "github.com/seaweedfs/seaweedfs/weed/kms"
  16. "github.com/seaweedfs/seaweedfs/weed/util"
  17. )
  18. func init() {
  19. // Register the AWS KMS provider
  20. seaweedkms.RegisterProvider("aws", NewAWSKMSProvider)
  21. }
  22. // AWSKMSProvider implements the KMSProvider interface using AWS KMS
  23. type AWSKMSProvider struct {
  24. client *kms.KMS
  25. region string
  26. endpoint string // For testing with LocalStack or custom endpoints
  27. }
  28. // AWSKMSConfig contains configuration for the AWS KMS provider
  29. type AWSKMSConfig struct {
  30. Region string `json:"region"` // AWS region (e.g., "us-east-1")
  31. AccessKey string `json:"access_key"` // AWS access key (optional if using IAM roles)
  32. SecretKey string `json:"secret_key"` // AWS secret key (optional if using IAM roles)
  33. SessionToken string `json:"session_token"` // AWS session token (optional for STS)
  34. Endpoint string `json:"endpoint"` // Custom endpoint (optional, for LocalStack/testing)
  35. Profile string `json:"profile"` // AWS profile name (optional)
  36. RoleARN string `json:"role_arn"` // IAM role ARN to assume (optional)
  37. ExternalID string `json:"external_id"` // External ID for role assumption (optional)
  38. ConnectTimeout int `json:"connect_timeout"` // Connection timeout in seconds (default: 10)
  39. RequestTimeout int `json:"request_timeout"` // Request timeout in seconds (default: 30)
  40. MaxRetries int `json:"max_retries"` // Maximum number of retries (default: 3)
  41. }
  42. // NewAWSKMSProvider creates a new AWS KMS provider
  43. func NewAWSKMSProvider(config util.Configuration) (seaweedkms.KMSProvider, error) {
  44. if config == nil {
  45. return nil, fmt.Errorf("AWS KMS configuration is required")
  46. }
  47. // Extract configuration
  48. region := config.GetString("region")
  49. if region == "" {
  50. region = "us-east-1" // Default region
  51. }
  52. accessKey := config.GetString("access_key")
  53. secretKey := config.GetString("secret_key")
  54. sessionToken := config.GetString("session_token")
  55. endpoint := config.GetString("endpoint")
  56. profile := config.GetString("profile")
  57. // Timeouts and retries
  58. connectTimeout := config.GetInt("connect_timeout")
  59. if connectTimeout == 0 {
  60. connectTimeout = 10 // Default 10 seconds
  61. }
  62. requestTimeout := config.GetInt("request_timeout")
  63. if requestTimeout == 0 {
  64. requestTimeout = 30 // Default 30 seconds
  65. }
  66. maxRetries := config.GetInt("max_retries")
  67. if maxRetries == 0 {
  68. maxRetries = 3 // Default 3 retries
  69. }
  70. // Create AWS session
  71. awsConfig := &aws.Config{
  72. Region: aws.String(region),
  73. MaxRetries: aws.Int(maxRetries),
  74. HTTPClient: &http.Client{
  75. Timeout: time.Duration(requestTimeout) * time.Second,
  76. },
  77. }
  78. // Set custom endpoint if provided (for testing with LocalStack)
  79. if endpoint != "" {
  80. awsConfig.Endpoint = aws.String(endpoint)
  81. awsConfig.DisableSSL = aws.Bool(strings.HasPrefix(endpoint, "http://"))
  82. }
  83. // Configure credentials
  84. if accessKey != "" && secretKey != "" {
  85. awsConfig.Credentials = credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)
  86. } else if profile != "" {
  87. awsConfig.Credentials = credentials.NewSharedCredentials("", profile)
  88. }
  89. // If neither are provided, use default credential chain (IAM roles, etc.)
  90. sess, err := session.NewSession(awsConfig)
  91. if err != nil {
  92. return nil, fmt.Errorf("failed to create AWS session: %w", err)
  93. }
  94. provider := &AWSKMSProvider{
  95. client: kms.New(sess),
  96. region: region,
  97. endpoint: endpoint,
  98. }
  99. glog.V(1).Infof("AWS KMS provider initialized for region %s", region)
  100. return provider, nil
  101. }
  102. // GenerateDataKey generates a new data encryption key using AWS KMS
  103. func (p *AWSKMSProvider) GenerateDataKey(ctx context.Context, req *seaweedkms.GenerateDataKeyRequest) (*seaweedkms.GenerateDataKeyResponse, error) {
  104. if req == nil {
  105. return nil, fmt.Errorf("GenerateDataKeyRequest cannot be nil")
  106. }
  107. if req.KeyID == "" {
  108. return nil, fmt.Errorf("KeyID is required")
  109. }
  110. // Validate key spec
  111. var keySpec string
  112. switch req.KeySpec {
  113. case seaweedkms.KeySpecAES256:
  114. keySpec = "AES_256"
  115. default:
  116. return nil, fmt.Errorf("unsupported key spec: %s", req.KeySpec)
  117. }
  118. // Build KMS request
  119. kmsReq := &kms.GenerateDataKeyInput{
  120. KeyId: aws.String(req.KeyID),
  121. KeySpec: aws.String(keySpec),
  122. }
  123. // Add encryption context if provided
  124. if len(req.EncryptionContext) > 0 {
  125. kmsReq.EncryptionContext = aws.StringMap(req.EncryptionContext)
  126. }
  127. // Call AWS KMS
  128. glog.V(4).Infof("AWS KMS: Generating data key for key ID %s", req.KeyID)
  129. result, err := p.client.GenerateDataKeyWithContext(ctx, kmsReq)
  130. if err != nil {
  131. return nil, p.convertAWSError(err, req.KeyID)
  132. }
  133. // Extract the actual key ID from the response (resolves aliases)
  134. actualKeyID := ""
  135. if result.KeyId != nil {
  136. actualKeyID = *result.KeyId
  137. }
  138. // Create standardized envelope format for consistent API behavior
  139. envelopeBlob, err := seaweedkms.CreateEnvelope("aws", actualKeyID, base64.StdEncoding.EncodeToString(result.CiphertextBlob), nil)
  140. if err != nil {
  141. return nil, fmt.Errorf("failed to create ciphertext envelope: %w", err)
  142. }
  143. response := &seaweedkms.GenerateDataKeyResponse{
  144. KeyID: actualKeyID,
  145. Plaintext: result.Plaintext,
  146. CiphertextBlob: envelopeBlob, // Store in standardized envelope format
  147. }
  148. glog.V(4).Infof("AWS KMS: Generated data key for key ID %s (actual: %s)", req.KeyID, actualKeyID)
  149. return response, nil
  150. }
  151. // Decrypt decrypts an encrypted data key using AWS KMS
  152. func (p *AWSKMSProvider) Decrypt(ctx context.Context, req *seaweedkms.DecryptRequest) (*seaweedkms.DecryptResponse, error) {
  153. if req == nil {
  154. return nil, fmt.Errorf("DecryptRequest cannot be nil")
  155. }
  156. if len(req.CiphertextBlob) == 0 {
  157. return nil, fmt.Errorf("CiphertextBlob cannot be empty")
  158. }
  159. // Parse the ciphertext envelope to extract key information
  160. envelope, err := seaweedkms.ParseEnvelope(req.CiphertextBlob)
  161. if err != nil {
  162. return nil, fmt.Errorf("failed to parse ciphertext envelope: %w", err)
  163. }
  164. if envelope.Provider != "aws" {
  165. return nil, fmt.Errorf("invalid provider in envelope: expected 'aws', got '%s'", envelope.Provider)
  166. }
  167. ciphertext, err := base64.StdEncoding.DecodeString(envelope.Ciphertext)
  168. if err != nil {
  169. return nil, fmt.Errorf("failed to decode ciphertext from envelope: %w", err)
  170. }
  171. // Build KMS request
  172. kmsReq := &kms.DecryptInput{
  173. CiphertextBlob: ciphertext,
  174. }
  175. // Add encryption context if provided
  176. if len(req.EncryptionContext) > 0 {
  177. kmsReq.EncryptionContext = aws.StringMap(req.EncryptionContext)
  178. }
  179. // Call AWS KMS
  180. glog.V(4).Infof("AWS KMS: Decrypting data key (blob size: %d bytes)", len(req.CiphertextBlob))
  181. result, err := p.client.DecryptWithContext(ctx, kmsReq)
  182. if err != nil {
  183. return nil, p.convertAWSError(err, "")
  184. }
  185. // Extract the key ID that was used for encryption
  186. keyID := ""
  187. if result.KeyId != nil {
  188. keyID = *result.KeyId
  189. }
  190. response := &seaweedkms.DecryptResponse{
  191. KeyID: keyID,
  192. Plaintext: result.Plaintext,
  193. }
  194. glog.V(4).Infof("AWS KMS: Decrypted data key using key ID %s", keyID)
  195. return response, nil
  196. }
  197. // DescribeKey validates that a key exists and returns its metadata
  198. func (p *AWSKMSProvider) DescribeKey(ctx context.Context, req *seaweedkms.DescribeKeyRequest) (*seaweedkms.DescribeKeyResponse, error) {
  199. if req == nil {
  200. return nil, fmt.Errorf("DescribeKeyRequest cannot be nil")
  201. }
  202. if req.KeyID == "" {
  203. return nil, fmt.Errorf("KeyID is required")
  204. }
  205. // Build KMS request
  206. kmsReq := &kms.DescribeKeyInput{
  207. KeyId: aws.String(req.KeyID),
  208. }
  209. // Call AWS KMS
  210. glog.V(4).Infof("AWS KMS: Describing key %s", req.KeyID)
  211. result, err := p.client.DescribeKeyWithContext(ctx, kmsReq)
  212. if err != nil {
  213. return nil, p.convertAWSError(err, req.KeyID)
  214. }
  215. if result.KeyMetadata == nil {
  216. return nil, fmt.Errorf("no key metadata returned from AWS KMS")
  217. }
  218. metadata := result.KeyMetadata
  219. response := &seaweedkms.DescribeKeyResponse{
  220. KeyID: aws.StringValue(metadata.KeyId),
  221. ARN: aws.StringValue(metadata.Arn),
  222. Description: aws.StringValue(metadata.Description),
  223. }
  224. // Convert AWS key usage to our enum
  225. if metadata.KeyUsage != nil {
  226. switch *metadata.KeyUsage {
  227. case "ENCRYPT_DECRYPT":
  228. response.KeyUsage = seaweedkms.KeyUsageEncryptDecrypt
  229. case "GENERATE_DATA_KEY":
  230. response.KeyUsage = seaweedkms.KeyUsageGenerateDataKey
  231. }
  232. }
  233. // Convert AWS key state to our enum
  234. if metadata.KeyState != nil {
  235. switch *metadata.KeyState {
  236. case "Enabled":
  237. response.KeyState = seaweedkms.KeyStateEnabled
  238. case "Disabled":
  239. response.KeyState = seaweedkms.KeyStateDisabled
  240. case "PendingDeletion":
  241. response.KeyState = seaweedkms.KeyStatePendingDeletion
  242. case "Unavailable":
  243. response.KeyState = seaweedkms.KeyStateUnavailable
  244. }
  245. }
  246. // Convert AWS origin to our enum
  247. if metadata.Origin != nil {
  248. switch *metadata.Origin {
  249. case "AWS_KMS":
  250. response.Origin = seaweedkms.KeyOriginAWS
  251. case "EXTERNAL":
  252. response.Origin = seaweedkms.KeyOriginExternal
  253. case "AWS_CLOUDHSM":
  254. response.Origin = seaweedkms.KeyOriginCloudHSM
  255. }
  256. }
  257. glog.V(4).Infof("AWS KMS: Described key %s (actual: %s, state: %s)", req.KeyID, response.KeyID, response.KeyState)
  258. return response, nil
  259. }
  260. // GetKeyID resolves a key alias or ARN to the actual key ID
  261. func (p *AWSKMSProvider) GetKeyID(ctx context.Context, keyIdentifier string) (string, error) {
  262. if keyIdentifier == "" {
  263. return "", fmt.Errorf("key identifier cannot be empty")
  264. }
  265. // Use DescribeKey to resolve the key identifier
  266. descReq := &seaweedkms.DescribeKeyRequest{KeyID: keyIdentifier}
  267. descResp, err := p.DescribeKey(ctx, descReq)
  268. if err != nil {
  269. return "", fmt.Errorf("failed to resolve key identifier %s: %w", keyIdentifier, err)
  270. }
  271. return descResp.KeyID, nil
  272. }
  273. // Close cleans up any resources used by the provider
  274. func (p *AWSKMSProvider) Close() error {
  275. // AWS SDK clients don't require explicit cleanup
  276. glog.V(2).Infof("AWS KMS provider closed")
  277. return nil
  278. }
  279. // convertAWSError converts AWS KMS errors to our standard KMS errors
  280. func (p *AWSKMSProvider) convertAWSError(err error, keyID string) error {
  281. if awsErr, ok := err.(awserr.Error); ok {
  282. switch awsErr.Code() {
  283. case "NotFoundException":
  284. return &seaweedkms.KMSError{
  285. Code: seaweedkms.ErrCodeNotFoundException,
  286. Message: awsErr.Message(),
  287. KeyID: keyID,
  288. }
  289. case "DisabledException", "KeyUnavailableException":
  290. return &seaweedkms.KMSError{
  291. Code: seaweedkms.ErrCodeKeyUnavailable,
  292. Message: awsErr.Message(),
  293. KeyID: keyID,
  294. }
  295. case "AccessDeniedException":
  296. return &seaweedkms.KMSError{
  297. Code: seaweedkms.ErrCodeAccessDenied,
  298. Message: awsErr.Message(),
  299. KeyID: keyID,
  300. }
  301. case "InvalidKeyUsageException":
  302. return &seaweedkms.KMSError{
  303. Code: seaweedkms.ErrCodeInvalidKeyUsage,
  304. Message: awsErr.Message(),
  305. KeyID: keyID,
  306. }
  307. case "InvalidCiphertextException":
  308. return &seaweedkms.KMSError{
  309. Code: seaweedkms.ErrCodeInvalidCiphertext,
  310. Message: awsErr.Message(),
  311. KeyID: keyID,
  312. }
  313. case "KMSInternalException", "KMSInvalidStateException":
  314. return &seaweedkms.KMSError{
  315. Code: seaweedkms.ErrCodeKMSInternalFailure,
  316. Message: awsErr.Message(),
  317. KeyID: keyID,
  318. }
  319. default:
  320. // For unknown AWS errors, wrap them as internal failures
  321. return &seaweedkms.KMSError{
  322. Code: seaweedkms.ErrCodeKMSInternalFailure,
  323. Message: fmt.Sprintf("AWS KMS error %s: %s", awsErr.Code(), awsErr.Message()),
  324. KeyID: keyID,
  325. }
  326. }
  327. }
  328. // For non-AWS errors (network issues, etc.), wrap as internal failure
  329. return &seaweedkms.KMSError{
  330. Code: seaweedkms.ErrCodeKMSInternalFailure,
  331. Message: fmt.Sprintf("AWS KMS provider error: %v", err),
  332. KeyID: keyID,
  333. }
  334. }