add LimitWhitelist middleware

This commit is contained in:
Dawid Wysokiński 2020-08-09 14:32:46 +02:00
parent 7d2f75db52
commit 73390475fc
14 changed files with 70 additions and 12 deletions

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/tribalwarshelp/api/dailyplayerstats"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
)
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.D
if filter == nil {
filter = &models.DailyPlayerStatsFilter{}
}
if filter.Limit > dailyplayerstats.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > dailyplayerstats.PaginationLimit || filter.Limit <= 0) {
filter.Limit = dailyplayerstats.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/tribalwarshelp/api/dailytribestats"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
)
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.D
if filter == nil {
filter = &models.DailyTribeStatsFilter{}
}
if filter.Limit > dailytribestats.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > dailytribestats.PaginationLimit || filter.Limit <= 0) {
filter.Limit = dailytribestats.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/tribalwarshelp/api/ennoblement"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
)
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.E
if filter == nil {
filter = &models.EnnoblementFilter{}
}
if filter.Limit > ennoblement.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > ennoblement.PaginationLimit || filter.Limit <= 0) {
filter.Limit = ennoblement.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/api/langversion"
@ -24,7 +25,7 @@ func (ucase *usecase) Fetch(ctx context.Context, filter *models.LangVersionFilte
if filter == nil {
filter = &models.LangVersionFilter{}
}
if filter.Limit > langversion.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > langversion.PaginationLimit || filter.Limit <= 0) {
filter.Limit = langversion.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -6,6 +6,7 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"time"
"github.com/gin-contrib/cors"
@ -143,6 +144,9 @@ func main() {
VillageRepo: villageRepo,
LangVersionRepo: langversionRepo,
}))
graphql.Use(middleware.LimitWhitelist(middleware.LimitWhitelistConfig{
IPAddresses: strings.Split(os.Getenv("LIMIT_WHITELIST"), ","),
}))
httpdelivery.Attach(httpdelivery.Config{
RouterGroup: graphql,
Resolver: &resolvers.Resolver{

View File

@ -0,0 +1,42 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
)
var limitWhitelistContextKey ContextKey = "limitWhitelist"
type LimitWhitelistConfig struct {
IPAddresses []string
}
func LimitWhitelist(cfg LimitWhitelistConfig) gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
clientIP := c.ClientIP()
mayExceedLimit := false
for _, ip := range cfg.IPAddresses {
if ip == clientIP {
mayExceedLimit = true
break
}
}
ctx = StoreLimitWhitelistDataInContext(ctx, mayExceedLimit)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
func StoreLimitWhitelistDataInContext(ctx context.Context, mayExceedLimit bool) context.Context {
return context.WithValue(ctx, limitWhitelistContextKey, mayExceedLimit)
}
func MayExceedLimit(ctx context.Context) bool {
whitelisted := ctx.Value(limitWhitelistContextKey)
if whitelisted == nil {
return false
}
return whitelisted.(bool)
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/player"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -21,7 +22,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.P
if filter == nil {
filter = &models.PlayerFilter{}
}
if filter.Limit > player.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > player.PaginationLimit || filter.Limit <= 0) {
filter.Limit = player.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -3,6 +3,7 @@ package usecase
import (
"context"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/playerhistory"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.P
if filter == nil {
filter = &models.PlayerHistoryFilter{}
}
if filter.Limit > playerhistory.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > playerhistory.PaginationLimit || filter.Limit <= 0) {
filter.Limit = playerhistory.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/server"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -21,7 +22,7 @@ func (ucase *usecase) Fetch(ctx context.Context, filter *models.ServerFilter) ([
if filter == nil {
filter = &models.ServerFilter{}
}
if filter.Limit > server.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > server.PaginationLimit || filter.Limit <= 0) {
filter.Limit = server.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -3,6 +3,7 @@ package usecase
import (
"context"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/serverstats"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.S
if filter == nil {
filter = &models.ServerStatsFilter{}
}
if filter.Limit > serverstats.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > serverstats.PaginationLimit || filter.Limit <= 0) {
filter.Limit = serverstats.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/tribe"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -21,7 +22,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.T
if filter == nil {
filter = &models.TribeFilter{}
}
if filter.Limit > tribe.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > tribe.PaginationLimit || filter.Limit <= 0) {
filter.Limit = tribe.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -3,6 +3,7 @@ package usecase
import (
"context"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/tribechange"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.T
if filter == nil {
filter = &models.TribeChangeFilter{}
}
if filter.Limit > tribechange.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > tribechange.PaginationLimit || filter.Limit <= 0) {
filter.Limit = tribechange.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -3,6 +3,7 @@ package usecase
import (
"context"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/tribehistory"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/shared/models"
@ -20,7 +21,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.T
if filter == nil {
filter = &models.TribeHistoryFilter{}
}
if filter.Limit > tribehistory.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > tribehistory.PaginationLimit || filter.Limit <= 0) {
filter.Limit = tribehistory.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/tribalwarshelp/api/middleware"
"github.com/tribalwarshelp/api/utils"
"github.com/tribalwarshelp/api/village"
"github.com/tribalwarshelp/shared/models"
@ -21,7 +22,7 @@ func (ucase *usecase) Fetch(ctx context.Context, server string, filter *models.V
if filter == nil {
filter = &models.VillageFilter{}
}
if filter.Limit > village.PaginationLimit || filter.Limit <= 0 {
if !middleware.MayExceedLimit(ctx) && (filter.Limit > village.PaginationLimit || filter.Limit <= 0) {
filter.Limit = village.PaginationLimit
}
filter.Sort = utils.SanitizeSort(filter.Sort)