add a new method to the Qualification usecase/repository - GetSimilar

This commit is contained in:
Dawid Wysokiński 2021-03-27 16:03:43 +01:00
parent 48d2adb7ae
commit c3c57b780c
7 changed files with 109 additions and 51 deletions

View File

@ -14,9 +14,18 @@ type FetchConfig struct {
Count bool
}
type GetSimilarConfig struct {
Limit int
Offset int
QualificationID int
Sort []string
Count bool
}
type Repository interface {
Store(ctx context.Context, input *models.QualificationInput) (*models.Qualification, error)
UpdateMany(ctx context.Context, f *models.QualificationFilter, input *models.QualificationInput) ([]*models.Qualification, error)
Delete(ctx context.Context, f *models.QualificationFilter) ([]*models.Qualification, error)
Fetch(ctx context.Context, cfg *FetchConfig) ([]*models.Qualification, int, error)
GetSimilar(ctx context.Context, cfg *GetSimilarConfig) ([]*models.Qualification, int, error)
}

View File

@ -39,21 +39,13 @@ func (repo *pgRepository) Store(ctx context.Context, input *models.Qualification
Context(ctx).
Returning("*").
Insert(); err != nil {
if strings.Contains(err.Error(), "name") {
return errorutils.Wrap(err, messageNameIsAlreadyTaken)
} else if strings.Contains(err.Error(), "code") || strings.Contains(err.Error(), "slug") {
return errorutils.Wrap(err, messageCodeIsAlreadyTaken)
}
return errorutils.Wrap(err, messageFailedToSaveModel)
return handleInsertAndUpdateError(err)
}
for _, professionID := range input.AssociateProfession {
tx.
Model(&models.QualificationToProfession{
QualificationID: item.ID,
ProfessionID: professionID,
}).
Insert()
if len(input.AssociateProfession) > 0 {
if err := repo.associateQualificationWithProfession(tx, []int{item.ID}, input.AssociateProfession); err != nil {
return handleInsertAndUpdateError(err)
}
}
return nil
@ -75,12 +67,7 @@ func (repo *pgRepository) UpdateMany(
Apply(input.ApplyUpdate).
Apply(f.Where).
Update(); err != nil && err != pg.ErrNoRows {
if strings.Contains(err.Error(), "name") {
return errorutils.Wrap(err, messageNameIsAlreadyTaken)
} else if strings.Contains(err.Error(), "code") || strings.Contains(err.Error(), "slug") {
return errorutils.Wrap(err, messageCodeIsAlreadyTaken)
}
return errorutils.Wrap(err, messageFailedToSaveModel)
return handleInsertAndUpdateError(err)
}
}
@ -89,7 +76,7 @@ func (repo *pgRepository) UpdateMany(
Context(ctx).
Apply(f.Where).
Select(); err != nil && err != pg.ErrNoRows {
return errorutils.Wrap(err, messageFailedToFetchModel)
return handleInsertAndUpdateError(err)
}
qualificationIDs := make([]int, len(items))
@ -99,24 +86,20 @@ func (repo *pgRepository) UpdateMany(
if len(qualificationIDs) > 0 {
if len(input.DissociateProfession) > 0 {
tx.
_, err := tx.
Model(&models.QualificationToProfession{}).
Where(sqlutils.BuildConditionArray("profession_id"), pg.Array(input.DissociateProfession)).
Where(sqlutils.BuildConditionArray("qualification_id"), pg.Array(qualificationIDs)).
Delete()
if err != nil {
return handleInsertAndUpdateError(err)
}
}
if len(input.AssociateProfession) > 0 {
toInsert := []*models.QualificationToProfession{}
for _, professionID := range input.AssociateProfession {
for _, qualificationID := range qualificationIDs {
toInsert = append(toInsert, &models.QualificationToProfession{
ProfessionID: professionID,
QualificationID: qualificationID,
})
}
if err := repo.associateQualificationWithProfession(tx, qualificationIDs, input.AssociateProfession); err != nil {
return handleInsertAndUpdateError(err)
}
tx.Model(&toInsert).Insert()
}
}
@ -160,3 +143,57 @@ func (repo *pgRepository) Fetch(ctx context.Context, cfg *qualification.FetchCon
}
return items, total, nil
}
func (repo *pgRepository) GetSimilar(ctx context.Context, cfg *qualification.GetSimilarConfig) ([]*models.Qualification, int, error) {
var err error
subquery := repo.
Model(&models.QualificationToProfession{}).
Context(ctx).
Where(sqlutils.BuildConditionEquals("qualification_id"), cfg.QualificationID).
Column("profession_id")
qualificationIDs := []int{}
err = repo.
Model(&models.QualificationToProfession{}).
Context(ctx).
Column("qualification_id").
With("prof", subquery).
Where(sqlutils.BuildConditionIn("profession_id"), pg.Safe("SELECT profession_id FROM prof")).
Where(sqlutils.BuildConditionNEQ("qualification_id"), cfg.QualificationID).
Select(&qualificationIDs)
if err != nil {
return nil, 0, errorutils.Wrap(err, messageFailedToFetchModel)
}
return repo.Fetch(ctx, &qualification.FetchConfig{
Sort: cfg.Sort,
Limit: cfg.Limit,
Offset: cfg.Offset,
Filter: &models.QualificationFilter{
ID: qualificationIDs,
},
Count: cfg.Count,
})
}
func (repo *pgRepository) associateQualificationWithProfession(tx *pg.Tx, qualificationIDs, professionIDs []int) error {
toInsert := []*models.QualificationToProfession{}
for _, professionID := range professionIDs {
for _, qualificationID := range qualificationIDs {
toInsert = append(toInsert, &models.QualificationToProfession{
ProfessionID: professionID,
QualificationID: qualificationID,
})
}
}
_, err := tx.Model(&toInsert).Insert()
return err
}
func handleInsertAndUpdateError(err error) error {
if strings.Contains(err.Error(), "name") {
return errorutils.Wrap(err, messageNameIsAlreadyTaken)
} else if strings.Contains(err.Error(), "code") || strings.Contains(err.Error(), "slug") {
return errorutils.Wrap(err, messageCodeIsAlreadyTaken)
}
return errorutils.Wrap(err, messageFailedToSaveModel)
}

View File

@ -13,4 +13,5 @@ type Usecase interface {
Fetch(ctx context.Context, cfg *FetchConfig) ([]*models.Qualification, int, error)
GetByID(ctx context.Context, id int) (*models.Qualification, error)
GetBySlug(ctx context.Context, slug string) (*models.Qualification, error)
GetSimilar(ctx context.Context, cfg *GetSimilarConfig) ([]*models.Qualification, int, error)
}

View File

@ -1,10 +1,11 @@
package usecase
const (
messageInvalidID = "Niepoprawne ID."
messageItemNotFound = "Nie znaleziono kwalifikacji."
messageEmptyPayload = "Nie wprowadzono jakichkolwiek danych."
messageNameIsRequired = "Nazwa kwalifikacji jest wymagana."
messageCodeIsRequired = "Oznaczenie kwalifikacji jest wymagane."
messageNameIsTooLong = "Nazwa kwalifikacji może się składać z maksymalnie %d znaków."
messageInvalidID = "Niepoprawne ID."
messageItemNotFound = "Nie znaleziono kwalifikacji."
messageEmptyPayload = "Nie wprowadzono jakichkolwiek danych."
messageNameIsRequired = "Nazwa kwalifikacji jest wymagana."
messageCodeIsRequired = "Oznaczenie kwalifikacji jest wymagane."
messageNameIsTooLong = "Nazwa kwalifikacji może się składać z maksymalnie %d znaków."
messageQualificationIDIsRequired = "ID kwalifikacji jest wymagane."
)

View File

@ -2,7 +2,7 @@ package usecase
import (
"context"
"fmt"
"github.com/pkg/errors"
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
"github.com/zdam-egzamin-zawodowy/backend/internal/qualification"
@ -19,7 +19,7 @@ type Config struct {
func New(cfg *Config) (qualification.Usecase, error) {
if cfg == nil || cfg.QualificationRepository == nil {
return nil, fmt.Errorf("qualification/usecase: QualificationRepository is required")
return nil, errors.New("qualification/usecase: QualificationRepository is required")
}
return &usecase{
cfg.QualificationRepository,
@ -35,7 +35,7 @@ func (ucase *usecase) Store(ctx context.Context, input *models.QualificationInpu
func (ucase *usecase) UpdateOneByID(ctx context.Context, id int, input *models.QualificationInput) (*models.Qualification, error) {
if id <= 0 {
return nil, fmt.Errorf(messageInvalidID)
return nil, errors.New(messageInvalidID)
}
if err := validateInput(input.Sanitize(), validateOptions{true}); err != nil {
return nil, err
@ -49,7 +49,7 @@ func (ucase *usecase) UpdateOneByID(ctx context.Context, id int, input *models.Q
return nil, err
}
if len(items) == 0 {
return nil, fmt.Errorf(messageItemNotFound)
return nil, errors.New(messageItemNotFound)
}
return items[0], nil
}
@ -81,7 +81,7 @@ func (ucase *usecase) GetByID(ctx context.Context, id int) (*models.Qualificatio
return nil, err
}
if len(items) == 0 {
return nil, fmt.Errorf(messageItemNotFound)
return nil, errors.New(messageItemNotFound)
}
return items[0], nil
}
@ -98,36 +98,44 @@ func (ucase *usecase) GetBySlug(ctx context.Context, slug string) (*models.Quali
return nil, err
}
if len(items) == 0 {
return nil, fmt.Errorf(messageItemNotFound)
return nil, errors.New(messageItemNotFound)
}
return items[0], nil
}
func (ucase *usecase) GetSimilar(ctx context.Context, cfg *qualification.GetSimilarConfig) ([]*models.Qualification, int, error) {
if cfg == nil || cfg.QualificationID <= 0 {
return nil, 0, errors.New(messageQualificationIDIsRequired)
}
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
return ucase.qualificationRepository.GetSimilar(ctx, cfg)
}
type validateOptions struct {
allowNilValues bool
}
func validateInput(input *models.QualificationInput, opts validateOptions) error {
if input.IsEmpty() {
return fmt.Errorf(messageEmptyPayload)
return errors.New(messageEmptyPayload)
}
if input.Name != nil {
if *input.Name == "" {
return fmt.Errorf(messageNameIsRequired)
return errors.New(messageNameIsRequired)
} else if len(*input.Name) > qualification.MaxNameLength {
return fmt.Errorf(messageNameIsTooLong, qualification.MaxNameLength)
return errors.Errorf(messageNameIsTooLong, qualification.MaxNameLength)
}
} else if !opts.allowNilValues {
return fmt.Errorf(messageNameIsRequired)
return errors.New(messageNameIsRequired)
}
if input.Code != nil {
if *input.Code == "" {
return fmt.Errorf(messageCodeIsRequired)
return errors.New(messageCodeIsRequired)
}
} else if !opts.allowNilValues {
return fmt.Errorf(messageCodeIsRequired)
return errors.New(messageCodeIsRequired)
}
return nil

View File

@ -1,8 +1,6 @@
package errorutils
import (
"fmt"
"github.com/pkg/errors"
"github.com/zdam-egzamin-zawodowy/backend/pkg/mode"
)
@ -15,5 +13,5 @@ func Wrapf(details error, message string, args ...interface{}) error {
if mode.Get() != mode.ProductionMode {
return errors.Wrapf(details, message, args...)
}
return fmt.Errorf(message, args...)
return errors.Errorf(message, args...)
}

View File

@ -22,6 +22,10 @@ func BuildConditionEquals(column string) string {
return column + " = ?"
}
func BuildConditionNEQ(column string) string {
return column + " != ?"
}
func BuildConditionLT(column string) string {
return column + " < ?"
}