feat: add server persistence (#8)

Reviewed-on: twhelp/corev3#8
This commit is contained in:
Dawid Wysokiński 2023-12-25 09:13:42 +00:00
parent b4e95f3267
commit 482870d3a8
12 changed files with 475 additions and 124 deletions

View File

@ -7,6 +7,7 @@ import (
"sync" "sync"
"time" "time"
"gitea.dwysokinski.me/twhelp/corev3/internal/adapter"
"gitea.dwysokinski.me/twhelp/corev3/internal/app" "gitea.dwysokinski.me/twhelp/corev3/internal/app"
"gitea.dwysokinski.me/twhelp/corev3/internal/health" "gitea.dwysokinski.me/twhelp/corev3/internal/health"
"gitea.dwysokinski.me/twhelp/corev3/internal/health/healthfile" "gitea.dwysokinski.me/twhelp/corev3/internal/health/healthfile"
@ -47,7 +48,7 @@ var cmdConsumer = &cli.Command{
} }
consumer := port.NewServerWatermillConsumer( consumer := port.NewServerWatermillConsumer(
app.NewServerService(twSvc), app.NewServerService(adapter.NewServerBunRepository(db), twSvc),
subscriber, subscriber,
logger, logger,
marshaler, marshaler,

View File

@ -81,6 +81,7 @@ type listServersParamsApplier struct {
params domain.ListServersParams params domain.ListServersParams
} }
//nolint:gocyclo
func (a listServersParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery { func (a listServersParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery {
if keys := a.params.Keys(); len(keys) > 0 { if keys := a.params.Keys(); len(keys) > 0 {
q = q.Where("server.key IN (?)", bun.In(keys)) q = q.Where("server.key IN (?)", bun.In(keys))
@ -90,6 +91,10 @@ func (a listServersParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery {
q = q.Where("server.key > ?", keyGT.Value) q = q.Where("server.key > ?", keyGT.Value)
} }
if versionCodes := a.params.VersionCodes(); len(versionCodes) > 0 {
q = q.Where("server.version_code IN (?)", bun.In(versionCodes))
}
if open := a.params.Open(); open.Valid { if open := a.params.Open(); open.Valid {
q = q.Where("server.open = ?", open.Value) q = q.Where("server.open = ?", open.Value)
} }

View File

@ -3,6 +3,7 @@ package adapter_test
import ( import (
"cmp" "cmp"
"context" "context"
"fmt"
"net/url" "net/url"
"slices" "slices"
"testing" "testing"
@ -33,8 +34,8 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories
version := versions[0] version := versions[0]
serversToCreate := domain.BaseServers{ serversToCreate := domain.BaseServers{
domaintest.NewBaseServer(t, domaintest.BaseServerConfig{Open: true}), domaintest.NewBaseServer(t),
domaintest.NewBaseServer(t, domaintest.BaseServerConfig{Open: true}), domaintest.NewBaseServer(t),
} }
createParams, err := domain.NewCreateServerParams(serversToCreate, version.Code()) createParams, err := domain.NewCreateServerParams(serversToCreate, version.Code())
@ -63,10 +64,10 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories
} }
serversToUpdate := domain.BaseServers{ serversToUpdate := domain.BaseServers{
domaintest.NewBaseServer(t, domaintest.BaseServerConfig{ domaintest.NewBaseServer(t, func(cfg *domaintest.BaseServerConfig) {
Key: serversToCreate[0].Key(), cfg.Key = serversToCreate[0].Key()
URL: randURL(t), cfg.URL = randURL(t)
Open: !serversToCreate[0].Open(), cfg.Open = !serversToCreate[0].Open()
}), }),
} }
@ -102,6 +103,11 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories
repos := newRepos(t) repos := newRepos(t)
servers, listServersErr := repos.server.List(context.Background(), domain.NewListServersParams())
require.NoError(t, listServersErr)
require.NotEmpty(t, servers)
randServer := servers[0]
tests := []struct { tests := []struct {
name string name string
params func(t *testing.T) domain.ListServersParams params func(t *testing.T) domain.ListServersParams
@ -196,11 +202,11 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories
}, },
}, },
{ {
name: "OK: keys=[de188, en113]", name: fmt.Sprintf("OK: keys=[%s]", randServer.Key()),
params: func(t *testing.T) domain.ListServersParams { params: func(t *testing.T) domain.ListServersParams {
t.Helper() t.Helper()
params := domain.NewListServersParams() params := domain.NewListServersParams()
require.NoError(t, params.SetKeys([]string{"de188", "en113"})) require.NoError(t, params.SetKeys([]string{randServer.Key()}))
return params return params
}, },
assertServers: func(t *testing.T, params domain.ListServersParams, servers domain.Servers) { assertServers: func(t *testing.T, params domain.ListServersParams, servers domain.Servers) {
@ -225,12 +231,41 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories
}, },
}, },
{ {
name: "OK: keyGT=de188", name: fmt.Sprintf("OK: version code=[%s]", randServer.VersionCode()),
params: func(t *testing.T) domain.ListServersParams {
t.Helper()
params := domain.NewListServersParams()
require.NoError(t, params.SetVersionCodes([]string{randServer.VersionCode()}))
return params
},
assertServers: func(t *testing.T, params domain.ListServersParams, servers domain.Servers) {
t.Helper()
versionCodes := params.VersionCodes()
assert.NotEmpty(t, servers)
for _, c := range versionCodes {
assert.True(t, slices.ContainsFunc(servers, func(server domain.Server) bool {
return server.VersionCode() == c
}), c)
}
},
assertError: func(t *testing.T, err error) {
t.Helper()
require.NoError(t, err)
},
assertTotal: func(t *testing.T, params domain.ListServersParams, total int) {
t.Helper()
assert.Equal(t, len(params.Keys()), total) //nolint:testifylint
},
},
{
name: fmt.Sprintf("OK: keyGT=%s", randServer.Key()),
params: func(t *testing.T) domain.ListServersParams { params: func(t *testing.T) domain.ListServersParams {
t.Helper() t.Helper()
params := domain.NewListServersParams() params := domain.NewListServersParams()
require.NoError(t, params.SetKeyGT(domain.NullString{ require.NoError(t, params.SetKeyGT(domain.NullString{
Value: "de188", Value: randServer.Key(),
Valid: true, Valid: true,
})) }))
return params return params

View File

@ -3,26 +3,125 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"gitea.dwysokinski.me/twhelp/corev3/internal/domain" "gitea.dwysokinski.me/twhelp/corev3/internal/domain"
) )
type ServerRepository interface {
CreateOrUpdate(ctx context.Context, params ...domain.CreateServerParams) error
List(ctx context.Context, params domain.ListServersParams) (domain.Servers, error)
}
type ServerService struct { type ServerService struct {
repo ServerRepository
twSvc TWService twSvc TWService
} }
func NewServerService(twSvc TWService) *ServerService { func NewServerService(repo ServerRepository, twSvc TWService) *ServerService {
return &ServerService{twSvc: twSvc} return &ServerService{repo: repo, twSvc: twSvc}
} }
func (svc *ServerService) Sync(ctx context.Context, payload domain.SyncServersCmdPayload) error { func (svc *ServerService) Sync(ctx context.Context, payload domain.SyncServersCmdPayload) error {
versionCode := payload.VersionCode()
openServers, err := svc.twSvc.GetOpenServers(ctx, payload.URL().String()) openServers, err := svc.twSvc.GetOpenServers(ctx, payload.URL().String())
if err != nil { if err != nil {
return fmt.Errorf("couldn't get open servers for version code '%s': %w", payload.VersionCode(), err) return fmt.Errorf("couldn't get open servers for version code %s: %w", versionCode, err)
} }
log.Println(openServers) specialServers, err := svc.listAllSpecial(ctx, versionCode)
if err != nil {
return fmt.Errorf("couldn't list special servers with version code %s: %w", versionCode, err)
}
return nil currentlyStoredOpenServers, err := svc.listAllOpen(ctx, versionCode)
if err != nil {
return fmt.Errorf("couldn't list open servers with version code %s: %w", versionCode, err)
}
openServersWithoutSpecial := openServers.FilterOutSpecial(specialServers)
serversToBeClosed, err := currentlyStoredOpenServers.Close(openServersWithoutSpecial)
if err != nil {
return fmt.Errorf("couldn't close servers: %w", err)
}
params, err := domain.NewCreateServerParams(append(openServersWithoutSpecial, serversToBeClosed...), versionCode)
if err != nil {
return err
}
return svc.repo.CreateOrUpdate(ctx, params...)
}
func (svc *ServerService) listAllSpecial(ctx context.Context, versionCode string) (domain.Servers, error) {
params := domain.NewListServersParams()
if err := params.SetVersionCodes([]string{versionCode}); err != nil {
return nil, err
}
if err := params.SetSpecial(domain.NullBool{
Value: true,
Valid: true,
}); err != nil {
return nil, err
}
return svc.ListAll(ctx, params)
}
func (svc *ServerService) listAllOpen(ctx context.Context, versionCode string) (domain.Servers, error) {
params := domain.NewListServersParams()
if err := params.SetVersionCodes([]string{versionCode}); err != nil {
return nil, err
}
if err := params.SetOpen(domain.NullBool{
Value: true,
Valid: true,
}); err != nil {
return nil, err
}
return svc.ListAll(ctx, params)
}
// ListAll retrieves all servers from the database based on the given params in an optimal way.
// You can't specify a custom limit/offset/sort/keyGT for this operation.
func (svc *ServerService) ListAll(ctx context.Context, params domain.ListServersParams) (domain.Servers, error) {
if err := params.SetOffset(0); err != nil {
return nil, err
}
if err := params.SetLimit(domain.ServerListMaxLimit); err != nil {
return nil, err
}
if err := params.SetSort([]domain.ServerSort{domain.ServerSortKeyASC}); err != nil {
return nil, err
}
var servers domain.Servers
for {
ss, err := svc.repo.List(ctx, params)
if err != nil {
return nil, err
}
if len(ss) == 0 {
return servers, nil
}
servers = append(servers, ss...)
if err = params.SetKeyGT(domain.NullString{
Value: ss[len(ss)-1].Key(),
Valid: true,
}); err != nil {
return nil, err
}
}
} }

View File

@ -1,6 +1,9 @@
package domain package domain
import "net/url" import (
"net/url"
"slices"
)
type BaseServer struct { type BaseServer struct {
key string key string
@ -48,3 +51,19 @@ func (b BaseServer) IsZero() bool {
} }
type BaseServers []BaseServer type BaseServers []BaseServer
func (ss BaseServers) FilterOutSpecial(specials Servers) BaseServers {
res := make(BaseServers, 0, len(ss))
for _, s := range ss {
if slices.ContainsFunc(specials, func(special Server) bool {
return special.Special() && special.Key() == s.Key()
}) {
continue
}
res = append(res, s)
}
return res
}

View File

@ -1,6 +1,7 @@
package domain_test package domain_test
import ( import (
"slices"
"testing" "testing"
"gitea.dwysokinski.me/twhelp/corev3/internal/domain" "gitea.dwysokinski.me/twhelp/corev3/internal/domain"
@ -12,7 +13,7 @@ import (
func TestNewBaseServer(t *testing.T) { func TestNewBaseServer(t *testing.T) {
t.Parallel() t.Parallel()
validBaseServer := domaintest.NewBaseServer(t, domaintest.BaseServerConfig{}) validBaseServer := domaintest.NewBaseServer(t)
type args struct { type args struct {
key string key string
@ -81,3 +82,37 @@ func TestNewBaseServer(t *testing.T) {
}) })
} }
} }
func TestBaseServers_FilterOutSpecial(t *testing.T) {
t.Parallel()
servers := domain.BaseServers{
domaintest.NewBaseServer(t),
domaintest.NewBaseServer(t),
domaintest.NewBaseServer(t),
domaintest.NewBaseServer(t),
}
special := domain.Servers{
domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) {
cfg.Key = servers[0].Key()
cfg.Special = true
}),
domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) {
cfg.Key = servers[2].Key()
cfg.Special = true
}),
}
res := servers.FilterOutSpecial(special)
assert.Len(t, res, len(servers)-len(special))
for _, s := range servers {
if slices.ContainsFunc(special, func(server domain.Server) bool {
return server.Key() == s.Key()
}) {
return
}
assert.Contains(t, res, s)
}
}

View File

@ -14,9 +14,16 @@ type BaseServerConfig struct {
Open bool Open bool
} }
func (cfg *BaseServerConfig) init() { func NewBaseServer(tb TestingTB, opts ...func(cfg *BaseServerConfig)) domain.BaseServer {
if cfg.Key == "" { tb.Helper()
cfg.Key = RandServerKey()
cfg := &BaseServerConfig{
Key: RandServerKey(),
Open: true,
}
for _, opt := range opts {
opt(cfg)
} }
if cfg.URL == nil { if cfg.URL == nil {
@ -25,11 +32,9 @@ func (cfg *BaseServerConfig) init() {
Host: cfg.Key + "." + gofakeit.DomainName(), Host: cfg.Key + "." + gofakeit.DomainName(),
} }
} }
}
func NewBaseServer(tb TestingTB, cfg BaseServerConfig) domain.BaseServer {
cfg.init()
s, err := domain.NewBaseServer(cfg.Key, cfg.URL.String(), cfg.Open) s, err := domain.NewBaseServer(cfg.Key, cfg.URL.String(), cfg.Open)
require.NoError(tb, err) require.NoError(tb, err)
return s return s
} }

View File

@ -1,7 +1,71 @@
package domaintest package domaintest
import "github.com/brianvoe/gofakeit/v6" import (
"net/url"
"time"
"gitea.dwysokinski.me/twhelp/corev3/internal/domain"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/require"
)
func RandServerKey() string { func RandServerKey() string {
return gofakeit.LetterN(5) return gofakeit.LetterN(5)
} }
type ServerConfig struct {
Key string
VersionCode string
URL *url.URL
Open bool
Special bool
}
func NewServer(tb TestingTB, opts ...func(cfg *ServerConfig)) domain.Server {
tb.Helper()
cfg := &ServerConfig{
Key: RandServerKey(),
VersionCode: RandVersionCode(),
Open: true,
Special: false,
}
for _, opt := range opts {
opt(cfg)
}
if cfg.URL == nil {
cfg.URL = &url.URL{
Scheme: "https",
Host: cfg.Key + "." + gofakeit.DomainName(),
}
}
s, err := domain.UnmarshalServerFromDatabase(
cfg.Key,
cfg.VersionCode,
cfg.URL.String(),
cfg.Open,
cfg.Special,
0,
0,
0,
0,
0,
0,
domain.ServerConfig{},
domain.BuildingInfo{},
domain.UnitInfo{},
time.Now(),
time.Time{},
time.Time{},
time.Time{},
time.Time{},
time.Time{},
time.Time{},
)
require.NoError(tb, err)
return s
}

View File

@ -10,14 +10,17 @@ type VersionConfig struct {
Code string Code string
} }
func (cfg *VersionConfig) init() { func NewVersion(tb TestingTB, opts ...func(cfg *VersionConfig)) domain.Version {
if cfg.Code == "" { tb.Helper()
cfg.Code = RandVersionCode()
} cfg := &VersionConfig{
} Code: RandVersionCode(),
}
for _, opt := range opts {
opt(cfg)
}
func NewVersion(tb TestingTB, cfg VersionConfig) domain.Version {
cfg.init()
s, err := domain.UnmarshalVersionFromDatabase( s, err := domain.UnmarshalVersionFromDatabase(
cfg.Code, cfg.Code,
gofakeit.LetterN(10), gofakeit.LetterN(10),
@ -25,6 +28,7 @@ func NewVersion(tb TestingTB, cfg VersionConfig) domain.Version {
gofakeit.TimeZoneRegion(), gofakeit.TimeZoneRegion(),
) )
require.NoError(tb, err) require.NoError(tb, err)
return s return s
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"slices"
"time" "time"
) )
@ -182,6 +183,29 @@ func (s Server) EnnoblementDataSyncedAt() time.Time {
type Servers []Server type Servers []Server
// Close finds all servers with Server.Open returning true that are not in the given slice with open servers
// and then converts them to BaseServers with the corrected open value.
func (ss Servers) Close(open BaseServers) (BaseServers, error) {
res := make(BaseServers, 0, len(ss))
for _, s := range ss {
if !s.Open() || slices.ContainsFunc(open, func(openServer BaseServer) bool {
return openServer.Key() == s.Key() && openServer.Open() == s.Open()
}) {
continue
}
base, err := NewBaseServer(s.Key(), s.URL().String(), false)
if err != nil {
return nil, fmt.Errorf("couldn't construct BaseServer for server with key '%s': %w", s.Key(), err)
}
res = append(res, base)
}
return res, nil
}
type CreateServerParams struct { type CreateServerParams struct {
base BaseServer base BaseServer
versionCode string versionCode string
@ -231,13 +255,14 @@ const (
const ServerListMaxLimit = 500 const ServerListMaxLimit = 500
type ListServersParams struct { type ListServersParams struct {
keys []string keys []string
keyGT NullString keyGT NullString
open NullBool versionCodes []string
special NullBool open NullBool
sort []ServerSort special NullBool
limit int sort []ServerSort
offset int limit int
offset int
} }
func NewListServersParams() ListServersParams { func NewListServersParams() ListServersParams {
@ -269,6 +294,15 @@ func (params *ListServersParams) SetKeyGT(keyGT NullString) error {
return nil return nil
} }
func (params *ListServersParams) VersionCodes() []string {
return params.versionCodes
}
func (params *ListServersParams) SetVersionCodes(versionCodes []string) error {
params.versionCodes = versionCodes
return nil
}
func (params *ListServersParams) Open() NullBool { func (params *ListServersParams) Open() NullBool {
return params.open return params.open
} }

View File

@ -0,0 +1,101 @@
package domain_test
import (
"net/url"
"testing"
"gitea.dwysokinski.me/twhelp/corev3/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewSyncServersCmdPayload(t *testing.T) {
t.Parallel()
type args struct {
versionCode string
url *url.URL
}
tests := []struct {
name string
args args
expectedErr error
}{
{
name: "OK",
args: args{
versionCode: "pl",
url: &url.URL{
Scheme: "https",
Host: "plemiona.pl",
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
payload, err := domain.NewSyncServersCmdPayload(tt.args.versionCode, tt.args.url)
require.ErrorIs(t, err, tt.expectedErr)
if tt.expectedErr != nil {
return
}
assert.Equal(t, tt.args.versionCode, payload.VersionCode())
assert.Equal(t, tt.args.url, payload.URL())
})
}
}
func TestNewSyncServersCmdPayloadWithStringURL(t *testing.T) {
t.Parallel()
type args struct {
versionCode string
url string
}
tests := []struct {
name string
args args
expectedErr error
}{
{
name: "OK",
args: args{
versionCode: "pl",
url: "https://plemiona.pl",
},
},
{
name: "ERR: invalid url",
args: args{
versionCode: "pl",
url: "plemiona.pl",
},
expectedErr: domain.InvalidURLError{
URL: "plemiona.pl",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
payload, err := domain.NewSyncServersCmdPayloadWithStringURL(tt.args.versionCode, tt.args.url)
require.ErrorIs(t, err, tt.expectedErr)
if tt.expectedErr != nil {
return
}
assert.Equal(t, tt.args.versionCode, payload.VersionCode())
assert.Equal(t, tt.args.url, payload.URL().String())
})
}
}

View File

@ -2,7 +2,7 @@ package domain_test
import ( import (
"fmt" "fmt"
"net/url" "slices"
"testing" "testing"
"gitea.dwysokinski.me/twhelp/corev3/internal/domain" "gitea.dwysokinski.me/twhelp/corev3/internal/domain"
@ -11,102 +11,51 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestNewSyncServersCmdPayload(t *testing.T) { func TestServers_Close(t *testing.T) {
t.Parallel() t.Parallel()
type args struct { open := domain.BaseServers{
versionCode string domaintest.NewBaseServer(t),
url *url.URL domaintest.NewBaseServer(t),
domaintest.NewBaseServer(t),
} }
tests := []struct { servers := domain.Servers{
name string domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) {
args args cfg.Key = open[0].Key()
expectedErr error }),
}{ domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) {
{ cfg.Key = open[2].Key()
name: "OK", }),
args: args{ domaintest.NewServer(t),
versionCode: "pl", domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) {
url: &url.URL{ cfg.Open = false
Scheme: "https", }),
Host: "plemiona.pl",
},
},
},
} }
for _, tt := range tests { res, err := servers.Close(open)
tt := tt require.NoError(t, err)
assert.NotEmpty(t, res)
for _, s := range servers {
if !s.Open() || slices.ContainsFunc(open, func(server domain.BaseServer) bool {
return server.Key() == s.Key()
}) {
continue
}
t.Run(tt.name, func(t *testing.T) { assert.Contains(t, res, domaintest.NewBaseServer(t, func(cfg *domaintest.BaseServerConfig) {
t.Parallel() cfg.Key = s.Key()
cfg.URL = s.URL()
payload, err := domain.NewSyncServersCmdPayload(tt.args.versionCode, tt.args.url) cfg.Open = false
require.ErrorIs(t, err, tt.expectedErr) }))
if tt.expectedErr != nil {
return
}
assert.Equal(t, tt.args.versionCode, payload.VersionCode())
assert.Equal(t, tt.args.url, payload.URL())
})
}
}
func TestNewSyncServersCmdPayloadWithStringURL(t *testing.T) {
t.Parallel()
type args struct {
versionCode string
url string
}
tests := []struct {
name string
args args
expectedErr error
}{
{
name: "OK",
args: args{
versionCode: "pl",
url: "https://plemiona.pl",
},
},
{
name: "ERR: invalid url",
args: args{
versionCode: "pl",
url: "plemiona.pl",
},
expectedErr: domain.InvalidURLError{
URL: "plemiona.pl",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
payload, err := domain.NewSyncServersCmdPayloadWithStringURL(tt.args.versionCode, tt.args.url)
require.ErrorIs(t, err, tt.expectedErr)
if tt.expectedErr != nil {
return
}
assert.Equal(t, tt.args.versionCode, payload.VersionCode())
assert.Equal(t, tt.args.url, payload.URL().String())
})
} }
} }
func TestNewCreateServerParams(t *testing.T) { func TestNewCreateServerParams(t *testing.T) {
t.Parallel() t.Parallel()
validVersion := domaintest.NewVersion(t, domaintest.VersionConfig{}) validVersion := domaintest.NewVersion(t)
validBaseServer := domaintest.NewBaseServer(t, domaintest.BaseServerConfig{}) validBaseServer := domaintest.NewBaseServer(t)
type args struct { type args struct {
servers domain.BaseServers servers domain.BaseServers