402 lines
11 KiB
Go
402 lines
11 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
_ "github.com/lib/pq"
|
|
"github.com/mattermost/morph/drivers"
|
|
"github.com/mattermost/morph/models"
|
|
)
|
|
|
|
var (
|
|
driverName = "postgres"
|
|
defaultMigrationMaxSize = 10 * 1 << 20 // 10 MB
|
|
configParams = []string{
|
|
"x-migration-max-size",
|
|
"x-migrations-table",
|
|
"x-statement-timeout",
|
|
}
|
|
)
|
|
|
|
// The format is morph: followed by a comma separated list of values.
|
|
// For now, we are taking the whole string in a single constant.
|
|
// Later, if we need more values, we can split "morph:" to a separate constant.
|
|
const nonTransactionalPrefix = "morph:nontransactional"
|
|
|
|
type driverConfig struct {
|
|
drivers.Config
|
|
databaseName string
|
|
schemaName string
|
|
closeDBonClose bool
|
|
}
|
|
|
|
type postgres struct {
|
|
conn *sql.Conn
|
|
db *sql.DB
|
|
config *driverConfig
|
|
}
|
|
|
|
func WithInstance(dbInstance *sql.DB) (drivers.Driver, error) {
|
|
conn, err := dbInstance.Conn(context.Background())
|
|
if err != nil {
|
|
return nil, &drivers.DatabaseError{Driver: driverName, Command: "grabbing_connection", OrigErr: err, Message: "failed to grab connection to the database"}
|
|
}
|
|
|
|
driverConfig := getDefaultConfig()
|
|
if driverConfig.databaseName, err = currentDatabaseNameFromDB(conn, driverConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if driverConfig.schemaName, err = currentSchema(conn, driverConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postgres{
|
|
conn: conn,
|
|
db: dbInstance,
|
|
config: driverConfig,
|
|
}, nil
|
|
}
|
|
|
|
func Open(connURL string) (drivers.Driver, error) {
|
|
customParams, err := drivers.ExtractCustomParams(connURL, configParams)
|
|
if err != nil {
|
|
return nil, &drivers.AppError{Driver: driverName, OrigErr: err, Message: "failed to parse custom parameters from url"}
|
|
}
|
|
|
|
sanitizedConnURL, err := drivers.RemoveParamsFromURL(connURL, configParams)
|
|
if err != nil {
|
|
return nil, &drivers.AppError{Driver: driverName, OrigErr: err, Message: "failed to sanitize url from custom parameters"}
|
|
}
|
|
|
|
driverConfig, err := mergeConfigWithParams(customParams, getDefaultConfig())
|
|
if err != nil {
|
|
return nil, &drivers.AppError{Driver: driverName, OrigErr: err, Message: "failed to merge custom params to driver config"}
|
|
}
|
|
|
|
db, err := sql.Open(driverName, sanitizedConnURL)
|
|
if err != nil {
|
|
return nil, &drivers.DatabaseError{Driver: driverName, Command: "opening_connection", OrigErr: err, Message: "failed to open connection with the database"}
|
|
}
|
|
|
|
conn, err := db.Conn(context.Background())
|
|
if err != nil {
|
|
return nil, &drivers.DatabaseError{Driver: driverName, Command: "grabbing_connection", OrigErr: err, Message: "failed to grab connection to the database"}
|
|
}
|
|
|
|
if driverConfig.databaseName, err = extractDatabaseNameFromURL(connURL); err != nil {
|
|
return nil, &drivers.AppError{Driver: driverName, OrigErr: err, Message: "failed to extract database name from connection url"}
|
|
}
|
|
|
|
if driverConfig.schemaName, err = currentSchema(conn, driverConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
driverConfig.closeDBonClose = true
|
|
|
|
return &postgres{
|
|
db: db,
|
|
config: driverConfig,
|
|
conn: conn,
|
|
}, nil
|
|
}
|
|
|
|
func currentSchema(conn *sql.Conn, config *driverConfig) (string, error) {
|
|
query := "SELECT CURRENT_SCHEMA()"
|
|
|
|
ctx, cancel := drivers.GetContext(config.StatementTimeoutInSecs)
|
|
defer cancel()
|
|
|
|
var schemaName string
|
|
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
|
|
return "", &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to fetch current schema",
|
|
Command: "current_schema",
|
|
Query: []byte(query),
|
|
}
|
|
}
|
|
return schemaName, nil
|
|
}
|
|
|
|
func mergeConfigWithParams(params map[string]string, config *driverConfig) (*driverConfig, error) {
|
|
var err error
|
|
|
|
for _, configKey := range configParams {
|
|
if v, ok := params[configKey]; ok {
|
|
switch configKey {
|
|
case "x-migration-max-size":
|
|
if config.MigrationMaxSize, err = strconv.Atoi(v); err != nil {
|
|
return nil, errors.New(fmt.Sprintf("failed to cast config param %s of %s", configKey, v))
|
|
}
|
|
case "x-migrations-table":
|
|
config.MigrationsTable = v
|
|
case "x-statement-timeout":
|
|
if config.StatementTimeoutInSecs, err = strconv.Atoi(v); err != nil {
|
|
return nil, errors.New(fmt.Sprintf("failed to cast config param %s of %s", configKey, v))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
func (pg *postgres) Ping() error {
|
|
ctx, cancel := drivers.GetContext(pg.config.StatementTimeoutInSecs)
|
|
defer cancel()
|
|
|
|
return pg.conn.PingContext(ctx)
|
|
}
|
|
|
|
func (pg *postgres) createSchemaTableIfNotExists() (err error) {
|
|
ctx, cancel := drivers.GetContext(pg.config.StatementTimeoutInSecs)
|
|
defer cancel()
|
|
|
|
createTableIfNotExistsQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint not null primary key, name varchar not null)", pg.config.MigrationsTable)
|
|
if _, err = pg.conn.ExecContext(ctx, createTableIfNotExistsQuery); err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed while executing query",
|
|
Command: "create_migrations_table_if_not_exists",
|
|
Query: []byte(createTableIfNotExistsQuery),
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (postgres) DriverName() string {
|
|
return driverName
|
|
}
|
|
|
|
func (pg *postgres) Close() error {
|
|
if pg.conn != nil {
|
|
if err := pg.conn.Close(); err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to close database connection",
|
|
Command: "pg_conn_close",
|
|
Query: nil,
|
|
}
|
|
}
|
|
}
|
|
|
|
if pg.db != nil && pg.config.closeDBonClose {
|
|
if err := pg.db.Close(); err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to close database",
|
|
Command: "pg_db_close",
|
|
Query: nil,
|
|
}
|
|
}
|
|
pg.db = nil
|
|
}
|
|
|
|
pg.conn = nil
|
|
return nil
|
|
}
|
|
|
|
func (pg *postgres) Apply(migration *models.Migration, saveVersion bool) (err error) {
|
|
query := migration.Query()
|
|
|
|
ctx, cancel := drivers.GetContext(pg.config.StatementTimeoutInSecs)
|
|
defer cancel()
|
|
|
|
nonTransactional := strings.HasPrefix(query, "-- "+nonTransactionalPrefix)
|
|
// We wrap with a transaction only when there is no non-transactional prefix.
|
|
if !nonTransactional {
|
|
transaction, err := pg.conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "error while opening a transaction to the database",
|
|
Command: "begin_transaction",
|
|
}
|
|
}
|
|
|
|
if err = executeQuery(ctx, transaction, query); err != nil {
|
|
return err
|
|
}
|
|
|
|
if saveVersion {
|
|
if err = executeQuery(ctx, transaction, pg.addMigrationQuery(migration)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = transaction.Commit()
|
|
if err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "error while committing a transaction to the database",
|
|
Command: "commit_transaction",
|
|
}
|
|
}
|
|
} else {
|
|
_, err := pg.conn.ExecContext(ctx, query)
|
|
if err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to execute migration",
|
|
Command: "executing_query",
|
|
Query: []byte(query),
|
|
}
|
|
}
|
|
|
|
if saveVersion {
|
|
_, err = pg.conn.ExecContext(ctx, pg.addMigrationQuery(migration))
|
|
if err != nil {
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to save version",
|
|
Command: "executing_query",
|
|
Query: []byte(query),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (pg *postgres) AppliedMigrations() (migrations []*models.Migration, err error) {
|
|
if pg.conn == nil {
|
|
return nil, &drivers.AppError{
|
|
OrigErr: errors.New("driver has no connection established"),
|
|
Message: "database connection is missing",
|
|
Driver: driverName,
|
|
}
|
|
}
|
|
|
|
if err := pg.createSchemaTableIfNotExists(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := fmt.Sprintf("SELECT version, name FROM %s", pg.config.MigrationsTable)
|
|
ctx, cancel := drivers.GetContext(pg.config.StatementTimeoutInSecs)
|
|
defer cancel()
|
|
var appliedMigrations []*models.Migration
|
|
var version uint32
|
|
var name string
|
|
|
|
rows, err := pg.conn.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to fetch applied migrations",
|
|
Command: "select_applied_migrations",
|
|
Query: []byte(query),
|
|
}
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
if err := rows.Scan(&version, &name); err != nil {
|
|
return nil, &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to scan applied migration row",
|
|
Command: "scan_applied_migrations",
|
|
}
|
|
}
|
|
|
|
appliedMigrations = append(appliedMigrations, &models.Migration{
|
|
Name: name,
|
|
Version: version,
|
|
Direction: models.Up,
|
|
})
|
|
}
|
|
|
|
return appliedMigrations, nil
|
|
}
|
|
|
|
func (pg *postgres) addMigrationQuery(migration *models.Migration) string {
|
|
if migration.Direction == models.Down {
|
|
return fmt.Sprintf("DELETE FROM %s WHERE (Version=%d AND NAME='%s')", pg.config.MigrationsTable, migration.Version, migration.Name)
|
|
}
|
|
return fmt.Sprintf("INSERT INTO %s (version, name) VALUES (%d, '%s')", pg.config.MigrationsTable, migration.Version, migration.Name)
|
|
}
|
|
|
|
func executeQuery(ctx context.Context, transaction *sql.Tx, query string) error {
|
|
if _, err := transaction.ExecContext(ctx, query); err != nil {
|
|
if txErr := transaction.Rollback(); txErr != nil {
|
|
err = errors.Wrap(errors.New(err.Error()+txErr.Error()), "failed to execute query in migration transaction")
|
|
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Command: "rollback_transaction",
|
|
}
|
|
}
|
|
|
|
return &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to execute migration",
|
|
Command: "executing_query",
|
|
Query: []byte(query),
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func currentDatabaseNameFromDB(conn *sql.Conn, config *driverConfig) (string, error) {
|
|
query := "SELECT CURRENT_DATABASE()"
|
|
|
|
ctx, cancel := drivers.GetContext(config.StatementTimeoutInSecs)
|
|
defer cancel()
|
|
|
|
var databaseName string
|
|
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
|
|
return "", &drivers.DatabaseError{
|
|
OrigErr: err,
|
|
Driver: driverName,
|
|
Message: "failed to fetch database name",
|
|
Command: "current_database",
|
|
Query: []byte(query),
|
|
}
|
|
}
|
|
return databaseName, nil
|
|
}
|
|
|
|
func (pg *postgres) SetConfig(key string, value interface{}) error {
|
|
if pg.config != nil {
|
|
switch key {
|
|
case "StatementTimeoutInSecs":
|
|
n, ok := value.(int)
|
|
if ok {
|
|
pg.config.StatementTimeoutInSecs = n
|
|
return nil
|
|
}
|
|
return fmt.Errorf("incorrect value type for %s", key)
|
|
case "MigrationsTable":
|
|
n, ok := value.(string)
|
|
if ok {
|
|
pg.config.MigrationsTable = n
|
|
return nil
|
|
}
|
|
return fmt.Errorf("incorrect value type for %s", key)
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("incorrect key name %q", key)
|
|
}
|