arithmetic_test.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package engine
  2. import (
  3. "fmt"
  4. "testing"
  5. "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
  6. )
  7. func TestArithmeticExpressionParsing(t *testing.T) {
  8. tests := []struct {
  9. name string
  10. expression string
  11. expectNil bool
  12. leftCol string
  13. rightCol string
  14. operator string
  15. }{
  16. {
  17. name: "simple addition",
  18. expression: "id+user_id",
  19. expectNil: false,
  20. leftCol: "id",
  21. rightCol: "user_id",
  22. operator: "+",
  23. },
  24. {
  25. name: "simple subtraction",
  26. expression: "col1-col2",
  27. expectNil: false,
  28. leftCol: "col1",
  29. rightCol: "col2",
  30. operator: "-",
  31. },
  32. {
  33. name: "multiplication with spaces",
  34. expression: "a * b",
  35. expectNil: false,
  36. leftCol: "a",
  37. rightCol: "b",
  38. operator: "*",
  39. },
  40. {
  41. name: "string concatenation",
  42. expression: "first_name||last_name",
  43. expectNil: false,
  44. leftCol: "first_name",
  45. rightCol: "last_name",
  46. operator: "||",
  47. },
  48. {
  49. name: "string concatenation with spaces",
  50. expression: "prefix || suffix",
  51. expectNil: false,
  52. leftCol: "prefix",
  53. rightCol: "suffix",
  54. operator: "||",
  55. },
  56. {
  57. name: "not arithmetic",
  58. expression: "simple_column",
  59. expectNil: true,
  60. },
  61. }
  62. for _, tt := range tests {
  63. t.Run(tt.name, func(t *testing.T) {
  64. // Use CockroachDB parser to parse the expression
  65. cockroachParser := NewCockroachSQLParser()
  66. dummySelect := fmt.Sprintf("SELECT %s", tt.expression)
  67. stmt, err := cockroachParser.ParseSQL(dummySelect)
  68. var result *ArithmeticExpr
  69. if err == nil {
  70. if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 {
  71. if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok {
  72. if arithmeticExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok {
  73. result = arithmeticExpr
  74. }
  75. }
  76. }
  77. }
  78. if tt.expectNil {
  79. if result != nil {
  80. t.Errorf("Expected nil for %s, got %v", tt.expression, result)
  81. }
  82. return
  83. }
  84. if result == nil {
  85. t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression)
  86. return
  87. }
  88. if result.Operator != tt.operator {
  89. t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator)
  90. }
  91. // Check left operand
  92. if leftCol, ok := result.Left.(*ColName); ok {
  93. if leftCol.Name.String() != tt.leftCol {
  94. t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String())
  95. }
  96. } else {
  97. t.Errorf("Expected left operand to be ColName, got %T", result.Left)
  98. }
  99. // Check right operand
  100. if rightCol, ok := result.Right.(*ColName); ok {
  101. if rightCol.Name.String() != tt.rightCol {
  102. t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String())
  103. }
  104. } else {
  105. t.Errorf("Expected right operand to be ColName, got %T", result.Right)
  106. }
  107. })
  108. }
  109. }
  110. func TestArithmeticExpressionEvaluation(t *testing.T) {
  111. engine := NewSQLEngine("")
  112. // Create test data
  113. result := HybridScanResult{
  114. Values: map[string]*schema_pb.Value{
  115. "id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}},
  116. "user_id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}},
  117. "price": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}},
  118. "qty": {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}},
  119. "first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}},
  120. "last_name": {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}},
  121. "prefix": {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}},
  122. "suffix": {Kind: &schema_pb.Value_StringValue{StringValue: "World"}},
  123. },
  124. }
  125. tests := []struct {
  126. name string
  127. expression string
  128. expected interface{}
  129. }{
  130. {
  131. name: "integer addition",
  132. expression: "id+user_id",
  133. expected: int64(15),
  134. },
  135. {
  136. name: "integer subtraction",
  137. expression: "id-user_id",
  138. expected: int64(5),
  139. },
  140. {
  141. name: "mixed types multiplication",
  142. expression: "price*qty",
  143. expected: float64(76.5),
  144. },
  145. {
  146. name: "string concatenation",
  147. expression: "first_name||last_name",
  148. expected: "JohnDoe",
  149. },
  150. {
  151. name: "string concatenation with spaces",
  152. expression: "prefix || suffix",
  153. expected: "HelloWorld",
  154. },
  155. }
  156. for _, tt := range tests {
  157. t.Run(tt.name, func(t *testing.T) {
  158. // Parse the arithmetic expression using CockroachDB parser
  159. cockroachParser := NewCockroachSQLParser()
  160. dummySelect := fmt.Sprintf("SELECT %s", tt.expression)
  161. stmt, err := cockroachParser.ParseSQL(dummySelect)
  162. if err != nil {
  163. t.Fatalf("Failed to parse expression %s: %v", tt.expression, err)
  164. }
  165. var arithmeticExpr *ArithmeticExpr
  166. if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 {
  167. if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok {
  168. if arithExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok {
  169. arithmeticExpr = arithExpr
  170. }
  171. }
  172. }
  173. if arithmeticExpr == nil {
  174. t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression)
  175. }
  176. // Evaluate the expression
  177. value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result)
  178. if err != nil {
  179. t.Fatalf("Failed to evaluate expression: %v", err)
  180. }
  181. if value == nil {
  182. t.Fatalf("Got nil value for expression: %s", tt.expression)
  183. }
  184. // Check the result
  185. switch expected := tt.expected.(type) {
  186. case int64:
  187. if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok {
  188. if intVal.Int64Value != expected {
  189. t.Errorf("Expected %d, got %d", expected, intVal.Int64Value)
  190. }
  191. } else {
  192. t.Errorf("Expected int64 result, got %T", value.Kind)
  193. }
  194. case float64:
  195. if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok {
  196. if doubleVal.DoubleValue != expected {
  197. t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue)
  198. }
  199. } else {
  200. t.Errorf("Expected double result, got %T", value.Kind)
  201. }
  202. case string:
  203. if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok {
  204. if stringVal.StringValue != expected {
  205. t.Errorf("Expected %s, got %s", expected, stringVal.StringValue)
  206. }
  207. } else {
  208. t.Errorf("Expected string result, got %T", value.Kind)
  209. }
  210. }
  211. })
  212. }
  213. }
  214. func TestSelectArithmeticExpression(t *testing.T) {
  215. // Test parsing a SELECT with arithmetic and string concatenation expressions
  216. stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table")
  217. if err != nil {
  218. t.Fatalf("Failed to parse SQL: %v", err)
  219. }
  220. selectStmt := stmt.(*SelectStatement)
  221. if len(selectStmt.SelectExprs) != 3 {
  222. t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
  223. }
  224. // Check first expression (id+user_id)
  225. aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr)
  226. if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok {
  227. if arithmeticExpr1.Operator != "+" {
  228. t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator)
  229. }
  230. } else {
  231. t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr)
  232. }
  233. // Check second expression (user_id*2)
  234. aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr)
  235. if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok {
  236. if arithmeticExpr2.Operator != "*" {
  237. t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator)
  238. }
  239. } else {
  240. t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr)
  241. }
  242. // Check third expression (first_name||last_name)
  243. aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr)
  244. if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok {
  245. if arithmeticExpr3.Operator != "||" {
  246. t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator)
  247. }
  248. } else {
  249. t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr)
  250. }
  251. }