mattermost-community-enterp.../channels/app/platform/web_hub.go
Claude ec1f89217a Merge: Complete Mattermost Server with Community Enterprise
Full Mattermost server source with integrated Community Enterprise features.
Includes vendor directory for offline/air-gapped builds.

Structure:
- enterprise-impl/: Enterprise feature implementations
- enterprise-community/: Init files that register implementations
- enterprise/: Bridge imports (community_imports.go)
- vendor/: All dependencies for offline builds

Build (online):
  go build ./cmd/mattermost

Build (offline/air-gapped):
  go build -mod=vendor ./cmd/mattermost

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-17 23:59:07 +09:00

1009 lines
29 KiB
Go

// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package platform
import (
"fmt"
"hash/maphash"
"iter"
"maps"
"runtime"
"runtime/debug"
"strconv"
"sync/atomic"
"time"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store"
)
const (
broadcastQueueSize = 4096
inactiveConnReaperInterval = 5 * time.Minute
)
type SuiteIFace interface {
GetSession(token string) (*model.Session, *model.AppError)
RolesGrantPermission(roleNames []string, permissionId string) bool
HasPermissionToReadChannel(rctx request.CTX, userID string, channel *model.Channel) bool
UserCanSeeOtherUser(rctx request.CTX, userID string, otherUserId string) (bool, *model.AppError)
MFARequired(rctx request.CTX) *model.AppError
}
type webConnActivityMessage struct {
userID string
sessionToken string
activityAt int64
}
type webConnDirectMessage struct {
conn *WebConn
msg model.WebSocketMessage
}
type webConnSessionMessage struct {
userID string
sessionToken string
isRegistered chan bool
}
type webConnRegisterMessage struct {
conn *WebConn
err chan error
}
type webConnCheckMessage struct {
userID string
connectionID string
result chan *CheckConnResult
}
type webConnCountMessage struct {
userID string
result chan int
}
var hubSemaphoreCount = runtime.NumCPU() * 4
// Hub is the central place to manage all websocket connections in the server.
// It handles different websocket events and sending messages to individual
// user connections.
type Hub struct {
// connectionCount should be kept first.
// See https://github.com/mattermost/mattermost-server/pull/7281
connectionCount int64
platform *PlatformService
connectionIndex int
register chan *webConnRegisterMessage
unregister chan *WebConn
broadcast chan *model.WebSocketEvent
stop chan struct{}
didStop chan struct{}
invalidateUser chan string
activity chan *webConnActivityMessage
directMsg chan *webConnDirectMessage
explicitStop bool
checkRegistered chan *webConnSessionMessage
checkConn chan *webConnCheckMessage
connCount chan *webConnCountMessage
broadcastHooks map[string]BroadcastHook
// Hub-specific semaphore for limiting concurrent goroutines
hubSemaphore chan struct{}
}
// newWebHub creates a new Hub.
func newWebHub(ps *PlatformService) *Hub {
return &Hub{
platform: ps,
register: make(chan *webConnRegisterMessage),
unregister: make(chan *WebConn),
broadcast: make(chan *model.WebSocketEvent, broadcastQueueSize),
stop: make(chan struct{}),
didStop: make(chan struct{}),
invalidateUser: make(chan string),
activity: make(chan *webConnActivityMessage),
directMsg: make(chan *webConnDirectMessage),
checkRegistered: make(chan *webConnSessionMessage),
checkConn: make(chan *webConnCheckMessage),
connCount: make(chan *webConnCountMessage),
hubSemaphore: make(chan struct{}, hubSemaphoreCount),
}
}
// hubStart starts all the hubs.
func (ps *PlatformService) hubStart(broadcastHooks map[string]BroadcastHook) {
// After running some tests, we found using the same number of hubs
// as CPUs to be the ideal in terms of performance.
// https://github.com/mattermost/mattermost/pull/25798#issuecomment-1889386454
numberOfHubs := runtime.NumCPU()
ps.logger.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
hubs := make([]*Hub, numberOfHubs)
for i := range numberOfHubs {
hubs[i] = newWebHub(ps)
hubs[i].connectionIndex = i
hubs[i].broadcastHooks = broadcastHooks
hubs[i].Start()
}
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
// during the start of the program and always read from after that.
ps.hubs = hubs
}
func (ps *PlatformService) InvalidateCacheForWebhook(webhookID string) {
ps.Store.Webhook().InvalidateWebhookCache(webhookID)
}
// HubStop stops all the hubs.
func (ps *PlatformService) HubStop() {
ps.logger.Info("stopping websocket hub connections")
for _, hub := range ps.hubs {
hub.Stop()
}
}
// GetHubForUserId returns the hub for a given user id.
func (ps *PlatformService) GetHubForUserId(userID string) *Hub {
if len(ps.hubs) == 0 {
return nil
}
// TODO: check if caching the userID -> hub mapping
// is worth the memory tradeoff.
// https://mattermost.atlassian.net/browse/MM-26629.
var hash maphash.Hash
hash.SetSeed(ps.hashSeed)
_, err := hash.Write([]byte(userID))
if err != nil {
ps.logger.Error("Unable to write userID to hash", mlog.String("userID", userID), mlog.Err(err))
}
index := hash.Sum64() % uint64(len(ps.hubs))
return ps.hubs[int(index)]
}
// HubRegister registers a connection to a hub.
func (ps *PlatformService) HubRegister(webConn *WebConn) error {
hub := ps.GetHubForUserId(webConn.UserId)
if hub != nil {
if metrics := ps.metricsIFace; metrics != nil {
metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
}
return hub.Register(webConn)
}
return nil
}
// HubUnregister unregisters a connection from a hub.
func (ps *PlatformService) HubUnregister(webConn *WebConn) {
hub := ps.GetHubForUserId(webConn.UserId)
if hub != nil {
if metrics := ps.metricsIFace; metrics != nil {
metrics.DecrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
}
hub.Unregister(webConn)
}
}
func (ps *PlatformService) InvalidateCacheForChannel(channel *model.Channel) {
ps.Store.Channel().InvalidateChannel(channel.Id)
teamID := channel.TeamId
if teamID == "" {
teamID = "dm"
}
ps.Store.Channel().InvalidateChannelByName(teamID, channel.Name)
}
func (ps *PlatformService) InvalidateCacheForChannelMembers(channelID string) {
ps.Store.User().InvalidateProfilesInChannelCache(channelID)
ps.Store.Channel().InvalidateMemberCount(channelID)
ps.Store.Channel().InvalidateGuestCount(channelID)
}
func (ps *PlatformService) InvalidateCacheForChannelMembersNotifyProps(channelID string) {
ps.Store.Channel().InvalidateCacheForChannelMembersNotifyProps(channelID)
}
func (ps *PlatformService) InvalidateCacheForChannelPosts(channelID string) {
ps.Store.Channel().InvalidatePinnedPostCount(channelID)
ps.Store.Post().InvalidateLastPostTimeCache(channelID)
}
func (ps *PlatformService) InvalidateCacheForUser(userID string) {
ps.InvalidateChannelCacheForUser(userID)
ps.Store.User().InvalidateProfileCacheForUser(userID)
}
func (ps *PlatformService) invalidateWebConnSessionCacheForUser(userID string) {
ps.invalidateWebConnSessionCacheForUserSkipClusterSend(userID)
if ps.clusterIFace != nil {
msg := &model.ClusterMessage{
Event: model.ClusterEventInvalidateWebConnCacheForUser,
SendType: model.ClusterSendBestEffort,
Data: []byte(userID),
}
ps.clusterIFace.SendClusterMessage(msg)
}
}
func (ps *PlatformService) InvalidateChannelCacheForUser(userID string) {
ps.Store.Channel().InvalidateAllChannelMembersForUser(userID)
ps.invalidateWebConnSessionCacheForUser(userID)
ps.Store.User().InvalidateProfilesInChannelCacheByUser(userID)
}
func (ps *PlatformService) InvalidateCacheForUserTeams(userID string) {
ps.invalidateWebConnSessionCacheForUser(userID)
// This method has its own cluster broadcast hidden inside localcachelayer.
ps.Store.Team().InvalidateAllTeamIdsForUser(userID)
}
// UpdateWebConnUserActivity sets the LastUserActivityAt of the hub for the given session.
func (ps *PlatformService) UpdateWebConnUserActivity(session model.Session, activityAt int64) {
hub := ps.GetHubForUserId(session.UserId)
if hub != nil {
hub.UpdateActivity(session.UserId, session.Token, activityAt)
}
}
// SessionIsRegistered determines if a specific session has been registered
func (ps *PlatformService) SessionIsRegistered(session model.Session) bool {
hub := ps.GetHubForUserId(session.UserId)
if hub != nil {
return hub.IsRegistered(session.UserId, session.Token)
}
return false
}
func (ps *PlatformService) CheckWebConn(userID, connectionID string, seqNum int64) *CheckConnResult {
if ps.Cluster() == nil || seqNum == 0 {
hub := ps.GetHubForUserId(userID)
if hub != nil {
return hub.CheckConn(userID, connectionID)
}
return nil
}
// We need some extra care for HA
// Check other nodes
// If any nodes return with an aq and/or dq, use that.
// If all nodes return empty, proceed with local case.
// We have to do this because a client might reconnect with an older seq num to a node
// which it had connected before. So checking its local queue will lead the server to believe
// that there is no msg loss, whereas there is actually loss.
queueMap, err := ps.Cluster().GetWSQueues(userID, connectionID, seqNum)
if err != nil {
// If there is an error we do not have enough data to say anything reliably.
// Fall back to unreliable case.
ps.Log().Error("Error while getting websocket queues",
mlog.String("connection_id", connectionID),
mlog.String("user_id", userID),
mlog.Int("sequence_number", seqNum),
mlog.Err(err))
return nil
}
connRes := &CheckConnResult{
ConnectionID: connectionID,
UserID: userID,
}
for _, queues := range queueMap {
if queues == nil || queues.ActiveQ == nil {
continue
}
// parse the activeq
aq := make(chan model.WebSocketMessage, sendQueueSize)
for _, aqItem := range queues.ActiveQ {
item, err := ps.UnmarshalAQItem(aqItem)
if err != nil {
ps.Log().Error("Error while unmarshalling websocket message from active queue",
mlog.String("connection_id", connectionID),
mlog.String("user_id", userID),
mlog.Err(err))
return nil
}
// This cannot block because all send queues are of sendQueueSize at max.
// TODO: There could be a case where there's severe message loss, and to
// reliably get the messages, we need to get send queues from multiple nodes.
// We leave that case for Redis.
aq <- item
}
connRes.ActiveQueue = aq
connRes.ReuseCount = queues.ReuseCount
// parse the deadq
if queues.DeadQ != nil {
dq, dqPtr, err := ps.UnmarshalDQ(queues.DeadQ)
if err != nil {
ps.Log().Error("Error while unmarshalling websocket message from dead queue",
mlog.String("connection_id", connectionID),
mlog.String("user_id", userID),
mlog.Err(err))
return nil
}
// We check if atleast one item has been written.
// Length of dq is always guaranteed to be deadQueueSize.
if dq[0] != nil {
connRes.DeadQueue = dq
connRes.DeadQueuePointer = dqPtr
}
}
return connRes
}
// Now we check local queue
hub := ps.GetHubForUserId(userID)
if hub != nil {
return hub.CheckConn(userID, connectionID)
}
return nil
}
// WebConnCountForUser returns the number of active websocket connections
// for a given userID.
func (ps *PlatformService) WebConnCountForUser(userID string) int {
hub := ps.GetHubForUserId(userID)
if hub != nil {
return hub.WebConnCountForUser(userID)
}
return 0
}
// Register registers a connection to the hub.
func (h *Hub) Register(webConn *WebConn) error {
wr := &webConnRegisterMessage{
conn: webConn,
err: make(chan error),
}
select {
case h.register <- wr:
return <-wr.err
case <-h.stop:
}
return nil
}
// Unregister unregisters a connection from the hub.
func (h *Hub) Unregister(webConn *WebConn) {
select {
case h.unregister <- webConn:
case <-h.stop:
}
}
// Determines if a user's session is registered a connection from the hub.
func (h *Hub) IsRegistered(userID, sessionToken string) bool {
ws := &webConnSessionMessage{
userID: userID,
sessionToken: sessionToken,
isRegistered: make(chan bool),
}
select {
case h.checkRegistered <- ws:
return <-ws.isRegistered
case <-h.stop:
}
return false
}
func (h *Hub) CheckConn(userID, connectionID string) *CheckConnResult {
req := &webConnCheckMessage{
userID: userID,
connectionID: connectionID,
result: make(chan *CheckConnResult),
}
select {
case h.checkConn <- req:
return <-req.result
case <-h.stop:
}
return nil
}
func (h *Hub) WebConnCountForUser(userID string) int {
req := &webConnCountMessage{
userID: userID,
result: make(chan int),
}
select {
case h.connCount <- req:
return <-req.result
case <-h.stop:
}
return 0
}
// Broadcast broadcasts the message to all connections in the hub.
func (h *Hub) Broadcast(message *model.WebSocketEvent) {
// XXX: The hub nil check is because of the way we setup our tests. We call
// `app.NewServer()` which returns a server, but only after that, we call
// `wsapi.Init()` to initialize the hub. But in the `NewServer` call
// itself proceeds to broadcast some messages happily. This needs to be
// fixed once the wsapi cyclic dependency with server/app goes away.
// And possibly, we can look into doing the hub initialization inside
// NewServer itself.
if h != nil && message != nil {
if metrics := h.platform.metricsIFace; metrics != nil {
metrics.IncrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
}
select {
case h.broadcast <- message:
case <-h.stop:
}
}
}
// InvalidateUser invalidates the cache for the given user.
func (h *Hub) InvalidateUser(userID string) {
select {
case h.invalidateUser <- userID:
case <-h.stop:
}
}
// UpdateActivity sets the LastUserActivityAt field for the connection
// of the user.
func (h *Hub) UpdateActivity(userID, sessionToken string, activityAt int64) {
select {
case h.activity <- &webConnActivityMessage{
userID: userID,
sessionToken: sessionToken,
activityAt: activityAt,
}:
case <-h.stop:
}
}
// SendMessage sends the given message to the given connection.
func (h *Hub) SendMessage(conn *WebConn, msg model.WebSocketMessage) {
select {
case h.directMsg <- &webConnDirectMessage{
conn: conn,
msg: msg,
}:
case <-h.stop:
}
}
// ProcessAsync executes a function with hub-specific concurrency control
func (h *Hub) ProcessAsync(f func()) {
h.hubSemaphore <- struct{}{}
go func() {
defer func() {
<-h.hubSemaphore
}()
// Add timeout protection
done := make(chan struct{})
go func() {
defer close(done)
f()
}()
select {
case <-done:
// Function completed normally
case <-time.After(5 * time.Second):
h.platform.Log().Warn("ProcessAsync function timed out after 5 seconds")
}
}()
}
// Stop stops the hub.
func (h *Hub) Stop() {
close(h.stop)
<-h.didStop
// Ensure that all remaining elements are processed
// before shutting down.
for range hubSemaphoreCount {
h.hubSemaphore <- struct{}{}
}
}
// Start starts the hub.
func (h *Hub) Start() {
var doStart func()
var doRecoverableStart func()
var doRecover func()
doStart = func() {
mlog.Debug("Hub is starting", mlog.Int("index", h.connectionIndex))
ticker := time.NewTicker(inactiveConnReaperInterval)
defer ticker.Stop()
connIndex := newHubConnectionIndex(inactiveConnReaperInterval,
h.platform.Store,
h.platform.logger,
*h.platform.Config().ServiceSettings.EnableWebHubChannelIteration,
)
for {
select {
case webSessionMessage := <-h.checkRegistered:
var isRegistered bool
for conn := range connIndex.ForUser(webSessionMessage.userID) {
if !conn.Active.Load() {
continue
}
if conn.GetSessionToken() == webSessionMessage.sessionToken {
isRegistered = true
}
}
webSessionMessage.isRegistered <- isRegistered
case req := <-h.checkConn:
var res *CheckConnResult
conn := connIndex.RemoveInactiveByConnectionID(req.userID, req.connectionID)
if conn != nil {
res = &CheckConnResult{
ConnectionID: req.connectionID,
UserID: req.userID,
ActiveQueue: conn.send,
DeadQueue: conn.deadQueue,
DeadQueuePointer: conn.deadQueuePointer,
ReuseCount: conn.reuseCount + 1,
}
}
req.result <- res
case req := <-h.connCount:
req.result <- connIndex.ForUserActiveCount(req.userID)
case <-ticker.C:
connIndex.RemoveInactiveConnections()
case webConnReg := <-h.register:
// Mark the current one as active.
// There is no need to check if it was inactive or not,
// we will anyways need to make it active.
webConnReg.conn.Active.Store(true)
err := connIndex.Add(webConnReg.conn)
if err != nil {
webConnReg.err <- err
continue
}
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
if webConnReg.conn.IsBasicAuthenticated() && webConnReg.conn.reuseCount == 0 {
// The hello message should only be sent when the reuseCount is 0.
// i.e in server restart, or long timeout, or fresh connection case.
// In case of seq number not found in dead queue, it is handled by
// the webconn write pump.
webConnReg.conn.send <- webConnReg.conn.createHelloMessage()
}
webConnReg.err <- nil
case webConn := <-h.unregister:
// If already removed (via queue full), then removing again becomes a noop.
// But if not removed, mark inactive.
webConn.Active.Store(false)
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
if webConn.UserId == "" {
continue
}
conns := connIndex.ForUser(webConn.UserId)
// areAllInactive also returns true if there are no connections,
// which is intentional.
if areAllInactive(conns) {
userID := webConn.UserId
h.ProcessAsync(func() {
// If this is an HA setup, get count for this user
// from other nodes.
var clusterCnt int
var appErr *model.AppError
if h.platform.Cluster() != nil {
clusterCnt, appErr = h.platform.Cluster().WebConnCountForUser(userID)
}
if appErr != nil {
mlog.Error("Error in trying to get the webconn count from cluster", mlog.Err(appErr))
// We take a conservative approach
// and do not set status to offline in case
// there's an error, rather than potentially
// incorrectly setting status to offline.
return
}
// Only set to offline if there are no
// active connections in other nodes as well.
if clusterCnt == 0 {
h.platform.QueueSetStatusOffline(userID, false)
}
})
continue
}
var latestActivity int64
for conn := range conns {
if !conn.Active.Load() {
continue
}
if conn.lastUserActivityAt > latestActivity {
latestActivity = conn.lastUserActivityAt
}
}
if h.platform.isUserAway(latestActivity) {
userID := webConn.UserId
h.platform.Go(func() {
h.platform.SetStatusLastActivityAt(userID, latestActivity)
})
}
case userID := <-h.invalidateUser:
for webConn := range connIndex.ForUser(userID) {
webConn.InvalidateCache()
}
if !*h.platform.Config().ServiceSettings.EnableWebHubChannelIteration {
continue
}
err := connIndex.InvalidateCMCacheForUser(userID)
if err != nil {
h.platform.Log().Error("Error while invalidating channel member cache", mlog.String("user_id", userID), mlog.Err(err))
for webConn := range connIndex.ForUser(userID) {
closeAndRemoveConn(connIndex, webConn)
}
}
case activity := <-h.activity:
for webConn := range connIndex.ForUser(activity.userID) {
if !webConn.Active.Load() {
continue
}
if webConn.GetSessionToken() == activity.sessionToken {
webConn.lastUserActivityAt = activity.activityAt
}
}
case directMsg := <-h.directMsg:
if !connIndex.Has(directMsg.conn) {
continue
}
select {
case directMsg.conn.send <- directMsg.msg:
default:
// Don't log the warning if it's an inactive connection.
if directMsg.conn.Active.Load() {
mlog.Error("webhub.broadcast: cannot send, closing websocket for user",
mlog.String("user_id", directMsg.conn.UserId),
mlog.String("conn_id", directMsg.conn.GetConnectionID()))
}
closeAndRemoveConn(connIndex, directMsg.conn)
}
case msg := <-h.broadcast:
if metrics := h.platform.metricsIFace; metrics != nil {
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
}
// Remove the broadcast hook information before precomputing the JSON so that those aren't included in it
msg, broadcastHooks, broadcastHookArgs := msg.WithoutBroadcastHooks()
msg = msg.PrecomputeJSON()
broadcast := func(webConn *WebConn) {
if !connIndex.Has(webConn) {
return
}
if webConn.ShouldSendEvent(msg) {
select {
case webConn.send <- h.runBroadcastHooks(msg, webConn, broadcastHooks, broadcastHookArgs):
default:
// Don't log the warning if it's an inactive connection.
if webConn.Active.Load() {
mlog.Error("webhub.broadcast: cannot send, closing websocket for user",
mlog.String("user_id", webConn.UserId),
mlog.String("conn_id", webConn.GetConnectionID()))
}
closeAndRemoveConn(connIndex, webConn)
}
}
}
// Quick return for a single connection.
if webConn := connIndex.ForConnection(msg.GetBroadcast().ConnectionId); webConn != nil {
broadcast(webConn)
continue
}
fastIteration := *h.platform.Config().ServiceSettings.EnableWebHubChannelIteration
var targetConns iter.Seq[*WebConn]
if userID := msg.GetBroadcast().UserId; userID != "" {
targetConns = connIndex.ForUser(userID)
} else if channelID := msg.GetBroadcast().ChannelId; channelID != "" && fastIteration {
targetConns = connIndex.ForChannel(channelID)
}
if targetConns != nil {
for webConn := range targetConns {
broadcast(webConn)
}
continue
}
// There are multiple hubs in a system. So while supporting both channel based iteration and the old
// method, there would be events scoped to a channel being sent to multiple hubs. And only one hub would
// have the targetConns. Therefore, we need to stop here if channel based iteration is enabled, and it's a
// channel-scoped event.
if channelID := msg.GetBroadcast().ChannelId; channelID != "" && fastIteration {
continue
}
for webConn := range connIndex.All() {
broadcast(webConn)
}
case <-h.stop:
for webConn := range connIndex.All() {
webConn.Close()
h.platform.SetStatusOffline(webConn.UserId, false, false)
}
h.explicitStop = true
close(h.didStop)
return
}
}
}
doRecoverableStart = func() {
defer doRecover()
doStart()
}
doRecover = func() {
if !h.explicitStop {
if r := recover(); r != nil {
mlog.Error("Recovering from Hub panic.", mlog.Any("panic", r))
} else {
mlog.Error("Webhub stopped unexpectedly. Recovering.")
}
mlog.Error(string(debug.Stack()))
go doRecoverableStart()
}
}
go doRecoverableStart()
}
// areAllInactive returns whether all of the connections
// are inactive or not. It also returns true if there are
// no connections which is also intentional.
func areAllInactive(conns iter.Seq[*WebConn]) bool {
for conn := range conns {
if conn.Active.Load() {
return false
}
}
return true
}
// closeAndRemoveConn closes the send channel which will close the
// websocket connection, and then it removes the webConn from the conn index.
func closeAndRemoveConn(connIndex *hubConnectionIndex, conn *WebConn) {
close(conn.send)
connIndex.Remove(conn)
}
// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
// It requires 4 functionalities which need to be very fast:
// - check if a connection exists or not.
// - get all connections for a given userID.
// - get all connections for a given channelID.
// - get all connections.
type hubConnectionIndex struct {
// byUserId stores the set of connections for a given userID
byUserId map[string]map[*WebConn]struct{}
// byChannelID stores the set of connections for a given channelID
byChannelID map[string]map[*WebConn]struct{}
// byConnection serves the dual purpose of storing the channelIDs
// and also to get all connections
byConnection map[*WebConn][]string
byConnectionId map[string]*WebConn
// staleThreshold is the limit beyond which inactive connections
// will be deleted.
staleThreshold time.Duration
fastIteration bool
store store.Store
logger mlog.LoggerIFace
}
func newHubConnectionIndex(interval time.Duration,
store store.Store,
logger mlog.LoggerIFace,
fastIteration bool,
) *hubConnectionIndex {
return &hubConnectionIndex{
byUserId: make(map[string]map[*WebConn]struct{}),
byChannelID: make(map[string]map[*WebConn]struct{}),
byConnection: make(map[*WebConn][]string),
byConnectionId: make(map[string]*WebConn),
staleThreshold: interval,
store: store,
logger: logger,
fastIteration: fastIteration,
}
}
func (i *hubConnectionIndex) Add(wc *WebConn) error {
var channelIDs []string
if i.fastIteration {
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), wc.UserId, false, false)
if err != nil {
return fmt.Errorf("error getChannelMembersForUser: %v", err)
}
// Store channel IDs and add to byChannelID
channelIDs = make([]string, 0, len(cm))
for chID := range cm {
channelIDs = append(channelIDs, chID)
// Initialize the channel's map if it doesn't exist
if _, ok := i.byChannelID[chID]; !ok {
i.byChannelID[chID] = make(map[*WebConn]struct{})
}
i.byChannelID[chID][wc] = struct{}{}
}
}
// Initialize the user's map if it doesn't exist
if _, ok := i.byUserId[wc.UserId]; !ok {
i.byUserId[wc.UserId] = make(map[*WebConn]struct{})
}
i.byUserId[wc.UserId][wc] = struct{}{}
i.byConnection[wc] = channelIDs
i.byConnectionId[wc.GetConnectionID()] = wc
return nil
}
func (i *hubConnectionIndex) Remove(wc *WebConn) {
channelIDs, ok := i.byConnection[wc]
if !ok {
return
}
// Remove from byUserId
if userConns, ok := i.byUserId[wc.UserId]; ok {
delete(userConns, wc)
}
if i.fastIteration {
// Remove from byChannelID for each channel
for _, chID := range channelIDs {
if channelConns, ok := i.byChannelID[chID]; ok {
delete(channelConns, wc)
}
}
}
delete(i.byConnection, wc)
delete(i.byConnectionId, wc.GetConnectionID())
}
func (i *hubConnectionIndex) InvalidateCMCacheForUser(userID string) error {
// We make this query first to fail fast in case of an error.
cm, err := i.store.Channel().GetAllChannelMembersForUser(request.EmptyContext(i.logger), userID, false, false)
if err != nil {
return err
}
// Get all connections for this user
conns := i.ForUser(userID)
// Remove all user connections from existing channels
for conn := range conns {
if channelIDs, ok := i.byConnection[conn]; ok {
// Remove from old channels
for _, chID := range channelIDs {
if channelConns, ok := i.byChannelID[chID]; ok {
delete(channelConns, conn)
}
}
}
}
// Add connections to new channels
for conn := range conns {
newChannelIDs := make([]string, 0, len(cm))
for chID := range cm {
newChannelIDs = append(newChannelIDs, chID)
// Initialize channel map if needed
if _, ok := i.byChannelID[chID]; !ok {
i.byChannelID[chID] = make(map[*WebConn]struct{})
}
i.byChannelID[chID][conn] = struct{}{}
}
// Update connection metadata
if _, ok := i.byConnection[conn]; ok {
i.byConnection[conn] = newChannelIDs
}
}
return nil
}
func (i *hubConnectionIndex) Has(wc *WebConn) bool {
_, ok := i.byConnection[wc]
return ok
}
// ForUser returns all connections for a user ID.
func (i *hubConnectionIndex) ForUser(id string) iter.Seq[*WebConn] {
return maps.Keys(i.byUserId[id])
}
// ForChannel returns all connections for a channelID.
func (i *hubConnectionIndex) ForChannel(channelID string) iter.Seq[*WebConn] {
return maps.Keys(i.byChannelID[channelID])
}
// ForUserActiveCount returns the number of active connections for a userID
func (i *hubConnectionIndex) ForUserActiveCount(id string) int {
cnt := 0
for conn := range i.ForUser(id) {
if conn.Active.Load() {
cnt++
}
}
return cnt
}
// ForConnection returns the connection from its ID.
func (i *hubConnectionIndex) ForConnection(id string) *WebConn {
return i.byConnectionId[id]
}
// All returns the full webConn index.
func (i *hubConnectionIndex) All() map[*WebConn][]string {
return i.byConnection
}
// RemoveInactiveByConnectionID removes an inactive connection for the given
// userID and connectionID.
func (i *hubConnectionIndex) RemoveInactiveByConnectionID(userID, connectionID string) *WebConn {
// To handle empty sessions.
if userID == "" {
return nil
}
for conn := range i.ForUser(userID) {
if conn.GetConnectionID() == connectionID && !conn.Active.Load() {
i.Remove(conn)
return conn
}
}
return nil
}
// RemoveInactiveConnections removes all inactive connections whose lastUserActivityAt
// exceeded staleThreshold.
func (i *hubConnectionIndex) RemoveInactiveConnections() {
now := model.GetMillis()
for conn := range i.byConnection {
if !conn.Active.Load() && now-conn.lastUserActivityAt > i.staleThreshold.Milliseconds() {
i.Remove(conn)
}
}
}
// AllActive returns the number of active connections.
// This is only called during register/unregister so we can take
// a bit of perf hit here.
func (i *hubConnectionIndex) AllActive() int {
cnt := 0
for conn := range i.byConnection {
if conn.Active.Load() {
cnt++
}
}
return cnt
}