diff --git a/internal/router/rest/rest_test.go b/internal/router/rest/rest_test.go index 4e9bfcb..a5ce49c 100644 --- a/internal/router/rest/rest_test.go +++ b/internal/router/rest/rest_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -47,6 +48,105 @@ func TestMethodNotAllowed(t *testing.T) { }, &model.ErrorResp{}) } +func TestCORS(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg rest.CORSConfig + reqHeaders map[string]string + resHeaders map[string]string + expectedStatus int + }{ + { + name: "Disabled", + cfg: rest.CORSConfig{ + Enabled: false, + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST"}, + AllowCredentials: false, + MaxAge: 600, + }, + reqHeaders: map[string]string{ + "Origin": "https://sessions.tribalwarshelp.com", + "Access-Control-Request-Method": http.MethodGet, + }, + resHeaders: map[string]string{ + "Content-Type": "application/json", + }, + expectedStatus: http.StatusNotFound, + }, + { + name: "*", + cfg: rest.CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST"}, + AllowCredentials: false, + MaxAge: 600, + }, + reqHeaders: map[string]string{ + "Origin": "https://sessions.tribalwarshelp.com", + "Access-Control-Request-Method": http.MethodGet, + }, + resHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": http.MethodGet, + "Access-Control-Max-Age": "600", + "Vary": "Origin Access-Control-Request-Method Access-Control-Request-Headers", + }, + expectedStatus: http.StatusOK, + }, + { + name: "https://sessions.tribalwarshelp.com", + cfg: rest.CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://sessions.tribalwarshelp.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowCredentials: true, + MaxAge: 300, + }, + reqHeaders: map[string]string{ + "Origin": "https://sessions.tribalwarshelp.com", + "Access-Control-Request-Method": http.MethodGet, + }, + resHeaders: map[string]string{ + "Access-Control-Allow-Origin": "https://sessions.tribalwarshelp.com", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Methods": http.MethodGet, + "Access-Control-Max-Age": "300", + "Vary": "Origin Access-Control-Request-Method Access-Control-Request-Headers", + }, + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := rest.NewRouter(rest.RouterConfig{ + CORS: tt.cfg, + }) + + req := httptest.NewRequest(http.MethodOptions, "/url", nil) + for k, v := range tt.reqHeaders { + req.Header.Set(k, v) + } + + resp := doCustomRequest(r, req) + defer resp.Body.Close() + assert.Equal(t, tt.expectedStatus, resp.StatusCode) + assert.Len(t, resp.Header, len(tt.resHeaders)) + for k, expected := range tt.resHeaders { + assert.Equal(t, expected, strings.Join(resp.Header.Values(k), " ")) + } + }) + } +} + func TestSwagger(t *testing.T) { t.Parallel() @@ -87,12 +187,16 @@ func TestSwagger(t *testing.T) { } func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response { - rr := httptest.NewRecorder() req := httptest.NewRequest(method, target, body) if apiKey != "" { req.Header.Set("X-Api-Key", apiKey) } - mux.ServeHTTP(rr, req) + return doCustomRequest(mux, req) +} + +func doCustomRequest(mux chi.Router, r *http.Request) *http.Response { + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, r) return rr.Result() }