diff --git a/cmd/yolo/main.go b/cmd/yolo/main.go index dcbf20ba..ddc35d9d 100644 --- a/cmd/yolo/main.go +++ b/cmd/yolo/main.go @@ -35,6 +35,8 @@ func main() { corsAllowedOrigins string requestTimeout time.Duration shutdownTimeout time.Duration + basicAuth string + realm string ) var ( rootFlagSet = flag.NewFlagSet("yolo", flag.ExitOnError) @@ -51,6 +53,8 @@ func main() { serverFlagSet.StringVar(&corsAllowedOrigins, "cors-allowed-origins", "", "CORS allowed origins (*.domain.tld)") serverFlagSet.DurationVar(&requestTimeout, "request-timeout", 5*time.Second, "request timeout") serverFlagSet.DurationVar(&shutdownTimeout, "shutdown-timeout", 6*time.Second, "server shutdown timeout") + serverFlagSet.StringVar(&basicAuth, "basic-auth-password", "", "if set, enables basic authentication") + serverFlagSet.StringVar(&realm, "realm", "Yolo", "authentication Realm") server := &ffcli.Command{ Name: `server`, @@ -93,6 +97,8 @@ func main() { RequestTimeout: requestTimeout, ShutdownTimeout: shutdownTimeout, CORSAllowedOrigins: corsAllowedOrigins, + BasicAuth: basicAuth, + Realm: realm, }) if err != nil { return err diff --git a/server.go b/server.go index a99cd69f..c90ba336 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,8 @@ package yolo import ( "context" + "crypto/subtle" + fmt "fmt" "net" "net/http" "strings" @@ -39,6 +41,8 @@ type ServerOpts struct { CORSAllowedOrigins string RequestTimeout time.Duration ShutdownTimeout time.Duration + BasicAuth string + Realm string } func NewServer(ctx context.Context, svc Service, opts ServerOpts) (*Server, error) { @@ -51,6 +55,9 @@ func NewServer(ctx context.Context, svc Service, opts ServerOpts) (*Server, erro if opts.GRPCBind == "" { opts.GRPCBind = ":0" } + if opts.Realm == "" { + opts.Realm = "Yolo" + } // gRPC internal server srv := Server{ @@ -103,6 +110,11 @@ func NewServer(ctx context.Context, svc Service, opts ServerOpts) (*Server, erro r.Use(chilogger.Logger(srv.logger)) r.Use(middleware.Timeout(opts.RequestTimeout)) r.Use(middleware.Recoverer) + + if opts.BasicAuth != "" { + r.Use(basicAuth(opts.BasicAuth, opts.Realm)) + } + gwmux := runtime.NewServeMux( runtime.WithMarshalerOption(runtime.MIMEWildcard, &gateway.JSONPb{EmitDefaults: false, Indent: " ", OrigName: true}), runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler), @@ -153,3 +165,18 @@ func (srv *Server) Start() error { func (srv *Server) Stop() { srv.grpcServer.GracefulStop() } + +func basicAuth(basicAuth string, realm string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, password, ok := r.BasicAuth() + if !ok || subtle.ConstantTimeCompare([]byte(password), []byte(basicAuth)) != 1 { + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprintf(w, "invalid credentials\n") + return + } + next.ServeHTTP(w, r) + }) + } +}