diff --git a/cmd/sessions/internal/serve/serve.go b/cmd/sessions/internal/serve/serve.go index 3a49257..73ba442 100644 --- a/cmd/sessions/internal/serve/serve.go +++ b/cmd/sessions/internal/serve/serve.go @@ -91,10 +91,12 @@ func newServer(cfg serverConfig) (*http.Server, error) { // repos userRepo := bundb.NewUser(cfg.db) apiKeyRepo := bundb.NewAPIKey(cfg.db) + sessionRepo := bundb.NewSession(cfg.db) // services userSvc := service.NewUser(userRepo) apiKeySvc := service.NewAPIKey(apiKeyRepo, userSvc) + sessionSvc := service.NewSession(sessionRepo, userSvc) // router r := chi.NewRouter() @@ -102,6 +104,7 @@ func newServer(cfg serverConfig) (*http.Server, error) { r.Mount(metaEndpointsPrefix, meta.NewRouter([]meta.Checker{bundb.NewChecker(cfg.db)})) r.Mount("/api", rest.NewRouter(rest.RouterConfig{ APIKeyVerifier: apiKeySvc, + SessionService: sessionSvc, CORS: rest.CORSConfig{ Enabled: apiCfg.CORSEnabled, AllowedOrigins: apiCfg.CORSAllowedOrigins, diff --git a/internal/bundb/api_key_test.go b/internal/bundb/api_key_test.go index e7ae798..80d596b 100644 --- a/internal/bundb/api_key_test.go +++ b/internal/bundb/api_key_test.go @@ -25,7 +25,7 @@ func TestAPIKey_Create(t *testing.T) { t.Parallel() params, err := domain.NewCreateAPIKeyParams(uuid.NewString(), fixture.user(t, "user-1").ID) - require.Nil(t, err) + require.NoError(t, err) apiKey, err := repo.Create(context.Background(), params) assert.NoError(t, err) @@ -35,11 +35,11 @@ func TestAPIKey_Create(t *testing.T) { assert.WithinDuration(t, time.Now(), apiKey.CreatedAt, time.Second) }) - t.Run("ERR: player doesn't exist", func(t *testing.T) { + t.Run("ERR: user doesn't exist", func(t *testing.T) { t.Parallel() params, err := domain.NewCreateAPIKeyParams(uuid.NewString(), fixture.user(t, "user-1").ID+1) - require.Nil(t, err) + require.NoError(t, err) apiKey, err := repo.Create(context.Background(), params) assert.ErrorIs(t, err, domain.UserDoesNotExistError{ @@ -55,7 +55,7 @@ func TestAPIKey_Create(t *testing.T) { fixture.apiKey(t, "user-1-api-key-1").Key, fixture.user(t, "user-1").ID, ) - require.Nil(t, err) + require.NoError(t, err) apiKey, err := repo.Create(context.Background(), params) var pgError pgdriver.Error diff --git a/internal/bundb/internal/model/api_key.go b/internal/bundb/internal/model/api_key.go index 3d54246..761d2ed 100644 --- a/internal/bundb/internal/model/api_key.go +++ b/internal/bundb/internal/model/api_key.go @@ -24,7 +24,7 @@ func NewAPIKey(p domain.CreateAPIKeyParams) APIKey { } } -func (a *APIKey) ToDomain() domain.APIKey { +func (a APIKey) ToDomain() domain.APIKey { return domain.APIKey{ ID: a.ID, Key: a.Key, diff --git a/internal/bundb/internal/model/session.go b/internal/bundb/internal/model/session.go new file mode 100644 index 0000000..b1de00f --- /dev/null +++ b/internal/bundb/internal/model/session.go @@ -0,0 +1,40 @@ +package model + +import ( + "time" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "github.com/uptrace/bun" +) + +type Session struct { + bun.BaseModel `bun:"base_model,table:sessions,alias:session"` + + ID int64 `bun:"id,pk,autoincrement,identity"` + UserID int64 `bun:"user_id,nullzero,notnull,unique:sessions_user_id_server_key_key"` + ServerKey string `bun:"server_key,type:varchar(10),nullzero,notnull,unique:sessions_user_id_server_key_key"` + SID string `bun:"sid,type:varchar(512),nullzero,notnull"` + CreatedAt time.Time `bun:"created_at,nullzero,notnull,default:current_timestamp"` + UpdatedAt time.Time `bun:"updated_at,nullzero,notnull,default:current_timestamp"` +} + +func NewSession(p domain.CreateSessionParams) Session { + return Session{ + UserID: p.UserID(), + ServerKey: p.ServerKey(), + SID: p.SID(), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +} + +func (s Session) ToDomain() domain.Session { + return domain.Session{ + ID: s.ID, + UserID: s.UserID, + ServerKey: s.ServerKey, + SID: s.SID, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + } +} diff --git a/internal/bundb/internal/model/session_test.go b/internal/bundb/internal/model/session_test.go new file mode 100644 index 0000000..b251380 --- /dev/null +++ b/internal/bundb/internal/model/session_test.go @@ -0,0 +1,35 @@ +package model_test + +import ( + "encoding/base64" + "testing" + "time" + + "gitea.dwysokinski.me/twhelp/sessions/internal/bundb/internal/model" + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestSession(t *testing.T) { + t.Parallel() + + var id int64 = 123 + params, err := domain.NewCreateSessionParams( + "pl151", + base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + 1, + ) + assert.NoError(t, err) + + result := model.NewSession(params) + result.ID = id + + apiKey := result.ToDomain() + assert.Equal(t, id, apiKey.ID) + assert.Equal(t, params.ServerKey(), apiKey.ServerKey) + assert.Equal(t, params.SID(), apiKey.SID) + assert.Equal(t, params.UserID(), apiKey.UserID) + assert.WithinDuration(t, time.Now(), apiKey.CreatedAt, 10*time.Millisecond) + assert.WithinDuration(t, time.Now(), apiKey.UpdatedAt, 10*time.Millisecond) +} diff --git a/internal/bundb/migrations/20221122050217_create_sessions_table.go b/internal/bundb/migrations/20221122050217_create_sessions_table.go new file mode 100644 index 0000000..91eaee4 --- /dev/null +++ b/internal/bundb/migrations/20221122050217_create_sessions_table.go @@ -0,0 +1,31 @@ +package migrations + +import ( + "context" + "fmt" + + "gitea.dwysokinski.me/twhelp/sessions/internal/bundb/internal/model" + "github.com/uptrace/bun" +) + +func init() { + Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error { + if _, err := db.NewCreateTable(). + Model(&model.Session{}). + Varchar(defaultVarcharLength). + ForeignKey(`(user_id) REFERENCES users (id) ON DELETE CASCADE`). + Exec(ctx); err != nil { + return fmt.Errorf("couldn't create the 'sessions' table: %w", err) + } + return nil + }, func(ctx context.Context, db *bun.DB) error { + if _, err := db.NewDropTable(). + Model(&model.Session{}). + IfExists(). + Cascade(). + Exec(ctx); err != nil { + return fmt.Errorf("couldn't drop the 'sessions' table: %w", err) + } + return nil + }) +} diff --git a/internal/bundb/session.go b/internal/bundb/session.go new file mode 100644 index 0000000..de1a242 --- /dev/null +++ b/internal/bundb/session.go @@ -0,0 +1,58 @@ +package bundb + +import ( + "context" + "errors" + "fmt" + + "gitea.dwysokinski.me/twhelp/sessions/internal/bundb/internal/model" + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "github.com/jackc/pgerrcode" + "github.com/uptrace/bun" + "github.com/uptrace/bun/driver/pgdriver" +) + +type Session struct { + db *bun.DB +} + +func NewSession(db *bun.DB) *Session { + return &Session{db: db} +} + +func (u *Session) CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) { + sess := model.NewSession(params) + + if _, err := u.db.NewInsert(). + Model(&sess). + On("CONFLICT ON CONSTRAINT sessions_user_id_server_key_key DO UPDATE"). + Set("sid = EXCLUDED.sid"). + Set("updated_at = EXCLUDED.updated_at"). + Returning("*"). + Exec(ctx); err != nil { + return domain.Session{}, fmt.Errorf( + "something went wrong while inserting session into the db: %w", + mapCreateSessionError(err, params), + ) + } + + return sess.ToDomain(), nil +} + +func mapCreateSessionError(err error, params domain.CreateSessionParams) error { + var pgError pgdriver.Error + if !errors.As(err, &pgError) { + return err + } + + code := pgError.Field('C') + constraint := pgError.Field('n') + switch { + case code == pgerrcode.ForeignKeyViolation && constraint == "sessions_user_id_fkey": + return domain.UserDoesNotExistError{ + ID: params.UserID(), + } + default: + return err + } +} diff --git a/internal/bundb/session_test.go b/internal/bundb/session_test.go new file mode 100644 index 0000000..6012819 --- /dev/null +++ b/internal/bundb/session_test.go @@ -0,0 +1,75 @@ +package bundb_test + +import ( + "context" + "encoding/base64" + "testing" + "time" + + "gitea.dwysokinski.me/twhelp/sessions/internal/bundb" + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSession_CreateOrUpdate(t *testing.T) { + t.Parallel() + + db := newDB(t) + fixture := loadFixtures(t, db) + repo := bundb.NewSession(db) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + createParams, err := domain.NewCreateSessionParams( + "pl151", + base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + fixture.user(t, "user-1").ID, + ) + require.NoError(t, err) + + createdSess, err := repo.CreateOrUpdate(context.Background(), createParams) + assert.NoError(t, err) + assert.Greater(t, createdSess.ID, int64(0)) + assert.Equal(t, createParams.UserID(), createdSess.UserID) + assert.Equal(t, createParams.SID(), createdSess.SID) + assert.Equal(t, createParams.ServerKey(), createdSess.ServerKey) + assert.WithinDuration(t, time.Now(), createdSess.CreatedAt, time.Second) + assert.WithinDuration(t, time.Now(), createdSess.UpdatedAt, time.Second) + + updateParams, err := domain.NewCreateSessionParams( + createParams.ServerKey(), + base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + createParams.UserID(), + ) + require.NoError(t, err) + + updatedSess, err := repo.CreateOrUpdate(context.Background(), updateParams) + assert.NoError(t, err) + assert.Equal(t, createdSess.ID, updatedSess.ID) + assert.Equal(t, createdSess.UserID, updatedSess.UserID) + assert.Equal(t, updateParams.SID(), updatedSess.SID) + assert.Equal(t, createdSess.ServerKey, updatedSess.ServerKey) + assert.Equal(t, createdSess.CreatedAt, updatedSess.CreatedAt) + assert.True(t, updatedSess.UpdatedAt.After(createdSess.UpdatedAt)) + }) + + t.Run("ERR: user doesn't exist", func(t *testing.T) { + t.Parallel() + + params, err := domain.NewCreateSessionParams( + "pl151", + base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + fixture.user(t, "user-1").ID+1111, + ) + require.NoError(t, err) + + sess, err := repo.CreateOrUpdate(context.Background(), params) + assert.ErrorIs(t, err, domain.UserDoesNotExistError{ + ID: params.UserID(), + }) + assert.Zero(t, sess) + }) +} diff --git a/internal/bundb/user_test.go b/internal/bundb/user_test.go index cfd5160..74a6ea7 100644 --- a/internal/bundb/user_test.go +++ b/internal/bundb/user_test.go @@ -23,7 +23,7 @@ func TestUser_Create(t *testing.T) { t.Parallel() params, err := domain.NewCreateUserParams("nameName") - require.Nil(t, err) + require.NoError(t, err) user, err := repo.Create(context.Background(), params) assert.NoError(t, err) diff --git a/internal/domain/api_key.go b/internal/domain/api_key.go index b3e73e5..01f100b 100644 --- a/internal/domain/api_key.go +++ b/internal/domain/api_key.go @@ -5,10 +5,6 @@ import ( "time" ) -const ( - apiKeyUserIDMin = 1 -) - type APIKey struct { ID int64 Key string @@ -29,11 +25,11 @@ func NewCreateAPIKeyParams(key string, userID int64) (CreateAPIKeyParams, error) } } - if userID < apiKeyUserIDMin { + if userID < userIDMin { return CreateAPIKeyParams{}, ValidationError{ Field: "UserID", Err: MinError{ - Min: apiKeyUserIDMin, + Min: userIDMin, }, } } diff --git a/internal/domain/error.go b/internal/domain/error.go index cea78f9..9ea60b6 100644 --- a/internal/domain/error.go +++ b/internal/domain/error.go @@ -7,6 +7,7 @@ import ( var ( ErrRequired = errors.New("cannot be blank") + ErrBase64 = errors.New("must be encoded in Base64") ) type ErrorCode uint8 diff --git a/internal/domain/session.go b/internal/domain/session.go new file mode 100644 index 0000000..02640f9 --- /dev/null +++ b/internal/domain/session.go @@ -0,0 +1,85 @@ +package domain + +import ( + "encoding/base64" + "time" +) + +const ( + serverMaxLen = 10 +) + +type Session struct { + ID int64 + UserID int64 + ServerKey string + SID string + CreatedAt time.Time + UpdatedAt time.Time +} + +type CreateSessionParams struct { + userID int64 + serverKey string + sid string +} + +func NewCreateSessionParams(serverKey, sid string, userID int64) (CreateSessionParams, error) { + if serverKey == "" { + return CreateSessionParams{}, ValidationError{ + Field: "ServerKey", + Err: ErrRequired, + } + } + + if len(serverKey) > serverMaxLen { + return CreateSessionParams{}, ValidationError{ + Field: "ServerKey", + Err: MaxLengthError{ + Max: serverMaxLen, + }, + } + } + + if sid == "" { + return CreateSessionParams{}, ValidationError{ + Field: "SID", + Err: ErrRequired, + } + } + + if !isBase64(sid) { + return CreateSessionParams{}, ValidationError{ + Field: "SID", + Err: ErrBase64, + } + } + + if userID < userIDMin { + return CreateSessionParams{}, ValidationError{ + Field: "UserID", + Err: MinError{ + Min: userIDMin, + }, + } + } + + return CreateSessionParams{userID: userID, serverKey: serverKey, sid: sid}, nil +} + +func (c CreateSessionParams) UserID() int64 { + return c.userID +} + +func (c CreateSessionParams) ServerKey() string { + return c.serverKey +} + +func (c CreateSessionParams) SID() string { + return c.sid +} + +func isBase64(s string) bool { + _, err := base64.StdEncoding.DecodeString(s) + return err == nil +} diff --git a/internal/domain/session_test.go b/internal/domain/session_test.go new file mode 100644 index 0000000..ff391d6 --- /dev/null +++ b/internal/domain/session_test.go @@ -0,0 +1,103 @@ +package domain_test + +import ( + "encoding/base64" + "testing" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewCreateSessionParams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverKey string + sid string + userID int64 + expectedErr error + }{ + { + name: "OK", + serverKey: "pl151", + sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + userID: 1, + expectedErr: nil, + }, + { + name: "ERR: serverKey is required", + serverKey: "", + sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + userID: 1, + expectedErr: domain.ValidationError{ + Field: "ServerKey", + Err: domain.ErrRequired, + }, + }, + { + name: "ERR: len(serverKey) > 10", + serverKey: randString(11), + sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + userID: 1, + expectedErr: domain.ValidationError{ + Field: "ServerKey", + Err: domain.MaxLengthError{ + Max: 10, + }, + }, + }, + { + name: "ERR: SID is required", + serverKey: "pl151", + sid: "", + userID: 1, + expectedErr: domain.ValidationError{ + Field: "SID", + Err: domain.ErrRequired, + }, + }, + { + name: "ERR: SID is not a valid base64", + serverKey: "pl151", + sid: uuid.NewString(), + userID: 1, + expectedErr: domain.ValidationError{ + Field: "SID", + Err: domain.ErrBase64, + }, + }, + { + name: "ERR: userID < 1", + serverKey: "pl151", + sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + userID: 0, + expectedErr: domain.ValidationError{ + Field: "UserID", + Err: domain.MinError{ + Min: 1, + }, + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + params, err := domain.NewCreateSessionParams(tt.serverKey, tt.sid, tt.userID) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + assert.Zero(t, params) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.serverKey, params.ServerKey()) + assert.Equal(t, tt.sid, params.SID()) + assert.Equal(t, tt.userID, params.UserID()) + }) + } +} diff --git a/internal/domain/user.go b/internal/domain/user.go index 1ee0153..b29ed4a 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -8,8 +8,9 @@ import ( ) const ( - UsernameMinLength = 2 - UsernameMaxLength = 40 + usernameMinLength = 2 + usernameMaxLength = 40 + userIDMin = 1 ) var ( @@ -36,20 +37,20 @@ func NewCreateUserParams(name string) (CreateUserParams, error) { } } - if len(name) < UsernameMinLength { + if len(name) < usernameMinLength { return CreateUserParams{}, ValidationError{ Field: "Name", Err: MinLengthError{ - Min: UsernameMinLength, + Min: usernameMinLength, }, } } - if len(name) > UsernameMaxLength { + if len(name) > usernameMaxLength { return CreateUserParams{}, ValidationError{ Field: "Name", Err: MaxLengthError{ - Max: UsernameMaxLength, + Max: usernameMaxLength, }, } } diff --git a/internal/router/rest/internal/model/user_test.go b/internal/router/rest/internal/model/user_test.go new file mode 100644 index 0000000..56a44a8 --- /dev/null +++ b/internal/router/rest/internal/model/user_test.go @@ -0,0 +1,40 @@ +package model_test + +import ( + "testing" + "time" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" + "github.com/stretchr/testify/assert" +) + +func TestNewUser(t *testing.T) { + t.Parallel() + + user := domain.User{ + ID: 111, + Name: "Name 111", + CreatedAt: time.Now(), + } + assertUser(t, user, model.NewUser(user)) +} + +func TestNewGetUserResp(t *testing.T) { + t.Parallel() + + user := domain.User{ + ID: 111, + Name: "Name 111", + CreatedAt: time.Now(), + } + assertUser(t, user, model.NewGetUserResp(user).Data) +} + +func assertUser(tb testing.TB, du domain.User, ru model.User) { + tb.Helper() + + assert.Equal(tb, du.ID, ru.ID) + assert.Equal(tb, du.Name, ru.Name) + assert.Equal(tb, du.CreatedAt, ru.CreatedAt) +} diff --git a/internal/router/rest/mw_auth.go b/internal/router/rest/mw_auth.go index 9b181c7..db6f20b 100644 --- a/internal/router/rest/mw_auth.go +++ b/internal/router/rest/mw_auth.go @@ -10,6 +10,11 @@ import ( type authCtxKey struct{} +//counterfeiter:generate -o internal/mock/api_key_verifier.gen.go . APIKeyVerifier +type APIKeyVerifier interface { + Verify(ctx context.Context, key string) (domain.User, error) +} + func authMiddleware(verifier APIKeyVerifier) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/router/rest/rest.go b/internal/router/rest/rest.go index 78000a8..fca2929 100644 --- a/internal/router/rest/rest.go +++ b/internal/router/rest/rest.go @@ -1,8 +1,8 @@ package rest import ( - "context" "encoding/json" + "errors" "net/http" "gitea.dwysokinski.me/twhelp/sessions/internal/domain" @@ -29,13 +29,9 @@ type SwaggerConfig struct { Schemes []string } -//counterfeiter:generate -o internal/mock/api_key_verifier.gen.go . APIKeyVerifier -type APIKeyVerifier interface { - Verify(ctx context.Context, key string) (domain.User, error) -} - type RouterConfig struct { APIKeyVerifier APIKeyVerifier + SessionService SessionService CORS CORSConfig Swagger SwaggerConfig } @@ -43,6 +39,9 @@ type RouterConfig struct { func NewRouter(cfg RouterConfig) *chi.Mux { // handlers uh := userHandler{} + sh := sessionHandler{ + svc: cfg.SessionService, + } router := chi.NewRouter() @@ -73,6 +72,7 @@ func NewRouter(cfg RouterConfig) *chi.Mux { authMw := authMiddleware(cfg.APIKeyVerifier) r.With(authMw).Get("/user", uh.getAuthenticated) + r.With(authMw).Put("/user/sessions/{serverKey}", sh.createOrUpdate) }) return router @@ -96,53 +96,53 @@ func methodNotAllowedHandler(w http.ResponseWriter, _ *http.Request) { }) } -// type internalServerError struct { -// err error -// } -// -// func (e internalServerError) Error() string { -// return e.err.Error() -// } -// -// func (e internalServerError) UserError() string { -// return "internal server error" -// } -// -// func (e internalServerError) Code() domain.ErrorCode { -// return domain.ErrorCodeUnknown -// } -// -// func (e internalServerError) Unwrap() error { -// return e.err -// } -// -// func renderErr(w http.ResponseWriter, err error) { -// var userError domain.Error -// if !errors.As(err, &userError) { -// userError = internalServerError{err: err} -// } -// renderJSON(w, errorCodeToHTTPStatus(userError.Code()), model.ErrorResp{ -// Error: model.APIError{ -// Code: userError.Code().String(), -// Message: userError.UserError(), -// }, -// }) -// } -// -// func errorCodeToHTTPStatus(code domain.ErrorCode) int { -// switch code { -// case domain.ErrorCodeValidationError: -// return http.StatusBadRequest -// case domain.ErrorCodeEntityNotFound: -// return http.StatusNotFound -// case domain.ErrorCodeAlreadyExists: -// return http.StatusConflict -// case domain.ErrorCodeUnknown: -// fallthrough -// default: -// return http.StatusInternalServerError -// } -// } +type internalServerError struct { + err error +} + +func (e internalServerError) Error() string { + return e.err.Error() +} + +func (e internalServerError) UserError() string { + return "internal server error" +} + +func (e internalServerError) Code() domain.ErrorCode { + return domain.ErrorCodeUnknown +} + +func (e internalServerError) Unwrap() error { + return e.err +} + +func renderErr(w http.ResponseWriter, err error) { + var domainErr domain.Error + if !errors.As(err, &domainErr) { + domainErr = internalServerError{err: err} + } + renderJSON(w, errorCodeToHTTPStatus(domainErr.Code()), model.ErrorResp{ + Error: model.APIError{ + Code: domainErr.Code().String(), + Message: domainErr.UserError(), + }, + }) +} + +func errorCodeToHTTPStatus(code domain.ErrorCode) int { + switch code { + case domain.ErrorCodeValidationError: + return http.StatusBadRequest + case domain.ErrorCodeEntityNotFound: + return http.StatusNotFound + case domain.ErrorCodeAlreadyExists: + return http.StatusConflict + case domain.ErrorCodeUnknown: + fallthrough + default: + return http.StatusInternalServerError + } +} func renderJSON(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") diff --git a/internal/router/rest/rest_test.go b/internal/router/rest/rest_test.go index a2b80fd..4e9bfcb 100644 --- a/internal/router/rest/rest_test.go +++ b/internal/router/rest/rest_test.go @@ -20,35 +20,31 @@ import ( func TestRouteNotFound(t *testing.T) { t.Parallel() - expectedResp := &model.ErrorResp{ + resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil) + defer resp.Body.Close() + assertJSONResponse(t, resp, http.StatusNotFound, &model.ErrorResp{ Error: model.APIError{ Code: "route-not-found", Message: "route not found", }, - } - - resp := doRequest(rest.NewRouter(rest.RouterConfig{}), http.MethodGet, "/v1/"+uuid.NewString(), "", nil) - defer resp.Body.Close() - assertJSONResponse(t, resp, http.StatusNotFound, expectedResp, &model.ErrorResp{}) + }, &model.ErrorResp{}) } func TestMethodNotAllowed(t *testing.T) { t.Parallel() - expectedResp := &model.ErrorResp{ - Error: model.APIError{ - Code: "method-not-allowed", - Message: "method not allowed", - }, - } - resp := doRequest(rest.NewRouter(rest.RouterConfig{ Swagger: rest.SwaggerConfig{ Enabled: true, }, }), http.MethodPost, "/v1/swagger/index.html", "", nil) defer resp.Body.Close() - assertJSONResponse(t, resp, http.StatusMethodNotAllowed, expectedResp, &model.ErrorResp{}) + assertJSONResponse(t, resp, http.StatusMethodNotAllowed, &model.ErrorResp{ + Error: model.APIError{ + Code: "method-not-allowed", + Message: "method not allowed", + }, + }, &model.ErrorResp{}) } func TestSwagger(t *testing.T) { diff --git a/internal/router/rest/session.go b/internal/router/rest/session.go new file mode 100644 index 0000000..e8a532d --- /dev/null +++ b/internal/router/rest/session.go @@ -0,0 +1,62 @@ +package rest + +import ( + "context" + "io" + "net/http" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "github.com/go-chi/chi/v5" +) + +const ( + sidMaxBytes = 512 +) + +//counterfeiter:generate -o internal/mock/session_service.gen.go . SessionService +type SessionService interface { + CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) +} + +type sessionHandler struct { + svc SessionService +} + +// @ID createOrUpdateSession +// @Summary Create or update a session +// @Description Create or update a session +// @Tags users,sessions +// @Accept plain +// @Success 204 +// @Success 400 {object} model.ErrorResp +// @Success 401 {object} model.ErrorResp +// @Failure 500 {object} model.ErrorResp +// @Security ApiKeyAuth +// @Param serverKey path string true "Server key" +// @Param request body string true "SID" +// @Router /user/sessions/{serverKey} [put] +func (h *sessionHandler) createOrUpdate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + chiCtx := chi.RouteContext(ctx) + user, _ := userFromContext(ctx) + + b, err := io.ReadAll(io.LimitReader(r.Body, sidMaxBytes)) + if err != nil { + renderErr(w, err) + return + } + + params, err := domain.NewCreateSessionParams(chiCtx.URLParam("serverKey"), string(b), user.ID) + if err != nil { + renderErr(w, err) + return + } + + _, err = h.svc.CreateOrUpdate(ctx, params) + if err != nil { + renderErr(w, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/router/rest/session_test.go b/internal/router/rest/session_test.go new file mode 100644 index 0000000..a334ae3 --- /dev/null +++ b/internal/router/rest/session_test.go @@ -0,0 +1,218 @@ +package rest_test + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + "testing" + "time" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest" + "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/mock" + "gitea.dwysokinski.me/twhelp/sessions/internal/router/rest/internal/model" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestSession_createOrUpdate(t *testing.T) { + t.Parallel() + + now := time.Now() + apiKey := uuid.NewString() + tests := []struct { + name string + setup func(*mock.FakeAPIKeyVerifier, *mock.FakeSessionService) + apiKey string + serverKey string + sid string + expectedStatus int + target any + expectedResponse any + }{ + { + name: "OK", + setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) { + apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) { + if key != apiKey { + return domain.User{}, domain.APIKeyNotFoundError{ + Key: key, + } + } + + return domain.User{ + ID: 111, + Name: "name", + CreatedAt: now, + }, nil + }) + sessionSvc.CreateOrUpdateReturns(domain.Session{}, nil) + }, + apiKey: apiKey, + serverKey: "pl151", + sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + expectedStatus: http.StatusNoContent, + }, + { + name: "ERR: len(serverKey) > 10", + setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) { + apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) { + if key != apiKey { + return domain.User{}, domain.APIKeyNotFoundError{ + Key: key, + } + } + + return domain.User{ + ID: 111, + Name: "name", + CreatedAt: now, + }, nil + }) + }, + apiKey: apiKey, + serverKey: "012345678890", + sid: base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + expectedStatus: http.StatusBadRequest, + target: &model.ErrorResp{}, + expectedResponse: &model.ErrorResp{ + Error: model.APIError{ + Code: domain.ErrorCodeValidationError.String(), + Message: "ServerKey: the length must be no more than 10", + }, + }, + }, + { + name: "ERR: SID is required", + setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) { + apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) { + if key != apiKey { + return domain.User{}, domain.APIKeyNotFoundError{ + Key: key, + } + } + + return domain.User{ + ID: 111, + Name: "name", + CreatedAt: now, + }, nil + }) + }, + apiKey: apiKey, + serverKey: "pl151", + sid: "", + expectedStatus: http.StatusBadRequest, + target: &model.ErrorResp{}, + expectedResponse: &model.ErrorResp{ + Error: model.APIError{ + Code: domain.ErrorCodeValidationError.String(), + Message: "SID: cannot be blank", + }, + }, + }, + { + name: "ERR: SID is not a valid base64", + setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) { + apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) { + if key != apiKey { + return domain.User{}, domain.APIKeyNotFoundError{ + Key: key, + } + } + + return domain.User{ + ID: 111, + Name: "name", + CreatedAt: now, + }, nil + }) + }, + apiKey: apiKey, + serverKey: "pl151", + sid: uuid.NewString(), + expectedStatus: http.StatusBadRequest, + target: &model.ErrorResp{}, + expectedResponse: &model.ErrorResp{ + Error: model.APIError{ + Code: domain.ErrorCodeValidationError.String(), + Message: "SID: must be encoded in Base64", + }, + }, + }, + { + name: "ERR: apiKey == \"\"", + setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) {}, + apiKey: "", + expectedStatus: http.StatusUnauthorized, + serverKey: "pl151", + target: &model.ErrorResp{}, + expectedResponse: &model.ErrorResp{ + Error: model.APIError{ + Code: "unauthorized", + Message: "invalid API key", + }, + }, + }, + { + name: "ERR: unexpected API key", + setup: func(apiKeySvc *mock.FakeAPIKeyVerifier, sessionSvc *mock.FakeSessionService) { + apiKeySvc.VerifyCalls(func(ctx context.Context, key string) (domain.User, error) { + if key != apiKey { + return domain.User{}, domain.APIKeyNotFoundError{ + Key: key, + } + } + + return domain.User{ + ID: 111, + Name: "name", + CreatedAt: now, + }, nil + }) + }, + apiKey: uuid.NewString(), + serverKey: "pl151", + expectedStatus: http.StatusUnauthorized, + target: &model.ErrorResp{}, + expectedResponse: &model.ErrorResp{ + Error: model.APIError{ + Code: "unauthorized", + Message: "invalid API key", + }, + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + apiKeySvc := &mock.FakeAPIKeyVerifier{} + sessionSvc := &mock.FakeSessionService{} + tt.setup(apiKeySvc, sessionSvc) + + router := rest.NewRouter(rest.RouterConfig{ + APIKeyVerifier: apiKeySvc, + SessionService: sessionSvc, + }) + + resp := doRequest( + router, + http.MethodPut, + "/v1/user/sessions/"+tt.serverKey, + tt.apiKey, + strings.NewReader(tt.sid), + ) + defer resp.Body.Close() + if tt.expectedStatus == http.StatusNoContent { + assert.Equal(t, tt.expectedStatus, resp.StatusCode) + return + } + assertJSONResponse(t, resp, tt.expectedStatus, tt.expectedResponse, tt.target) + }) + } +} diff --git a/internal/router/rest/user_test.go b/internal/router/rest/user_test.go index 521ac7a..9407b93 100644 --- a/internal/router/rest/user_test.go +++ b/internal/router/rest/user_test.go @@ -114,5 +114,4 @@ func TestUser_getAuthenticated(t *testing.T) { assertJSONResponse(t, resp, tt.expectedStatus, tt.expectedResponse, tt.target) }) } - } diff --git a/internal/service/session.go b/internal/service/session.go new file mode 100644 index 0000000..797cee5 --- /dev/null +++ b/internal/service/session.go @@ -0,0 +1,39 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" +) + +//counterfeiter:generate -o internal/mock/session_repository.gen.go . SessionRepository +type SessionRepository interface { + CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) +} + +type Session struct { + repo SessionRepository + userSvc UserGetter +} + +func NewSession(repo SessionRepository, userSvc UserGetter) *Session { + return &Session{repo: repo, userSvc: userSvc} +} + +func (s *Session) CreateOrUpdate(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) { + if _, err := s.userSvc.Get(ctx, params.UserID()); err != nil { + if errors.Is(err, domain.UserNotFoundError{ID: params.UserID()}) { + return domain.Session{}, domain.UserDoesNotExistError{ID: params.UserID()} + } + return domain.Session{}, fmt.Errorf("UserService.Get: %w", err) + } + + sess, err := s.repo.CreateOrUpdate(ctx, params) + if err != nil { + return domain.Session{}, fmt.Errorf("SessionRepository.CreateOrUpdate: %w", err) + } + + return sess, nil +} diff --git a/internal/service/session_test.go b/internal/service/session_test.go new file mode 100644 index 0000000..f1233cb --- /dev/null +++ b/internal/service/session_test.go @@ -0,0 +1,74 @@ +package service_test + +import ( + "context" + "encoding/base64" + "testing" + "time" + + "gitea.dwysokinski.me/twhelp/sessions/internal/domain" + "gitea.dwysokinski.me/twhelp/sessions/internal/service" + "gitea.dwysokinski.me/twhelp/sessions/internal/service/internal/mock" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSession_CreateOrUpdate(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + params, err := domain.NewCreateSessionParams( + "pl151", + base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + 111, + ) + require.NoError(t, err) + + repo := &mock.FakeSessionRepository{} + repo.CreateOrUpdateCalls(func(ctx context.Context, params domain.CreateSessionParams) (domain.Session, error) { + return domain.Session{ + ID: 1111, + UserID: params.UserID(), + ServerKey: params.ServerKey(), + SID: params.SID(), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, nil + }) + + userSvc := &mock.FakeUserGetter{} + userSvc.GetReturns(domain.User{ + ID: params.UserID(), + }, nil) + + sess, err := service.NewSession(repo, userSvc).CreateOrUpdate(context.Background(), params) + assert.NoError(t, err) + assert.Greater(t, sess.ID, int64(0)) + assert.Equal(t, params.UserID(), sess.UserID) + assert.Equal(t, params.SID(), sess.SID) + assert.Equal(t, params.ServerKey(), sess.ServerKey) + assert.WithinDuration(t, time.Now(), sess.CreatedAt, 10*time.Millisecond) + assert.WithinDuration(t, time.Now(), sess.UpdatedAt, 10*time.Millisecond) + }) + + t.Run("ERR: user doesnt exist", func(t *testing.T) { + t.Parallel() + + params, err := domain.NewCreateSessionParams( + "pl151", + base64.StdEncoding.EncodeToString([]byte(uuid.NewString())), + 111, + ) + require.NoError(t, err) + + userSvc := &mock.FakeUserGetter{} + userSvc.GetReturns(domain.User{}, domain.UserNotFoundError{ID: params.UserID()}) + + sess, err := service.NewSession(nil, userSvc).CreateOrUpdate(context.Background(), params) + assert.ErrorIs(t, err, domain.UserDoesNotExistError{ID: params.UserID()}) + assert.Zero(t, sess) + }) +}