refactor: version - split the List function into two
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Dawid Wysokiński 2022-12-22 08:03:33 +01:00
parent 9dd1804d02
commit 1c75179c12
Signed by: Kichiyaki
GPG Key ID: B5445E357FB8B892
7 changed files with 59 additions and 69 deletions

View File

@ -19,19 +19,26 @@ func NewVersion(db *bun.DB) *Version {
return &Version{db: db}
}
func (v *Version) List(ctx context.Context, params domain.ListVersionsParams) ([]domain.Version, int64, error) {
func (v *Version) List(ctx context.Context) ([]domain.Version, error) {
var versions []model.Version
var count int
var err error
q := v.db.NewSelect().
Model(&versions).
Order("code ASC")
if params.Count {
count, err = q.ScanAndCount(ctx)
} else {
err = q.Scan(ctx)
err := v.baseListQuery(&versions).Scan(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("couldn't select versions from the db: %w", err)
}
result := make([]domain.Version, 0, len(versions))
for _, version := range versions {
result = append(result, version.ToDomain())
}
return result, nil
}
func (v *Version) ListCount(ctx context.Context) ([]domain.Version, int64, error) {
var versions []model.Version
count, err := v.baseListQuery(&versions).ScanAndCount(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, 0, fmt.Errorf("couldn't select versions from the db: %w", err)
}
@ -44,6 +51,12 @@ func (v *Version) List(ctx context.Context, params domain.ListVersionsParams) ([
return result, int64(count), nil
}
func (v *Version) baseListQuery(versions *[]model.Version) *bun.SelectQuery {
return v.db.NewSelect().
Model(versions).
Order("code ASC")
}
func (v *Version) GetByCode(ctx context.Context, code string) (domain.Version, error) {
var version model.Version

View File

@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestVersion_List(t *testing.T) {
func TestVersion_List_ListCount(t *testing.T) {
t.Parallel()
repo := bundb.NewVersion(newDB(t))
@ -38,52 +38,24 @@ func TestVersion_List(t *testing.T) {
"sk",
}
tests := []struct {
name string
params domain.ListVersionsParams
expectedVersions []string
expectedCount int64
}{
{
name: "Count=false",
params: domain.ListVersionsParams{
Count: false,
},
expectedVersions: allVersions,
expectedCount: 0,
},
{
name: "Count=true",
params: domain.ListVersionsParams{
Count: true,
},
expectedVersions: allVersions,
expectedCount: int64(len(allVersions)),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
res, count, err := repo.List(context.Background(), tt.params)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCount, count)
assert.Len(t, res, len(tt.expectedVersions))
for _, versionCode := range tt.expectedVersions {
found := false
for _, version := range res {
if version.Code == versionCode {
found = true
break
}
}
assert.True(t, found, "version (code=%s) not found", versionCode)
listCountRes, count, err := repo.ListCount(context.Background())
assert.NoError(t, err)
assert.Equal(t, int64(len(allVersions)), count)
assert.Len(t, listCountRes, len(allVersions))
for _, versionCode := range allVersions {
found := false
for _, version := range listCountRes {
if version.Code == versionCode {
found = true
break
}
})
}
assert.True(t, found, "version (code=%s) not found", versionCode)
}
listRes, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Equal(t, listCountRes, listRes)
}
func TestVersion_GetByCode(t *testing.T) {

View File

@ -9,10 +9,6 @@ type Version struct {
Timezone string
}
type ListVersionsParams struct {
Count bool
}
type VersionNotFoundError struct {
VerCode string
}

View File

@ -13,7 +13,7 @@ import (
//counterfeiter:generate -o internal/mock/version_service.gen.go . VersionService
type VersionService interface {
List(ctx context.Context, params domain.ListVersionsParams) ([]domain.Version, int64, error)
ListCount(ctx context.Context) ([]domain.Version, int64, error)
GetByCode(ctx context.Context, code string) (domain.Version, error)
}
@ -31,7 +31,7 @@ type version struct {
// @Header 200 {integer} X-Total-Count "Total number of records"
// @Router /versions [get]
func (v *version) list(w http.ResponseWriter, r *http.Request) {
versions, count, err := v.svc.List(r.Context(), domain.ListVersionsParams{Count: true})
versions, count, err := v.svc.ListCount(r.Context())
if err != nil {
renderErr(w, err)
return

View File

@ -45,7 +45,7 @@ func TestVersionHandler_list(t *testing.T) {
Timezone: "Europe/Budapest",
},
}
svc.ListReturns(versions, int64(len(versions)), nil)
svc.ListCountReturns(versions, int64(len(versions)), nil)
},
expectedStatus: http.StatusOK,
target: &model.ListVersionsResp{},

View File

@ -9,7 +9,7 @@ import (
)
type VersionLister interface {
List(ctx context.Context, params domain.ListVersionsParams) ([]domain.Version, int64, error)
List(ctx context.Context) ([]domain.Version, error)
}
type ServerLister interface {
@ -54,7 +54,7 @@ func NewJob(
}
func (j *Job) UpdateData(ctx context.Context) error {
versions, _, err := j.versionSvc.List(ctx, domain.ListVersionsParams{Count: false})
versions, err := j.versionSvc.List(ctx)
if err != nil {
return fmt.Errorf("VersionService.List: %w", err)
}
@ -111,7 +111,7 @@ func (j *Job) UpdateEnnoblements(ctx context.Context) error {
}
func (j *Job) CreateSnapshots(ctx context.Context) error {
versions, _, err := j.versionSvc.List(ctx, domain.ListVersionsParams{Count: false})
versions, err := j.versionSvc.List(ctx)
if err != nil {
return fmt.Errorf("VersionService.List: %w", err)
}

View File

@ -8,7 +8,8 @@ import (
)
type VersionRepository interface {
List(ctx context.Context, params domain.ListVersionsParams) ([]domain.Version, int64, error)
List(ctx context.Context) ([]domain.Version, error)
ListCount(ctx context.Context) ([]domain.Version, int64, error)
GetByCode(ctx context.Context, code string) (domain.Version, error)
}
@ -20,10 +21,18 @@ func NewVersion(repo VersionRepository) *Version {
return &Version{repo: repo}
}
func (v *Version) List(ctx context.Context, params domain.ListVersionsParams) ([]domain.Version, int64, error) {
versions, count, err := v.repo.List(ctx, params)
func (v *Version) List(ctx context.Context) ([]domain.Version, error) {
versions, err := v.repo.List(ctx)
if err != nil {
return nil, 0, fmt.Errorf("VersionRepository.List: %w", err)
return nil, fmt.Errorf("VersionRepository.List: %w", err)
}
return versions, nil
}
func (v *Version) ListCount(ctx context.Context) ([]domain.Version, int64, error) {
versions, count, err := v.repo.ListCount(ctx)
if err != nil {
return nil, 0, fmt.Errorf("VersionRepository.ListCount: %w", err)
}
return versions, count, nil
}