// Copyright (c) 2024 Mattermost Community Enterprise // Outgoing OAuth Connection Implementation package outgoing_oauth_connection import ( "context" "encoding/json" "net/http" "net/url" "strings" "sync" "time" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/store" ) // OutgoingOAuthConnectionConfig holds configuration for the outgoing OAuth connection interface type OutgoingOAuthConnectionConfig struct { Store store.Store Config func() *model.Config Logger mlog.LoggerIFace } // OutgoingOAuthConnectionImpl implements the OutgoingOAuthConnectionInterface type OutgoingOAuthConnectionImpl struct { store store.Store config func() *model.Config logger mlog.LoggerIFace // In-memory storage for connections (in production, this would use the store) connections map[string]*model.OutgoingOAuthConnection // Token cache tokenCache map[string]*cachedToken tokenCacheMutex sync.RWMutex mutex sync.RWMutex } type cachedToken struct { token *model.OutgoingOAuthConnectionToken expiresAt time.Time } // NewOutgoingOAuthConnectionInterface creates a new outgoing OAuth connection interface func NewOutgoingOAuthConnectionInterface(cfg *OutgoingOAuthConnectionConfig) *OutgoingOAuthConnectionImpl { return &OutgoingOAuthConnectionImpl{ store: cfg.Store, config: cfg.Config, logger: cfg.Logger, connections: make(map[string]*model.OutgoingOAuthConnection), tokenCache: make(map[string]*cachedToken), } } // DeleteConnection deletes an outgoing OAuth connection func (o *OutgoingOAuthConnectionImpl) DeleteConnection(rctx request.CTX, id string) *model.AppError { o.mutex.Lock() defer o.mutex.Unlock() if _, ok := o.connections[id]; !ok { return model.NewAppError("DeleteConnection", "outgoing_oauth.connection_not_found", map[string]any{"Id": id}, "", http.StatusNotFound) } delete(o.connections, id) // Clear token cache for this connection o.tokenCacheMutex.Lock() delete(o.tokenCache, id) o.tokenCacheMutex.Unlock() o.logger.Info("Deleted outgoing OAuth connection", mlog.String("connection_id", id), ) return nil } // GetConnection retrieves an outgoing OAuth connection by ID func (o *OutgoingOAuthConnectionImpl) GetConnection(rctx request.CTX, id string) (*model.OutgoingOAuthConnection, *model.AppError) { o.mutex.RLock() defer o.mutex.RUnlock() conn, ok := o.connections[id] if !ok { return nil, model.NewAppError("GetConnection", "outgoing_oauth.connection_not_found", map[string]any{"Id": id}, "", http.StatusNotFound) } // Return a copy to prevent external modification connCopy := *conn return &connCopy, nil } // GetConnections retrieves outgoing OAuth connections based on filters func (o *OutgoingOAuthConnectionImpl) GetConnections(rctx request.CTX, filters model.OutgoingOAuthConnectionGetConnectionsFilter) ([]*model.OutgoingOAuthConnection, *model.AppError) { o.mutex.RLock() defer o.mutex.RUnlock() var result []*model.OutgoingOAuthConnection startFound := filters.OffsetId == "" for _, conn := range o.connections { // Handle offset if !startFound { if conn.Id == filters.OffsetId { startFound = true } continue } // Filter by audience if specified if filters.Audience != "" { found := false for _, audience := range conn.Audiences { if audience == filters.Audience || strings.HasPrefix(filters.Audience, audience) { found = true break } } if !found { continue } } // Return a copy connCopy := *conn result = append(result, &connCopy) // Check limit if filters.Limit > 0 && len(result) >= filters.Limit { break } } return result, nil } // SaveConnection saves a new outgoing OAuth connection func (o *OutgoingOAuthConnectionImpl) SaveConnection(rctx request.CTX, conn *model.OutgoingOAuthConnection) (*model.OutgoingOAuthConnection, *model.AppError) { o.mutex.Lock() defer o.mutex.Unlock() if conn.Id == "" { conn.Id = model.NewId() } if _, exists := o.connections[conn.Id]; exists { return nil, model.NewAppError("SaveConnection", "outgoing_oauth.connection_exists", map[string]any{"Id": conn.Id}, "", http.StatusConflict) } now := model.GetMillis() conn.CreateAt = now conn.UpdateAt = now // Validate connection if err := o.validateConnection(conn); err != nil { return nil, err } // Store a copy connCopy := *conn o.connections[conn.Id] = &connCopy o.logger.Info("Saved outgoing OAuth connection", mlog.String("connection_id", conn.Id), mlog.String("name", conn.Name), ) // Return sanitized copy result := connCopy o.SanitizeConnection(&result) return &result, nil } // UpdateConnection updates an existing outgoing OAuth connection func (o *OutgoingOAuthConnectionImpl) UpdateConnection(rctx request.CTX, conn *model.OutgoingOAuthConnection) (*model.OutgoingOAuthConnection, *model.AppError) { o.mutex.Lock() defer o.mutex.Unlock() existing, ok := o.connections[conn.Id] if !ok { return nil, model.NewAppError("UpdateConnection", "outgoing_oauth.connection_not_found", map[string]any{"Id": conn.Id}, "", http.StatusNotFound) } // Preserve original creation info conn.CreateAt = existing.CreateAt conn.CreatorId = existing.CreatorId conn.UpdateAt = model.GetMillis() // Validate connection if err := o.validateConnection(conn); err != nil { return nil, err } // Store a copy connCopy := *conn o.connections[conn.Id] = &connCopy // Clear token cache for this connection o.tokenCacheMutex.Lock() delete(o.tokenCache, conn.Id) o.tokenCacheMutex.Unlock() o.logger.Info("Updated outgoing OAuth connection", mlog.String("connection_id", conn.Id), mlog.String("name", conn.Name), ) // Return sanitized copy result := connCopy o.SanitizeConnection(&result) return &result, nil } // SanitizeConnection removes sensitive data from a connection func (o *OutgoingOAuthConnectionImpl) SanitizeConnection(conn *model.OutgoingOAuthConnection) { conn.ClientSecret = "" conn.CredentialsPassword = nil } // SanitizeConnections removes sensitive data from multiple connections func (o *OutgoingOAuthConnectionImpl) SanitizeConnections(conns []*model.OutgoingOAuthConnection) { for _, conn := range conns { o.SanitizeConnection(conn) } } // GetConnectionForAudience finds a connection that matches the given URL func (o *OutgoingOAuthConnectionImpl) GetConnectionForAudience(rctx request.CTX, targetURL string) (*model.OutgoingOAuthConnection, *model.AppError) { o.mutex.RLock() defer o.mutex.RUnlock() parsedURL, err := url.Parse(targetURL) if err != nil { return nil, model.NewAppError("GetConnectionForAudience", "outgoing_oauth.invalid_url", nil, err.Error(), http.StatusBadRequest) } // Normalize URL for comparison normalizedURL := parsedURL.Scheme + "://" + parsedURL.Host if parsedURL.Path != "" { normalizedURL += parsedURL.Path } for _, conn := range o.connections { for _, audience := range conn.Audiences { // Check if the URL matches or starts with the audience if normalizedURL == audience || strings.HasPrefix(normalizedURL, audience) { connCopy := *conn return &connCopy, nil } } } return nil, model.NewAppError("GetConnectionForAudience", "outgoing_oauth.no_matching_connection", map[string]any{"URL": targetURL}, "", http.StatusNotFound) } // RetrieveTokenForConnection retrieves an OAuth token for the given connection func (o *OutgoingOAuthConnectionImpl) RetrieveTokenForConnection(rctx request.CTX, conn *model.OutgoingOAuthConnection) (*model.OutgoingOAuthConnectionToken, *model.AppError) { // Check cache first o.tokenCacheMutex.RLock() cached, ok := o.tokenCache[conn.Id] o.tokenCacheMutex.RUnlock() if ok && time.Now().Before(cached.expiresAt) { return cached.token, nil } // Need to fetch a new token token, expiresIn, err := o.fetchToken(rctx, conn) if err != nil { return nil, err } // Cache the token o.tokenCacheMutex.Lock() o.tokenCache[conn.Id] = &cachedToken{ token: token, expiresAt: time.Now().Add(time.Duration(expiresIn-60) * time.Second), // Expire 60 seconds early } o.tokenCacheMutex.Unlock() return token, nil } func (o *OutgoingOAuthConnectionImpl) validateConnection(conn *model.OutgoingOAuthConnection) *model.AppError { if conn.Name == "" { return model.NewAppError("validateConnection", "outgoing_oauth.name_required", nil, "", http.StatusBadRequest) } if conn.OAuthTokenURL == "" { return model.NewAppError("validateConnection", "outgoing_oauth.token_url_required", nil, "", http.StatusBadRequest) } if len(conn.Audiences) == 0 { return model.NewAppError("validateConnection", "outgoing_oauth.audiences_required", nil, "", http.StatusBadRequest) } switch conn.GrantType { case model.OutgoingOAuthConnectionGrantTypeClientCredentials: if conn.ClientId == "" || conn.ClientSecret == "" { return model.NewAppError("validateConnection", "outgoing_oauth.client_credentials_required", nil, "", http.StatusBadRequest) } case model.OutgoingOAuthConnectionGrantTypePassword: if conn.CredentialsUsername == nil || conn.CredentialsPassword == nil { return model.NewAppError("validateConnection", "outgoing_oauth.password_credentials_required", nil, "", http.StatusBadRequest) } default: return model.NewAppError("validateConnection", "outgoing_oauth.invalid_grant_type", map[string]any{"GrantType": conn.GrantType}, "", http.StatusBadRequest) } return nil } func (o *OutgoingOAuthConnectionImpl) fetchToken(rctx request.CTX, conn *model.OutgoingOAuthConnection) (*model.OutgoingOAuthConnectionToken, int64, *model.AppError) { // Build token request data := url.Values{} switch conn.GrantType { case model.OutgoingOAuthConnectionGrantTypeClientCredentials: data.Set("grant_type", "client_credentials") data.Set("client_id", conn.ClientId) data.Set("client_secret", conn.ClientSecret) case model.OutgoingOAuthConnectionGrantTypePassword: data.Set("grant_type", "password") data.Set("client_id", conn.ClientId) data.Set("client_secret", conn.ClientSecret) if conn.CredentialsUsername != nil { data.Set("username", *conn.CredentialsUsername) } if conn.CredentialsPassword != nil { data.Set("password", *conn.CredentialsPassword) } } // Make request ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "POST", conn.OAuthTokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, 0, model.NewAppError("fetchToken", "outgoing_oauth.request_failed", nil, err.Error(), http.StatusInternalServerError) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, 0, model.NewAppError("fetchToken", "outgoing_oauth.request_failed", nil, err.Error(), http.StatusInternalServerError) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, 0, model.NewAppError("fetchToken", "outgoing_oauth.token_request_failed", map[string]any{"Status": resp.StatusCode}, "", http.StatusInternalServerError) } // Parse response var tokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return nil, 0, model.NewAppError("fetchToken", "outgoing_oauth.parse_token_failed", nil, err.Error(), http.StatusInternalServerError) } token := &model.OutgoingOAuthConnectionToken{ AccessToken: tokenResponse.AccessToken, TokenType: tokenResponse.TokenType, } // Default expires_in if not provided if tokenResponse.ExpiresIn == 0 { tokenResponse.ExpiresIn = 3600 // 1 hour default } o.logger.Debug("Retrieved OAuth token for connection", mlog.String("connection_id", conn.Id), mlog.Int("expires_in", int(tokenResponse.ExpiresIn)), ) return token, tokenResponse.ExpiresIn, nil }