refactor: cors - add tests
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
7127be6f5f
commit
bd25ee0fed
|
@ -5,6 +5,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -47,6 +48,105 @@ func TestMethodNotAllowed(t *testing.T) {
|
|||
}, &model.ErrorResp{})
|
||||
}
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg rest.CORSConfig
|
||||
reqHeaders map[string]string
|
||||
resHeaders map[string]string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Disabled",
|
||||
cfg: rest.CORSConfig{
|
||||
Enabled: false,
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowCredentials: false,
|
||||
MaxAge: 600,
|
||||
},
|
||||
reqHeaders: map[string]string{
|
||||
"Origin": "https://sessions.tribalwarshelp.com",
|
||||
"Access-Control-Request-Method": http.MethodGet,
|
||||
},
|
||||
resHeaders: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "*",
|
||||
cfg: rest.CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowCredentials: false,
|
||||
MaxAge: 600,
|
||||
},
|
||||
reqHeaders: map[string]string{
|
||||
"Origin": "https://sessions.tribalwarshelp.com",
|
||||
"Access-Control-Request-Method": http.MethodGet,
|
||||
},
|
||||
resHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": http.MethodGet,
|
||||
"Access-Control-Max-Age": "600",
|
||||
"Vary": "Origin Access-Control-Request-Method Access-Control-Request-Headers",
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "https://sessions.tribalwarshelp.com",
|
||||
cfg: rest.CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://sessions.tribalwarshelp.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300,
|
||||
},
|
||||
reqHeaders: map[string]string{
|
||||
"Origin": "https://sessions.tribalwarshelp.com",
|
||||
"Access-Control-Request-Method": http.MethodGet,
|
||||
},
|
||||
resHeaders: map[string]string{
|
||||
"Access-Control-Allow-Origin": "https://sessions.tribalwarshelp.com",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": http.MethodGet,
|
||||
"Access-Control-Max-Age": "300",
|
||||
"Vary": "Origin Access-Control-Request-Method Access-Control-Request-Headers",
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := rest.NewRouter(rest.RouterConfig{
|
||||
CORS: tt.cfg,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/url", nil)
|
||||
for k, v := range tt.reqHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp := doCustomRequest(r, req)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, tt.expectedStatus, resp.StatusCode)
|
||||
assert.Len(t, resp.Header, len(tt.resHeaders))
|
||||
for k, expected := range tt.resHeaders {
|
||||
assert.Equal(t, expected, strings.Join(resp.Header.Values(k), " "))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwagger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -87,12 +187,16 @@ func TestSwagger(t *testing.T) {
|
|||
}
|
||||
|
||||
func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(method, target, body)
|
||||
if apiKey != "" {
|
||||
req.Header.Set("X-Api-Key", apiKey)
|
||||
}
|
||||
mux.ServeHTTP(rr, req)
|
||||
return doCustomRequest(mux, req)
|
||||
}
|
||||
|
||||
func doCustomRequest(mux chi.Router, r *http.Request) *http.Response {
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, r)
|
||||
return rr.Result()
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue