s3_sse_c_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. package s3api
  2. import (
  3. "bytes"
  4. "crypto/md5"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "testing"
  10. "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
  11. )
  12. func base64MD5(b []byte) string {
  13. s := md5.Sum(b)
  14. return base64.StdEncoding.EncodeToString(s[:])
  15. }
  16. func TestSSECHeaderValidation(t *testing.T) {
  17. // Test valid SSE-C headers
  18. req := &http.Request{Header: make(http.Header)}
  19. key := make([]byte, 32) // 256-bit key
  20. for i := range key {
  21. key[i] = byte(i)
  22. }
  23. keyBase64 := base64.StdEncoding.EncodeToString(key)
  24. md5sum := md5.Sum(key)
  25. keyMD5 := base64.StdEncoding.EncodeToString(md5sum[:])
  26. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
  27. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyBase64)
  28. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyMD5)
  29. // Test validation
  30. err := ValidateSSECHeaders(req)
  31. if err != nil {
  32. t.Errorf("Expected valid headers, got error: %v", err)
  33. }
  34. // Test parsing
  35. customerKey, err := ParseSSECHeaders(req)
  36. if err != nil {
  37. t.Errorf("Expected successful parsing, got error: %v", err)
  38. }
  39. if customerKey == nil {
  40. t.Error("Expected customer key, got nil")
  41. }
  42. if customerKey.Algorithm != "AES256" {
  43. t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
  44. }
  45. if !bytes.Equal(customerKey.Key, key) {
  46. t.Error("Key doesn't match original")
  47. }
  48. if customerKey.KeyMD5 != keyMD5 {
  49. t.Errorf("Expected key MD5 %s, got %s", keyMD5, customerKey.KeyMD5)
  50. }
  51. }
  52. func TestSSECCopySourceHeaders(t *testing.T) {
  53. // Test valid SSE-C copy source headers
  54. req := &http.Request{Header: make(http.Header)}
  55. key := make([]byte, 32) // 256-bit key
  56. for i := range key {
  57. key[i] = byte(i) + 1 // Different from regular test
  58. }
  59. keyBase64 := base64.StdEncoding.EncodeToString(key)
  60. md5sum2 := md5.Sum(key)
  61. keyMD5 := base64.StdEncoding.EncodeToString(md5sum2[:])
  62. req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerAlgorithm, "AES256")
  63. req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKey, keyBase64)
  64. req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKeyMD5, keyMD5)
  65. // Test parsing copy source headers
  66. customerKey, err := ParseSSECCopySourceHeaders(req)
  67. if err != nil {
  68. t.Errorf("Expected successful copy source parsing, got error: %v", err)
  69. }
  70. if customerKey == nil {
  71. t.Error("Expected customer key from copy source headers, got nil")
  72. }
  73. if customerKey.Algorithm != "AES256" {
  74. t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
  75. }
  76. if !bytes.Equal(customerKey.Key, key) {
  77. t.Error("Copy source key doesn't match original")
  78. }
  79. // Test that regular headers don't interfere with copy source headers
  80. regularKey, err := ParseSSECHeaders(req)
  81. if err != nil {
  82. t.Errorf("Regular header parsing should not fail: %v", err)
  83. }
  84. if regularKey != nil {
  85. t.Error("Expected nil for regular headers when only copy source headers are present")
  86. }
  87. }
  88. func TestSSECHeaderValidationErrors(t *testing.T) {
  89. tests := []struct {
  90. name string
  91. algorithm string
  92. key string
  93. keyMD5 string
  94. wantErr error
  95. }{
  96. {
  97. name: "invalid algorithm",
  98. algorithm: "AES128",
  99. key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
  100. keyMD5: base64MD5(make([]byte, 32)),
  101. wantErr: ErrInvalidEncryptionAlgorithm,
  102. },
  103. {
  104. name: "invalid key length",
  105. algorithm: "AES256",
  106. key: base64.StdEncoding.EncodeToString(make([]byte, 16)),
  107. keyMD5: base64MD5(make([]byte, 16)),
  108. wantErr: ErrInvalidEncryptionKey,
  109. },
  110. {
  111. name: "mismatched MD5",
  112. algorithm: "AES256",
  113. key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
  114. keyMD5: "wrong==md5",
  115. wantErr: ErrSSECustomerKeyMD5Mismatch,
  116. },
  117. {
  118. name: "incomplete headers",
  119. algorithm: "AES256",
  120. key: "",
  121. keyMD5: "",
  122. wantErr: ErrInvalidRequest,
  123. },
  124. }
  125. for _, tt := range tests {
  126. t.Run(tt.name, func(t *testing.T) {
  127. req := &http.Request{Header: make(http.Header)}
  128. if tt.algorithm != "" {
  129. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tt.algorithm)
  130. }
  131. if tt.key != "" {
  132. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tt.key)
  133. }
  134. if tt.keyMD5 != "" {
  135. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tt.keyMD5)
  136. }
  137. err := ValidateSSECHeaders(req)
  138. if err != tt.wantErr {
  139. t.Errorf("Expected error %v, got %v", tt.wantErr, err)
  140. }
  141. })
  142. }
  143. }
  144. func TestSSECEncryptionDecryption(t *testing.T) {
  145. // Create customer key
  146. key := make([]byte, 32)
  147. for i := range key {
  148. key[i] = byte(i)
  149. }
  150. md5sumKey := md5.Sum(key)
  151. customerKey := &SSECustomerKey{
  152. Algorithm: "AES256",
  153. Key: key,
  154. KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey[:]),
  155. }
  156. // Test data
  157. testData := []byte("Hello, World! This is a test of SSE-C encryption.")
  158. // Create encrypted reader
  159. dataReader := bytes.NewReader(testData)
  160. encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
  161. if err != nil {
  162. t.Fatalf("Failed to create encrypted reader: %v", err)
  163. }
  164. // Read encrypted data
  165. encryptedData, err := io.ReadAll(encryptedReader)
  166. if err != nil {
  167. t.Fatalf("Failed to read encrypted data: %v", err)
  168. }
  169. // Verify data is actually encrypted (different from original)
  170. if bytes.Equal(encryptedData[16:], testData) { // Skip IV
  171. t.Error("Data doesn't appear to be encrypted")
  172. }
  173. // Create decrypted reader
  174. encryptedReader2 := bytes.NewReader(encryptedData)
  175. decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
  176. if err != nil {
  177. t.Fatalf("Failed to create decrypted reader: %v", err)
  178. }
  179. // Read decrypted data
  180. decryptedData, err := io.ReadAll(decryptedReader)
  181. if err != nil {
  182. t.Fatalf("Failed to read decrypted data: %v", err)
  183. }
  184. // Verify decrypted data matches original
  185. if !bytes.Equal(decryptedData, testData) {
  186. t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
  187. }
  188. }
  189. func TestSSECIsSSECRequest(t *testing.T) {
  190. // Test with SSE-C headers
  191. req := &http.Request{Header: make(http.Header)}
  192. req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
  193. if !IsSSECRequest(req) {
  194. t.Error("Expected IsSSECRequest to return true when SSE-C headers are present")
  195. }
  196. // Test without SSE-C headers
  197. req2 := &http.Request{Header: make(http.Header)}
  198. if IsSSECRequest(req2) {
  199. t.Error("Expected IsSSECRequest to return false when no SSE-C headers are present")
  200. }
  201. }
  202. // Test encryption with different data sizes (similar to s3tests)
  203. func TestSSECEncryptionVariousSizes(t *testing.T) {
  204. sizes := []int{1, 13, 1024, 1024 * 1024} // 1B, 13B, 1KB, 1MB
  205. for _, size := range sizes {
  206. t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
  207. // Create customer key
  208. key := make([]byte, 32)
  209. for i := range key {
  210. key[i] = byte(i + size) // Make key unique per test
  211. }
  212. md5sumDyn := md5.Sum(key)
  213. customerKey := &SSECustomerKey{
  214. Algorithm: "AES256",
  215. Key: key,
  216. KeyMD5: base64.StdEncoding.EncodeToString(md5sumDyn[:]),
  217. }
  218. // Create test data of specified size
  219. testData := make([]byte, size)
  220. for i := range testData {
  221. testData[i] = byte('A' + (i % 26)) // Pattern of A-Z
  222. }
  223. // Encrypt
  224. dataReader := bytes.NewReader(testData)
  225. encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
  226. if err != nil {
  227. t.Fatalf("Failed to create encrypted reader: %v", err)
  228. }
  229. encryptedData, err := io.ReadAll(encryptedReader)
  230. if err != nil {
  231. t.Fatalf("Failed to read encrypted data: %v", err)
  232. }
  233. // Verify encrypted data has same size as original (IV is stored in metadata, not in stream)
  234. if len(encryptedData) != size {
  235. t.Errorf("Expected encrypted data length %d (same as original), got %d", size, len(encryptedData))
  236. }
  237. // Decrypt
  238. encryptedReader2 := bytes.NewReader(encryptedData)
  239. decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
  240. if err != nil {
  241. t.Fatalf("Failed to create decrypted reader: %v", err)
  242. }
  243. decryptedData, err := io.ReadAll(decryptedReader)
  244. if err != nil {
  245. t.Fatalf("Failed to read decrypted data: %v", err)
  246. }
  247. // Verify decrypted data matches original
  248. if !bytes.Equal(decryptedData, testData) {
  249. t.Errorf("Decrypted data doesn't match original for size %d", size)
  250. }
  251. })
  252. }
  253. }
  254. func TestSSECEncryptionWithNilKey(t *testing.T) {
  255. testData := []byte("test data")
  256. dataReader := bytes.NewReader(testData)
  257. // Test encryption with nil key (should pass through)
  258. encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, nil)
  259. if err != nil {
  260. t.Fatalf("Failed to create encrypted reader with nil key: %v", err)
  261. }
  262. result, err := io.ReadAll(encryptedReader)
  263. if err != nil {
  264. t.Fatalf("Failed to read from pass-through reader: %v", err)
  265. }
  266. if !bytes.Equal(result, testData) {
  267. t.Error("Data should pass through unchanged when key is nil")
  268. }
  269. // Test decryption with nil key (should pass through)
  270. dataReader2 := bytes.NewReader(testData)
  271. decryptedReader, err := CreateSSECDecryptedReader(dataReader2, nil, iv)
  272. if err != nil {
  273. t.Fatalf("Failed to create decrypted reader with nil key: %v", err)
  274. }
  275. result2, err := io.ReadAll(decryptedReader)
  276. if err != nil {
  277. t.Fatalf("Failed to read from pass-through reader: %v", err)
  278. }
  279. if !bytes.Equal(result2, testData) {
  280. t.Error("Data should pass through unchanged when key is nil")
  281. }
  282. }
  283. // TestSSECEncryptionSmallBuffers tests the fix for the critical bug where small buffers
  284. // could corrupt the data stream when reading in chunks smaller than the IV size
  285. func TestSSECEncryptionSmallBuffers(t *testing.T) {
  286. testData := []byte("This is a test message for small buffer reads")
  287. // Create customer key
  288. key := make([]byte, 32)
  289. for i := range key {
  290. key[i] = byte(i)
  291. }
  292. md5sumKey3 := md5.Sum(key)
  293. customerKey := &SSECustomerKey{
  294. Algorithm: "AES256",
  295. Key: key,
  296. KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey3[:]),
  297. }
  298. // Create encrypted reader
  299. dataReader := bytes.NewReader(testData)
  300. encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
  301. if err != nil {
  302. t.Fatalf("Failed to create encrypted reader: %v", err)
  303. }
  304. // Read with very small buffers (smaller than IV size of 16 bytes)
  305. var encryptedData []byte
  306. smallBuffer := make([]byte, 5) // Much smaller than 16-byte IV
  307. for {
  308. n, err := encryptedReader.Read(smallBuffer)
  309. if n > 0 {
  310. encryptedData = append(encryptedData, smallBuffer[:n]...)
  311. }
  312. if err == io.EOF {
  313. break
  314. }
  315. if err != nil {
  316. t.Fatalf("Error reading encrypted data: %v", err)
  317. }
  318. }
  319. // Verify we have some encrypted data (IV is in metadata, not in stream)
  320. if len(encryptedData) == 0 && len(testData) > 0 {
  321. t.Fatal("Expected encrypted data but got none")
  322. }
  323. // Expected size: same as original data (IV is stored in metadata, not in stream)
  324. if len(encryptedData) != len(testData) {
  325. t.Errorf("Expected encrypted data size %d (same as original), got %d", len(testData), len(encryptedData))
  326. }
  327. // Decrypt and verify
  328. encryptedReader2 := bytes.NewReader(encryptedData)
  329. decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
  330. if err != nil {
  331. t.Fatalf("Failed to create decrypted reader: %v", err)
  332. }
  333. decryptedData, err := io.ReadAll(decryptedReader)
  334. if err != nil {
  335. t.Fatalf("Failed to read decrypted data: %v", err)
  336. }
  337. if !bytes.Equal(decryptedData, testData) {
  338. t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
  339. }
  340. }