diff --git a/cardinal/cardinal_test.go b/cardinal/cardinal_test.go index a4eab7501..355c1f8f0 100644 --- a/cardinal/cardinal_test.go +++ b/cardinal/cardinal_test.go @@ -1,6 +1,7 @@ package cardinal_test import ( + "net/http" "os" "os/exec" "strconv" @@ -28,9 +29,8 @@ func TestNewWorld(t *testing.T) { func TestCanQueryInsideSystem(t *testing.T) { testutils.SetTestTimeout(t, 10*time.Second) - world, doTick := testutils.MakeWorldAndTicker(t) + world, doTick := testutils.MakeWorldAndTicker(t, cardinal.WithCORS()) assert.NilError(t, cardinal.RegisterComponent[Foo](world)) - wantNumOfEntities := 10 world.Init(func(worldCtx cardinal.WorldContext) { _, err := cardinal.CreateMany(worldCtx, wantNumOfEntities, Foo{}) @@ -59,7 +59,7 @@ func TestShutdownViaSignal(t *testing.T) { // If this test is frozen then it failed to shut down, create a failure with panic. var wg sync.WaitGroup testutils.SetTestTimeout(t, 10*time.Second) - world, err := cardinal.NewMockWorld() + world, err := cardinal.NewMockWorld(cardinal.WithCORS()) assert.NilError(t, cardinal.RegisterComponent[Foo](world)) assert.NilError(t, err) wantNumOfEntities := 10 @@ -77,6 +77,16 @@ func TestShutdownViaSignal(t *testing.T) { // wait until game loop is running time.Sleep(500 * time.Millisecond) } + // test CORS with cardinal + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, "http://localhost:4040/query/http/endpoints", nil) + assert.NilError(t, err) + req.Header.Set("Origin", "http://www.bullshit.com") // test CORS + resp, err := client.Do(req) + assert.NilError(t, err) + v := resp.Header.Get("Access-Control-Allow-Origin") + assert.Equal(t, v, "*") + assert.Equal(t, resp.StatusCode, 200) conn, _, err := websocket.DefaultDialer.Dial("ws://localhost:4040/events", nil) assert.NilError(t, err) diff --git a/cardinal/go.mod b/cardinal/go.mod index 7bc25602f..66f25e8ee 100644 --- a/cardinal/go.mod +++ b/cardinal/go.mod @@ -25,6 +25,7 @@ require ( github.com/invopop/jsonschema v0.7.0 github.com/mitchellh/mapstructure v1.5.0 github.com/redis/go-redis/v9 v9.0.2 + github.com/rs/cors v1.10.1 github.com/rs/zerolog v1.30.0 github.com/stretchr/testify v1.8.4 google.golang.org/grpc v1.58.3 diff --git a/cardinal/option.go b/cardinal/option.go index 23e819bfd..8c8d295fb 100644 --- a/cardinal/option.go +++ b/cardinal/option.go @@ -82,3 +82,9 @@ func WithPrettyLog() WorldOption { ecsOption: ecs.WithPrettyLog(), } } + +func WithCORS() WorldOption { + return WorldOption{ + serverOption: server.WithCORS(), + } +} diff --git a/cardinal/server/option.go b/cardinal/server/option.go index 2705c448f..dfb01f39d 100644 --- a/cardinal/server/option.go +++ b/cardinal/server/option.go @@ -21,3 +21,9 @@ func WithAdapter(a shard.Adapter) Option { th.adapter = a } } + +func WithCORS() Option { + return func(th *Handler) { + th.withCORS = true + } +} diff --git a/cardinal/server/server.go b/cardinal/server/server.go index 6aa051382..a3921611e 100644 --- a/cardinal/server/server.go +++ b/cardinal/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/go-openapi/runtime/middleware" "github.com/go-openapi/runtime/middleware/untyped" "github.com/mitchellh/mapstructure" + "github.com/rs/cors" "github.com/rs/zerolog/log" "pkg.world.dev/world-engine/cardinal/ecs" "pkg.world.dev/world-engine/cardinal/shard" @@ -27,6 +28,7 @@ type Handler struct { server *http.Server disableSigVerification bool Port string + withCORS bool // plugins adapter shard.WriteAdapter @@ -61,8 +63,9 @@ var swaggerData []byte func newSwaggerHandlerEmbed(w *ecs.World, builder middleware.Builder, opts ...Option) (*Handler, error) { th := &Handler{ - w: w, - Mux: http.NewServeMux(), + w: w, + Mux: http.NewServeMux(), + withCORS: false, } for _, opt := range opts { opt(th) @@ -95,8 +98,11 @@ func newSwaggerHandlerEmbed(w *ecs.World, builder middleware.Builder, opts ...Op } app := middleware.NewContext(specDoc, api, nil) - - th.Mux.Handle("/", app.APIHandler(builder)) + var handler = app.APIHandler(builder) + if th.withCORS { + handler = cors.AllowAll().Handler(handler) + } + th.Mux.Handle("/", handler) th.Initialize() return th, nil diff --git a/cardinal/server/server_test.go b/cardinal/server/server_test.go index f534dc4fb..f52eaf623 100644 --- a/cardinal/server/server_test.go +++ b/cardinal/server/server_test.go @@ -203,10 +203,15 @@ func TestCanListTransactionEndpoints(t *testing.T) { betaTx := ecs.NewTransactionType[SendEnergyTx, SendEnergyTxResult]("beta") gammaTx := ecs.NewTransactionType[SendEnergyTx, SendEnergyTxResult]("gamma") assert.NilError(t, w.RegisterTransactions(alphaTx, betaTx, gammaTx)) - txh := testutils.MakeTestTransactionHandler(t, w, server.DisableSignatureVerification()) - - resp, err := http.Post(txh.MakeHTTPURL("query/http/endpoints"), "application/json", nil) + txh := testutils.MakeTestTransactionHandler(t, w, server.DisableSignatureVerification(), server.WithCORS()) + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, txh.MakeHTTPURL("query/http/endpoints"), nil) + assert.NilError(t, err) + req.Header.Set("Origin", "http://www.bullshit.com") // test CORS + resp, err := client.Do(req) assert.NilError(t, err) + v := resp.Header.Get("Access-Control-Allow-Origin") + assert.Equal(t, v, "*") assert.Equal(t, resp.StatusCode, 200) var gotEndpoints map[string][]string assert.NilError(t, json.NewDecoder(resp.Body).Decode(&gotEndpoints))