diff --git a/internal/gin/middleware/authenticate.go b/internal/gin/middleware/authenticate.go new file mode 100644 index 0000000..2cfe18a --- /dev/null +++ b/internal/gin/middleware/authenticate.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "context" + "fmt" + + "github.com/gin-gonic/gin" + "github.com/zdam-egzamin-zawodowy/backend/internal/auth" + "github.com/zdam-egzamin-zawodowy/backend/internal/models" +) + +const ( + authorizationHeader = "Authorization" +) + +var ( + authenticateKey contextKey = "current_user" +) + +func Authenticate(ucase auth.Usecase) gin.HandlerFunc { + return func(c *gin.Context) { + token := extractToken(c.GetHeader(authorizationHeader)) + if token != "" { + ctx := c.Request.Context() + user, err := ucase.ExtractAccessTokenMetadata(ctx, token) + if err == nil && user != nil { + ctx = context.WithValue(ctx, authenticateKey, user) + c.Request = c.Request.WithContext(ctx) + } + } + c.Next() + } +} + +func UserFromContext(ctx context.Context) (*models.User, error) { + user := ctx.Value(authenticateKey) + if user == nil { + err := fmt.Errorf("Could not retrieve *models.User") + return nil, err + } + + gc, ok := user.(*models.User) + if !ok { + err := fmt.Errorf("*models.User has wrong type") + return nil, err + } + return gc, nil +} diff --git a/internal/gin/middleware/context_key.go b/internal/gin/middleware/context_key.go new file mode 100644 index 0000000..4bab584 --- /dev/null +++ b/internal/gin/middleware/context_key.go @@ -0,0 +1,3 @@ +package middleware + +type contextKey string diff --git a/internal/gin/middleware/gin_context_to_context.go b/internal/gin/middleware/gin_context_to_context.go new file mode 100644 index 0000000..bfeeb0a --- /dev/null +++ b/internal/gin/middleware/gin_context_to_context.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "context" + "fmt" + + "github.com/gin-gonic/gin" +) + +var ( + ginContextToContextKey contextKey = "gin_context" +) + +func GinContextToContext() gin.HandlerFunc { + return func(c *gin.Context) { + ctx := context.WithValue(c.Request.Context(), ginContextToContextKey, c) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} + +func GinContextFromContext(ctx context.Context) (*gin.Context, error) { + ginContext := ctx.Value(ginContextToContextKey) + if ginContext == nil { + err := fmt.Errorf("could not retrieve gin.Context") + return nil, err + } + + gc, ok := ginContext.(*gin.Context) + if !ok { + err := fmt.Errorf("gin.Context has wrong type") + return nil, err + } + return gc, nil +} diff --git a/internal/gin/middleware/helpers.go b/internal/gin/middleware/helpers.go new file mode 100644 index 0000000..108bb5f --- /dev/null +++ b/internal/gin/middleware/helpers.go @@ -0,0 +1,13 @@ +package middleware + +import ( + "strings" +) + +func extractToken(bearToken string) string { + strArr := strings.Split(bearToken, " ") + if len(strArr) == 2 { + return strArr[1] + } + return "" +}