| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- package engine
- import (
- "testing"
- )
- func TestParseSQL_COUNT_Functions(t *testing.T) {
- tests := []struct {
- name string
- sql string
- wantErr bool
- validate func(t *testing.T, stmt Statement)
- }{
- {
- name: "COUNT(*) basic",
- sql: "SELECT COUNT(*) FROM test_table",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt, ok := stmt.(*SelectStatement)
- if !ok {
- t.Fatalf("Expected *SelectStatement, got %T", stmt)
- }
- if len(selectStmt.SelectExprs) != 1 {
- t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
- }
- aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
- if !ok {
- t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
- }
- funcExpr, ok := aliasedExpr.Expr.(*FuncExpr)
- if !ok {
- t.Fatalf("Expected *FuncExpr, got %T", aliasedExpr.Expr)
- }
- if funcExpr.Name.String() != "COUNT" {
- t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
- }
- if len(funcExpr.Exprs) != 1 {
- t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
- }
- starExpr, ok := funcExpr.Exprs[0].(*StarExpr)
- if !ok {
- t.Errorf("Expected *StarExpr argument, got %T", funcExpr.Exprs[0])
- }
- _ = starExpr // Use the variable to avoid unused variable error
- },
- },
- {
- name: "COUNT(column_name)",
- sql: "SELECT COUNT(user_id) FROM users",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt, ok := stmt.(*SelectStatement)
- if !ok {
- t.Fatalf("Expected *SelectStatement, got %T", stmt)
- }
- aliasedExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
- funcExpr := aliasedExpr.Expr.(*FuncExpr)
- if funcExpr.Name.String() != "COUNT" {
- t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String())
- }
- if len(funcExpr.Exprs) != 1 {
- t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs))
- }
- argExpr, ok := funcExpr.Exprs[0].(*AliasedExpr)
- if !ok {
- t.Errorf("Expected *AliasedExpr argument, got %T", funcExpr.Exprs[0])
- }
- colName, ok := argExpr.Expr.(*ColName)
- if !ok {
- t.Errorf("Expected *ColName, got %T", argExpr.Expr)
- }
- if colName.Name.String() != "user_id" {
- t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
- }
- },
- },
- {
- name: "Multiple aggregate functions",
- sql: "SELECT COUNT(*), SUM(amount), AVG(score) FROM transactions",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt, ok := stmt.(*SelectStatement)
- if !ok {
- t.Fatalf("Expected *SelectStatement, got %T", stmt)
- }
- if len(selectStmt.SelectExprs) != 3 {
- t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
- }
- // Verify COUNT(*)
- countExpr := selectStmt.SelectExprs[0].(*AliasedExpr)
- countFunc := countExpr.Expr.(*FuncExpr)
- if countFunc.Name.String() != "COUNT" {
- t.Errorf("Expected first function to be COUNT, got %s", countFunc.Name.String())
- }
- // Verify SUM(amount)
- sumExpr := selectStmt.SelectExprs[1].(*AliasedExpr)
- sumFunc := sumExpr.Expr.(*FuncExpr)
- if sumFunc.Name.String() != "SUM" {
- t.Errorf("Expected second function to be SUM, got %s", sumFunc.Name.String())
- }
- // Verify AVG(score)
- avgExpr := selectStmt.SelectExprs[2].(*AliasedExpr)
- avgFunc := avgExpr.Expr.(*FuncExpr)
- if avgFunc.Name.String() != "AVG" {
- t.Errorf("Expected third function to be AVG, got %s", avgFunc.Name.String())
- }
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stmt, err := ParseSQL(tt.sql)
- if tt.wantErr {
- if err == nil {
- t.Errorf("Expected error, but got none")
- }
- return
- }
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- return
- }
- if tt.validate != nil {
- tt.validate(t, stmt)
- }
- })
- }
- }
- func TestParseSQL_SELECT_Expressions(t *testing.T) {
- tests := []struct {
- name string
- sql string
- wantErr bool
- validate func(t *testing.T, stmt Statement)
- }{
- {
- name: "SELECT * FROM table",
- sql: "SELECT * FROM users",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if len(selectStmt.SelectExprs) != 1 {
- t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
- }
- _, ok := selectStmt.SelectExprs[0].(*StarExpr)
- if !ok {
- t.Errorf("Expected *StarExpr, got %T", selectStmt.SelectExprs[0])
- }
- },
- },
- {
- name: "SELECT column FROM table",
- sql: "SELECT user_id FROM users",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if len(selectStmt.SelectExprs) != 1 {
- t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs))
- }
- aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr)
- if !ok {
- t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0])
- }
- colName, ok := aliasedExpr.Expr.(*ColName)
- if !ok {
- t.Fatalf("Expected *ColName, got %T", aliasedExpr.Expr)
- }
- if colName.Name.String() != "user_id" {
- t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String())
- }
- },
- },
- {
- name: "SELECT multiple columns",
- sql: "SELECT user_id, name, email FROM users",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if len(selectStmt.SelectExprs) != 3 {
- t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
- }
- expectedColumns := []string{"user_id", "name", "email"}
- for i, expected := range expectedColumns {
- aliasedExpr := selectStmt.SelectExprs[i].(*AliasedExpr)
- colName := aliasedExpr.Expr.(*ColName)
- if colName.Name.String() != expected {
- t.Errorf("Expected column %d to be '%s', got '%s'", i, expected, colName.Name.String())
- }
- }
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stmt, err := ParseSQL(tt.sql)
- if tt.wantErr {
- if err == nil {
- t.Errorf("Expected error, but got none")
- }
- return
- }
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- return
- }
- if tt.validate != nil {
- tt.validate(t, stmt)
- }
- })
- }
- }
- func TestParseSQL_WHERE_Clauses(t *testing.T) {
- tests := []struct {
- name string
- sql string
- wantErr bool
- validate func(t *testing.T, stmt Statement)
- }{
- {
- name: "WHERE with simple comparison",
- sql: "SELECT * FROM users WHERE age > 18",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Where == nil {
- t.Fatal("Expected WHERE clause, got nil")
- }
- // Just verify we have a WHERE clause with an expression
- if selectStmt.Where.Expr == nil {
- t.Error("Expected WHERE expression, got nil")
- }
- },
- },
- {
- name: "WHERE with AND condition",
- sql: "SELECT * FROM users WHERE age > 18 AND status = 'active'",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Where == nil {
- t.Fatal("Expected WHERE clause, got nil")
- }
- // Verify we have an AND expression
- andExpr, ok := selectStmt.Where.Expr.(*AndExpr)
- if !ok {
- t.Errorf("Expected *AndExpr, got %T", selectStmt.Where.Expr)
- }
- _ = andExpr // Use variable to avoid unused error
- },
- },
- {
- name: "WHERE with OR condition",
- sql: "SELECT * FROM users WHERE age < 18 OR age > 65",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Where == nil {
- t.Fatal("Expected WHERE clause, got nil")
- }
- // Verify we have an OR expression
- orExpr, ok := selectStmt.Where.Expr.(*OrExpr)
- if !ok {
- t.Errorf("Expected *OrExpr, got %T", selectStmt.Where.Expr)
- }
- _ = orExpr // Use variable to avoid unused error
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stmt, err := ParseSQL(tt.sql)
- if tt.wantErr {
- if err == nil {
- t.Errorf("Expected error, but got none")
- }
- return
- }
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- return
- }
- if tt.validate != nil {
- tt.validate(t, stmt)
- }
- })
- }
- }
- func TestParseSQL_LIMIT_Clauses(t *testing.T) {
- tests := []struct {
- name string
- sql string
- wantErr bool
- validate func(t *testing.T, stmt Statement)
- }{
- {
- name: "LIMIT with number",
- sql: "SELECT * FROM users LIMIT 10",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Limit == nil {
- t.Fatal("Expected LIMIT clause, got nil")
- }
- if selectStmt.Limit.Rowcount == nil {
- t.Error("Expected LIMIT rowcount, got nil")
- }
- // Verify no OFFSET is set
- if selectStmt.Limit.Offset != nil {
- t.Error("Expected OFFSET to be nil for LIMIT-only query")
- }
- sqlVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
- if !ok {
- t.Errorf("Expected *SQLVal, got %T", selectStmt.Limit.Rowcount)
- }
- if sqlVal.Type != IntVal {
- t.Errorf("Expected IntVal type, got %d", sqlVal.Type)
- }
- if string(sqlVal.Val) != "10" {
- t.Errorf("Expected limit value '10', got '%s'", string(sqlVal.Val))
- }
- },
- },
- {
- name: "LIMIT with OFFSET",
- sql: "SELECT * FROM users LIMIT 10 OFFSET 5",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Limit == nil {
- t.Fatal("Expected LIMIT clause, got nil")
- }
- // Verify LIMIT value
- if selectStmt.Limit.Rowcount == nil {
- t.Error("Expected LIMIT rowcount, got nil")
- }
- limitVal, ok := selectStmt.Limit.Rowcount.(*SQLVal)
- if !ok {
- t.Errorf("Expected *SQLVal for LIMIT, got %T", selectStmt.Limit.Rowcount)
- }
- if limitVal.Type != IntVal {
- t.Errorf("Expected IntVal type for LIMIT, got %d", limitVal.Type)
- }
- if string(limitVal.Val) != "10" {
- t.Errorf("Expected limit value '10', got '%s'", string(limitVal.Val))
- }
- // Verify OFFSET value
- if selectStmt.Limit.Offset == nil {
- t.Fatal("Expected OFFSET clause, got nil")
- }
- offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
- if !ok {
- t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
- }
- if offsetVal.Type != IntVal {
- t.Errorf("Expected IntVal type for OFFSET, got %d", offsetVal.Type)
- }
- if string(offsetVal.Val) != "5" {
- t.Errorf("Expected offset value '5', got '%s'", string(offsetVal.Val))
- }
- },
- },
- {
- name: "LIMIT with OFFSET zero",
- sql: "SELECT * FROM users LIMIT 5 OFFSET 0",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Limit == nil {
- t.Fatal("Expected LIMIT clause, got nil")
- }
- // Verify OFFSET is 0
- if selectStmt.Limit.Offset == nil {
- t.Fatal("Expected OFFSET clause, got nil")
- }
- offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
- if !ok {
- t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
- }
- if string(offsetVal.Val) != "0" {
- t.Errorf("Expected offset value '0', got '%s'", string(offsetVal.Val))
- }
- },
- },
- {
- name: "LIMIT with large OFFSET",
- sql: "SELECT * FROM users LIMIT 100 OFFSET 1000",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- selectStmt := stmt.(*SelectStatement)
- if selectStmt.Limit == nil {
- t.Fatal("Expected LIMIT clause, got nil")
- }
- // Verify large OFFSET value
- offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal)
- if !ok {
- t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset)
- }
- if string(offsetVal.Val) != "1000" {
- t.Errorf("Expected offset value '1000', got '%s'", string(offsetVal.Val))
- }
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stmt, err := ParseSQL(tt.sql)
- if tt.wantErr {
- if err == nil {
- t.Errorf("Expected error, but got none")
- }
- return
- }
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- return
- }
- if tt.validate != nil {
- tt.validate(t, stmt)
- }
- })
- }
- }
- func TestParseSQL_SHOW_Statements(t *testing.T) {
- tests := []struct {
- name string
- sql string
- wantErr bool
- validate func(t *testing.T, stmt Statement)
- }{
- {
- name: "SHOW DATABASES",
- sql: "SHOW DATABASES",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- showStmt, ok := stmt.(*ShowStatement)
- if !ok {
- t.Fatalf("Expected *ShowStatement, got %T", stmt)
- }
- if showStmt.Type != "databases" {
- t.Errorf("Expected type 'databases', got '%s'", showStmt.Type)
- }
- },
- },
- {
- name: "SHOW TABLES",
- sql: "SHOW TABLES",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- showStmt, ok := stmt.(*ShowStatement)
- if !ok {
- t.Fatalf("Expected *ShowStatement, got %T", stmt)
- }
- if showStmt.Type != "tables" {
- t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
- }
- },
- },
- {
- name: "SHOW TABLES FROM database",
- sql: "SHOW TABLES FROM \"test_db\"",
- wantErr: false,
- validate: func(t *testing.T, stmt Statement) {
- showStmt, ok := stmt.(*ShowStatement)
- if !ok {
- t.Fatalf("Expected *ShowStatement, got %T", stmt)
- }
- if showStmt.Type != "tables" {
- t.Errorf("Expected type 'tables', got '%s'", showStmt.Type)
- }
- if showStmt.Schema != "test_db" {
- t.Errorf("Expected schema 'test_db', got '%s'", showStmt.Schema)
- }
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- stmt, err := ParseSQL(tt.sql)
- if tt.wantErr {
- if err == nil {
- t.Errorf("Expected error, but got none")
- }
- return
- }
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- return
- }
- if tt.validate != nil {
- tt.validate(t, stmt)
- }
- })
- }
- }
|