fix(security): server id must be verified in the group repository
All checks were successful
continuous-integration/drone/pr Build is passing

This commit is contained in:
Dawid Wysokiński 2022-10-12 06:39:27 +02:00
parent 9c1884d624
commit 90f7a5476b
Signed by: Kichiyaki
GPG Key ID: B5445E357FB8B892
6 changed files with 49 additions and 48 deletions

View File

@ -39,7 +39,7 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
return group.ToDomain(), nil
}
func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error) {
func (g *Group) Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error) {
if params.IsZero() {
return domain.Group{}, domain.ErrNothingToUpdate
}
@ -54,6 +54,7 @@ func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateG
Model(&group).
Returning("*").
Where("id = ?", id).
Where("server_id = ?", serverID).
Apply(updateGroupsParamsApplier{params}.apply).
Exec(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {

View File

@ -44,17 +44,17 @@ func TestGroup_Create(t *testing.T) {
})
}
func TestGroup_UpdateByID(t *testing.T) {
func TestGroup_Update(t *testing.T) {
t.Parallel()
db := newDB(t)
fixture := loadFixtures(t, db)
repo := bundb.NewGroup(db)
group := getGroupFromFixture(t, fixture, "group-server-1-1")
t.Run("OK", func(t *testing.T) {
t.Parallel()
group := getGroupFromFixture(t, fixture, "group-server-1-1")
params := domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: group.ChannelGains + "update",
@ -74,7 +74,7 @@ func TestGroup_UpdateByID(t *testing.T) {
},
}
updatedGroup, err := repo.UpdateByID(context.Background(), group.ID.String(), params)
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), group.ServerID, params)
assert.NoError(t, err)
assert.Equal(t, params.ChannelGains.String, updatedGroup.ChannelGains)
assert.Equal(t, params.ChannelLosses.String, updatedGroup.ChannelLosses)
@ -85,7 +85,7 @@ func TestGroup_UpdateByID(t *testing.T) {
t.Run("ERR: nothing to update", func(t *testing.T) {
t.Parallel()
updatedGroup, err := repo.UpdateByID(context.Background(), "", domain.UpdateGroupParams{})
updatedGroup, err := repo.Update(context.Background(), "", "", domain.UpdateGroupParams{})
assert.ErrorIs(t, err, domain.ErrNothingToUpdate)
assert.Zero(t, updatedGroup)
})
@ -94,7 +94,7 @@ func TestGroup_UpdateByID(t *testing.T) {
t.Parallel()
id := "12345"
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
updatedGroup, err := repo.Update(context.Background(), id, "", domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: "update",
Valid: true,
@ -104,11 +104,11 @@ func TestGroup_UpdateByID(t *testing.T) {
assert.Zero(t, updatedGroup)
})
t.Run("ERR: group not found", func(t *testing.T) {
t.Run("ERR: group not found (unknown ID)", func(t *testing.T) {
t.Parallel()
id := uuid.NewString()
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
updatedGroup, err := repo.Update(context.Background(), id, group.ServerID, domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: "update",
Valid: true,
@ -117,6 +117,19 @@ func TestGroup_UpdateByID(t *testing.T) {
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: id})
assert.Zero(t, updatedGroup)
})
t.Run("ERR: group not found (unknown ServerID)", func(t *testing.T) {
t.Parallel()
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), uuid.NewString(), domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: "update",
Valid: true,
},
})
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: group.ID.String()})
assert.Zero(t, updatedGroup)
})
}
func TestGroup_List(t *testing.T) {

View File

@ -11,9 +11,9 @@ import (
type GroupService interface {
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error)
SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error)
SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error)
SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error)
SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error)
SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error)
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
}

View File

@ -366,7 +366,7 @@ func (c *groupCommand) handleSetServer(s *discordgo.Session, i *discordgo.Intera
}
}
_, err := c.svc.SetTWServer(ctx, group, version, server)
_, err := c.svc.SetTWServer(ctx, group, i.GuildID, version, server)
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -402,7 +402,7 @@ func (c *groupCommand) handleSetChannelGains(s *discordgo.Session, i *discordgo.
}
}
_, err := c.svc.SetChannelGains(ctx, group, channel)
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, channel)
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -438,7 +438,7 @@ func (c *groupCommand) handleSetChannelLosses(s *discordgo.Session, i *discordgo
}
}
_, err := c.svc.SetChannelLosses(ctx, group, channel)
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, channel)
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -473,7 +473,7 @@ func (c *groupCommand) handleUnsetChannelGains(s *discordgo.Session, i *discordg
ctx := context.Background()
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
_, err := c.svc.SetChannelGains(ctx, group, "")
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, "")
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -496,7 +496,7 @@ func (c *groupCommand) handleUnsetChannelLosses(s *discordgo.Session, i *discord
ctx := context.Background()
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
_, err := c.svc.SetChannelLosses(ctx, group, "")
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, "")
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,

View File

@ -16,7 +16,7 @@ const (
//counterfeiter:generate -o internal/mock/group_repository.gen.go . GroupRepository
type GroupRepository interface {
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error)
Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error)
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
}
@ -44,23 +44,8 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
}
}
server, err := g.client.GetServer(ctx, params.VersionCode(), params.ServerKey())
if err != nil {
var apiErr twhelp.APIError
if !errors.As(err, &apiErr) {
return domain.Group{}, fmt.Errorf("TWHelpClient.GetServer: %w", err)
}
return domain.Group{}, domain.ServerDoesNotExistError{
VersionCode: params.VersionCode(),
Key: params.ServerKey(),
}
}
if !server.Open {
return domain.Group{}, domain.ServerIsClosedError{
VersionCode: params.VersionCode(),
Key: params.ServerKey(),
}
if err = g.checkTWServer(ctx, params.VersionCode(), params.ServerKey()); err != nil {
return domain.Group{}, err
}
group, err := g.repo.Create(ctx, params)
@ -71,12 +56,12 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
return group, nil
}
func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error) {
func (g *Group) SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error) {
if err := g.checkTWServer(ctx, versionCode, serverKey); err != nil {
return domain.Group{}, err
}
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
VersionCode: domain.NullString{
String: versionCode,
Valid: true,
@ -87,34 +72,34 @@ func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey stri
},
})
if err != nil {
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
}
return group, nil
}
func (g *Group) SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error) {
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
func (g *Group) SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: channel,
Valid: true,
},
})
if err != nil {
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
}
return group, nil
}
func (g *Group) SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error) {
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
func (g *Group) SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
ChannelLosses: domain.NullString{
String: channel,
Valid: true,
},
})
if err != nil {
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
}
return group, nil
}

View File

@ -130,6 +130,7 @@ func TestGroup_SetTWServer(t *testing.T) {
t.Parallel()
id := uuid.NewString()
serverID := uuid.NewString()
versionCode := "pl"
serverKey := "pl181"
@ -137,10 +138,10 @@ func TestGroup_SetTWServer(t *testing.T) {
t.Parallel()
repo := &mock.FakeGroupRepository{}
repo.UpdateByIDCalls(func(_ context.Context, id string, p domain.UpdateGroupParams) (domain.Group, error) {
repo.UpdateCalls(func(_ context.Context, id, serverID string, p domain.UpdateGroupParams) (domain.Group, error) {
return domain.Group{
ID: id,
ServerID: uuid.NewString(),
ServerID: serverID,
ChannelGains: p.ChannelGains.String,
ChannelLosses: p.ChannelLosses.String,
ServerKey: p.ServerKey.String,
@ -159,9 +160,10 @@ func TestGroup_SetTWServer(t *testing.T) {
}, nil
})
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, versionCode, serverKey)
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
assert.NoError(t, err)
assert.Equal(t, id, g.ID)
assert.Equal(t, serverID, g.ServerID)
assert.Equal(t, serverKey, g.ServerKey)
assert.Equal(t, versionCode, g.VersionCode)
})
@ -177,7 +179,7 @@ func TestGroup_SetTWServer(t *testing.T) {
}
})
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
assert.ErrorIs(t, err, domain.ServerDoesNotExistError{
VersionCode: versionCode,
Key: serverKey,
@ -197,7 +199,7 @@ func TestGroup_SetTWServer(t *testing.T) {
}, nil
})
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
assert.ErrorIs(t, err, domain.ServerIsClosedError{
VersionCode: versionCode,
Key: serverKey,