add complexity limit

This commit is contained in:
Dawid Wysokiński 2021-03-20 16:44:56 +01:00
parent f86160c8b3
commit 91a35dd2e7
5 changed files with 65 additions and 10 deletions

View File

@ -2,6 +2,7 @@ package httpdelivery
import (
"fmt"
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
"time"
"github.com/99designs/gqlgen/graphql/handler"
@ -20,6 +21,7 @@ const (
playgroundTTL = time.Hour / time.Second
graphqlEndpoint = "/graphql"
playgroundEndpoint = "/"
complexityLimit = 1000
)
type Config struct {
@ -31,7 +33,7 @@ func Attach(group *gin.RouterGroup, cfg Config) error {
if cfg.Resolver == nil {
return fmt.Errorf("Graphql resolver cannot be nil")
}
gqlHandler := graphqlHandler(cfg.Resolver, cfg.Directive)
gqlHandler := graphqlHandler(prepareConfig(cfg.Resolver, cfg.Directive))
group.GET(graphqlEndpoint, gqlHandler)
group.POST(graphqlEndpoint, gqlHandler)
if mode.Get() == mode.DevelopmentMode {
@ -41,10 +43,7 @@ func Attach(group *gin.RouterGroup, cfg Config) error {
}
// Defining the GraphQL handler
func graphqlHandler(r *resolvers.Resolver, d *directive.Directive) gin.HandlerFunc {
cfg := generated.Config{Resolvers: r}
cfg.Directives.Authenticated = d.Authenticated
cfg.Directives.HasRole = d.HasRole
func graphqlHandler(cfg generated.Config) gin.HandlerFunc {
srv := handler.New(generated.NewExecutableSchema(cfg))
srv.AddTransport(transport.GET{})
@ -56,6 +55,8 @@ func graphqlHandler(r *resolvers.Resolver, d *directive.Directive) gin.HandlerFu
srv.Use(extension.AutomaticPersistedQuery{
Cache: lru.New(100),
})
srv.SetQueryCache(lru.New(100))
srv.Use(extension.FixedComplexityLimit(complexityLimit))
if mode.Get() == mode.DevelopmentMode {
srv.Use(extension.Introspection{})
}
@ -75,3 +76,55 @@ func playgroundHandler() gin.HandlerFunc {
h.ServeHTTP(c.Writer, c.Request)
}
}
func prepareConfig(r *resolvers.Resolver, d *directive.Directive) generated.Config {
cfg := generated.Config{Resolvers: r}
cfg.Directives.Authenticated = d.Authenticated
cfg.Directives.HasRole = d.HasRole
cfg.Complexity = getComplexityRoot()
return cfg
}
func getComplexityRoot() generated.ComplexityRoot {
complexityRoot := generated.ComplexityRoot{}
complexityRoot.Query.GenerateTest = func(childComplexity int, qualificationIDs []int, limit *int) int {
return 300 + childComplexity
}
complexityRoot.Query.Professions = func(
childComplexity int,
filter *models.ProfessionFilter,
limit *int,
offset *int,
sort []string,
) int {
return 200 + childComplexity
}
complexityRoot.Query.Qualifications = func(
childComplexity int,
filter *models.QualificationFilter,
limit *int,
offset *int,
sort []string,
) int {
return 200 + childComplexity
}
complexityRoot.Query.Questions = func(
childComplexity int,
filter *models.QuestionFilter,
limit *int,
offset *int,
sort []string,
) int {
return 200 + childComplexity
}
complexityRoot.Query.Users = func(
childComplexity int,
filter *models.UserFilter,
limit *int,
offset *int,
sort []string,
) int {
return 200 + childComplexity
}
return complexityRoot
}

View File

@ -71,7 +71,7 @@ func (r *queryResolver) Users(
&user.FetchConfig{
Count: shouldCount(ctx),
Filter: filter,
Limit: utils.SafeIntPointer(limit, user.DefaultLimit),
Limit: utils.SafeIntPointer(limit, user.FetchMaxLimit),
Offset: utils.SafeIntPointer(offset, 0),
Sort: sort,
},

View File

@ -1,7 +1,7 @@
package user
const (
DefaultLimit = 100
FetchMaxLimit = 100
MinDisplayNameLength = 2
MaxDisplayNameLength = 32
MinPasswordLength = 6

View File

@ -75,10 +75,13 @@ func (ucase *usecase) Delete(ctx context.Context, f *models.UserFilter) ([]*mode
func (ucase *usecase) Fetch(ctx context.Context, cfg *user.FetchConfig) ([]*models.User, int, error) {
if cfg == nil {
cfg = &user.FetchConfig{
Limit: user.DefaultLimit,
Limit: user.FetchMaxLimit,
Count: true,
}
}
if cfg.Limit > user.FetchMaxLimit || cfg.Limit <= 0 {
cfg.Limit = user.FetchMaxLimit
}
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
return ucase.userRepository.Fetch(ctx, cfg)
}

View File

@ -12,8 +12,7 @@ var (
)
func SanitizeSort(sort string) string {
trimmed := strings.TrimSpace(sort)
splitted := strings.Split(trimmed, " ")
splitted := strings.Split(strings.TrimSpace(sort), " ")
length := len(splitted)
if length != 2 || !sortRegex.Match([]byte(splitted[0])) {