mattermost-community-enterp.../vendor/github.com/mattermost-community/enterprise/saml/saml.go
Claude ec1f89217a Merge: Complete Mattermost Server with Community Enterprise
Full Mattermost server source with integrated Community Enterprise features.
Includes vendor directory for offline/air-gapped builds.

Structure:
- enterprise-impl/: Enterprise feature implementations
- enterprise-community/: Init files that register implementations
- enterprise/: Bridge imports (community_imports.go)
- vendor/: All dependencies for offline builds

Build (online):
  go build ./cmd/mattermost

Build (offline/air-gapped):
  go build -mod=vendor ./cmd/mattermost

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-17 23:59:07 +09:00

582 lines
18 KiB
Go

// Copyright (c) 2024 Mattermost Community Enterprise
// SAML 2.0 SSO Implementation
package saml
import (
"bytes"
"compress/flate"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"text/template"
"time"
saml2 "github.com/mattermost/gosaml2"
"github.com/mattermost/gosaml2/types"
dsig "github.com/russellhaering/goxmldsig"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
)
// SamlConfig holds configuration for SAML interface
type SamlConfig struct {
Config func() *model.Config
Logger mlog.LoggerIFace
ConfigDir string
}
// SamlImpl implements the SamlInterface
type SamlImpl struct {
config func() *model.Config
logger mlog.LoggerIFace
configDir string
sp *saml2.SAMLServiceProvider
idpMetadata *types.EntityDescriptor
spPrivateKey *rsa.PrivateKey
spCert *x509.Certificate
idpCert *x509.Certificate
mutex sync.RWMutex
}
// NewSamlInterface creates a new SAML interface
func NewSamlInterface(cfg *SamlConfig) *SamlImpl {
return &SamlImpl{
config: cfg.Config,
logger: cfg.Logger,
configDir: cfg.ConfigDir,
}
}
// ConfigureSP configures the SAML Service Provider
func (s *SamlImpl) ConfigureSP(rctx request.CTX) error {
s.mutex.Lock()
defer s.mutex.Unlock()
cfg := s.config()
samlSettings := cfg.SamlSettings
if samlSettings.Enable == nil || !*samlSettings.Enable {
return nil
}
// Load IdP certificate
idpCert, err := s.loadCertificate(s.getFilePath(*samlSettings.IdpCertificateFile))
if err != nil {
return fmt.Errorf("failed to load IdP certificate: %w", err)
}
s.idpCert = idpCert
// Load SP private key if encryption/signing is enabled
if (samlSettings.Encrypt != nil && *samlSettings.Encrypt) ||
(samlSettings.SignRequest != nil && *samlSettings.SignRequest) {
privateKey, err := s.loadPrivateKey(s.getFilePath(*samlSettings.PrivateKeyFile))
if err != nil {
return fmt.Errorf("failed to load SP private key: %w", err)
}
s.spPrivateKey = privateKey
spCert, err := s.loadCertificate(s.getFilePath(*samlSettings.PublicCertificateFile))
if err != nil {
return fmt.Errorf("failed to load SP certificate: %w", err)
}
s.spCert = spCert
}
// Create certificate store for IdP verification
certStore := dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{idpCert},
}
// Configure SP
sp := &saml2.SAMLServiceProvider{
IdentityProviderSSOURL: *samlSettings.IdpURL,
IdentityProviderIssuer: *samlSettings.IdpDescriptorURL,
ServiceProviderIssuer: *samlSettings.ServiceProviderIdentifier,
AssertionConsumerServiceURL: *samlSettings.AssertionConsumerServiceURL,
IDPCertificateStore: &certStore,
SkipSignatureValidation: samlSettings.Verify == nil || !*samlSettings.Verify,
}
// Configure signing
if samlSettings.SignRequest != nil && *samlSettings.SignRequest && s.spPrivateKey != nil {
sp.SignAuthnRequests = true
sp.SPKeyStore = dsig.TLSCertKeyStore(tls.Certificate{
Certificate: [][]byte{s.spCert.Raw},
PrivateKey: s.spPrivateKey,
})
// Set signature algorithm
if samlSettings.SignatureAlgorithm != nil {
switch *samlSettings.SignatureAlgorithm {
case model.SamlSettingsSignatureAlgorithmSha256:
sp.SignAuthnRequestsAlgorithm = dsig.RSASHA256SignatureMethod
case model.SamlSettingsSignatureAlgorithmSha512:
sp.SignAuthnRequestsAlgorithm = dsig.RSASHA512SignatureMethod
default:
sp.SignAuthnRequestsAlgorithm = dsig.RSASHA1SignatureMethod
}
}
}
s.sp = sp
s.logger.Info("SAML Service Provider configured successfully",
mlog.String("issuer", sp.ServiceProviderIssuer),
mlog.String("acs_url", sp.AssertionConsumerServiceURL),
)
return nil
}
// BuildRequest builds a SAML authentication request
func (s *SamlImpl) BuildRequest(rctx request.CTX, relayState string) (*model.SamlAuthRequest, *model.AppError) {
s.mutex.RLock()
defer s.mutex.RUnlock()
if s.sp == nil {
return nil, model.NewAppError("BuildRequest", "saml.sp_not_configured", nil, "SAML Service Provider not configured", http.StatusInternalServerError)
}
cfg := s.config()
samlSettings := cfg.SamlSettings
// Build AuthnRequest
authnRequest := s.buildAuthnRequest(&samlSettings)
// Serialize to XML
xmlBytes, err := xml.MarshalIndent(authnRequest, "", " ")
if err != nil {
return nil, model.NewAppError("BuildRequest", "saml.build_request.marshal", nil, err.Error(), http.StatusInternalServerError)
}
// Deflate compress
var compressed bytes.Buffer
writer, err := flate.NewWriter(&compressed, flate.DefaultCompression)
if err != nil {
return nil, model.NewAppError("BuildRequest", "saml.build_request.compress", nil, err.Error(), http.StatusInternalServerError)
}
writer.Write(xmlBytes)
writer.Close()
// Base64 encode
base64Request := base64.StdEncoding.EncodeToString(compressed.Bytes())
// Build redirect URL
redirectURL, err := url.Parse(*samlSettings.IdpURL)
if err != nil {
return nil, model.NewAppError("BuildRequest", "saml.build_request.parse_url", nil, err.Error(), http.StatusInternalServerError)
}
query := redirectURL.Query()
query.Set("SAMLRequest", base64Request)
if relayState != "" {
query.Set("RelayState", relayState)
}
redirectURL.RawQuery = query.Encode()
return &model.SamlAuthRequest{
Base64AuthRequest: base64Request,
URL: redirectURL.String(),
RelayState: relayState,
}, nil
}
// DoLogin processes a SAML login response
func (s *SamlImpl) DoLogin(rctx request.CTX, encodedXML string, relayState map[string]string) (*model.User, *saml2.AssertionInfo, *model.AppError) {
s.mutex.RLock()
defer s.mutex.RUnlock()
if s.sp == nil {
return nil, nil, model.NewAppError("DoLogin", "saml.sp_not_configured", nil, "SAML Service Provider not configured", http.StatusInternalServerError)
}
cfg := s.config()
samlSettings := cfg.SamlSettings
// Decode the SAML response
rawXML, err := base64.StdEncoding.DecodeString(encodedXML)
if err != nil {
return nil, nil, model.NewAppError("DoLogin", "saml.do_login.decode", nil, err.Error(), http.StatusBadRequest)
}
// Parse and validate the assertion
assertionInfo, err := s.sp.RetrieveAssertionInfo(encodedXML)
if err != nil {
s.logger.Error("Failed to retrieve assertion info",
mlog.Err(err),
mlog.String("raw_response", string(rawXML)),
)
return nil, nil, model.NewAppError("DoLogin", "saml.do_login.validate", nil, err.Error(), http.StatusBadRequest)
}
if assertionInfo.WarningInfo.InvalidTime {
return nil, nil, model.NewAppError("DoLogin", "saml.do_login.invalid_time", nil, "SAML assertion has invalid time", http.StatusBadRequest)
}
if assertionInfo.WarningInfo.NotInAudience {
return nil, nil, model.NewAppError("DoLogin", "saml.do_login.invalid_audience", nil, "SAML assertion audience mismatch", http.StatusBadRequest)
}
// Extract user attributes
user, appErr := s.extractUserFromAssertion(assertionInfo, &samlSettings)
if appErr != nil {
return nil, nil, appErr
}
return user, assertionInfo, nil
}
// GetMetadata returns the SP metadata XML
func (s *SamlImpl) GetMetadata(rctx request.CTX) (string, *model.AppError) {
s.mutex.RLock()
defer s.mutex.RUnlock()
cfg := s.config()
samlSettings := cfg.SamlSettings
if samlSettings.Enable == nil || !*samlSettings.Enable {
return "", model.NewAppError("GetMetadata", "saml.not_enabled", nil, "SAML is not enabled", http.StatusNotImplemented)
}
metadata, err := s.generateMetadataXML(&samlSettings)
if err != nil {
return "", model.NewAppError("GetMetadata", "saml.get_metadata.generate", nil, err.Error(), http.StatusInternalServerError)
}
return metadata, nil
}
// CheckProviderAttributes validates provider attributes and returns warnings
func (s *SamlImpl) CheckProviderAttributes(rctx request.CTX, SS *model.SamlSettings, ouser *model.User, patch *model.UserPatch) string {
var warnings []string
// Check email attribute
if SS.EmailAttribute == nil || *SS.EmailAttribute == "" {
warnings = append(warnings, "Email attribute is not configured")
}
// Check username attribute
if SS.UsernameAttribute == nil || *SS.UsernameAttribute == "" {
warnings = append(warnings, "Username attribute is not configured")
}
// Check if user email would be changed
if ouser != nil && patch != nil && patch.Email != nil {
if *patch.Email != ouser.Email {
warnings = append(warnings, fmt.Sprintf("User email would change from %s to %s", ouser.Email, *patch.Email))
}
}
// Check first name attribute
if SS.FirstNameAttribute == nil || *SS.FirstNameAttribute == "" {
warnings = append(warnings, "First name attribute is not configured")
}
// Check last name attribute
if SS.LastNameAttribute == nil || *SS.LastNameAttribute == "" {
warnings = append(warnings, "Last name attribute is not configured")
}
// Check nickname attribute
if SS.NicknameAttribute == nil || *SS.NicknameAttribute == "" {
warnings = append(warnings, "Nickname attribute is not configured (optional)")
}
// Check position attribute
if SS.PositionAttribute == nil || *SS.PositionAttribute == "" {
warnings = append(warnings, "Position attribute is not configured (optional)")
}
// Check locale attribute
if SS.LocaleAttribute == nil || *SS.LocaleAttribute == "" {
warnings = append(warnings, "Locale attribute is not configured (optional)")
}
if len(warnings) > 0 {
return strings.Join(warnings, "; ")
}
return ""
}
// Helper methods
func (s *SamlImpl) getFilePath(path string) string {
if strings.HasPrefix(path, "/") {
return path
}
return s.configDir + "/" + path
}
func (s *SamlImpl) loadCertificate(path string) (*x509.Certificate, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
block, _ := pem.Decode(data)
if block == nil {
// Try parsing as DER
return x509.ParseCertificate(data)
}
return x509.ParseCertificate(block.Bytes)
}
func (s *SamlImpl) loadPrivateKey(path string) (*rsa.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
// Try PKCS8
pkcs8Key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
rsaKey, ok := pkcs8Key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("key is not RSA")
}
return rsaKey, nil
}
return key, nil
}
// AuthnRequest represents a SAML AuthnRequest
type AuthnRequest struct {
XMLName xml.Name `xml:"samlp:AuthnRequest"`
XMLNsSamlp string `xml:"xmlns:samlp,attr"`
XMLNsSaml string `xml:"xmlns:saml,attr"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant string `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr"`
AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"`
ProtocolBinding string `xml:"ProtocolBinding,attr"`
Issuer Issuer `xml:"saml:Issuer"`
NameIDPolicy *NameIDPolicy
RequestedAuthnContext *RequestedAuthnContext
}
type Issuer struct {
XMLName xml.Name `xml:"saml:Issuer"`
Value string `xml:",chardata"`
}
type NameIDPolicy struct {
XMLName xml.Name `xml:"samlp:NameIDPolicy"`
Format string `xml:"Format,attr,omitempty"`
AllowCreate bool `xml:"AllowCreate,attr"`
}
type RequestedAuthnContext struct {
XMLName xml.Name `xml:"samlp:RequestedAuthnContext"`
Comparison string `xml:"Comparison,attr"`
AuthnContextClassRef string `xml:"saml:AuthnContextClassRef"`
}
func (s *SamlImpl) buildAuthnRequest(samlSettings *model.SamlSettings) *AuthnRequest {
id := "_" + model.NewId()
now := time.Now().UTC().Format(time.RFC3339)
return &AuthnRequest{
XMLNsSamlp: "urn:oasis:names:tc:SAML:2.0:protocol",
XMLNsSaml: "urn:oasis:names:tc:SAML:2.0:assertion",
ID: id,
Version: "2.0",
IssueInstant: now,
Destination: *samlSettings.IdpURL,
AssertionConsumerServiceURL: *samlSettings.AssertionConsumerServiceURL,
ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
Issuer: Issuer{
Value: *samlSettings.ServiceProviderIdentifier,
},
NameIDPolicy: &NameIDPolicy{
Format: "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
AllowCreate: true,
},
}
}
func (s *SamlImpl) extractUserFromAssertion(assertionInfo *saml2.AssertionInfo, samlSettings *model.SamlSettings) (*model.User, *model.AppError) {
user := &model.User{
AuthService: model.UserAuthServiceSaml,
AuthData: model.NewPointer(assertionInfo.NameID),
}
// Extract attributes
attrs := assertionInfo.Values
// Email (required)
if samlSettings.EmailAttribute != nil && *samlSettings.EmailAttribute != "" {
if email := getFirstAttributeValue(attrs, *samlSettings.EmailAttribute); email != "" {
user.Email = email
} else {
return nil, model.NewAppError("extractUserFromAssertion", "saml.missing_email", nil, "Email attribute not found in SAML response", http.StatusBadRequest)
}
}
// Username
if samlSettings.UsernameAttribute != nil && *samlSettings.UsernameAttribute != "" {
if username := getFirstAttributeValue(attrs, *samlSettings.UsernameAttribute); username != "" {
user.Username = username
}
}
// First name
if samlSettings.FirstNameAttribute != nil && *samlSettings.FirstNameAttribute != "" {
if firstName := getFirstAttributeValue(attrs, *samlSettings.FirstNameAttribute); firstName != "" {
user.FirstName = firstName
}
}
// Last name
if samlSettings.LastNameAttribute != nil && *samlSettings.LastNameAttribute != "" {
if lastName := getFirstAttributeValue(attrs, *samlSettings.LastNameAttribute); lastName != "" {
user.LastName = lastName
}
}
// Nickname
if samlSettings.NicknameAttribute != nil && *samlSettings.NicknameAttribute != "" {
if nickname := getFirstAttributeValue(attrs, *samlSettings.NicknameAttribute); nickname != "" {
user.Nickname = nickname
}
}
// Position
if samlSettings.PositionAttribute != nil && *samlSettings.PositionAttribute != "" {
if position := getFirstAttributeValue(attrs, *samlSettings.PositionAttribute); position != "" {
user.Position = position
}
}
// Locale
if samlSettings.LocaleAttribute != nil && *samlSettings.LocaleAttribute != "" {
if locale := getFirstAttributeValue(attrs, *samlSettings.LocaleAttribute); locale != "" {
user.Locale = locale
}
}
// ID attribute (for AuthData)
if samlSettings.IdAttribute != nil && *samlSettings.IdAttribute != "" {
if id := getFirstAttributeValue(attrs, *samlSettings.IdAttribute); id != "" {
user.AuthData = model.NewPointer(id)
}
}
return user, nil
}
func getFirstAttributeValue(attrs saml2.Values, name string) string {
if values, ok := attrs[name]; ok && len(values.Values) > 0 {
return values.Values[0].Value
}
return ""
}
const metadataTemplate = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{{.EntityID}}">
<md:SPSSODescriptor AuthnRequestsSigned="{{.SignRequests}}" WantAssertionsSigned="{{.WantAssertionsSigned}}" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
{{if .Certificate}}
<md:KeyDescriptor use="signing">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>{{.Certificate}}</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:KeyDescriptor use="encryption">
<ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:X509Data>
<ds:X509Certificate>{{.Certificate}}</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
{{end}}
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{{.ACSUrl}}" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>`
func (s *SamlImpl) generateMetadataXML(samlSettings *model.SamlSettings) (string, error) {
tmpl, err := template.New("metadata").Parse(metadataTemplate)
if err != nil {
return "", err
}
var cert string
if s.spCert != nil {
cert = base64.StdEncoding.EncodeToString(s.spCert.Raw)
}
data := struct {
EntityID string
SignRequests bool
WantAssertionsSigned bool
Certificate string
ACSUrl string
}{
EntityID: *samlSettings.ServiceProviderIdentifier,
SignRequests: samlSettings.SignRequest != nil && *samlSettings.SignRequest,
WantAssertionsSigned: samlSettings.Verify != nil && *samlSettings.Verify,
Certificate: cert,
ACSUrl: *samlSettings.AssertionConsumerServiceURL,
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// FetchIdPMetadata fetches metadata from IdP URL (helper function)
func (s *SamlImpl) FetchIdPMetadata(metadataURL string) (*types.EntityDescriptor, error) {
resp, err := http.Get(metadataURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch IdP metadata: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch IdP metadata: status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read IdP metadata: %w", err)
}
var metadata types.EntityDescriptor
if err := xml.Unmarshal(body, &metadata); err != nil {
return nil, fmt.Errorf("failed to parse IdP metadata: %w", err)
}
return &metadata, nil
}