Skip to content

Commit

Permalink
Use a DTO instead of the actual Goyave request
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Nov 2, 2023
1 parent ff1d8d1 commit 7c17576
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 237 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go: ["1.20", "1.21"]
go: ["1.21"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
Expand All @@ -36,5 +36,5 @@ jobs:
- name: Run lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.54
version: v1.55
args: --timeout 5m
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import "goyave.dev/filter"

func (ctrl *UserController) Index(response *goyave.Response, request *goyave.Request) {
var users []*model.User
paginator, tx := filter.Scope(ctrl.DB(), request, &users)
paginator, tx := filter.Scope(ctrl.DB(), filter.NewRequest(request.Query), &users)
if response.WriteDBError(tx.Error) {
return
}
Expand All @@ -44,7 +44,7 @@ And **that's it**! Now your front-end can add query parameters to filter as it w
You can also find records without paginating using `ScopeUnpaginated()`:
```go
var users []*model.User
tx := filter.ScopeUnpaginated(ctrl.DB(), request, &users)
tx := filter.ScopeUnpaginated(ctrl.DB(), filter.NewRequest(request.Query), &users)
if response.WriteDBError(tx.Error) {
return
}
Expand Down Expand Up @@ -90,7 +90,7 @@ settings := &filter.Settings[*model.User]{
},
}
results := []*model.User{}
paginator, tx := settings.Scope(ctrl.DB(), request, &results)
paginator, tx := settings.Scope(ctrl.DB(), filter.NewRequest(request.Query), &results)
```

### Filter
Expand Down Expand Up @@ -298,7 +298,7 @@ If you want to add static conditions (not automatically defined by the library),
users := []model.User{}
db := ctrl.DB()
db = db.Where(db.Session(&gorm.Session{NewDB: true}).Where("username LIKE ?", "%Miss%").Or("username LIKE ?", "%Ms.%"))
paginator, tx := filter.Scope(db, request, &users)
paginator, tx := filter.Scope(db, filter.NewRequest(request.Query), &users)
if response.WriteDBError(tx.Error) {
return
}
Expand Down Expand Up @@ -422,7 +422,7 @@ func (ctrl *UserController) Index(response *goyave.Response, request *goyave.Req

db := ctrl.DB().Joins("Relation")

paginator, tx := filter.Scope(db, request, &users)
paginator, tx := filter.Scope(db, filter.NewRequest(request.Query), &users)
if response.WriteDBError(tx.Error) {
return
}
Expand Down
173 changes: 100 additions & 73 deletions settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,65 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"goyave.dev/goyave/v5"
"goyave.dev/goyave/v5/database"
"goyave.dev/goyave/v5/util/errors"
"goyave.dev/goyave/v5/util/typeutil"
)

// Request DTO for a filter query. Any non-present option will be ignored.
type Request struct {
Search typeutil.Undefined[string]
Filter typeutil.Undefined[[]*Filter]
Or typeutil.Undefined[[]*Filter]
Sort typeutil.Undefined[[]*Sort]
Join typeutil.Undefined[[]*Join]
Fields typeutil.Undefined[[]string]
Page typeutil.Undefined[int]
PerPage typeutil.Undefined[int]
}

// NewRequest creates a filter request from an HTTP request's query.
// Uses the following entries in the query, expected to be validated:
// - search
// - filter
// - or
// - sort
// - join
// - fields
// - page
// - per_page
//
// If a field in the query doesn't match the expected type (non-validated) for the
// filtering option, it will be ignored without an error.
func NewRequest(query map[string]any) *Request {
r := &Request{}
if search, ok := query["search"].(string); ok {
r.Search = typeutil.NewUndefined(search)
}
if filter, ok := query["filter"].([]*Filter); ok {
r.Filter = typeutil.NewUndefined(filter)
}
if or, ok := query["or"].([]*Filter); ok {
r.Or = typeutil.NewUndefined(or)
}
if sort, ok := query["sort"].([]*Sort); ok {
r.Sort = typeutil.NewUndefined(sort)
}
if join, ok := query["join"].([]*Join); ok {
r.Join = typeutil.NewUndefined(join)
}
if fields, ok := query["fields"].([]string); ok {
r.Fields = typeutil.NewUndefined(fields)
}
if page, ok := query["page"].(int); ok {
r.Page = typeutil.NewUndefined(page)
}
if perPage, ok := query["per_page"].(int); ok {
r.PerPage = typeutil.NewUndefined(perPage)
}
return r
}

// Settings settings to disable certain features and/or blacklist fields
// and relations.
// The generic type is the pointer type of the model.
Expand Down Expand Up @@ -71,30 +125,24 @@ func parseModel(db *gorm.DB, model any) (*schema.Schema, error) {
}

// Scope using the default FilterSettings. See `FilterSettings.Scope()` for more details.
func Scope[T any](db *gorm.DB, request *goyave.Request, dest *[]T) (*database.Paginator[T], *gorm.DB) {
func Scope[T any](db *gorm.DB, request *Request, dest *[]T) (*database.Paginator[T], *gorm.DB) {
return (&Settings[T]{}).Scope(db, request, dest)
}

// ScopeUnpaginated using the default FilterSettings. See `FilterSettings.ScopeUnpaginated()` for more details.
func ScopeUnpaginated[T any](db *gorm.DB, request *goyave.Request, dest *[]T) *gorm.DB {
func ScopeUnpaginated[T any](db *gorm.DB, request *Request, dest *[]T) *gorm.DB {
return (&Settings[T]{}).ScopeUnpaginated(db, request, dest)
}

// Scope apply all filters, sorts and joins defined in the request's data to the given `*gorm.DB`
// and process pagination. Returns the resulting `*database.Paginator` and the `*gorm.DB` result,
// which can be used to check for database errors.
// The given request is expected to be validated using `ApplyValidation`.
func (s *Settings[T]) Scope(db *gorm.DB, request *goyave.Request, dest *[]T) (*database.Paginator[T], *gorm.DB) {
func (s *Settings[T]) Scope(db *gorm.DB, request *Request, dest *[]T) (*database.Paginator[T], *gorm.DB) {
db, schema, hasJoins := s.scopeCommon(db, request, dest)

page := 1
if queryPage, ok := request.Query["page"]; ok {
page = queryPage.(int)
}
pageSize := DefaultPageSize
if queryPerPage, ok := request.Query["per_page"]; ok {
pageSize = queryPerPage.(int)
}
page := request.Page.Default(1)
pageSize := request.PerPage.Default(DefaultPageSize)

paginator := database.NewPaginator(db, page, pageSize, dest)
paginator.UpdatePageInfo()
Expand All @@ -114,7 +162,7 @@ func (s *Settings[T]) Scope(db *gorm.DB, request *goyave.Request, dest *[]T) (*d
// Returns the `*gorm.DB` result, which can be used to check for database errors.
// The records will be added in the given `dest` slice.
// The given request is expected to be validated using `ApplyValidation`.
func (s *Settings[T]) ScopeUnpaginated(db *gorm.DB, request *goyave.Request, dest any) *gorm.DB {
func (s *Settings[T]) ScopeUnpaginated(db *gorm.DB, request *Request, dest any) *gorm.DB {
db, schema, hasJoins := s.scopeCommon(db, request, dest)
db = s.scopeSort(db, request, schema)
if fieldsDB := s.scopeFields(db, request, schema, hasJoins); fieldsDB != nil {
Expand All @@ -127,7 +175,7 @@ func (s *Settings[T]) ScopeUnpaginated(db *gorm.DB, request *goyave.Request, des

// scopeCommon applies all scopes common to both the paginated and non-paginated requests.
// The third returned valued indicates if the query contains joins.
func (s *Settings[T]) scopeCommon(db *gorm.DB, request *goyave.Request, dest any) (*gorm.DB, *schema.Schema, bool) {
func (s *Settings[T]) scopeCommon(db *gorm.DB, request *Request, dest any) (*gorm.DB, *schema.Schema, bool) {
schema, err := parseModel(db, dest)
if err != nil {
panic(errors.New(err))
Expand All @@ -137,24 +185,20 @@ func (s *Settings[T]) scopeCommon(db *gorm.DB, request *goyave.Request, dest any
db = s.applyFilters(db, request, schema)

hasJoins := false
queryJoin, queryHasJoin := request.Query["join"]
if !s.DisableJoin && queryHasJoin {
joins, ok := queryJoin.([]*Join)
if ok {
selectCache := map[string][]string{}
for _, j := range joins {
hasJoins = true
j.selectCache = selectCache
if s := j.Scopes(s.Blacklist, schema); s != nil {
db = db.Scopes(s...)
}
if !s.DisableJoin && request.Join.Present {
joins := request.Join.Val
selectCache := map[string][]string{}
for _, j := range joins {
hasJoins = true
j.selectCache = selectCache
if s := j.Scopes(s.Blacklist, schema); s != nil {
db = db.Scopes(s...)
}
}
}

querySearch, queryHasSearch := request.Query["search"]
if !s.DisableSearch && queryHasSearch {
if search := s.applySearch(querySearch.(string), schema); search != nil {
if !s.DisableSearch && request.Search.Present {
if search := s.applySearch(request.Search.Val, schema); search != nil {
if scope := search.Scope(schema); scope != nil {
db = db.Scopes(scope)
}
Expand All @@ -171,10 +215,9 @@ func (s *Settings[T]) scopeCommon(db *gorm.DB, request *goyave.Request, dest any
return db, schema, hasJoins
}

func (s *Settings[T]) scopeFields(db *gorm.DB, request *goyave.Request, schema *schema.Schema, hasJoins bool) *gorm.DB {
queryFields, queryHasFields := request.Query["fields"].([]string)
if !s.DisableFields && queryHasFields {
fields := slices.Clone(queryFields)
func (s *Settings[T]) scopeFields(db *gorm.DB, request *Request, schema *schema.Schema, hasJoins bool) *gorm.DB {
if !s.DisableFields && request.Fields.Present {
fields := slices.Clone(request.Fields.Val)
if hasJoins {
if len(schema.PrimaryFieldDBNames) == 0 {
db.AddError(errors.New("could not find primary key. Add `gorm:\"primaryKey\"` to your model"))
Expand All @@ -188,14 +231,12 @@ func (s *Settings[T]) scopeFields(db *gorm.DB, request *goyave.Request, schema *
return db.Scopes(selectScope(schema.Table, getSelectableFields(&s.Blacklist, schema), false))
}

func (s *Settings[T]) scopeSort(db *gorm.DB, request *goyave.Request, schema *schema.Schema) *gorm.DB {
querySort, queryHasSort := request.Query["sort"]

func (s *Settings[T]) scopeSort(db *gorm.DB, request *Request, schema *schema.Schema) *gorm.DB {
var sorts []*Sort
if !queryHasSort {
if !request.Sort.Present {
sorts = s.DefaultSort
} else if s, ok := querySort.([]*Sort); ok {
sorts = s
} else {
sorts = request.Sort.Val
}

if !s.DisableSort {
Expand All @@ -208,42 +249,38 @@ func (s *Settings[T]) scopeSort(db *gorm.DB, request *goyave.Request, schema *sc
return db
}

func (s *Settings[T]) applyFilters(db *gorm.DB, request *goyave.Request, schema *schema.Schema) *gorm.DB {
func (s *Settings[T]) applyFilters(db *gorm.DB, request *Request, schema *schema.Schema) *gorm.DB {
if s.DisableFilter {
return db
}
filterScopes := make([]func(*gorm.DB) *gorm.DB, 0, 2)
joinScopes := make([]func(*gorm.DB) *gorm.DB, 0, 2)

andLen := filterLen[T](request, "filter")
orLen := filterLen[T](request, "or")
andLen := len(request.Filter.Default([]*Filter{}))
orLen := len(request.Or.Default([]*Filter{}))
mixed := orLen > 1 && andLen > 0

for _, queryParam := range []string{"filter", "or"} {
query, has := request.Query[queryParam]
if has {
filters, ok := query.([]*Filter)
if ok {
group := make([]func(*gorm.DB) *gorm.DB, 0, 4)
for _, f := range filters {
if mixed {
f = &Filter{
Field: f.Field,
Operator: f.Operator,
Args: f.Args,
Or: false,
}
}
joinScope, conditionScope := f.Scope(s.Blacklist, schema)
if conditionScope != nil {
group = append(group, conditionScope)
}
if joinScope != nil {
joinScopes = append(joinScopes, joinScope)
for _, filters := range []typeutil.Undefined[[]*Filter]{request.Filter, request.Or} {
if filters.Present {
group := make([]func(*gorm.DB) *gorm.DB, 0, 4)
for _, f := range filters.Val {
if mixed {
f = &Filter{
Field: f.Field,
Operator: f.Operator,
Args: f.Args,
Or: false,
}
}
filterScopes = append(filterScopes, groupFilters(group, false))
joinScope, conditionScope := f.Scope(s.Blacklist, schema)
if conditionScope != nil {
group = append(group, conditionScope)
}
if joinScope != nil {
joinScopes = append(joinScopes, joinScope)
}
}
filterScopes = append(filterScopes, groupFilters(group, false))
}
}
if len(joinScopes) > 0 {
Expand All @@ -255,16 +292,6 @@ func (s *Settings[T]) applyFilters(db *gorm.DB, request *goyave.Request, schema
return db
}

func filterLen[T any](request *goyave.Request, name string) int {
count := 0
if data, ok := request.Query[name]; ok {
if filters, ok := data.([]*Filter); ok {
count = len(filters)
}
}
return count
}

func groupFilters(scopes []func(*gorm.DB) *gorm.DB, and bool) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {
processedFilters := tx.Session(&gorm.Session{NewDB: true})
Expand Down
Loading

0 comments on commit 7c17576

Please sign in to comment.