152 lines
3.6 KiB
Go
152 lines
3.6 KiB
Go
package rest
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
|
|
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
|
|
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/docs"
|
|
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/cors"
|
|
httpSwagger "github.com/swaggo/http-swagger"
|
|
)
|
|
|
|
//go:generate counterfeiter -generate
|
|
|
|
type CORSConfig struct {
|
|
Enabled bool
|
|
AllowedOrigins []string
|
|
AllowedMethods []string
|
|
AllowCredentials bool
|
|
MaxAge int
|
|
}
|
|
|
|
type SwaggerConfig struct {
|
|
Enabled bool
|
|
Host string
|
|
Schemes []string
|
|
}
|
|
|
|
//counterfeiter:generate -o internal/mock/api_key_verifier.gen.go . APIKeyVerifier
|
|
type APIKeyVerifier interface {
|
|
Verify(ctx context.Context, key string) (domain.User, error)
|
|
}
|
|
|
|
type RouterConfig struct {
|
|
APIKeyVerifier APIKeyVerifier
|
|
CORS CORSConfig
|
|
Swagger SwaggerConfig
|
|
}
|
|
|
|
func NewRouter(cfg RouterConfig) *chi.Mux {
|
|
// handlers
|
|
uh := userHandler{}
|
|
|
|
router := chi.NewRouter()
|
|
|
|
if cfg.CORS.Enabled {
|
|
router.Use(cors.Handler(cors.Options{
|
|
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
|
AllowCredentials: cfg.CORS.AllowCredentials,
|
|
AllowedMethods: cfg.CORS.AllowedMethods,
|
|
AllowedHeaders: []string{"Origin", "Content-Length", "Content-Type", "X-API-Key"},
|
|
MaxAge: cfg.CORS.MaxAge,
|
|
}))
|
|
}
|
|
|
|
router.NotFound(routeNotFoundHandler)
|
|
router.MethodNotAllowed(methodNotAllowedHandler)
|
|
|
|
router.Route("/v1", func(r chi.Router) {
|
|
if cfg.Swagger.Enabled {
|
|
if cfg.Swagger.Host != "" {
|
|
docs.SwaggerInfo.Host = cfg.Swagger.Host
|
|
}
|
|
if len(cfg.Swagger.Schemes) > 0 {
|
|
docs.SwaggerInfo.Schemes = cfg.Swagger.Schemes
|
|
}
|
|
r.Get("/swagger/*", httpSwagger.Handler())
|
|
}
|
|
|
|
authMw := authMiddleware(cfg.APIKeyVerifier)
|
|
|
|
r.With(authMw).Get("/user", uh.getAuthenticated)
|
|
})
|
|
|
|
return router
|
|
}
|
|
|
|
func routeNotFoundHandler(w http.ResponseWriter, _ *http.Request) {
|
|
renderJSON(w, http.StatusNotFound, model.ErrorResp{
|
|
Error: model.APIError{
|
|
Code: "route-not-found",
|
|
Message: "route not found",
|
|
},
|
|
})
|
|
}
|
|
|
|
func methodNotAllowedHandler(w http.ResponseWriter, _ *http.Request) {
|
|
renderJSON(w, http.StatusMethodNotAllowed, model.ErrorResp{
|
|
Error: model.APIError{
|
|
Code: "method-not-allowed",
|
|
Message: "method not allowed",
|
|
},
|
|
})
|
|
}
|
|
|
|
// type internalServerError struct {
|
|
// err error
|
|
// }
|
|
//
|
|
// func (e internalServerError) Error() string {
|
|
// return e.err.Error()
|
|
// }
|
|
//
|
|
// func (e internalServerError) UserError() string {
|
|
// return "internal server error"
|
|
// }
|
|
//
|
|
// func (e internalServerError) Code() domain.ErrorCode {
|
|
// return domain.ErrorCodeUnknown
|
|
// }
|
|
//
|
|
// func (e internalServerError) Unwrap() error {
|
|
// return e.err
|
|
// }
|
|
//
|
|
// func renderErr(w http.ResponseWriter, err error) {
|
|
// var userError domain.Error
|
|
// if !errors.As(err, &userError) {
|
|
// userError = internalServerError{err: err}
|
|
// }
|
|
// renderJSON(w, errorCodeToHTTPStatus(userError.Code()), model.ErrorResp{
|
|
// Error: model.APIError{
|
|
// Code: userError.Code().String(),
|
|
// Message: userError.UserError(),
|
|
// },
|
|
// })
|
|
// }
|
|
//
|
|
// func errorCodeToHTTPStatus(code domain.ErrorCode) int {
|
|
// switch code {
|
|
// case domain.ErrorCodeValidationError:
|
|
// return http.StatusBadRequest
|
|
// case domain.ErrorCodeEntityNotFound:
|
|
// return http.StatusNotFound
|
|
// case domain.ErrorCodeAlreadyExists:
|
|
// return http.StatusConflict
|
|
// case domain.ErrorCodeUnknown:
|
|
// fallthrough
|
|
// default:
|
|
// return http.StatusInternalServerError
|
|
// }
|
|
// }
|
|
|
|
func renderJSON(w http.ResponseWriter, status int, data any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(data)
|
|
}
|