From 90f7a5476b98e661e906fa4f11b518c8d17abd4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Wysoki=C5=84ski?= Date: Wed, 12 Oct 2022 06:39:27 +0200 Subject: [PATCH] fix(security): server id must be verified in the group repository --- internal/bundb/group.go | 3 ++- internal/bundb/group_test.go | 27 +++++++++++++++------ internal/discord/bot.go | 6 ++--- internal/discord/command_group.go | 10 ++++---- internal/service/group.go | 39 ++++++++++--------------------- internal/service/group_test.go | 12 ++++++---- 6 files changed, 49 insertions(+), 48 deletions(-) diff --git a/internal/bundb/group.go b/internal/bundb/group.go index 4b68665..8de511a 100644 --- a/internal/bundb/group.go +++ b/internal/bundb/group.go @@ -39,7 +39,7 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do return group.ToDomain(), nil } -func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error) { +func (g *Group) Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error) { if params.IsZero() { return domain.Group{}, domain.ErrNothingToUpdate } @@ -54,6 +54,7 @@ func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateG Model(&group). Returning("*"). Where("id = ?", id). + Where("server_id = ?", serverID). Apply(updateGroupsParamsApplier{params}.apply). Exec(ctx) if err != nil && !errors.Is(err, sql.ErrNoRows) { diff --git a/internal/bundb/group_test.go b/internal/bundb/group_test.go index 0388a7e..22878b7 100644 --- a/internal/bundb/group_test.go +++ b/internal/bundb/group_test.go @@ -44,17 +44,17 @@ func TestGroup_Create(t *testing.T) { }) } -func TestGroup_UpdateByID(t *testing.T) { +func TestGroup_Update(t *testing.T) { t.Parallel() db := newDB(t) fixture := loadFixtures(t, db) repo := bundb.NewGroup(db) + group := getGroupFromFixture(t, fixture, "group-server-1-1") t.Run("OK", func(t *testing.T) { t.Parallel() - group := getGroupFromFixture(t, fixture, "group-server-1-1") params := domain.UpdateGroupParams{ ChannelGains: domain.NullString{ String: group.ChannelGains + "update", @@ -74,7 +74,7 @@ func TestGroup_UpdateByID(t *testing.T) { }, } - updatedGroup, err := repo.UpdateByID(context.Background(), group.ID.String(), params) + updatedGroup, err := repo.Update(context.Background(), group.ID.String(), group.ServerID, params) assert.NoError(t, err) assert.Equal(t, params.ChannelGains.String, updatedGroup.ChannelGains) assert.Equal(t, params.ChannelLosses.String, updatedGroup.ChannelLosses) @@ -85,7 +85,7 @@ func TestGroup_UpdateByID(t *testing.T) { t.Run("ERR: nothing to update", func(t *testing.T) { t.Parallel() - updatedGroup, err := repo.UpdateByID(context.Background(), "", domain.UpdateGroupParams{}) + updatedGroup, err := repo.Update(context.Background(), "", "", domain.UpdateGroupParams{}) assert.ErrorIs(t, err, domain.ErrNothingToUpdate) assert.Zero(t, updatedGroup) }) @@ -94,7 +94,7 @@ func TestGroup_UpdateByID(t *testing.T) { t.Parallel() id := "12345" - updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{ + updatedGroup, err := repo.Update(context.Background(), id, "", domain.UpdateGroupParams{ ChannelGains: domain.NullString{ String: "update", Valid: true, @@ -104,11 +104,11 @@ func TestGroup_UpdateByID(t *testing.T) { assert.Zero(t, updatedGroup) }) - t.Run("ERR: group not found", func(t *testing.T) { + t.Run("ERR: group not found (unknown ID)", func(t *testing.T) { t.Parallel() id := uuid.NewString() - updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{ + updatedGroup, err := repo.Update(context.Background(), id, group.ServerID, domain.UpdateGroupParams{ ChannelGains: domain.NullString{ String: "update", Valid: true, @@ -117,6 +117,19 @@ func TestGroup_UpdateByID(t *testing.T) { assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: id}) assert.Zero(t, updatedGroup) }) + + t.Run("ERR: group not found (unknown ServerID)", func(t *testing.T) { + t.Parallel() + + updatedGroup, err := repo.Update(context.Background(), group.ID.String(), uuid.NewString(), domain.UpdateGroupParams{ + ChannelGains: domain.NullString{ + String: "update", + Valid: true, + }, + }) + assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: group.ID.String()}) + assert.Zero(t, updatedGroup) + }) } func TestGroup_List(t *testing.T) { diff --git a/internal/discord/bot.go b/internal/discord/bot.go index e1ea27d..5de12ae 100644 --- a/internal/discord/bot.go +++ b/internal/discord/bot.go @@ -11,9 +11,9 @@ import ( type GroupService interface { Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error) - SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error) - SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error) - SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error) + SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error) + SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error) + SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error) List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error) } diff --git a/internal/discord/command_group.go b/internal/discord/command_group.go index aae9d72..e29634b 100644 --- a/internal/discord/command_group.go +++ b/internal/discord/command_group.go @@ -366,7 +366,7 @@ func (c *groupCommand) handleSetServer(s *discordgo.Session, i *discordgo.Intera } } - _, err := c.svc.SetTWServer(ctx, group, version, server) + _, err := c.svc.SetTWServer(ctx, group, i.GuildID, version, server) if err != nil { _ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, @@ -402,7 +402,7 @@ func (c *groupCommand) handleSetChannelGains(s *discordgo.Session, i *discordgo. } } - _, err := c.svc.SetChannelGains(ctx, group, channel) + _, err := c.svc.SetChannelGains(ctx, group, i.GuildID, channel) if err != nil { _ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, @@ -438,7 +438,7 @@ func (c *groupCommand) handleSetChannelLosses(s *discordgo.Session, i *discordgo } } - _, err := c.svc.SetChannelLosses(ctx, group, channel) + _, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, channel) if err != nil { _ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, @@ -473,7 +473,7 @@ func (c *groupCommand) handleUnsetChannelGains(s *discordgo.Session, i *discordg ctx := context.Background() group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue() - _, err := c.svc.SetChannelGains(ctx, group, "") + _, err := c.svc.SetChannelGains(ctx, group, i.GuildID, "") if err != nil { _ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, @@ -496,7 +496,7 @@ func (c *groupCommand) handleUnsetChannelLosses(s *discordgo.Session, i *discord ctx := context.Background() group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue() - _, err := c.svc.SetChannelLosses(ctx, group, "") + _, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, "") if err != nil { _ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, diff --git a/internal/service/group.go b/internal/service/group.go index e5d1596..09a8dae 100644 --- a/internal/service/group.go +++ b/internal/service/group.go @@ -16,7 +16,7 @@ const ( //counterfeiter:generate -o internal/mock/group_repository.gen.go . GroupRepository type GroupRepository interface { Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error) - UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error) + Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error) List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error) } @@ -44,23 +44,8 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do } } - server, err := g.client.GetServer(ctx, params.VersionCode(), params.ServerKey()) - if err != nil { - var apiErr twhelp.APIError - if !errors.As(err, &apiErr) { - return domain.Group{}, fmt.Errorf("TWHelpClient.GetServer: %w", err) - } - return domain.Group{}, domain.ServerDoesNotExistError{ - VersionCode: params.VersionCode(), - Key: params.ServerKey(), - } - } - - if !server.Open { - return domain.Group{}, domain.ServerIsClosedError{ - VersionCode: params.VersionCode(), - Key: params.ServerKey(), - } + if err = g.checkTWServer(ctx, params.VersionCode(), params.ServerKey()); err != nil { + return domain.Group{}, err } group, err := g.repo.Create(ctx, params) @@ -71,12 +56,12 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do return group, nil } -func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error) { +func (g *Group) SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error) { if err := g.checkTWServer(ctx, versionCode, serverKey); err != nil { return domain.Group{}, err } - group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{ + group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{ VersionCode: domain.NullString{ String: versionCode, Valid: true, @@ -87,34 +72,34 @@ func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey stri }, }) if err != nil { - return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err) + return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err) } return group, nil } -func (g *Group) SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error) { - group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{ +func (g *Group) SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error) { + group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{ ChannelGains: domain.NullString{ String: channel, Valid: true, }, }) if err != nil { - return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err) + return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err) } return group, nil } -func (g *Group) SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error) { - group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{ +func (g *Group) SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error) { + group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{ ChannelLosses: domain.NullString{ String: channel, Valid: true, }, }) if err != nil { - return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err) + return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err) } return group, nil } diff --git a/internal/service/group_test.go b/internal/service/group_test.go index d55c46a..01d2198 100644 --- a/internal/service/group_test.go +++ b/internal/service/group_test.go @@ -130,6 +130,7 @@ func TestGroup_SetTWServer(t *testing.T) { t.Parallel() id := uuid.NewString() + serverID := uuid.NewString() versionCode := "pl" serverKey := "pl181" @@ -137,10 +138,10 @@ func TestGroup_SetTWServer(t *testing.T) { t.Parallel() repo := &mock.FakeGroupRepository{} - repo.UpdateByIDCalls(func(_ context.Context, id string, p domain.UpdateGroupParams) (domain.Group, error) { + repo.UpdateCalls(func(_ context.Context, id, serverID string, p domain.UpdateGroupParams) (domain.Group, error) { return domain.Group{ ID: id, - ServerID: uuid.NewString(), + ServerID: serverID, ChannelGains: p.ChannelGains.String, ChannelLosses: p.ChannelLosses.String, ServerKey: p.ServerKey.String, @@ -159,9 +160,10 @@ func TestGroup_SetTWServer(t *testing.T) { }, nil }) - g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, versionCode, serverKey) + g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey) assert.NoError(t, err) assert.Equal(t, id, g.ID) + assert.Equal(t, serverID, g.ServerID) assert.Equal(t, serverKey, g.ServerKey) assert.Equal(t, versionCode, g.VersionCode) }) @@ -177,7 +179,7 @@ func TestGroup_SetTWServer(t *testing.T) { } }) - g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey) + g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey) assert.ErrorIs(t, err, domain.ServerDoesNotExistError{ VersionCode: versionCode, Key: serverKey, @@ -197,7 +199,7 @@ func TestGroup_SetTWServer(t *testing.T) { }, nil }) - g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey) + g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey) assert.ErrorIs(t, err, domain.ServerIsClosedError{ VersionCode: versionCode, Key: serverKey,