fix(security): server id must be verified in the group repository #15

Merged
Kichiyaki merged 1 commits from fix/security-set-unset into master 2022-10-12 04:45:17 +00:00
6 changed files with 49 additions and 48 deletions

View File

@ -39,7 +39,7 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
return group.ToDomain(), nil
}
func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error) {
func (g *Group) Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error) {
if params.IsZero() {
return domain.Group{}, domain.ErrNothingToUpdate
}
@ -54,6 +54,7 @@ func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateG
Model(&group).
Returning("*").
Where("id = ?", id).
Where("server_id = ?", serverID).
Apply(updateGroupsParamsApplier{params}.apply).
Exec(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {

View File

@ -44,17 +44,17 @@ func TestGroup_Create(t *testing.T) {
})
}
func TestGroup_UpdateByID(t *testing.T) {
func TestGroup_Update(t *testing.T) {
t.Parallel()
db := newDB(t)
fixture := loadFixtures(t, db)
repo := bundb.NewGroup(db)
group := getGroupFromFixture(t, fixture, "group-server-1-1")
t.Run("OK", func(t *testing.T) {
t.Parallel()
group := getGroupFromFixture(t, fixture, "group-server-1-1")
params := domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: group.ChannelGains + "update",
@ -74,7 +74,7 @@ func TestGroup_UpdateByID(t *testing.T) {
},
}
updatedGroup, err := repo.UpdateByID(context.Background(), group.ID.String(), params)
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), group.ServerID, params)
assert.NoError(t, err)
assert.Equal(t, params.ChannelGains.String, updatedGroup.ChannelGains)
assert.Equal(t, params.ChannelLosses.String, updatedGroup.ChannelLosses)
@ -85,7 +85,7 @@ func TestGroup_UpdateByID(t *testing.T) {
t.Run("ERR: nothing to update", func(t *testing.T) {
t.Parallel()
updatedGroup, err := repo.UpdateByID(context.Background(), "", domain.UpdateGroupParams{})
updatedGroup, err := repo.Update(context.Background(), "", "", domain.UpdateGroupParams{})
assert.ErrorIs(t, err, domain.ErrNothingToUpdate)
assert.Zero(t, updatedGroup)
})
@ -94,7 +94,7 @@ func TestGroup_UpdateByID(t *testing.T) {
t.Parallel()
id := "12345"
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
updatedGroup, err := repo.Update(context.Background(), id, "", domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: "update",
Valid: true,
@ -104,11 +104,11 @@ func TestGroup_UpdateByID(t *testing.T) {
assert.Zero(t, updatedGroup)
})
t.Run("ERR: group not found", func(t *testing.T) {
t.Run("ERR: group not found (unknown ID)", func(t *testing.T) {
t.Parallel()
id := uuid.NewString()
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
updatedGroup, err := repo.Update(context.Background(), id, group.ServerID, domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: "update",
Valid: true,
@ -117,6 +117,19 @@ func TestGroup_UpdateByID(t *testing.T) {
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: id})
assert.Zero(t, updatedGroup)
})
t.Run("ERR: group not found (unknown ServerID)", func(t *testing.T) {
t.Parallel()
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), uuid.NewString(), domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: "update",
Valid: true,
},
})
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: group.ID.String()})
assert.Zero(t, updatedGroup)
})
}
func TestGroup_List(t *testing.T) {

View File

@ -11,9 +11,9 @@ import (
type GroupService interface {
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error)
SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error)
SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error)
SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error)
SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error)
SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error)
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
}

View File

@ -366,7 +366,7 @@ func (c *groupCommand) handleSetServer(s *discordgo.Session, i *discordgo.Intera
}
}
_, err := c.svc.SetTWServer(ctx, group, version, server)
_, err := c.svc.SetTWServer(ctx, group, i.GuildID, version, server)
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -402,7 +402,7 @@ func (c *groupCommand) handleSetChannelGains(s *discordgo.Session, i *discordgo.
}
}
_, err := c.svc.SetChannelGains(ctx, group, channel)
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, channel)
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -438,7 +438,7 @@ func (c *groupCommand) handleSetChannelLosses(s *discordgo.Session, i *discordgo
}
}
_, err := c.svc.SetChannelLosses(ctx, group, channel)
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, channel)
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -473,7 +473,7 @@ func (c *groupCommand) handleUnsetChannelGains(s *discordgo.Session, i *discordg
ctx := context.Background()
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
_, err := c.svc.SetChannelGains(ctx, group, "")
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, "")
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
@ -496,7 +496,7 @@ func (c *groupCommand) handleUnsetChannelLosses(s *discordgo.Session, i *discord
ctx := context.Background()
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
_, err := c.svc.SetChannelLosses(ctx, group, "")
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, "")
if err != nil {
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,

View File

@ -16,7 +16,7 @@ const (
//counterfeiter:generate -o internal/mock/group_repository.gen.go . GroupRepository
type GroupRepository interface {
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error)
Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error)
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
}
@ -44,23 +44,8 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
}
}
server, err := g.client.GetServer(ctx, params.VersionCode(), params.ServerKey())
if err != nil {
var apiErr twhelp.APIError
if !errors.As(err, &apiErr) {
return domain.Group{}, fmt.Errorf("TWHelpClient.GetServer: %w", err)
}
return domain.Group{}, domain.ServerDoesNotExistError{
VersionCode: params.VersionCode(),
Key: params.ServerKey(),
}
}
if !server.Open {
return domain.Group{}, domain.ServerIsClosedError{
VersionCode: params.VersionCode(),
Key: params.ServerKey(),
}
if err = g.checkTWServer(ctx, params.VersionCode(), params.ServerKey()); err != nil {
return domain.Group{}, err
}
group, err := g.repo.Create(ctx, params)
@ -71,12 +56,12 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
return group, nil
}
func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error) {
func (g *Group) SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error) {
if err := g.checkTWServer(ctx, versionCode, serverKey); err != nil {
return domain.Group{}, err
}
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
VersionCode: domain.NullString{
String: versionCode,
Valid: true,
@ -87,34 +72,34 @@ func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey stri
},
})
if err != nil {
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
}
return group, nil
}
func (g *Group) SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error) {
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
func (g *Group) SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
ChannelGains: domain.NullString{
String: channel,
Valid: true,
},
})
if err != nil {
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
}
return group, nil
}
func (g *Group) SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error) {
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
func (g *Group) SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
ChannelLosses: domain.NullString{
String: channel,
Valid: true,
},
})
if err != nil {
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
}
return group, nil
}

View File

@ -130,6 +130,7 @@ func TestGroup_SetTWServer(t *testing.T) {
t.Parallel()
id := uuid.NewString()
serverID := uuid.NewString()
versionCode := "pl"
serverKey := "pl181"
@ -137,10 +138,10 @@ func TestGroup_SetTWServer(t *testing.T) {
t.Parallel()
repo := &mock.FakeGroupRepository{}
repo.UpdateByIDCalls(func(_ context.Context, id string, p domain.UpdateGroupParams) (domain.Group, error) {
repo.UpdateCalls(func(_ context.Context, id, serverID string, p domain.UpdateGroupParams) (domain.Group, error) {
return domain.Group{
ID: id,
ServerID: uuid.NewString(),
ServerID: serverID,
ChannelGains: p.ChannelGains.String,
ChannelLosses: p.ChannelLosses.String,
ServerKey: p.ServerKey.String,
@ -159,9 +160,10 @@ func TestGroup_SetTWServer(t *testing.T) {
}, nil
})
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, versionCode, serverKey)
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
assert.NoError(t, err)
assert.Equal(t, id, g.ID)
assert.Equal(t, serverID, g.ServerID)
assert.Equal(t, serverKey, g.ServerKey)
assert.Equal(t, versionCode, g.VersionCode)
})
@ -177,7 +179,7 @@ func TestGroup_SetTWServer(t *testing.T) {
}
})
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
assert.ErrorIs(t, err, domain.ServerDoesNotExistError{
VersionCode: versionCode,
Key: serverKey,
@ -197,7 +199,7 @@ func TestGroup_SetTWServer(t *testing.T) {
}, nil
})
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
assert.ErrorIs(t, err, domain.ServerIsClosedError{
VersionCode: versionCode,
Key: serverKey,