diff --git a/config.go b/config.go index cc6aa11..c5d297a 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package chizap import ( + "net" "net/http" "go.uber.org/zap" @@ -12,13 +13,19 @@ type AdditionalFieldExtractor func(r *http.Request) []zap.Field type config struct { filters []Filter + ipFn func(r *http.Request) string additionalFieldExtractors []AdditionalFieldExtractor } type Option func(*config) func newConfig(opts ...Option) *config { - cfg := &config{} + cfg := &config{ + ipFn: func(r *http.Request) string { + ip, _, _ := net.SplitHostPort(r.RemoteAddr) + return ip + }, + } for _, opt := range opts { opt(cfg) @@ -44,3 +51,12 @@ func WithAdditionalFieldExtractor(extractor AdditionalFieldExtractor) Option { c.additionalFieldExtractors = append(c.additionalFieldExtractors, extractor) } } + +// WithIPFn takes a function that will be called on every +// request and the returned ip will be added to the log entry. +// http.Request RemoteAddr is logged by default. +func WithIPFn(fn func(r *http.Request) string) Option { + return func(c *config) { + c.ipFn = fn + } +} diff --git a/logger.go b/logger.go index 7775229..6387a1f 100644 --- a/logger.go +++ b/logger.go @@ -1,7 +1,6 @@ package chizap import ( - "net" "net/http" "time" @@ -31,12 +30,11 @@ func Logger(logger *zap.Logger, opts ...Option) func(next http.Handler) http.Han end := time.Now() statusCode := ww.Status() - ip, _, _ := net.SplitHostPort(r.RemoteAddr) fields := []zap.Field{ zap.Int("statusCode", statusCode), zap.Duration("duration", end.Sub(start)), - zap.String("ip", ip), + zap.String("ip", cfg.ipFn(r)), zap.String("method", r.Method), zap.String("query", query), zap.String("path", path), @@ -47,8 +45,8 @@ func Logger(logger *zap.Logger, opts ...Option) func(next http.Handler) http.Han zap.String("routePattern", chi.RouteContext(r.Context()).RoutePattern()), } - for _, fn := range cfg.additionalFieldExtractors { - fields = append(fields, fn(r)...) + for _, f := range cfg.additionalFieldExtractors { + fields = append(fields, f(r)...) } if statusCode >= http.StatusInternalServerError { diff --git a/logger_test.go b/logger_test.go index 807e12d..38cdee2 100644 --- a/logger_test.go +++ b/logger_test.go @@ -23,6 +23,7 @@ func TestLogger(t *testing.T) { name string req *http.Request excluded bool + expectedIP string expectedLevel zapcore.Level expectedRoutePattern string expectedAdditionalFields []zap.Field @@ -30,18 +31,21 @@ func TestLogger(t *testing.T) { { name: "/info?test=true", req: httptest.NewRequest(http.MethodGet, "/info?test=true", nil), + expectedIP: "192.0.2.1", expectedLevel: zap.InfoLevel, expectedRoutePattern: "/info", }, { name: "/warn?test=true", req: httptest.NewRequest(http.MethodGet, "/warn?test=true", nil), + expectedIP: "192.0.2.1", expectedLevel: zap.WarnLevel, expectedRoutePattern: "/warn", }, { name: "/error?test=true", req: httptest.NewRequest(http.MethodGet, "/error?test=true", nil), + expectedIP: "192.0.2.1", expectedLevel: zap.ErrorLevel, expectedRoutePattern: "/error", }, @@ -53,12 +57,24 @@ func TestLogger(t *testing.T) { { name: "/delete/123", req: httptest.NewRequest(http.MethodDelete, "/delete/123", nil), + expectedIP: "192.0.2.1", expectedLevel: zap.InfoLevel, expectedRoutePattern: "/delete/{id}", expectedAdditionalFields: []zap.Field{ zap.String("id", "123"), }, }, + { + name: "/x-forwarded-for", + req: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/x-forwarded-for", nil) + req.Header.Set("X-Forwarded-For", "94.222.111.115") + return req + }(), + expectedIP: "94.222.111.115", + expectedLevel: zap.InfoLevel, + expectedRoutePattern: "/x-forwarded-for", + }, } for _, tt := range tests { @@ -84,8 +100,7 @@ func TestLogger(t *testing.T) { assert.Equal(t, tt.req.URL.Path, entry.Message) assert.Equal(t, tt.expectedLevel, entry.Level) require.Len(t, entry.Context, 11+len(tt.expectedAdditionalFields)) - ip, _, _ := net.SplitHostPort(tt.req.RemoteAddr) - assert.Contains(t, entry.Context, zap.String("ip", ip)) + assert.Contains(t, entry.Context, zap.String("ip", tt.expectedIP)) assert.Contains(t, entry.Context, zap.Int("statusCode", rr.Code)) assert.Contains(t, entry.Context, zap.String("method", tt.req.Method)) assert.Contains(t, entry.Context, zap.String("path", tt.req.URL.Path)) @@ -128,6 +143,13 @@ func newRouter(logger *zap.Logger) *chi.Mux { zap.String("id", chi.URLParam(r, "id")), } }), + chizap.WithIPFn(func(r *http.Request) string { + if r.URL.Path != "/x-forwarded-for" { + ip, _, _ := net.SplitHostPort(r.RemoteAddr) + return ip + } + return r.Header.Get(http.CanonicalHeaderKey("X-Forwarded-For")) + }), )) r.Get("/info", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -141,6 +163,9 @@ func newRouter(logger *zap.Logger) *chi.Mux { r.Get("/excluded", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + r.Get("/x-forwarded-for", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) r.Delete("/delete/{id}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })