Skip to content

Commit

Permalink
refactor: register query
Browse files Browse the repository at this point in the history
  • Loading branch information
smsunarto committed Nov 10, 2023
1 parent 4fc6230 commit c6d27dd
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 274 deletions.
99 changes: 54 additions & 45 deletions cardinal/ecs/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,68 +33,48 @@ type IQuery interface {
IsEVMCompatible() bool
}

type QueryType[Request any, Reply any] struct {
type Query[Request any, Reply any] struct {
name string
handler func(wCtx WorldContext, req Request) (Reply, error)
handler func(wCtx WorldContext, req *Request) (*Reply, error)
requestABI *ethereumAbi.Type
replyABI *ethereumAbi.Type
}

func WithQueryEVMSupport[Request, Reply any]() func(transactionType *QueryType[Request, Reply]) {
return func(query *QueryType[Request, Reply]) {
func WithQueryEVMSupport[Request, Reply any]() func(transactionType *Query[Request, Reply]) {
return func(query *Query[Request, Reply]) {
err := query.generateABIBindings()
if err != nil {
panic(err)
}
}
}

var _ IQuery = &QueryType[struct{}, struct{}]{}

func NewQueryType[Request any, Reply any](
name string,
handler func(wCtx WorldContext, req Request) (Reply, error),
opts ...func() func(queryType *QueryType[Request, Reply]),
) *QueryType[Request, Reply] {
if name == "" {
panic("cannot create query without name")
}
if handler == nil {
panic("cannot create query without handler")
}
var req Request
var rep Reply
reqType := reflect.TypeOf(req)
reqKind := reqType.Kind()
reqValid := false
if (reqKind == reflect.Pointer && reqType.Elem().Kind() == reflect.Struct) || reqKind == reflect.Struct {
reqValid = true
}
repType := reflect.TypeOf(rep)
repKind := reqType.Kind()
repValid := false
if (repKind == reflect.Pointer && repType.Elem().Kind() == reflect.Struct) || repKind == reflect.Struct {
repValid = true
handler func(wCtx WorldContext, req *Request) (*Reply, error),
opts ...func() func(queryType *Query[Request, Reply]),
) (IQuery, error) {
err := validateQuery[Request, Reply](name, handler)
if err != nil {
return nil, err
}

if !repValid || !reqValid {
panic(fmt.Sprintf("Invalid QueryType: %s: The Request and Reply must be both structs", name))
}
r := &QueryType[Request, Reply]{
r := &Query[Request, Reply]{
name: name,
handler: handler,
}
for _, opt := range opts {
opt()(r)
}
return r

return r, nil
}

func (r *QueryType[Request, Reply]) IsEVMCompatible() bool {
func (r *Query[Request, Reply]) IsEVMCompatible() bool {
return r.requestABI != nil && r.replyABI != nil
}

func (r *QueryType[Request, Reply]) generateABIBindings() error {
func (r *Query[Request, Reply]) generateABIBindings() error {
var req Request
reqABI, err := abi.GenerateABIType(req)
if err != nil {
Expand All @@ -110,30 +90,30 @@ func (r *QueryType[Request, Reply]) generateABIBindings() error {
return nil
}

func (r *QueryType[req, rep]) Name() string {
func (r *Query[req, rep]) Name() string {
return r.name
}

func (r *QueryType[req, rep]) Schema() (request, reply *jsonschema.Schema) {
func (r *Query[req, rep]) Schema() (request, reply *jsonschema.Schema) {
return jsonschema.Reflect(new(req)), jsonschema.Reflect(new(rep))
}

func (r *QueryType[req, rep]) HandleQuery(wCtx WorldContext, a any) (any, error) {
func (r *Query[req, rep]) HandleQuery(wCtx WorldContext, a any) (any, error) {
request, ok := a.(req)
if !ok {
return nil, fmt.Errorf("cannot cast %T to this query request type %T", a, new(req))
}
reply, err := r.handler(wCtx, request)
reply, err := r.handler(wCtx, &request)
return reply, err
}

func (r *QueryType[req, rep]) HandleQueryRaw(wCtx WorldContext, bz []byte) ([]byte, error) {
func (r *Query[req, rep]) HandleQueryRaw(wCtx WorldContext, bz []byte) ([]byte, error) {
request := new(req)
err := json.Unmarshal(bz, request)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal query request into type %T: %w", *request, err)
}
res, err := r.handler(wCtx, *request)
res, err := r.handler(wCtx, request)
if err != nil {
return nil, err
}
Expand All @@ -144,7 +124,7 @@ func (r *QueryType[req, rep]) HandleQueryRaw(wCtx WorldContext, bz []byte) ([]by
return bz, nil
}

func (r *QueryType[req, rep]) DecodeEVMRequest(bz []byte) (any, error) {
func (r *Query[req, rep]) DecodeEVMRequest(bz []byte) (any, error) {
if r.requestABI == nil {
return nil, ErrEVMTypeNotSet
}
Expand All @@ -163,7 +143,7 @@ func (r *QueryType[req, rep]) DecodeEVMRequest(bz []byte) (any, error) {
return request, nil
}

func (r *QueryType[req, rep]) DecodeEVMReply(bz []byte) (any, error) {
func (r *Query[req, rep]) DecodeEVMReply(bz []byte) (any, error) {
if r.replyABI == nil {
return nil, ErrEVMTypeNotSet
}
Expand All @@ -182,7 +162,7 @@ func (r *QueryType[req, rep]) DecodeEVMReply(bz []byte) (any, error) {
return reply, nil
}

func (r *QueryType[req, rep]) EncodeEVMReply(a any) ([]byte, error) {
func (r *Query[req, rep]) EncodeEVMReply(a any) ([]byte, error) {
if r.replyABI == nil {
return nil, ErrEVMTypeNotSet
}
Expand All @@ -191,7 +171,7 @@ func (r *QueryType[req, rep]) EncodeEVMReply(a any) ([]byte, error) {
return bz, err
}

func (r *QueryType[Request, Reply]) EncodeAsABI(input any) ([]byte, error) {
func (r *Query[Request, Reply]) EncodeAsABI(input any) ([]byte, error) {
if r.requestABI == nil || r.replyABI == nil {
return nil, ErrEVMTypeNotSet
}
Expand All @@ -217,3 +197,32 @@ func (r *QueryType[Request, Reply]) EncodeAsABI(input any) ([]byte, error) {
}
return bz, nil
}

func validateQuery[Request any, Reply any](name string, handler func(wCtx WorldContext, req *Request) (*Reply, error)) error {
if name == "" {
return errors.New("cannot create query without name")
}
if handler == nil {
return errors.New("cannot create query without handler")
}

var req Request
var rep Reply
reqType := reflect.TypeOf(req)
reqKind := reqType.Kind()
reqValid := false
if (reqKind == reflect.Pointer && reqType.Elem().Kind() == reflect.Struct) || reqKind == reflect.Struct {
reqValid = true
}
repType := reflect.TypeOf(rep)
repKind := reqType.Kind()
repValid := false
if (repKind == reflect.Pointer && repType.Elem().Kind() == reflect.Struct) || repKind == reflect.Struct {
repValid = true
}

if !repValid || !reqValid {
return errors.New(fmt.Sprintf("Invalid Query: %s: The Request and Reply must be both structs", name))
}
return nil
}
72 changes: 24 additions & 48 deletions cardinal/ecs/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package ecs_test

import (
"context"
"github.com/stretchr/testify/require"
"pkg.world.dev/world-engine/cardinal/ecs"
"pkg.world.dev/world-engine/cardinal/testutils"
"testing"

Expand All @@ -11,41 +11,14 @@ import (
"gotest.tools/v3/assert"

routerv1 "pkg.world.dev/world-engine/rift/router/v1"

"pkg.world.dev/world-engine/cardinal/ecs"
)

func TestQueryTypeNotStructs(t *testing.T) {
type FooRequest struct {
ID string
}
type FooReply struct {
Name string
Age uint64
}

expectedReply := FooReply{
Name: "Chad",
Age: 22,
}

defer func() {
// test should trigger a panic.
panicValue := recover()
assert.Assert(t, panicValue != nil)
ecs.NewQueryType[FooRequest, FooReply]("foo", func(wCtx ecs.WorldContext, req FooRequest) (FooReply, error) {
return expectedReply, nil
})
defer func() {
// deferred function should not fail
panicValue = recover()
assert.Assert(t, panicValue == nil)
}()
}()

ecs.NewQueryType[string, string]("foo", func(wCtx ecs.WorldContext, req string) (string, error) {
return "blah", nil
str := "blah"
err := ecs.RegisterQuery[string, string](testutils.NewTestWorld(t).Instance(), "foo", func(wCtx ecs.WorldContext, req *string) (*string, error) {
return &str, nil
})
assert.Assert(t, err != nil)
}

func TestQueryEVM(t *testing.T) {
Expand All @@ -62,20 +35,22 @@ func TestQueryEVM(t *testing.T) {
Name: "Chad",
Age: 22,
}
fooQuery := ecs.NewQueryType[FooRequest, FooReply]("foo", func(wCtx ecs.WorldContext, req FooRequest,
) (FooReply, error) {
return expectedReply, nil
}, ecs.WithQueryEVMSupport[FooRequest, FooReply])

w := testutils.NewTestWorld(t).Instance()
err := w.RegisterQueries(fooQuery)
err := ecs.RegisterQuery[FooRequest, FooReply](w, "foo", func(wCtx ecs.WorldContext, req *FooRequest,
) (*FooReply, error) {
return &expectedReply, nil
}, ecs.WithQueryEVMSupport[FooRequest, FooReply])

assert.NilError(t, err)
err = w.RegisterMessages(ecs.NewMessageType[struct{}, struct{}]("blah"))
assert.NilError(t, err)
s, err := evm.NewServer(w)
assert.NilError(t, err)

// create the abi encoded bytes that the EVM would send.
fooQuery, err := w.GetQueryByName("foo")
assert.NilError(t, err)
bz, err := fooQuery.EncodeAsABI(FooRequest{ID: "foo"})
assert.NilError(t, err)

Expand All @@ -101,31 +76,32 @@ func TestPanicsOnNoNameOrHandler(t *testing.T) {
type foo struct{}
testCases := []struct {
name string
createQuery func()
shouldPanic bool
createQuery func() error
shouldErr bool
}{
{
name: "panic on no name",
createQuery: func() {
ecs.NewQueryType[foo, foo]("", nil)
createQuery: func() error {
return ecs.RegisterQuery[foo, foo](testutils.NewTestWorld(t).Instance(), "", nil)
},
shouldPanic: true,
shouldErr: true,
},
{
name: "panic on no handler",
createQuery: func() {
ecs.NewQueryType[foo, foo]("foo", nil)
createQuery: func() error {
return ecs.RegisterQuery[foo, foo](testutils.NewTestWorld(t).Instance(), "foo", nil)
},
shouldPanic: true,
shouldErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.shouldPanic {
require.Panics(t, tc.createQuery)
if tc.shouldErr {
err := tc.createQuery()
assert.Assert(t, err != nil)
} else {
require.NotPanics(t, tc.createQuery)
assert.NilError(t, tc.createQuery())
}
})
}
Expand Down
39 changes: 29 additions & 10 deletions cardinal/ecs/world.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type World struct {
systemNames []string
tick uint64
nameToComponent map[string]metadata.ComponentMetadata
nameToQuery map[string]IQuery
registeredComponents []metadata.ComponentMetadata
registeredMessages []message.Message
registeredQueries []IQuery
Expand Down Expand Up @@ -192,22 +193,39 @@ func (w *World) GetComponentByName(name string) (metadata.ComponentMetadata, err
return componentType, nil
}

func (w *World) RegisterQueries(queries ...IQuery) error {
if w.stateIsLoaded {
func RegisterQuery[Request any, Reply any](
world *World,
name string,
handler func(wCtx WorldContext, req *Request) (*Reply, error),
opts ...func() func(queryType *Query[Request, Reply]),
) error {
if world.stateIsLoaded {
panic("cannot register queries after loading game state")
}
w.registeredQueries = append(w.registeredQueries, queries...)
seenQueryNames := map[string]struct{}{}
for _, t := range w.registeredQueries {
name := t.Name()
if _, ok := seenQueryNames[name]; ok {
return fmt.Errorf("duplicate query %q: %w", name, ErrDuplicateQueryName)
}
seenQueryNames[name] = struct{}{}

if _, ok := world.nameToQuery[name]; ok {
return fmt.Errorf("query with name %s is already registered", name)
}

q, err := NewQueryType[Request, Reply](name, handler, opts...)
if err != nil {
return err
}

world.registeredQueries = append(world.registeredQueries, q)
world.nameToQuery[q.Name()] = q

return nil
}

func (w *World) GetQueryByName(name string) (IQuery, error) {
if q, ok := w.nameToQuery[name]; ok {
return q, nil
} else {
return nil, fmt.Errorf("query with name %s not found", name)
}
}

func (w *World) RegisterMessages(txs ...message.Message) error {
if w.stateIsLoaded {
panic("cannot register messages after loading game state")
Expand Down Expand Up @@ -272,6 +290,7 @@ func NewWorld(
systems: make([]System, 0),
initSystem: func(_ WorldContext) error { return nil },
nameToComponent: make(map[string]metadata.ComponentMetadata),
nameToQuery: make(map[string]IQuery),
txQueue: message.NewTxQueue(),
Logger: logger,
isGameLoopRunning: atomic.Bool{},
Expand Down
Loading

0 comments on commit c6d27dd

Please sign in to comment.