rename two utils (sqlutils.SanitizeSortExpression -> sqlutils.SanitizeSort, sqlutils.SanitizeSortExpressions -> sqlutils.SanitizeSorts), add a new helper type to db package (Sort)

This commit is contained in:
Dawid Wysokiński 2021-03-06 08:44:49 +01:00
parent 838a93b8b5
commit 56c8c9160b
8 changed files with 53 additions and 24 deletions

View File

@ -14,12 +14,6 @@ import (
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
)
const (
extensions = `
CREATE EXTENSION IF NOT EXISTS tsm_system_rows;
`
)
var log = logrus.WithField("package", "internal/db")
type Config struct {
@ -60,10 +54,6 @@ func prepareOptions() *pg.Options {
func createSchema(db *pg.DB) error {
return db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
if _, err := tx.Exec(extensions); err != nil {
return errors.Wrap(err, "createSchema")
}
modelsToCreate := []interface{}{
(*models.User)(nil),
(*models.Profession)(nil),

28
internal/db/sort.go Normal file
View File

@ -0,0 +1,28 @@
package db
import (
"strings"
"github.com/go-pg/pg/v10/orm"
)
type Sort struct {
Relationships map[string]string
Orders []string
}
func (s Sort) Apply(q *orm.Query) (*orm.Query, error) {
for _, order := range s.Orders {
if alias := s.extractAlias(order); alias != "" && s.Relationships[alias] != "" {
q = q.Relation(s.Relationships[alias])
}
}
return q.Order(s.Orders...), nil
}
func (s Sort) extractAlias(order string) string {
if strings.Contains(order, ".") {
return strings.Split(order, ".")[0]
}
return ""
}

View File

@ -65,7 +65,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *profession.FetchConfig) ([
Count: true,
}
}
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
return ucase.professionRepository.Fetch(ctx, cfg)
}

View File

@ -65,7 +65,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *qualification.FetchConfig)
Count: true,
}
}
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
return ucase.qualificationRepository.Fetch(ctx, cfg)
}

View File

@ -11,6 +11,7 @@ import (
sqlutils "github.com/zdam-egzamin-zawodowy/backend/pkg/utils/sql"
"github.com/go-pg/pg/v10"
"github.com/zdam-egzamin-zawodowy/backend/internal/db"
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
"github.com/zdam-egzamin-zawodowy/backend/internal/question"
)
@ -116,6 +117,12 @@ func (repo *pgRepository) Fetch(ctx context.Context, cfg *question.FetchConfig)
Context(ctx).
Limit(cfg.Limit).
Offset(cfg.Offset).
Apply(db.Sort{
Relationships: map[string]string{
"qualification": "qualification",
},
Orders: cfg.Sort,
}.Apply).
Apply(cfg.Filter.Where)
if cfg.Count {

View File

@ -71,7 +71,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *question.FetchConfig) ([]*
Count: true,
}
}
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
return ucase.questionRepository.Fetch(ctx, cfg)
}

View File

@ -79,7 +79,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *user.FetchConfig) ([]*mode
Count: true,
}
}
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
return ucase.userRepository.Fetch(ctx, cfg)
}

View File

@ -8,16 +8,18 @@ import (
)
var (
sortexprRegex = regexp.MustCompile(`^[\p{L}\_\.]+$`)
sortRegex = regexp.MustCompile(`^[\p{L}\_\.]+$`)
)
func SanitizeSortExpression(expr string) string {
trimmed := strings.TrimSpace(expr)
func SanitizeSort(sort string) string {
trimmed := strings.TrimSpace(sort)
splitted := strings.Split(trimmed, " ")
length := len(splitted)
if length != 2 || !sortexprRegex.Match([]byte(splitted[0])) {
if length != 2 || !sortRegex.Match([]byte(splitted[0])) {
return ""
}
table := ""
column := splitted[0]
if strings.Contains(splitted[0], ".") {
@ -25,17 +27,19 @@ func SanitizeSortExpression(expr string) string {
table = underscore(columnAndTable[0]) + "."
column = columnAndTable[1]
}
keyword := "ASC"
direction := "ASC"
if strings.ToUpper(splitted[1]) == "DESC" {
keyword = "DESC"
direction = "DESC"
}
return strings.ToLower(table+underscore(column)) + " " + keyword
return strings.ToLower(table+underscore(column)) + " " + direction
}
func SanitizeSortExpressions(exprs []string) []string {
func SanitizeSorts(sorts []string) []string {
sanitizedExprs := []string{}
for _, expr := range exprs {
sanitized := SanitizeSortExpression(expr)
for _, sort := range sorts {
sanitized := SanitizeSort(sort)
if sanitized != "" {
sanitizedExprs = append(sanitizedExprs, sanitized)
}