Full Mattermost server source with integrated Community Enterprise features. Includes vendor directory for offline/air-gapped builds. Structure: - enterprise-impl/: Enterprise feature implementations - enterprise-community/: Init files that register implementations - enterprise/: Bridge imports (community_imports.go) - vendor/: All dependencies for offline builds Build (online): go build ./cmd/mattermost Build (offline/air-gapped): go build -mod=vendor ./cmd/mattermost 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1009 lines
29 KiB
Go
1009 lines
29 KiB
Go
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
|
// See LICENSE.txt for license information.
|
|
|
|
package platform
|
|
|
|
import (
|
|
"fmt"
|
|
"hash/maphash"
|
|
"iter"
|
|
"maps"
|
|
"runtime"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/mattermost/mattermost/server/public/model"
|
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
|
"github.com/mattermost/mattermost/server/public/shared/request"
|
|
"github.com/mattermost/mattermost/server/v8/channels/store"
|
|
)
|
|
|
|
const (
|
|
broadcastQueueSize = 4096
|
|
inactiveConnReaperInterval = 5 * time.Minute
|
|
)
|
|
|
|
type SuiteIFace interface {
|
|
GetSession(token string) (*model.Session, *model.AppError)
|
|
RolesGrantPermission(roleNames []string, permissionId string) bool
|
|
HasPermissionToReadChannel(rctx request.CTX, userID string, channel *model.Channel) bool
|
|
UserCanSeeOtherUser(rctx request.CTX, userID string, otherUserId string) (bool, *model.AppError)
|
|
MFARequired(rctx request.CTX) *model.AppError
|
|
}
|
|
|
|
type webConnActivityMessage struct {
|
|
userID string
|
|
sessionToken string
|
|
activityAt int64
|
|
}
|
|
|
|
type webConnDirectMessage struct {
|
|
conn *WebConn
|
|
msg model.WebSocketMessage
|
|
}
|
|
|
|
type webConnSessionMessage struct {
|
|
userID string
|
|
sessionToken string
|
|
isRegistered chan bool
|
|
}
|
|
|
|
type webConnRegisterMessage struct {
|
|
conn *WebConn
|
|
err chan error
|
|
}
|
|
|
|
type webConnCheckMessage struct {
|
|
userID string
|
|
connectionID string
|
|
result chan *CheckConnResult
|
|
}
|
|
|
|
type webConnCountMessage struct {
|
|
userID string
|
|
result chan int
|
|
}
|
|
|
|
var hubSemaphoreCount = runtime.NumCPU() * 4
|
|
|
|
// Hub is the central place to manage all websocket connections in the server.
|
|
// It handles different websocket events and sending messages to individual
|
|
// user connections.
|
|
type Hub struct {
|
|
// connectionCount should be kept first.
|
|
// See https://github.com/mattermost/mattermost-server/pull/7281
|
|
connectionCount int64
|
|
platform *PlatformService
|
|
connectionIndex int
|
|
register chan *webConnRegisterMessage
|
|
unregister chan *WebConn
|
|
broadcast chan *model.WebSocketEvent
|
|
stop chan struct{}
|
|
didStop chan struct{}
|
|
invalidateUser chan string
|
|
activity chan *webConnActivityMessage
|
|
directMsg chan *webConnDirectMessage
|
|
explicitStop bool
|
|
checkRegistered chan *webConnSessionMessage
|
|
checkConn chan *webConnCheckMessage
|
|
connCount chan *webConnCountMessage
|
|
broadcastHooks map[string]BroadcastHook
|
|
|
|
// Hub-specific semaphore for limiting concurrent goroutines
|
|
hubSemaphore chan struct{}
|
|
}
|
|
|
|
// newWebHub creates a new Hub.
|
|
func newWebHub(ps *PlatformService) *Hub {
|
|
return &Hub{
|
|
platform: ps,
|
|
register: make(chan *webConnRegisterMessage),
|
|
unregister: make(chan *WebConn),
|
|
broadcast: make(chan *model.WebSocketEvent, broadcastQueueSize),
|
|
stop: make(chan struct{}),
|
|
didStop: make(chan struct{}),
|
|
invalidateUser: make(chan string),
|
|
activity: make(chan *webConnActivityMessage),
|
|
directMsg: make(chan *webConnDirectMessage),
|
|
checkRegistered: make(chan *webConnSessionMessage),
|
|
checkConn: make(chan *webConnCheckMessage),
|
|
connCount: make(chan *webConnCountMessage),
|
|
hubSemaphore: make(chan struct{}, hubSemaphoreCount),
|
|
}
|
|
}
|
|
|
|
// hubStart starts all the hubs.
|
|
func (ps *PlatformService) hubStart(broadcastHooks map[string]BroadcastHook) {
|
|
// After running some tests, we found using the same number of hubs
|
|
// as CPUs to be the ideal in terms of performance.
|
|
// https://github.com/mattermost/mattermost/pull/25798#issuecomment-1889386454
|
|
numberOfHubs := runtime.NumCPU()
|
|
ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
|
|
|
|
hubs := make([]*Hub, numberOfHubs)
|
|
|
|
for i := range numberOfHubs {
|
|
hubs[i] = newWebHub(ps)
|
|
hubs[i].connectionIndex = i
|
|
hubs[i].broadcastHooks = broadcastHooks
|
|
hubs[i].Start()
|
|
}
|
|
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
|
|
// during the start of the program and always read from after that.
|
|
ps.hubs = hubs
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForWebhook(webhookID string) {
|
|
ps.Store.Webhook().InvalidateWebhookCache(webhookID)
|
|
}
|
|
|
|
// HubStop stops all the hubs.
|
|
func (ps *PlatformService) HubStop() {
|
|
ps.logger.Info("stopping websocket hub connections")
|
|
|
|
for _, hub := range ps.hubs {
|
|
hub.Stop()
|
|
}
|
|
}
|
|
|
|
// GetHubForUserId returns the hub for a given user id.
|
|
func (ps *PlatformService) GetHubForUserId(userID string) *Hub {
|
|
if len(ps.hubs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// TODO: check if caching the userID -> hub mapping
|
|
// is worth the memory tradeoff.
|
|
// https://mattermost.atlassian.net/browse/MM-26629.
|
|
var hash maphash.Hash
|
|
hash.SetSeed(ps.hashSeed)
|
|
_, err := hash.Write([]byte(userID))
|
|
if err != nil {
|
|
ps.logger.Error("Unable to write userID to hash", mlog.String("userID", userID), mlog.Err(err))
|
|
}
|
|
index := hash.Sum64() % uint64(len(ps.hubs))
|
|
|
|
return ps.hubs[int(index)]
|
|
}
|
|
|
|
// HubRegister registers a connection to a hub.
|
|
func (ps *PlatformService) HubRegister(webConn *WebConn) error {
|
|
hub := ps.GetHubForUserId(webConn.UserId)
|
|
if hub != nil {
|
|
if metrics := ps.metricsIFace; metrics != nil {
|
|
metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
|
|
}
|
|
return hub.Register(webConn)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// HubUnregister unregisters a connection from a hub.
|
|
func (ps *PlatformService) HubUnregister(webConn *WebConn) {
|
|
hub := ps.GetHubForUserId(webConn.UserId)
|
|
if hub != nil {
|
|
if metrics := ps.metricsIFace; metrics != nil {
|
|
metrics.DecrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
|
|
}
|
|
hub.Unregister(webConn)
|
|
}
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForChannel(channel *model.Channel) {
|
|
ps.Store.Channel().InvalidateChannel(channel.Id)
|
|
teamID := channel.TeamId
|
|
if teamID == "" {
|
|
teamID = "dm"
|
|
}
|
|
|
|
ps.Store.Channel().InvalidateChannelByName(teamID, channel.Name)
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForChannelMembers(channelID string) {
|
|
ps.Store.User().InvalidateProfilesInChannelCache(channelID)
|
|
ps.Store.Channel().InvalidateMemberCount(channelID)
|
|
ps.Store.Channel().InvalidateGuestCount(channelID)
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForChannelMembersNotifyProps(channelID string) {
|
|
ps.Store.Channel().InvalidateCacheForChannelMembersNotifyProps(channelID)
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForChannelPosts(channelID string) {
|
|
ps.Store.Channel().InvalidatePinnedPostCount(channelID)
|
|
ps.Store.Post().InvalidateLastPostTimeCache(channelID)
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForUser(userID string) {
|
|
ps.InvalidateChannelCacheForUser(userID)
|
|
ps.Store.User().InvalidateProfileCacheForUser(userID)
|
|
}
|
|
|
|
func (ps *PlatformService) invalidateWebConnSessionCacheForUser(userID string) {
|
|
ps.invalidateWebConnSessionCacheForUserSkipClusterSend(userID)
|
|
if ps.clusterIFace != nil {
|
|
msg := &model.ClusterMessage{
|
|
Event: model.ClusterEventInvalidateWebConnCacheForUser,
|
|
SendType: model.ClusterSendBestEffort,
|
|
Data: []byte(userID),
|
|
}
|
|
ps.clusterIFace.SendClusterMessage(msg)
|
|
}
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateChannelCacheForUser(userID string) {
|
|
ps.Store.Channel().InvalidateAllChannelMembersForUser(userID)
|
|
ps.invalidateWebConnSessionCacheForUser(userID)
|
|
ps.Store.User().InvalidateProfilesInChannelCacheByUser(userID)
|
|
}
|
|
|
|
func (ps *PlatformService) InvalidateCacheForUserTeams(userID string) {
|
|
ps.invalidateWebConnSessionCacheForUser(userID)
|
|
// This method has its own cluster broadcast hidden inside localcachelayer.
|
|
ps.Store.Team().InvalidateAllTeamIdsForUser(userID)
|
|
}
|
|
|
|
// UpdateWebConnUserActivity sets the LastUserActivityAt of the hub for the given session.
|
|
func (ps *PlatformService) UpdateWebConnUserActivity(session model.Session, activityAt int64) {
|
|
hub := ps.GetHubForUserId(session.UserId)
|
|
if hub != nil {
|
|
hub.UpdateActivity(session.UserId, session.Token, activityAt)
|
|
}
|
|
}
|
|
|
|
// SessionIsRegistered determines if a specific session has been registered
|
|
func (ps *PlatformService) SessionIsRegistered(session model.Session) bool {
|
|
hub := ps.GetHubForUserId(session.UserId)
|
|
if hub != nil {
|
|
return hub.IsRegistered(session.UserId, session.Token)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (ps *PlatformService) CheckWebConn(userID, connectionID string, seqNum int64) *CheckConnResult {
|
|
if ps.Cluster() == nil || seqNum == 0 {
|
|
hub := ps.GetHubForUserId(userID)
|
|
if hub != nil {
|
|
return hub.CheckConn(userID, connectionID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// We need some extra care for HA
|
|
// Check other nodes
|
|
// If any nodes return with an aq and/or dq, use that.
|
|
// If all nodes return empty, proceed with local case.
|
|
// We have to do this because a client might reconnect with an older seq num to a node
|
|
// which it had connected before. So checking its local queue will lead the server to believe
|
|
// that there is no msg loss, whereas there is actually loss.
|
|
queueMap, err := ps.Cluster().GetWSQueues(userID, connectionID, seqNum)
|
|
if err != nil {
|
|
// If there is an error we do not have enough data to say anything reliably.
|
|
// Fall back to unreliable case.
|
|
ps.Log().Error("Error while getting websocket queues",
|
|
mlog.String("connection_id", connectionID),
|
|
mlog.String("user_id", userID),
|
|
mlog.Int("sequence_number", seqNum),
|
|
mlog.Err(err))
|
|
return nil
|
|
}
|
|
|
|
connRes := &CheckConnResult{
|
|
ConnectionID: connectionID,
|
|
UserID: userID,
|
|
}
|
|
for _, queues := range queueMap {
|
|
if queues == nil || queues.ActiveQ == nil {
|
|
continue
|
|
}
|
|
// parse the activeq
|
|
aq := make(chan model.WebSocketMessage, sendQueueSize)
|
|
for _, aqItem := range queues.ActiveQ {
|
|
item, err := ps.UnmarshalAQItem(aqItem)
|
|
if err != nil {
|
|
ps.Log().Error("Error while unmarshalling websocket message from active queue",
|
|
mlog.String("connection_id", connectionID),
|
|
mlog.String("user_id", userID),
|
|
mlog.Err(err))
|
|
return nil
|
|
}
|
|
// This cannot block because all send queues are of sendQueueSize at max.
|
|
// TODO: There could be a case where there's severe message loss, and to
|
|
// reliably get the messages, we need to get send queues from multiple nodes.
|
|
// We leave that case for Redis.
|
|
aq <- item
|
|
}
|
|
|
|
connRes.ActiveQueue = aq
|
|
connRes.ReuseCount = queues.ReuseCount
|
|
|
|
// parse the deadq
|
|
if queues.DeadQ != nil {
|
|
dq, dqPtr, err := ps.UnmarshalDQ(queues.DeadQ)
|
|
if err != nil {
|
|
ps.Log().Error("Error while unmarshalling websocket message from dead queue",
|
|
mlog.String("connection_id", connectionID),
|
|
mlog.String("user_id", userID),
|
|
mlog.Err(err))
|
|
return nil
|
|
}
|
|
|
|
// We check if atleast one item has been written.
|
|
// Length of dq is always guaranteed to be deadQueueSize.
|
|
if dq[0] != nil {
|
|
connRes.DeadQueue = dq
|
|
connRes.DeadQueuePointer = dqPtr
|
|
}
|
|
}
|
|
|
|
return connRes
|
|
}
|
|
|
|
// Now we check local queue
|
|
hub := ps.GetHubForUserId(userID)
|
|
if hub != nil {
|
|
return hub.CheckConn(userID, connectionID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// WebConnCountForUser returns the number of active websocket connections
|
|
// for a given userID.
|
|
func (ps *PlatformService) WebConnCountForUser(userID string) int {
|
|
hub := ps.GetHubForUserId(userID)
|
|
if hub != nil {
|
|
return hub.WebConnCountForUser(userID)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// Register registers a connection to the hub.
|
|
func (h *Hub) Register(webConn *WebConn) error {
|
|
wr := &webConnRegisterMessage{
|
|
conn: webConn,
|
|
err: make(chan error),
|
|
}
|
|
select {
|
|
case h.register <- wr:
|
|
return <-wr.err
|
|
case <-h.stop:
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Unregister unregisters a connection from the hub.
|
|
func (h *Hub) Unregister(webConn *WebConn) {
|
|
select {
|
|
case h.unregister <- webConn:
|
|
case <-h.stop:
|
|
}
|
|
}
|
|
|
|
// Determines if a user's session is registered a connection from the hub.
|
|
func (h *Hub) IsRegistered(userID, sessionToken string) bool {
|
|
ws := &webConnSessionMessage{
|
|
userID: userID,
|
|
sessionToken: sessionToken,
|
|
isRegistered: make(chan bool),
|
|
}
|
|
select {
|
|
case h.checkRegistered <- ws:
|
|
return <-ws.isRegistered
|
|
case <-h.stop:
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (h *Hub) CheckConn(userID, connectionID string) *CheckConnResult {
|
|
req := &webConnCheckMessage{
|
|
userID: userID,
|
|
connectionID: connectionID,
|
|
result: make(chan *CheckConnResult),
|
|
}
|
|
select {
|
|
case h.checkConn <- req:
|
|
return <-req.result
|
|
case <-h.stop:
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *Hub) WebConnCountForUser(userID string) int {
|
|
req := &webConnCountMessage{
|
|
userID: userID,
|
|
result: make(chan int),
|
|
}
|
|
select {
|
|
case h.connCount <- req:
|
|
return <-req.result
|
|
case <-h.stop:
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// Broadcast broadcasts the message to all connections in the hub.
|
|
func (h *Hub) Broadcast(message *model.WebSocketEvent) {
|
|
// XXX: The hub nil check is because of the way we setup our tests. We call
|
|
// `app.NewServer()` which returns a server, but only after that, we call
|
|
// `wsapi.Init()` to initialize the hub. But in the `NewServer` call
|
|
// itself proceeds to broadcast some messages happily. This needs to be
|
|
// fixed once the wsapi cyclic dependency with server/app goes away.
|
|
// And possibly, we can look into doing the hub initialization inside
|
|
// NewServer itself.
|
|
if h != nil && message != nil {
|
|
if metrics := h.platform.metricsIFace; metrics != nil {
|
|
metrics.IncrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
|
|
}
|
|
select {
|
|
case h.broadcast <- message:
|
|
case <-h.stop:
|
|
}
|
|
}
|
|
}
|
|
|
|
// InvalidateUser invalidates the cache for the given user.
|
|
func (h *Hub) InvalidateUser(userID string) {
|
|
select {
|
|
case h.invalidateUser <- userID:
|
|
case <-h.stop:
|
|
}
|
|
}
|
|
|
|
// UpdateActivity sets the LastUserActivityAt field for the connection
|
|
// of the user.
|
|
func (h *Hub) UpdateActivity(userID, sessionToken string, activityAt int64) {
|
|
select {
|
|
case h.activity <- &webConnActivityMessage{
|
|
userID: userID,
|
|
sessionToken: sessionToken,
|
|
activityAt: activityAt,
|
|
}:
|
|
case <-h.stop:
|
|
}
|
|
}
|
|
|
|
// SendMessage sends the given message to the given connection.
|
|
func (h *Hub) SendMessage(conn *WebConn, msg model.WebSocketMessage) {
|
|
select {
|
|
case h.directMsg <- &webConnDirectMessage{
|
|
conn: conn,
|
|
msg: msg,
|
|
}:
|
|
case <-h.stop:
|
|
}
|
|
}
|
|
|
|
// ProcessAsync executes a function with hub-specific concurrency control
|
|
func (h *Hub) ProcessAsync(f func()) {
|
|
h.hubSemaphore <- struct{}{}
|
|
go func() {
|
|
defer func() {
|
|
<-h.hubSemaphore
|
|
}()
|
|
|
|
// Add timeout protection
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
f()
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Function completed normally
|
|
case <-time.After(5 * time.Second):
|
|
h.platform.Log().Warn("ProcessAsync function timed out after 5 seconds")
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Stop stops the hub.
|
|
func (h *Hub) Stop() {
|
|
close(h.stop)
|
|
<-h.didStop
|
|
// Ensure that all remaining elements are processed
|
|
// before shutting down.
|
|
for range hubSemaphoreCount {
|
|
h.hubSemaphore <- struct{}{}
|
|
}
|
|
}
|
|
|
|
// Start starts the hub.
|
|
func (h *Hub) Start() {
|
|
var doStart func()
|
|
var doRecoverableStart func()
|
|
var doRecover func()
|
|
|
|
doStart = func() {
|
|
mlog.Debug("Hub is starting", mlog.Int("index", h.connectionIndex))
|
|
|
|
ticker := time.NewTicker(inactiveConnReaperInterval)
|
|
defer ticker.Stop()
|
|
|
|
connIndex := newHubConnectionIndex(inactiveConnReaperInterval,
|
|
h.platform.Store,
|
|
h.platform.logger,
|
|
*h.platform.Config().ServiceSettings.EnableWebHubChannelIteration,
|
|
)
|
|
|
|
for {
|
|
select {
|
|
case webSessionMessage := <-h.checkRegistered:
|
|
var isRegistered bool
|
|
for conn := range connIndex.ForUser(webSessionMessage.userID) {
|
|
if !conn.Active.Load() {
|
|
continue
|
|
}
|
|
if conn.GetSessionToken() == webSessionMessage.sessionToken {
|
|
isRegistered = true
|
|
}
|
|
}
|
|
webSessionMessage.isRegistered <- isRegistered
|
|
case req := <-h.checkConn:
|
|
var res *CheckConnResult
|
|
conn := connIndex.RemoveInactiveByConnectionID(req.userID, req.connectionID)
|
|
if conn != nil {
|
|
res = &CheckConnResult{
|
|
ConnectionID: req.connectionID,
|
|
UserID: req.userID,
|
|
ActiveQueue: conn.send,
|
|
DeadQueue: conn.deadQueue,
|
|
DeadQueuePointer: conn.deadQueuePointer,
|
|
ReuseCount: conn.reuseCount + 1,
|
|
}
|
|
}
|
|
req.result <- res
|
|
case req := <-h.connCount:
|
|
req.result <- connIndex.ForUserActiveCount(req.userID)
|
|
case <-ticker.C:
|
|
connIndex.RemoveInactiveConnections()
|
|
case webConnReg := <-h.register:
|
|
// Mark the current one as active.
|
|
// There is no need to check if it was inactive or not,
|
|
// we will anyways need to make it active.
|
|
webConnReg.conn.Active.Store(true)
|
|
|
|
err := connIndex.Add(webConnReg.conn)
|
|
if err != nil {
|
|
webConnReg.err <- err
|
|
continue
|
|
}
|
|
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
|
|
|
|
if webConnReg.conn.IsBasicAuthenticated() && webConnReg.conn.reuseCount == 0 {
|
|
// The hello message should only be sent when the reuseCount is 0.
|
|
// i.e in server restart, or long timeout, or fresh connection case.
|
|
// In case of seq number not found in dead queue, it is handled by
|
|
// the webconn write pump.
|
|
webConnReg.conn.send <- webConnReg.conn.createHelloMessage()
|
|
}
|
|
webConnReg.err <- nil
|
|
case webConn := <-h.unregister:
|
|
// If already removed (via queue full), then removing again becomes a noop.
|
|
// But if not removed, mark inactive.
|
|
webConn.Active.Store(false)
|
|
|
|
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
|
|
|
|
if webConn.UserId == "" {
|
|
continue
|
|
}
|
|
|
|
conns := connIndex.ForUser(webConn.UserId)
|
|
// areAllInactive also returns true if there are no connections,
|
|
// which is intentional.
|
|
if areAllInactive(conns) {
|
|
userID := webConn.UserId
|
|
h.ProcessAsync(func() {
|
|
// If this is an HA setup, get count for this user
|
|
// from other nodes.
|
|
var clusterCnt int
|
|
var appErr *model.AppError
|
|
if h.platform.Cluster() != nil {
|
|
clusterCnt, appErr = h.platform.Cluster().WebConnCountForUser(userID)
|
|
}
|
|
if appErr != nil {
|
|
mlog.Error("Error in trying to get the webconn count from cluster", mlog.Err(appErr))
|
|
// We take a conservative approach
|
|
// and do not set status to offline in case
|
|
// there's an error, rather than potentially
|
|
// incorrectly setting status to offline.
|
|
return
|
|
}
|
|
// Only set to offline if there are no
|
|
// active connections in other nodes as well.
|
|
if clusterCnt == 0 {
|
|
h.platform.QueueSetStatusOffline(userID, false)
|
|
}
|
|
})
|
|
continue
|
|
}
|
|
var latestActivity int64
|
|
for conn := range conns {
|
|
if !conn.Active.Load() {
|
|
continue
|
|
}
|
|
if conn.lastUserActivityAt > latestActivity {
|
|
latestActivity = conn.lastUserActivityAt
|
|
}
|
|
}
|
|
|
|
if h.platform.isUserAway(latestActivity) {
|
|
userID := webConn.UserId
|
|
h.platform.Go(func() {
|
|
h.platform.SetStatusLastActivityAt(userID, latestActivity)
|
|
})
|
|
}
|
|
case userID := <-h.invalidateUser:
|
|
for webConn := range connIndex.ForUser(userID) {
|
|
webConn.InvalidateCache()
|
|
}
|
|
|
|
if !*h.platform.Config().ServiceSettings.EnableWebHubChannelIteration {
|
|
continue
|
|
}
|
|
|
|
err := connIndex.InvalidateCMCacheForUser(userID)
|
|
if err != nil {
|
|
h.platform.Log().Error("Error while invalidating channel member cache", mlog.String("user_id", userID), mlog.Err(err))
|
|
for webConn := range connIndex.ForUser(userID) {
|
|
closeAndRemoveConn(connIndex, webConn)
|
|
}
|
|
}
|
|
case activity := <-h.activity:
|
|
for webConn := range connIndex.ForUser(activity.userID) {
|
|
if !webConn.Active.Load() {
|
|
continue
|
|
}
|
|
if webConn.GetSessionToken() == activity.sessionToken {
|
|
webConn.lastUserActivityAt = activity.activityAt
|
|
}
|
|
}
|
|
case directMsg := <-h.directMsg:
|
|
if !connIndex.Has(directMsg.conn) {
|
|
continue
|
|
}
|
|
select {
|
|
case directMsg.conn.send <- directMsg.msg:
|
|
default:
|
|
// Don't log the warning if it's an inactive connection.
|
|
if directMsg.conn.Active.Load() {
|
|
mlog.Error("webhub.broadcast: cannot send, closing websocket for user",
|
|
mlog.String("user_id", directMsg.conn.UserId),
|
|
mlog.String("conn_id", directMsg.conn.GetConnectionID()))
|
|
}
|
|
closeAndRemoveConn(connIndex, directMsg.conn)
|
|
}
|
|
case msg := <-h.broadcast:
|
|
if metrics := h.platform.metricsIFace; metrics != nil {
|
|
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
|
|
}
|
|
|
|
// Remove the broadcast hook information before precomputing the JSON so that those aren't included in it
|
|
msg, broadcastHooks, broadcastHookArgs := msg.WithoutBroadcastHooks()
|
|
|
|
msg = msg.PrecomputeJSON()
|
|
|
|
broadcast := func(webConn *WebConn) {
|
|
if !connIndex.Has(webConn) {
|
|
return
|
|
}
|
|
if webConn.ShouldSendEvent(msg) {
|
|
select {
|
|
case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs):
|
|
default:
|
|
// Don't log the warning if it's an inactive connection.
|
|
if webConn.Active.Load() {
|
|
mlog.Error("webhub.broadcast: cannot send, closing websocket for user",
|
|
mlog.String("user_id", webConn.UserId),
|
|
mlog.String("conn_id", webConn.GetConnectionID()))
|
|
}
|
|
closeAndRemoveConn(connIndex, webConn)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Quick return for a single connection.
|
|
if webConn := connIndex.ForConnection(msg.GetBroadcast().ConnectionId); webConn != nil {
|
|
broadcast(webConn)
|
|
continue
|
|
}
|
|
|
|
fastIteration := *h.platform.Config().ServiceSettings.EnableWebHubChannelIteration
|
|
var targetConns iter.Seq[*WebConn]
|
|
if userID := msg.GetBroadcast().UserId; userID != "" {
|
|
targetConns = connIndex.ForUser(userID)
|
|
} else if channelID := msg.GetBroadcast().ChannelId; channelID != "" && fastIteration {
|
|
targetConns = connIndex.ForChannel(channelID)
|
|
}
|
|
if targetConns != nil {
|
|
for webConn := range targetConns {
|
|
broadcast(webConn)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// There are multiple hubs in a system. So while supporting both channel based iteration and the old
|
|
// method, there would be events scoped to a channel being sent to multiple hubs. And only one hub would
|
|
// have the targetConns. Therefore, we need to stop here if channel based iteration is enabled, and it's a
|
|
// channel-scoped event.
|
|
if channelID := msg.GetBroadcast().ChannelId; channelID != "" && fastIteration {
|
|
continue
|
|
}
|
|
|
|
for webConn := range connIndex.All() {
|
|
broadcast(webConn)
|
|
}
|
|
case <-h.stop:
|
|
for webConn := range connIndex.All() {
|
|
webConn.Close()
|
|
h.platform.SetStatusOffline(webConn.UserId, false, false)
|
|
}
|
|
|
|
h.explicitStop = true
|
|
close(h.didStop)
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
doRecoverableStart = func() {
|
|
defer doRecover()
|
|
doStart()
|
|
}
|
|
|
|
doRecover = func() {
|
|
if !h.explicitStop {
|
|
if r := recover(); r != nil {
|
|
mlog.Error("Recovering from Hub panic.", mlog.Any("panic", r))
|
|
} else {
|
|
mlog.Error("Webhub stopped unexpectedly. Recovering.")
|
|
}
|
|
|
|
mlog.Error(string(debug.Stack()))
|
|
|
|
go doRecoverableStart()
|
|
}
|
|
}
|
|
|
|
go doRecoverableStart()
|
|
}
|
|
|
|
// areAllInactive returns whether all of the connections
|
|
// are inactive or not. It also returns true if there are
|
|
// no connections which is also intentional.
|
|
func areAllInactive(conns iter.Seq[*WebConn]) bool {
|
|
for conn := range conns {
|
|
if conn.Active.Load() {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// closeAndRemoveConn closes the send channel which will close the
|
|
// websocket connection, and then it removes the webConn from the conn index.
|
|
func closeAndRemoveConn(connIndex *hubConnectionIndex, conn *WebConn) {
|
|
close(conn.send)
|
|
connIndex.Remove(conn)
|
|
}
|
|
|
|
// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
|
|
// It requires 4 functionalities which need to be very fast:
|
|
// - check if a connection exists or not.
|
|
// - get all connections for a given userID.
|
|
// - get all connections for a given channelID.
|
|
// - get all connections.
|
|
type hubConnectionIndex struct {
|
|
// byUserId stores the set of connections for a given userID
|
|
byUserId map[string]map[*WebConn]struct{}
|
|
// byChannelID stores the set of connections for a given channelID
|
|
byChannelID map[string]map[*WebConn]struct{}
|
|
// byConnection serves the dual purpose of storing the channelIDs
|
|
// and also to get all connections
|
|
byConnection map[*WebConn][]string
|
|
byConnectionId map[string]*WebConn
|
|
// staleThreshold is the limit beyond which inactive connections
|
|
// will be deleted.
|
|
staleThreshold time.Duration
|
|
|
|
fastIteration bool
|
|
store store.Store
|
|
logger mlog.LoggerIFace
|
|
}
|
|
|
|
func newHubConnectionIndex(interval time.Duration,
|
|
store store.Store,
|
|
logger mlog.LoggerIFace,
|
|
fastIteration bool,
|
|
) *hubConnectionIndex {
|
|
return &hubConnectionIndex{
|
|
byUserId: make(map[string]map[*WebConn]struct{}),
|
|
byChannelID: make(map[string]map[*WebConn]struct{}),
|
|
byConnection: make(map[*WebConn][]string),
|
|
byConnectionId: make(map[string]*WebConn),
|
|
staleThreshold: interval,
|
|
store: store,
|
|
logger: logger,
|
|
fastIteration: fastIteration,
|
|
}
|
|
}
|
|
|
|
func (i *hubConnectionIndex) Add(wc *WebConn) error {
|
|
var channelIDs []string
|
|
if i.fastIteration {
|
|
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), wc.UserId, false, false)
|
|
if err != nil {
|
|
return fmt.Errorf("error getChannelMembersForUser: %v", err)
|
|
}
|
|
|
|
// Store channel IDs and add to byChannelID
|
|
channelIDs = make([]string, 0, len(cm))
|
|
for chID := range cm {
|
|
channelIDs = append(channelIDs, chID)
|
|
|
|
// Initialize the channel's map if it doesn't exist
|
|
if _, ok := i.byChannelID[chID]; !ok {
|
|
i.byChannelID[chID] = make(map[*WebConn]struct{})
|
|
}
|
|
i.byChannelID[chID][wc] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Initialize the user's map if it doesn't exist
|
|
if _, ok := i.byUserId[wc.UserId]; !ok {
|
|
i.byUserId[wc.UserId] = make(map[*WebConn]struct{})
|
|
}
|
|
i.byUserId[wc.UserId][wc] = struct{}{}
|
|
i.byConnection[wc] = channelIDs
|
|
i.byConnectionId[wc.GetConnectionID()] = wc
|
|
return nil
|
|
}
|
|
|
|
func (i *hubConnectionIndex) Remove(wc *WebConn) {
|
|
channelIDs, ok := i.byConnection[wc]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// Remove from byUserId
|
|
if userConns, ok := i.byUserId[wc.UserId]; ok {
|
|
delete(userConns, wc)
|
|
}
|
|
|
|
if i.fastIteration {
|
|
// Remove from byChannelID for each channel
|
|
for _, chID := range channelIDs {
|
|
if channelConns, ok := i.byChannelID[chID]; ok {
|
|
delete(channelConns, wc)
|
|
}
|
|
}
|
|
}
|
|
|
|
delete(i.byConnection, wc)
|
|
delete(i.byConnectionId, wc.GetConnectionID())
|
|
}
|
|
|
|
func (i *hubConnectionIndex) InvalidateCMCacheForUser(userID string) error {
|
|
// We make this query first to fail fast in case of an error.
|
|
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), userID, false, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get all connections for this user
|
|
conns := i.ForUser(userID)
|
|
|
|
// Remove all user connections from existing channels
|
|
for conn := range conns {
|
|
if channelIDs, ok := i.byConnection[conn]; ok {
|
|
// Remove from old channels
|
|
for _, chID := range channelIDs {
|
|
if channelConns, ok := i.byChannelID[chID]; ok {
|
|
delete(channelConns, conn)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add connections to new channels
|
|
for conn := range conns {
|
|
newChannelIDs := make([]string, 0, len(cm))
|
|
for chID := range cm {
|
|
newChannelIDs = append(newChannelIDs, chID)
|
|
// Initialize channel map if needed
|
|
if _, ok := i.byChannelID[chID]; !ok {
|
|
i.byChannelID[chID] = make(map[*WebConn]struct{})
|
|
}
|
|
i.byChannelID[chID][conn] = struct{}{}
|
|
}
|
|
|
|
// Update connection metadata
|
|
if _, ok := i.byConnection[conn]; ok {
|
|
i.byConnection[conn] = newChannelIDs
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *hubConnectionIndex) Has(wc *WebConn) bool {
|
|
_, ok := i.byConnection[wc]
|
|
return ok
|
|
}
|
|
|
|
// ForUser returns all connections for a user ID.
|
|
func (i *hubConnectionIndex) ForUser(id string) iter.Seq[*WebConn] {
|
|
return maps.Keys(i.byUserId[id])
|
|
}
|
|
|
|
// ForChannel returns all connections for a channelID.
|
|
func (i *hubConnectionIndex) ForChannel(channelID string) iter.Seq[*WebConn] {
|
|
return maps.Keys(i.byChannelID[channelID])
|
|
}
|
|
|
|
// ForUserActiveCount returns the number of active connections for a userID
|
|
func (i *hubConnectionIndex) ForUserActiveCount(id string) int {
|
|
cnt := 0
|
|
for conn := range i.ForUser(id) {
|
|
if conn.Active.Load() {
|
|
cnt++
|
|
}
|
|
}
|
|
return cnt
|
|
}
|
|
|
|
// ForConnection returns the connection from its ID.
|
|
func (i *hubConnectionIndex) ForConnection(id string) *WebConn {
|
|
return i.byConnectionId[id]
|
|
}
|
|
|
|
// All returns the full webConn index.
|
|
func (i *hubConnectionIndex) All() map[*WebConn][]string {
|
|
return i.byConnection
|
|
}
|
|
|
|
// RemoveInactiveByConnectionID removes an inactive connection for the given
|
|
// userID and connectionID.
|
|
func (i *hubConnectionIndex) RemoveInactiveByConnectionID(userID, connectionID string) *WebConn {
|
|
// To handle empty sessions.
|
|
if userID == "" {
|
|
return nil
|
|
}
|
|
for conn := range i.ForUser(userID) {
|
|
if conn.GetConnectionID() == connectionID && !conn.Active.Load() {
|
|
i.Remove(conn)
|
|
return conn
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveInactiveConnections removes all inactive connections whose lastUserActivityAt
|
|
// exceeded staleThreshold.
|
|
func (i *hubConnectionIndex) RemoveInactiveConnections() {
|
|
now := model.GetMillis()
|
|
for conn := range i.byConnection {
|
|
if !conn.Active.Load() && now-conn.lastUserActivityAt > i.staleThreshold.Milliseconds() {
|
|
i.Remove(conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
// AllActive returns the number of active connections.
|
|
// This is only called during register/unregister so we can take
|
|
// a bit of perf hit here.
|
|
func (i *hubConnectionIndex) AllActive() int {
|
|
cnt := 0
|
|
for conn := range i.byConnection {
|
|
if conn.Active.Load() {
|
|
cnt++
|
|
}
|
|
}
|
|
return cnt
|
|
}
|