Full Mattermost server source with integrated Community Enterprise features. Includes vendor directory for offline/air-gapped builds. Structure: - enterprise-impl/: Enterprise feature implementations - enterprise-community/: Init files that register implementations - enterprise/: Bridge imports (community_imports.go) - vendor/: All dependencies for offline builds Build (online): go build ./cmd/mattermost Build (offline/air-gapped): go build -mod=vendor ./cmd/mattermost 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
412 lines
14 KiB
Go
412 lines
14 KiB
Go
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
|
// See LICENSE.txt for license information.
|
|
|
|
package storetest
|
|
|
|
import (
|
|
"sort"
|
|
"testing"
|
|
|
|
"github.com/mattermost/mattermost/server/public/model"
|
|
"github.com/mattermost/mattermost/server/public/shared/request"
|
|
"github.com/mattermost/mattermost/server/v8/channels/store"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func newValidOutgoingOAuthConnection() *model.OutgoingOAuthConnection {
|
|
return &model.OutgoingOAuthConnection{
|
|
CreatorId: model.NewId(),
|
|
Name: "Test Connection",
|
|
ClientId: model.NewId(),
|
|
ClientSecret: model.NewId(),
|
|
OAuthTokenURL: "https://nowhere.com/oauth/token",
|
|
GrantType: model.OutgoingOAuthConnectionGrantTypeClientCredentials,
|
|
Audiences: []string{"https://nowhere.com"},
|
|
}
|
|
}
|
|
|
|
func cleanupOutgoingOAuthConnections(t *testing.T, ss store.Store) func() {
|
|
return func() {
|
|
// Delete all outgoing connections
|
|
connections, err := ss.OutgoingOAuthConnection().GetConnections(request.TestContext(t), model.OutgoingOAuthConnectionGetConnectionsFilter{
|
|
Limit: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
for _, conn := range connections {
|
|
err := ss.OutgoingOAuthConnection().DeleteConnection(request.TestContext(t), conn.Id)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestOutgoingOAuthConnectionStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
|
t.Run("SaveConnection", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
testSaveOutgoingOAuthConnection(t, ss)
|
|
})
|
|
t.Run("UpdateConnection", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
testUpdateOutgoingOAuthConnection(t, ss)
|
|
})
|
|
t.Run("GetConnection", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
testGetOutgoingOAuthConnection(t, ss)
|
|
})
|
|
t.Run("GetConnectionsByAudience", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
testGetOutgoingOAuthConnectionByAudience(t, ss)
|
|
})
|
|
t.Run("GetConnections", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
testGetOutgoingOAuthConnections(t, ss)
|
|
})
|
|
t.Run("DeleteConnection", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
testDeleteOutgoingOAuthConnection(t, ss)
|
|
})
|
|
}
|
|
|
|
func testSaveOutgoingOAuthConnection(t *testing.T, ss store.Store) {
|
|
c := request.TestContext(t)
|
|
|
|
t.Run("save/get", func(t *testing.T) {
|
|
// Define test data
|
|
connection := newValidOutgoingOAuthConnection()
|
|
|
|
// Save the connection
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve the connection
|
|
storeConn, err := ss.OutgoingOAuthConnection().GetConnection(c, connection.Id)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, storeConn)
|
|
})
|
|
|
|
t.Run("save without id should fail", func(t *testing.T) {
|
|
connection := &model.OutgoingOAuthConnection{
|
|
Id: model.NewId(),
|
|
}
|
|
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("save with incorrect grant type should fail", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
connection.GrantType = "incorrect"
|
|
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func testUpdateOutgoingOAuthConnection(t *testing.T, ss store.Store) {
|
|
c := request.TestContext(t)
|
|
|
|
t.Run("update/get", func(t *testing.T) {
|
|
// Define test data
|
|
connection := newValidOutgoingOAuthConnection()
|
|
|
|
// Save the connection
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
// Update the connection
|
|
connection.Name = "Updated Name"
|
|
_, err = ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve the connection
|
|
storeConn, err := ss.OutgoingOAuthConnection().GetConnection(c, connection.Id)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, storeConn)
|
|
})
|
|
|
|
t.Run("update non-existing", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
connection.Id = model.NewId()
|
|
|
|
_, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("update without id should fail", func(t *testing.T) {
|
|
connection := &model.OutgoingOAuthConnection{
|
|
Id: model.NewId(),
|
|
}
|
|
|
|
_, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("update should update all fields", func(t *testing.T) {
|
|
// Define test data
|
|
connection := newValidOutgoingOAuthConnection()
|
|
|
|
// Save the connection
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
// Update the connection
|
|
connection.Name = "Updated Name"
|
|
connection.ClientId = "Updated ClientId"
|
|
connection.ClientSecret = "Updated ClientSecret"
|
|
connection.OAuthTokenURL = "https://nowhere.com/updated"
|
|
// connection.GrantType = "client_credentials" // ignoring since we only allow one for now
|
|
connection.Audiences = []string{"https://nowhere.com/updated"}
|
|
_, err = ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve the connection
|
|
storeConn, err := ss.OutgoingOAuthConnection().GetConnection(c, connection.Id)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, storeConn)
|
|
})
|
|
|
|
t.Run("patch", func(t *testing.T) {
|
|
t.Run("name", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
connection.Name = "Updated Name"
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("client id", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
connection.ClientId = "Updated ClientId"
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("client secret", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
connection.ClientSecret = "Updated ClientSecret"
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("oauth token url", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
connection.OAuthTokenURL = "https://nowhere.com/updated"
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("grant type", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
connection.GrantType = model.OutgoingOAuthConnectionGrantTypeClientCredentials
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("audiences", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
connection.Audiences = model.StringArray{"https://nowhere.com/updated"}
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("credentials username", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
username := "updated username"
|
|
connection.CredentialsUsername = &username
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
|
|
t.Run("credentials password", func(t *testing.T) {
|
|
connection := newValidOutgoingOAuthConnection()
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
password := "updated password"
|
|
connection.CredentialsPassword = &password
|
|
|
|
updated, err := ss.OutgoingOAuthConnection().UpdateConnection(c, connection)
|
|
require.NoError(t, err)
|
|
require.Equal(t, connection, updated)
|
|
})
|
|
})
|
|
}
|
|
|
|
func testGetOutgoingOAuthConnection(t *testing.T, ss store.Store) {
|
|
c := request.TestContext(t)
|
|
|
|
t.Run("get non-existing", func(t *testing.T) {
|
|
nonExistingID := model.NewId()
|
|
var expected *store.ErrNotFound
|
|
_, err := ss.OutgoingOAuthConnection().GetConnection(c, nonExistingID)
|
|
require.ErrorAs(t, err, &expected)
|
|
})
|
|
}
|
|
|
|
func runAudienceTests(t *testing.T, ss store.Store, connection *model.OutgoingOAuthConnection) {
|
|
c := request.TestContext(t)
|
|
|
|
t.Run("find by host only", func(t *testing.T) {
|
|
conn, err := ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{Audience: "knowhere.com"})
|
|
require.NoError(t, err)
|
|
require.Len(t, conn, 1)
|
|
require.Equal(t, []*model.OutgoingOAuthConnection{connection}, conn)
|
|
})
|
|
|
|
t.Run("find by host and path", func(t *testing.T) {
|
|
conn, err := ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{Audience: "knowhere.com/audience"})
|
|
require.NoError(t, err)
|
|
require.Len(t, conn, 1)
|
|
require.Equal(t, []*model.OutgoingOAuthConnection{connection}, conn)
|
|
})
|
|
|
|
t.Run("find by full url", func(t *testing.T) {
|
|
conn, err := ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{Audience: "https://knowhere.com/audience"})
|
|
require.NoError(t, err)
|
|
require.Len(t, conn, 1)
|
|
require.Equal(t, []*model.OutgoingOAuthConnection{connection}, conn)
|
|
})
|
|
|
|
t.Run("non-existent", func(t *testing.T) {
|
|
conn, err := ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{Audience: "https://mattermost.com"})
|
|
require.NoError(t, err)
|
|
require.Empty(t, conn)
|
|
})
|
|
}
|
|
|
|
func testGetOutgoingOAuthConnectionByAudience(t *testing.T, ss store.Store) {
|
|
t.Run("get non-existing", func(t *testing.T) {
|
|
c := request.TestContext(t)
|
|
|
|
nonExistingID := model.NewId()
|
|
var expected *store.ErrNotFound
|
|
_, err := ss.OutgoingOAuthConnection().GetConnection(c, nonExistingID)
|
|
require.ErrorAs(t, err, &expected)
|
|
})
|
|
|
|
t.Run("get existing (single audience)", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
c := request.TestContext(t)
|
|
|
|
connection := newValidOutgoingOAuthConnection()
|
|
connection.Audiences = []string{"https://knowhere.com/audience"}
|
|
var err error
|
|
connection, err = ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
runAudienceTests(t, ss, connection)
|
|
})
|
|
|
|
t.Run("get existing (multiple audiences)", func(t *testing.T) {
|
|
t.Cleanup(cleanupOutgoingOAuthConnections(t, ss))
|
|
c := request.TestContext(t)
|
|
|
|
connection := newValidOutgoingOAuthConnection()
|
|
connection.Audiences = []string{"https://knowhere.com/audience", "https://example.com"}
|
|
var err error
|
|
connection, err = ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
runAudienceTests(t, ss, connection)
|
|
})
|
|
}
|
|
|
|
func testGetOutgoingOAuthConnections(t *testing.T, ss store.Store) {
|
|
c := request.TestContext(t)
|
|
|
|
// Define test data
|
|
connection1 := newValidOutgoingOAuthConnection()
|
|
connection2 := newValidOutgoingOAuthConnection()
|
|
connection3 := newValidOutgoingOAuthConnection()
|
|
|
|
// Save the connections
|
|
connection1, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection1)
|
|
require.NoError(t, err)
|
|
connection2, err = ss.OutgoingOAuthConnection().SaveConnection(c, connection2)
|
|
require.NoError(t, err)
|
|
connection3, err = ss.OutgoingOAuthConnection().SaveConnection(c, connection3)
|
|
require.NoError(t, err)
|
|
|
|
connections := []*model.OutgoingOAuthConnection{connection1, connection2, connection3}
|
|
sort.Slice(connections, func(i, j int) bool {
|
|
return connections[i].Id < connections[j].Id
|
|
})
|
|
|
|
t.Run("get all", func(t *testing.T) {
|
|
// Retrieve the connections
|
|
conns, err := ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{Limit: 3})
|
|
require.NoError(t, err)
|
|
require.Len(t, conns, 3)
|
|
})
|
|
|
|
t.Run("get connections using pagination", func(t *testing.T) {
|
|
// Retrieve the first page
|
|
conns, err := ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{Limit: 1})
|
|
require.NoError(t, err)
|
|
require.Len(t, conns, 1)
|
|
require.Equal(t, connections[0].Id, conns[0].Id, "should return the first connection")
|
|
|
|
// Retrieve the second page
|
|
conns, err = ss.OutgoingOAuthConnection().GetConnections(c, model.OutgoingOAuthConnectionGetConnectionsFilter{OffsetId: connections[0].Id})
|
|
require.NoError(t, err)
|
|
require.Len(t, conns, 2)
|
|
require.Equal(t, connections[1].Id, conns[0].Id, "should return the second connection")
|
|
require.Equal(t, connections[2].Id, conns[1].Id, "should return the third connection")
|
|
})
|
|
}
|
|
|
|
func testDeleteOutgoingOAuthConnection(t *testing.T, ss store.Store) {
|
|
c := request.TestContext(t)
|
|
|
|
t.Run("delete", func(t *testing.T) {
|
|
// Define test data
|
|
connection := newValidOutgoingOAuthConnection()
|
|
|
|
// Save the connection
|
|
_, err := ss.OutgoingOAuthConnection().SaveConnection(c, connection)
|
|
require.NoError(t, err)
|
|
|
|
// Delete the connection
|
|
err = ss.OutgoingOAuthConnection().DeleteConnection(c, connection.Id)
|
|
require.NoError(t, err)
|
|
|
|
// Retrieve the connection
|
|
_, err = ss.OutgoingOAuthConnection().GetConnection(c, connection.Id)
|
|
var expected *store.ErrNotFound
|
|
require.ErrorAs(t, err, &expected)
|
|
})
|
|
}
|