83 lines
2.1 KiB
Go
83 lines
2.1 KiB
Go
package chiclientip_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"gitea.dwysokinski.me/Kichiyaki/chiclientip"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/realclientip/realclientip-go"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestClientIP(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
strategy func(t *testing.T) realclientip.Strategy
|
|
remoteAddr string
|
|
header func(header http.Header) http.Header
|
|
expectedClientIP string
|
|
}{
|
|
{
|
|
name: "OK: X-Forwarded-For",
|
|
strategy: func(t *testing.T) realclientip.Strategy {
|
|
t.Helper()
|
|
strategy, err := realclientip.NewRightmostNonPrivateStrategy("X-Forwarded-For")
|
|
require.NoError(t, err)
|
|
return strategy
|
|
},
|
|
remoteAddr: "192.0.2.1:1234",
|
|
header: func(header http.Header) http.Header {
|
|
header["X-Forwarded-For"] = []string{"10.0.42.1", "94.123.222.111", "10.0.42.3"}
|
|
return header
|
|
},
|
|
expectedClientIP: "94.123.222.111",
|
|
},
|
|
{
|
|
name: "OK: X-Forwarded-For, RemoteAddr",
|
|
strategy: func(t *testing.T) realclientip.Strategy {
|
|
t.Helper()
|
|
stratXForwardedFor, err := realclientip.NewRightmostNonPrivateStrategy("X-Forwarded-For")
|
|
require.NoError(t, err)
|
|
return realclientip.NewChainStrategy(
|
|
stratXForwardedFor,
|
|
realclientip.RemoteAddrStrategy{},
|
|
)
|
|
},
|
|
remoteAddr: "192.0.2.1:1234",
|
|
header: func(header http.Header) http.Header {
|
|
return header
|
|
},
|
|
expectedClientIP: "192.0.2.1",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
router := chi.NewRouter()
|
|
router.Use(chiclientip.ClientIP(tt.strategy(t)))
|
|
router.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
|
clientIP, _ := chiclientip.ClientIPFromContext(r.Context())
|
|
_, _ = w.Write([]byte(clientIP))
|
|
})
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
req.Header = tt.header(req.Header.Clone())
|
|
router.ServeHTTP(rr, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rr.Code)
|
|
assert.Equal(t, tt.expectedClientIP, rr.Body.String())
|
|
})
|
|
}
|
|
}
|