diff --git a/cmd/twhelp/cmd_consumer.go b/cmd/twhelp/cmd_consumer.go index 8a54425..66745e7 100644 --- a/cmd/twhelp/cmd_consumer.go +++ b/cmd/twhelp/cmd_consumer.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "gitea.dwysokinski.me/twhelp/corev3/internal/adapter" "gitea.dwysokinski.me/twhelp/corev3/internal/app" "gitea.dwysokinski.me/twhelp/corev3/internal/health" "gitea.dwysokinski.me/twhelp/corev3/internal/health/healthfile" @@ -47,7 +48,7 @@ var cmdConsumer = &cli.Command{ } consumer := port.NewServerWatermillConsumer( - app.NewServerService(twSvc), + app.NewServerService(adapter.NewServerBunRepository(db), twSvc), subscriber, logger, marshaler, diff --git a/internal/adapter/repository_bun_server.go b/internal/adapter/repository_bun_server.go index 7a94c09..388355b 100644 --- a/internal/adapter/repository_bun_server.go +++ b/internal/adapter/repository_bun_server.go @@ -81,6 +81,7 @@ type listServersParamsApplier struct { params domain.ListServersParams } +//nolint:gocyclo func (a listServersParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery { if keys := a.params.Keys(); len(keys) > 0 { q = q.Where("server.key IN (?)", bun.In(keys)) @@ -90,6 +91,10 @@ func (a listServersParamsApplier) apply(q *bun.SelectQuery) *bun.SelectQuery { q = q.Where("server.key > ?", keyGT.Value) } + if versionCodes := a.params.VersionCodes(); len(versionCodes) > 0 { + q = q.Where("server.version_code IN (?)", bun.In(versionCodes)) + } + if open := a.params.Open(); open.Valid { q = q.Where("server.open = ?", open.Value) } diff --git a/internal/adapter/repository_server_test.go b/internal/adapter/repository_server_test.go index db764c1..989c4cb 100644 --- a/internal/adapter/repository_server_test.go +++ b/internal/adapter/repository_server_test.go @@ -3,6 +3,7 @@ package adapter_test import ( "cmp" "context" + "fmt" "net/url" "slices" "testing" @@ -33,8 +34,8 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories version := versions[0] serversToCreate := domain.BaseServers{ - domaintest.NewBaseServer(t, domaintest.BaseServerConfig{Open: true}), - domaintest.NewBaseServer(t, domaintest.BaseServerConfig{Open: true}), + domaintest.NewBaseServer(t), + domaintest.NewBaseServer(t), } createParams, err := domain.NewCreateServerParams(serversToCreate, version.Code()) @@ -63,10 +64,10 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories } serversToUpdate := domain.BaseServers{ - domaintest.NewBaseServer(t, domaintest.BaseServerConfig{ - Key: serversToCreate[0].Key(), - URL: randURL(t), - Open: !serversToCreate[0].Open(), + domaintest.NewBaseServer(t, func(cfg *domaintest.BaseServerConfig) { + cfg.Key = serversToCreate[0].Key() + cfg.URL = randURL(t) + cfg.Open = !serversToCreate[0].Open() }), } @@ -102,6 +103,11 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories repos := newRepos(t) + servers, listServersErr := repos.server.List(context.Background(), domain.NewListServersParams()) + require.NoError(t, listServersErr) + require.NotEmpty(t, servers) + randServer := servers[0] + tests := []struct { name string params func(t *testing.T) domain.ListServersParams @@ -196,11 +202,11 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories }, }, { - name: "OK: keys=[de188, en113]", + name: fmt.Sprintf("OK: keys=[%s]", randServer.Key()), params: func(t *testing.T) domain.ListServersParams { t.Helper() params := domain.NewListServersParams() - require.NoError(t, params.SetKeys([]string{"de188", "en113"})) + require.NoError(t, params.SetKeys([]string{randServer.Key()})) return params }, assertServers: func(t *testing.T, params domain.ListServersParams, servers domain.Servers) { @@ -225,12 +231,41 @@ func testServerRepository(t *testing.T, newRepos func(t *testing.T) repositories }, }, { - name: "OK: keyGT=de188", + name: fmt.Sprintf("OK: version code=[%s]", randServer.VersionCode()), + params: func(t *testing.T) domain.ListServersParams { + t.Helper() + params := domain.NewListServersParams() + require.NoError(t, params.SetVersionCodes([]string{randServer.VersionCode()})) + return params + }, + assertServers: func(t *testing.T, params domain.ListServersParams, servers domain.Servers) { + t.Helper() + + versionCodes := params.VersionCodes() + + assert.NotEmpty(t, servers) + for _, c := range versionCodes { + assert.True(t, slices.ContainsFunc(servers, func(server domain.Server) bool { + return server.VersionCode() == c + }), c) + } + }, + assertError: func(t *testing.T, err error) { + t.Helper() + require.NoError(t, err) + }, + assertTotal: func(t *testing.T, params domain.ListServersParams, total int) { + t.Helper() + assert.Equal(t, len(params.Keys()), total) //nolint:testifylint + }, + }, + { + name: fmt.Sprintf("OK: keyGT=%s", randServer.Key()), params: func(t *testing.T) domain.ListServersParams { t.Helper() params := domain.NewListServersParams() require.NoError(t, params.SetKeyGT(domain.NullString{ - Value: "de188", + Value: randServer.Key(), Valid: true, })) return params diff --git a/internal/app/service_server.go b/internal/app/service_server.go index 0f08b9c..2432e01 100644 --- a/internal/app/service_server.go +++ b/internal/app/service_server.go @@ -3,26 +3,125 @@ package app import ( "context" "fmt" - "log" "gitea.dwysokinski.me/twhelp/corev3/internal/domain" ) +type ServerRepository interface { + CreateOrUpdate(ctx context.Context, params ...domain.CreateServerParams) error + List(ctx context.Context, params domain.ListServersParams) (domain.Servers, error) +} + type ServerService struct { + repo ServerRepository twSvc TWService } -func NewServerService(twSvc TWService) *ServerService { - return &ServerService{twSvc: twSvc} +func NewServerService(repo ServerRepository, twSvc TWService) *ServerService { + return &ServerService{repo: repo, twSvc: twSvc} } func (svc *ServerService) Sync(ctx context.Context, payload domain.SyncServersCmdPayload) error { + versionCode := payload.VersionCode() + openServers, err := svc.twSvc.GetOpenServers(ctx, payload.URL().String()) if err != nil { - return fmt.Errorf("couldn't get open servers for version code '%s': %w", payload.VersionCode(), err) + return fmt.Errorf("couldn't get open servers for version code %s: %w", versionCode, err) } - log.Println(openServers) + specialServers, err := svc.listAllSpecial(ctx, versionCode) + if err != nil { + return fmt.Errorf("couldn't list special servers with version code %s: %w", versionCode, err) + } - return nil + currentlyStoredOpenServers, err := svc.listAllOpen(ctx, versionCode) + if err != nil { + return fmt.Errorf("couldn't list open servers with version code %s: %w", versionCode, err) + } + + openServersWithoutSpecial := openServers.FilterOutSpecial(specialServers) + + serversToBeClosed, err := currentlyStoredOpenServers.Close(openServersWithoutSpecial) + if err != nil { + return fmt.Errorf("couldn't close servers: %w", err) + } + + params, err := domain.NewCreateServerParams(append(openServersWithoutSpecial, serversToBeClosed...), versionCode) + if err != nil { + return err + } + + return svc.repo.CreateOrUpdate(ctx, params...) +} + +func (svc *ServerService) listAllSpecial(ctx context.Context, versionCode string) (domain.Servers, error) { + params := domain.NewListServersParams() + + if err := params.SetVersionCodes([]string{versionCode}); err != nil { + return nil, err + } + + if err := params.SetSpecial(domain.NullBool{ + Value: true, + Valid: true, + }); err != nil { + return nil, err + } + + return svc.ListAll(ctx, params) +} + +func (svc *ServerService) listAllOpen(ctx context.Context, versionCode string) (domain.Servers, error) { + params := domain.NewListServersParams() + + if err := params.SetVersionCodes([]string{versionCode}); err != nil { + return nil, err + } + + if err := params.SetOpen(domain.NullBool{ + Value: true, + Valid: true, + }); err != nil { + return nil, err + } + + return svc.ListAll(ctx, params) +} + +// ListAll retrieves all servers from the database based on the given params in an optimal way. +// You can't specify a custom limit/offset/sort/keyGT for this operation. +func (svc *ServerService) ListAll(ctx context.Context, params domain.ListServersParams) (domain.Servers, error) { + if err := params.SetOffset(0); err != nil { + return nil, err + } + + if err := params.SetLimit(domain.ServerListMaxLimit); err != nil { + return nil, err + } + + if err := params.SetSort([]domain.ServerSort{domain.ServerSortKeyASC}); err != nil { + return nil, err + } + + var servers domain.Servers + + for { + ss, err := svc.repo.List(ctx, params) + if err != nil { + return nil, err + } + + if len(ss) == 0 { + return servers, nil + } + + servers = append(servers, ss...) + + if err = params.SetKeyGT(domain.NullString{ + Value: ss[len(ss)-1].Key(), + Valid: true, + }); err != nil { + return nil, err + } + } } diff --git a/internal/domain/base_server.go b/internal/domain/base_server.go index be85f7c..941d8da 100644 --- a/internal/domain/base_server.go +++ b/internal/domain/base_server.go @@ -1,6 +1,9 @@ package domain -import "net/url" +import ( + "net/url" + "slices" +) type BaseServer struct { key string @@ -48,3 +51,19 @@ func (b BaseServer) IsZero() bool { } type BaseServers []BaseServer + +func (ss BaseServers) FilterOutSpecial(specials Servers) BaseServers { + res := make(BaseServers, 0, len(ss)) + + for _, s := range ss { + if slices.ContainsFunc(specials, func(special Server) bool { + return special.Special() && special.Key() == s.Key() + }) { + continue + } + + res = append(res, s) + } + + return res +} diff --git a/internal/domain/base_server_test.go b/internal/domain/base_server_test.go index 6cbe3a3..2cc0b71 100644 --- a/internal/domain/base_server_test.go +++ b/internal/domain/base_server_test.go @@ -1,6 +1,7 @@ package domain_test import ( + "slices" "testing" "gitea.dwysokinski.me/twhelp/corev3/internal/domain" @@ -12,7 +13,7 @@ import ( func TestNewBaseServer(t *testing.T) { t.Parallel() - validBaseServer := domaintest.NewBaseServer(t, domaintest.BaseServerConfig{}) + validBaseServer := domaintest.NewBaseServer(t) type args struct { key string @@ -81,3 +82,37 @@ func TestNewBaseServer(t *testing.T) { }) } } + +func TestBaseServers_FilterOutSpecial(t *testing.T) { + t.Parallel() + + servers := domain.BaseServers{ + domaintest.NewBaseServer(t), + domaintest.NewBaseServer(t), + domaintest.NewBaseServer(t), + domaintest.NewBaseServer(t), + } + + special := domain.Servers{ + domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) { + cfg.Key = servers[0].Key() + cfg.Special = true + }), + domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) { + cfg.Key = servers[2].Key() + cfg.Special = true + }), + } + + res := servers.FilterOutSpecial(special) + assert.Len(t, res, len(servers)-len(special)) + for _, s := range servers { + if slices.ContainsFunc(special, func(server domain.Server) bool { + return server.Key() == s.Key() + }) { + return + } + + assert.Contains(t, res, s) + } +} diff --git a/internal/domain/domaintest/base_server.go b/internal/domain/domaintest/base_server.go index 91dbdfa..d752607 100644 --- a/internal/domain/domaintest/base_server.go +++ b/internal/domain/domaintest/base_server.go @@ -14,9 +14,16 @@ type BaseServerConfig struct { Open bool } -func (cfg *BaseServerConfig) init() { - if cfg.Key == "" { - cfg.Key = RandServerKey() +func NewBaseServer(tb TestingTB, opts ...func(cfg *BaseServerConfig)) domain.BaseServer { + tb.Helper() + + cfg := &BaseServerConfig{ + Key: RandServerKey(), + Open: true, + } + + for _, opt := range opts { + opt(cfg) } if cfg.URL == nil { @@ -25,11 +32,9 @@ func (cfg *BaseServerConfig) init() { Host: cfg.Key + "." + gofakeit.DomainName(), } } -} -func NewBaseServer(tb TestingTB, cfg BaseServerConfig) domain.BaseServer { - cfg.init() s, err := domain.NewBaseServer(cfg.Key, cfg.URL.String(), cfg.Open) require.NoError(tb, err) + return s } diff --git a/internal/domain/domaintest/server.go b/internal/domain/domaintest/server.go index 90dc7bd..c97e0b2 100644 --- a/internal/domain/domaintest/server.go +++ b/internal/domain/domaintest/server.go @@ -1,7 +1,71 @@ package domaintest -import "github.com/brianvoe/gofakeit/v6" +import ( + "net/url" + "time" + + "gitea.dwysokinski.me/twhelp/corev3/internal/domain" + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/require" +) func RandServerKey() string { return gofakeit.LetterN(5) } + +type ServerConfig struct { + Key string + VersionCode string + URL *url.URL + Open bool + Special bool +} + +func NewServer(tb TestingTB, opts ...func(cfg *ServerConfig)) domain.Server { + tb.Helper() + + cfg := &ServerConfig{ + Key: RandServerKey(), + VersionCode: RandVersionCode(), + Open: true, + Special: false, + } + + for _, opt := range opts { + opt(cfg) + } + + if cfg.URL == nil { + cfg.URL = &url.URL{ + Scheme: "https", + Host: cfg.Key + "." + gofakeit.DomainName(), + } + } + + s, err := domain.UnmarshalServerFromDatabase( + cfg.Key, + cfg.VersionCode, + cfg.URL.String(), + cfg.Open, + cfg.Special, + 0, + 0, + 0, + 0, + 0, + 0, + domain.ServerConfig{}, + domain.BuildingInfo{}, + domain.UnitInfo{}, + time.Now(), + time.Time{}, + time.Time{}, + time.Time{}, + time.Time{}, + time.Time{}, + time.Time{}, + ) + require.NoError(tb, err) + + return s +} diff --git a/internal/domain/domaintest/version.go b/internal/domain/domaintest/version.go index 81a4206..eed7efa 100644 --- a/internal/domain/domaintest/version.go +++ b/internal/domain/domaintest/version.go @@ -10,14 +10,17 @@ type VersionConfig struct { Code string } -func (cfg *VersionConfig) init() { - if cfg.Code == "" { - cfg.Code = RandVersionCode() - } -} +func NewVersion(tb TestingTB, opts ...func(cfg *VersionConfig)) domain.Version { + tb.Helper() + + cfg := &VersionConfig{ + Code: RandVersionCode(), + } + + for _, opt := range opts { + opt(cfg) + } -func NewVersion(tb TestingTB, cfg VersionConfig) domain.Version { - cfg.init() s, err := domain.UnmarshalVersionFromDatabase( cfg.Code, gofakeit.LetterN(10), @@ -25,6 +28,7 @@ func NewVersion(tb TestingTB, cfg VersionConfig) domain.Version { gofakeit.TimeZoneRegion(), ) require.NoError(tb, err) + return s } diff --git a/internal/domain/server.go b/internal/domain/server.go index 3da74eb..43c1ea2 100644 --- a/internal/domain/server.go +++ b/internal/domain/server.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/url" + "slices" "time" ) @@ -182,6 +183,29 @@ func (s Server) EnnoblementDataSyncedAt() time.Time { type Servers []Server +// Close finds all servers with Server.Open returning true that are not in the given slice with open servers +// and then converts them to BaseServers with the corrected open value. +func (ss Servers) Close(open BaseServers) (BaseServers, error) { + res := make(BaseServers, 0, len(ss)) + + for _, s := range ss { + if !s.Open() || slices.ContainsFunc(open, func(openServer BaseServer) bool { + return openServer.Key() == s.Key() && openServer.Open() == s.Open() + }) { + continue + } + + base, err := NewBaseServer(s.Key(), s.URL().String(), false) + if err != nil { + return nil, fmt.Errorf("couldn't construct BaseServer for server with key '%s': %w", s.Key(), err) + } + + res = append(res, base) + } + + return res, nil +} + type CreateServerParams struct { base BaseServer versionCode string @@ -231,13 +255,14 @@ const ( const ServerListMaxLimit = 500 type ListServersParams struct { - keys []string - keyGT NullString - open NullBool - special NullBool - sort []ServerSort - limit int - offset int + keys []string + keyGT NullString + versionCodes []string + open NullBool + special NullBool + sort []ServerSort + limit int + offset int } func NewListServersParams() ListServersParams { @@ -269,6 +294,15 @@ func (params *ListServersParams) SetKeyGT(keyGT NullString) error { return nil } +func (params *ListServersParams) VersionCodes() []string { + return params.versionCodes +} + +func (params *ListServersParams) SetVersionCodes(versionCodes []string) error { + params.versionCodes = versionCodes + return nil +} + func (params *ListServersParams) Open() NullBool { return params.open } diff --git a/internal/domain/server_message_payloads_test.go b/internal/domain/server_message_payloads_test.go new file mode 100644 index 0000000..a45ed75 --- /dev/null +++ b/internal/domain/server_message_payloads_test.go @@ -0,0 +1,101 @@ +package domain_test + +import ( + "net/url" + "testing" + + "gitea.dwysokinski.me/twhelp/corev3/internal/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSyncServersCmdPayload(t *testing.T) { + t.Parallel() + + type args struct { + versionCode string + url *url.URL + } + + tests := []struct { + name string + args args + expectedErr error + }{ + { + name: "OK", + args: args{ + versionCode: "pl", + url: &url.URL{ + Scheme: "https", + Host: "plemiona.pl", + }, + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + payload, err := domain.NewSyncServersCmdPayload(tt.args.versionCode, tt.args.url) + require.ErrorIs(t, err, tt.expectedErr) + if tt.expectedErr != nil { + return + } + assert.Equal(t, tt.args.versionCode, payload.VersionCode()) + assert.Equal(t, tt.args.url, payload.URL()) + }) + } +} + +func TestNewSyncServersCmdPayloadWithStringURL(t *testing.T) { + t.Parallel() + + type args struct { + versionCode string + url string + } + + tests := []struct { + name string + args args + expectedErr error + }{ + { + name: "OK", + args: args{ + versionCode: "pl", + url: "https://plemiona.pl", + }, + }, + { + name: "ERR: invalid url", + args: args{ + versionCode: "pl", + url: "plemiona.pl", + }, + expectedErr: domain.InvalidURLError{ + URL: "plemiona.pl", + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + payload, err := domain.NewSyncServersCmdPayloadWithStringURL(tt.args.versionCode, tt.args.url) + require.ErrorIs(t, err, tt.expectedErr) + if tt.expectedErr != nil { + return + } + assert.Equal(t, tt.args.versionCode, payload.VersionCode()) + assert.Equal(t, tt.args.url, payload.URL().String()) + }) + } +} diff --git a/internal/domain/server_test.go b/internal/domain/server_test.go index 807979c..927ce67 100644 --- a/internal/domain/server_test.go +++ b/internal/domain/server_test.go @@ -2,7 +2,7 @@ package domain_test import ( "fmt" - "net/url" + "slices" "testing" "gitea.dwysokinski.me/twhelp/corev3/internal/domain" @@ -11,102 +11,51 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewSyncServersCmdPayload(t *testing.T) { +func TestServers_Close(t *testing.T) { t.Parallel() - type args struct { - versionCode string - url *url.URL + open := domain.BaseServers{ + domaintest.NewBaseServer(t), + domaintest.NewBaseServer(t), + domaintest.NewBaseServer(t), } - tests := []struct { - name string - args args - expectedErr error - }{ - { - name: "OK", - args: args{ - versionCode: "pl", - url: &url.URL{ - Scheme: "https", - Host: "plemiona.pl", - }, - }, - }, + servers := domain.Servers{ + domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) { + cfg.Key = open[0].Key() + }), + domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) { + cfg.Key = open[2].Key() + }), + domaintest.NewServer(t), + domaintest.NewServer(t, func(cfg *domaintest.ServerConfig) { + cfg.Open = false + }), } - for _, tt := range tests { - tt := tt + res, err := servers.Close(open) + require.NoError(t, err) + assert.NotEmpty(t, res) + for _, s := range servers { + if !s.Open() || slices.ContainsFunc(open, func(server domain.BaseServer) bool { + return server.Key() == s.Key() + }) { + continue + } - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - payload, err := domain.NewSyncServersCmdPayload(tt.args.versionCode, tt.args.url) - require.ErrorIs(t, err, tt.expectedErr) - if tt.expectedErr != nil { - return - } - assert.Equal(t, tt.args.versionCode, payload.VersionCode()) - assert.Equal(t, tt.args.url, payload.URL()) - }) - } -} - -func TestNewSyncServersCmdPayloadWithStringURL(t *testing.T) { - t.Parallel() - - type args struct { - versionCode string - url string - } - - tests := []struct { - name string - args args - expectedErr error - }{ - { - name: "OK", - args: args{ - versionCode: "pl", - url: "https://plemiona.pl", - }, - }, - { - name: "ERR: invalid url", - args: args{ - versionCode: "pl", - url: "plemiona.pl", - }, - expectedErr: domain.InvalidURLError{ - URL: "plemiona.pl", - }, - }, - } - - for _, tt := range tests { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - payload, err := domain.NewSyncServersCmdPayloadWithStringURL(tt.args.versionCode, tt.args.url) - require.ErrorIs(t, err, tt.expectedErr) - if tt.expectedErr != nil { - return - } - assert.Equal(t, tt.args.versionCode, payload.VersionCode()) - assert.Equal(t, tt.args.url, payload.URL().String()) - }) + assert.Contains(t, res, domaintest.NewBaseServer(t, func(cfg *domaintest.BaseServerConfig) { + cfg.Key = s.Key() + cfg.URL = s.URL() + cfg.Open = false + })) } } func TestNewCreateServerParams(t *testing.T) { t.Parallel() - validVersion := domaintest.NewVersion(t, domaintest.VersionConfig{}) - validBaseServer := domaintest.NewBaseServer(t, domaintest.BaseServerConfig{}) + validVersion := domaintest.NewVersion(t) + validBaseServer := domaintest.NewBaseServer(t) type args struct { servers domain.BaseServers