From a6cb80ab75bc08dd501a09966f0a9abd0a6d7b5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Wysoki=C5=84ski?= Date: Fri, 27 Jan 2023 07:13:55 +0100 Subject: [PATCH] refactor: rest.New - introduce required/optional args --- cmd/sessions/internal/serve/serve.go | 18 ++++----- internal/router/meta/meta.go | 2 +- internal/router/meta/meta_test.go | 4 +- internal/router/rest/config.go | 42 ++++++++++++++++++++ internal/router/rest/rest.go | 51 +++++++++--------------- internal/router/rest/rest_test.go | 59 +++++++++++++++++++++------- internal/router/rest/session_test.go | 11 +----- internal/router/rest/user_test.go | 5 +-- 8 files changed, 119 insertions(+), 73 deletions(-) create mode 100644 internal/router/rest/config.go diff --git a/cmd/sessions/internal/serve/serve.go b/cmd/sessions/internal/serve/serve.go index 8cae3fc..c9b43f0 100644 --- a/cmd/sessions/internal/serve/serve.go +++ b/cmd/sessions/internal/serve/serve.go @@ -90,23 +90,23 @@ func newServer(logger *zap.Logger, db *bun.DB) (*http.Server, error) { // router r := chi.NewRouter() r.Use(getMiddlewares(logger)...) - r.Mount(metaEndpointsPrefix, meta.NewRouter([]meta.Checker{bundb.NewChecker(db)})) - r.Mount("/api", rest.NewRouter(rest.RouterConfig{ - APIKeyVerifier: apiKeySvc, - SessionService: sessionSvc, - CORS: rest.CORSConfig{ + r.Mount(metaEndpointsPrefix, meta.New([]meta.Checker{bundb.NewChecker(db)})) + r.Mount("/api", rest.New( + apiKeySvc, + sessionSvc, + rest.WithCORSConfig(rest.CORSConfig{ Enabled: apiCfg.CORSEnabled, AllowedOrigins: apiCfg.CORSAllowedOrigins, AllowedMethods: apiCfg.CORSAllowedMethods, AllowCredentials: apiCfg.CORSAllowCredentials, MaxAge: int(apiCfg.CORSMaxAge), - }, - Swagger: rest.SwaggerConfig{ + }), + rest.WithSwaggerConfig(rest.SwaggerConfig{ Enabled: apiCfg.SwaggerEnabled, Host: apiCfg.SwaggerHost, Schemes: apiCfg.SwaggerSchemes, - }, - })) + }), + )) return &http.Server{ Addr: ":" + defaultPort, diff --git a/internal/router/meta/meta.go b/internal/router/meta/meta.go index 04ba2a2..f05af5e 100644 --- a/internal/router/meta/meta.go +++ b/internal/router/meta/meta.go @@ -21,7 +21,7 @@ type Checker interface { Check(ctx context.Context) error } -func NewRouter(checkers []Checker) *chi.Mux { +func New(checkers []Checker) *chi.Mux { router := chi.NewRouter() router.Get("/livez", func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/router/meta/meta_test.go b/internal/router/meta/meta_test.go index 2a22936..0b8c568 100644 --- a/internal/router/meta/meta_test.go +++ b/internal/router/meta/meta_test.go @@ -20,7 +20,7 @@ func TestLivez(t *testing.T) { t.Parallel() rr := httptest.NewRecorder() - meta.NewRouter(nil).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/livez", nil)) + meta.New(nil).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/livez", nil)) assert.Equal(t, http.StatusNoContent, rr.Code) assert.Equal(t, 0, rr.Body.Len()) @@ -99,7 +99,7 @@ func TestReadyz(t *testing.T) { t.Parallel() rr := httptest.NewRecorder() - meta.NewRouter(tt.checkers).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/readyz", nil)) + meta.New(tt.checkers).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/readyz", nil)) assert.Equal(t, tt.expectedStatus, rr.Code) assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) diff --git a/internal/router/rest/config.go b/internal/router/rest/config.go new file mode 100644 index 0000000..0f0a097 --- /dev/null +++ b/internal/router/rest/config.go @@ -0,0 +1,42 @@ +package rest + +type CORSConfig struct { + Enabled bool + AllowedOrigins []string + AllowedMethods []string + AllowCredentials bool + MaxAge int +} + +type SwaggerConfig struct { + Enabled bool + Host string + Schemes []string +} + +type Option func(cfg *config) + +func WithCORSConfig(cors CORSConfig) Option { + return func(cfg *config) { + cfg.cors = cors + } +} + +func WithSwaggerConfig(swagger SwaggerConfig) Option { + return func(cfg *config) { + cfg.swagger = swagger + } +} + +type config struct { + cors CORSConfig + swagger SwaggerConfig +} + +func newConfig(opts ...Option) *config { + cfg := &config{} + for _, opt := range opts { + opt(cfg) + } + return cfg +} diff --git a/internal/router/rest/rest.go b/internal/router/rest/rest.go index 16ac750..2cac02a 100644 --- a/internal/router/rest/rest.go +++ b/internal/router/rest/rest.go @@ -15,43 +15,28 @@ import ( //go:generate counterfeiter -generate -type CORSConfig struct { - Enabled bool - AllowedOrigins []string - AllowedMethods []string - AllowCredentials bool - MaxAge int -} +func New( + apiKeyVerifier APIKeyVerifier, + sessionSvc SessionService, + opts ...Option, +) *chi.Mux { + cfg := newConfig(opts...) -type SwaggerConfig struct { - Enabled bool - Host string - Schemes []string -} - -type RouterConfig struct { - APIKeyVerifier APIKeyVerifier - SessionService SessionService - CORS CORSConfig - Swagger SwaggerConfig -} - -func NewRouter(cfg RouterConfig) *chi.Mux { // handlers uh := userHandler{} sh := sessionHandler{ - svc: cfg.SessionService, + svc: sessionSvc, } router := chi.NewRouter() - if cfg.CORS.Enabled { + if cfg.cors.Enabled { router.Use(cors.Handler(cors.Options{ - AllowedOrigins: cfg.CORS.AllowedOrigins, - AllowCredentials: cfg.CORS.AllowCredentials, - AllowedMethods: cfg.CORS.AllowedMethods, + AllowedOrigins: cfg.cors.AllowedOrigins, + AllowCredentials: cfg.cors.AllowCredentials, + AllowedMethods: cfg.cors.AllowedMethods, AllowedHeaders: []string{"Origin", "Content-Length", "Content-Type", "X-API-Key"}, - MaxAge: cfg.CORS.MaxAge, + MaxAge: cfg.cors.MaxAge, })) } @@ -59,17 +44,17 @@ func NewRouter(cfg RouterConfig) *chi.Mux { router.MethodNotAllowed(methodNotAllowedHandler) router.Route("/v1", func(r chi.Router) { - if cfg.Swagger.Enabled { - if cfg.Swagger.Host != "" { - docs.SwaggerInfo.Host = cfg.Swagger.Host + if cfg.swagger.Enabled { + if cfg.swagger.Host != "" { + docs.SwaggerInfo.Host = cfg.swagger.Host } - if len(cfg.Swagger.Schemes) > 0 { - docs.SwaggerInfo.Schemes = cfg.Swagger.Schemes + if len(cfg.swagger.Schemes) > 0 { + docs.SwaggerInfo.Schemes = cfg.swagger.Schemes } r.Get("/swagger/*", httpSwagger.Handler()) } - authMw := authMiddleware(cfg.APIKeyVerifier) + authMw := authMiddleware(apiKeyVerifier) r.With(authMw).Get("/user", uh.getCurrent) r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate) diff --git a/internal/router/rest/rest_test.go b/internal/router/rest/rest_test.go index a5ce49c..e7b5a28 100644 --- a/internal/router/rest/rest_test.go +++ b/internal/router/rest/rest_test.go @@ -21,7 +21,7 @@ import ( func TestRouteNotFound(t *testing.T) { t.Parallel() - resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil) + resp := doRequest(newRouter(), http.MethodGet, "/v1/"+uuid.NewString(), "", nil) defer resp.Body.Close() assertJSONResponse(t, resp, http.StatusNotFound, &model.ErrorResp{ Error: model.APIError{ @@ -34,11 +34,10 @@ func TestRouteNotFound(t *testing.T) { func TestMethodNotAllowed(t *testing.T) { t.Parallel() - resp := doRequest(rest.NewRouter(rest.RouterConfig{ - Swagger: rest.SwaggerConfig{ - Enabled: true, - }, - }), http.MethodPost, "/v1/swagger/index.html", "", nil) + router := newRouter(withOption(rest.WithSwaggerConfig(rest.SwaggerConfig{ + Enabled: true, + }))) + resp := doRequest(router, http.MethodPost, "/v1/swagger/index.html", "", nil) defer resp.Body.Close() assertJSONResponse(t, resp, http.StatusMethodNotAllowed, &model.ErrorResp{ Error: model.APIError{ @@ -127,16 +126,14 @@ func TestCORS(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - r := rest.NewRouter(rest.RouterConfig{ - CORS: tt.cfg, - }) + router := newRouter(withOption(rest.WithCORSConfig(tt.cfg))) req := httptest.NewRequest(http.MethodOptions, "/url", nil) for k, v := range tt.reqHeaders { req.Header.Set(k, v) } - resp := doCustomRequest(r, req) + resp := doCustomRequest(router, req) defer resp.Body.Close() assert.Equal(t, tt.expectedStatus, resp.StatusCode) assert.Len(t, resp.Header, len(tt.resHeaders)) @@ -166,11 +163,9 @@ func TestSwagger(t *testing.T) { expectedContentType: "application/json; charset=utf-8", }, } - router := rest.NewRouter(rest.RouterConfig{ - Swagger: rest.SwaggerConfig{ - Enabled: true, - }, - }) + router := newRouter(withOption(rest.WithSwaggerConfig(rest.SwaggerConfig{ + Enabled: true, + }))) for _, tt := range tests { tt := tt @@ -186,6 +181,40 @@ func TestSwagger(t *testing.T) { } } +type routerConfig struct { + apiKeyVerifier rest.APIKeyVerifier + sessionSvc rest.SessionService + opts []rest.Option +} + +type routerOption func(cfg *routerConfig) + +func withSessionService(svc rest.SessionService) routerOption { + return func(cfg *routerConfig) { + cfg.sessionSvc = svc + } +} + +func withAPIKeyVerifier(svc rest.APIKeyVerifier) routerOption { + return func(cfg *routerConfig) { + cfg.apiKeyVerifier = svc + } +} + +func withOption(opt rest.Option) routerOption { + return func(cfg *routerConfig) { + cfg.opts = append(cfg.opts, opt) + } +} + +func newRouter(opts ...routerOption) chi.Router { + cfg := &routerConfig{} + for _, opt := range opts { + opt(cfg) + } + return rest.New(cfg.apiKeyVerifier, cfg.sessionSvc, cfg.opts...) +} + func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response { req := httptest.NewRequest(method, target, body) if apiKey != "" { diff --git a/internal/router/rest/session_test.go b/internal/router/rest/session_test.go index 7a35989..ef6ac19 100644 --- a/internal/router/rest/session_test.go +++ b/internal/router/rest/session_test.go @@ -9,7 +9,6 @@ import ( "time" "gitea.dwysokinski.me/twhelp/sessions/internal/domain" - "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" "github.com/google/uuid" @@ -163,10 +162,7 @@ func TestSession_createOrUpdate(t *testing.T) { sessionSvc := &mock.FakeSessionService{} tt.setup(apiKeySvc, sessionSvc) - router := rest.NewRouter(rest.RouterConfig{ - APIKeyVerifier: apiKeySvc, - SessionService: sessionSvc, - }) + router := newRouter(withAPIKeyVerifier(apiKeySvc), withSessionService(sessionSvc)) resp := doRequest( router, @@ -309,10 +305,7 @@ func TestSession_getCurrentUser(t *testing.T) { sessionSvc := &mock.FakeSessionService{} tt.setup(apiKeySvc, sessionSvc) - router := rest.NewRouter(rest.RouterConfig{ - APIKeyVerifier: apiKeySvc, - SessionService: sessionSvc, - }) + router := newRouter(withAPIKeyVerifier(apiKeySvc), withSessionService(sessionSvc)) resp := doRequest(router, http.MethodGet, "/v1/user/sessions/"+tt.serverKey, tt.apiKey, nil) defer resp.Body.Close() diff --git a/internal/router/rest/user_test.go b/internal/router/rest/user_test.go index 9407b93..d50a0e9 100644 --- a/internal/router/rest/user_test.go +++ b/internal/router/rest/user_test.go @@ -7,7 +7,6 @@ import ( "time" "gitea.dwysokinski.me/twhelp/sessions/internal/domain" - "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" "github.com/google/uuid" @@ -105,9 +104,7 @@ func TestUser_getAuthenticated(t *testing.T) { apiKeySvc := &mock.FakeAPIKeyVerifier{} tt.setup(apiKeySvc) - router := rest.NewRouter(rest.RouterConfig{ - APIKeyVerifier: apiKeySvc, - }) + router := newRouter(withAPIKeyVerifier(apiKeySvc)) resp := doRequest(router, http.MethodGet, "/v1/user", tt.apiKey, nil) defer resp.Body.Close()