-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest_shift.go
169 lines (155 loc) · 4 KB
/
test_shift.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
package shift
import (
"context"
"database/sql"
"encoding/hex"
"fmt"
"math/rand"
"reflect"
"testing"
"time"
"github.com/luno/jettison/errors"
)
// TODO: Implement TestArcFSM
// TestFSM tests the provided FSM instance by driving it through all possible
// state transitions using fuzzed data. It ensures all states are reachable and
// that the sql queries match the schema.
func TestFSM(_ testing.TB, dbc *sql.DB, fsm *FSM) error {
if fsm.insertStatus == nil {
return errors.New("fsm without insert status not supported")
}
found := map[int]bool{
fsm.insertStatus.ShiftStatus(): true,
}
paths := buildPaths(fsm.states, fsm.insertStatus)
for i, path := range paths {
name := fmt.Sprintf("%d_from_%d_to_%d_len_%d", i, path[0].st, path[len(path)-1].st, len(path))
msg := "error in path " + name
insert, err := randomInsert(path[0].req)
if err != nil {
return errors.Wrap(err, msg)
}
id, err := fsm.Insert(context.Background(), dbc, insert)
if err != nil {
return errors.Wrap(err, msg)
}
from := path[0].st
for _, up := range path[1:] {
update, err := randomUpdate(up.req, id)
if err != nil {
return errors.Wrap(err, msg)
}
err = fsm.Update(context.Background(), dbc, from, up.st, update)
if err != nil {
return errors.Wrap(err, msg)
}
from = up.st
found[up.st.ShiftStatus()] = true
}
}
for st := range fsm.states {
if !found[st] {
return errors.New("status not reachable")
}
}
return nil
}
func randomUpdate(req any, id int64) (u Updater[int64], err error) {
u, ok := req.(Updater[int64])
if !ok {
return nil, errors.New("req not of tupe Updater")
}
s := reflect.New(reflect.ValueOf(req).Type()).Elem()
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
t := f.Type()
if s.Type().Field(i).Name == "ID" {
f.SetInt(id)
} else {
f.Set(randVal(t))
}
}
return s.Interface().(Updater[int64]), nil
}
func randomInsert(req any) (Inserter[int64], error) {
_, ok := req.(Inserter[int64])
if !ok {
return nil, errors.New("req not of type Inserter")
}
s := reflect.New(reflect.ValueOf(req).Type()).Elem()
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
f.Set(randVal(f.Type()))
}
return s.Interface().(Inserter[int64]), nil
}
func buildPaths(states map[int]status, from Status) [][]status {
var res [][]status
here := states[from.ShiftStatus()]
hasEnd := len(here.next) == 0
delete(states, from.ShiftStatus()) // Break cycles
for next := range here.next {
if _, ok := states[next.ShiftStatus()]; !ok {
hasEnd = true // Stop at breaks
continue
}
paths := buildPaths(states, next)
for _, path := range paths {
res = append(res, append([]status{here}, path...))
}
}
states[from.ShiftStatus()] = here
if hasEnd {
res = append(res, []status{here})
}
return res
}
var (
intType = reflect.TypeOf((int)(0))
int64Type = reflect.TypeOf((int64)(0))
float64Type = reflect.TypeOf((float64)(0))
timeType = reflect.TypeOf(time.Time{})
sliceByteType = reflect.TypeOf([]byte(nil))
boolType = reflect.TypeOf(false)
stringType = reflect.TypeOf("")
nullTimeType = reflect.TypeOf(sql.NullTime{})
nullStringType = reflect.TypeOf(sql.NullString{})
)
func randVal(t reflect.Type) reflect.Value {
var v any
switch t {
case intType:
v = rand.Intn(1000)
case int64Type:
v = int64(rand.Intn(1000))
case float64Type:
v = rand.Float64() * 1000
case timeType:
d := time.Duration(rand.Intn(1000)) * time.Hour
v = time.Now().Add(-d)
case sliceByteType:
v = randBytes(rand.Intn(64))
case boolType:
v = rand.Float64() < 0.5
case stringType:
v = hex.EncodeToString(randBytes(rand.Intn(5) + 5))
case nullTimeType:
v = sql.NullTime{
Valid: rand.Float64() < 0.5,
Time: time.Now(),
}
case nullStringType:
v = sql.NullString{
Valid: rand.Float64() < 0.5,
String: hex.EncodeToString(randBytes(rand.Intn(5) + 5)),
}
default:
return reflect.Indirect(reflect.New(t))
}
return reflect.ValueOf(v)
}
func randBytes(size int) []byte {
b := make([]byte, size)
rand.Read(b)
return b
}