rename two utils (sqlutils.SanitizeSortExpression -> sqlutils.SanitizeSort, sqlutils.SanitizeSortExpressions -> sqlutils.SanitizeSorts), add a new helper type to db package (Sort)
This commit is contained in:
parent
838a93b8b5
commit
56c8c9160b
|
@ -14,12 +14,6 @@ import (
|
||||||
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
|
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
extensions = `
|
|
||||||
CREATE EXTENSION IF NOT EXISTS tsm_system_rows;
|
|
||||||
`
|
|
||||||
)
|
|
||||||
|
|
||||||
var log = logrus.WithField("package", "internal/db")
|
var log = logrus.WithField("package", "internal/db")
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -60,10 +54,6 @@ func prepareOptions() *pg.Options {
|
||||||
|
|
||||||
func createSchema(db *pg.DB) error {
|
func createSchema(db *pg.DB) error {
|
||||||
return db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
|
return db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
|
||||||
if _, err := tx.Exec(extensions); err != nil {
|
|
||||||
return errors.Wrap(err, "createSchema")
|
|
||||||
}
|
|
||||||
|
|
||||||
modelsToCreate := []interface{}{
|
modelsToCreate := []interface{}{
|
||||||
(*models.User)(nil),
|
(*models.User)(nil),
|
||||||
(*models.Profession)(nil),
|
(*models.Profession)(nil),
|
||||||
|
|
28
internal/db/sort.go
Normal file
28
internal/db/sort.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-pg/pg/v10/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sort struct {
|
||||||
|
Relationships map[string]string
|
||||||
|
Orders []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Sort) Apply(q *orm.Query) (*orm.Query, error) {
|
||||||
|
for _, order := range s.Orders {
|
||||||
|
if alias := s.extractAlias(order); alias != "" && s.Relationships[alias] != "" {
|
||||||
|
q = q.Relation(s.Relationships[alias])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return q.Order(s.Orders...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Sort) extractAlias(order string) string {
|
||||||
|
if strings.Contains(order, ".") {
|
||||||
|
return strings.Split(order, ".")[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
|
@ -65,7 +65,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *profession.FetchConfig) ([
|
||||||
Count: true,
|
Count: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
|
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
|
||||||
return ucase.professionRepository.Fetch(ctx, cfg)
|
return ucase.professionRepository.Fetch(ctx, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *qualification.FetchConfig)
|
||||||
Count: true,
|
Count: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
|
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
|
||||||
return ucase.qualificationRepository.Fetch(ctx, cfg)
|
return ucase.qualificationRepository.Fetch(ctx, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
sqlutils "github.com/zdam-egzamin-zawodowy/backend/pkg/utils/sql"
|
sqlutils "github.com/zdam-egzamin-zawodowy/backend/pkg/utils/sql"
|
||||||
|
|
||||||
"github.com/go-pg/pg/v10"
|
"github.com/go-pg/pg/v10"
|
||||||
|
"github.com/zdam-egzamin-zawodowy/backend/internal/db"
|
||||||
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
|
"github.com/zdam-egzamin-zawodowy/backend/internal/models"
|
||||||
"github.com/zdam-egzamin-zawodowy/backend/internal/question"
|
"github.com/zdam-egzamin-zawodowy/backend/internal/question"
|
||||||
)
|
)
|
||||||
|
@ -116,6 +117,12 @@ func (repo *pgRepository) Fetch(ctx context.Context, cfg *question.FetchConfig)
|
||||||
Context(ctx).
|
Context(ctx).
|
||||||
Limit(cfg.Limit).
|
Limit(cfg.Limit).
|
||||||
Offset(cfg.Offset).
|
Offset(cfg.Offset).
|
||||||
|
Apply(db.Sort{
|
||||||
|
Relationships: map[string]string{
|
||||||
|
"qualification": "qualification",
|
||||||
|
},
|
||||||
|
Orders: cfg.Sort,
|
||||||
|
}.Apply).
|
||||||
Apply(cfg.Filter.Where)
|
Apply(cfg.Filter.Where)
|
||||||
|
|
||||||
if cfg.Count {
|
if cfg.Count {
|
||||||
|
|
|
@ -71,7 +71,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *question.FetchConfig) ([]*
|
||||||
Count: true,
|
Count: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
|
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
|
||||||
return ucase.questionRepository.Fetch(ctx, cfg)
|
return ucase.questionRepository.Fetch(ctx, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ func (ucase *usecase) Fetch(ctx context.Context, cfg *user.FetchConfig) ([]*mode
|
||||||
Count: true,
|
Count: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg.Sort = sqlutils.SanitizeSortExpressions(cfg.Sort)
|
cfg.Sort = sqlutils.SanitizeSorts(cfg.Sort)
|
||||||
return ucase.userRepository.Fetch(ctx, cfg)
|
return ucase.userRepository.Fetch(ctx, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,16 +8,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sortexprRegex = regexp.MustCompile(`^[\p{L}\_\.]+$`)
|
sortRegex = regexp.MustCompile(`^[\p{L}\_\.]+$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func SanitizeSortExpression(expr string) string {
|
func SanitizeSort(sort string) string {
|
||||||
trimmed := strings.TrimSpace(expr)
|
trimmed := strings.TrimSpace(sort)
|
||||||
splitted := strings.Split(trimmed, " ")
|
splitted := strings.Split(trimmed, " ")
|
||||||
length := len(splitted)
|
length := len(splitted)
|
||||||
if length != 2 || !sortexprRegex.Match([]byte(splitted[0])) {
|
|
||||||
|
if length != 2 || !sortRegex.Match([]byte(splitted[0])) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
table := ""
|
table := ""
|
||||||
column := splitted[0]
|
column := splitted[0]
|
||||||
if strings.Contains(splitted[0], ".") {
|
if strings.Contains(splitted[0], ".") {
|
||||||
|
@ -25,17 +27,19 @@ func SanitizeSortExpression(expr string) string {
|
||||||
table = underscore(columnAndTable[0]) + "."
|
table = underscore(columnAndTable[0]) + "."
|
||||||
column = columnAndTable[1]
|
column = columnAndTable[1]
|
||||||
}
|
}
|
||||||
keyword := "ASC"
|
|
||||||
|
direction := "ASC"
|
||||||
if strings.ToUpper(splitted[1]) == "DESC" {
|
if strings.ToUpper(splitted[1]) == "DESC" {
|
||||||
keyword = "DESC"
|
direction = "DESC"
|
||||||
}
|
}
|
||||||
return strings.ToLower(table+underscore(column)) + " " + keyword
|
|
||||||
|
return strings.ToLower(table+underscore(column)) + " " + direction
|
||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeSortExpressions(exprs []string) []string {
|
func SanitizeSorts(sorts []string) []string {
|
||||||
sanitizedExprs := []string{}
|
sanitizedExprs := []string{}
|
||||||
for _, expr := range exprs {
|
for _, sort := range sorts {
|
||||||
sanitized := SanitizeSortExpression(expr)
|
sanitized := SanitizeSort(sort)
|
||||||
if sanitized != "" {
|
if sanitized != "" {
|
||||||
sanitizedExprs = append(sanitizedExprs, sanitized)
|
sanitizedExprs = append(sanitizedExprs, sanitized)
|
||||||
}
|
}
|
Reference in New Issue
Block a user