Skip to content

Commit

Permalink
refactor(cardinal): register query with generics (#413)
Browse files Browse the repository at this point in the history
Closes: #XXX

## What is the purpose of the change

Making register query more consistent with register component

## Brief Changelog

- Refactor RegisterQuery to be similar to RegisterComponent

## Testing and Verifying

- All relevant tests have been adjusted accordingly
  • Loading branch information
smsunarto authored Nov 13, 2023
1 parent f091e73 commit 27f1dc7
Show file tree
Hide file tree
Showing 15 changed files with 320 additions and 289 deletions.
79 changes: 49 additions & 30 deletions cardinal/ecs/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"pkg.world.dev/world-engine/cardinal/ecs/abi"
)

type IQuery interface {
type Query interface {
// Name returns the name of the query.
Name() string
// HandleQuery handles queries with concrete types, rather than encoded bytes.
Expand All @@ -35,7 +35,7 @@ type IQuery interface {

type QueryType[Request any, Reply any] struct {
name string
handler func(wCtx WorldContext, req Request) (Reply, error)
handler func(wCtx WorldContext, req *Request) (*Reply, error)
requestABI *ethereumAbi.Type
replyABI *ethereumAbi.Type
}
Expand All @@ -49,45 +49,27 @@ func WithQueryEVMSupport[Request, Reply any]() func(transactionType *QueryType[R
}
}

var _ IQuery = &QueryType[struct{}, struct{}]{}
var _ Query = &QueryType[struct{}, struct{}]{}

func NewQueryType[Request any, Reply any](
name string,
handler func(wCtx WorldContext, req Request) (Reply, error),
handler func(wCtx WorldContext, req *Request) (*Reply, error),
opts ...func() func(queryType *QueryType[Request, Reply]),
) *QueryType[Request, Reply] {
if name == "" {
panic("cannot create query without name")
}
if handler == nil {
panic("cannot create query without handler")
}
var req Request
var rep Reply
reqType := reflect.TypeOf(req)
reqKind := reqType.Kind()
reqValid := false
if (reqKind == reflect.Pointer && reqType.Elem().Kind() == reflect.Struct) || reqKind == reflect.Struct {
reqValid = true
}
repType := reflect.TypeOf(rep)
repKind := reqType.Kind()
repValid := false
if (repKind == reflect.Pointer && repType.Elem().Kind() == reflect.Struct) || repKind == reflect.Struct {
repValid = true
) (Query, error) {
err := validateQuery[Request, Reply](name, handler)
if err != nil {
return nil, err
}

if !repValid || !reqValid {
panic(fmt.Sprintf("Invalid QueryType: %s: The Request and Reply must be both structs", name))
}
r := &QueryType[Request, Reply]{
name: name,
handler: handler,
}
for _, opt := range opts {
opt()(r)
}
return r

return r, nil
}

func (r *QueryType[Request, Reply]) IsEVMCompatible() bool {
Expand Down Expand Up @@ -123,7 +105,7 @@ func (r *QueryType[req, rep]) HandleQuery(wCtx WorldContext, a any) (any, error)
if !ok {
return nil, fmt.Errorf("cannot cast %T to this query request type %T", a, new(req))
}
reply, err := r.handler(wCtx, request)
reply, err := r.handler(wCtx, &request)
return reply, err
}

Expand All @@ -133,7 +115,7 @@ func (r *QueryType[req, rep]) HandleQueryRaw(wCtx WorldContext, bz []byte) ([]by
if err != nil {
return nil, fmt.Errorf("unable to unmarshal query request into type %T: %w", *request, err)
}
res, err := r.handler(wCtx, *request)
res, err := r.handler(wCtx, request)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -217,3 +199,40 @@ func (r *QueryType[Request, Reply]) EncodeAsABI(input any) ([]byte, error) {
}
return bz, nil
}

func validateQuery[Request any, Reply any](
name string,
handler func(wCtx WorldContext, req *Request) (*Reply, error),
) error {
if name == "" {
return errors.New("cannot create query without name")
}
if handler == nil {
return errors.New("cannot create query without handler")
}

var req Request
var rep Reply
reqType := reflect.TypeOf(req)
reqKind := reqType.Kind()
reqValid := false
if (reqKind == reflect.Pointer && reqType.Elem().Kind() == reflect.Struct) ||
reqKind == reflect.Struct {
reqValid = true
}
repType := reflect.TypeOf(rep)
repKind := reqType.Kind()
repValid := false
if (repKind == reflect.Pointer && repType.Elem().Kind() == reflect.Struct) ||
repKind == reflect.Struct {
repValid = true
}

if !repValid || !reqValid {
return fmt.Errorf(
"invalid query: %s: the Request and Reply generics must be both structs",
name,
)
}
return nil
}
92 changes: 39 additions & 53 deletions cardinal/ecs/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,28 @@ package ecs_test

import (
"context"
"github.com/stretchr/testify/require"
"pkg.world.dev/world-engine/cardinal/testutils"
"testing"

"pkg.world.dev/world-engine/cardinal/ecs"
"pkg.world.dev/world-engine/cardinal/testutils"

"pkg.world.dev/world-engine/cardinal/evm"

"gotest.tools/v3/assert"

routerv1 "pkg.world.dev/world-engine/rift/router/v1"

"pkg.world.dev/world-engine/cardinal/ecs"
)

func TestQueryTypeNotStructs(t *testing.T) {
type FooRequest struct {
ID string
}
type FooReply struct {
Name string
Age uint64
}

expectedReply := FooReply{
Name: "Chad",
Age: 22,
}

defer func() {
// test should trigger a panic.
panicValue := recover()
assert.Assert(t, panicValue != nil)
ecs.NewQueryType[FooRequest, FooReply]("foo", func(wCtx ecs.WorldContext, req FooRequest) (FooReply, error) {
return expectedReply, nil
})
defer func() {
// deferred function should not fail
panicValue = recover()
assert.Assert(t, panicValue == nil)
}()
}()

ecs.NewQueryType[string, string]("foo", func(wCtx ecs.WorldContext, req string) (string, error) {
return "blah", nil
})
str := "blah"
err := ecs.RegisterQuery[string, string](
testutils.NewTestWorld(t).Instance(),
"foo",
func(wCtx ecs.WorldContext, req *string) (*string, error) {
return &str, nil
},
)
assert.ErrorContains(t, err, "the Request and Reply generics must be both structs")
}

func TestQueryEVM(t *testing.T) {
Expand All @@ -62,20 +40,27 @@ func TestQueryEVM(t *testing.T) {
Name: "Chad",
Age: 22,
}
fooQuery := ecs.NewQueryType[FooRequest, FooReply]("foo", func(wCtx ecs.WorldContext, req FooRequest,
) (FooReply, error) {
return expectedReply, nil
}, ecs.WithQueryEVMSupport[FooRequest, FooReply])

w := testutils.NewTestWorld(t).Instance()
err := w.RegisterQueries(fooQuery)
err := ecs.RegisterQuery[FooRequest, FooReply](
w,
"foo",
func(wCtx ecs.WorldContext, req *FooRequest,
) (*FooReply, error) {
return &expectedReply, nil
},
ecs.WithQueryEVMSupport[FooRequest, FooReply],
)

assert.NilError(t, err)
err = w.RegisterMessages(ecs.NewMessageType[struct{}, struct{}]("blah"))
assert.NilError(t, err)
s, err := evm.NewServer(w)
assert.NilError(t, err)

// create the abi encoded bytes that the EVM would send.
fooQuery, err := w.GetQueryByName("foo")
assert.NilError(t, err)
bz, err := fooQuery.EncodeAsABI(FooRequest{ID: "foo"})
assert.NilError(t, err)

Expand All @@ -97,35 +82,36 @@ func TestQueryEVM(t *testing.T) {
assert.Equal(t, reply, expectedReply)
}

func TestPanicsOnNoNameOrHandler(t *testing.T) {
func TestErrOnNoNameOrHandler(t *testing.T) {
type foo struct{}
testCases := []struct {
name string
createQuery func()
shouldPanic bool
createQuery func() error
shouldErr bool
}{
{
name: "panic on no name",
createQuery: func() {
ecs.NewQueryType[foo, foo]("", nil)
name: "error on no name",
createQuery: func() error {
return ecs.RegisterQuery[foo, foo](testutils.NewTestWorld(t).Instance(), "", nil)
},
shouldPanic: true,
shouldErr: true,
},
{
name: "panic on no handler",
createQuery: func() {
ecs.NewQueryType[foo, foo]("foo", nil)
name: "error on no handler",
createQuery: func() error {
return ecs.RegisterQuery[foo, foo](testutils.NewTestWorld(t).Instance(), "foo", nil)
},
shouldPanic: true,
shouldErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.shouldPanic {
require.Panics(t, tc.createQuery)
if tc.shouldErr {
err := tc.createQuery()
assert.Assert(t, err != nil)
} else {
require.NotPanics(t, tc.createQuery)
assert.NilError(t, tc.createQuery())
}
})
}
Expand Down
48 changes: 34 additions & 14 deletions cardinal/ecs/world.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ type World struct {
systemNames []string
tick uint64
nameToComponent map[string]metadata.ComponentMetadata
nameToQuery map[string]Query
registeredComponents []metadata.ComponentMetadata
registeredMessages []message.Message
registeredQueries []IQuery
registeredQueries []Query
isComponentsRegistered bool
isMessagesRegistered bool
stateIsLoaded bool
Expand Down Expand Up @@ -208,22 +209,38 @@ func (w *World) GetComponentByName(name string) (metadata.ComponentMetadata, err
return componentType, nil
}

func (w *World) RegisterQueries(queries ...IQuery) error {
if w.stateIsLoaded {
func RegisterQuery[Request any, Reply any](
world *World,
name string,
handler func(wCtx WorldContext, req *Request) (*Reply, error),
opts ...func() func(queryType *QueryType[Request, Reply]),
) error {
if world.stateIsLoaded {
panic("cannot register queries after loading game state")
}
w.registeredQueries = append(w.registeredQueries, queries...)
seenQueryNames := map[string]struct{}{}
for _, t := range w.registeredQueries {
name := t.Name()
if _, ok := seenQueryNames[name]; ok {
return fmt.Errorf("duplicate query %q: %w", name, ErrDuplicateQueryName)
}
seenQueryNames[name] = struct{}{}

if _, ok := world.nameToQuery[name]; ok {
return fmt.Errorf("query with name %s is already registered", name)
}

q, err := NewQueryType[Request, Reply](name, handler, opts...)
if err != nil {
return err
}

world.registeredQueries = append(world.registeredQueries, q)
world.nameToQuery[q.Name()] = q

return nil
}

func (w *World) GetQueryByName(name string) (Query, error) {
if q, ok := w.nameToQuery[name]; ok {
return q, nil
}
return nil, fmt.Errorf("query with name %s not found", name)
}

func (w *World) RegisterMessages(txs ...message.Message) error {
if w.stateIsLoaded {
panic("cannot register messages after loading game state")
Expand Down Expand Up @@ -258,7 +275,7 @@ func (w *World) registerInternalMessages() {
)
}

func (w *World) ListQueries() []IQuery {
func (w *World) ListQueries() []Query {
return w.registeredQueries
}

Expand Down Expand Up @@ -288,6 +305,7 @@ func NewWorld(
systems: make([]System, 0),
initSystem: func(_ WorldContext) error { return nil },
nameToComponent: make(map[string]metadata.ComponentMetadata),
nameToQuery: make(map[string]Query),
txQueue: message.NewTxQueue(),
Logger: logger,
isGameLoopRunning: atomic.Bool{},
Expand Down Expand Up @@ -630,8 +648,10 @@ func (w *World) RecoverFromChain(ctx context.Context) error {
"be sure to use the `WithAdapter` option when creating the world")
}
if w.CurrentTick() > 0 {
return fmt.Errorf("world recovery should not occur in a world with existing state. please verify all " +
"state has been cleared before running recovery")
return fmt.Errorf(
"world recovery should not occur in a world with existing state. please verify all " +
"state has been cleared before running recovery",
)
}

w.isRecovering = true
Expand Down
Loading

0 comments on commit 27f1dc7

Please sign in to comment.