fix(security): server id must be verified in the group repository #15
|
@ -39,7 +39,7 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
|
||||||
return group.ToDomain(), nil
|
return group.ToDomain(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error) {
|
func (g *Group) Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error) {
|
||||||
if params.IsZero() {
|
if params.IsZero() {
|
||||||
return domain.Group{}, domain.ErrNothingToUpdate
|
return domain.Group{}, domain.ErrNothingToUpdate
|
||||||
}
|
}
|
||||||
|
@ -54,6 +54,7 @@ func (g *Group) UpdateByID(ctx context.Context, id string, params domain.UpdateG
|
||||||
Model(&group).
|
Model(&group).
|
||||||
Returning("*").
|
Returning("*").
|
||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
|
Where("server_id = ?", serverID).
|
||||||
Apply(updateGroupsParamsApplier{params}.apply).
|
Apply(updateGroupsParamsApplier{params}.apply).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
|
|
@ -44,17 +44,17 @@ func TestGroup_Create(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGroup_UpdateByID(t *testing.T) {
|
func TestGroup_Update(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
fixture := loadFixtures(t, db)
|
fixture := loadFixtures(t, db)
|
||||||
repo := bundb.NewGroup(db)
|
repo := bundb.NewGroup(db)
|
||||||
|
group := getGroupFromFixture(t, fixture, "group-server-1-1")
|
||||||
|
|
||||||
t.Run("OK", func(t *testing.T) {
|
t.Run("OK", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
group := getGroupFromFixture(t, fixture, "group-server-1-1")
|
|
||||||
params := domain.UpdateGroupParams{
|
params := domain.UpdateGroupParams{
|
||||||
ChannelGains: domain.NullString{
|
ChannelGains: domain.NullString{
|
||||||
String: group.ChannelGains + "update",
|
String: group.ChannelGains + "update",
|
||||||
|
@ -74,7 +74,7 @@ func TestGroup_UpdateByID(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedGroup, err := repo.UpdateByID(context.Background(), group.ID.String(), params)
|
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), group.ServerID, params)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, params.ChannelGains.String, updatedGroup.ChannelGains)
|
assert.Equal(t, params.ChannelGains.String, updatedGroup.ChannelGains)
|
||||||
assert.Equal(t, params.ChannelLosses.String, updatedGroup.ChannelLosses)
|
assert.Equal(t, params.ChannelLosses.String, updatedGroup.ChannelLosses)
|
||||||
|
@ -85,7 +85,7 @@ func TestGroup_UpdateByID(t *testing.T) {
|
||||||
t.Run("ERR: nothing to update", func(t *testing.T) {
|
t.Run("ERR: nothing to update", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
updatedGroup, err := repo.UpdateByID(context.Background(), "", domain.UpdateGroupParams{})
|
updatedGroup, err := repo.Update(context.Background(), "", "", domain.UpdateGroupParams{})
|
||||||
assert.ErrorIs(t, err, domain.ErrNothingToUpdate)
|
assert.ErrorIs(t, err, domain.ErrNothingToUpdate)
|
||||||
assert.Zero(t, updatedGroup)
|
assert.Zero(t, updatedGroup)
|
||||||
})
|
})
|
||||||
|
@ -94,7 +94,7 @@ func TestGroup_UpdateByID(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
id := "12345"
|
id := "12345"
|
||||||
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
|
updatedGroup, err := repo.Update(context.Background(), id, "", domain.UpdateGroupParams{
|
||||||
ChannelGains: domain.NullString{
|
ChannelGains: domain.NullString{
|
||||||
String: "update",
|
String: "update",
|
||||||
Valid: true,
|
Valid: true,
|
||||||
|
@ -104,11 +104,11 @@ func TestGroup_UpdateByID(t *testing.T) {
|
||||||
assert.Zero(t, updatedGroup)
|
assert.Zero(t, updatedGroup)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ERR: group not found", func(t *testing.T) {
|
t.Run("ERR: group not found (unknown ID)", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
id := uuid.NewString()
|
id := uuid.NewString()
|
||||||
updatedGroup, err := repo.UpdateByID(context.Background(), id, domain.UpdateGroupParams{
|
updatedGroup, err := repo.Update(context.Background(), id, group.ServerID, domain.UpdateGroupParams{
|
||||||
ChannelGains: domain.NullString{
|
ChannelGains: domain.NullString{
|
||||||
String: "update",
|
String: "update",
|
||||||
Valid: true,
|
Valid: true,
|
||||||
|
@ -117,6 +117,19 @@ func TestGroup_UpdateByID(t *testing.T) {
|
||||||
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: id})
|
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: id})
|
||||||
assert.Zero(t, updatedGroup)
|
assert.Zero(t, updatedGroup)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("ERR: group not found (unknown ServerID)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
updatedGroup, err := repo.Update(context.Background(), group.ID.String(), uuid.NewString(), domain.UpdateGroupParams{
|
||||||
|
ChannelGains: domain.NullString{
|
||||||
|
String: "update",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.ErrorIs(t, err, domain.GroupNotFoundError{ID: group.ID.String()})
|
||||||
|
assert.Zero(t, updatedGroup)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGroup_List(t *testing.T) {
|
func TestGroup_List(t *testing.T) {
|
||||||
|
|
|
@ -11,9 +11,9 @@ import (
|
||||||
|
|
||||||
type GroupService interface {
|
type GroupService interface {
|
||||||
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
|
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
|
||||||
SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error)
|
SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error)
|
||||||
SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error)
|
SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error)
|
||||||
SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error)
|
SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error)
|
||||||
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
|
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -366,7 +366,7 @@ func (c *groupCommand) handleSetServer(s *discordgo.Session, i *discordgo.Intera
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := c.svc.SetTWServer(ctx, group, version, server)
|
_, err := c.svc.SetTWServer(ctx, group, i.GuildID, version, server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||||
|
@ -402,7 +402,7 @@ func (c *groupCommand) handleSetChannelGains(s *discordgo.Session, i *discordgo.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := c.svc.SetChannelGains(ctx, group, channel)
|
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||||
|
@ -438,7 +438,7 @@ func (c *groupCommand) handleSetChannelLosses(s *discordgo.Session, i *discordgo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := c.svc.SetChannelLosses(ctx, group, channel)
|
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||||
|
@ -473,7 +473,7 @@ func (c *groupCommand) handleUnsetChannelGains(s *discordgo.Session, i *discordg
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
|
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
|
||||||
_, err := c.svc.SetChannelGains(ctx, group, "")
|
_, err := c.svc.SetChannelGains(ctx, group, i.GuildID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||||
|
@ -496,7 +496,7 @@ func (c *groupCommand) handleUnsetChannelLosses(s *discordgo.Session, i *discord
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
|
group := i.ApplicationCommandData().Options[0].Options[0].Options[0].StringValue()
|
||||||
_, err := c.svc.SetChannelLosses(ctx, group, "")
|
_, err := c.svc.SetChannelLosses(ctx, group, i.GuildID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
_ = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
|
||||||
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
Type: discordgo.InteractionResponseChannelMessageWithSource,
|
||||||
|
|
|
@ -16,7 +16,7 @@ const (
|
||||||
//counterfeiter:generate -o internal/mock/group_repository.gen.go . GroupRepository
|
//counterfeiter:generate -o internal/mock/group_repository.gen.go . GroupRepository
|
||||||
type GroupRepository interface {
|
type GroupRepository interface {
|
||||||
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
|
Create(ctx context.Context, params domain.CreateGroupParams) (domain.Group, error)
|
||||||
UpdateByID(ctx context.Context, id string, params domain.UpdateGroupParams) (domain.Group, error)
|
Update(ctx context.Context, id, serverID string, params domain.UpdateGroupParams) (domain.Group, error)
|
||||||
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
|
List(ctx context.Context, params domain.ListGroupsParams) ([]domain.Group, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,23 +44,8 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := g.client.GetServer(ctx, params.VersionCode(), params.ServerKey())
|
if err = g.checkTWServer(ctx, params.VersionCode(), params.ServerKey()); err != nil {
|
||||||
if err != nil {
|
return domain.Group{}, err
|
||||||
var apiErr twhelp.APIError
|
|
||||||
if !errors.As(err, &apiErr) {
|
|
||||||
return domain.Group{}, fmt.Errorf("TWHelpClient.GetServer: %w", err)
|
|
||||||
}
|
|
||||||
return domain.Group{}, domain.ServerDoesNotExistError{
|
|
||||||
VersionCode: params.VersionCode(),
|
|
||||||
Key: params.ServerKey(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !server.Open {
|
|
||||||
return domain.Group{}, domain.ServerIsClosedError{
|
|
||||||
VersionCode: params.VersionCode(),
|
|
||||||
Key: params.ServerKey(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := g.repo.Create(ctx, params)
|
group, err := g.repo.Create(ctx, params)
|
||||||
|
@ -71,12 +56,12 @@ func (g *Group) Create(ctx context.Context, params domain.CreateGroupParams) (do
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey string) (domain.Group, error) {
|
func (g *Group) SetTWServer(ctx context.Context, id, serverID, versionCode, serverKey string) (domain.Group, error) {
|
||||||
if err := g.checkTWServer(ctx, versionCode, serverKey); err != nil {
|
if err := g.checkTWServer(ctx, versionCode, serverKey); err != nil {
|
||||||
return domain.Group{}, err
|
return domain.Group{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
|
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
|
||||||
VersionCode: domain.NullString{
|
VersionCode: domain.NullString{
|
||||||
String: versionCode,
|
String: versionCode,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
|
@ -87,34 +72,34 @@ func (g *Group) SetTWServer(ctx context.Context, id, versionCode, serverKey stri
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
|
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) SetChannelGains(ctx context.Context, id, channel string) (domain.Group, error) {
|
func (g *Group) SetChannelGains(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
|
||||||
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
|
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
|
||||||
ChannelGains: domain.NullString{
|
ChannelGains: domain.NullString{
|
||||||
String: channel,
|
String: channel,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
|
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
|
||||||
}
|
}
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) SetChannelLosses(ctx context.Context, id, channel string) (domain.Group, error) {
|
func (g *Group) SetChannelLosses(ctx context.Context, id, serverID, channel string) (domain.Group, error) {
|
||||||
group, err := g.repo.UpdateByID(ctx, id, domain.UpdateGroupParams{
|
group, err := g.repo.Update(ctx, id, serverID, domain.UpdateGroupParams{
|
||||||
ChannelLosses: domain.NullString{
|
ChannelLosses: domain.NullString{
|
||||||
String: channel,
|
String: channel,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return domain.Group{}, fmt.Errorf("GroupRepository.UpdateByID: %w", err)
|
return domain.Group{}, fmt.Errorf("GroupRepository.Update: %w", err)
|
||||||
}
|
}
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,6 +130,7 @@ func TestGroup_SetTWServer(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
id := uuid.NewString()
|
id := uuid.NewString()
|
||||||
|
serverID := uuid.NewString()
|
||||||
versionCode := "pl"
|
versionCode := "pl"
|
||||||
serverKey := "pl181"
|
serverKey := "pl181"
|
||||||
|
|
||||||
|
@ -137,10 +138,10 @@ func TestGroup_SetTWServer(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
repo := &mock.FakeGroupRepository{}
|
repo := &mock.FakeGroupRepository{}
|
||||||
repo.UpdateByIDCalls(func(_ context.Context, id string, p domain.UpdateGroupParams) (domain.Group, error) {
|
repo.UpdateCalls(func(_ context.Context, id, serverID string, p domain.UpdateGroupParams) (domain.Group, error) {
|
||||||
return domain.Group{
|
return domain.Group{
|
||||||
ID: id,
|
ID: id,
|
||||||
ServerID: uuid.NewString(),
|
ServerID: serverID,
|
||||||
ChannelGains: p.ChannelGains.String,
|
ChannelGains: p.ChannelGains.String,
|
||||||
ChannelLosses: p.ChannelLosses.String,
|
ChannelLosses: p.ChannelLosses.String,
|
||||||
ServerKey: p.ServerKey.String,
|
ServerKey: p.ServerKey.String,
|
||||||
|
@ -159,9 +160,10 @@ func TestGroup_SetTWServer(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, versionCode, serverKey)
|
g, err := service.NewGroup(repo, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, id, g.ID)
|
assert.Equal(t, id, g.ID)
|
||||||
|
assert.Equal(t, serverID, g.ServerID)
|
||||||
assert.Equal(t, serverKey, g.ServerKey)
|
assert.Equal(t, serverKey, g.ServerKey)
|
||||||
assert.Equal(t, versionCode, g.VersionCode)
|
assert.Equal(t, versionCode, g.VersionCode)
|
||||||
})
|
})
|
||||||
|
@ -177,7 +179,7 @@ func TestGroup_SetTWServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
|
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
|
||||||
assert.ErrorIs(t, err, domain.ServerDoesNotExistError{
|
assert.ErrorIs(t, err, domain.ServerDoesNotExistError{
|
||||||
VersionCode: versionCode,
|
VersionCode: versionCode,
|
||||||
Key: serverKey,
|
Key: serverKey,
|
||||||
|
@ -197,7 +199,7 @@ func TestGroup_SetTWServer(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, versionCode, serverKey)
|
g, err := service.NewGroup(nil, client).SetTWServer(context.Background(), id, serverID, versionCode, serverKey)
|
||||||
assert.ErrorIs(t, err, domain.ServerIsClosedError{
|
assert.ErrorIs(t, err, domain.ServerIsClosedError{
|
||||||
VersionCode: versionCode,
|
VersionCode: versionCode,
|
||||||
Key: serverKey,
|
Key: serverKey,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user