package morph import ( "context" "errors" "fmt" "log" "os" "sort" "strings" "sync" "time" "github.com/mattermost/morph/models" "github.com/mattermost/morph/drivers" "github.com/mattermost/morph/sources" ms "github.com/mattermost/morph/drivers/mysql" ps "github.com/mattermost/morph/drivers/postgres" _ "github.com/mattermost/morph/sources/embedded" _ "github.com/mattermost/morph/sources/file" ) var ( migrationProgressStart = "== %s: migrating (%s) =============================================" migrationProgressFinished = "== %s: migrated (%s) ========================================" migrationInterceptor = "== %s: running pre-migration function ==================================" ) const maxProgressLogLength = 100 type Morph struct { config *Config driver drivers.Driver source sources.Source mutex drivers.Locker interceptorLock sync.Mutex intercecptorsUp map[int]Interceptor intercecptorsDown map[int]Interceptor } type Config struct { Logger Logger LockKey string DryRun bool } type EngineOption func(*Morph) error // Interceptor is a handler function that being called just before the migration // applied. If the interceptor returns an error, migration will be aborted. type Interceptor func() error func WithLogger(logger Logger) EngineOption { return func(m *Morph) error { m.config.Logger = logger return nil } } func SetMigrationTableName(name string) EngineOption { return func(m *Morph) error { return m.driver.SetConfig("MigrationsTable", name) } } func SetStatementTimeoutInSeconds(n int) EngineOption { return func(m *Morph) error { return m.driver.SetConfig("StatementTimeoutInSecs", n) } } // WithLock creates a lock table in the database so that the migrations are // guaranteed to be executed from a single instance. The key is used for naming // the mutex. func WithLock(key string) EngineOption { return func(m *Morph) error { m.config.LockKey = key return nil } } // SetDryRun will not execute any migrations if set to true, but // will still log the migrations that would be executed. func SetDryRun(enable bool) EngineOption { return func(m *Morph) error { m.config.DryRun = enable return nil } } // New creates a new instance of the migrations engine from an existing db instance and a migrations source. // If the driver implements the Lockable interface, it will also wait until it has acquired a lock. // The context is propagated to the drivers lock method (if the driver implements divers.Locker interface) and // it can be used to cancel the lock acquisition. func New(ctx context.Context, driver drivers.Driver, source sources.Source, options ...EngineOption) (*Morph, error) { engine := &Morph{ config: &Config{ Logger: newColorLogger(log.New(os.Stderr, "", log.LstdFlags)), // add default logger }, source: source, driver: driver, intercecptorsUp: make(map[int]Interceptor), intercecptorsDown: make(map[int]Interceptor), } for _, option := range options { if err := option(engine); err != nil { return nil, fmt.Errorf("could not apply option: %w", err) } } if err := driver.Ping(); err != nil { return nil, err } if impl, ok := driver.(drivers.Lockable); ok && engine.config.LockKey != "" { var mx drivers.Locker var err error switch impl.DriverName() { case "mysql": mx, err = ms.NewMutex(engine.config.LockKey, driver, engine.config.Logger) case "postgres": mx, err = ps.NewMutex(engine.config.LockKey, driver, engine.config.Logger) default: err = errors.New("driver does not support locking") } if err != nil { return nil, err } engine.mutex = mx err = mx.Lock(ctx) if err != nil { return nil, err } } return engine, nil } // Close closes the underlying database connection of the engine. func (m *Morph) Close() error { if m.mutex != nil { err := m.mutex.Unlock() if err != nil { return err } } return m.driver.Close() } func (m *Morph) apply(migration *models.Migration, saveVersion, dryRun bool) error { start := time.Now() migrationName := migration.Name direction := migration.Direction f := m.getInterceptor(migration) if f != nil { m.config.Logger.Println(formatProgress(fmt.Sprintf(migrationInterceptor, migrationName))) err := f() if err != nil { return err } } m.config.Logger.Println(formatProgress(fmt.Sprintf(migrationProgressStart, migrationName, direction))) if !dryRun { if err := m.driver.Apply(migration, saveVersion); err != nil { return err } } elapsed := time.Since(start) m.config.Logger.Println(formatProgress(fmt.Sprintf(migrationProgressFinished, migrationName, fmt.Sprintf("%.4fs", elapsed.Seconds())))) return nil } // ApplyAll applies all pending migrations. func (m *Morph) ApplyAll() error { _, err := m.Apply(-1) return err } // Applies limited number of migrations upwards. func (m *Morph) Apply(limit int) (int, error) { appliedMigrations, err := m.driver.AppliedMigrations() if err != nil { return -1, err } pendingMigrations, err := computePendingMigrations(appliedMigrations, m.source.Migrations()) if err != nil { return -1, err } migrations := make([]*models.Migration, 0) sortedMigrations := sortMigrations(pendingMigrations) for _, migration := range sortedMigrations { if migration.Direction != models.Up { continue } migrations = append(migrations, migration) } steps := limit if len(migrations) < steps { return -1, fmt.Errorf("there are only %d migrations available, but you requested %d", len(migrations), steps) } if limit < 0 { steps = len(migrations) } var applied int for i := 0; i < steps; i++ { if err := m.apply(migrations[i], true, m.config.DryRun); err != nil { return applied, err } applied++ } return applied, nil } // ApplyDown rollbacks a limited number of migrations // if limit is given below zero, all down scripts are going to be applied. func (m *Morph) ApplyDown(limit int) (int, error) { appliedMigrations, err := m.driver.AppliedMigrations() if err != nil { return -1, err } sortedMigrations := reverseSortMigrations(appliedMigrations) downMigrations, err := findDownScripts(sortedMigrations, m.source.Migrations()) if err != nil { return -1, err } steps := limit if len(sortedMigrations) < steps { return -1, fmt.Errorf("there are only %d migrations available, but you requested %d", len(sortedMigrations), steps) } if limit < 0 { steps = len(sortedMigrations) } var applied int for i := 0; i < steps; i++ { migrationName := sortedMigrations[i].Name if err := m.apply(downMigrations[migrationName], true, m.config.DryRun); err != nil { return applied, err } applied++ } return applied, nil } // Diff returns the difference between the applied migrations and the available migrations. func (m *Morph) Diff(mode models.Direction) ([]*models.Migration, error) { appliedMigrations, err := m.driver.AppliedMigrations() if err != nil { return nil, err } if mode == models.Down { sortedMigrations := reverseSortMigrations(appliedMigrations) downMigrations, err := findDownScripts(sortedMigrations, m.source.Migrations()) if err != nil { return nil, err } diff := make([]*models.Migration, 0, len(downMigrations)) for i := 0; i < len(sortedMigrations); i++ { diff = append(diff, downMigrations[sortedMigrations[i].Name]) } return diff, nil } pendingMigrations, err := computePendingMigrations(appliedMigrations, m.source.Migrations()) if err != nil { return nil, err } var diff []*models.Migration for _, migration := range sortMigrations(pendingMigrations) { if migration.Direction != models.Up { continue } diff = append(diff, migration) } return diff, nil } func (m *Morph) GetOppositeMigrations(migrations []*models.Migration) ([]*models.Migration, error) { var direction models.Direction migrationsMap := make(map[string]*models.Migration) for _, migration := range migrations { if direction == "" { direction = migration.Direction } // check if the migrations has the same direction if direction != migration.Direction { return nil, errors.New("migrations have different directions") } migrationsMap[migration.Name] = migration } rollbackMigrations := make([]*models.Migration, 0, len(migrations)) availableMigrations := m.source.Migrations() for _, migration := range availableMigrations { // skip if we have the same direction for the migration // we are looking for opposite direction if migration.Direction == direction { continue } // we don't have the migration in the map // so we can't rollback it _, ok := migrationsMap[migration.Name] if !ok { continue } rollbackMigrations = append(rollbackMigrations, migration) } if len(migrations) != len(rollbackMigrations) { return nil, errors.New("not all migrations have opposite migrations") } return rollbackMigrations, nil } // GeneratePlan returns the plan to apply these migrations and also includes // the safe rollback steps for the given migrations. func (m *Morph) GeneratePlan(migrations []*models.Migration, auto bool) (*models.Plan, error) { rollbackMigrations, err := m.GetOppositeMigrations(migrations) if err != nil { return nil, fmt.Errorf("could not get opposite migrations: %w", err) } plan := models.NewPlan(migrations, rollbackMigrations, auto) return plan, nil } func (m *Morph) ApplyPlan(plan *models.Plan) error { if err := plan.Validate(); err != nil { return fmt.Errorf("invalid plan: %w", err) } revertMigrations := make([]*models.Migration, 0, len(plan.RevertMigrations)) var err error var failIndex int for i := range plan.Migrations { // add to the revert queue for _, migration := range plan.RevertMigrations { if migration.Name == plan.Migrations[i].Name && migration.Version == plan.Migrations[i].Version { revertMigrations = append(revertMigrations, migration) break } } err = m.apply(plan.Migrations[i], true, m.config.DryRun) if err != nil { break } failIndex = i } if err == nil { return nil } if !plan.Auto { return err } m.config.Logger.Printf("migration %s failed, starting rollback", plan.Migrations[failIndex].Name) for j := len(revertMigrations) - 1; j >= 0; j-- { // There is a special case when we are reverting a rollback // We shouldn't save the version if we are trying to restore the last applied migration // here is an example, lets say we have following migrations in the applied migrations table: // migration_1, migration_2, migration_3 // Once we initiate the rollback, we will have the following: // migration_3, migration_2, migration_1 (to rollback) // Let's say we have a bug in migration_2 and failed. // We don't remove that version from the database, because migration is not successfully rolled back. // So in this case, we need to apply the migration_2 (up) but it will be in the migrations table. // Therefore we are not saving the version in the database because it will fail on the save version step. skipSave := revertMigrations[j].Direction == models.Up && j == len(revertMigrations)-1 rErr := m.apply(revertMigrations[j], !skipSave, m.config.DryRun) if rErr != nil { return fmt.Errorf("could not rollback migrations after trying to migrate: %w", rErr) } m.config.Logger.Printf("successfully rolled back migration: %s", revertMigrations[j].Name) } // return error in any case return fmt.Errorf("could not apply migration: %w", err) } // AddInterceptor registers a handler function to be executed before the actual migration func (m *Morph) AddInterceptor(version int, direction models.Direction, handler Interceptor) { m.interceptorLock.Lock() switch direction { case models.Up: m.intercecptorsUp[version] = handler case models.Down: m.intercecptorsDown[version] = handler } m.interceptorLock.Unlock() } // RemoveInterceptor removes the handler function from the engine func (m *Morph) RemoveInterceptor(version int, direction models.Direction) { m.interceptorLock.Lock() switch direction { case models.Up: delete(m.intercecptorsUp, version) case models.Down: delete(m.intercecptorsDown, version) } m.interceptorLock.Unlock() } func (m *Morph) getInterceptor(migration *models.Migration) Interceptor { m.interceptorLock.Lock() var f Interceptor switch migration.Direction { case models.Up: fn, ok := m.intercecptorsUp[int(migration.Version)] if ok { f = fn } case models.Down: fn, ok := m.intercecptorsDown[int(migration.Version)] if ok { f = fn } } m.interceptorLock.Unlock() return f } // SwapPlanDirection alters the plan direction to the opposite direction. func SwapPlanDirection(plan *models.Plan) { // we need to ensure that the intended migrations for applying is in the // correct order. plan.RevertMigrations = sortMigrations(plan.RevertMigrations) if len(plan.RevertMigrations) > 0 && plan.RevertMigrations[0].Direction == models.Down { plan.RevertMigrations = reverseSortMigrations(plan.RevertMigrations) } // we copy the migrations to set them as revert migrations in the plan migrations := plan.Migrations plan.Migrations = plan.RevertMigrations plan.RevertMigrations = migrations } func reverseSortMigrations(migrations []*models.Migration) []*models.Migration { sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version > migrations[j].Version }) return migrations } func sortMigrations(migrations []*models.Migration) []*models.Migration { sort.Slice(migrations, func(i, j int) bool { return migrations[i].RawName < migrations[j].RawName }) return migrations } func computePendingMigrations(appliedMigrations []*models.Migration, sourceMigrations []*models.Migration) ([]*models.Migration, error) { // sourceMigrations has to be greater or equal to databaseMigrations if len(appliedMigrations) > len(sourceMigrations) { return nil, errors.New("migration mismatch, there are more migrations applied than those were specified in source") } dict := make(map[string]*models.Migration) for _, appliedMigration := range appliedMigrations { dict[appliedMigration.Name] = appliedMigration } var pendingMigrations []*models.Migration for _, sourceMigration := range sourceMigrations { if _, ok := dict[sourceMigration.Name]; !ok { pendingMigrations = append(pendingMigrations, sourceMigration) } } return pendingMigrations, nil } func findDownScripts(appliedMigrations []*models.Migration, sourceMigrations []*models.Migration) (map[string]*models.Migration, error) { tmp := make(map[string]*models.Migration) for _, m := range sourceMigrations { if m.Direction != models.Down { continue } tmp[m.Name] = m } for _, m := range appliedMigrations { _, ok := tmp[m.Name] if !ok { return nil, fmt.Errorf("could not find down script for %s", m.Name) } } return tmp, nil } func formatProgress(p string) string { if len(p) < maxProgressLogLength { return p + strings.Repeat("=", maxProgressLogLength-len(p)) } if len(p) > maxProgressLogLength { return p[:maxProgressLogLength] } return p }