From 400d51c9279d741c1df9301315b73271f6ffeb82 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 12 Nov 2024 12:23:58 +0700 Subject: [PATCH] feat(expression): support case-when expression --- clause/expr_case.go | 54 ++++++++++++++++++++++++++++++++++++ clause/expr_case_test.go | 60 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 clause/expr_case.go create mode 100644 clause/expr_case_test.go diff --git a/clause/expr_case.go b/clause/expr_case.go new file mode 100644 index 000000000..bdcce9ab6 --- /dev/null +++ b/clause/expr_case.go @@ -0,0 +1,54 @@ +package clause + +type ExprCaseCondition struct { + When string + Then string + Vars []any +} + +type ExprCaseElse struct { + Then string + Vars []any +} + +type ExprCase struct { + Cases []*ExprCaseCondition + Else *ExprCaseElse +} + +func (expr ExprCase) Name() string { + return "CASE" +} + +func (expr ExprCase) Build(builder Builder) { + var vars []any + for idx, condition := range expr.Cases { + if idx > 0 { + _ = builder.WriteByte(' ') + } + _, _ = builder.WriteString("WHEN ") + _, _ = builder.WriteString(condition.When) + _, _ = builder.WriteString(" THEN ") + _, _ = builder.WriteString(condition.Then) + if len(condition.Vars) > 0 { + vars = append(vars, condition.Vars...) + } + } + + if expr.Else != nil { + elseExpr := expr.Else + _, _ = builder.WriteString(" ELSE ") + _, _ = builder.WriteString(elseExpr.Then) + if len(elseExpr.Vars) > 0 { + vars = append(vars, elseExpr.Vars...) + } + } + _, _ = builder.WriteString(" END") + + clauseExpr := Expr{SQL: "", Vars: vars} + clauseExpr.Build(builder) +} + +func (expr ExprCase) MergeClause(clause *Clause) { + clause.Expression = expr +} diff --git a/clause/expr_case_test.go b/clause/expr_case_test.go new file mode 100644 index 000000000..c06d09e2e --- /dev/null +++ b/clause/expr_case_test.go @@ -0,0 +1,60 @@ +package clause_test + +import ( + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func Test_ExprCase(t *testing.T) { + type exampleUser struct { + ID string + Name string + } + + inputUsers := []*exampleUser{ + { + ID: "user-001", + Name: "user-name-001", + }, + { + ID: "user-002", + Name: "user-name-002", + }, + } + + userIDs := make([]string, len(inputUsers)) + userNameCases := make([]*clause.ExprCaseCondition, len(inputUsers)) + for idx, user := range inputUsers { + userIDs[idx] = user.ID + userNameCases[idx] = &clause.ExprCaseCondition{ + When: "user_id=?", + Then: "?", + Vars: []any{ + user.ID, + user.Name, + }, + } + } + + sqlQuery := db.ToSQL(func(db *gorm.DB) *gorm.DB { + return db. + Table("users"). + Where("user_id IN (?)", userIDs). + UpdateColumns(map[string]any{ + "user_name": clause.ExprCase{ + Cases: userNameCases, + Else: &clause.ExprCaseElse{ + Then: "user_name", + Vars: nil, + }, + }, + }) + }) + + expectedSQLQuery := "UPDATE `users` SET `user_name`=CASE WHEN user_id=\"user-001\" THEN \"user-name-001\" WHEN user_id=\"user-002\" THEN \"user-name-002\" ELSE user_name END WHERE user_id IN (\"user-001\",\"user-002\")" + if sqlQuery != expectedSQLQuery { + t.Errorf("SQLQuery is mismatch actual: %v expected:%v\n", sqlQuery, expectedSQLQuery) + } +}