refactor: rest.New - introduce required/optional args
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
02f5e24c6f
commit
a6cb80ab75
|
@ -90,23 +90,23 @@ func newServer(logger *zap.Logger, db *bun.DB) (*http.Server, error) {
|
|||
// router
|
||||
r := chi.NewRouter()
|
||||
r.Use(getMiddlewares(logger)...)
|
||||
r.Mount(metaEndpointsPrefix, meta.NewRouter([]meta.Checker{bundb.NewChecker(db)}))
|
||||
r.Mount("/api", rest.NewRouter(rest.RouterConfig{
|
||||
APIKeyVerifier: apiKeySvc,
|
||||
SessionService: sessionSvc,
|
||||
CORS: rest.CORSConfig{
|
||||
r.Mount(metaEndpointsPrefix, meta.New([]meta.Checker{bundb.NewChecker(db)}))
|
||||
r.Mount("/api", rest.New(
|
||||
apiKeySvc,
|
||||
sessionSvc,
|
||||
rest.WithCORSConfig(rest.CORSConfig{
|
||||
Enabled: apiCfg.CORSEnabled,
|
||||
AllowedOrigins: apiCfg.CORSAllowedOrigins,
|
||||
AllowedMethods: apiCfg.CORSAllowedMethods,
|
||||
AllowCredentials: apiCfg.CORSAllowCredentials,
|
||||
MaxAge: int(apiCfg.CORSMaxAge),
|
||||
},
|
||||
Swagger: rest.SwaggerConfig{
|
||||
}),
|
||||
rest.WithSwaggerConfig(rest.SwaggerConfig{
|
||||
Enabled: apiCfg.SwaggerEnabled,
|
||||
Host: apiCfg.SwaggerHost,
|
||||
Schemes: apiCfg.SwaggerSchemes,
|
||||
},
|
||||
}))
|
||||
}),
|
||||
))
|
||||
|
||||
return &http.Server{
|
||||
Addr: ":" + defaultPort,
|
||||
|
|
|
@ -21,7 +21,7 @@ type Checker interface {
|
|||
Check(ctx context.Context) error
|
||||
}
|
||||
|
||||
func NewRouter(checkers []Checker) *chi.Mux {
|
||||
func New(checkers []Checker) *chi.Mux {
|
||||
router := chi.NewRouter()
|
||||
|
||||
router.Get("/livez", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -20,7 +20,7 @@ func TestLivez(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
meta.NewRouter(nil).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/livez", nil))
|
||||
meta.New(nil).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/livez", nil))
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, rr.Code)
|
||||
assert.Equal(t, 0, rr.Body.Len())
|
||||
|
@ -99,7 +99,7 @@ func TestReadyz(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
meta.NewRouter(tt.checkers).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/readyz", nil))
|
||||
meta.New(tt.checkers).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/readyz", nil))
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||
assert.Equal(t, "application/json", rr.Header().Get("Content-Type"))
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
package rest
|
||||
|
||||
type CORSConfig struct {
|
||||
Enabled bool
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowCredentials bool
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
type SwaggerConfig struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Schemes []string
|
||||
}
|
||||
|
||||
type Option func(cfg *config)
|
||||
|
||||
func WithCORSConfig(cors CORSConfig) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.cors = cors
|
||||
}
|
||||
}
|
||||
|
||||
func WithSwaggerConfig(swagger SwaggerConfig) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.swagger = swagger
|
||||
}
|
||||
}
|
||||
|
||||
type config struct {
|
||||
cors CORSConfig
|
||||
swagger SwaggerConfig
|
||||
}
|
||||
|
||||
func newConfig(opts ...Option) *config {
|
||||
cfg := &config{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -15,43 +15,28 @@ import (
|
|||
|
||||
//go:generate counterfeiter -generate
|
||||
|
||||
type CORSConfig struct {
|
||||
Enabled bool
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowCredentials bool
|
||||
MaxAge int
|
||||
}
|
||||
func New(
|
||||
apiKeyVerifier APIKeyVerifier,
|
||||
sessionSvc SessionService,
|
||||
opts ...Option,
|
||||
) *chi.Mux {
|
||||
cfg := newConfig(opts...)
|
||||
|
||||
type SwaggerConfig struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Schemes []string
|
||||
}
|
||||
|
||||
type RouterConfig struct {
|
||||
APIKeyVerifier APIKeyVerifier
|
||||
SessionService SessionService
|
||||
CORS CORSConfig
|
||||
Swagger SwaggerConfig
|
||||
}
|
||||
|
||||
func NewRouter(cfg RouterConfig) *chi.Mux {
|
||||
// handlers
|
||||
uh := userHandler{}
|
||||
sh := sessionHandler{
|
||||
svc: cfg.SessionService,
|
||||
svc: sessionSvc,
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
|
||||
if cfg.CORS.Enabled {
|
||||
if cfg.cors.Enabled {
|
||||
router.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
||||
AllowCredentials: cfg.CORS.AllowCredentials,
|
||||
AllowedMethods: cfg.CORS.AllowedMethods,
|
||||
AllowedOrigins: cfg.cors.AllowedOrigins,
|
||||
AllowCredentials: cfg.cors.AllowCredentials,
|
||||
AllowedMethods: cfg.cors.AllowedMethods,
|
||||
AllowedHeaders: []string{"Origin", "Content-Length", "Content-Type", "X-API-Key"},
|
||||
MaxAge: cfg.CORS.MaxAge,
|
||||
MaxAge: cfg.cors.MaxAge,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -59,17 +44,17 @@ func NewRouter(cfg RouterConfig) *chi.Mux {
|
|||
router.MethodNotAllowed(methodNotAllowedHandler)
|
||||
|
||||
router.Route("/v1", func(r chi.Router) {
|
||||
if cfg.Swagger.Enabled {
|
||||
if cfg.Swagger.Host != "" {
|
||||
docs.SwaggerInfo.Host = cfg.Swagger.Host
|
||||
if cfg.swagger.Enabled {
|
||||
if cfg.swagger.Host != "" {
|
||||
docs.SwaggerInfo.Host = cfg.swagger.Host
|
||||
}
|
||||
if len(cfg.Swagger.Schemes) > 0 {
|
||||
docs.SwaggerInfo.Schemes = cfg.Swagger.Schemes
|
||||
if len(cfg.swagger.Schemes) > 0 {
|
||||
docs.SwaggerInfo.Schemes = cfg.swagger.Schemes
|
||||
}
|
||||
r.Get("/swagger/*", httpSwagger.Handler())
|
||||
}
|
||||
|
||||
authMw := authMiddleware(cfg.APIKeyVerifier)
|
||||
authMw := authMiddleware(apiKeyVerifier)
|
||||
|
||||
r.With(authMw).Get("/user", uh.getCurrent)
|
||||
r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate)
|
||||
|
|
|
@ -21,7 +21,7 @@ import (
|
|||
func TestRouteNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil)
|
||||
resp := doRequest(newRouter(), http.MethodGet, "/v1/"+uuid.NewString(), "", nil)
|
||||
defer resp.Body.Close()
|
||||
assertJSONResponse(t, resp, http.StatusNotFound, &model.ErrorResp{
|
||||
Error: model.APIError{
|
||||
|
@ -34,11 +34,10 @@ func TestRouteNotFound(t *testing.T) {
|
|||
func TestMethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp := doRequest(rest.NewRouter(rest.RouterConfig{
|
||||
Swagger: rest.SwaggerConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}), http.MethodPost, "/v1/swagger/index.html", "", nil)
|
||||
router := newRouter(withOption(rest.WithSwaggerConfig(rest.SwaggerConfig{
|
||||
Enabled: true,
|
||||
})))
|
||||
resp := doRequest(router, http.MethodPost, "/v1/swagger/index.html", "", nil)
|
||||
defer resp.Body.Close()
|
||||
assertJSONResponse(t, resp, http.StatusMethodNotAllowed, &model.ErrorResp{
|
||||
Error: model.APIError{
|
||||
|
@ -127,16 +126,14 @@ func TestCORS(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := rest.NewRouter(rest.RouterConfig{
|
||||
CORS: tt.cfg,
|
||||
})
|
||||
router := newRouter(withOption(rest.WithCORSConfig(tt.cfg)))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/url", nil)
|
||||
for k, v := range tt.reqHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp := doCustomRequest(r, req)
|
||||
resp := doCustomRequest(router, req)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, tt.expectedStatus, resp.StatusCode)
|
||||
assert.Len(t, resp.Header, len(tt.resHeaders))
|
||||
|
@ -166,11 +163,9 @@ func TestSwagger(t *testing.T) {
|
|||
expectedContentType: "application/json; charset=utf-8",
|
||||
},
|
||||
}
|
||||
router := rest.NewRouter(rest.RouterConfig{
|
||||
Swagger: rest.SwaggerConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
})
|
||||
router := newRouter(withOption(rest.WithSwaggerConfig(rest.SwaggerConfig{
|
||||
Enabled: true,
|
||||
})))
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
@ -186,6 +181,40 @@ func TestSwagger(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type routerConfig struct {
|
||||
apiKeyVerifier rest.APIKeyVerifier
|
||||
sessionSvc rest.SessionService
|
||||
opts []rest.Option
|
||||
}
|
||||
|
||||
type routerOption func(cfg *routerConfig)
|
||||
|
||||
func withSessionService(svc rest.SessionService) routerOption {
|
||||
return func(cfg *routerConfig) {
|
||||
cfg.sessionSvc = svc
|
||||
}
|
||||
}
|
||||
|
||||
func withAPIKeyVerifier(svc rest.APIKeyVerifier) routerOption {
|
||||
return func(cfg *routerConfig) {
|
||||
cfg.apiKeyVerifier = svc
|
||||
}
|
||||
}
|
||||
|
||||
func withOption(opt rest.Option) routerOption {
|
||||
return func(cfg *routerConfig) {
|
||||
cfg.opts = append(cfg.opts, opt)
|
||||
}
|
||||
}
|
||||
|
||||
func newRouter(opts ...routerOption) chi.Router {
|
||||
cfg := &routerConfig{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return rest.New(cfg.apiKeyVerifier, cfg.sessionSvc, cfg.opts...)
|
||||
}
|
||||
|
||||
func doRequest(mux chi.Router, method, target, apiKey string, body io.Reader) *http.Response {
|
||||
req := httptest.NewRequest(method, target, body)
|
||||
if apiKey != "" {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"time"
|
||||
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest"
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock"
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
|
||||
"github.com/google/uuid"
|
||||
|
@ -163,10 +162,7 @@ func TestSession_createOrUpdate(t *testing.T) {
|
|||
sessionSvc := &mock.FakeSessionService{}
|
||||
tt.setup(apiKeySvc, sessionSvc)
|
||||
|
||||
router := rest.NewRouter(rest.RouterConfig{
|
||||
APIKeyVerifier: apiKeySvc,
|
||||
SessionService: sessionSvc,
|
||||
})
|
||||
router := newRouter(withAPIKeyVerifier(apiKeySvc), withSessionService(sessionSvc))
|
||||
|
||||
resp := doRequest(
|
||||
router,
|
||||
|
@ -309,10 +305,7 @@ func TestSession_getCurrentUser(t *testing.T) {
|
|||
sessionSvc := &mock.FakeSessionService{}
|
||||
tt.setup(apiKeySvc, sessionSvc)
|
||||
|
||||
router := rest.NewRouter(rest.RouterConfig{
|
||||
APIKeyVerifier: apiKeySvc,
|
||||
SessionService: sessionSvc,
|
||||
})
|
||||
router := newRouter(withAPIKeyVerifier(apiKeySvc), withSessionService(sessionSvc))
|
||||
|
||||
resp := doRequest(router, http.MethodGet, "/v1/user/sessions/"+tt.serverKey, tt.apiKey, nil)
|
||||
defer resp.Body.Close()
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"time"
|
||||
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest"
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock"
|
||||
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
|
||||
"github.com/google/uuid"
|
||||
|
@ -105,9 +104,7 @@ func TestUser_getAuthenticated(t *testing.T) {
|
|||
apiKeySvc := &mock.FakeAPIKeyVerifier{}
|
||||
tt.setup(apiKeySvc)
|
||||
|
||||
router := rest.NewRouter(rest.RouterConfig{
|
||||
APIKeyVerifier: apiKeySvc,
|
||||
})
|
||||
router := newRouter(withAPIKeyVerifier(apiKeySvc))
|
||||
|
||||
resp := doRequest(router, http.MethodGet, "/v1/user", tt.apiKey, nil)
|
||||
defer resp.Body.Close()
|
||||
|
|
Loading…
Reference in New Issue