From 7f614877c56617670d620819ed7bd4316dc600eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Wysoki=C5=84ski?= Date: Tue, 6 Feb 2024 07:49:01 +0100 Subject: [PATCH] refactor: version cursor - don't allow for nullable values --- internal/adapter/repository_bun_version.go | 11 +++--- internal/adapter/repository_version_test.go | 14 ++----- internal/domain/domaintest/version.go | 7 +--- internal/domain/version.go | 43 ++++++++------------- internal/domain/version_test.go | 30 +++----------- 5 files changed, 34 insertions(+), 71 deletions(-) diff --git a/internal/adapter/repository_bun_version.go b/internal/adapter/repository_bun_version.go index 0e4e54b..7a4effd 100644 --- a/internal/adapter/repository_bun_version.go +++ b/internal/adapter/repository_bun_version.go @@ -50,18 +50,19 @@ func (a listVersionsParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery { } for _, s := range a.params.Sort() { - codeCursor := a.params.Cursor().Code() + cursorZero := a.params.Cursor().IsZero() + cursorCode := a.params.Cursor().Code() switch s { case domain.VersionSortCodeASC: - if codeCursor.Valid { - q = q.Where("version.code >= ?", codeCursor.Value) + if !cursorZero { + q = q.Where("version.code >= ?", cursorCode) } q = q.Order("version.code ASC") case domain.VersionSortCodeDESC: - if codeCursor.Valid { - q = q.Where("version.code <= ?", codeCursor.Value) + if !cursorZero { + q = q.Where("version.code <= ?", cursorCode) } q = q.Order("version.code DESC") diff --git a/internal/adapter/repository_version_test.go b/internal/adapter/repository_version_test.go index 1144646..4ab73e0 100644 --- a/internal/adapter/repository_version_test.go +++ b/internal/adapter/repository_version_test.go @@ -80,10 +80,7 @@ func testVersionRepository(t *testing.T, newRepos func(t *testing.T) repositorie require.Greater(t, len(res.Versions()), 2) require.NoError(t, params.SetCursor(domaintest.NewVersionCursor(t, func(cfg *domaintest.VersionCursorConfig) { - cfg.Code = domain.NullString{ - Value: res.Versions()[1].Code(), - Valid: true, - } + cfg.Code = res.Versions()[1].Code() }))) return params @@ -95,7 +92,7 @@ func testVersionRepository(t *testing.T, newRepos func(t *testing.T) repositorie return cmp.Compare(a.Code(), b.Code()) })) for _, v := range res.Versions() { - assert.GreaterOrEqual(t, v.Code(), params.Cursor().Code().Value, v.Code()) + assert.GreaterOrEqual(t, v.Code(), params.Cursor().Code(), v.Code()) } }, assertError: func(t *testing.T, err error) { @@ -116,10 +113,7 @@ func testVersionRepository(t *testing.T, newRepos func(t *testing.T) repositorie require.Greater(t, len(res.Versions()), 2) require.NoError(t, params.SetCursor(domaintest.NewVersionCursor(t, func(cfg *domaintest.VersionCursorConfig) { - cfg.Code = domain.NullString{ - Value: res.Versions()[1].Code(), - Valid: true, - } + cfg.Code = res.Versions()[1].Code() }))) return params @@ -131,7 +125,7 @@ func testVersionRepository(t *testing.T, newRepos func(t *testing.T) repositorie return cmp.Compare(a.Code(), b.Code()) * -1 })) for _, v := range res.Versions() { - assert.LessOrEqual(t, v.Code(), params.Cursor().Code().Value, v.Code()) + assert.LessOrEqual(t, v.Code(), params.Cursor().Code(), v.Code()) } }, assertError: func(t *testing.T, err error) { diff --git a/internal/domain/domaintest/version.go b/internal/domain/domaintest/version.go index 2bf5480..b977caf 100644 --- a/internal/domain/domaintest/version.go +++ b/internal/domain/domaintest/version.go @@ -7,17 +7,14 @@ import ( ) type VersionCursorConfig struct { - Code domain.NullString + Code string } func NewVersionCursor(tb TestingTB, opts ...func(cfg *VersionCursorConfig)) domain.VersionCursor { tb.Helper() cfg := &VersionCursorConfig{ - Code: domain.NullString{ - Value: RandVersionCode(), - Valid: true, - }, + Code: RandVersionCode(), } for _, opt := range opts { diff --git a/internal/domain/version.go b/internal/domain/version.go index c444a75..b950958 100644 --- a/internal/domain/version.go +++ b/internal/domain/version.go @@ -108,17 +108,13 @@ const ( ) type VersionCursor struct { - code NullString + code string } const versionCursorModelName = "VersionCursor" -func NewVersionCursor(code NullString) (VersionCursor, error) { - if !code.Valid { - return VersionCursor{}, nil - } - - if err := validateVersionCode(code.Value); err != nil { +func NewVersionCursor(code string) (VersionCursor, error) { + if err := validateVersionCode(code); err != nil { return VersionCursor{}, ValidationError{ Model: versionCursorModelName, Field: "code", @@ -152,10 +148,7 @@ func decodeVersionCursor(encoded string) (VersionCursor, error) { return VersionCursor{}, ErrInvalidCursor } - vc, err := NewVersionCursor(NullString{ - Value: code, - Valid: true, - }) + vc, err := NewVersionCursor(code) if err != nil { return VersionCursor{}, ErrInvalidCursor } @@ -163,12 +156,12 @@ func decodeVersionCursor(encoded string) (VersionCursor, error) { return vc, nil } -func (vc VersionCursor) Code() NullString { +func (vc VersionCursor) Code() string { return vc.code } func (vc VersionCursor) IsZero() bool { - return !vc.code.Valid + return vc == VersionCursor{} } func (vc VersionCursor) Encode() string { @@ -176,7 +169,7 @@ func (vc VersionCursor) Encode() string { return "" } - return base64.StdEncoding.EncodeToString([]byte("code=" + vc.code.Value)) + return base64.StdEncoding.EncodeToString([]byte("code=" + vc.code)) } type ListVersionsParams struct { @@ -287,10 +280,7 @@ func NewListVersionsResult(versions Versions, next Version) (ListVersionsResult, } if len(versions) > 0 { - res.self, err = NewVersionCursor(NullString{ - Value: versions[0].Code(), - Valid: true, - }) + res.self, err = NewVersionCursor(versions[0].Code()) if err != nil { return ListVersionsResult{}, ValidationError{ Model: listVersionsResultModelName, @@ -300,15 +290,14 @@ func NewListVersionsResult(versions Versions, next Version) (ListVersionsResult, } } - res.next, err = NewVersionCursor(NullString{ - Value: next.Code(), - Valid: !next.IsZero(), - }) - if err != nil { - return ListVersionsResult{}, ValidationError{ - Model: listVersionsResultModelName, - Field: "next", - Err: err, + if !next.IsZero() { + res.next, err = NewVersionCursor(next.Code()) + if err != nil { + return ListVersionsResult{}, ValidationError{ + Model: listVersionsResultModelName, + Field: "next", + Err: err, + } } } diff --git a/internal/domain/version_test.go b/internal/domain/version_test.go index 3f89120..5e8eec3 100644 --- a/internal/domain/version_test.go +++ b/internal/domain/version_test.go @@ -17,7 +17,7 @@ func TestNewVersionCursor(t *testing.T) { validVersionCursor := domaintest.NewVersionCursor(t) type args struct { - code domain.NullString + code string } type test struct { @@ -34,23 +34,13 @@ func TestNewVersionCursor(t *testing.T) { }, expectedErr: nil, }, - { - name: "OK: nil code", - args: args{ - code: domain.NullString{}, - }, - expectedErr: nil, - }, } for _, versionCodeTest := range newVersionCodeValidationTests() { tests = append(tests, test{ name: versionCodeTest.name, args: args{ - code: domain.NullString{ - Value: versionCodeTest.code, - Valid: true, - }, + code: versionCodeTest.code, }, expectedErr: domain.ValidationError{ Model: "VersionCursor", @@ -72,11 +62,7 @@ func TestNewVersionCursor(t *testing.T) { return } assert.Equal(t, tt.args.code, vc.Code()) - if tt.args.code.Valid { - assert.NotEmpty(t, vc.Encode()) - } else { - assert.Empty(t, vc.Encode()) - } + assert.NotEmpty(t, vc.Encode()) }) } } @@ -231,7 +217,6 @@ func TestListVersionsParams_SetEncodedCursor(t *testing.T) { fmt.Println(tt.expectedErr.Error()) return } - assert.True(t, params.Cursor().Code().Valid) assert.Equal(t, tt.expectedCursor.Code(), params.Cursor().Code()) assert.Equal(t, tt.args.cursor, params.Cursor().Encode()) }) @@ -319,10 +304,8 @@ func TestNewListVersionsResult(t *testing.T) { res, err := domain.NewListVersionsResult(versions, next) require.NoError(t, err) assert.Equal(t, versions, res.Versions()) - assert.True(t, res.Self().Code().Valid) - assert.Equal(t, versions[0].Code(), res.Self().Code().Value) - assert.True(t, res.Next().Code().Valid) - assert.Equal(t, next.Code(), res.Next().Code().Value) + assert.Equal(t, versions[0].Code(), res.Self().Code()) + assert.Equal(t, next.Code(), res.Next().Code()) }) t.Run("OK: without next", func(t *testing.T) { @@ -331,8 +314,7 @@ func TestNewListVersionsResult(t *testing.T) { res, err := domain.NewListVersionsResult(versions, domain.Version{}) require.NoError(t, err) assert.Equal(t, versions, res.Versions()) - assert.True(t, res.Self().Code().Valid) - assert.Equal(t, versions[0].Code(), res.Self().Code().Value) + assert.Equal(t, versions[0].Code(), res.Self().Code()) assert.True(t, res.Next().IsZero()) })