Skip to content

Commit

Permalink
Added GetTotalRows to mock, and a function to make mock queries easier
Browse files Browse the repository at this point in the history
  • Loading branch information
adampresley committed May 3, 2023
1 parent 423e417 commit a51452f
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ go 1.16

require (
github.com/jackc/pgconn v1.8.1
github.com/jackc/pgproto3/v2 v2.0.6
github.com/jackc/pgx/v4 v4.11.0
)
5 changes: 5 additions & 0 deletions mock-postgresr.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type MockRows struct {
CommandTagFunc func() pgconn.CommandTag
ErrFunc func() error
FieldDescriptionsFunc func() []pgproto3.FieldDescription
GetTotalRowsFunc func() uint64
NextFunc func() bool
ScanFunc func(dest ...interface{}) error
ValuesFunc func() ([]interface{}, error)
Expand Down Expand Up @@ -67,6 +68,10 @@ func (m *MockRows) FieldDescriptions() []pgproto3.FieldDescription {
return m.FieldDescriptionsFunc()
}

func (m *MockRows) GetTotalRows() uint64 {
return m.GetTotalRowsFunc()
}

func (m *MockRows) Next() bool {
return m.NextFunc()
}
Expand Down
116 changes: 116 additions & 0 deletions testhelpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package postgresr

import (
"context"
"database/sql"
"fmt"
"time"

"github.com/jackc/pgx/v4"
)

func MockQuerySuccessHelper(rows [][]interface{}) func(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) {
rowIndex := -1

return func(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) {
return &MockRows{
CloseFunc: func() {},
GetTotalRowsFunc: func() uint64 {
return uint64(len(rows))
},
NextFunc: func() bool {
rowIndex++
return rowIndex < len(rows)
},
ScanFunc: func(dest ...interface{}) error {
data := rows[rowIndex]

for index, d := range dest {
switch t := d.(type) {
case *bool:
p := d.(*bool)
*p = data[index].(bool)

case *string:
p := d.(*string)
*p = data[index].(string)

case *int:
p := d.(*int)
*p = data[index].(int)

case *int32:
p := d.(*int32)
*p = data[index].(int32)

case *int64:
p := d.(*int64)
*p = data[index].(int64)

case *float32:
p := d.(*float32)
*p = data[index].(float32)

case *float64:
p := d.(*float64)
*p = data[index].(float64)

case *time.Time:
p := d.(*time.Time)
*p = data[index].(time.Time)

case *sql.NullBool:
v, _ := data[index].(bool)
result := sql.NullBool{Bool: v, Valid: true}
p := d.(*sql.NullBool)
*p = result

case *sql.NullTime:
v, _ := data[index].(time.Time)
result := sql.NullTime{Time: v, Valid: true}
p := d.(*sql.NullTime)
*p = result

case *sql.NullInt16:
v, _ := data[index].(int16)
result := sql.NullInt16{Int16: v, Valid: true}
p := d.(*sql.NullInt16)
*p = result

case *sql.NullInt64:
v, _ := data[index].(int64)
result := sql.NullInt64{Int64: v, Valid: true}
p := d.(*sql.NullInt64)
*p = result

case *sql.NullInt32:
v, _ := data[index].(int32)
result := sql.NullInt32{Int32: v, Valid: true}
p := d.(*sql.NullInt32)
*p = result

case *sql.NullString:
v, _ := data[index].(string)
result := sql.NullString{String: v, Valid: true}
p := d.(*sql.NullString)
*p = result

case *sql.NullFloat64:
v, _ := data[index].(float64)
result := sql.NullFloat64{Float64: v, Valid: true}
p := d.(*sql.NullFloat64)
*p = result

default:
fmt.Printf("undefined type '%T' for value '%v', skipping.", t, data[index])
}
}

return nil
},
ValuesFunc: func() ([]interface{}, error) {
return rows[rowIndex], nil
},
}, nil
}
}

0 comments on commit a51452f

Please sign in to comment.