feat: add a new REST endpoint - PUT /api/v1/user/sessions/:server (#8)
continuous-integration/drone/push Build is passing Details

Reviewed-on: #8
This commit is contained in:
Dawid Wysokiński 2022-11-24 05:17:32 +00:00
parent 5d5ca74719
commit 8e8e7e1f94
23 changed files with 947 additions and 86 deletions

View File

@ -91,10 +91,12 @@ func newServer(cfg serverConfig) (*http.Server, error) {
// repos
userRepo := bundb.NewUser(cfg.db)
apiKeyRepo := bundb.NewAPIKey(cfg.db)
sessionRepo := bundb.NewSession(cfg.db)
// services
userSvc := service.NewUser(userRepo)
apiKeySvc := service.NewAPIKey(apiKeyRepo, userSvc)
sessionSvc := service.NewSession(sessionRepo, userSvc)
// router
r := chi.NewRouter()
@ -102,6 +104,7 @@ func newServer(cfg serverConfig) (*http.Server, error) {
r.Mount(metaEndpointsPrefix, meta.NewRouter([]meta.Checker{bundb.NewChecker(cfg.db)}))
r.Mount("/api", rest.NewRouter(rest.RouterConfig{
APIKeyVerifier: apiKeySvc,
SessionService: sessionSvc,
CORS: rest.CORSConfig{
Enabled: apiCfg.CORSEnabled,
AllowedOrigins: apiCfg.CORSAllowedOrigins,

View File

@ -25,7 +25,7 @@ func TestAPIKey_Create(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateAPIKeyParams(uuid.NewString(), fixture.user(t, "user-1").ID)
require.Nil(t, err)
require.NoError(t, err)
apiKey, err := repo.Create(context.Background(), params)
assert.NoError(t, err)
@ -35,11 +35,11 @@ func TestAPIKey_Create(t *testing.T) {
assert.WithinDuration(t, time.Now(), apiKey.CreatedAt, time.Second)
})
t.Run("ERR: player doesn't exist", func(t *testing.T) {
t.Run("ERR: user doesn't exist", func(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateAPIKeyParams(uuid.NewString(), fixture.user(t, "user-1").ID+1)
require.Nil(t, err)
require.NoError(t, err)
apiKey, err := repo.Create(context.Background(), params)
assert.ErrorIs(t, err, domain.UserDoesNotExistError{
@ -55,7 +55,7 @@ func TestAPIKey_Create(t *testing.T) {
fixture.apiKey(t, "user-1-api-key-1").Key,
fixture.user(t, "user-1").ID,
)
require.Nil(t, err)
require.NoError(t, err)
apiKey, err := repo.Create(context.Background(), params)
var pgError pgdriver.Error

View File

@ -24,7 +24,7 @@ func NewAPIKey(p domain.CreateAPIKeyParams) APIKey {
}
}
func (a *APIKey) ToDomain() domain.APIKey {
func (a APIKey) ToDomain() domain.APIKey {
return domain.APIKey{
ID: a.ID,
Key: a.Key,

View File

@ -0,0 +1,40 @@
package model
import (
"time"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"github.com/uptrace/bun"
)
type Session struct {
bun.BaseModel `bun:"base_model,table:sessions,alias:session"`
ID int64 `bun:"id,pk,autoincrement,identity"`
UserID int64 `bun:"user_id,nullzero,notnull,unique:sessions_user_id_server_key_key"`
ServerKey string `bun:"server_key,type:varchar(10),nullzero,notnull,unique:sessions_user_id_server_key_key"`
SID string `bun:"sid,type:varchar(512),nullzero,notnull"`
CreatedAt time.Time `bun:"created_at,nullzero,notnull,default:current_timestamp"`
UpdatedAt time.Time `bun:"updated_at,nullzero,notnull,default:current_timestamp"`
}
func NewSession(p domain.CreateSessionParams) Session {
return Session{
UserID: p.UserID(),
ServerKey: p.ServerKey(),
SID: p.SID(),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
}
func (s Session) ToDomain() domain.Session {
return domain.Session{
ID: s.ID,
UserID: s.UserID,
ServerKey: s.ServerKey,
SID: s.SID,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}

View File

@ -0,0 +1,35 @@
package model_test
import (
"encoding/base64"
"testing"
"time"
"gitea.dwysokinski.me/twhelp/sessions/internal/bundb/internal/model"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestSession(t *testing.T) {
t.Parallel()
var id int64 = 123
params, err := domain.NewCreateSessionParams(
"pl151",
base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
1,
)
assert.NoError(t, err)
result := model.NewSession(params)
result.ID = id
apiKey := result.ToDomain()
assert.Equal(t, id, apiKey.ID)
assert.Equal(t, params.ServerKey(), apiKey.ServerKey)
assert.Equal(t, params.SID(), apiKey.SID)
assert.Equal(t, params.UserID(), apiKey.UserID)
assert.WithinDuration(t, time.Now(), apiKey.CreatedAt, 10*time.Millisecond)
assert.WithinDuration(t, time.Now(), apiKey.UpdatedAt, 10*time.Millisecond)
}

View File

@ -0,0 +1,31 @@
package migrations
import (
"context"
"fmt"
"gitea.dwysokinski.me/twhelp/sessions/internal/bundb/internal/model"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error {
if _, err := db.NewCreateTable().
Model(&model.Session{}).
Varchar(defaultVarcharLength).
ForeignKey(`(user_id) REFERENCES users (id) ON DELETE CASCADE`).
Exec(ctx); err != nil {
return fmt.Errorf("couldn't create the 'sessions' table: %w", err)
}
return nil
}, func(ctx context.Context, db *bun.DB) error {
if _, err := db.NewDropTable().
Model(&model.Session{}).
IfExists().
Cascade().
Exec(ctx); err != nil {
return fmt.Errorf("couldn't drop the 'sessions' table: %w", err)
}
return nil
})
}

58
internal/bundb/session.go Normal file
View File

@ -0,0 +1,58 @@
package bundb
import (
"context"
"errors"
"fmt"
"gitea.dwysokinski.me/twhelp/sessions/internal/bundb/internal/model"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"github.com/jackc/pgerrcode"
"github.com/uptrace/bun"
"github.com/uptrace/bun/driver/pgdriver"
)
type Session struct {
db *bun.DB
}
func NewSession(db *bun.DB) *Session {
return &Session{db: db}
}
func (u *Session) CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) {
sess := model.NewSession(params)
if _, err := u.db.NewInsert().
Model(&sess).
On("CONFLICT ON CONSTRAINT sessions_user_id_server_key_key DO UPDATE").
Set("sid = EXCLUDED.sid").
Set("updated_at = EXCLUDED.updated_at").
Returning("*").
Exec(ctx); err != nil {
return domain.Session{}, fmt.Errorf(
"something went wrong while inserting session into the db: %w",
mapCreateSessionError(err, params),
)
}
return sess.ToDomain(), nil
}
func mapCreateSessionError(err error, params domain.CreateSessionParams) error {
var pgError pgdriver.Error
if !errors.As(err, &pgError) {
return err
}
code := pgError.Field('C')
constraint := pgError.Field('n')
switch {
case code == pgerrcode.ForeignKeyViolation && constraint == "sessions_user_id_fkey":
return domain.UserDoesNotExistError{
ID: params.UserID(),
}
default:
return err
}
}

View File

@ -0,0 +1,75 @@
package bundb_test
import (
"context"
"encoding/base64"
"testing"
"time"
"gitea.dwysokinski.me/twhelp/sessions/internal/bundb"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSession_CreateOrUpdate(t *testing.T) {
t.Parallel()
db := newDB(t)
fixture := loadFixtures(t, db)
repo := bundb.NewSession(db)
t.Run("OK", func(t *testing.T) {
t.Parallel()
createParams, err := domain.NewCreateSessionParams(
"pl151",
base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
fixture.user(t, "user-1").ID,
)
require.NoError(t, err)
createdSess, err := repo.CreateOrUpdate(context.Background(), createParams)
assert.NoError(t, err)
assert.Greater(t, createdSess.ID, int64(0))
assert.Equal(t, createParams.UserID(), createdSess.UserID)
assert.Equal(t, createParams.SID(), createdSess.SID)
assert.Equal(t, createParams.ServerKey(), createdSess.ServerKey)
assert.WithinDuration(t, time.Now(), createdSess.CreatedAt, time.Second)
assert.WithinDuration(t, time.Now(), createdSess.UpdatedAt, time.Second)
updateParams, err := domain.NewCreateSessionParams(
createParams.ServerKey(),
base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
createParams.UserID(),
)
require.NoError(t, err)
updatedSess, err := repo.CreateOrUpdate(context.Background(), updateParams)
assert.NoError(t, err)
assert.Equal(t, createdSess.ID, updatedSess.ID)
assert.Equal(t, createdSess.UserID, updatedSess.UserID)
assert.Equal(t, updateParams.SID(), updatedSess.SID)
assert.Equal(t, createdSess.ServerKey, updatedSess.ServerKey)
assert.Equal(t, createdSess.CreatedAt, updatedSess.CreatedAt)
assert.True(t, updatedSess.UpdatedAt.After(createdSess.UpdatedAt))
})
t.Run("ERR: user doesn't exist", func(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateSessionParams(
"pl151",
base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
fixture.user(t, "user-1").ID+1111,
)
require.NoError(t, err)
sess, err := repo.CreateOrUpdate(context.Background(), params)
assert.ErrorIs(t, err, domain.UserDoesNotExistError{
ID: params.UserID(),
})
assert.Zero(t, sess)
})
}

View File

@ -23,7 +23,7 @@ func TestUser_Create(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateUserParams("nameName")
require.Nil(t, err)
require.NoError(t, err)
user, err := repo.Create(context.Background(), params)
assert.NoError(t, err)

View File

@ -5,10 +5,6 @@ import (
"time"
)
const (
apiKeyUserIDMin = 1
)
type APIKey struct {
ID int64
Key string
@ -29,11 +25,11 @@ func NewCreateAPIKeyParams(key string, userID int64) (CreateAPIKeyParams, error)
}
}
if userID < apiKeyUserIDMin {
if userID < userIDMin {
return CreateAPIKeyParams{}, ValidationError{
Field: "UserID",
Err: MinError{
Min: apiKeyUserIDMin,
Min: userIDMin,
},
}
}

View File

@ -7,6 +7,7 @@ import (
var (
ErrRequired = errors.New("cannot be blank")
ErrBase64 = errors.New("must be encoded in Base64")
)
type ErrorCode uint8

View File

@ -0,0 +1,85 @@
package domain
import (
"encoding/base64"
"time"
)
const (
serverMaxLen = 10
)
type Session struct {
ID int64
UserID int64
ServerKey string
SID string
CreatedAt time.Time
UpdatedAt time.Time
}
type CreateSessionParams struct {
userID int64
serverKey string
sid string
}
func NewCreateSessionParams(serverKey, sid string, userID int64) (CreateSessionParams, error) {
if serverKey == "" {
return CreateSessionParams{}, ValidationError{
Field: "ServerKey",
Err: ErrRequired,
}
}
if len(serverKey) > serverMaxLen {
return CreateSessionParams{}, ValidationError{
Field: "ServerKey",
Err: MaxLengthError{
Max: serverMaxLen,
},
}
}
if sid == "" {
return CreateSessionParams{}, ValidationError{
Field: "SID",
Err: ErrRequired,
}
}
if !isBase64(sid) {
return CreateSessionParams{}, ValidationError{
Field: "SID",
Err: ErrBase64,
}
}
if userID < userIDMin {
return CreateSessionParams{}, ValidationError{
Field: "UserID",
Err: MinError{
Min: userIDMin,
},
}
}
return CreateSessionParams{userID: userID, serverKey: serverKey, sid: sid}, nil
}
func (c CreateSessionParams) UserID() int64 {
return c.userID
}
func (c CreateSessionParams) ServerKey() string {
return c.serverKey
}
func (c CreateSessionParams) SID() string {
return c.sid
}
func isBase64(s string) bool {
_, err := base64.StdEncoding.DecodeString(s)
return err == nil
}

View File

@ -0,0 +1,103 @@
package domain_test
import (
"encoding/base64"
"testing"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestNewCreateSessionParams(t *testing.T) {
t.Parallel()
tests := []struct {
name string
serverKey string
sid string
userID int64
expectedErr error
}{
{
name: "OK",
serverKey: "pl151",
sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
userID: 1,
expectedErr: nil,
},
{
name: "ERR: serverKey is required",
serverKey: "",
sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
userID: 1,
expectedErr: domain.ValidationError{
Field: "ServerKey",
Err: domain.ErrRequired,
},
},
{
name: "ERR: len(serverKey) > 10",
serverKey: randString(11),
sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
userID: 1,
expectedErr: domain.ValidationError{
Field: "ServerKey",
Err: domain.MaxLengthError{
Max: 10,
},
},
},
{
name: "ERR: SID is required",
serverKey: "pl151",
sid: "",
userID: 1,
expectedErr: domain.ValidationError{
Field: "SID",
Err: domain.ErrRequired,
},
},
{
name: "ERR: SID is not a valid base64",
serverKey: "pl151",
sid: uuid.NewString(),
userID: 1,
expectedErr: domain.ValidationError{
Field: "SID",
Err: domain.ErrBase64,
},
},
{
name: "ERR: userID < 1",
serverKey: "pl151",
sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
userID: 0,
expectedErr: domain.ValidationError{
Field: "UserID",
Err: domain.MinError{
Min: 1,
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateSessionParams(tt.serverKey, tt.sid, tt.userID)
if tt.expectedErr != nil {
assert.Equal(t, tt.expectedErr, err)
assert.Zero(t, params)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.serverKey, params.ServerKey())
assert.Equal(t, tt.sid, params.SID())
assert.Equal(t, tt.userID, params.UserID())
})
}
}

View File

@ -8,8 +8,9 @@ import (
)
const (
UsernameMinLength = 2
UsernameMaxLength = 40
usernameMinLength = 2
usernameMaxLength = 40
userIDMin = 1
)
var (
@ -36,20 +37,20 @@ func NewCreateUserParams(name string) (CreateUserParams, error) {
}
}
if len(name) < UsernameMinLength {
if len(name) < usernameMinLength {
return CreateUserParams{}, ValidationError{
Field: "Name",
Err: MinLengthError{
Min: UsernameMinLength,
Min: usernameMinLength,
},
}
}
if len(name) > UsernameMaxLength {
if len(name) > usernameMaxLength {
return CreateUserParams{}, ValidationError{
Field: "Name",
Err: MaxLengthError{
Max: UsernameMaxLength,
Max: usernameMaxLength,
},
}
}

View File

@ -0,0 +1,40 @@
package model_test
import (
"testing"
"time"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
"github.com/stretchr/testify/assert"
)
func TestNewUser(t *testing.T) {
t.Parallel()
user := domain.User{
ID: 111,
Name: "Name 111",
CreatedAt: time.Now(),
}
assertUser(t, user, model.NewUser(user))
}
func TestNewGetUserResp(t *testing.T) {
t.Parallel()
user := domain.User{
ID: 111,
Name: "Name 111",
CreatedAt: time.Now(),
}
assertUser(t, user, model.NewGetUserResp(user).Data)
}
func assertUser(tb testing.TB, du domain.User, ru model.User) {
tb.Helper()
assert.Equal(tb, du.ID, ru.ID)
assert.Equal(tb, du.Name, ru.Name)
assert.Equal(tb, du.CreatedAt, ru.CreatedAt)
}

View File

@ -10,6 +10,11 @@ import (
type authCtxKey struct{}
//counterfeiter:generate -o internal/mock/api_key_verifier.gen.go . APIKeyVerifier
type APIKeyVerifier interface {
Verify(ctx context.Context, key string) (domain.User, error)
}
func authMiddleware(verifier APIKeyVerifier) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -1,8 +1,8 @@
package rest
import (
"context"
"encoding/json"
"errors"
"net/http"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
@ -29,13 +29,9 @@ type SwaggerConfig struct {
Schemes []string
}
//counterfeiter:generate -o internal/mock/api_key_verifier.gen.go . APIKeyVerifier
type APIKeyVerifier interface {
Verify(ctx context.Context, key string) (domain.User, error)
}
type RouterConfig struct {
APIKeyVerifier APIKeyVerifier
SessionService SessionService
CORS CORSConfig
Swagger SwaggerConfig
}
@ -43,6 +39,9 @@ type RouterConfig struct {
func NewRouter(cfg RouterConfig) *chi.Mux {
// handlers
uh := userHandler{}
sh := sessionHandler{
svc: cfg.SessionService,
}
router := chi.NewRouter()
@ -73,6 +72,7 @@ func NewRouter(cfg RouterConfig) *chi.Mux {
authMw := authMiddleware(cfg.APIKeyVerifier)
r.With(authMw).Get("/user", uh.getAuthenticated)
r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate)
})
return router
@ -96,53 +96,53 @@ func methodNotAllowedHandler(w http.ResponseWriter, _ *http.Request) {
})
}
// type internalServerError struct {
// err error
// }
//
// func (e internalServerError) Error() string {
// return e.err.Error()
// }
//
// func (e internalServerError) UserError() string {
// return "internal server error"
// }
//
// func (e internalServerError) Code() domain.ErrorCode {
// return domain.ErrorCodeUnknown
// }
//
// func (e internalServerError) Unwrap() error {
// return e.err
// }
//
// func renderErr(w http.ResponseWriter, err error) {
// var userError domain.Error
// if !errors.As(err, &userError) {
// userError = internalServerError{err: err}
// }
// renderJSON(w, errorCodeToHTTPStatus(userError.Code()), model.ErrorResp{
// Error: model.APIError{
// Code: userError.Code().String(),
// Message: userError.UserError(),
// },
// })
// }
//
// func errorCodeToHTTPStatus(code domain.ErrorCode) int {
// switch code {
// case domain.ErrorCodeValidationError:
// return http.StatusBadRequest
// case domain.ErrorCodeEntityNotFound:
// return http.StatusNotFound
// case domain.ErrorCodeAlreadyExists:
// return http.StatusConflict
// case domain.ErrorCodeUnknown:
// fallthrough
// default:
// return http.StatusInternalServerError
// }
// }
type internalServerError struct {
err error
}
func (e internalServerError) Error() string {
return e.err.Error()
}
func (e internalServerError) UserError() string {
return "internal server error"
}
func (e internalServerError) Code() domain.ErrorCode {
return domain.ErrorCodeUnknown
}
func (e internalServerError) Unwrap() error {
return e.err
}
func renderErr(w http.ResponseWriter, err error) {
var domainErr domain.Error
if !errors.As(err, &domainErr) {
domainErr = internalServerError{err: err}
}
renderJSON(w, errorCodeToHTTPStatus(domainErr.Code()), model.ErrorResp{
Error: model.APIError{
Code: domainErr.Code().String(),
Message: domainErr.UserError(),
},
})
}
func errorCodeToHTTPStatus(code domain.ErrorCode) int {
switch code {
case domain.ErrorCodeValidationError:
return http.StatusBadRequest
case domain.ErrorCodeEntityNotFound:
return http.StatusNotFound
case domain.ErrorCodeAlreadyExists:
return http.StatusConflict
case domain.ErrorCodeUnknown:
fallthrough
default:
return http.StatusInternalServerError
}
}
func renderJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")

View File

@ -20,35 +20,31 @@ import (
func TestRouteNotFound(t *testing.T) {
t.Parallel()
expectedResp := &model.ErrorResp{
resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil)
defer resp.Body.Close()
assertJSONResponse(t, resp, http.StatusNotFound, &model.ErrorResp{
Error: model.APIError{
Code: "route-not-found",
Message: "route not found",
},
}
resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil)
defer resp.Body.Close()
assertJSONResponse(t, resp, http.StatusNotFound, expectedResp, &model.ErrorResp{})
}, &model.ErrorResp{})
}
func TestMethodNotAllowed(t *testing.T) {
t.Parallel()
expectedResp := &model.ErrorResp{
Error: model.APIError{
Code: "method-not-allowed",
Message: "method not allowed",
},
}
resp := doRequest(rest.NewRouter(rest.RouterConfig{
Swagger: rest.SwaggerConfig{
Enabled: true,
},
}), http.MethodPost, "/v1/swagger/index.html", "", nil)
defer resp.Body.Close()
assertJSONResponse(t, resp, http.StatusMethodNotAllowed, expectedResp, &model.ErrorResp{})
assertJSONResponse(t, resp, http.StatusMethodNotAllowed, &model.ErrorResp{
Error: model.APIError{
Code: "method-not-allowed",
Message: "method not allowed",
},
}, &model.ErrorResp{})
}
func TestSwagger(t *testing.T) {

View File

@ -0,0 +1,62 @@
package rest
import (
"context"
"io"
"net/http"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"github.com/go-chi/chi/v5"
)
const (
sidMaxBytes = 512
)
//counterfeiter:generate -o internal/mock/session_service.gen.go . SessionService
type SessionService interface {
CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error)
}
type sessionHandler struct {
svc SessionService
}
// @ID createOrUpdateSession
// @Summary Create or update a session
// @Description Create or update a session
// @Tags users,sessions
// @Accept plain
// @Success 204
// @Success 400 {object} model.ErrorResp
// @Success 401 {object} model.ErrorResp
// @Failure 500 {object} model.ErrorResp
// @Security ApiKeyAuth
// @Param serverKey path string true "Server key"
// @Param request body string true "SID"
// @Router /user/sessions/{serverKey} [put]
func (h *sessionHandler) createOrUpdate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chiCtx := chi.RouteContext(ctx)
user, _ := userFromContext(ctx)
b, err := io.ReadAll(io.LimitReader(r.Body, sidMaxBytes))
if err != nil {
renderErr(w, err)
return
}
params, err := domain.NewCreateSessionParams(chiCtx.URLParam("serverKey"), string(b), user.ID)
if err != nil {
renderErr(w, err)
return
}
_, err = h.svc.CreateOrUpdate(ctx, params)
if err != nil {
renderErr(w, err)
return
}
w.WriteHeader(http.StatusNoContent)
}

View File

@ -0,0 +1,218 @@
package rest_test
import (
"context"
"encoding/base64"
"net/http"
"strings"
"testing"
"time"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock"
"gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestSession_createOrUpdate(t *testing.T) {
t.Parallel()
now := time.Now()
apiKey := uuid.NewString()
tests := []struct {
name string
setup func(*mock.FakeAPIKeyVerifier, *mock.FakeSessionService)
apiKey string
serverKey string
sid string
expectedStatus int
target any
expectedResponse any
}{
{
name: "OK",
setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {
apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) {
if key != apiKey {
return domain.User{}, domain.APIKeyNotFoundError{
Key: key,
}
}
return domain.User{
ID: 111,
Name: "name",
CreatedAt: now,
}, nil
})
sessionSvc.CreateOrUpdateReturns(domain.Session{}, nil)
},
apiKey: apiKey,
serverKey: "pl151",
sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
expectedStatus: http.StatusNoContent,
},
{
name: "ERR: len(serverKey) > 10",
setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {
apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) {
if key != apiKey {
return domain.User{}, domain.APIKeyNotFoundError{
Key: key,
}
}
return domain.User{
ID: 111,
Name: "name",
CreatedAt: now,
}, nil
})
},
apiKey: apiKey,
serverKey: "012345678890",
sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
expectedStatus: http.StatusBadRequest,
target: &model.ErrorResp{},
expectedResponse: &model.ErrorResp{
Error: model.APIError{
Code: domain.ErrorCodeValidationError.String(),
Message: "ServerKey: the length must be no more than 10",
},
},
},
{
name: "ERR: SID is required",
setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {
apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) {
if key != apiKey {
return domain.User{}, domain.APIKeyNotFoundError{
Key: key,
}
}
return domain.User{
ID: 111,
Name: "name",
CreatedAt: now,
}, nil
})
},
apiKey: apiKey,
serverKey: "pl151",
sid: "",
expectedStatus: http.StatusBadRequest,
target: &model.ErrorResp{},
expectedResponse: &model.ErrorResp{
Error: model.APIError{
Code: domain.ErrorCodeValidationError.String(),
Message: "SID: cannot be blank",
},
},
},
{
name: "ERR: SID is not a valid base64",
setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {
apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) {
if key != apiKey {
return domain.User{}, domain.APIKeyNotFoundError{
Key: key,
}
}
return domain.User{
ID: 111,
Name: "name",
CreatedAt: now,
}, nil
})
},
apiKey: apiKey,
serverKey: "pl151",
sid: uuid.NewString(),
expectedStatus: http.StatusBadRequest,
target: &model.ErrorResp{},
expectedResponse: &model.ErrorResp{
Error: model.APIError{
Code: domain.ErrorCodeValidationError.String(),
Message: "SID: must be encoded in Base64",
},
},
},
{
name: "ERR: apiKey == \"\"",
setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {},
apiKey: "",
expectedStatus: http.StatusUnauthorized,
serverKey: "pl151",
target: &model.ErrorResp{},
expectedResponse: &model.ErrorResp{
Error: model.APIError{
Code: "unauthorized",
Message: "invalid API key",
},
},
},
{
name: "ERR: unexpected API key",
setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {
apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) {
if key != apiKey {
return domain.User{}, domain.APIKeyNotFoundError{
Key: key,
}
}
return domain.User{
ID: 111,
Name: "name",
CreatedAt: now,
}, nil
})
},
apiKey: uuid.NewString(),
serverKey: "pl151",
expectedStatus: http.StatusUnauthorized,
target: &model.ErrorResp{},
expectedResponse: &model.ErrorResp{
Error: model.APIError{
Code: "unauthorized",
Message: "invalid API key",
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
apiKeySvc := &mock.FakeAPIKeyVerifier{}
sessionSvc := &mock.FakeSessionService{}
tt.setup(apiKeySvc, sessionSvc)
router := rest.NewRouter(rest.RouterConfig{
APIKeyVerifier: apiKeySvc,
SessionService: sessionSvc,
})
resp := doRequest(
router,
http.MethodPut,
"/v1/user/sessions/"+tt.serverKey,
tt.apiKey,
strings.NewReader(tt.sid),
)
defer resp.Body.Close()
if tt.expectedStatus == http.StatusNoContent {
assert.Equal(t, tt.expectedStatus, resp.StatusCode)
return
}
assertJSONResponse(t, resp, tt.expectedStatus, tt.expectedResponse, tt.target)
})
}
}

View File

@ -114,5 +114,4 @@ func TestUser_getAuthenticated(t *testing.T) {
assertJSONResponse(t, resp, tt.expectedStatus, tt.expectedResponse, tt.target)
})
}
}

View File

@ -0,0 +1,39 @@
package service
import (
"context"
"errors"
"fmt"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
)
//counterfeiter:generate -o internal/mock/session_repository.gen.go . SessionRepository
type SessionRepository interface {
CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error)
}
type Session struct {
repo SessionRepository
userSvc UserGetter
}
func NewSession(repo SessionRepository, userSvc UserGetter) *Session {
return &Session{repo: repo, userSvc: userSvc}
}
func (s *Session) CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) {
if _, err := s.userSvc.Get(ctx, params.UserID()); err != nil {
if errors.Is(err, domain.UserNotFoundError{ID: params.UserID()}) {
return domain.Session{}, domain.UserDoesNotExistError{ID: params.UserID()}
}
return domain.Session{}, fmt.Errorf("UserService.Get: %w", err)
}
sess, err := s.repo.CreateOrUpdate(ctx, params)
if err != nil {
return domain.Session{}, fmt.Errorf("SessionRepository.CreateOrUpdate: %w", err)
}
return sess, nil
}

View File

@ -0,0 +1,74 @@
package service_test
import (
"context"
"encoding/base64"
"testing"
"time"
"gitea.dwysokinski.me/twhelp/sessions/internal/domain"
"gitea.dwysokinski.me/twhelp/sessions/internal/service"
"gitea.dwysokinski.me/twhelp/sessions/internal/service/internal/mock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSession_CreateOrUpdate(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateSessionParams(
"pl151",
base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
111,
)
require.NoError(t, err)
repo := &mock.FakeSessionRepository{}
repo.CreateOrUpdateCalls(func(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) {
return domain.Session{
ID: 1111,
UserID: params.UserID(),
ServerKey: params.ServerKey(),
SID: params.SID(),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
})
userSvc := &mock.FakeUserGetter{}
userSvc.GetReturns(domain.User{
ID: params.UserID(),
}, nil)
sess, err := service.NewSession(repo, userSvc).CreateOrUpdate(context.Background(), params)
assert.NoError(t, err)
assert.Greater(t, sess.ID, int64(0))
assert.Equal(t, params.UserID(), sess.UserID)
assert.Equal(t, params.SID(), sess.SID)
assert.Equal(t, params.ServerKey(), sess.ServerKey)
assert.WithinDuration(t, time.Now(), sess.CreatedAt, 10*time.Millisecond)
assert.WithinDuration(t, time.Now(), sess.UpdatedAt, 10*time.Millisecond)
})
t.Run("ERR: user doesnt exist", func(t *testing.T) {
t.Parallel()
params, err := domain.NewCreateSessionParams(
"pl151",
base64.StdEncoding.EncodeToString([]byte(uuid.NewString())),
111,
)
require.NoError(t, err)
userSvc := &mock.FakeUserGetter{}
userSvc.GetReturns(domain.User{}, domain.UserNotFoundError{ID: params.UserID()})
sess, err := service.NewSession(nil, userSvc).CreateOrUpdate(context.Background(), params)
assert.ErrorIs(t, err, domain.UserDoesNotExistError{ID: params.UserID()})
assert.Zero(t, sess)
})
}