263 lines
6.2 KiB
Go
263 lines
6.2 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
ms "github.com/go-sql-driver/mysql"
|
|
"github.com/mattermost/morph/drivers"
|
|
)
|
|
|
|
// Mutex is similar to sync.Mutex, except usable by morph to lock the db.
|
|
//
|
|
// Pick a unique name for each mutex your plugin requires.
|
|
//
|
|
// A Mutex must not be copied after first use.
|
|
type Mutex struct {
|
|
noCopy // nolint:unused
|
|
key string
|
|
|
|
// lock guards the variables used to manage the refresh task, and is not itself related to
|
|
// the db lock.
|
|
lock sync.Mutex
|
|
stopRefresh chan bool
|
|
refreshDone chan bool
|
|
conn *sql.Conn
|
|
|
|
logger drivers.Logger
|
|
}
|
|
|
|
// NewMutex creates a mutex with the given key name.
|
|
//
|
|
// returns error if key is empty.
|
|
func NewMutex(key string, driver drivers.Driver, logger drivers.Logger) (*Mutex, error) {
|
|
key, err := drivers.MakeLockKey(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), drivers.TTL)
|
|
defer cancel()
|
|
|
|
ms, ok := driver.(*mysql)
|
|
if !ok {
|
|
return nil, errors.New("incorrect implementation of the driver")
|
|
}
|
|
|
|
conn, err := ms.db.Conn(context.Background())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
createTableIfNotExistsQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (Id varchar(64) NOT NULL, ExpireAt bigint(20) NOT NULL, PRIMARY KEY (Id)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", drivers.MutexTableName)
|
|
if _, err = conn.ExecContext(ctx, createTableIfNotExistsQuery); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Mutex{
|
|
key: key,
|
|
conn: conn,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
// lock makes a single attempt to lock the mutex, returning true only if successful.
|
|
func (m *Mutex) tryLock(ctx context.Context) (bool, error) {
|
|
now := time.Now()
|
|
tx, err := m.conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer m.finalizeTx(tx)
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s (Id, ExpireAt) VALUES (?, ?)", drivers.MutexTableName)
|
|
if _, err := tx.Exec(query, m.key, now.Add(drivers.TTL).Unix()); err != nil {
|
|
if mysqlErr, ok := err.(*ms.MySQLError); ok && mysqlErr.Number == 1062 {
|
|
m.logger.Println("DB is locked, going to try acquire the lock if it is expired.")
|
|
}
|
|
m.finalizeTx(tx)
|
|
|
|
err2 := m.releaseLock(ctx, now)
|
|
if err2 == nil { // lock has been released due to expiration
|
|
return true, nil
|
|
}
|
|
|
|
return false, fmt.Errorf("failed to lock mutex: %w", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (m *Mutex) releaseLock(ctx context.Context, t time.Time) error {
|
|
tx, err := m.conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer m.finalizeTx(tx)
|
|
|
|
e, err := m.getExpireAt(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if t.Unix() < e {
|
|
return errors.New("could not release the lock")
|
|
}
|
|
|
|
query := fmt.Sprintf("UPDATE %s SET ExpireAt = ? WHERE Id = ?", drivers.MutexTableName)
|
|
if err = executeTx(tx, query, t.Add(drivers.TTL).Unix(), m.key); err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to set new expireat for mutex: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Mutex) getExpireAt(tx *sql.Tx) (int64, error) {
|
|
var expireAt int64
|
|
query := fmt.Sprintf("SELECT ExpireAt FROM %s WHERE Id = ?", drivers.MutexTableName)
|
|
err := tx.QueryRow(query, m.key).Scan(&expireAt)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to fetch mutex from db: %w", err)
|
|
}
|
|
|
|
return expireAt, nil
|
|
}
|
|
|
|
// refreshLock rewrites the lock key value with a new expiry, returning nil only if successful.
|
|
func (m *Mutex) refreshLock(ctx context.Context) error {
|
|
tx, err := m.conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer m.finalizeTx(tx)
|
|
|
|
e, err := m.getExpireAt(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tmp := time.Unix(e, 0)
|
|
query := fmt.Sprintf("UPDATE %s SET ExpireAt = ? WHERE Id = ?", drivers.MutexTableName)
|
|
if err = executeTx(tx, query, tmp.Add(drivers.TTL).Unix(), m.key); err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to refresh expireat for mutex: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Lock locks m unless the context is canceled. If the mutex is already locked by any other
|
|
// instance, including the current one, the calling goroutine blocks until the mutex can be locked,
|
|
// or the context is canceled.
|
|
//
|
|
// The mutex is locked only if a nil error is returned.
|
|
func (m *Mutex) Lock(ctx context.Context) error {
|
|
var waitInterval time.Duration
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(waitInterval):
|
|
}
|
|
|
|
ok, err := m.tryLock(ctx)
|
|
if err != nil || !ok {
|
|
waitInterval = drivers.NextWaitInterval(waitInterval, err)
|
|
continue
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
stop := make(chan bool)
|
|
done := make(chan bool)
|
|
go func() {
|
|
defer close(done)
|
|
t := time.NewTicker(drivers.RefreshInterval)
|
|
for {
|
|
select {
|
|
case <-t.C:
|
|
err := m.refreshLock(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
case <-stop:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
m.lock.Lock()
|
|
m.stopRefresh = stop
|
|
m.refreshDone = done
|
|
m.lock.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Unlock unlocks m. It is a run-time error if m is not locked on entry to Unlock.
|
|
//
|
|
// Just like sync.Mutex, a locked Lock is not associated with a particular goroutine or a process.
|
|
func (m *Mutex) Unlock() error {
|
|
m.lock.Lock()
|
|
if m.stopRefresh == nil {
|
|
m.lock.Unlock()
|
|
panic("mutex has not been acquired")
|
|
}
|
|
|
|
close(m.stopRefresh)
|
|
m.stopRefresh = nil
|
|
<-m.refreshDone
|
|
m.lock.Unlock()
|
|
|
|
defer m.conn.Close()
|
|
|
|
// If an error occurs deleting, the mutex will still expire, allowing later retry.
|
|
query := fmt.Sprintf("DELETE FROM %s WHERE Id = ?", drivers.MutexTableName)
|
|
_, err := m.conn.ExecContext(context.Background(), query, m.key)
|
|
return err
|
|
}
|
|
|
|
func executeTx(tx *sql.Tx, query string, args ...interface{}) error {
|
|
if _, err := tx.Exec(query, args...); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Mutex) finalizeTx(tx *sql.Tx) {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
m.logger.Printf("failed to rollback transaction: %s", err)
|
|
}
|
|
}
|
|
|
|
// noCopy may be embedded into structs which must not be copied
|
|
// after the first use.
|
|
//
|
|
// See https://golang.org/issues/8005#issuecomment-190753527
|
|
// for details.
|
|
type noCopy struct{} // nolint:unused
|
|
|
|
// Lock is a no-op used by -copylocks checker from `go vet`.
|
|
func (*noCopy) Lock() {} // nolint:unused
|