package rest import ( "encoding/json" "errors" "net/http" "gitea.dwysokinski.me/twhelp/sessions/internal/domain" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/docs" "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" httpSwagger "github.com/swaggo/http-swagger" ) //go:generate counterfeiter -generate type CORSConfig struct { Enabled bool AllowedOrigins []string AllowedMethods []string AllowCredentials bool MaxAge int } type SwaggerConfig struct { Enabled bool Host string Schemes []string } type RouterConfig struct { APIKeyVerifier APIKeyVerifier SessionService SessionService CORS CORSConfig Swagger SwaggerConfig } func NewRouter(cfg RouterConfig) *chi.Mux { // handlers uh := userHandler{} sh := sessionHandler{ svc: cfg.SessionService, } router := chi.NewRouter() if cfg.CORS.Enabled { router.Use(cors.Handler(cors.Options{ AllowedOrigins: cfg.CORS.AllowedOrigins, AllowCredentials: cfg.CORS.AllowCredentials, AllowedMethods: cfg.CORS.AllowedMethods, AllowedHeaders: []string{"Origin", "Content-Length", "Content-Type", "X-API-Key"}, MaxAge: cfg.CORS.MaxAge, })) } router.NotFound(routeNotFoundHandler) router.MethodNotAllowed(methodNotAllowedHandler) router.Route("/v1", func(r chi.Router) { if cfg.Swagger.Enabled { if cfg.Swagger.Host != "" { docs.SwaggerInfo.Host = cfg.Swagger.Host } if len(cfg.Swagger.Schemes) > 0 { docs.SwaggerInfo.Schemes = cfg.Swagger.Schemes } r.Get("/swagger/*", httpSwagger.Handler()) } authMw := authMiddleware(cfg.APIKeyVerifier) r.With(authMw).Get("/user", uh.getCurrent) r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate) r.With(authMw).Get("/user/sessions/{serverKey}", sh.getCurrentUser) }) return router } func routeNotFoundHandler(w http.ResponseWriter, _ *http.Request) { renderJSON(w, http.StatusNotFound, model.ErrorResp{ Error: model.APIError{ Code: "route-not-found", Message: "route not found", }, }) } func methodNotAllowedHandler(w http.ResponseWriter, _ *http.Request) { renderJSON(w, http.StatusMethodNotAllowed, model.ErrorResp{ Error: model.APIError{ Code: "method-not-allowed", Message: "method not allowed", }, }) } type internalServerError struct { err error } func (e internalServerError) Error() string { return e.err.Error() } func (e internalServerError) UserError() string { return "internal server error" } func (e internalServerError) Code() domain.ErrorCode { return domain.ErrorCodeUnknown } func (e internalServerError) Unwrap() error { return e.err } func renderErr(w http.ResponseWriter, err error) { var domainErr domain.Error if !errors.As(err, &domainErr) { domainErr = internalServerError{err: err} } renderJSON(w, errorCodeToHTTPStatus(domainErr.Code()), model.ErrorResp{ Error: model.APIError{ Code: domainErr.Code().String(), Message: domainErr.UserError(), }, }) } func errorCodeToHTTPStatus(code domain.ErrorCode) int { switch code { case domain.ErrorCodeValidationError: return http.StatusBadRequest case domain.ErrorCodeEntityNotFound: return http.StatusNotFound case domain.ErrorCodeAlreadyExists: return http.StatusConflict case domain.ErrorCodeUnknown: fallthrough default: return http.StatusInternalServerError } } func renderJSON(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(data) }