refactor: api - error handling refactor

This commit is contained in:
Dawid Wysokiński 2024-02-23 07:51:37 +01:00
parent 30be8bebc0
commit a8b8ffea58
Signed by: Kichiyaki
GPG Key ID: B5445E357FB8B892
5 changed files with 63 additions and 66 deletions

View File

@ -16,6 +16,7 @@ type apiHTTPHandler struct {
versionSvc *app.VersionService
serverSvc *app.ServerService
tribeSvc *app.TribeService
errorRenderer apiErrorRenderer
openAPISchema func() (*openapi3.T, error)
}
@ -73,31 +74,23 @@ func NewAPIHTTPHandler(
h.tribeMiddleware,
},
ErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) {
apiErrorRenderer{errors: []error{err}}.render(w, r)
h.errorRenderer.render(w, r, err)
},
})
}
func (h *apiHTTPHandler) handleNotFound(w http.ResponseWriter, r *http.Request) {
apiErrorRenderer{
errors: []error{
apiError{
status: http.StatusNotFound,
code: "route-not-found",
message: "route not found",
},
},
}.render(w, r)
h.errorRenderer.render(w, r, apiError{
status: http.StatusNotFound,
code: "route-not-found",
message: "route not found",
})
}
func (h *apiHTTPHandler) handleMethodNotAllowed(w http.ResponseWriter, r *http.Request) {
apiErrorRenderer{
errors: []error{
apiError{
status: http.StatusMethodNotAllowed,
code: "method-not-allowed",
message: "method not allowed",
},
},
}.render(w, r)
h.errorRenderer.render(w, r, apiError{
status: http.StatusMethodNotAllowed,
code: "method-not-allowed",
message: "method not allowed",
})
}

View File

@ -71,12 +71,20 @@ type errorPathSegment struct {
index int // index may be <0 and this means that it is unset
}
type errorPathFormatter func(segments []errorPathSegment) []string
type apiErrorRenderer struct {
errors []error
// formatErrorPath allows to override the default path formatter
// errorPathFormatter allows to override the default path formatter
// for domain.ValidationError and domain.SliceElementValidationError.
// If formatErrorPath returns an empty slice, an internal server error is rendered.
formatErrorPath func(segments []errorPathSegment) []string
// If errorPathFormatter returns an empty slice, an internal server error is rendered.
errorPathFormatter errorPathFormatter
}
func (re apiErrorRenderer) withErrorPathFormatter(formatter errorPathFormatter) apiErrorRenderer {
// this assignment is done on purpose
//nolint:revive
re.errorPathFormatter = formatter
return re
}
var errAPIInternalServerError = apiError{
@ -85,10 +93,14 @@ var errAPIInternalServerError = apiError{
message: "internal server error",
}
func (re apiErrorRenderer) render(w http.ResponseWriter, r *http.Request) {
errs := make(apiErrors, 0, len(re.errors))
func (re apiErrorRenderer) render(w http.ResponseWriter, r *http.Request, errs ...error) {
apiErrs := make(apiErrors, 0, len(errs))
for _, err := range errs {
if err == nil {
continue
}
for _, err := range re.errors {
var apiErr apiError
var domainErr domain.Error
var paramFormatErr *apimodel.InvalidParamFormatError
@ -103,16 +115,16 @@ func (re apiErrorRenderer) render(w http.ResponseWriter, r *http.Request) {
apiErr = errAPIInternalServerError
}
errs = append(errs, apiErr)
apiErrs = append(apiErrs, apiErr)
}
renderJSON(w, r, errs.status(), errs.toResponse())
renderJSON(w, r, apiErrs.status(), apiErrs.toResponse())
}
func (re apiErrorRenderer) invalidParamFormatErrorToAPIError(
paramFormatErr *apimodel.InvalidParamFormatError,
) apiError {
var location string
location := "$unknown"
switch paramFormatErr.Location {
case runtime.ParamLocationPath:
@ -124,9 +136,7 @@ func (re apiErrorRenderer) invalidParamFormatErrorToAPIError(
case runtime.ParamLocationCookie:
location = "$cookie"
case runtime.ParamLocationUndefined:
fallthrough
default:
location = "$unknown"
// do nothing
}
return apiError{
@ -171,11 +181,11 @@ func (re apiErrorRenderer) domainErrorToAPIError(domainErr domain.Error) apiErro
var path []string
if len(pathSegments) > 0 {
if re.formatErrorPath == nil {
if re.errorPathFormatter == nil {
return errAPIInternalServerError
}
path = re.formatErrorPath(pathSegments)
path = re.errorPathFormatter(pathSegments)
if len(path) == 0 {
return errAPIInternalServerError
@ -193,7 +203,7 @@ func (re apiErrorRenderer) domainErrorToAPIError(domainErr domain.Error) apiErro
}
return apiError{
status: errorTypeToStatusCode(domainErr.Type()),
status: re.domainErrorTypeToStatusCode(domainErr.Type()),
code: domainErr.Code(),
path: path,
params: cloned,
@ -201,7 +211,7 @@ func (re apiErrorRenderer) domainErrorToAPIError(domainErr domain.Error) apiErro
}
}
func errorTypeToStatusCode(code domain.ErrorType) int {
func (re apiErrorRenderer) domainErrorTypeToStatusCode(code domain.ErrorType) int {
switch code {
case domain.ErrorTypeIncorrectInput:
return http.StatusBadRequest

View File

@ -19,12 +19,12 @@ func (h *apiHTTPHandler) ListServers(
domainParams := domain.NewListServersParams()
if err := domainParams.SetSort([]domain.ServerSort{domain.ServerSortOpenDESC, domain.ServerSortKeyASC}); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err)
return
}
if err := domainParams.SetVersionCodes([]string{versionCode}); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err)
return
}
@ -33,28 +33,28 @@ func (h *apiHTTPHandler) ListServers(
Value: *params.Open,
Valid: true,
}); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err)
return
}
}
if params.Limit != nil {
if err := domainParams.SetLimit(*params.Limit); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err)
return
}
}
if params.Cursor != nil {
if err := domainParams.SetEncodedCursor(*params.Cursor); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err)
return
}
}
res, err := h.serverSvc.List(r.Context(), domainParams)
if err != nil {
apiErrorRenderer{errors: []error{err}}.render(w, r)
h.errorRenderer.render(w, r, err)
return
}
@ -119,7 +119,7 @@ func (h *apiHTTPHandler) serverMiddleware(next http.Handler) http.Handler {
routeCtx.URLParams.Values[serverKeyIdx],
)
if err != nil {
apiErrorRenderer{errors: []error{err}}.render(w, r)
h.errorRenderer.render(w, r, err)
return
}
@ -138,7 +138,7 @@ func serverFromContext(ctx context.Context) (domain.Server, bool) {
return s, ok
}
func formatListServersParamsErrorPath(segments []errorPathSegment) []string {
func formatListServersErrorPath(segments []errorPathSegment) []string {
if segments[0].model != "ListServersParams" {
return nil
}

View File

@ -2,7 +2,6 @@ package port
import (
"context"
"fmt"
"net/http"
"slices"
"strconv"
@ -23,25 +22,25 @@ func (h *apiHTTPHandler) ListTribes(
domainParams := domain.NewListTribesParams()
if err := domainParams.SetSort([]domain.TribeSort{domain.TribeSortIDASC}); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
if params.Sort != nil {
if err := domainParams.PrependSortString(*params.Sort); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
}
if err := domainParams.SetServerKeys([]string{serverKey}); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
if params.Tag != nil {
if err := domainParams.SetTags(*params.Tag); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
}
@ -51,28 +50,28 @@ func (h *apiHTTPHandler) ListTribes(
Value: *params.Deleted,
Valid: true,
}); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
}
if params.Limit != nil {
if err := domainParams.SetLimit(*params.Limit); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
}
if params.Cursor != nil {
if err := domainParams.SetEncodedCursor(*params.Cursor); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err)
return
}
}
res, err := h.tribeSvc.List(r.Context(), domainParams)
if err != nil {
apiErrorRenderer{errors: []error{err}}.render(w, r)
h.errorRenderer.render(w, r, err)
return
}
@ -113,8 +112,7 @@ func (h *apiHTTPHandler) tribeMiddleware(next http.Handler) http.Handler {
server.Key(),
)
if err != nil {
fmt.Println(err)
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatGetTribeErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatGetTribeErrorPath).render(w, r, err)
return
}
@ -133,7 +131,7 @@ func tribeFromContext(ctx context.Context) (domain.Tribe, bool) {
return t, ok
}
func formatListTribesParamsErrorPath(segments []errorPathSegment) []string {
func formatListTribesErrorPath(segments []errorPathSegment) []string {
if segments[0].model != "ListTribesParams" {
return nil
}
@ -163,11 +161,7 @@ func formatGetTribeErrorPath(segments []errorPathSegment) []string {
switch segments[0].field {
case "ids":
path := []string{"$path", "ids"}
if segments[0].index >= 0 {
path = append(path, strconv.Itoa(segments[0].index))
}
return path
return []string{"$path", "tribeId"}
default:
return nil
}

View File

@ -15,21 +15,21 @@ func (h *apiHTTPHandler) ListVersions(w http.ResponseWriter, r *http.Request, pa
if params.Limit != nil {
if err := domainParams.SetLimit(*params.Limit); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListVersionsParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListVersionsErrorPath).render(w, r, err)
return
}
}
if params.Cursor != nil {
if err := domainParams.SetEncodedCursor(*params.Cursor); err != nil {
apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListVersionsParamsErrorPath}.render(w, r)
h.errorRenderer.withErrorPathFormatter(formatListVersionsErrorPath).render(w, r, err)
return
}
}
res, err := h.versionSvc.List(r.Context(), domainParams)
if err != nil {
apiErrorRenderer{errors: []error{err}}.render(w, r)
h.errorRenderer.render(w, r, err)
return
}
@ -54,7 +54,7 @@ func (h *apiHTTPHandler) versionMiddleware(next http.Handler) http.Handler {
version, err := h.versionSvc.Get(ctx, routeCtx.URLParams.Values[idx])
if err != nil {
apiErrorRenderer{errors: []error{err}}.render(w, r)
h.errorRenderer.render(w, r, err)
return
}
@ -73,7 +73,7 @@ func versionFromContext(ctx context.Context) (domain.Version, bool) {
return v, ok
}
func formatListVersionsParamsErrorPath(segments []errorPathSegment) []string {
func formatListVersionsErrorPath(segments []errorPathSegment) []string {
if segments[0].model != "ListVersionsParams" {
return nil
}