Skip to content

Commit

Permalink
feat: restrict access to all tables by default
Browse files Browse the repository at this point in the history
for #10
  • Loading branch information
bcho committed May 29, 2022
1 parent 1511b88 commit 0b9c5af
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 14 deletions.
5 changes: 5 additions & 0 deletions fixture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"k8s.io/klog/v2/ktesting"
)

var enabledTestTables = []string{"test", "test_view"}

type TestContext struct {
server *httptest.Server
db *sqlx.DB
Expand Down Expand Up @@ -135,6 +137,7 @@ func createTestContextUsingInMemoryDB(t testing.TB) *TestContext {
Execer: db,
}
serverOpts.AuthOptions.disableAuth = true
serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables
server, err := NewServer(serverOpts)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -190,6 +193,7 @@ func createTestContextWithHMACTokenAuth(t testing.TB) *TestContext {
Execer: db,
}
serverOpts.AuthOptions.TokenFilePath = testTokenFile
serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables
server, err := NewServer(serverOpts)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -265,6 +269,7 @@ func createTestContextWithRSATokenAuth(t testing.TB) *TestContext {
Execer: db,
}
serverOpts.AuthOptions.RSAPublicKeyFilePath = testTokenFile
serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables
server, err := NewServer(serverOpts)
if err != nil {
t.Fatal(err)
Expand Down
30 changes: 30 additions & 0 deletions integration_security_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package main

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSecurityNegativeCases(t *testing.T) {
t.Run("Unauthorized", func(t *testing.T) {
tc := createTestContextWithHMACTokenAuth(t)
defer tc.CleanUp(t)

tc.authToken = "" // disable auth
client := tc.Client()
_, _, err := client.From("test").Select("id", "", false).Execute()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Unauthorized")
})

t.Run("TableAccessRestricted", func(t *testing.T) {
tc := createTestContextWithHMACTokenAuth(t)
defer tc.CleanUp(t)

client := tc.Client()
_, _, err := client.From(tableNameMigrations).Select("id", "", false).Execute()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Access Restricted")
})
}
37 changes: 23 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,27 @@ const (
)

type ServerOptions struct {
Logger logr.Logger
Addr string
AuthOptions ServerAuthOptions
Queryer sqlx.QueryerContext
Execer sqlx.ExecerContext
Logger logr.Logger
Addr string
AuthOptions ServerAuthOptions
SecurityOptions ServerSecurityOptions
Queryer sqlx.QueryerContext
Execer sqlx.ExecerContext
}

func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) {
fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen addr")
opts.AuthOptions.bindCLIFlags(fs)
opts.SecurityOptions.bindCLIFlags(fs)
}

func (opts *ServerOptions) defaults() error {
if err := opts.AuthOptions.defaults(); err != nil {
return err
}
if err := opts.SecurityOptions.defaults(); err != nil {
return err
}

if opts.Logger.GetSink() == nil {
opts.Logger = logr.Discard()
Expand Down Expand Up @@ -82,17 +87,21 @@ func NewServer(opts *ServerOptions) (*dbServer, error) {

// TODO: allow specifying cors config from cli / table
serverMux.Use(cors.AllowAll().Handler)
authMiddleware := opts.AuthOptions.createAuthMiddleware(rv.responseError)

{
serverMux.With(authMiddleware).Group(func(r chi.Router) {
routePattern := fmt.Sprintf("/{%s:[^/]+}", routeVarTableOrView)
r.Get(routePattern, rv.handleQueryTableOrView)
r.Post(routePattern, rv.handleInsertTable)
r.Patch(routePattern, rv.handleUpdateTable)
r.Put(routePattern, rv.handleUpdateSingleEntity)
r.Delete(routePattern, rv.handleDeleteTable)
})
serverMux.
With(
opts.AuthOptions.createAuthMiddleware(rv.responseError),
opts.SecurityOptions.createTableOrViewAccessCheckMiddleware(rv.responseError, routeVarTableOrView),
).
Group(func(r chi.Router) {
routePattern := fmt.Sprintf("/{%s:[^/]+}", routeVarTableOrView)
r.Get(routePattern, rv.handleQueryTableOrView)
r.Post(routePattern, rv.handleInsertTable)
r.Patch(routePattern, rv.handleUpdateTable)
r.Put(routePattern, rv.handleUpdateSingleEntity)
r.Delete(routePattern, rv.handleDeleteTable)
})
}

rv.server.Handler = serverMux
Expand Down
5 changes: 5 additions & 0 deletions server_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ var (
Message: "Unauthorized",
StatusCode: http.StatusUnauthorized,
}

ErrAccessRestricted = &ServerError{
Message: "Access Restricted",
StatusCode: http.StatusForbidden,
}
)

func ErrUnsupportedOperator(op string) *ServerError {
Expand Down
53 changes: 53 additions & 0 deletions server_security.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package main

import (
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
"github.com/spf13/pflag"
)

// TODO: generally speaking, we need a fine-grained RBAC system.

type ServerSecurityOptions struct {
// EnabledTableOrViews list of table or view names that are accessible (read & write).
EnabledTableOrViews []string
}

func (opts *ServerSecurityOptions) bindCLIFlags(fs *pflag.FlagSet) {
fs.StringSliceVar(
&opts.EnabledTableOrViews,
"--security-allow-table",
[]string{},
"list of table or view names that are accessible (read & write)",
)
}

func (opts *ServerSecurityOptions) defaults() error {
return nil
}

func (opts *ServerSecurityOptions) createTableOrViewAccessCheckMiddleware(
responseErr func(w http.ResponseWriter, err error),
routeVarTableOrView string,
) func(http.Handler) http.Handler {
accesibleTableOrViews := make(map[string]struct{})
for _, t := range opts.EnabledTableOrViews {
accesibleTableOrViews[t] = struct{}{}
}
fmt.Println(accesibleTableOrViews)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
target := chi.URLParam(req, routeVarTableOrView)

if _, ok := accesibleTableOrViews[target]; !ok {
responseErr(w, ErrAccessRestricted)
return
}

next.ServeHTTP(w, req)
})
}
}

0 comments on commit 0b9c5af

Please sign in to comment.