refactor: cors - add tests
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
7127be6f5f
commit
bd25ee0fed
|
@ -5,6 +5,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -47,6 +48,105 @@ func TestMethodNotAllowed(t *testing.T) {
|
||||||
}, &model.ErrorResp{})
|
}, &model.ErrorResp{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCORS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg rest.CORSConfig
|
||||||
|
reqHeaders map[string]string
|
||||||
|
resHeaders map[string]string
|
||||||
|
expectedStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Disabled",
|
||||||
|
cfg: rest.CORSConfig{
|
||||||
|
Enabled: false,
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
MaxAge: 600,
|
||||||
|
},
|
||||||
|
reqHeaders: map[string]string{
|
||||||
|
"Origin": "https://sessions.tribalwarshelp.com",
|
||||||
|
"Access-Control-Request-Method": http.MethodGet,
|
||||||
|
},
|
||||||
|
resHeaders: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "*",
|
||||||
|
cfg: rest.CORSConfig{
|
||||||
|
Enabled: true,
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
MaxAge: 600,
|
||||||
|
},
|
||||||
|
reqHeaders: map[string]string{
|
||||||
|
"Origin": "https://sessions.tribalwarshelp.com",
|
||||||
|
"Access-Control-Request-Method": http.MethodGet,
|
||||||
|
},
|
||||||
|
resHeaders: map[string]string{
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": http.MethodGet,
|
||||||
|
"Access-Control-Max-Age": "600",
|
||||||
|
"Vary": "Origin Access-Control-Request-Method Access-Control-Request-Headers",
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https://sessions.tribalwarshelp.com",
|
||||||
|
cfg: rest.CORSConfig{
|
||||||
|
Enabled: true,
|
||||||
|
AllowedOrigins: []string{"https://sessions.tribalwarshelp.com"},
|
||||||
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 300,
|
||||||
|
},
|
||||||
|
reqHeaders: map[string]string{
|
||||||
|
"Origin": "https://sessions.tribalwarshelp.com",
|
||||||
|
"Access-Control-Request-Method": http.MethodGet,
|
||||||
|
},
|
||||||
|
resHeaders: map[string]string{
|
||||||
|
"Access-Control-Allow-Origin": "https://sessions.tribalwarshelp.com",
|
||||||
|
"Access-Control-Allow-Credentials": "true",
|
||||||
|
"Access-Control-Allow-Methods": http.MethodGet,
|
||||||
|
"Access-Control-Max-Age": "300",
|
||||||
|
"Vary": "Origin Access-Control-Request-Method Access-Control-Request-Headers",
|
||||||
|
},
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
r := rest.NewRouter(rest.RouterConfig{
|
||||||
|
CORS: tt.cfg,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/url", nil)
|
||||||
|
for k, v := range tt.reqHeaders {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := doCustomRequest(r, req)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, tt.expectedStatus, resp.StatusCode)
|
||||||
|
assert.Len(t, resp.Header, len(tt.resHeaders))
|
||||||
|
for k, expected := range tt.resHeaders {
|
||||||
|
assert.Equal(t, expected, strings.Join(resp.Header.Values(k), " "))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSwagger(t *testing.T) {
|
func TestSwagger(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -87,12 +187,16 @@ func TestSwagger(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response {
|
func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response {
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(method, target, body)
|
req := httptest.NewRequest(method, target, body)
|
||||||
if apiKey != "" {
|
if apiKey != "" {
|
||||||
req.Header.Set("X-Api-Key", apiKey)
|
req.Header.Set("X-Api-Key", apiKey)
|
||||||
}
|
}
|
||||||
mux.ServeHTTP(rr, req)
|
return doCustomRequest(mux, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doCustomRequest(mux chi.Router, r *http.Request) *http.Response {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
mux.ServeHTTP(rr, r)
|
||||||
return rr.Result()
|
return rr.Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue