Skip to content

Commit

Permalink
feat(cardinal): allow cors on dev mode (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
pyrofolium authored Nov 8, 2023
1 parent f92d50d commit d63eaaa
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 10 deletions.
16 changes: 13 additions & 3 deletions cardinal/cardinal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cardinal_test

import (
"net/http"
"os"
"os/exec"
"strconv"
Expand Down Expand Up @@ -28,9 +29,8 @@ func TestNewWorld(t *testing.T) {
func TestCanQueryInsideSystem(t *testing.T) {
testutils.SetTestTimeout(t, 10*time.Second)

world, doTick := testutils.MakeWorldAndTicker(t)
world, doTick := testutils.MakeWorldAndTicker(t, cardinal.WithCORS())
assert.NilError(t, cardinal.RegisterComponent[Foo](world))

wantNumOfEntities := 10
world.Init(func(worldCtx cardinal.WorldContext) {
_, err := cardinal.CreateMany(worldCtx, wantNumOfEntities, Foo{})
Expand Down Expand Up @@ -59,7 +59,7 @@ func TestShutdownViaSignal(t *testing.T) {
// If this test is frozen then it failed to shut down, create a failure with panic.
var wg sync.WaitGroup
testutils.SetTestTimeout(t, 10*time.Second)
world, err := cardinal.NewMockWorld()
world, err := cardinal.NewMockWorld(cardinal.WithCORS())
assert.NilError(t, cardinal.RegisterComponent[Foo](world))
assert.NilError(t, err)
wantNumOfEntities := 10
Expand All @@ -77,6 +77,16 @@ func TestShutdownViaSignal(t *testing.T) {
// wait until game loop is running
time.Sleep(500 * time.Millisecond)
}
// test CORS with cardinal
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, "http://localhost:4040/query/http/endpoints", nil)
assert.NilError(t, err)
req.Header.Set("Origin", "http://www.bullshit.com") // test CORS
resp, err := client.Do(req)
assert.NilError(t, err)
v := resp.Header.Get("Access-Control-Allow-Origin")
assert.Equal(t, v, "*")
assert.Equal(t, resp.StatusCode, 200)

conn, _, err := websocket.DefaultDialer.Dial("ws://localhost:4040/events", nil)
assert.NilError(t, err)
Expand Down
1 change: 1 addition & 0 deletions cardinal/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
github.com/invopop/jsonschema v0.7.0
github.com/mitchellh/mapstructure v1.5.0
github.com/redis/go-redis/v9 v9.0.2
github.com/rs/cors v1.10.1
github.com/rs/zerolog v1.30.0
github.com/stretchr/testify v1.8.4
google.golang.org/grpc v1.58.3
Expand Down
6 changes: 6 additions & 0 deletions cardinal/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,9 @@ func WithPrettyLog() WorldOption {
ecsOption: ecs.WithPrettyLog(),
}
}

func WithCORS() WorldOption {
return WorldOption{
serverOption: server.WithCORS(),
}
}
6 changes: 6 additions & 0 deletions cardinal/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ func WithAdapter(a shard.Adapter) Option {
th.adapter = a
}
}

func WithCORS() Option {
return func(th *Handler) {
th.withCORS = true
}
}
14 changes: 10 additions & 4 deletions cardinal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/go-openapi/runtime/middleware"
"github.com/go-openapi/runtime/middleware/untyped"
"github.com/mitchellh/mapstructure"
"github.com/rs/cors"
"github.com/rs/zerolog/log"
"pkg.world.dev/world-engine/cardinal/ecs"
"pkg.world.dev/world-engine/cardinal/shard"
Expand All @@ -27,6 +28,7 @@ type Handler struct {
server *http.Server
disableSigVerification bool
Port string
withCORS bool

// plugins
adapter shard.WriteAdapter
Expand Down Expand Up @@ -61,8 +63,9 @@ var swaggerData []byte

func newSwaggerHandlerEmbed(w *ecs.World, builder middleware.Builder, opts ...Option) (*Handler, error) {
th := &Handler{
w: w,
Mux: http.NewServeMux(),
w: w,
Mux: http.NewServeMux(),
withCORS: false,
}
for _, opt := range opts {
opt(th)
Expand Down Expand Up @@ -95,8 +98,11 @@ func newSwaggerHandlerEmbed(w *ecs.World, builder middleware.Builder, opts ...Op
}

app := middleware.NewContext(specDoc, api, nil)

th.Mux.Handle("/", app.APIHandler(builder))
var handler = app.APIHandler(builder)
if th.withCORS {
handler = cors.AllowAll().Handler(handler)
}
th.Mux.Handle("/", handler)
th.Initialize()

return th, nil
Expand Down
11 changes: 8 additions & 3 deletions cardinal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,15 @@ func TestCanListTransactionEndpoints(t *testing.T) {
betaTx := ecs.NewTransactionType[SendEnergyTx, SendEnergyTxResult]("beta")
gammaTx := ecs.NewTransactionType[SendEnergyTx, SendEnergyTxResult]("gamma")
assert.NilError(t, w.RegisterTransactions(alphaTx, betaTx, gammaTx))
txh := testutils.MakeTestTransactionHandler(t, w, server.DisableSignatureVerification())

resp, err := http.Post(txh.MakeHTTPURL("query/http/endpoints"), "application/json", nil)
txh := testutils.MakeTestTransactionHandler(t, w, server.DisableSignatureVerification(), server.WithCORS())
client := &http.Client{}
req, err := http.NewRequest(http.MethodPost, txh.MakeHTTPURL("query/http/endpoints"), nil)
assert.NilError(t, err)
req.Header.Set("Origin", "http://www.bullshit.com") // test CORS
resp, err := client.Do(req)
assert.NilError(t, err)
v := resp.Header.Get("Access-Control-Allow-Origin")
assert.Equal(t, v, "*")
assert.Equal(t, resp.StatusCode, 200)
var gotEndpoints map[string][]string
assert.NilError(t, json.NewDecoder(resp.Body).Decode(&gotEndpoints))
Expand Down

0 comments on commit d63eaaa

Please sign in to comment.