-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgorm.go
175 lines (152 loc) · 5.74 KB
/
gorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
package repositorysdk
import (
"gorm.io/gorm"
"math"
)
type Entity interface {
TableName() string
}
// Pagination returns a function that can be used as a GORM scope to paginate results. It takes a pointer to a slice of the entity type, a pointer to a PaginationMetadata struct, a GORM database instance, and an optional list of additional GORM scopes. It calculates the total number of items that match the query, updates the provided PaginationMetadata struct with the total number of items, total number of pages, and current page number, and returns a GORM scope that can be used to fetch the results for the current page.
func Pagination[T Entity](value *[]T, meta *PaginationMetadata, db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) func(db *gorm.DB) *gorm.DB {
var totalItems int64
db.Model(&value).
Scopes(scopes...).
Count(&totalItems)
meta.TotalItem = int(totalItems)
totalPages := math.Ceil(float64(totalItems) / float64(meta.GetItemPerPage()))
meta.TotalPage = int(totalPages)
return func(db *gorm.DB) *gorm.DB {
return db.Offset(meta.GetOffset()).Limit(meta.ItemsPerPage)
}
}
// FindOneByID returns a function that queries the entity with the given ID and returns the query result.
func FindOneByID[T Entity](id string, entity T) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.
First(&entity, "id = ?", id)
}
}
// UpdateWithoutResult returns a function that updates the entity with the given ID using the given entity, but doesn't return the updated entity.
func UpdateWithoutResult[T Entity](id string, entity T) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.
Where(id, "id = ?", id).
Updates(&entity)
}
}
// UpdateByIDWithResult returns a function that updates the entity with the given ID using the given entity, and returns the updated entity.
func UpdateByIDWithResult[T Entity](id string, entity T) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.
Where(id, "id = ?", id).
Updates(&entity).
First(&entity, "id = ?", id)
}
}
// DeleteWithResult returns a function that queries the entity with the given ID, deletes it, and returns the deleted entity.
func DeleteWithResult[T Entity](id string, entity T) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.
First(&entity, "id = ?", id).
Delete(&entity, "id = ?", id)
}
}
// DeleteWithoutResult returns a function that deletes the entity with the given ID using the given entity, but doesn't return the deleted entity.
func DeleteWithoutResult[T Entity](id string, entity T) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.
Delete(&entity, "id = ?", id)
}
}
type GormRepository[T Entity] interface {
FindAll(metadata *PaginationMetadata, entities *[]T, scope ...func(db *gorm.DB) *gorm.DB) error
FindOne(id string, entity T, scope ...func(db *gorm.DB) *gorm.DB) error
Create(entity T, scope ...func(db *gorm.DB) *gorm.DB) error
Update(id string, entity T, scope ...func(db *gorm.DB) *gorm.DB) error
Delete(id string, entity T, scope ...func(db *gorm.DB) *gorm.DB) error
GetDB() *gorm.DB
}
type gormRepository[T Entity] struct {
db *gorm.DB
}
// NewGormRepository function that create a new instance of gormRepository[T] with a GORM database connection
func NewGormRepository[T Entity](db *gorm.DB) GormRepository[T] {
return &gormRepository[T]{
db: db,
}
}
func (r *gormRepository[T]) GetDB() *gorm.DB {
return r.db
}
// FindAll the entities with pagination metadata and scopes.
// Pagination is achieved by using the Pagination function.
// The method updates the metadata to reflect the total number of items and the number of items on the current page.
func (r *gormRepository[T]) FindAll(metadata *PaginationMetadata, entities *[]T, scope ...func(db *gorm.DB) *gorm.DB) error {
if err := r.db.
Scopes(Pagination[T](entities, metadata, r.db, scope...)).
Find(&entities).
Error; err != nil {
return err
}
metadata.ItemCount = len(*entities)
return nil
}
// FindOne finds a single entity with the given id and optional scopes.
func (r *gormRepository[T]) FindOne(id string, entity T, scope ...func(db *gorm.DB) *gorm.DB) error {
return r.db.
Scopes(scope...).
First(entity, "id = ?", id).
Error
}
// Create a new entity in the database.
func (r *gormRepository[T]) Create(entity T, scope ...func(db *gorm.DB) *gorm.DB) error {
return r.db.
Scopes(scope...).
Create(entity).
Error
}
// Update an existing entity with the given id in the database.
// It returns an error if no entity with the given id is found.
func (r *gormRepository[T]) Update(id string, entity T, scope ...func(db *gorm.DB) *gorm.DB) error {
return r.db.
Scopes(scope...).
Where(id, "id = ?", id).
Updates(&entity).
First(&entity, "id = ?", id).
Error
}
// Delete an existing entity with the given id from the database.
// It returns an error if no entity with the given id is found.
func (r *gormRepository[T]) Delete(id string, entity T, scope ...func(db *gorm.DB) *gorm.DB) error {
return r.db.
Scopes(scope...).
First(&entity, "id = ?", id).
Delete(&entity).
Error
}
// WithTransaction runs a list of functions inside a single transaction.
//
// Parameters:
// - fns: a list of functions that will be executed within a single transaction.
//
// Returns:
// - error: an error if any of the functions returns an error or the transaction commit fails, otherwise nil.
func (r *gormRepository[T]) WithTransaction(fns ...func(tx *gorm.DB) error) error {
tx := r.db.Begin()
defer func() {
if err := recover(); err != nil {
tx.Rollback()
panic(err)
} else if err := tx.Commit().Error; err != nil {
tx.Rollback()
panic(err)
}
}()
for _, fn := range fns {
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
}
return nil
}