diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..fc0c9e1 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,98 @@ +package middleware + +// Copied from chi middleware + +// Ported from Goji's middleware, source: +// https://github.com/zenazn/goji/tree/master/web/middleware + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "os" + "strings" + "sync/atomic" +) + +// Key to use when setting the request ID. +type ctxKeyRequestID int + +// RequestIDKey is the key that holds the unique request ID in a request context. +const RequestIDKey ctxKeyRequestID = 0 + +// RequestIDHeader is the name of the HTTP Header which contains the request id. +// Exported so that it can be changed by developers +var RequestIDHeader = "X-Request-Id" + +var prefix string +var reqid uint64 + +// A quick note on the statistics here: we're trying to calculate the chance that +// two randomly generated base62 prefixes will collide. We use the formula from +// http://en.wikipedia.org/wiki/Birthday_problem +// +// P[m, n] \approx 1 - e^{-m^2/2n} +// +// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server +// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$ +// +// For a $k$ character base-62 identifier, we have $n(k) = 62^k$ +// +// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for +// our purposes, and is surely more than anyone would ever need in practice -- a +// process that is rebooted a handful of times a day for a hundred years has less +// than a millionth of a percent chance of generating two colliding IDs. + +func init() { + hostname, err := os.Hostname() + if hostname == "" || err != nil { + hostname = "localhost" + } + var buf [12]byte + var b64 string + for len(b64) < 10 { + rand.Read(buf[:]) + b64 = base64.StdEncoding.EncodeToString(buf[:]) + b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) + } + + prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) +} + +// RequestID is a middleware that injects a request ID into the context of each +// request. A request ID is a string of the form "host.example.com/random-0001", +// where "random" is a base62 random string that uniquely identifies this go +// process, and where the last number is an atomically incremented request +// counter. +func RequestID(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + requestID := r.Header.Get(RequestIDHeader) + if requestID == "" { + myid := atomic.AddUint64(&reqid, 1) + requestID = fmt.Sprintf("%s-%06d", prefix, myid) + } + ctx = context.WithValue(ctx, RequestIDKey, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// GetReqID returns a request ID from the given context if one is present. +// Returns the empty string if a request ID cannot be found. +func GetReqID(ctx context.Context) string { + if ctx == nil { + return "" + } + if reqID, ok := ctx.Value(RequestIDKey).(string); ok { + return reqID + } + return "" +} + +// NextRequestID generates the next request ID in the sequence. +func NextRequestID() uint64 { + return atomic.AddUint64(&reqid, 1) +} diff --git a/lmw.go b/lmw.go index a2dbfb6..5843236 100644 --- a/lmw.go +++ b/lmw.go @@ -7,9 +7,34 @@ import ( "io" "net/http" "time" + + "github.com/interline-io/log/internal/middleware" ) -// log request and duration +// Copy of chi request id middleware +func RequestIDMiddleware(next http.Handler) http.Handler { + return middleware.RequestID(next) +} + +// Glue between chi RequestID and Zerolog +func RequestIDLoggingMiddleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rlog := For(ctx). + With(). + Str("request_id", middleware.GetReqID(ctx)).Logger() + next.ServeHTTP(w, r.WithContext(WithLogger(ctx, rlog))) + } + return http.HandlerFunc(fn) +} + +// Log request and duration +// Renamed from LoggingMiddleware +func DurationLoggingMiddleware(longQueryDuration int, getUserName func(context.Context) string) func(http.Handler) http.Handler { + return LoggingMiddleware(longQueryDuration, getUserName) +} + +// Log request and duration func LoggingMiddleware(longQueryDuration int, getUserName func(context.Context) string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {