add two middlewares - Authenticate and GinContextToContext

This commit is contained in:
Dawid Wysokiński 2021-03-06 10:05:54 +01:00
parent acbc4defe2
commit 6251cacded
4 changed files with 99 additions and 0 deletions

View File

@ -0,0 +1,48 @@
package middleware
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/zdam-egzamin-zawodowy/backend/internal/auth"
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
)
const (
authorizationHeader = "Authorization"
)
var (
authenticateKey contextKey = "current_user"
)
func Authenticate(ucase auth.Usecase) gin.HandlerFunc {
return func(c *gin.Context) {
token := extractToken(c.GetHeader(authorizationHeader))
if token != "" {
ctx := c.Request.Context()
user, err := ucase.ExtractAccessTokenMetadata(ctx, token)
if err == nil && user != nil {
ctx = context.WithValue(ctx, authenticateKey, user)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
func UserFromContext(ctx context.Context) (*models.User, error) {
user := ctx.Value(authenticateKey)
if user == nil {
err := fmt.Errorf("Could not retrieve *models.User")
return nil, err
}
gc, ok := user.(*models.User)
if !ok {
err := fmt.Errorf("*models.User has wrong type")
return nil, err
}
return gc, nil
}

View File

@ -0,0 +1,3 @@
package middleware
type contextKey string

View File

@ -0,0 +1,35 @@
package middleware
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
)
var (
ginContextToContextKey contextKey = "gin_context"
)
func GinContextToContext() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := context.WithValue(c.Request.Context(), ginContextToContextKey, c)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
func GinContextFromContext(ctx context.Context) (*gin.Context, error) {
ginContext := ctx.Value(ginContextToContextKey)
if ginContext == nil {
err := fmt.Errorf("could not retrieve gin.Context")
return nil, err
}
gc, ok := ginContext.(*gin.Context)
if !ok {
err := fmt.Errorf("gin.Context has wrong type")
return nil, err
}
return gc, nil
}

View File

@ -0,0 +1,13 @@
package middleware
import (
"strings"
)
func extractToken(bearToken string) string {
strArr := strings.Split(bearToken, " ")
if len(strArr) == 2 {
return strArr[1]
}
return ""
}