package main import ( "context" "database/sql" "fmt" "log/slog" "runtime" "time" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" "github.com/urfave/cli/v2" ) var ( dbFlagConnectionString = &cli.StringFlag{ Name: "db.connectionString", Required: true, EnvVars: []string{"DB_CONNECTION_STRING"}, Usage: "https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING", } dbFlagMaxIdleConns = &cli.IntFlag{ Name: "db.maxIdleConns", Value: 2, //nolint:gomnd EnvVars: []string{"DB_MAX_IDLE_CONNS"}, Usage: "https://pkg.go.dev/database/sql#DB.SetMaxIdleConns", } dbFlagMaxOpenConns = &cli.IntFlag{ Name: "db.maxOpenConns", Value: runtime.NumCPU() * 4, //nolint:gomnd EnvVars: []string{"DB_MAX_OPEN_CONNS"}, Usage: "https://pkg.go.dev/database/sql#DB.SetMaxOpenConns", } dbFlagConnMaxLifetime = &cli.DurationFlag{ Name: "db.connMaxLifetime", Value: 30 * time.Minute, //nolint:gomnd EnvVars: []string{"DB_CONN_MAX_LIFETIME"}, Usage: "https://pkg.go.dev/database/sql#DB.SetConnMaxLifetime", } dbFlagReadTimeout = &cli.DurationFlag{ Name: "db.readTimeout", Value: 10 * time.Second, //nolint:gomnd EnvVars: []string{"DB_READ_TIMEOUT"}, } dbFlagWriteTimeout = &cli.DurationFlag{ Name: "db.writeTimeout", Value: 5 * time.Second, //nolint:gomnd EnvVars: []string{"DB_WRITE_TIMEOUT"}, } dbFlags = []cli.Flag{ dbFlagConnectionString, dbFlagMaxIdleConns, dbFlagMaxOpenConns, dbFlagConnMaxLifetime, dbFlagReadTimeout, dbFlagWriteTimeout, } ) func newBunDBFromFlags(c *cli.Context) (*bun.DB, error) { return newBunDB(bundDBConfig{ connectionString: c.String(dbFlagConnectionString.Name), maxOpenConns: c.Int(dbFlagMaxOpenConns.Name), maxIdleConns: c.Int(dbFlagMaxIdleConns.Name), connMaxLifetime: c.Duration(dbFlagConnMaxLifetime.Name), writeTimeout: c.Duration(dbFlagWriteTimeout.Name), readTimeout: c.Duration(dbFlagReadTimeout.Name), }) } type bundDBConfig struct { connectionString string maxOpenConns int maxIdleConns int connMaxLifetime time.Duration readTimeout time.Duration writeTimeout time.Duration } const dbPingTimeout = 10 * time.Second func newBunDB(cfg bundDBConfig) (*bun.DB, error) { db := bun.NewDB(newSQLDB(cfg), pgdialect.New()) ctx, cancel := context.WithTimeout(context.Background(), dbPingTimeout) defer cancel() if err := db.PingContext(ctx); err != nil { return nil, fmt.Errorf("couldn't ping db: %w", err) } return db, nil } func newSQLDB(cfg bundDBConfig) *sql.DB { db := sql.OpenDB(pgdriver.NewConnector( pgdriver.WithDSN(cfg.connectionString), pgdriver.WithReadTimeout(cfg.readTimeout), pgdriver.WithWriteTimeout(cfg.writeTimeout), )) db.SetMaxOpenConns(cfg.maxOpenConns) db.SetMaxIdleConns(cfg.maxIdleConns) db.SetConnMaxLifetime(cfg.connMaxLifetime) return db } func closeBunDB(db *bun.DB, logger *slog.Logger) { logger.Debug("closing db connections...", slog.Int("db.openConnections", db.Stats().OpenConnections)) if dbCloseErr := db.Close(); dbCloseErr != nil { logger.Warn("couldn't close db connections", slog.Any("error", dbCloseErr)) } else { logger.Debug("db connections closed") } }