refactor: rest.New - introduce required/optional args
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Dawid Wysokiński 2023-01-27 07:13:55 +01:00
parent 02f5e24c6f
commit a6cb80ab75
Signed by: Kichiyaki
GPG Key ID: B5445E357FB8B892
8 changed files with 119 additions and 73 deletions

View File

@ -90,23 +90,23 @@ func newServer(logger *zap.Logger, db *bun.DB) (*http.Server, error) {
// router // router
r := chi.NewRouter() r := chi.NewRouter()
r.Use(getMiddlewares(logger)...) r.Use(getMiddlewares(logger)...)
r.Mount(metaEndpointsPrefix, meta.NewRouter([]meta.Checker{bundb.NewChecker(db)})) r.Mount(metaEndpointsPrefix, meta.New([]meta.Checker{bundb.NewChecker(db)}))
r.Mount("/api", rest.NewRouter(rest.RouterConfig{ r.Mount("/api", rest.New(
APIKeyVerifier: apiKeySvc, apiKeySvc,
SessionService: sessionSvc, sessionSvc,
CORS: rest.CORSConfig{ rest.WithCORSConfig(rest.CORSConfig{
Enabled: apiCfg.CORSEnabled, Enabled: apiCfg.CORSEnabled,
AllowedOrigins: apiCfg.CORSAllowedOrigins, AllowedOrigins: apiCfg.CORSAllowedOrigins,
AllowedMethods: apiCfg.CORSAllowedMethods, AllowedMethods: apiCfg.CORSAllowedMethods,
AllowCredentials: apiCfg.CORSAllowCredentials, AllowCredentials: apiCfg.CORSAllowCredentials,
MaxAge: int(apiCfg.CORSMaxAge), MaxAge: int(apiCfg.CORSMaxAge),
}, }),
Swagger: rest.SwaggerConfig{ rest.WithSwaggerConfig(rest.SwaggerConfig{
Enabled: apiCfg.SwaggerEnabled, Enabled: apiCfg.SwaggerEnabled,
Host: apiCfg.SwaggerHost, Host: apiCfg.SwaggerHost,
Schemes: apiCfg.SwaggerSchemes, Schemes: apiCfg.SwaggerSchemes,
}, }),
})) ))
return &http.Server{ return &http.Server{
Addr: ":" + defaultPort, Addr: ":" + defaultPort,

View File

@ -21,7 +21,7 @@ type Checker interface {
Check(ctx context.Context) error Check(ctx context.Context) error
} }
func NewRouter(checkers []Checker) *chi.Mux { func New(checkers []Checker) *chi.Mux {
router := chi.NewRouter() router := chi.NewRouter()
router.Get("/livez", func(w http.ResponseWriter, r *http.Request) { router.Get("/livez", func(w http.ResponseWriter, r *http.Request) {

View File

@ -20,7 +20,7 @@ func TestLivez(t *testing.T) {
t.Parallel() t.Parallel()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
meta.NewRouter(nil).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/livez", nil)) meta.New(nil).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/livez", nil))
assert.Equal(t, http.StatusNoContent, rr.Code) assert.Equal(t, http.StatusNoContent, rr.Code)
assert.Equal(t, 0, rr.Body.Len()) assert.Equal(t, 0, rr.Body.Len())
@ -99,7 +99,7 @@ func TestReadyz(t *testing.T) {
t.Parallel() t.Parallel()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
meta.NewRouter(tt.checkers).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/readyz", nil)) meta.New(tt.checkers).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/readyz", nil))
assert.Equal(t, tt.expectedStatus, rr.Code) assert.Equal(t, tt.expectedStatus, rr.Code)
assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) assert.Equal(t, "application/json", rr.Header().Get("Content-Type"))

View File

@ -0,0 +1,42 @@
package rest
type CORSConfig struct {
Enabled bool
AllowedOrigins []string
AllowedMethods []string
AllowCredentials bool
MaxAge int
}
type SwaggerConfig struct {
Enabled bool
Host string
Schemes []string
}
type Option func(cfg *config)
func WithCORSConfig(cors CORSConfig) Option {
return func(cfg *config) {
cfg.cors = cors
}
}
func WithSwaggerConfig(swagger SwaggerConfig) Option {
return func(cfg *config) {
cfg.swagger = swagger
}
}
type config struct {
cors CORSConfig
swagger SwaggerConfig
}
func newConfig(opts ...Option) *config {
cfg := &config{}
for _, opt := range opts {
opt(cfg)
}
return cfg
}

View File

@ -15,43 +15,28 @@ import (
//go:generate counterfeiter -generate //go:generate counterfeiter -generate
type CORSConfig struct { func New(
Enabled bool apiKeyVerifier APIKeyVerifier,
AllowedOrigins []string sessionSvc SessionService,
AllowedMethods []string opts ...Option,
AllowCredentials bool ) *chi.Mux {
MaxAge int cfg := newConfig(opts...)
}
type SwaggerConfig struct {
Enabled bool
Host string
Schemes []string
}
type RouterConfig struct {
APIKeyVerifier APIKeyVerifier
SessionService SessionService
CORS CORSConfig
Swagger SwaggerConfig
}
func NewRouter(cfg RouterConfig) *chi.Mux {
// handlers // handlers
uh := userHandler{} uh := userHandler{}
sh := sessionHandler{ sh := sessionHandler{
svc: cfg.SessionService, svc: sessionSvc,
} }
router := chi.NewRouter() router := chi.NewRouter()
if cfg.CORS.Enabled { if cfg.cors.Enabled {
router.Use(cors.Handler(cors.Options{ router.Use(cors.Handler(cors.Options{
AllowedOrigins: cfg.CORS.AllowedOrigins, AllowedOrigins: cfg.cors.AllowedOrigins,
AllowCredentials: cfg.CORS.AllowCredentials, AllowCredentials: cfg.cors.AllowCredentials,
AllowedMethods: cfg.CORS.AllowedMethods, AllowedMethods: cfg.cors.AllowedMethods,
AllowedHeaders: []string{"Origin", "Content-Length", "Content-Type", "X-API-Key"}, AllowedHeaders: []string{"Origin", "Content-Length", "Content-Type", "X-API-Key"},
MaxAge: cfg.CORS.MaxAge, MaxAge: cfg.cors.MaxAge,
})) }))
} }
@ -59,17 +44,17 @@ func NewRouter(cfg RouterConfig) *chi.Mux {
router.MethodNotAllowed(methodNotAllowedHandler) router.MethodNotAllowed(methodNotAllowedHandler)
router.Route("/v1", func(r chi.Router) { router.Route("/v1", func(r chi.Router) {
if cfg.Swagger.Enabled { if cfg.swagger.Enabled {
if cfg.Swagger.Host != "" { if cfg.swagger.Host != "" {
docs.SwaggerInfo.Host = cfg.Swagger.Host docs.SwaggerInfo.Host = cfg.swagger.Host
} }
if len(cfg.Swagger.Schemes) > 0 { if len(cfg.swagger.Schemes) > 0 {
docs.SwaggerInfo.Schemes = cfg.Swagger.Schemes docs.SwaggerInfo.Schemes = cfg.swagger.Schemes
} }
r.Get("/swagger/*", httpSwagger.Handler()) r.Get("/swagger/*", httpSwagger.Handler())
} }
authMw := authMiddleware(cfg.APIKeyVerifier) authMw := authMiddleware(apiKeyVerifier)
r.With(authMw).Get("/user", uh.getCurrent) r.With(authMw).Get("/user", uh.getCurrent)
r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate) r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate)

View File

@ -21,7 +21,7 @@ import (
func TestRouteNotFound(t *testing.T) { func TestRouteNotFound(t *testing.T) {
t.Parallel() t.Parallel()
resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil) resp := doRequest(newRouter(), http.MethodGet, "/v1/"+uuid.NewString(), "", nil)
defer resp.Body.Close() defer resp.Body.Close()
assertJSONResponse(t, resp, http.StatusNotFound, &model.ErrorResp{ assertJSONResponse(t, resp, http.StatusNotFound, &model.ErrorResp{
Error: model.APIError{ Error: model.APIError{
@ -34,11 +34,10 @@ func TestRouteNotFound(t *testing.T) {
func TestMethodNotAllowed(t *testing.T) { func TestMethodNotAllowed(t *testing.T) {
t.Parallel() t.Parallel()
resp := doRequest(rest.NewRouter(rest.RouterConfig{ router := newRouter(withOption(rest.WithSwaggerConfig(rest.SwaggerConfig{
Swagger: rest.SwaggerConfig{ Enabled: true,
Enabled: true, })))
}, resp := doRequest(router, http.MethodPost, "/v1/swagger/index.html", "", nil)
}), http.MethodPost, "/v1/swagger/index.html", "", nil)
defer resp.Body.Close() defer resp.Body.Close()
assertJSONResponse(t, resp, http.StatusMethodNotAllowed, &model.ErrorResp{ assertJSONResponse(t, resp, http.StatusMethodNotAllowed, &model.ErrorResp{
Error: model.APIError{ Error: model.APIError{
@ -127,16 +126,14 @@ func TestCORS(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
r := rest.NewRouter(rest.RouterConfig{ router := newRouter(withOption(rest.WithCORSConfig(tt.cfg)))
CORS: tt.cfg,
})
req := httptest.NewRequest(http.MethodOptions, "/url", nil) req := httptest.NewRequest(http.MethodOptions, "/url", nil)
for k, v := range tt.reqHeaders { for k, v := range tt.reqHeaders {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
resp := doCustomRequest(r, req) resp := doCustomRequest(router, req)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, tt.expectedStatus, resp.StatusCode) assert.Equal(t, tt.expectedStatus, resp.StatusCode)
assert.Len(t, resp.Header, len(tt.resHeaders)) assert.Len(t, resp.Header, len(tt.resHeaders))
@ -166,11 +163,9 @@ func TestSwagger(t *testing.T) {
expectedContentType: "application/json; charset=utf-8", expectedContentType: "application/json; charset=utf-8",
}, },
} }
router := rest.NewRouter(rest.RouterConfig{ router := newRouter(withOption(rest.WithSwaggerConfig(rest.SwaggerConfig{
Swagger: rest.SwaggerConfig{ Enabled: true,
Enabled: true, })))
},
})
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
@ -186,6 +181,40 @@ func TestSwagger(t *testing.T) {
} }
} }
type routerConfig struct {
apiKeyVerifier rest.APIKeyVerifier
sessionSvc rest.SessionService
opts []rest.Option
}
type routerOption func(cfg *routerConfig)
func withSessionService(svc rest.SessionService) routerOption {
return func(cfg *routerConfig) {
cfg.sessionSvc = svc
}
}
func withAPIKeyVerifier(svc rest.APIKeyVerifier) routerOption {
return func(cfg *routerConfig) {
cfg.apiKeyVerifier = svc
}
}
func withOption(opt rest.Option) routerOption {
return func(cfg *routerConfig) {
cfg.opts = append(cfg.opts, opt)
}
}
func newRouter(opts ...routerOption) chi.Router {
cfg := &routerConfig{}
for _, opt := range opts {
opt(cfg)
}
return rest.New(cfg.apiKeyVerifier, cfg.sessionSvc, cfg.opts...)
}
func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response { func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response {
req := httptest.NewRequest(method, target, body) req := httptest.NewRequest(method, target, body)
if apiKey != "" { if apiKey != "" {

View File

@ -9,7 +9,6 @@ import (
"time" "time"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain" "gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
"github.com/google/uuid" "github.com/google/uuid"
@ -163,10 +162,7 @@ func TestSession_createOrUpdate(t *testing.T) {
sessionSvc := &mock.FakeSessionService{} sessionSvc := &mock.FakeSessionService{}
tt.setup(apiKeySvc, sessionSvc) tt.setup(apiKeySvc, sessionSvc)
router := rest.NewRouter(rest.RouterConfig{ router := newRouter(withAPIKeyVerifier(apiKeySvc), withSessionService(sessionSvc))
APIKeyVerifier: apiKeySvc,
SessionService: sessionSvc,
})
resp := doRequest( resp := doRequest(
router, router,
@ -309,10 +305,7 @@ func TestSession_getCurrentUser(t *testing.T) {
sessionSvc := &mock.FakeSessionService{} sessionSvc := &mock.FakeSessionService{}
tt.setup(apiKeySvc, sessionSvc) tt.setup(apiKeySvc, sessionSvc)
router := rest.NewRouter(rest.RouterConfig{ router := newRouter(withAPIKeyVerifier(apiKeySvc), withSessionService(sessionSvc))
APIKeyVerifier: apiKeySvc,
SessionService: sessionSvc,
})
resp := doRequest(router, http.MethodGet, "/v1/user/sessions/"+tt.serverKey, tt.apiKey, nil) resp := doRequest(router, http.MethodGet, "/v1/user/sessions/"+tt.serverKey, tt.apiKey, nil)
defer resp.Body.Close() defer resp.Body.Close()

View File

@ -7,7 +7,6 @@ import (
"time" "time"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain" "gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
"github.com/google/uuid" "github.com/google/uuid"
@ -105,9 +104,7 @@ func TestUser_getAuthenticated(t *testing.T) {
apiKeySvc := &mock.FakeAPIKeyVerifier{} apiKeySvc := &mock.FakeAPIKeyVerifier{}
tt.setup(apiKeySvc) tt.setup(apiKeySvc)
router := rest.NewRouter(rest.RouterConfig{ router := newRouter(withAPIKeyVerifier(apiKeySvc))
APIKeyVerifier: apiKeySvc,
})
resp := doRequest(router, http.MethodGet, "/v1/user", tt.apiKey, nil) resp := doRequest(router, http.MethodGet, "/v1/user", tt.apiKey, nil)
defer resp.Body.Close() defer resp.Body.Close()