refactor: village - split the List method into two (#145)
All checks were successful
continuous-integration/drone/push Build is passing

Reviewed-on: twhelp/core#145
This commit is contained in:
Dawid Wysokiński 2022-12-24 07:44:25 +00:00
parent 6c8b30d061
commit f636d6ac94
14 changed files with 149 additions and 271 deletions

View File

@ -56,19 +56,6 @@ func scanAndCount(
var count int var count int
var err error var err error
wg.Add(1)
go func() {
defer wg.Done()
if scanErr := scanQ.Scan(ctx); scanErr == nil {
mu.Lock()
if err == nil {
err = scanErr
}
mu.Unlock()
}
}()
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
@ -83,6 +70,19 @@ func scanAndCount(
} }
}() }()
wg.Add(1)
go func() {
defer wg.Done()
if scanErr := scanQ.Scan(ctx); scanErr == nil {
mu.Lock()
if err == nil {
err = scanErr
}
mu.Unlock()
}
}()
wg.Wait() wg.Wait()
return count, err return count, err

View File

@ -43,7 +43,6 @@ func (e *Ennoblement) List(
ctx context.Context, ctx context.Context,
params domain.ListEnnoblementsParams, params domain.ListEnnoblementsParams,
) ([]domain.EnnoblementWithRelations, int64, error) { ) ([]domain.EnnoblementWithRelations, int64, error) {
var ennoblements []model.Ennoblement var ennoblements []model.Ennoblement
var count int var count int
var err error var err error
@ -53,18 +52,18 @@ func (e *Ennoblement) List(
base := e.db.NewSelect(). base := e.db.NewSelect().
Model(&model.Ennoblement{}). Model(&model.Ennoblement{}).
Order("ennoblement.server_key ASC"). Order("ennoblement.server_key ASC").
Apply(paramsApplier.applyFilters). Apply(paramsApplier.applyFilters)
Apply(paramsApplier.applyPagination)
base, err = paramsApplier.applySort(base)
if err != nil {
return nil, 0, fmt.Errorf("listEnnoblementsParamsApplier.applySort: %w", err)
}
q := e.db.NewSelect(). q := e.db.NewSelect().
With("ennoblements_base", base). With("ennoblements_base", base).
Model(&ennoblements). Model(&ennoblements).
ModelTableExpr("ennoblements_base AS ennoblement"). ModelTableExpr("ennoblements_base AS ennoblement").
Apply(paramsApplier.applyRelations) Apply(paramsApplier.applyRelations).
Apply(paramsApplier.applyPagination)
q, err = paramsApplier.applySort(q)
if err != nil {
return nil, 0, err
}
if params.Count { if params.Count {
count, err = scanAndCount(ctx, base, q) count, err = scanAndCount(ctx, base, q)

View File

@ -14,6 +14,7 @@ import (
var ( var (
playerMetaColumns = []string{"id", "name", "profile_url"} playerMetaColumns = []string{"id", "name", "profile_url"}
playerOrders = []string{"player.server_key ASC"}
) )
type Player struct { type Player struct {
@ -82,13 +83,15 @@ func (p *Player) Delete(ctx context.Context, serverKey string, ids ...int64) err
func (p *Player) List(ctx context.Context, params domain.ListPlayersParams) ([]domain.Player, error) { func (p *Player) List(ctx context.Context, params domain.ListPlayersParams) ([]domain.Player, error) {
var players []model.Player var players []model.Player
q, err := p.baseListQuery(&players, params) q := p.db.NewSelect().
Model(&players).
Order(playerOrders...)
q, err := listPlayersParamsApplier{params: params}.apply(q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = q.Scan(ctx) if err = q.Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("couldn't select players from the db: %w", err) return nil, fmt.Errorf("couldn't select players from the db: %w", err)
} }
@ -103,19 +106,27 @@ func (p *Player) List(ctx context.Context, params domain.ListPlayersParams) ([]d
func (p *Player) ListCountWithRelations(ctx context.Context, params domain.ListPlayersParams) ([]domain.PlayerWithRelations, int64, error) { func (p *Player) ListCountWithRelations(ctx context.Context, params domain.ListPlayersParams) ([]domain.PlayerWithRelations, int64, error) {
var players []model.Player var players []model.Player
base, err := p.baseListQuery(&model.Player{}, params) paramsApplier := listPlayersParamsApplier{params}
if err != nil {
return nil, 0, err base := p.db.NewSelect().
} Model(&model.Player{}).
Apply(paramsApplier.applyFilters)
q := p.db.NewSelect(). q := p.db.NewSelect().
With("players_base", base). With("players_base", base).
Model(&players). Model(&players).
Order(playerOrders...).
Apply(paramsApplier.applyPagination).
ModelTableExpr("players_base AS player"). ModelTableExpr("players_base AS player").
Relation("Tribe", func(q *bun.SelectQuery) *bun.SelectQuery { Relation("Tribe", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Column(tribeMetaColumns...) return q.Column(tribeMetaColumns...)
}) })
q, err := paramsApplier.applySort(q)
if err != nil {
return nil, 0, err
}
count, err := scanAndCount(ctx, base, q) count, err := scanAndCount(ctx, base, q)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, 0, fmt.Errorf("couldn't select players from the db: %w", err) return nil, 0, fmt.Errorf("couldn't select players from the db: %w", err)
@ -129,19 +140,6 @@ func (p *Player) ListCountWithRelations(ctx context.Context, params domain.ListP
return result, int64(count), nil return result, int64(count), nil
} }
func (p *Player) baseListQuery(model any, params domain.ListPlayersParams) (*bun.SelectQuery, error) {
q := p.db.NewSelect().
Model(model).
Order("player.server_key ASC")
q, err := listPlayersParamsApplier{params}.apply(q)
if err != nil {
return nil, fmt.Errorf("listPlayersParamsApplier.apply: %w", err)
}
return q, nil
}
type listPlayersParamsApplier struct { type listPlayersParamsApplier struct {
params domain.ListPlayersParams params domain.ListPlayersParams
} }

View File

@ -600,7 +600,7 @@ func TestPlayer_List_ListCountWithRelations(t *testing.T) {
continue continue
} }
if player.TribeID > 0 && (!player.Tribe.Valid || player.Tribe.Tribe.ID != player.TribeID) { if player.TribeID != 0 && (!player.Tribe.Valid || player.Tribe.Tribe.ID != player.TribeID) {
continue continue
} }

View File

@ -58,8 +58,7 @@ func (s *Server) CreateOrUpdate(ctx context.Context, params ...domain.CreateServ
func (s *Server) List(ctx context.Context, params domain.ListServersParams) ([]domain.Server, error) { func (s *Server) List(ctx context.Context, params domain.ListServersParams) ([]domain.Server, error) {
var servers []model.Server var servers []model.Server
err := s.baseListQuery(&servers, params).Scan(ctx) if err := s.baseListQuery(&servers, params).Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("couldn't select servers from the db: %w", err) return nil, fmt.Errorf("couldn't select servers from the db: %w", err)
} }

View File

@ -105,8 +105,7 @@ func (t *Tribe) List(ctx context.Context, params domain.ListTribesParams) ([]dom
return nil, err return nil, err
} }
err = q.Scan(ctx) if err = q.Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("couldn't select tribes from the db: %w", err) return nil, fmt.Errorf("couldn't select tribes from the db: %w", err)
} }

View File

@ -22,8 +22,7 @@ func NewVersion(db *bun.DB) *Version {
func (v *Version) List(ctx context.Context) ([]domain.Version, error) { func (v *Version) List(ctx context.Context) ([]domain.Version, error) {
var versions []model.Version var versions []model.Version
err := v.baseListQuery(&versions).Scan(ctx) if err := v.baseListQuery(&versions).Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("couldn't select versions from the db: %w", err) return nil, fmt.Errorf("couldn't select versions from the db: %w", err)
} }

View File

@ -13,6 +13,7 @@ import (
var ( var (
villageMetaColumns = []string{"id", "name", "x", "y", "continent", "profile_url"} villageMetaColumns = []string{"id", "name", "x", "y", "continent", "profile_url"}
villageOrders = []string{"village.server_key ASC", "village.id ASC"}
) )
type Village struct { type Village struct {
@ -51,30 +52,49 @@ func (v *Village) CreateOrUpdate(ctx context.Context, params ...domain.CreateVil
return nil return nil
} }
func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.Village, error) {
var villages []model.Village
err := v.db.NewSelect().
Model(&villages).
Order(villageOrders...).
Apply(listVillagesParamsApplier{params}.apply).
Scan(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("couldn't select villages from the db: %w", err)
}
result := make([]domain.Village, 0, len(villages))
for _, village := range villages {
result = append(result, village.ToDomain())
}
return result, nil
}
func (v *Village) ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) {
var villages []model.Village var villages []model.Village
var count int
var err error
paramsApplier := listVillagesParamsApplier{params} paramsApplier := listVillagesParamsApplier{params}
base := v.db.NewSelect(). base := v.db.NewSelect().
Model(&model.Village{}). Model(&model.Village{}).
Order("village.server_key ASC", "village.id ASC"). Apply(paramsApplier.applyFilters)
Apply(paramsApplier.applyFilters).
Apply(paramsApplier.applyPagination)
q := v.db.NewSelect(). q := v.db.NewSelect().
With("villages_base", base). With("villages_base", base).
Model(&villages). Model(&villages).
ModelTableExpr("villages_base AS village"). ModelTableExpr("villages_base AS village").
Apply(paramsApplier.applyRelations) Order(villageOrders...).
Apply(paramsApplier.applyPagination).
Relation("Player", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Column(playerMetaColumns...)
}).
Relation("Player.Tribe", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Column(tribeMetaColumns...)
})
if params.Count { count, err := scanAndCount(ctx, base, q)
count, err = scanAndCount(ctx, base, q)
} else {
err = q.Scan(ctx)
}
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, 0, fmt.Errorf("couldn't select villages from the db: %w", err) return nil, 0, fmt.Errorf("couldn't select villages from the db: %w", err)
} }
@ -91,6 +111,10 @@ type listVillagesParamsApplier struct {
params domain.ListVillagesParams params domain.ListVillagesParams
} }
func (l listVillagesParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery {
return l.applyPagination(l.applyFilters(q))
}
func (l listVillagesParamsApplier) applyFilters(q *bun.SelectQuery) *bun.SelectQuery { func (l listVillagesParamsApplier) applyFilters(q *bun.SelectQuery) *bun.SelectQuery {
if l.params.IDs != nil { if l.params.IDs != nil {
q = q.Where("village.id IN (?)", bun.In(l.params.IDs)) q = q.Where("village.id IN (?)", bun.In(l.params.IDs))
@ -103,22 +127,6 @@ func (l listVillagesParamsApplier) applyFilters(q *bun.SelectQuery) *bun.SelectQ
return q return q
} }
func (l listVillagesParamsApplier) applyRelations(q *bun.SelectQuery) *bun.SelectQuery {
if l.params.IncludePlayer {
q = q.Relation("Player", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Column(playerMetaColumns...)
})
}
if l.params.IncludePlayerTribe {
q = q.Relation("Player.Tribe", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Column(tribeMetaColumns...)
})
}
return q
}
func (l listVillagesParamsApplier) applyPagination(q *bun.SelectQuery) *bun.SelectQuery { func (l listVillagesParamsApplier) applyPagination(q *bun.SelectQuery) *bun.SelectQuery {
return (paginationApplier{pagination: l.params.Pagination}).apply(q) return (paginationApplier{pagination: l.params.Pagination}).apply(q)
} }

View File

@ -84,12 +84,12 @@ func TestVillage_CreateOrUpdate(t *testing.T) {
ctx := context.Background() ctx := context.Background()
assert.NoError(t, repo.CreateOrUpdate(ctx, createParams...)) assert.NoError(t, repo.CreateOrUpdate(ctx, createParams...))
createdVillages, _, err := repo.List(ctx, listParams) createdVillages, err := repo.List(ctx, listParams)
assert.NoError(t, err) assert.NoError(t, err)
assertCreatedVillages(t, createParams, createdVillages) assertCreatedVillages(t, createParams, createdVillages)
assert.NoError(t, repo.CreateOrUpdate(ctx, updateParams...)) assert.NoError(t, repo.CreateOrUpdate(ctx, updateParams...))
updatedVillages, _, err := repo.List(ctx, listParams) updatedVillages, err := repo.List(ctx, listParams)
assert.NoError(t, err) assert.NoError(t, err)
assertCreatedVillages(t, updateParams, updatedVillages) assertCreatedVillages(t, updateParams, updatedVillages)
}) })
@ -121,7 +121,7 @@ func TestVillage_CreateOrUpdate(t *testing.T) {
}) })
} }
func TestVillage_List(t *testing.T) { func TestVillage_List_ListCountWithRelations(t *testing.T) {
t.Parallel() t.Parallel()
db := newDB(t) db := newDB(t)
@ -132,8 +132,6 @@ func TestVillage_List(t *testing.T) {
type expectedVillage struct { type expectedVillage struct {
id int64 id int64
serverKey string serverKey string
playerID int64
tribeID int64
} }
allVillages := make([]expectedVillage, 0, len(villages)) allVillages := make([]expectedVillage, 0, len(villages))
@ -164,26 +162,20 @@ func TestVillage_List(t *testing.T) {
expectedCount int64 expectedCount int64
}{ }{
{ {
name: "Count=true", name: "Empty struct",
params: domain.ListVillagesParams{Count: true}, params: domain.ListVillagesParams{},
expectedVillages: allVillages, expectedVillages: allVillages,
expectedCount: int64(len(allVillages)), expectedCount: int64(len(allVillages)),
}, },
{ {
name: "Count=false", name: "ServerKeys=[it70]",
params: domain.ListVillagesParams{Count: false}, params: domain.ListVillagesParams{ServerKeys: []string{"it70"}},
expectedVillages: allVillages,
expectedCount: 0,
},
{
name: "ServerKeys=[it70],Count=true",
params: domain.ListVillagesParams{ServerKeys: []string{"it70"}, Count: true},
expectedVillages: villagesIT70, expectedVillages: villagesIT70,
expectedCount: int64(len(villagesIT70)), expectedCount: int64(len(villagesIT70)),
}, },
{ {
name: "IDs=[10022, 1114],Count=true", name: "IDs=[10022, 1114]",
params: domain.ListVillagesParams{IDs: []int64{10022, 1114}, Count: true}, params: domain.ListVillagesParams{IDs: []int64{10022, 1114}},
expectedVillages: []expectedVillage{ expectedVillages: []expectedVillage{
{ {
id: 10022, id: 10022,
@ -197,8 +189,8 @@ func TestVillage_List(t *testing.T) {
expectedCount: 2, expectedCount: 2,
}, },
{ {
name: "IDs=[1113, 1114],ServerKeys=[pl169],Count=true", name: "IDs=[1113, 1114],ServerKeys=[pl169]",
params: domain.ListVillagesParams{IDs: []int64{1113, 1114}, ServerKeys: []string{"pl169"}, Count: true}, params: domain.ListVillagesParams{IDs: []int64{1113, 1114}, ServerKeys: []string{"pl169"}},
expectedVillages: []expectedVillage{ expectedVillages: []expectedVillage{
{ {
id: 1113, id: 1113,
@ -212,105 +204,13 @@ func TestVillage_List(t *testing.T) {
expectedCount: 2, expectedCount: 2,
}, },
{ {
name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],IncludePlayer=true,Count=true", name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=1",
params: domain.ListVillagesParams{
IDs: []int64{10022, 10023, 10024},
ServerKeys: []string{"it70"},
IncludePlayer: true,
Count: true,
},
expectedVillages: []expectedVillage{
{
id: 10022,
serverKey: "it70",
playerID: 578014,
tribeID: 0,
},
{
id: 10023,
serverKey: "it70",
playerID: 578014,
tribeID: 0,
},
{
id: 10024,
serverKey: "it70",
playerID: 0,
tribeID: 0,
},
},
expectedCount: 3,
},
{
name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],IncludePlayerTribe=true,Count=true",
params: domain.ListVillagesParams{
IDs: []int64{10022, 10023, 10024},
ServerKeys: []string{"it70"},
IncludePlayerTribe: true,
Count: true,
},
expectedVillages: []expectedVillage{
{
id: 10022,
serverKey: "it70",
playerID: 578014,
tribeID: 1,
},
{
id: 10023,
serverKey: "it70",
playerID: 578014,
tribeID: 1,
},
{
id: 10024,
serverKey: "it70",
playerID: 0,
tribeID: 0,
},
},
expectedCount: 3,
},
{
name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],IncludePlayer=true,IncludePlayerTribe=true,Count=true",
params: domain.ListVillagesParams{
IDs: []int64{10022, 10023, 10024},
ServerKeys: []string{"it70"},
IncludePlayer: true,
IncludePlayerTribe: true,
Count: true,
},
expectedVillages: []expectedVillage{
{
id: 10022,
serverKey: "it70",
playerID: 578014,
tribeID: 1,
},
{
id: 10023,
serverKey: "it70",
playerID: 578014,
tribeID: 1,
},
{
id: 10024,
serverKey: "it70",
playerID: 0,
tribeID: 0,
},
},
expectedCount: 3,
},
{
name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=1,Count=true",
params: domain.ListVillagesParams{ params: domain.ListVillagesParams{
IDs: []int64{10022, 10023, 10024}, IDs: []int64{10022, 10023, 10024},
ServerKeys: []string{"it70"}, ServerKeys: []string{"it70"},
Pagination: domain.Pagination{ Pagination: domain.Pagination{
Limit: 1, Limit: 1,
}, },
Count: true,
}, },
expectedVillages: []expectedVillage{ expectedVillages: []expectedVillage{
{ {
@ -321,7 +221,7 @@ func TestVillage_List(t *testing.T) {
expectedCount: 3, expectedCount: 3,
}, },
{ {
name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=100,Offset=1,Count=true", name: "IDs=[10022, 10023, 10024],ServerKeys=[it70],Limit=100,Offset=1",
params: domain.ListVillagesParams{ params: domain.ListVillagesParams{
IDs: []int64{10022, 10023, 10024}, IDs: []int64{10022, 10023, 10024},
ServerKeys: []string{"it70"}, ServerKeys: []string{"it70"},
@ -329,7 +229,6 @@ func TestVillage_List(t *testing.T) {
Limit: 100, Limit: 100,
Offset: 1, Offset: 1,
}, },
Count: true,
}, },
expectedVillages: []expectedVillage{ expectedVillages: []expectedVillage{
{ {
@ -351,14 +250,14 @@ func TestVillage_List(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
res, count, err := repo.List(context.Background(), tt.params) resListCountWithRelations, count, err := repo.ListCountWithRelations(context.Background(), tt.params)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tt.expectedCount, count) assert.Equal(t, tt.expectedCount, count)
assert.Len(t, res, len(tt.expectedVillages)) assert.Len(t, resListCountWithRelations, len(tt.expectedVillages))
for _, expVillage := range tt.expectedVillages { for _, expVillage := range tt.expectedVillages {
found := false found := false
for _, village := range res { for _, village := range resListCountWithRelations {
if village.ID != expVillage.id { if village.ID != expVillage.id {
continue continue
} }
@ -367,11 +266,19 @@ func TestVillage_List(t *testing.T) {
continue continue
} }
if village.Player.Player.ID != expVillage.playerID { if village.PlayerID != 0 && (!village.Player.Valid || village.Player.Player.ID != village.PlayerID) {
continue continue
} }
if village.Player.Player.Tribe.Tribe.ID != expVillage.tribeID { if village.Player.Player.Tribe.Valid && village.Player.Player.Tribe.Tribe.ID == 0 {
continue
}
if !village.Player.Player.Tribe.Valid && village.Player.Player.Tribe.Tribe.ID != 0 {
continue
}
if village.PlayerID == 0 && (village.Player.Valid || village.Player.Player.ID != village.PlayerID) {
continue continue
} }
@ -379,23 +286,25 @@ func TestVillage_List(t *testing.T) {
break break
} }
assert.True(t, found, "village (id=%d,playerID=%d,tribeID=%d,serverkey=%s) not found", assert.True(t, found, "village (id=%d,serverkey=%s) not found", expVillage.id, expVillage.serverKey)
expVillage.id, }
expVillage.playerID,
expVillage.tribeID, resList, err := repo.List(context.Background(), tt.params)
expVillage.serverKey, assert.NoError(t, err)
) assert.Len(t, resList, len(resListCountWithRelations))
for i, village := range resList {
assert.Equal(t, resListCountWithRelations[i].Village, village)
} }
}) })
} }
} }
func assertCreatedVillages(tb testing.TB, params []domain.CreateVillageParams, villages []domain.VillageWithRelations) { func assertCreatedVillages(tb testing.TB, params []domain.CreateVillageParams, villages []domain.Village) {
tb.Helper() tb.Helper()
assert.Len(tb, villages, len(params)) assert.Len(tb, villages, len(params))
for _, p := range params { for _, p := range params {
var village domain.VillageWithRelations var village domain.Village
for _, v := range villages { for _, v := range villages {
if v.ID == p.ID && v.ServerKey == p.ServerKey { if v.ID == p.ID && v.ServerKey == p.ServerKey {
village = v village = v

View File

@ -63,10 +63,7 @@ type CreateVillageParams struct {
type ListVillagesParams struct { type ListVillagesParams struct {
IDs []int64 IDs []int64
ServerKeys []string ServerKeys []string
IncludePlayer bool // IncludePlayer doesn't include tribe player is in. Check IncludePlayerTribe for that.
IncludePlayerTribe bool // IncludePlayerTribe also automatically includes player, even if IncludePlayer == false.
Pagination Pagination Pagination Pagination
Count bool
} }
type RefreshVillagesResult struct { type RefreshVillagesResult struct {

View File

@ -17,13 +17,8 @@ const (
//counterfeiter:generate -o internal/mock/village_service.gen.go . VillageService //counterfeiter:generate -o internal/mock/village_service.gen.go . VillageService
type VillageService interface { type VillageService interface {
List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error)
GetByServerKeyAndID( GetByServerKeyAndIDWithRelations(ctx context.Context, serverKey string, id int64) (domain.VillageWithRelations, error)
ctx context.Context,
serverKey string,
id int64,
includePlayer, includePlayerTribe bool,
) (domain.VillageWithRelations, error)
} }
type village struct { type village struct {
@ -50,9 +45,6 @@ func (v *village) list(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
params := domain.ListVillagesParams{ params := domain.ListVillagesParams{
ServerKeys: []string{chi.URLParamFromCtx(ctx, "serverKey")}, ServerKeys: []string{chi.URLParamFromCtx(ctx, "serverKey")},
IncludePlayer: true,
IncludePlayerTribe: true,
Count: true,
} }
query := queryParams{r.URL.Query()} query := queryParams{r.URL.Query()}
@ -62,7 +54,7 @@ func (v *village) list(w http.ResponseWriter, r *http.Request) {
return return
} }
villages, count, err := v.svc.List(ctx, params) villages, count, err := v.svc.ListCountWithRelations(ctx, params)
if err != nil { if err != nil {
renderErr(w, err) renderErr(w, err)
return return
@ -95,13 +87,7 @@ func (v *village) getByID(w http.ResponseWriter, r *http.Request) {
return return
} }
vlg, err := v.svc.GetByServerKeyAndID( vlg, err := v.svc.GetByServerKeyAndIDWithRelations(ctx, routeCtx.URLParam("serverKey"), villageID)
ctx,
routeCtx.URLParam("serverKey"),
villageID,
true,
true,
)
if err != nil { if err != nil {
renderErr(w, err) renderErr(w, err)
return return

View File

@ -50,16 +50,13 @@ func TestVillage_list(t *testing.T) {
setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) {
versionSvc.GetByCodeReturns(version, nil) versionSvc.GetByCodeReturns(version, nil)
serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil)
villageSvc.ListCalls(func(_ context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { villageSvc.ListCountWithRelationsCalls(func(_ context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) {
expectedParams := domain.ListVillagesParams{ expectedParams := domain.ListVillagesParams{
ServerKeys: []string{server.Key}, ServerKeys: []string{server.Key},
IncludePlayer: true,
IncludePlayerTribe: true,
Pagination: domain.Pagination{ Pagination: domain.Pagination{
Limit: 500, Limit: 500,
Offset: 0, Offset: 0,
}, },
Count: true,
} }
if diff := cmp.Diff(params, expectedParams); diff != "" { if diff := cmp.Diff(params, expectedParams); diff != "" {
@ -181,19 +178,16 @@ func TestVillage_list(t *testing.T) {
setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) {
versionSvc.GetByCodeReturns(version, nil) versionSvc.GetByCodeReturns(version, nil)
serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil)
villageSvc.ListCalls(func( villageSvc.ListCountWithRelationsCalls(func(
_ context.Context, _ context.Context,
params domain.ListVillagesParams, params domain.ListVillagesParams,
) ([]domain.VillageWithRelations, int64, error) { ) ([]domain.VillageWithRelations, int64, error) {
expectedParams := domain.ListVillagesParams{ expectedParams := domain.ListVillagesParams{
ServerKeys: []string{server.Key}, ServerKeys: []string{server.Key},
IncludePlayer: true,
IncludePlayerTribe: true,
Pagination: domain.Pagination{ Pagination: domain.Pagination{
Limit: 1, Limit: 1,
Offset: 15, Offset: 15,
}, },
Count: true,
} }
if diff := cmp.Diff(params, expectedParams); diff != "" { if diff := cmp.Diff(params, expectedParams); diff != "" {
@ -393,7 +387,7 @@ func TestVillage_getByID(t *testing.T) {
setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) {
versionSvc.GetByCodeReturns(version, nil) versionSvc.GetByCodeReturns(version, nil)
serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil)
villageSvc.GetByServerKeyAndIDReturns(domain.VillageWithRelations{ villageSvc.GetByServerKeyAndIDWithRelationsReturns(domain.VillageWithRelations{
Village: domain.Village{ Village: domain.Village{
ID: 1234, ID: 1234,
Name: "name", Name: "name",
@ -523,7 +517,7 @@ func TestVillage_getByID(t *testing.T) {
setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) { setup: func(versionSvc *mock.FakeVersionService, serverSvc *mock.FakeServerService, villageSvc *mock.FakeVillageService) {
versionSvc.GetByCodeReturns(version, nil) versionSvc.GetByCodeReturns(version, nil)
serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil) serverSvc.GetNormalByVersionCodeAndKeyReturns(server, nil)
villageSvc.GetByServerKeyAndIDReturns(domain.VillageWithRelations{}, domain.VillageNotFoundError{ID: 12345551}) villageSvc.GetByServerKeyAndIDWithRelationsReturns(domain.VillageWithRelations{}, domain.VillageNotFoundError{ID: 12345551})
}, },
versionCode: version.Code, versionCode: version.Code,
serverKey: server.Key, serverKey: server.Key,

View File

@ -16,7 +16,8 @@ const (
//counterfeiter:generate -o internal/mock/village_repository.gen.go . VillageRepository //counterfeiter:generate -o internal/mock/village_repository.gen.go . VillageRepository
type VillageRepository interface { type VillageRepository interface {
CreateOrUpdate(ctx context.Context, params ...domain.CreateVillageParams) error CreateOrUpdate(ctx context.Context, params ...domain.CreateVillageParams) error
List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.Village, error)
ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error)
} }
//counterfeiter:generate -o internal/mock/village_getter.gen.go . VillageGetter //counterfeiter:generate -o internal/mock/village_getter.gen.go . VillageGetter
@ -82,7 +83,7 @@ func (v *Village) Refresh(ctx context.Context, key, url string) (domain.RefreshV
return res, nil return res, nil
} }
func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) { func (v *Village) ListCountWithRelations(ctx context.Context, params domain.ListVillagesParams) ([]domain.VillageWithRelations, int64, error) {
if params.Pagination.Limit == 0 { if params.Pagination.Limit == 0 {
params.Pagination.Limit = villageMaxLimit params.Pagination.Limit = villageMaxLimit
} }
@ -91,25 +92,18 @@ func (v *Village) List(ctx context.Context, params domain.ListVillagesParams) ([
return nil, 0, fmt.Errorf("validatePagination: %w", err) return nil, 0, fmt.Errorf("validatePagination: %w", err)
} }
villages, count, err := v.repo.List(ctx, params) villages, count, err := v.repo.ListCountWithRelations(ctx, params)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("VillageRepository.List: %w", err) return nil, 0, fmt.Errorf("VillageRepository.ListCountWithRelations: %w", err)
} }
return villages, count, nil return villages, count, nil
} }
func (v *Village) GetByServerKeyAndID( func (v *Village) GetByServerKeyAndIDWithRelations(ctx context.Context, serverKey string, id int64) (domain.VillageWithRelations, error) {
ctx context.Context, villages, _, err := v.repo.ListCountWithRelations(ctx, domain.ListVillagesParams{
serverKey string,
id int64,
includePlayer, includePlayerTribe bool,
) (domain.VillageWithRelations, error) {
villages, _, err := v.repo.List(ctx, domain.ListVillagesParams{
IDs: []int64{id}, IDs: []int64{id},
ServerKeys: []string{serverKey}, ServerKeys: []string{serverKey},
IncludePlayer: includePlayer,
IncludePlayerTribe: includePlayerTribe,
Pagination: domain.Pagination{ Pagination: domain.Pagination{
Limit: 1, Limit: 1,
}, },

View File

@ -98,7 +98,7 @@ func TestVillage_Refresh(t *testing.T) {
} }
} }
func TestVillage_List(t *testing.T) { func TestVillage_ListCountWithRelations(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
@ -130,7 +130,7 @@ func TestVillage_List(t *testing.T) {
} }
repo := &mock.FakeVillageRepository{} repo := &mock.FakeVillageRepository{}
repo.ListCalls(func( repo.ListCountWithRelationsCalls(func(
_ context.Context, _ context.Context,
params domain.ListVillagesParams, params domain.ListVillagesParams,
) ([]domain.VillageWithRelations, int64, error) { ) ([]domain.VillageWithRelations, int64, error) {
@ -149,7 +149,7 @@ func TestVillage_List(t *testing.T) {
client := &mock.FakeVillageGetter{} client := &mock.FakeVillageGetter{}
players, count, err := service.NewVillage(repo, client). players, count, err := service.NewVillage(repo, client).
List(context.Background(), domain.ListVillagesParams{ ListCountWithRelations(context.Background(), domain.ListVillagesParams{
Pagination: domain.Pagination{ Pagination: domain.Pagination{
Limit: tt.limit, Limit: tt.limit,
}, },
@ -223,11 +223,11 @@ func TestVillage_List(t *testing.T) {
client := &mock.FakeVillageGetter{} client := &mock.FakeVillageGetter{}
players, count, err := service.NewVillage(repo, client). players, count, err := service.NewVillage(repo, client).
List(context.Background(), tt.params) ListCountWithRelations(context.Background(), tt.params)
assert.ErrorIs(t, err, tt.expectedErr) assert.ErrorIs(t, err, tt.expectedErr)
assert.Zero(t, players) assert.Zero(t, players)
assert.Zero(t, count) assert.Zero(t, count)
assert.Equal(t, 0, repo.ListCallCount()) assert.Equal(t, 0, repo.ListCountWithRelationsCallCount())
}) })
} }
}) })
@ -261,20 +261,17 @@ func TestVillage_GetByServerKeyAndID(t *testing.T) {
}, },
} }
repo := &mock.FakeVillageRepository{} repo := &mock.FakeVillageRepository{}
repo.ListReturns([]domain.VillageWithRelations{village}, 0, nil) repo.ListCountWithRelationsReturns([]domain.VillageWithRelations{village}, 1, nil)
client := &mock.FakeVillageGetter{} client := &mock.FakeVillageGetter{}
res, err := service.NewVillage(repo, client). res, err := service.NewVillage(repo, client).GetByServerKeyAndIDWithRelations(context.Background(), village.ServerKey, village.ID)
GetByServerKeyAndID(context.Background(), village.ServerKey, village.ID, true, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, village, res) assert.Equal(t, village, res)
require.Equal(t, 1, repo.ListCallCount()) require.Equal(t, 1, repo.ListCountWithRelationsCallCount())
_, params := repo.ListArgsForCall(0) _, params := repo.ListCountWithRelationsArgsForCall(0)
assert.Equal(t, domain.ListVillagesParams{ assert.Equal(t, domain.ListVillagesParams{
IDs: []int64{village.ID}, IDs: []int64{village.ID},
ServerKeys: []string{village.ServerKey}, ServerKeys: []string{village.ServerKey},
IncludePlayer: true,
IncludePlayerTribe: true,
Pagination: domain.Pagination{ Pagination: domain.Pagination{
Limit: 1, Limit: 1,
}, },
@ -289,8 +286,7 @@ func TestVillage_GetByServerKeyAndID(t *testing.T) {
client := &mock.FakeVillageGetter{} client := &mock.FakeVillageGetter{}
var id int64 = 123 var id int64 = 123
res, err := service.NewVillage(repo, client). res, err := service.NewVillage(repo, client).GetByServerKeyAndIDWithRelations(context.Background(), "pl151", id)
GetByServerKeyAndID(context.Background(), "pl151", id, false, false)
assert.ErrorIs(t, err, domain.VillageNotFoundError{ID: id}) assert.ErrorIs(t, err, domain.VillageNotFoundError{ID: id})
assert.Zero(t, res) assert.Zero(t, res)
}) })