From f636d6ac942ef7e8c67a29143bb3cc2ad71c3c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Wysoki=C5=84ski?= Date: Sat, 24 Dec 2022 07:44:25 +0000 Subject: [PATCH] refactor: village - split the List method into two (#145) Reviewed-on: https://gitea.dwysokinski.me/twhelp/core/pulls/145 --- internal/bundb/bundb.go | 26 ++--- internal/bundb/ennoblement.go | 15 ++- internal/bundb/player.go | 38 +++---- internal/bundb/player_test.go | 2 +- internal/bundb/server.go | 3 +- internal/bundb/tribe.go | 3 +- internal/bundb/version.go | 3 +- internal/bundb/village.go | 64 ++++++----- internal/bundb/village_test.go | 163 ++++++--------------------- internal/domain/village.go | 9 +- internal/router/rest/village.go | 24 +--- internal/router/rest/village_test.go | 18 +-- internal/service/village.go | 24 ++-- internal/service/village_test.go | 28 ++--- 14 files changed, 149 insertions(+), 271 deletions(-) diff --git a/internal/bundb/bundb.go b/internal/bundb/bundb.go index 8916fcb..7b2393c 100644 --- a/internal/bundb/bundb.go +++ b/internal/bundb/bundb.go @@ -56,19 +56,6 @@ func scanAndCount( var count int var err error - wg.Add(1) - go func() { - defer wg.Done() - - if scanErr := scanQ.Scan(ctx); scanErr == nil { - mu.Lock() - if err == nil { - err = scanErr - } - mu.Unlock() - } - }() - wg.Add(1) go func() { defer wg.Done() @@ -83,6 +70,19 @@ func scanAndCount( } }() + wg.Add(1) + go func() { + defer wg.Done() + + if scanErr := scanQ.Scan(ctx); scanErr == nil { + mu.Lock() + if err == nil { + err = scanErr + } + mu.Unlock() + } + }() + wg.Wait() return count, err diff --git a/internal/bundb/ennoblement.go b/internal/bundb/ennoblement.go index cefc0e0..bf27096 100644 --- a/internal/bundb/ennoblement.go +++ b/internal/bundb/ennoblement.go @@ -43,7 +43,6 @@ func (e *Ennoblement) List( ctx context.Context, params domain.ListEnnoblementsParams, ) ([]domain.EnnoblementWithRelations, int64, error) { - var ennoblements []model.Ennoblement var count int var err error @@ -53,18 +52,18 @@ func (e *Ennoblement) List( base := e.db.NewSelect(). Model(&model.Ennoblement{}). Order("ennoblement.server_key ASC"). - Apply(paramsApplier.applyFilters). - Apply(paramsApplier.applyPagination) - base, err = paramsApplier.applySort(base) - if err != nil { - return nil, 0, fmt.Errorf("listEnnoblementsParamsApplier.applySort: %w", err) - } + Apply(paramsApplier.applyFilters) q := e.db.NewSelect(). With("ennoblements_base", base). Model(&ennoblements). ModelTableExpr("ennoblements_base AS ennoblement"). - Apply(paramsApplier.applyRelations) + Apply(paramsApplier.applyRelations). + Apply(paramsApplier.applyPagination) + q, err = paramsApplier.applySort(q) + if err != nil { + return nil, 0, err + } if params.Count { count, err = scanAndCount(ctx, base, q) diff --git a/internal/bundb/player.go b/internal/bundb/player.go index 3fecf99..0e97506 100644 --- a/internal/bundb/player.go +++ b/internal/bundb/player.go @@ -14,6 +14,7 @@ import ( var ( playerMetaColumns = []string{"id", "name", "profile_url"} + playerOrders = []string{"player.server_key ASC"} ) type Player struct { @@ -82,13 +83,15 @@ func (p *Player) Delete(ctx context.Context, serverKey string, ids ...int64) err func (p *Player) List(ctx context.Context, params domain.ListPlayersParams) ([]domain.Player, error) { var players []model.Player - q, err := p.baseListQuery(&players, params) + q := p.db.NewSelect(). + Model(&players). + Order(playerOrders...) + q, err := listPlayersParamsApplier{params: params}.apply(q) if err != nil { return nil, err } - err = q.Scan(ctx) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err = q.Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("couldn't select players from the db: %w", err) } @@ -103,19 +106,27 @@ func (p *Player) List(ctx context.Context, params domain.ListPlayersParams) ([]d func (p *Player) ListCountWithRelations(ctx context.Context, params domain.ListPlayersParams) ([]domain.PlayerWithRelations, int64, error) { var players []model.Player - base, err := p.baseListQuery(&model.Player{}, params) - if err != nil { - return nil, 0, err - } + paramsApplier := listPlayersParamsApplier{params} + + base := p.db.NewSelect(). + Model(&model.Player{}). + Apply(paramsApplier.applyFilters) q := p.db.NewSelect(). With("players_base", base). Model(&players). + Order(playerOrders...). + Apply(paramsApplier.applyPagination). ModelTableExpr("players_base AS player"). Relation("Tribe", func(q *bun.SelectQuery) *bun.SelectQuery { return q.Column(tribeMetaColumns...) }) + q, err := paramsApplier.applySort(q) + if err != nil { + return nil, 0, err + } + count, err := scanAndCount(ctx, base, q) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, 0, fmt.Errorf("couldn't select players from the db: %w", err) @@ -129,19 +140,6 @@ func (p *Player) ListCountWithRelations(ctx context.Context, params domain.ListP return result, int64(count), nil } -func (p *Player) baseListQuery(model any, params domain.ListPlayersParams) (*bun.SelectQuery, error) { - q := p.db.NewSelect(). - Model(model). - Order("player.server_key ASC") - - q, err := listPlayersParamsApplier{params}.apply(q) - if err != nil { - return nil, fmt.Errorf("listPlayersParamsApplier.apply: %w", err) - } - - return q, nil -} - type listPlayersParamsApplier struct { params domain.ListPlayersParams } diff --git a/internal/bundb/player_test.go b/internal/bundb/player_test.go index 20885dd..22d1f96 100644 --- a/internal/bundb/player_test.go +++ b/internal/bundb/player_test.go @@ -600,7 +600,7 @@ func TestPlayer_List_ListCountWithRelations(t *testing.T) { continue } - if player.TribeID > 0 && (!player.Tribe.Valid || player.Tribe.Tribe.ID != player.TribeID) { + if player.TribeID != 0 && (!player.Tribe.Valid || player.Tribe.Tribe.ID != player.TribeID) { continue } diff --git a/internal/bundb/server.go b/internal/bundb/server.go index 569db35..03b06de 100644 --- a/internal/bundb/server.go +++ b/internal/bundb/server.go @@ -58,8 +58,7 @@ func (s *Server) CreateOrUpdate(ctx context.Context, params ...domain.CreateServ func (s *Server) List(ctx context.Context, params domain.ListServersParams) ([]domain.Server, error) { var servers []model.Server - err := s.baseListQuery(&servers, params).Scan(ctx) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err := s.baseListQuery(&servers, params).Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("couldn't select servers from the db: %w", err) } diff --git a/internal/bundb/tribe.go b/internal/bundb/tribe.go index d258458..cbf6eac 100644 --- a/internal/bundb/tribe.go +++ b/internal/bundb/tribe.go @@ -105,8 +105,7 @@ func (t *Tribe) List(ctx context.Context, params domain.ListTribesParams) ([]dom return nil, err } - err = q.Scan(ctx) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err = q.Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("couldn't select tribes from the db: %w", err) } diff --git a/internal/bundb/version.go b/internal/bundb/version.go index 36fdea4..9b23f18 100644 --- a/internal/bundb/version.go +++ b/internal/bundb/version.go @@ -22,8 +22,7 @@ func NewVersion(db *bun.DB) *Version { func (v *Version) List(ctx context.Context) ([]domain.Version, error) { var versions []model.Version - err := v.baseListQuery(&versions).Scan(ctx) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err := v.baseListQuery(&versions).Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("couldn't select versions from the db: %w", err) } diff --git a/internal/bundb/village.go b/internal/bundb/village.go index 2468d8b..5990b79 100644 --- a/internal/bundb/village.go +++ b/internal/bundb/village.go @@ -13,6 +13,7 @@ import ( var ( villageMetaColumns = []string{"id", "name", "x", "y", "continent", "profile_url"} + villageOrders = []string{"village.server_key ASC", "village.id ASC"} ) type Village struct { @@ -51,30 +52,49 @@ func (v *Village) CreateOrUpdate(ctx context.Context, params ...domain.CreateVil return nil } -func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { +func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.Village, error) { + var villages []model.Village + + err := v.db.NewSelect(). + Model(&villages). + Order(villageOrders...). + Apply(listVillagesParamsApplier{params}.apply). + Scan(ctx) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("couldn't select villages from the db: %w", err) + } + + result := make([]domain.Village, 0, len(villages)) + for _, village := range villages { + result = append(result, village.ToDomain()) + } + + return result, nil +} + +func (v *Village) ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { var villages []model.Village - var count int - var err error paramsApplier := listVillagesParamsApplier{params} base := v.db.NewSelect(). Model(&model.Village{}). - Order("village.server_key ASC", "village.id ASC"). - Apply(paramsApplier.applyFilters). - Apply(paramsApplier.applyPagination) + Apply(paramsApplier.applyFilters) q := v.db.NewSelect(). With("villages_base", base). Model(&villages). ModelTableExpr("villages_base AS village"). - Apply(paramsApplier.applyRelations) + Order(villageOrders...). + Apply(paramsApplier.applyPagination). + Relation("Player", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Column(playerMetaColumns...) + }). + Relation("Player.Tribe", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Column(tribeMetaColumns...) + }) - if params.Count { - count, err = scanAndCount(ctx, base, q) - } else { - err = q.Scan(ctx) - } + count, err := scanAndCount(ctx, base, q) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, 0, fmt.Errorf("couldn't select villages from the db: %w", err) } @@ -91,6 +111,10 @@ type listVillagesParamsApplier struct { params domain.ListVillagesParams } +func (l listVillagesParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery { + return l.applyPagination(l.applyFilters(q)) +} + func (l listVillagesParamsApplier) applyFilters(q *bun.SelectQuery) *bun.SelectQuery { if l.params.IDs != nil { q = q.Where("village.id IN (?)", bun.In(l.params.IDs)) @@ -103,22 +127,6 @@ func (l listVillagesParamsApplier) applyFilters(q *bun.SelectQuery) *bun.SelectQ return q } -func (l listVillagesParamsApplier) applyRelations(q *bun.SelectQuery) *bun.SelectQuery { - if l.params.IncludePlayer { - q = q.Relation("Player", func(q *bun.SelectQuery) *bun.SelectQuery { - return q.Column(playerMetaColumns...) - }) - } - - if l.params.IncludePlayerTribe { - q = q.Relation("Player.Tribe", func(q *bun.SelectQuery) *bun.SelectQuery { - return q.Column(tribeMetaColumns...) - }) - } - - return q -} - func (l listVillagesParamsApplier) applyPagination(q *bun.SelectQuery) *bun.SelectQuery { return (paginationApplier{pagination: l.params.Pagination}).apply(q) } diff --git a/internal/bundb/village_test.go b/internal/bundb/village_test.go index 183689f..5ab7e10 100644 --- a/internal/bundb/village_test.go +++ b/internal/bundb/village_test.go @@ -84,12 +84,12 @@ func TestVillage_CreateOrUpdate(t *testing.T) { ctx := context.Background() assert.NoError(t, repo.CreateOrUpdate(ctx, createParams...)) - createdVillages, _, err := repo.List(ctx, listParams) + createdVillages, err := repo.List(ctx, listParams) assert.NoError(t, err) assertCreatedVillages(t, createParams, createdVillages) assert.NoError(t, repo.CreateOrUpdate(ctx, updateParams...)) - updatedVillages, _, err := repo.List(ctx, listParams) + updatedVillages, err := repo.List(ctx, listParams) assert.NoError(t, err) assertCreatedVillages(t, updateParams, updatedVillages) }) @@ -121,7 +121,7 @@ func TestVillage_CreateOrUpdate(t *testing.T) { }) } -func TestVillage_List(t *testing.T) { +func TestVillage_List_ListCountWithRelations(t *testing.T) { t.Parallel() db := newDB(t) @@ -132,8 +132,6 @@ func TestVillage_List(t *testing.T) { type expectedVillage struct { id int64 serverKey string - playerID int64 - tribeID int64 } allVillages := make([]expectedVillage, 0, len(villages)) @@ -164,26 +162,20 @@ func TestVillage_List(t *testing.T) { expectedCount int64 }{ { - name: "Count=true", - params: domain.ListVillagesParams{Count: true}, + name: "Empty struct", + params: domain.ListVillagesParams{}, expectedVillages: allVillages, expectedCount: int64(len(allVillages)), }, { - name: "Count=false", - params: domain.ListVillagesParams{Count: false}, - expectedVillages: allVillages, - expectedCount: 0, - }, - { - name: "ServerKeys=[it70],Count=true", - params: domain.ListVillagesParams{ServerKeys: []string{"it70"}, Count: true}, + name: "ServerKeys=[it70]", + params: domain.ListVillagesParams{ServerKeys: []string{"it70"}}, expectedVillages: villagesIT70, expectedCount: int64(len(villagesIT70)), }, { - name: "IDs=[10022, 1114],Count=true", - params: domain.ListVillagesParams{IDs: []int64{10022, 1114}, Count: true}, + name: "IDs=[10022, 1114]", + params: domain.ListVillagesParams{IDs: []int64{10022, 1114}}, expectedVillages: []expectedVillage{ { id: 10022, @@ -197,8 +189,8 @@ func TestVillage_List(t *testing.T) { expectedCount: 2, }, { - name: "IDs=[1113, 1114],ServerKeys=[pl169],Count=true", - params: domain.ListVillagesParams{IDs: []int64{1113, 1114}, ServerKeys: []string{"pl169"}, Count: true}, + name: "IDs=[1113, 1114],ServerKeys=[pl169]", + params: domain.ListVillagesParams{IDs: []int64{1113, 1114}, ServerKeys: []string{"pl169"}}, expectedVillages: []expectedVillage{ { id: 1113, @@ -212,105 +204,13 @@ func TestVillage_List(t *testing.T) { expectedCount: 2, }, { - name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],IncludePlayer=true,Count=true", - params: domain.ListVillagesParams{ - IDs: []int64{10022, 10023, 10024}, - ServerKeys: []string{"it70"}, - IncludePlayer: true, - Count: true, - }, - expectedVillages: []expectedVillage{ - { - id: 10022, - serverKey: "it70", - playerID: 578014, - tribeID: 0, - }, - { - id: 10023, - serverKey: "it70", - playerID: 578014, - tribeID: 0, - }, - { - id: 10024, - serverKey: "it70", - playerID: 0, - tribeID: 0, - }, - }, - expectedCount: 3, - }, - { - name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],IncludePlayerTribe=true,Count=true", - params: domain.ListVillagesParams{ - IDs: []int64{10022, 10023, 10024}, - ServerKeys: []string{"it70"}, - IncludePlayerTribe: true, - Count: true, - }, - expectedVillages: []expectedVillage{ - { - id: 10022, - serverKey: "it70", - playerID: 578014, - tribeID: 1, - }, - { - id: 10023, - serverKey: "it70", - playerID: 578014, - tribeID: 1, - }, - { - id: 10024, - serverKey: "it70", - playerID: 0, - tribeID: 0, - }, - }, - expectedCount: 3, - }, - { - name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],IncludePlayer=true,IncludePlayerTribe=true,Count=true", - params: domain.ListVillagesParams{ - IDs: []int64{10022, 10023, 10024}, - ServerKeys: []string{"it70"}, - IncludePlayer: true, - IncludePlayerTribe: true, - Count: true, - }, - expectedVillages: []expectedVillage{ - { - id: 10022, - serverKey: "it70", - playerID: 578014, - tribeID: 1, - }, - { - id: 10023, - serverKey: "it70", - playerID: 578014, - tribeID: 1, - }, - { - id: 10024, - serverKey: "it70", - playerID: 0, - tribeID: 0, - }, - }, - expectedCount: 3, - }, - { - name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=1,Count=true", + name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=1", params: domain.ListVillagesParams{ IDs: []int64{10022, 10023, 10024}, ServerKeys: []string{"it70"}, Pagination: domain.Pagination{ Limit: 1, }, - Count: true, }, expectedVillages: []expectedVillage{ { @@ -321,7 +221,7 @@ func TestVillage_List(t *testing.T) { expectedCount: 3, }, { - name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=100,Offset=1,Count=true", + name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=100,Offset=1", params: domain.ListVillagesParams{ IDs: []int64{10022, 10023, 10024}, ServerKeys: []string{"it70"}, @@ -329,7 +229,6 @@ func TestVillage_List(t *testing.T) { Limit: 100, Offset: 1, }, - Count: true, }, expectedVillages: []expectedVillage{ { @@ -351,14 +250,14 @@ func TestVillage_List(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - res, count, err := repo.List(context.Background(), tt.params) + resListCountWithRelations, count, err := repo.ListCountWithRelations(context.Background(), tt.params) assert.NoError(t, err) assert.Equal(t, tt.expectedCount, count) - assert.Len(t, res, len(tt.expectedVillages)) + assert.Len(t, resListCountWithRelations, len(tt.expectedVillages)) for _, expVillage := range tt.expectedVillages { found := false - for _, village := range res { + for _, village := range resListCountWithRelations { if village.ID != expVillage.id { continue } @@ -367,11 +266,19 @@ func TestVillage_List(t *testing.T) { continue } - if village.Player.Player.ID != expVillage.playerID { + if village.PlayerID != 0 && (!village.Player.Valid || village.Player.Player.ID != village.PlayerID) { continue } - if village.Player.Player.Tribe.Tribe.ID != expVillage.tribeID { + if village.Player.Player.Tribe.Valid && village.Player.Player.Tribe.Tribe.ID == 0 { + continue + } + + if !village.Player.Player.Tribe.Valid && village.Player.Player.Tribe.Tribe.ID != 0 { + continue + } + + if village.PlayerID == 0 && (village.Player.Valid || village.Player.Player.ID != village.PlayerID) { continue } @@ -379,23 +286,25 @@ func TestVillage_List(t *testing.T) { break } - assert.True(t, found, "village (id=%d,playerID=%d,tribeID=%d,serverkey=%s) not found", - expVillage.id, - expVillage.playerID, - expVillage.tribeID, - expVillage.serverKey, - ) + assert.True(t, found, "village (id=%d,serverkey=%s) not found", expVillage.id, expVillage.serverKey) + } + + resList, err := repo.List(context.Background(), tt.params) + assert.NoError(t, err) + assert.Len(t, resList, len(resListCountWithRelations)) + for i, village := range resList { + assert.Equal(t, resListCountWithRelations[i].Village, village) } }) } } -func assertCreatedVillages(tb testing.TB, params []domain.CreateVillageParams, villages []domain.VillageWithRelations) { +func assertCreatedVillages(tb testing.TB, params []domain.CreateVillageParams, villages []domain.Village) { tb.Helper() assert.Len(tb, villages, len(params)) for _, p := range params { - var village domain.VillageWithRelations + var village domain.Village for _, v := range villages { if v.ID == p.ID && v.ServerKey == p.ServerKey { village = v diff --git a/internal/domain/village.go b/internal/domain/village.go index b375c36..4297816 100644 --- a/internal/domain/village.go +++ b/internal/domain/village.go @@ -61,12 +61,9 @@ type CreateVillageParams struct { } type ListVillagesParams struct { - IDs []int64 - ServerKeys []string - IncludePlayer bool // IncludePlayer doesn't include tribe player is in. Check IncludePlayerTribe for that. - IncludePlayerTribe bool // IncludePlayerTribe also automatically includes player, even if IncludePlayer == false. - Pagination Pagination - Count bool + IDs []int64 + ServerKeys []string + Pagination Pagination } type RefreshVillagesResult struct { diff --git a/internal/router/rest/village.go b/internal/router/rest/village.go index 6a9670e..97cc5e4 100644 --- a/internal/router/rest/village.go +++ b/internal/router/rest/village.go @@ -17,13 +17,8 @@ const ( //counterfeiter:generate -o internal/mock/village_service.gen.go . VillageService type VillageService interface { - List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) - GetByServerKeyAndID( - ctx context.Context, - serverKey string, - id int64, - includePlayer, includePlayerTribe bool, - ) (domain.VillageWithRelations, error) + ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) + GetByServerKeyAndIDWithRelations(ctx context.Context, serverKey string, id int64) (domain.VillageWithRelations, error) } type village struct { @@ -49,10 +44,7 @@ func (v *village) list(w http.ResponseWriter, r *http.Request) { var err error ctx := r.Context() params := domain.ListVillagesParams{ - ServerKeys: []string{chi.URLParamFromCtx(ctx, "serverKey")}, - IncludePlayer: true, - IncludePlayerTribe: true, - Count: true, + ServerKeys: []string{chi.URLParamFromCtx(ctx, "serverKey")}, } query := queryParams{r.URL.Query()} @@ -62,7 +54,7 @@ func (v *village) list(w http.ResponseWriter, r *http.Request) { return } - villages, count, err := v.svc.List(ctx, params) + villages, count, err := v.svc.ListCountWithRelations(ctx, params) if err != nil { renderErr(w, err) return @@ -95,13 +87,7 @@ func (v *village) getByID(w http.ResponseWriter, r *http.Request) { return } - vlg, err := v.svc.GetByServerKeyAndID( - ctx, - routeCtx.URLParam("serverKey"), - villageID, - true, - true, - ) + vlg, err := v.svc.GetByServerKeyAndIDWithRelations(ctx, routeCtx.URLParam("serverKey"), villageID) if err != nil { renderErr(w, err) return diff --git a/internal/router/rest/village_test.go b/internal/router/rest/village_test.go index 410d3a8..f5f59ff 100644 --- a/internal/router/rest/village_test.go +++ b/internal/router/rest/village_test.go @@ -50,16 +50,13 @@ func TestVillage_list(t *testing.T) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { versionSvc.GetByCodeReturns(version, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) - villageSvc.ListCalls(func(_ context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { + villageSvc.ListCountWithRelationsCalls(func(_ context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { expectedParams := domain.ListVillagesParams{ - ServerKeys: []string{server.Key}, - IncludePlayer: true, - IncludePlayerTribe: true, + ServerKeys: []string{server.Key}, Pagination: domain.Pagination{ Limit: 500, Offset: 0, }, - Count: true, } if diff := cmp.Diff(params, expectedParams); diff != "" { @@ -181,19 +178,16 @@ func TestVillage_list(t *testing.T) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { versionSvc.GetByCodeReturns(version, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) - villageSvc.ListCalls(func( + villageSvc.ListCountWithRelationsCalls(func( _ context.Context, params domain.ListVillagesParams, ) ([]domain.VillageWithRelations, int64, error) { expectedParams := domain.ListVillagesParams{ - ServerKeys: []string{server.Key}, - IncludePlayer: true, - IncludePlayerTribe: true, + ServerKeys: []string{server.Key}, Pagination: domain.Pagination{ Limit: 1, Offset: 15, }, - Count: true, } if diff := cmp.Diff(params, expectedParams); diff != "" { @@ -393,7 +387,7 @@ func TestVillage_getByID(t *testing.T) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { versionSvc.GetByCodeReturns(version, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) - villageSvc.GetByServerKeyAndIDReturns(domain.VillageWithRelations{ + villageSvc.GetByServerKeyAndIDWithRelationsReturns(domain.VillageWithRelations{ Village: domain.Village{ ID: 1234, Name: "name", @@ -523,7 +517,7 @@ func TestVillage_getByID(t *testing.T) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { versionSvc.GetByCodeReturns(version, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) - villageSvc.GetByServerKeyAndIDReturns(domain.VillageWithRelations{}, domain.VillageNotFoundError{ID: 12345551}) + villageSvc.GetByServerKeyAndIDWithRelationsReturns(domain.VillageWithRelations{}, domain.VillageNotFoundError{ID: 12345551}) }, versionCode: version.Code, serverKey: server.Key, diff --git a/internal/service/village.go b/internal/service/village.go index 208173e..8f2be62 100644 --- a/internal/service/village.go +++ b/internal/service/village.go @@ -16,7 +16,8 @@ const ( //counterfeiter:generate -o internal/mock/village_repository.gen.go . VillageRepository type VillageRepository interface { CreateOrUpdate(ctx context.Context, params ...domain.CreateVillageParams) error - List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) + List(ctx context.Context, params domain.ListVillagesParams) ([]domain.Village, error) + ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) } //counterfeiter:generate -o internal/mock/village_getter.gen.go . VillageGetter @@ -82,7 +83,7 @@ func (v *Village) Refresh(ctx context.Context, key, url string) (domain.RefreshV return res, nil } -func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { +func (v *Village) ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { if params.Pagination.Limit == 0 { params.Pagination.Limit = villageMaxLimit } @@ -91,25 +92,18 @@ func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([ return nil, 0, fmt.Errorf("validatePagination: %w", err) } - villages, count, err := v.repo.List(ctx, params) + villages, count, err := v.repo.ListCountWithRelations(ctx, params) if err != nil { - return nil, 0, fmt.Errorf("VillageRepository.List: %w", err) + return nil, 0, fmt.Errorf("VillageRepository.ListCountWithRelations: %w", err) } return villages, count, nil } -func (v *Village) GetByServerKeyAndID( - ctx context.Context, - serverKey string, - id int64, - includePlayer, includePlayerTribe bool, -) (domain.VillageWithRelations, error) { - villages, _, err := v.repo.List(ctx, domain.ListVillagesParams{ - IDs: []int64{id}, - ServerKeys: []string{serverKey}, - IncludePlayer: includePlayer, - IncludePlayerTribe: includePlayerTribe, +func (v *Village) GetByServerKeyAndIDWithRelations(ctx context.Context, serverKey string, id int64) (domain.VillageWithRelations, error) { + villages, _, err := v.repo.ListCountWithRelations(ctx, domain.ListVillagesParams{ + IDs: []int64{id}, + ServerKeys: []string{serverKey}, Pagination: domain.Pagination{ Limit: 1, }, diff --git a/internal/service/village_test.go b/internal/service/village_test.go index 92fa4e9..127a36d 100644 --- a/internal/service/village_test.go +++ b/internal/service/village_test.go @@ -98,7 +98,7 @@ func TestVillage_Refresh(t *testing.T) { } } -func TestVillage_List(t *testing.T) { +func TestVillage_ListCountWithRelations(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { @@ -130,7 +130,7 @@ func TestVillage_List(t *testing.T) { } repo := &mock.FakeVillageRepository{} - repo.ListCalls(func( + repo.ListCountWithRelationsCalls(func( _ context.Context, params domain.ListVillagesParams, ) ([]domain.VillageWithRelations, int64, error) { @@ -149,7 +149,7 @@ func TestVillage_List(t *testing.T) { client := &mock.FakeVillageGetter{} players, count, err := service.NewVillage(repo, client). - List(context.Background(), domain.ListVillagesParams{ + ListCountWithRelations(context.Background(), domain.ListVillagesParams{ Pagination: domain.Pagination{ Limit: tt.limit, }, @@ -223,11 +223,11 @@ func TestVillage_List(t *testing.T) { client := &mock.FakeVillageGetter{} players, count, err := service.NewVillage(repo, client). - List(context.Background(), tt.params) + ListCountWithRelations(context.Background(), tt.params) assert.ErrorIs(t, err, tt.expectedErr) assert.Zero(t, players) assert.Zero(t, count) - assert.Equal(t, 0, repo.ListCallCount()) + assert.Equal(t, 0, repo.ListCountWithRelationsCallCount()) }) } }) @@ -261,20 +261,17 @@ func TestVillage_GetByServerKeyAndID(t *testing.T) { }, } repo := &mock.FakeVillageRepository{} - repo.ListReturns([]domain.VillageWithRelations{village}, 0, nil) + repo.ListCountWithRelationsReturns([]domain.VillageWithRelations{village}, 1, nil) client := &mock.FakeVillageGetter{} - res, err := service.NewVillage(repo, client). - GetByServerKeyAndID(context.Background(), village.ServerKey, village.ID, true, true) + res, err := service.NewVillage(repo, client).GetByServerKeyAndIDWithRelations(context.Background(), village.ServerKey, village.ID) assert.NoError(t, err) assert.Equal(t, village, res) - require.Equal(t, 1, repo.ListCallCount()) - _, params := repo.ListArgsForCall(0) + require.Equal(t, 1, repo.ListCountWithRelationsCallCount()) + _, params := repo.ListCountWithRelationsArgsForCall(0) assert.Equal(t, domain.ListVillagesParams{ - IDs: []int64{village.ID}, - ServerKeys: []string{village.ServerKey}, - IncludePlayer: true, - IncludePlayerTribe: true, + IDs: []int64{village.ID}, + ServerKeys: []string{village.ServerKey}, Pagination: domain.Pagination{ Limit: 1, }, @@ -289,8 +286,7 @@ func TestVillage_GetByServerKeyAndID(t *testing.T) { client := &mock.FakeVillageGetter{} var id int64 = 123 - res, err := service.NewVillage(repo, client). - GetByServerKeyAndID(context.Background(), "pl151", id, false, false) + res, err := service.NewVillage(repo, client).GetByServerKeyAndIDWithRelations(context.Background(), "pl151", id) assert.ErrorIs(t, err, domain.VillageNotFoundError{ID: id}) assert.Zero(t, res) })