// Copyright (c) 2024 Mattermost Community Enterprise // SAML 2.0 SSO Implementation package saml import ( "bytes" "compress/flate" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" "encoding/pem" "encoding/xml" "fmt" "io" "net/http" "net/url" "os" "strings" "sync" "text/template" "time" saml2 "github.com/mattermost/gosaml2" "github.com/mattermost/gosaml2/types" dsig "github.com/russellhaering/goxmldsig" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" ) // SamlConfig holds configuration for SAML interface type SamlConfig struct { Config func() *model.Config Logger mlog.LoggerIFace ConfigDir string } // SamlImpl implements the SamlInterface type SamlImpl struct { config func() *model.Config logger mlog.LoggerIFace configDir string sp *saml2.SAMLServiceProvider idpMetadata *types.EntityDescriptor spPrivateKey *rsa.PrivateKey spCert *x509.Certificate idpCert *x509.Certificate mutex sync.RWMutex } // NewSamlInterface creates a new SAML interface func NewSamlInterface(cfg *SamlConfig) *SamlImpl { return &SamlImpl{ config: cfg.Config, logger: cfg.Logger, configDir: cfg.ConfigDir, } } // ConfigureSP configures the SAML Service Provider func (s *SamlImpl) ConfigureSP(rctx request.CTX) error { s.mutex.Lock() defer s.mutex.Unlock() cfg := s.config() samlSettings := cfg.SamlSettings if samlSettings.Enable == nil || !*samlSettings.Enable { return nil } // Load IdP certificate idpCert, err := s.loadCertificate(s.getFilePath(*samlSettings.IdpCertificateFile)) if err != nil { return fmt.Errorf("failed to load IdP certificate: %w", err) } s.idpCert = idpCert // Load SP private key if encryption/signing is enabled if (samlSettings.Encrypt != nil && *samlSettings.Encrypt) || (samlSettings.SignRequest != nil && *samlSettings.SignRequest) { privateKey, err := s.loadPrivateKey(s.getFilePath(*samlSettings.PrivateKeyFile)) if err != nil { return fmt.Errorf("failed to load SP private key: %w", err) } s.spPrivateKey = privateKey spCert, err := s.loadCertificate(s.getFilePath(*samlSettings.PublicCertificateFile)) if err != nil { return fmt.Errorf("failed to load SP certificate: %w", err) } s.spCert = spCert } // Create certificate store for IdP verification certStore := dsig.MemoryX509CertificateStore{ Roots: []*x509.Certificate{idpCert}, } // Configure SP sp := &saml2.SAMLServiceProvider{ IdentityProviderSSOURL: *samlSettings.IdpURL, IdentityProviderIssuer: *samlSettings.IdpDescriptorURL, ServiceProviderIssuer: *samlSettings.ServiceProviderIdentifier, AssertionConsumerServiceURL: *samlSettings.AssertionConsumerServiceURL, IDPCertificateStore: &certStore, SkipSignatureValidation: samlSettings.Verify == nil || !*samlSettings.Verify, } // Configure signing if samlSettings.SignRequest != nil && *samlSettings.SignRequest && s.spPrivateKey != nil { sp.SignAuthnRequests = true sp.SPKeyStore = dsig.TLSCertKeyStore(tls.Certificate{ Certificate: [][]byte{s.spCert.Raw}, PrivateKey: s.spPrivateKey, }) // Set signature algorithm if samlSettings.SignatureAlgorithm != nil { switch *samlSettings.SignatureAlgorithm { case model.SamlSettingsSignatureAlgorithmSha256: sp.SignAuthnRequestsAlgorithm = dsig.RSASHA256SignatureMethod case model.SamlSettingsSignatureAlgorithmSha512: sp.SignAuthnRequestsAlgorithm = dsig.RSASHA512SignatureMethod default: sp.SignAuthnRequestsAlgorithm = dsig.RSASHA1SignatureMethod } } } s.sp = sp s.logger.Info("SAML Service Provider configured successfully", mlog.String("issuer", sp.ServiceProviderIssuer), mlog.String("acs_url", sp.AssertionConsumerServiceURL), ) return nil } // BuildRequest builds a SAML authentication request func (s *SamlImpl) BuildRequest(rctx request.CTX, relayState string) (*model.SamlAuthRequest, *model.AppError) { s.mutex.RLock() defer s.mutex.RUnlock() if s.sp == nil { return nil, model.NewAppError("BuildRequest", "saml.sp_not_configured", nil, "SAML Service Provider not configured", http.StatusInternalServerError) } cfg := s.config() samlSettings := cfg.SamlSettings // Build AuthnRequest authnRequest := s.buildAuthnRequest(&samlSettings) // Serialize to XML xmlBytes, err := xml.MarshalIndent(authnRequest, "", " ") if err != nil { return nil, model.NewAppError("BuildRequest", "saml.build_request.marshal", nil, err.Error(), http.StatusInternalServerError) } // Deflate compress var compressed bytes.Buffer writer, err := flate.NewWriter(&compressed, flate.DefaultCompression) if err != nil { return nil, model.NewAppError("BuildRequest", "saml.build_request.compress", nil, err.Error(), http.StatusInternalServerError) } writer.Write(xmlBytes) writer.Close() // Base64 encode base64Request := base64.StdEncoding.EncodeToString(compressed.Bytes()) // Build redirect URL redirectURL, err := url.Parse(*samlSettings.IdpURL) if err != nil { return nil, model.NewAppError("BuildRequest", "saml.build_request.parse_url", nil, err.Error(), http.StatusInternalServerError) } query := redirectURL.Query() query.Set("SAMLRequest", base64Request) if relayState != "" { query.Set("RelayState", relayState) } redirectURL.RawQuery = query.Encode() return &model.SamlAuthRequest{ Base64AuthRequest: base64Request, URL: redirectURL.String(), RelayState: relayState, }, nil } // DoLogin processes a SAML login response func (s *SamlImpl) DoLogin(rctx request.CTX, encodedXML string, relayState map[string]string) (*model.User, *saml2.AssertionInfo, *model.AppError) { s.mutex.RLock() defer s.mutex.RUnlock() if s.sp == nil { return nil, nil, model.NewAppError("DoLogin", "saml.sp_not_configured", nil, "SAML Service Provider not configured", http.StatusInternalServerError) } cfg := s.config() samlSettings := cfg.SamlSettings // Decode the SAML response rawXML, err := base64.StdEncoding.DecodeString(encodedXML) if err != nil { return nil, nil, model.NewAppError("DoLogin", "saml.do_login.decode", nil, err.Error(), http.StatusBadRequest) } // Parse and validate the assertion assertionInfo, err := s.sp.RetrieveAssertionInfo(encodedXML) if err != nil { s.logger.Error("Failed to retrieve assertion info", mlog.Err(err), mlog.String("raw_response", string(rawXML)), ) return nil, nil, model.NewAppError("DoLogin", "saml.do_login.validate", nil, err.Error(), http.StatusBadRequest) } if assertionInfo.WarningInfo.InvalidTime { return nil, nil, model.NewAppError("DoLogin", "saml.do_login.invalid_time", nil, "SAML assertion has invalid time", http.StatusBadRequest) } if assertionInfo.WarningInfo.NotInAudience { return nil, nil, model.NewAppError("DoLogin", "saml.do_login.invalid_audience", nil, "SAML assertion audience mismatch", http.StatusBadRequest) } // Extract user attributes user, appErr := s.extractUserFromAssertion(assertionInfo, &samlSettings) if appErr != nil { return nil, nil, appErr } return user, assertionInfo, nil } // GetMetadata returns the SP metadata XML func (s *SamlImpl) GetMetadata(rctx request.CTX) (string, *model.AppError) { s.mutex.RLock() defer s.mutex.RUnlock() cfg := s.config() samlSettings := cfg.SamlSettings if samlSettings.Enable == nil || !*samlSettings.Enable { return "", model.NewAppError("GetMetadata", "saml.not_enabled", nil, "SAML is not enabled", http.StatusNotImplemented) } metadata, err := s.generateMetadataXML(&samlSettings) if err != nil { return "", model.NewAppError("GetMetadata", "saml.get_metadata.generate", nil, err.Error(), http.StatusInternalServerError) } return metadata, nil } // CheckProviderAttributes validates provider attributes and returns warnings func (s *SamlImpl) CheckProviderAttributes(rctx request.CTX, SS *model.SamlSettings, ouser *model.User, patch *model.UserPatch) string { var warnings []string // Check email attribute if SS.EmailAttribute == nil || *SS.EmailAttribute == "" { warnings = append(warnings, "Email attribute is not configured") } // Check username attribute if SS.UsernameAttribute == nil || *SS.UsernameAttribute == "" { warnings = append(warnings, "Username attribute is not configured") } // Check if user email would be changed if ouser != nil && patch != nil && patch.Email != nil { if *patch.Email != ouser.Email { warnings = append(warnings, fmt.Sprintf("User email would change from %s to %s", ouser.Email, *patch.Email)) } } // Check first name attribute if SS.FirstNameAttribute == nil || *SS.FirstNameAttribute == "" { warnings = append(warnings, "First name attribute is not configured") } // Check last name attribute if SS.LastNameAttribute == nil || *SS.LastNameAttribute == "" { warnings = append(warnings, "Last name attribute is not configured") } // Check nickname attribute if SS.NicknameAttribute == nil || *SS.NicknameAttribute == "" { warnings = append(warnings, "Nickname attribute is not configured (optional)") } // Check position attribute if SS.PositionAttribute == nil || *SS.PositionAttribute == "" { warnings = append(warnings, "Position attribute is not configured (optional)") } // Check locale attribute if SS.LocaleAttribute == nil || *SS.LocaleAttribute == "" { warnings = append(warnings, "Locale attribute is not configured (optional)") } if len(warnings) > 0 { return strings.Join(warnings, "; ") } return "" } // Helper methods func (s *SamlImpl) getFilePath(path string) string { if strings.HasPrefix(path, "/") { return path } return s.configDir + "/" + path } func (s *SamlImpl) loadCertificate(path string) (*x509.Certificate, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } block, _ := pem.Decode(data) if block == nil { // Try parsing as DER return x509.ParseCertificate(data) } return x509.ParseCertificate(block.Bytes) } func (s *SamlImpl) loadPrivateKey(path string) (*rsa.PrivateKey, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } block, _ := pem.Decode(data) if block == nil { return nil, fmt.Errorf("failed to decode PEM block") } key, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { // Try PKCS8 pkcs8Key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, err } rsaKey, ok := pkcs8Key.(*rsa.PrivateKey) if !ok { return nil, fmt.Errorf("key is not RSA") } return rsaKey, nil } return key, nil } // AuthnRequest represents a SAML AuthnRequest type AuthnRequest struct { XMLName xml.Name `xml:"samlp:AuthnRequest"` XMLNsSamlp string `xml:"xmlns:samlp,attr"` XMLNsSaml string `xml:"xmlns:saml,attr"` ID string `xml:"ID,attr"` Version string `xml:"Version,attr"` IssueInstant string `xml:"IssueInstant,attr"` Destination string `xml:"Destination,attr"` AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` ProtocolBinding string `xml:"ProtocolBinding,attr"` Issuer Issuer `xml:"saml:Issuer"` NameIDPolicy *NameIDPolicy RequestedAuthnContext *RequestedAuthnContext } type Issuer struct { XMLName xml.Name `xml:"saml:Issuer"` Value string `xml:",chardata"` } type NameIDPolicy struct { XMLName xml.Name `xml:"samlp:NameIDPolicy"` Format string `xml:"Format,attr,omitempty"` AllowCreate bool `xml:"AllowCreate,attr"` } type RequestedAuthnContext struct { XMLName xml.Name `xml:"samlp:RequestedAuthnContext"` Comparison string `xml:"Comparison,attr"` AuthnContextClassRef string `xml:"saml:AuthnContextClassRef"` } func (s *SamlImpl) buildAuthnRequest(samlSettings *model.SamlSettings) *AuthnRequest { id := "_" + model.NewId() now := time.Now().UTC().Format(time.RFC3339) return &AuthnRequest{ XMLNsSamlp: "urn:oasis:names:tc:SAML:2.0:protocol", XMLNsSaml: "urn:oasis:names:tc:SAML:2.0:assertion", ID: id, Version: "2.0", IssueInstant: now, Destination: *samlSettings.IdpURL, AssertionConsumerServiceURL: *samlSettings.AssertionConsumerServiceURL, ProtocolBinding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", Issuer: Issuer{ Value: *samlSettings.ServiceProviderIdentifier, }, NameIDPolicy: &NameIDPolicy{ Format: "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified", AllowCreate: true, }, } } func (s *SamlImpl) extractUserFromAssertion(assertionInfo *saml2.AssertionInfo, samlSettings *model.SamlSettings) (*model.User, *model.AppError) { user := &model.User{ AuthService: model.UserAuthServiceSaml, AuthData: model.NewPointer(assertionInfo.NameID), } // Extract attributes attrs := assertionInfo.Values // Email (required) if samlSettings.EmailAttribute != nil && *samlSettings.EmailAttribute != "" { if email := getFirstAttributeValue(attrs, *samlSettings.EmailAttribute); email != "" { user.Email = email } else { return nil, model.NewAppError("extractUserFromAssertion", "saml.missing_email", nil, "Email attribute not found in SAML response", http.StatusBadRequest) } } // Username if samlSettings.UsernameAttribute != nil && *samlSettings.UsernameAttribute != "" { if username := getFirstAttributeValue(attrs, *samlSettings.UsernameAttribute); username != "" { user.Username = username } } // First name if samlSettings.FirstNameAttribute != nil && *samlSettings.FirstNameAttribute != "" { if firstName := getFirstAttributeValue(attrs, *samlSettings.FirstNameAttribute); firstName != "" { user.FirstName = firstName } } // Last name if samlSettings.LastNameAttribute != nil && *samlSettings.LastNameAttribute != "" { if lastName := getFirstAttributeValue(attrs, *samlSettings.LastNameAttribute); lastName != "" { user.LastName = lastName } } // Nickname if samlSettings.NicknameAttribute != nil && *samlSettings.NicknameAttribute != "" { if nickname := getFirstAttributeValue(attrs, *samlSettings.NicknameAttribute); nickname != "" { user.Nickname = nickname } } // Position if samlSettings.PositionAttribute != nil && *samlSettings.PositionAttribute != "" { if position := getFirstAttributeValue(attrs, *samlSettings.PositionAttribute); position != "" { user.Position = position } } // Locale if samlSettings.LocaleAttribute != nil && *samlSettings.LocaleAttribute != "" { if locale := getFirstAttributeValue(attrs, *samlSettings.LocaleAttribute); locale != "" { user.Locale = locale } } // ID attribute (for AuthData) if samlSettings.IdAttribute != nil && *samlSettings.IdAttribute != "" { if id := getFirstAttributeValue(attrs, *samlSettings.IdAttribute); id != "" { user.AuthData = model.NewPointer(id) } } return user, nil } func getFirstAttributeValue(attrs saml2.Values, name string) string { if values, ok := attrs[name]; ok && len(values.Values) > 0 { return values.Values[0].Value } return "" } const metadataTemplate = ` {{if .Certificate}} {{.Certificate}} {{.Certificate}} {{end}} urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified ` func (s *SamlImpl) generateMetadataXML(samlSettings *model.SamlSettings) (string, error) { tmpl, err := template.New("metadata").Parse(metadataTemplate) if err != nil { return "", err } var cert string if s.spCert != nil { cert = base64.StdEncoding.EncodeToString(s.spCert.Raw) } data := struct { EntityID string SignRequests bool WantAssertionsSigned bool Certificate string ACSUrl string }{ EntityID: *samlSettings.ServiceProviderIdentifier, SignRequests: samlSettings.SignRequest != nil && *samlSettings.SignRequest, WantAssertionsSigned: samlSettings.Verify != nil && *samlSettings.Verify, Certificate: cert, ACSUrl: *samlSettings.AssertionConsumerServiceURL, } var buf bytes.Buffer if err := tmpl.Execute(&buf, data); err != nil { return "", err } return buf.String(), nil } // FetchIdPMetadata fetches metadata from IdP URL (helper function) func (s *SamlImpl) FetchIdPMetadata(metadataURL string) (*types.EntityDescriptor, error) { resp, err := http.Get(metadataURL) if err != nil { return nil, fmt.Errorf("failed to fetch IdP metadata: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to fetch IdP metadata: status %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read IdP metadata: %w", err) } var metadata types.EntityDescriptor if err := xml.Unmarshal(body, &metadata); err != nil { return nil, fmt.Errorf("failed to parse IdP metadata: %w", err) } return &metadata, nil }