fix(security): server id must be verified in the group repository
All checks were successful
continuous-integration/drone/pr Build is passing
All checks were successful
continuous-integration/drone/pr Build is passing
This commit is contained in:
parent
9c1884d624
commit
90f7a5476b
|
@ -39,7 +39,7 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
|
|||
return group.ToDomain(), nil
|
||||
}
|
||||
|
||||
func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error) {
|
||||
func (g *Group) Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error) {
|
||||
if params.IsZero() {
|
||||
return domain.Group{}, domain.ErrNothingToUpdate
|
||||
}
|
||||
|
@ -54,6 +54,7 @@ func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateG
|
|||
Model(&group).
|
||||
Returning("*").
|
||||
Where("id = ?", id).
|
||||
Where("server_id = ?", serverID).
|
||||
Apply(updateGroupsParamsApplier{params}.apply).
|
||||
Exec(ctx)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
|
@ -44,17 +44,17 @@ func TestGroup_Create(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestGroup_UpdateByID(t *testing.T) {
|
||||
func TestGroup_Update(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := newDB(t)
|
||||
fixture := loadFixtures(t, db)
|
||||
repo := bundb.NewGroup(db)
|
||||
group := getGroupFromFixture(t, fixture, "group-server-1-1")
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := getGroupFromFixture(t, fixture, "group-server-1-1")
|
||||
params := domain.UpdateGroupParams{
|
||||
ChannelGains: domain.NullString{
|
||||
String: group.ChannelGains + "update",
|
||||
|
@ -74,7 +74,7 @@ func TestGroup_UpdateByID(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
updatedGroup, err := repo.UpdateByID(context.Background(), group.ID.String(), params)
|
||||
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), group.ServerID, params)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, params.ChannelGains.String, updatedGroup.ChannelGains)
|
||||
assert.Equal(t, params.ChannelLosses.String, updatedGroup.ChannelLosses)
|
||||
|
@ -85,7 +85,7 @@ func TestGroup_UpdateByID(t *testing.T) {
|
|||
t.Run("ERR: nothing to update", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
updatedGroup, err := repo.UpdateByID(context.Background(), "", domain.UpdateGroupParams{})
|
||||
updatedGroup, err := repo.Update(context.Background(), "", "", domain.UpdateGroupParams{})
|
||||
assert.ErrorIs(t, err, domain.ErrNothingToUpdate)
|
||||
assert.Zero(t, updatedGroup)
|
||||
})
|
||||
|
@ -94,7 +94,7 @@ func TestGroup_UpdateByID(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
id := "12345"
|
||||
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
|
||||
updatedGroup, err := repo.Update(context.Background(), id, "", domain.UpdateGroupParams{
|
||||
ChannelGains: domain.NullString{
|
||||
String: "update",
|
||||
Valid: true,
|
||||
|
@ -104,11 +104,11 @@ func TestGroup_UpdateByID(t *testing.T) {
|
|||
assert.Zero(t, updatedGroup)
|
||||
})
|
||||
|
||||
t.Run("ERR: group not found", func(t *testing.T) {
|
||||
t.Run("ERR: group not found (unknown ID)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id := uuid.NewString()
|
||||
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
|
||||
updatedGroup, err := repo.Update(context.Background(), id, group.ServerID, domain.UpdateGroupParams{
|
||||
ChannelGains: domain.NullString{
|
||||
String: "update",
|
||||
Valid: true,
|
||||
|
@ -117,6 +117,19 @@ func TestGroup_UpdateByID(t *testing.T) {
|
|||
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: id})
|
||||
assert.Zero(t, updatedGroup)
|
||||
})
|
||||
|
||||
t.Run("ERR: group not found (unknown ServerID)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), uuid.NewString(), domain.UpdateGroupParams{
|
||||
ChannelGains: domain.NullString{
|
||||
String: "update",
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: group.ID.String()})
|
||||
assert.Zero(t, updatedGroup)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroup_List(t *testing.T) {
|
||||
|
|
|
@ -11,9 +11,9 @@ import (
|
|||
|
||||
type GroupService interface {
|
||||
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
|
||||
SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error)
|
||||
SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error)
|
||||
SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error)
|
||||
SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error)
|
||||
SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error)
|
||||
SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error)
|
||||
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
|
||||
}
|
||||
|
||||
|
|
|
@ -366,7 +366,7 @@ func (c *groupCommand) handleSetServer(s *discordgo.Session, i *discordgo.Intera
|
|||
}
|
||||
}
|
||||
|
||||
_, err := c.svc.SetTWServer(ctx, group, version, server)
|
||||
_, err := c.svc.SetTWServer(ctx, group, i.GuildID, version, server)
|
||||
if err != nil {
|
||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
|
@ -402,7 +402,7 @@ func (c *groupCommand) handleSetChannelGains(s *discordgo.Session, i *discordgo.
|
|||
}
|
||||
}
|
||||
|
||||
_, err := c.svc.SetChannelGains(ctx, group, channel)
|
||||
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, channel)
|
||||
if err != nil {
|
||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
|
@ -438,7 +438,7 @@ func (c *groupCommand) handleSetChannelLosses(s *discordgo.Session, i *discordgo
|
|||
}
|
||||
}
|
||||
|
||||
_, err := c.svc.SetChannelLosses(ctx, group, channel)
|
||||
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, channel)
|
||||
if err != nil {
|
||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
|
@ -473,7 +473,7 @@ func (c *groupCommand) handleUnsetChannelGains(s *discordgo.Session, i *discordg
|
|||
ctx := context.Background()
|
||||
|
||||
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
|
||||
_, err := c.svc.SetChannelGains(ctx, group, "")
|
||||
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, "")
|
||||
if err != nil {
|
||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
|
@ -496,7 +496,7 @@ func (c *groupCommand) handleUnsetChannelLosses(s *discordgo.Session, i *discord
|
|||
ctx := context.Background()
|
||||
|
||||
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
|
||||
_, err := c.svc.SetChannelLosses(ctx, group, "")
|
||||
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, "")
|
||||
if err != nil {
|
||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||
|
|
|
@ -16,7 +16,7 @@ const (
|
|||
//counterfeiter:generate -o internal/mock/group_repository.gen.go . GroupRepository
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
|
||||
UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error)
|
||||
Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error)
|
||||
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
|
||||
}
|
||||
|
||||
|
@ -44,23 +44,8 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
|
|||
}
|
||||
}
|
||||
|
||||
server, err := g.client.GetServer(ctx, params.VersionCode(), params.ServerKey())
|
||||
if err != nil {
|
||||
var apiErr twhelp.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
return domain.Group{}, fmt.Errorf("TWHelpClient.GetServer: %w", err)
|
||||
}
|
||||
return domain.Group{}, domain.ServerDoesNotExistError{
|
||||
VersionCode: params.VersionCode(),
|
||||
Key: params.ServerKey(),
|
||||
}
|
||||
}
|
||||
|
||||
if !server.Open {
|
||||
return domain.Group{}, domain.ServerIsClosedError{
|
||||
VersionCode: params.VersionCode(),
|
||||
Key: params.ServerKey(),
|
||||
}
|
||||
if err = g.checkTWServer(ctx, params.VersionCode(), params.ServerKey()); err != nil {
|
||||
return domain.Group{}, err
|
||||
}
|
||||
|
||||
group, err := g.repo.Create(ctx, params)
|
||||
|
@ -71,12 +56,12 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
|
|||
return group, nil
|
||||
}
|
||||
|
||||
func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error) {
|
||||
func (g *Group) SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error) {
|
||||
if err := g.checkTWServer(ctx, versionCode, serverKey); err != nil {
|
||||
return domain.Group{}, err
|
||||
}
|
||||
|
||||
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
|
||||
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
|
||||
VersionCode: domain.NullString{
|
||||
String: versionCode,
|
||||
Valid: true,
|
||||
|
@ -87,34 +72,34 @@ func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey stri
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
|
||||
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (g *Group) SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error) {
|
||||
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
|
||||
func (g *Group) SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
|
||||
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
|
||||
ChannelGains: domain.NullString{
|
||||
String: channel,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
|
||||
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (g *Group) SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error) {
|
||||
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
|
||||
func (g *Group) SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
|
||||
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
|
||||
ChannelLosses: domain.NullString{
|
||||
String: channel,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
|
||||
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
|
|
@ -130,6 +130,7 @@ func TestGroup_SetTWServer(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
id := uuid.NewString()
|
||||
serverID := uuid.NewString()
|
||||
versionCode := "pl"
|
||||
serverKey := "pl181"
|
||||
|
||||
|
@ -137,10 +138,10 @@ func TestGroup_SetTWServer(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
repo := &mock.FakeGroupRepository{}
|
||||
repo.UpdateByIDCalls(func(_ context.Context, id string, p domain.UpdateGroupParams) (domain.Group, error) {
|
||||
repo.UpdateCalls(func(_ context.Context, id, serverID string, p domain.UpdateGroupParams) (domain.Group, error) {
|
||||
return domain.Group{
|
||||
ID: id,
|
||||
ServerID: uuid.NewString(),
|
||||
ServerID: serverID,
|
||||
ChannelGains: p.ChannelGains.String,
|
||||
ChannelLosses: p.ChannelLosses.String,
|
||||
ServerKey: p.ServerKey.String,
|
||||
|
@ -159,9 +160,10 @@ func TestGroup_SetTWServer(t *testing.T) {
|
|||
}, nil
|
||||
})
|
||||
|
||||
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, versionCode, serverKey)
|
||||
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, id, g.ID)
|
||||
assert.Equal(t, serverID, g.ServerID)
|
||||
assert.Equal(t, serverKey, g.ServerKey)
|
||||
assert.Equal(t, versionCode, g.VersionCode)
|
||||
})
|
||||
|
@ -177,7 +179,7 @@ func TestGroup_SetTWServer(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
|
||||
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
|
||||
assert.ErrorIs(t, err, domain.ServerDoesNotExistError{
|
||||
VersionCode: versionCode,
|
||||
Key: serverKey,
|
||||
|
@ -197,7 +199,7 @@ func TestGroup_SetTWServer(t *testing.T) {
|
|||
}, nil
|
||||
})
|
||||
|
||||
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
|
||||
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
|
||||
assert.ErrorIs(t, err, domain.ServerIsClosedError{
|
||||
VersionCode: versionCode,
|
||||
Key: serverKey,
|
||||
|
|
Loading…
Reference in New Issue
Block a user