package chiclientip_test import ( "net/http" "net/http/httptest" "testing" "gitea.dwysokinski.me/Kichiyaki/chiclientip" "github.com/go-chi/chi/v5" "github.com/realclientip/realclientip-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestClientIP(t *testing.T) { t.Parallel() tests := []struct { name string strategy func(t *testing.T) realclientip.Strategy remoteAddr string header func(header http.Header) http.Header expectedClientIP string }{ { name: "OK: X-Forwarded-For", strategy: func(t *testing.T) realclientip.Strategy { t.Helper() strategy, err := realclientip.NewRightmostNonPrivateStrategy("X-Forwarded-For") require.NoError(t, err) return strategy }, remoteAddr: "192.0.2.1:1234", header: func(header http.Header) http.Header { header["X-Forwarded-For"] = []string{"10.0.42.1", "94.123.222.111", "10.0.42.3"} return header }, expectedClientIP: "94.123.222.111", }, { name: "OK: X-Forwarded-For, RemoteAddr", strategy: func(t *testing.T) realclientip.Strategy { t.Helper() stratXForwardedFor, err := realclientip.NewRightmostNonPrivateStrategy("X-Forwarded-For") require.NoError(t, err) return realclientip.NewChainStrategy( stratXForwardedFor, realclientip.RemoteAddrStrategy{}, ) }, remoteAddr: "192.0.2.1:1234", header: func(header http.Header) http.Header { return header }, expectedClientIP: "192.0.2.1", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() router := chi.NewRouter() router.Use(chiclientip.ClientIP(tt.strategy(t))) router.Get("/", func(w http.ResponseWriter, r *http.Request) { clientIP, _ := chiclientip.ClientIPFromContext(r.Context()) _, _ = w.Write([]byte(clientIP)) }) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = tt.remoteAddr req.Header = tt.header(req.Header.Clone()) router.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, tt.expectedClientIP, rr.Body.String()) }) } }