query_parsing_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. package engine
  2. import (
  3. "testing"
  4. )
  5. func TestParseSQL_COUNT_Functions(t *testing.T) {
  6. tests := []struct {
  7. name string
  8. sql string
  9. wantErr bool
  10. validate func(t *testing.T, stmt Statement)
  11. }{
  12. {
  13. name: "COUNT(*) basic",
  14. sql: "SELECT COUNT(*) FROM test_table",
  15. wantErr: false,
  16. validate: func(t *testing.T, stmt Statement) {
  17. selectStmt, ok := stmt.(*SelectStatement)
  18. if !ok {
  19. t.Fatalf("Expected *SelectStatement, got %T", stmt)
  20. }
  21. if len(selectStmt.SelectExprs) != 1 {
  22. t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
  23. }
  24. aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
  25. if !ok {
  26. t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
  27. }
  28. funcExpr, ok := aliasedExpr.Expr.(*FuncExpr)
  29. if !ok {
  30. t.Fatalf("Expected *FuncExpr, got %T", aliasedExpr.Expr)
  31. }
  32. if funcExpr.Name.String() != "COUNT" {
  33. t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
  34. }
  35. if len(funcExpr.Exprs) != 1 {
  36. t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
  37. }
  38. starExpr, ok := funcExpr.Exprs[0].(*StarExpr)
  39. if !ok {
  40. t.Errorf("Expected *StarExpr argument, got %T", funcExpr.Exprs[0])
  41. }
  42. _ = starExpr // Use the variable to avoid unused variable error
  43. },
  44. },
  45. {
  46. name: "COUNT(column_name)",
  47. sql: "SELECT COUNT(user_id) FROM users",
  48. wantErr: false,
  49. validate: func(t *testing.T, stmt Statement) {
  50. selectStmt, ok := stmt.(*SelectStatement)
  51. if !ok {
  52. t.Fatalf("Expected *SelectStatement, got %T", stmt)
  53. }
  54. aliasedExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
  55. funcExpr := aliasedExpr.Expr.(*FuncExpr)
  56. if funcExpr.Name.String() != "COUNT" {
  57. t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
  58. }
  59. if len(funcExpr.Exprs) != 1 {
  60. t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
  61. }
  62. argExpr, ok := funcExpr.Exprs[0].(*AliasedExpr)
  63. if !ok {
  64. t.Errorf("Expected *AliasedExpr argument, got %T", funcExpr.Exprs[0])
  65. }
  66. colName, ok := argExpr.Expr.(*ColName)
  67. if !ok {
  68. t.Errorf("Expected *ColName, got %T", argExpr.Expr)
  69. }
  70. if colName.Name.String() != "user_id" {
  71. t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
  72. }
  73. },
  74. },
  75. {
  76. name: "Multiple aggregate functions",
  77. sql: "SELECT COUNT(*), SUM(amount), AVG(score) FROM transactions",
  78. wantErr: false,
  79. validate: func(t *testing.T, stmt Statement) {
  80. selectStmt, ok := stmt.(*SelectStatement)
  81. if !ok {
  82. t.Fatalf("Expected *SelectStatement, got %T", stmt)
  83. }
  84. if len(selectStmt.SelectExprs) != 3 {
  85. t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
  86. }
  87. // Verify COUNT(*)
  88. countExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
  89. countFunc := countExpr.Expr.(*FuncExpr)
  90. if countFunc.Name.String() != "COUNT" {
  91. t.Errorf("Expected first function to be COUNT, got %s", countFunc.Name.String())
  92. }
  93. // Verify SUM(amount)
  94. sumExpr := selectStmt.SelectExprs[1].(*AliasedExpr)
  95. sumFunc := sumExpr.Expr.(*FuncExpr)
  96. if sumFunc.Name.String() != "SUM" {
  97. t.Errorf("Expected second function to be SUM, got %s", sumFunc.Name.String())
  98. }
  99. // Verify AVG(score)
  100. avgExpr := selectStmt.SelectExprs[2].(*AliasedExpr)
  101. avgFunc := avgExpr.Expr.(*FuncExpr)
  102. if avgFunc.Name.String() != "AVG" {
  103. t.Errorf("Expected third function to be AVG, got %s", avgFunc.Name.String())
  104. }
  105. },
  106. },
  107. }
  108. for _, tt := range tests {
  109. t.Run(tt.name, func(t *testing.T) {
  110. stmt, err := ParseSQL(tt.sql)
  111. if tt.wantErr {
  112. if err == nil {
  113. t.Errorf("Expected error, but got none")
  114. }
  115. return
  116. }
  117. if err != nil {
  118. t.Errorf("Unexpected error: %v", err)
  119. return
  120. }
  121. if tt.validate != nil {
  122. tt.validate(t, stmt)
  123. }
  124. })
  125. }
  126. }
  127. func TestParseSQL_SELECT_Expressions(t *testing.T) {
  128. tests := []struct {
  129. name string
  130. sql string
  131. wantErr bool
  132. validate func(t *testing.T, stmt Statement)
  133. }{
  134. {
  135. name: "SELECT * FROM table",
  136. sql: "SELECT * FROM users",
  137. wantErr: false,
  138. validate: func(t *testing.T, stmt Statement) {
  139. selectStmt := stmt.(*SelectStatement)
  140. if len(selectStmt.SelectExprs) != 1 {
  141. t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
  142. }
  143. _, ok := selectStmt.SelectExprs[0].(*StarExpr)
  144. if !ok {
  145. t.Errorf("Expected *StarExpr, got %T", selectStmt.SelectExprs[0])
  146. }
  147. },
  148. },
  149. {
  150. name: "SELECT column FROM table",
  151. sql: "SELECT user_id FROM users",
  152. wantErr: false,
  153. validate: func(t *testing.T, stmt Statement) {
  154. selectStmt := stmt.(*SelectStatement)
  155. if len(selectStmt.SelectExprs) != 1 {
  156. t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
  157. }
  158. aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
  159. if !ok {
  160. t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
  161. }
  162. colName, ok := aliasedExpr.Expr.(*ColName)
  163. if !ok {
  164. t.Fatalf("Expected *ColName, got %T", aliasedExpr.Expr)
  165. }
  166. if colName.Name.String() != "user_id" {
  167. t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
  168. }
  169. },
  170. },
  171. {
  172. name: "SELECT multiple columns",
  173. sql: "SELECT user_id, name, email FROM users",
  174. wantErr: false,
  175. validate: func(t *testing.T, stmt Statement) {
  176. selectStmt := stmt.(*SelectStatement)
  177. if len(selectStmt.SelectExprs) != 3 {
  178. t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
  179. }
  180. expectedColumns := []string{"user_id", "name", "email"}
  181. for i, expected := range expectedColumns {
  182. aliasedExpr := selectStmt.SelectExprs[i].(*AliasedExpr)
  183. colName := aliasedExpr.Expr.(*ColName)
  184. if colName.Name.String() != expected {
  185. t.Errorf("Expected column %d to be '%s', got '%s'", i, expected, colName.Name.String())
  186. }
  187. }
  188. },
  189. },
  190. }
  191. for _, tt := range tests {
  192. t.Run(tt.name, func(t *testing.T) {
  193. stmt, err := ParseSQL(tt.sql)
  194. if tt.wantErr {
  195. if err == nil {
  196. t.Errorf("Expected error, but got none")
  197. }
  198. return
  199. }
  200. if err != nil {
  201. t.Errorf("Unexpected error: %v", err)
  202. return
  203. }
  204. if tt.validate != nil {
  205. tt.validate(t, stmt)
  206. }
  207. })
  208. }
  209. }
  210. func TestParseSQL_WHERE_Clauses(t *testing.T) {
  211. tests := []struct {
  212. name string
  213. sql string
  214. wantErr bool
  215. validate func(t *testing.T, stmt Statement)
  216. }{
  217. {
  218. name: "WHERE with simple comparison",
  219. sql: "SELECT * FROM users WHERE age > 18",
  220. wantErr: false,
  221. validate: func(t *testing.T, stmt Statement) {
  222. selectStmt := stmt.(*SelectStatement)
  223. if selectStmt.Where == nil {
  224. t.Fatal("Expected WHERE clause, got nil")
  225. }
  226. // Just verify we have a WHERE clause with an expression
  227. if selectStmt.Where.Expr == nil {
  228. t.Error("Expected WHERE expression, got nil")
  229. }
  230. },
  231. },
  232. {
  233. name: "WHERE with AND condition",
  234. sql: "SELECT * FROM users WHERE age > 18 AND status = 'active'",
  235. wantErr: false,
  236. validate: func(t *testing.T, stmt Statement) {
  237. selectStmt := stmt.(*SelectStatement)
  238. if selectStmt.Where == nil {
  239. t.Fatal("Expected WHERE clause, got nil")
  240. }
  241. // Verify we have an AND expression
  242. andExpr, ok := selectStmt.Where.Expr.(*AndExpr)
  243. if !ok {
  244. t.Errorf("Expected *AndExpr, got %T", selectStmt.Where.Expr)
  245. }
  246. _ = andExpr // Use variable to avoid unused error
  247. },
  248. },
  249. {
  250. name: "WHERE with OR condition",
  251. sql: "SELECT * FROM users WHERE age < 18 OR age > 65",
  252. wantErr: false,
  253. validate: func(t *testing.T, stmt Statement) {
  254. selectStmt := stmt.(*SelectStatement)
  255. if selectStmt.Where == nil {
  256. t.Fatal("Expected WHERE clause, got nil")
  257. }
  258. // Verify we have an OR expression
  259. orExpr, ok := selectStmt.Where.Expr.(*OrExpr)
  260. if !ok {
  261. t.Errorf("Expected *OrExpr, got %T", selectStmt.Where.Expr)
  262. }
  263. _ = orExpr // Use variable to avoid unused error
  264. },
  265. },
  266. }
  267. for _, tt := range tests {
  268. t.Run(tt.name, func(t *testing.T) {
  269. stmt, err := ParseSQL(tt.sql)
  270. if tt.wantErr {
  271. if err == nil {
  272. t.Errorf("Expected error, but got none")
  273. }
  274. return
  275. }
  276. if err != nil {
  277. t.Errorf("Unexpected error: %v", err)
  278. return
  279. }
  280. if tt.validate != nil {
  281. tt.validate(t, stmt)
  282. }
  283. })
  284. }
  285. }
  286. func TestParseSQL_LIMIT_Clauses(t *testing.T) {
  287. tests := []struct {
  288. name string
  289. sql string
  290. wantErr bool
  291. validate func(t *testing.T, stmt Statement)
  292. }{
  293. {
  294. name: "LIMIT with number",
  295. sql: "SELECT * FROM users LIMIT 10",
  296. wantErr: false,
  297. validate: func(t *testing.T, stmt Statement) {
  298. selectStmt := stmt.(*SelectStatement)
  299. if selectStmt.Limit == nil {
  300. t.Fatal("Expected LIMIT clause, got nil")
  301. }
  302. if selectStmt.Limit.Rowcount == nil {
  303. t.Error("Expected LIMIT rowcount, got nil")
  304. }
  305. // Verify no OFFSET is set
  306. if selectStmt.Limit.Offset != nil {
  307. t.Error("Expected OFFSET to be nil for LIMIT-only query")
  308. }
  309. sqlVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
  310. if !ok {
  311. t.Errorf("Expected *SQLVal, got %T", selectStmt.Limit.Rowcount)
  312. }
  313. if sqlVal.Type != IntVal {
  314. t.Errorf("Expected IntVal type, got %d", sqlVal.Type)
  315. }
  316. if string(sqlVal.Val) != "10" {
  317. t.Errorf("Expected limit value '10', got '%s'", string(sqlVal.Val))
  318. }
  319. },
  320. },
  321. {
  322. name: "LIMIT with OFFSET",
  323. sql: "SELECT * FROM users LIMIT 10 OFFSET 5",
  324. wantErr: false,
  325. validate: func(t *testing.T, stmt Statement) {
  326. selectStmt := stmt.(*SelectStatement)
  327. if selectStmt.Limit == nil {
  328. t.Fatal("Expected LIMIT clause, got nil")
  329. }
  330. // Verify LIMIT value
  331. if selectStmt.Limit.Rowcount == nil {
  332. t.Error("Expected LIMIT rowcount, got nil")
  333. }
  334. limitVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
  335. if !ok {
  336. t.Errorf("Expected *SQLVal for LIMIT, got %T", selectStmt.Limit.Rowcount)
  337. }
  338. if limitVal.Type != IntVal {
  339. t.Errorf("Expected IntVal type for LIMIT, got %d", limitVal.Type)
  340. }
  341. if string(limitVal.Val) != "10" {
  342. t.Errorf("Expected limit value '10', got '%s'", string(limitVal.Val))
  343. }
  344. // Verify OFFSET value
  345. if selectStmt.Limit.Offset == nil {
  346. t.Fatal("Expected OFFSET clause, got nil")
  347. }
  348. offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
  349. if !ok {
  350. t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
  351. }
  352. if offsetVal.Type != IntVal {
  353. t.Errorf("Expected IntVal type for OFFSET, got %d", offsetVal.Type)
  354. }
  355. if string(offsetVal.Val) != "5" {
  356. t.Errorf("Expected offset value '5', got '%s'", string(offsetVal.Val))
  357. }
  358. },
  359. },
  360. {
  361. name: "LIMIT with OFFSET zero",
  362. sql: "SELECT * FROM users LIMIT 5 OFFSET 0",
  363. wantErr: false,
  364. validate: func(t *testing.T, stmt Statement) {
  365. selectStmt := stmt.(*SelectStatement)
  366. if selectStmt.Limit == nil {
  367. t.Fatal("Expected LIMIT clause, got nil")
  368. }
  369. // Verify OFFSET is 0
  370. if selectStmt.Limit.Offset == nil {
  371. t.Fatal("Expected OFFSET clause, got nil")
  372. }
  373. offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
  374. if !ok {
  375. t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
  376. }
  377. if string(offsetVal.Val) != "0" {
  378. t.Errorf("Expected offset value '0', got '%s'", string(offsetVal.Val))
  379. }
  380. },
  381. },
  382. {
  383. name: "LIMIT with large OFFSET",
  384. sql: "SELECT * FROM users LIMIT 100 OFFSET 1000",
  385. wantErr: false,
  386. validate: func(t *testing.T, stmt Statement) {
  387. selectStmt := stmt.(*SelectStatement)
  388. if selectStmt.Limit == nil {
  389. t.Fatal("Expected LIMIT clause, got nil")
  390. }
  391. // Verify large OFFSET value
  392. offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
  393. if !ok {
  394. t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
  395. }
  396. if string(offsetVal.Val) != "1000" {
  397. t.Errorf("Expected offset value '1000', got '%s'", string(offsetVal.Val))
  398. }
  399. },
  400. },
  401. }
  402. for _, tt := range tests {
  403. t.Run(tt.name, func(t *testing.T) {
  404. stmt, err := ParseSQL(tt.sql)
  405. if tt.wantErr {
  406. if err == nil {
  407. t.Errorf("Expected error, but got none")
  408. }
  409. return
  410. }
  411. if err != nil {
  412. t.Errorf("Unexpected error: %v", err)
  413. return
  414. }
  415. if tt.validate != nil {
  416. tt.validate(t, stmt)
  417. }
  418. })
  419. }
  420. }
  421. func TestParseSQL_SHOW_Statements(t *testing.T) {
  422. tests := []struct {
  423. name string
  424. sql string
  425. wantErr bool
  426. validate func(t *testing.T, stmt Statement)
  427. }{
  428. {
  429. name: "SHOW DATABASES",
  430. sql: "SHOW DATABASES",
  431. wantErr: false,
  432. validate: func(t *testing.T, stmt Statement) {
  433. showStmt, ok := stmt.(*ShowStatement)
  434. if !ok {
  435. t.Fatalf("Expected *ShowStatement, got %T", stmt)
  436. }
  437. if showStmt.Type != "databases" {
  438. t.Errorf("Expected type 'databases', got '%s'", showStmt.Type)
  439. }
  440. },
  441. },
  442. {
  443. name: "SHOW TABLES",
  444. sql: "SHOW TABLES",
  445. wantErr: false,
  446. validate: func(t *testing.T, stmt Statement) {
  447. showStmt, ok := stmt.(*ShowStatement)
  448. if !ok {
  449. t.Fatalf("Expected *ShowStatement, got %T", stmt)
  450. }
  451. if showStmt.Type != "tables" {
  452. t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
  453. }
  454. },
  455. },
  456. {
  457. name: "SHOW TABLES FROM database",
  458. sql: "SHOW TABLES FROM \"test_db\"",
  459. wantErr: false,
  460. validate: func(t *testing.T, stmt Statement) {
  461. showStmt, ok := stmt.(*ShowStatement)
  462. if !ok {
  463. t.Fatalf("Expected *ShowStatement, got %T", stmt)
  464. }
  465. if showStmt.Type != "tables" {
  466. t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
  467. }
  468. if showStmt.Schema != "test_db" {
  469. t.Errorf("Expected schema 'test_db', got '%s'", showStmt.Schema)
  470. }
  471. },
  472. },
  473. }
  474. for _, tt := range tests {
  475. t.Run(tt.name, func(t *testing.T) {
  476. stmt, err := ParseSQL(tt.sql)
  477. if tt.wantErr {
  478. if err == nil {
  479. t.Errorf("Expected error, but got none")
  480. }
  481. return
  482. }
  483. if err != nil {
  484. t.Errorf("Unexpected error: %v", err)
  485. return
  486. }
  487. if tt.validate != nil {
  488. tt.validate(t, stmt)
  489. }
  490. })
  491. }
  492. }