mattermost-community-enterp.../vendor/github.com/mattermost/morph/test_helper.go

206 lines
5.9 KiB
Go

package morph
import (
"bytes"
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"testing"
"text/template"
"time"
"github.com/mattermost/morph/drivers"
"github.com/mattermost/morph/drivers/mysql"
"github.com/mattermost/morph/drivers/postgres"
"github.com/mattermost/morph/drivers/sqlite"
"github.com/mattermost/morph/models"
"github.com/mattermost/morph/sources"
"github.com/mattermost/morph/testlib"
"github.com/stretchr/testify/require"
)
const (
defaultPostgresDSN = "postgres://morph:morph@localhost:6432/morph_test?sslmode=disable"
defaultMySQLDSN = "morph:morph@tcp(127.0.0.1:3307)/morph_test?multiStatements=true"
)
// query is a map of driver name to a map of direction for the dummy queries
var queries = map[string]map[models.Direction]string{
"postgres": {
models.Up: `CREATE TABLE IF NOT EXISTS {{.Name}} (id serial PRIMARY KEY, name text)`,
models.Down: `DROP TABLE IF EXISTS {{.Name}}`,
},
"mysql": {
models.Up: `CREATE TABLE IF NOT EXISTS {{.Name}} (id int(11) NOT NULL AUTO_INCREMENT, name varchar(255), PRIMARY KEY (id))`,
models.Down: `DROP TABLE IF EXISTS {{.Name}}`,
},
"sqlite": {
models.Up: `CREATE TABLE IF NOT EXISTS {{.Name}} (id integer PRIMARY KEY AUTOINCREMENT, name text)`,
models.Down: `DROP TABLE IF EXISTS {{.Name}}`,
},
}
// testHelper is a helper struct for testing morph engine.
// It contains all the necessary information to run tests for all drivers.
// It also provides helper functions to create dummy migrations.
type testHelper struct {
drivers map[string]drivers.Driver
dbInstances map[string]*sql.DB
sqliteFile string
options []EngineOption
migrations map[string][]*models.Migration
}
// testSource is a dummy source for testing purposes.
type testSource struct {
migrations []*models.Migration
}
func (s *testSource) Migrations() []*models.Migration {
return s.migrations
}
// source returns a dummy source for the given driver
func (h *testHelper) source(driverName string) sources.Source {
src := &testSource{
migrations: h.migrations[driverName],
}
return src
}
func newTestHelper(t *testing.T, options ...EngineOption) *testHelper {
helper := &testHelper{
options: options,
drivers: map[string]drivers.Driver{},
migrations: map[string][]*models.Migration{},
dbInstances: map[string]*sql.DB{},
}
helper.initializeDrivers(t)
return helper
}
// creates 3 new migrations
func (h *testHelper) CreateBasicMigrations(t *testing.T) *testHelper {
h.AddMigration(t, "create_table_1")
h.AddMigration(t, "create_table_2")
h.AddMigration(t, "create_table_3")
return h
}
// AddMigration adds a dummy migration to the test helper. It is important to add
// migrations before running the RunForAllDrivers function as migrations are registered
// before the test function is run.
func (h *testHelper) AddMigration(t *testing.T, migrationName string) {
// Just generate a random name
tableName := fmt.Sprintf("test_%s_%d", migrationName, time.Now().Unix())
for name := range h.drivers {
v := 1 + uint32(len(h.migrations[name]))
h.migrations[name] = append(h.migrations[name], &models.Migration{
Name: migrationName,
Direction: models.Up,
Version: v,
Bytes: getMigration(t, name, models.Up, tableName),
RawName: fmt.Sprintf("%d_%s.up.sql", v, migrationName),
})
h.migrations[name] = append(h.migrations[name], &models.Migration{
Name: migrationName,
Direction: models.Down,
Version: v,
Bytes: getMigration(t, name, models.Down, tableName),
RawName: fmt.Sprintf("%d_%s.down.sql", v, migrationName),
})
}
}
// getMigration returns a dummy migration for the given driver and direction
func getMigration(t *testing.T, driver string, direction models.Direction, tableName string) []byte {
tmp, err := template.New("query").Parse(queries[driver][direction])
require.NoError(t, err)
var b bytes.Buffer
err = tmp.Execute(&b, struct{ Name string }{Name: tableName})
require.NoError(t, err)
return b.Bytes()
}
// RunForAllDrivers runs the given test function for all drivers of the test helper
func (h *testHelper) RunForAllDrivers(t *testing.T, f func(*testing.T, *Morph), name ...string) {
var testName string
if len(name) > 0 {
testName = name[0] + "/"
}
for name, driver := range h.drivers {
t.Run(testName+name, func(t *testing.T) {
engine, err := New(context.Background(), driver, h.source(name), h.options...)
require.NoError(t, err)
f(t, engine)
})
}
}
// TearDown closes all database connections and removes all tables from the databases
func (h *testHelper) Teardown(t *testing.T) {
assets := testlib.Assets()
for name, driver := range h.drivers {
b, err := assets.ReadFile(filepath.Join("scripts", name+"_drop_all_tables.sql"))
require.NoError(t, err)
migration := &models.Migration{
Bytes: b,
}
err = driver.Apply(migration, false)
require.NoError(t, err)
}
for _, instance := range h.dbInstances {
err := instance.Close()
require.NoError(t, err)
}
err := os.RemoveAll(h.sqliteFile)
require.NoError(t, err)
}
func (h *testHelper) initializeDrivers(t *testing.T) {
// postgres
db, err := sql.Open("postgres", defaultPostgresDSN)
require.NoError(t, err)
pgDriver, err := postgres.WithInstance(db)
require.NoError(t, err)
h.drivers["postgres"] = pgDriver
h.dbInstances["postgres"] = db
// mysql
db2, err := sql.Open("mysql", defaultMySQLDSN)
require.NoError(t, err)
mysqlDriver, err := mysql.WithInstance(db2)
require.NoError(t, err)
h.drivers["mysql"] = mysqlDriver
h.dbInstances["mysql"] = db2
// sqlite
testDBFile, err := os.CreateTemp("", "morph-test.db")
require.NoError(t, err)
tfInfo, err := testDBFile.Stat()
require.NoError(t, err)
h.sqliteFile = filepath.Join(os.TempDir(), tfInfo.Name())
db3, err := sql.Open("sqlite", h.sqliteFile)
require.NoError(t, err)
sqliteDriver, err := sqlite.WithInstance(db3)
require.NoError(t, err)
h.drivers["sqlite"] = sqliteDriver
h.dbInstances["sqlite"] = db3
}