From a8b8ffea58faf4f65c779c0c08301106719ae4d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Wysoki=C5=84ski?= Date: Fri, 23 Feb 2024 07:51:37 +0100 Subject: [PATCH] refactor: api - error handling refactor --- internal/port/handler_http_api.go | 31 +++++++--------- internal/port/handler_http_api_error.go | 44 ++++++++++++++--------- internal/port/handler_http_api_server.go | 16 ++++----- internal/port/handler_http_api_tribe.go | 28 ++++++--------- internal/port/handler_http_api_version.go | 10 +++--- 5 files changed, 63 insertions(+), 66 deletions(-) diff --git a/internal/port/handler_http_api.go b/internal/port/handler_http_api.go index 8484774..eb12c11 100644 --- a/internal/port/handler_http_api.go +++ b/internal/port/handler_http_api.go @@ -16,6 +16,7 @@ type apiHTTPHandler struct { versionSvc *app.VersionService serverSvc *app.ServerService tribeSvc *app.TribeService + errorRenderer apiErrorRenderer openAPISchema func() (*openapi3.T, error) } @@ -73,31 +74,23 @@ func NewAPIHTTPHandler( h.tribeMiddleware, }, ErrorHandlerFunc: func(w http.ResponseWriter, r *http.Request, err error) { - apiErrorRenderer{errors: []error{err}}.render(w, r) + h.errorRenderer.render(w, r, err) }, }) } func (h *apiHTTPHandler) handleNotFound(w http.ResponseWriter, r *http.Request) { - apiErrorRenderer{ - errors: []error{ - apiError{ - status: http.StatusNotFound, - code: "route-not-found", - message: "route not found", - }, - }, - }.render(w, r) + h.errorRenderer.render(w, r, apiError{ + status: http.StatusNotFound, + code: "route-not-found", + message: "route not found", + }) } func (h *apiHTTPHandler) handleMethodNotAllowed(w http.ResponseWriter, r *http.Request) { - apiErrorRenderer{ - errors: []error{ - apiError{ - status: http.StatusMethodNotAllowed, - code: "method-not-allowed", - message: "method not allowed", - }, - }, - }.render(w, r) + h.errorRenderer.render(w, r, apiError{ + status: http.StatusMethodNotAllowed, + code: "method-not-allowed", + message: "method not allowed", + }) } diff --git a/internal/port/handler_http_api_error.go b/internal/port/handler_http_api_error.go index dadeb9b..c227335 100644 --- a/internal/port/handler_http_api_error.go +++ b/internal/port/handler_http_api_error.go @@ -71,12 +71,20 @@ type errorPathSegment struct { index int // index may be <0 and this means that it is unset } +type errorPathFormatter func(segments []errorPathSegment) []string + type apiErrorRenderer struct { - errors []error - // formatErrorPath allows to override the default path formatter + // errorPathFormatter allows to override the default path formatter // for domain.ValidationError and domain.SliceElementValidationError. - // If formatErrorPath returns an empty slice, an internal server error is rendered. - formatErrorPath func(segments []errorPathSegment) []string + // If errorPathFormatter returns an empty slice, an internal server error is rendered. + errorPathFormatter errorPathFormatter +} + +func (re apiErrorRenderer) withErrorPathFormatter(formatter errorPathFormatter) apiErrorRenderer { + // this assignment is done on purpose + //nolint:revive + re.errorPathFormatter = formatter + return re } var errAPIInternalServerError = apiError{ @@ -85,10 +93,14 @@ var errAPIInternalServerError = apiError{ message: "internal server error", } -func (re apiErrorRenderer) render(w http.ResponseWriter, r *http.Request) { - errs := make(apiErrors, 0, len(re.errors)) +func (re apiErrorRenderer) render(w http.ResponseWriter, r *http.Request, errs ...error) { + apiErrs := make(apiErrors, 0, len(errs)) + + for _, err := range errs { + if err == nil { + continue + } - for _, err := range re.errors { var apiErr apiError var domainErr domain.Error var paramFormatErr *apimodel.InvalidParamFormatError @@ -103,16 +115,16 @@ func (re apiErrorRenderer) render(w http.ResponseWriter, r *http.Request) { apiErr = errAPIInternalServerError } - errs = append(errs, apiErr) + apiErrs = append(apiErrs, apiErr) } - renderJSON(w, r, errs.status(), errs.toResponse()) + renderJSON(w, r, apiErrs.status(), apiErrs.toResponse()) } func (re apiErrorRenderer) invalidParamFormatErrorToAPIError( paramFormatErr *apimodel.InvalidParamFormatError, ) apiError { - var location string + location := "$unknown" switch paramFormatErr.Location { case runtime.ParamLocationPath: @@ -124,9 +136,7 @@ func (re apiErrorRenderer) invalidParamFormatErrorToAPIError( case runtime.ParamLocationCookie: location = "$cookie" case runtime.ParamLocationUndefined: - fallthrough - default: - location = "$unknown" + // do nothing } return apiError{ @@ -171,11 +181,11 @@ func (re apiErrorRenderer) domainErrorToAPIError(domainErr domain.Error) apiErro var path []string if len(pathSegments) > 0 { - if re.formatErrorPath == nil { + if re.errorPathFormatter == nil { return errAPIInternalServerError } - path = re.formatErrorPath(pathSegments) + path = re.errorPathFormatter(pathSegments) if len(path) == 0 { return errAPIInternalServerError @@ -193,7 +203,7 @@ func (re apiErrorRenderer) domainErrorToAPIError(domainErr domain.Error) apiErro } return apiError{ - status: errorTypeToStatusCode(domainErr.Type()), + status: re.domainErrorTypeToStatusCode(domainErr.Type()), code: domainErr.Code(), path: path, params: cloned, @@ -201,7 +211,7 @@ func (re apiErrorRenderer) domainErrorToAPIError(domainErr domain.Error) apiErro } } -func errorTypeToStatusCode(code domain.ErrorType) int { +func (re apiErrorRenderer) domainErrorTypeToStatusCode(code domain.ErrorType) int { switch code { case domain.ErrorTypeIncorrectInput: return http.StatusBadRequest diff --git a/internal/port/handler_http_api_server.go b/internal/port/handler_http_api_server.go index 6f304b1..6fb64bf 100644 --- a/internal/port/handler_http_api_server.go +++ b/internal/port/handler_http_api_server.go @@ -19,12 +19,12 @@ func (h *apiHTTPHandler) ListServers( domainParams := domain.NewListServersParams() if err := domainParams.SetSort([]domain.ServerSort{domain.ServerSortOpenDESC, domain.ServerSortKeyASC}); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err) return } if err := domainParams.SetVersionCodes([]string{versionCode}); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err) return } @@ -33,28 +33,28 @@ func (h *apiHTTPHandler) ListServers( Value: *params.Open, Valid: true, }); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err) return } } if params.Limit != nil { if err := domainParams.SetLimit(*params.Limit); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err) return } } if params.Cursor != nil { if err := domainParams.SetEncodedCursor(*params.Cursor); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListServersParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListServersErrorPath).render(w, r, err) return } } res, err := h.serverSvc.List(r.Context(), domainParams) if err != nil { - apiErrorRenderer{errors: []error{err}}.render(w, r) + h.errorRenderer.render(w, r, err) return } @@ -119,7 +119,7 @@ func (h *apiHTTPHandler) serverMiddleware(next http.Handler) http.Handler { routeCtx.URLParams.Values[serverKeyIdx], ) if err != nil { - apiErrorRenderer{errors: []error{err}}.render(w, r) + h.errorRenderer.render(w, r, err) return } @@ -138,7 +138,7 @@ func serverFromContext(ctx context.Context) (domain.Server, bool) { return s, ok } -func formatListServersParamsErrorPath(segments []errorPathSegment) []string { +func formatListServersErrorPath(segments []errorPathSegment) []string { if segments[0].model != "ListServersParams" { return nil } diff --git a/internal/port/handler_http_api_tribe.go b/internal/port/handler_http_api_tribe.go index 79003ab..6370ee6 100644 --- a/internal/port/handler_http_api_tribe.go +++ b/internal/port/handler_http_api_tribe.go @@ -2,7 +2,6 @@ package port import ( "context" - "fmt" "net/http" "slices" "strconv" @@ -23,25 +22,25 @@ func (h *apiHTTPHandler) ListTribes( domainParams := domain.NewListTribesParams() if err := domainParams.SetSort([]domain.TribeSort{domain.TribeSortIDASC}); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } if params.Sort != nil { if err := domainParams.PrependSortString(*params.Sort); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } } if err := domainParams.SetServerKeys([]string{serverKey}); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } if params.Tag != nil { if err := domainParams.SetTags(*params.Tag); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } } @@ -51,28 +50,28 @@ func (h *apiHTTPHandler) ListTribes( Value: *params.Deleted, Valid: true, }); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } } if params.Limit != nil { if err := domainParams.SetLimit(*params.Limit); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } } if params.Cursor != nil { if err := domainParams.SetEncodedCursor(*params.Cursor); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListTribesParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListTribesErrorPath).render(w, r, err) return } } res, err := h.tribeSvc.List(r.Context(), domainParams) if err != nil { - apiErrorRenderer{errors: []error{err}}.render(w, r) + h.errorRenderer.render(w, r, err) return } @@ -113,8 +112,7 @@ func (h *apiHTTPHandler) tribeMiddleware(next http.Handler) http.Handler { server.Key(), ) if err != nil { - fmt.Println(err) - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatGetTribeErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatGetTribeErrorPath).render(w, r, err) return } @@ -133,7 +131,7 @@ func tribeFromContext(ctx context.Context) (domain.Tribe, bool) { return t, ok } -func formatListTribesParamsErrorPath(segments []errorPathSegment) []string { +func formatListTribesErrorPath(segments []errorPathSegment) []string { if segments[0].model != "ListTribesParams" { return nil } @@ -163,11 +161,7 @@ func formatGetTribeErrorPath(segments []errorPathSegment) []string { switch segments[0].field { case "ids": - path := []string{"$path", "ids"} - if segments[0].index >= 0 { - path = append(path, strconv.Itoa(segments[0].index)) - } - return path + return []string{"$path", "tribeId"} default: return nil } diff --git a/internal/port/handler_http_api_version.go b/internal/port/handler_http_api_version.go index 3fee7a2..e3bf42a 100644 --- a/internal/port/handler_http_api_version.go +++ b/internal/port/handler_http_api_version.go @@ -15,21 +15,21 @@ func (h *apiHTTPHandler) ListVersions(w http.ResponseWriter, r *http.Request, pa if params.Limit != nil { if err := domainParams.SetLimit(*params.Limit); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListVersionsParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListVersionsErrorPath).render(w, r, err) return } } if params.Cursor != nil { if err := domainParams.SetEncodedCursor(*params.Cursor); err != nil { - apiErrorRenderer{errors: []error{err}, formatErrorPath: formatListVersionsParamsErrorPath}.render(w, r) + h.errorRenderer.withErrorPathFormatter(formatListVersionsErrorPath).render(w, r, err) return } } res, err := h.versionSvc.List(r.Context(), domainParams) if err != nil { - apiErrorRenderer{errors: []error{err}}.render(w, r) + h.errorRenderer.render(w, r, err) return } @@ -54,7 +54,7 @@ func (h *apiHTTPHandler) versionMiddleware(next http.Handler) http.Handler { version, err := h.versionSvc.Get(ctx, routeCtx.URLParams.Values[idx]) if err != nil { - apiErrorRenderer{errors: []error{err}}.render(w, r) + h.errorRenderer.render(w, r, err) return } @@ -73,7 +73,7 @@ func versionFromContext(ctx context.Context) (domain.Version, bool) { return v, ok } -func formatListVersionsParamsErrorPath(segments []errorPathSegment) []string { +func formatListVersionsErrorPath(segments []errorPathSegment) []string { if segments[0].model != "ListVersionsParams" { return nil }