Skip to content

Commit

Permalink
feat: support more factors in filter
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnyjoygh committed Feb 2, 2025
1 parent 2a392b8 commit b9a0c56
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 17 deletions.
5 changes: 5 additions & 0 deletions plugin/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ import (
// MemoFilterCELAttributes are the CEL attributes for memo.
var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("content", cel.StringType),
// As the built-in timestamp type is deprecated, we use string type for now.
// e.g., "2021-01-01T00:00:00Z"
cel.Variable("create_time", cel.StringType),
cel.Variable("tag", cel.StringType),
cel.Variable("update_time", cel.StringType),
cel.Variable("visibility", cel.StringType),
}

// Parse parses the filter string and returns the parsed expression.
Expand Down
1 change: 0 additions & 1 deletion server/router/api/v1/memo_service_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("display_time_before", cel.IntType),
cel.Variable("display_time_after", cel.IntType),
cel.Variable("creator", cel.StringType),
cel.Variable("uid", cel.StringType),
cel.Variable("state", cel.StringType),
cel.Variable("random", cel.BoolType),
cel.Variable("limit", cel.IntType),
Expand Down
15 changes: 15 additions & 0 deletions store/db/sqlite/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"

"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
Expand Down Expand Up @@ -100,6 +101,20 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
}
}
if v := find.Filter; v != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
// RestoreExprToSQL parses the expression and returns the SQL condition.
condition, err := RestoreExprToSQL(parsedExpr.GetExpr())
if err != nil {
return nil, err
}
where = append(where, condition)
}
if find.ExcludeComments {
where = append(where, "`parent_id` IS NULL")
}
Expand Down
79 changes: 65 additions & 14 deletions store/db/sqlite/memo_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"slices"
"strings"
"time"

"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
Expand Down Expand Up @@ -36,15 +37,55 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
// TODO(j): Implement this part.
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetConstValue(v.CallExpr.Args[1])
if err != nil {
return "", err
}
operator := "="
switch v.CallExpr.Function {
case "_==_":
operator = "="
case "_!=_":
operator = "!="
case "_<_":
operator = "<"
case "_>_":
operator = ">"
case "_<=_":
operator = "<="
case "_>=_":
operator = ">="
}

if identifier == "create_time" || identifier == "update_time" {
timestampStr, ok := value.(string)
if !ok {
return "", errors.New("invalid timestamp value")
}
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse timestamp")
}

if identifier == "create_time" {
condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix())
} else if identifier == "update_time" {
condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix())
}
}
case "@in":
if len(v.CallExpr.Args) != 2 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
factor := v.CallExpr.Args[0].GetIdentExpr().Name
if !slices.Contains([]string{"tag"}, factor) {
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}

values := []any{}
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
value, err := filter.GetConstValue(element)
Expand All @@ -53,33 +94,43 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
}
values = append(values, value)
}
if factor == "tag" {
t := []string{}
if identifier == "tag" {
subcodition := []string{}
for _, v := range values {
subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
}
if len(subcodition) == 1 {
condition = subcodition[0]
} else {
condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))
}
} else if identifier == "visibility" {
vs := []string{}
for _, v := range values {
t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
vs = append(vs, fmt.Sprintf(`"%s"`, v))
}
if len(t) == 1 {
condition = t[0]
if len(vs) == 1 {
condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0])
} else {
condition = fmt.Sprintf("(%s)", strings.Join(t, " OR "))
condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ","))
}
}
case "contains":
if len(v.CallExpr.Args) != 1 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
}
factor, err := RestoreExprToSQL(v.CallExpr.Target)
identifier, err := RestoreExprToSQL(v.CallExpr.Target)
if err != nil {
return "", err
}
if factor != "content" {
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
if identifier != "content" {
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
if err != nil {
return "", err
}
condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
case "!_":
if len(v.CallExpr.Args) != 1 {
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
Expand Down
16 changes: 14 additions & 2 deletions store/db/sqlite/memo_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@ func TestRestoreExprToSQL(t *testing.T) {
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))",
},
{
filter: `content.contains("hello")`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%",
filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE %\"memos\"%",
},
{
filter: `visibility in ["PUBLIC"]`,
want: "`memo`.`visibility` = \"PUBLIC\"",
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")",
},
{
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
want: "`memo`.`created_ts` = 1136189045",
},
}

Expand Down
1 change: 1 addition & 0 deletions store/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type FindMemo struct {
ExcludeContent bool
ExcludeComments bool
Random bool
Filter *string

// Pagination
Limit *int
Expand Down

0 comments on commit b9a0c56

Please sign in to comment.