Skip to content

Commit

Permalink
Merge pull request #18 from berty/dev/moul/basic-auth
Browse files Browse the repository at this point in the history
feat: basic auth
  • Loading branch information
moul authored Feb 11, 2020
2 parents b87dd13 + 48a0cdc commit ac196e5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cmd/yolo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func main() {
corsAllowedOrigins string
requestTimeout time.Duration
shutdownTimeout time.Duration
basicAuth string
realm string
)
var (
rootFlagSet = flag.NewFlagSet("yolo", flag.ExitOnError)
Expand All @@ -51,6 +53,8 @@ func main() {
serverFlagSet.StringVar(&corsAllowedOrigins, "cors-allowed-origins", "", "CORS allowed origins (*.domain.tld)")
serverFlagSet.DurationVar(&requestTimeout, "request-timeout", 5*time.Second, "request timeout")
serverFlagSet.DurationVar(&shutdownTimeout, "shutdown-timeout", 6*time.Second, "server shutdown timeout")
serverFlagSet.StringVar(&basicAuth, "basic-auth-password", "", "if set, enables basic authentication")
serverFlagSet.StringVar(&realm, "realm", "Yolo", "authentication Realm")

server := &ffcli.Command{
Name: `server`,
Expand Down Expand Up @@ -93,6 +97,8 @@ func main() {
RequestTimeout: requestTimeout,
ShutdownTimeout: shutdownTimeout,
CORSAllowedOrigins: corsAllowedOrigins,
BasicAuth: basicAuth,
Realm: realm,
})
if err != nil {
return err
Expand Down
27 changes: 27 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package yolo

import (
"context"
"crypto/subtle"
fmt "fmt"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -39,6 +41,8 @@ type ServerOpts struct {
CORSAllowedOrigins string
RequestTimeout time.Duration
ShutdownTimeout time.Duration
BasicAuth string
Realm string
}

func NewServer(ctx context.Context, svc Service, opts ServerOpts) (*Server, error) {
Expand All @@ -51,6 +55,9 @@ func NewServer(ctx context.Context, svc Service, opts ServerOpts) (*Server, erro
if opts.GRPCBind == "" {
opts.GRPCBind = ":0"
}
if opts.Realm == "" {
opts.Realm = "Yolo"
}

// gRPC internal server
srv := Server{
Expand Down Expand Up @@ -103,6 +110,11 @@ func NewServer(ctx context.Context, svc Service, opts ServerOpts) (*Server, erro
r.Use(chilogger.Logger(srv.logger))
r.Use(middleware.Timeout(opts.RequestTimeout))
r.Use(middleware.Recoverer)

if opts.BasicAuth != "" {
r.Use(basicAuth(opts.BasicAuth, opts.Realm))
}

gwmux := runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &gateway.JSONPb{EmitDefaults: false, Indent: " ", OrigName: true}),
runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler),
Expand Down Expand Up @@ -153,3 +165,18 @@ func (srv *Server) Start() error {
func (srv *Server) Stop() {
srv.grpcServer.GracefulStop()
}

func basicAuth(basicAuth string, realm string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, password, ok := r.BasicAuth()
if !ok || subtle.ConstantTimeCompare([]byte(password), []byte(basicAuth)) != 1 {
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm))
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprintf(w, "invalid credentials\n")
return
}
next.ServeHTTP(w, r)
})
}
}

0 comments on commit ac196e5

Please sign in to comment.