core/internal/adapter/adaptertest/postgres.go

186 lines
4.4 KiB
Go

package adaptertest
import (
"database/sql"
"fmt"
"net/url"
"os"
"strings"
"time"
"unicode"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
type Postgres struct {
connectionString *url.URL
resource *dockertest.Resource
}
type postgresConfig struct {
repo string
tag string
ttl uint
}
type PostgresOption func(cfg *postgresConfig)
const postgresDefaultTTL = 60
func newPostgresConfig(opts ...PostgresOption) *postgresConfig {
cfg := &postgresConfig{
repo: "postgres",
tag: "14",
ttl: postgresDefaultTTL,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
func WithPostgresTTL(ttlSeconds uint) PostgresOption {
return func(cfg *postgresConfig) {
cfg.ttl = ttlSeconds
}
}
func WithPostgresImage(image string) PostgresOption {
return func(cfg *postgresConfig) {
cfg.repo, cfg.tag = docker.ParseRepositoryTag(image)
}
}
// NewPostgres constructs a new Postgres resource. If the env variable 'TESTS_POSTGRES_CONNECTION_STRING' is set,
// this function doesn't run a new Docker container and uses the value of this variable as a connection string,
// otherwise this function runs a new PostgresSQL database in a Docker container.
// This function is intended for use in TestMain.
func NewPostgres(pool *dockertest.Pool, opts ...PostgresOption) (*Postgres, error) {
cfg := newPostgresConfig(opts...)
if connString := os.Getenv("TESTS_POSTGRES_CONNECTION_STRING"); connString != "" {
u, err := url.ParseRequestURI(connString)
if err != nil {
return nil, err
}
return &Postgres{
connectionString: u,
}, nil
}
u := &url.URL{
Scheme: "postgres",
User: url.UserPassword("postgres", "postgres"),
Path: "twhelpdb",
RawQuery: url.Values{
"sslmode": []string{"disable"},
}.Encode(),
}
pw, _ := u.User.Password()
resource, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: cfg.repo,
Tag: cfg.tag,
Env: []string{
fmt.Sprintf("POSTGRES_USER=%s", u.User.Username()),
fmt.Sprintf("POSTGRES_PASSWORD=%s", pw),
fmt.Sprintf("POSTGRES_DB=%s", u.Path),
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{
Name: "no",
}
})
if err != nil {
return nil, fmt.Errorf("couldn't run postgres: %w", err)
}
if err = resource.Expire(cfg.ttl); err != nil {
return nil, err
}
u.Host, err = getHostPort(resource, "5432/tcp")
if err != nil {
return nil, err
}
return &Postgres{
connectionString: u,
resource: resource,
}, nil
}
const (
postgresPingBackOffMaxInterval = 5 * time.Second
postgresPingBackOffMaxElapsedTime = 30 * time.Second
)
// NewBunDB initializes a new instance of *bun.DB, which is ready for use (all required migrations are applied).
// This method guarantees data separation through PostgresSQL schemas
// (https://www.postgresql.org/docs/current/ddl-schemas.html)
// and it is safe to call Postgres.NewBunDB multiple times.
//
// It fails if Postgres hasn't been properly initialized (via NewPostgres).
func (p *Postgres) NewBunDB(tb TestingTB) *bun.DB {
tb.Helper()
require.NotNil(tb, p, "postgres resource not property initialized")
require.NotNil(tb, p.connectionString, "postgres resource not properly initialized")
schema := generatePostgresSchema()
sqlDB := sql.OpenDB(
pgdriver.NewConnector(
pgdriver.WithDSN(p.connectionString.String()),
pgdriver.WithConnParams(map[string]any{
"search_path": schema,
}),
),
)
bunDB := bun.NewDB(sqlDB, pgdialect.New())
tb.Cleanup(func() {
_ = bunDB.Close()
})
bunDB.AddQueryHook(newBunDebugHook())
bo := backoff.NewExponentialBackOff()
bo.MaxInterval = postgresPingBackOffMaxInterval
bo.MaxElapsedTime = postgresPingBackOffMaxElapsedTime
require.NoError(tb, retry(bo, bunDB.Ping), "couldn't ping DB")
_, err := bunDB.Exec("CREATE SCHEMA ?", bun.Safe(schema))
require.NoError(tb, err, "couldn't create postgres schema")
runMigrations(tb, bunDB)
return bunDB
}
func (p *Postgres) Close() error {
if p != nil && p.resource != nil {
if err := p.resource.Close(); err != nil {
return err
}
}
return nil
}
func generatePostgresSchema() string {
return strings.TrimFunc(strings.ReplaceAll(uuid.NewString(), "-", "_"), unicode.IsNumber)
}