arithmetic_with_functions_test.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package engine
  2. import (
  3. "context"
  4. "testing"
  5. )
  6. // TestArithmeticWithFunctions tests arithmetic operations with function calls
  7. // This validates the complete AST parser and evaluation system for column-level calculations
  8. func TestArithmeticWithFunctions(t *testing.T) {
  9. engine := NewTestSQLEngine()
  10. testCases := []struct {
  11. name string
  12. sql string
  13. expected string
  14. desc string
  15. }{
  16. {
  17. name: "Simple function arithmetic",
  18. sql: "SELECT LENGTH('hello') + 10 FROM user_events LIMIT 1",
  19. expected: "15",
  20. desc: "Basic function call with addition",
  21. },
  22. {
  23. name: "Nested functions with arithmetic",
  24. sql: "SELECT length(trim(' hello world ')) + 12 FROM user_events LIMIT 1",
  25. expected: "23",
  26. desc: "Complex nested functions with arithmetic operation (user's original failing query)",
  27. },
  28. {
  29. name: "Function subtraction",
  30. sql: "SELECT LENGTH('programming') - 5 FROM user_events LIMIT 1",
  31. expected: "6",
  32. desc: "Function call with subtraction",
  33. },
  34. {
  35. name: "Function multiplication",
  36. sql: "SELECT LENGTH('test') * 3 FROM user_events LIMIT 1",
  37. expected: "12",
  38. desc: "Function call with multiplication",
  39. },
  40. {
  41. name: "Multiple nested functions",
  42. sql: "SELECT LENGTH(UPPER(TRIM(' hello '))) FROM user_events LIMIT 1",
  43. expected: "5",
  44. desc: "Triple nested functions",
  45. },
  46. }
  47. for _, tc := range testCases {
  48. t.Run(tc.name, func(t *testing.T) {
  49. result, err := engine.ExecuteSQL(context.Background(), tc.sql)
  50. if err != nil {
  51. t.Errorf("Query failed: %v", err)
  52. return
  53. }
  54. if result.Error != nil {
  55. t.Errorf("Query result error: %v", result.Error)
  56. return
  57. }
  58. if len(result.Rows) == 0 {
  59. t.Error("Expected at least one row")
  60. return
  61. }
  62. actual := result.Rows[0][0].ToString()
  63. if actual != tc.expected {
  64. t.Errorf("%s: Expected '%s', got '%s'", tc.desc, tc.expected, actual)
  65. } else {
  66. t.Logf("PASS %s: %s → %s", tc.desc, tc.sql, actual)
  67. }
  68. })
  69. }
  70. }