package adaptertest import ( "database/sql" "fmt" "net/url" "os" "strings" "time" "unicode" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/require" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" ) type Postgres struct { connectionString *url.URL resource *dockertest.Resource } type postgresConfig struct { repo string tag string ttl uint } type PostgresOption func(cfg *postgresConfig) const postgresDefaultTTL = 60 func newPostgresConfig(opts ...PostgresOption) *postgresConfig { cfg := &postgresConfig{ repo: "postgres", tag: "14", ttl: postgresDefaultTTL, } for _, opt := range opts { opt(cfg) } return cfg } func WithPostgresTTL(ttlSeconds uint) PostgresOption { return func(cfg *postgresConfig) { cfg.ttl = ttlSeconds } } func WithPostgresImage(image string) PostgresOption { return func(cfg *postgresConfig) { cfg.repo, cfg.tag = docker.ParseRepositoryTag(image) } } // NewPostgres constructs a new Postgres resource. If the env variable 'TESTS_POSTGRES_CONNECTION_STRING' is set, // this function doesn't run a new Docker container and uses the value of this variable as a connection string, // otherwise this function runs a new PostgresSQL database in a Docker container. // This function is intended for use in TestMain. func NewPostgres(pool *dockertest.Pool, opts ...PostgresOption) (*Postgres, error) { cfg := newPostgresConfig(opts...) if connString := os.Getenv("TESTS_POSTGRES_CONNECTION_STRING"); connString != "" { u, err := url.ParseRequestURI(connString) if err != nil { return nil, err } return &Postgres{ connectionString: u, }, nil } u := &url.URL{ Scheme: "postgres", User: url.UserPassword("postgres", "postgres"), Path: "twhelpdb", RawQuery: url.Values{ "sslmode": []string{"disable"}, }.Encode(), } pw, _ := u.User.Password() resource, err := pool.RunWithOptions(&dockertest.RunOptions{ Repository: cfg.repo, Tag: cfg.tag, Env: []string{ fmt.Sprintf("POSTGRES_USER=%s", u.User.Username()), fmt.Sprintf("POSTGRES_PASSWORD=%s", pw), fmt.Sprintf("POSTGRES_DB=%s", u.Path), }, }, func(config *docker.HostConfig) { config.AutoRemove = true config.RestartPolicy = docker.RestartPolicy{ Name: "no", } }) if err != nil { return nil, fmt.Errorf("couldn't run postgres: %w", err) } if err = resource.Expire(cfg.ttl); err != nil { return nil, err } u.Host, err = getHostPort(resource, "5432/tcp") if err != nil { return nil, err } return &Postgres{ connectionString: u, resource: resource, }, nil } const ( postgresPingBackOffMaxInterval = 5 * time.Second postgresPingBackOffMaxElapsedTime = 30 * time.Second ) // NewBunDB initializes a new instance of *bun.DB, which is ready for use (all required migrations are applied). // This method guarantees data separation through PostgresSQL schemas // (https://www.postgresql.org/docs/current/ddl-schemas.html) // and it is safe to call Postgres.NewBunDB multiple times. // // It fails if Postgres hasn't been properly initialized (via NewPostgres). func (p *Postgres) NewBunDB(tb TestingTB) *bun.DB { tb.Helper() require.NotNil(tb, p, "postgres resource not property initialized") require.NotNil(tb, p.connectionString, "postgres resource not properly initialized") schema := generatePostgresSchema() sqlDB := sql.OpenDB( pgdriver.NewConnector( pgdriver.WithDSN(p.connectionString.String()), pgdriver.WithConnParams(map[string]any{ "search_path": schema, }), ), ) bunDB := bun.NewDB(sqlDB, pgdialect.New()) tb.Cleanup(func() { _ = bunDB.Close() }) bunDB.AddQueryHook(newBunDebugHook()) bo := backoff.NewExponentialBackOff() bo.MaxInterval = postgresPingBackOffMaxInterval bo.MaxElapsedTime = postgresPingBackOffMaxElapsedTime require.NoError(tb, retry(bo, bunDB.Ping), "couldn't ping DB") _, err := bunDB.Exec("CREATE SCHEMA ?", bun.Safe(schema)) require.NoError(tb, err, "couldn't create postgres schema") runMigrations(tb, bunDB) return bunDB } func (p *Postgres) Close() error { if p != nil && p.resource != nil { if err := p.resource.Close(); err != nil { return err } } return nil } func generatePostgresSchema() string { return strings.TrimFunc(strings.ReplaceAll(uuid.NewString(), "-", "_"), unicode.IsNumber) }