// Copyright 2016 Russell Haering et al. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package saml2 import ( "bytes" "compress/flate" "crypto/tls" "crypto/x509" "encoding/base64" "fmt" "io" "encoding/xml" "github.com/beevik/etree" "github.com/mattermost/gosaml2/types" rtvalidator "github.com/mattermost/xml-roundtrip-validator" dsig "github.com/russellhaering/goxmldsig" "github.com/russellhaering/goxmldsig/etreeutils" ) const ( defaultMaxDecompressedResponseSize = 5 * 1024 * 1024 ) func (sp *SAMLServiceProvider) validationContext() *dsig.ValidationContext { ctx := dsig.NewDefaultValidationContext(sp.IDPCertificateStore) ctx.Clock = sp.Clock return ctx } // validateResponseAttributes validates a SAML Response's tag and attributes. It does // not inspect child elements of the Response at all. func (sp *SAMLServiceProvider) validateResponseAttributes(response *types.Response) error { if response.Destination != "" && response.Destination != sp.AssertionConsumerServiceURL { return ErrInvalidValue{ Key: DestinationAttr, Expected: sp.AssertionConsumerServiceURL, Actual: response.Destination, } } if response.Version != "2.0" { return ErrInvalidValue{ Reason: ReasonUnsupported, Key: "SAML version", Expected: "2.0", Actual: response.Version, } } return nil } // validateLogoutResponseAttributes validates a SAML Response's tag and attributes. It does // not inspect child elements of the Response at all. func (sp *SAMLServiceProvider) validateLogoutResponseAttributes(response *types.LogoutResponse) error { if response.Destination != "" && response.Destination != sp.ServiceProviderSLOURL { return ErrInvalidValue{ Key: DestinationAttr, Expected: sp.ServiceProviderSLOURL, Actual: response.Destination, } } if response.Version != "2.0" { return ErrInvalidValue{ Reason: ReasonUnsupported, Key: "SAML version", Expected: "2.0", Actual: response.Version, } } return nil } func xmlUnmarshalElement(el *etree.Element, obj interface{}) error { doc := etree.NewDocument() doc.SetRoot(el) data, err := doc.WriteToBytes() if err != nil { return err } err = xml.Unmarshal(data, obj) if err != nil { return err } return nil } func (sp *SAMLServiceProvider) getDecryptCert() (*tls.Certificate, error) { if sp.SPKeyStore == nil { return nil, fmt.Errorf("no decryption certs available") } //This is the tls.Certificate we'll use to decrypt any encrypted assertions var decryptCert tls.Certificate switch crt := sp.SPKeyStore.(type) { case dsig.TLSCertKeyStore: // Get the tls.Certificate directly if possible decryptCert = tls.Certificate(crt) default: //Otherwise, construct one from the results of GetKeyPair pk, cert, err := sp.SPKeyStore.GetKeyPair() if err != nil { return nil, fmt.Errorf("error getting keypair: %v", err) } decryptCert = tls.Certificate{ Certificate: [][]byte{cert}, PrivateKey: pk, } } if sp.ValidateEncryptionCert { // Check Validity period of certificate if len(decryptCert.Certificate) < 1 || len(decryptCert.Certificate[0]) < 1 { return nil, fmt.Errorf("empty decryption cert") } else if cert, err := x509.ParseCertificate(decryptCert.Certificate[0]); err != nil { return nil, fmt.Errorf("invalid x509 decryption cert: %v", err) } else { now := sp.Clock.Now() if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { return nil, fmt.Errorf("decryption cert is not valid at this time") } } } return &decryptCert, nil } func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error { var decryptCert *tls.Certificate decryptAssertion := func(ctx etreeutils.NSContext, encryptedElement *etree.Element) error { if encryptedElement.Parent() != el { return fmt.Errorf("found encrypted assertion with unexpected parent element: %s", encryptedElement.Parent().Tag) } detached, err := etreeutils.NSDetatch(ctx, encryptedElement) // make a detached copy if err != nil { return fmt.Errorf("unable to detach encrypted assertion: %v", err) } encryptedAssertion := &types.EncryptedAssertion{} err = xmlUnmarshalElement(detached, encryptedAssertion) if err != nil { return fmt.Errorf("unable to unmarshal encrypted assertion: %v", err) } if decryptCert == nil { decryptCert, err = sp.getDecryptCert() if err != nil { return fmt.Errorf("unable to get decryption certificate: %v", err) } } raw, derr := encryptedAssertion.DecryptBytes(decryptCert) if derr != nil { return fmt.Errorf("unable to decrypt encrypted assertion: %v", derr) } doc, _, err := parseResponse(raw, sp.MaximumDecompressedBodySize) if err != nil { return fmt.Errorf("unable to create element from decrypted assertion bytes: %v", derr) } // Replace the original encrypted assertion with the decrypted one. if el.RemoveChild(encryptedElement) == nil { // Out of an abundance of caution, make sure removed worked panic("unable to remove encrypted assertion") } el.AddChild(doc.Root()) return nil } return etreeutils.NSFindIterate(el, SAMLAssertionNamespace, EncryptedAssertionTag, decryptAssertion) } func (sp *SAMLServiceProvider) validateElementSignature(el *etree.Element) (*etree.Element, error) { return sp.validationContext().Validate(el) } func (sp *SAMLServiceProvider) validateAssertionSignatures(el *etree.Element) error { signedAssertions := 0 unsignedAssertions := 0 validateAssertion := func(ctx etreeutils.NSContext, unverifiedAssertion *etree.Element) error { parent := unverifiedAssertion.Parent() if parent == nil { return fmt.Errorf("parent is nil") } if parent != el { return fmt.Errorf("found assertion with unexpected parent element: %s", unverifiedAssertion.Parent().Tag) } detached, err := etreeutils.NSDetatch(ctx, unverifiedAssertion) // make a detached copy if err != nil { return fmt.Errorf("unable to detach unverified assertion: %v", err) } assertion, err := sp.validationContext().Validate(detached) if err == dsig.ErrMissingSignature { unsignedAssertions++ return nil } else if err != nil { return err } // Replace the original unverified Assertion with the verified one. Note that // if the Response is not signed, only signed Assertions (and not the parent Response) can be trusted. if el.RemoveChild(unverifiedAssertion) == nil { // Out of an abundance of caution, check to make sure an Assertion was actually // removed. If it wasn't a programming error has occurred. panic("unable to remove assertion") } el.AddChild(assertion) signedAssertions++ return nil } if err := etreeutils.NSFindIterate(el, SAMLAssertionNamespace, AssertionTag, validateAssertion); err != nil { return err } else if signedAssertions > 0 && unsignedAssertions > 0 { return fmt.Errorf("invalid to have both signed and unsigned assertions") } else if signedAssertions < 1 { return dsig.ErrMissingSignature } else { return nil } } // ValidateEncodedResponse both decodes and validates, based on SP // configuration, an encoded, signed response. It will also appropriately // decrypt a response if the assertion was encrypted func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (*types.Response, error) { raw, err := base64.StdEncoding.DecodeString(encodedResponse) if err != nil { return nil, err } // Parse the raw response doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize) if err != nil { return nil, err } elAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, AssertionTag) if err != nil { return nil, err } elEncAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, EncryptedAssertionTag) if err != nil { return nil, err } // We verify that either one of assertion or encrypted assertion elements are present, // but not both. if (elAssertion == nil) == (elEncAssertion == nil) { return nil, fmt.Errorf("found both or no assertion and encrypted assertion elements") } // And if a decryptCert is present, then it's only encrypted assertion elements. if sp.SPKeyStore != nil && elAssertion != nil { return nil, fmt.Errorf("all assertions are not encrypted") } var responseSignatureValidated bool if !sp.SkipSignatureValidation { el, err = sp.validateElementSignature(el) if err == dsig.ErrMissingSignature { // Unfortunately we just blew away our Response el = doc.Root() } else if err != nil { return nil, err } else if el == nil { return nil, fmt.Errorf("missing transformed response") } else { responseSignatureValidated = true } } err = sp.decryptAssertions(el) if err != nil { return nil, err } var assertionSignaturesValidated bool if !sp.SkipSignatureValidation { err = sp.validateAssertionSignatures(el) if err == dsig.ErrMissingSignature { if !responseSignatureValidated { return nil, fmt.Errorf("response and/or assertions must be signed") } } else if err != nil { return nil, err } else { assertionSignaturesValidated = true } } decodedResponse := &types.Response{} err = xmlUnmarshalElement(el, decodedResponse) if err != nil { return nil, fmt.Errorf("unable to unmarshal response: %v", err) } decodedResponse.SignatureValidated = responseSignatureValidated if assertionSignaturesValidated { for idx := 0; idx < len(decodedResponse.Assertions); idx++ { decodedResponse.Assertions[idx].SignatureValidated = true } } err = sp.Validate(decodedResponse) if err != nil { return nil, err } return decodedResponse, nil } // DecodeUnverifiedBaseResponse decodes several attributes from a SAML response for the purpose // of determining how to validate the response. This is useful for Service Providers which // expose a single Assertion Consumer Service URL but consume Responses from many IdPs. func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBaseResponse, error) { raw, err := base64.StdEncoding.DecodeString(encodedResponse) if err != nil { return nil, err } var response *types.UnverifiedBaseResponse err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error { response = &types.UnverifiedBaseResponse{} return xml.Unmarshal(maybeXML, response) }) if err != nil { return nil, err } return response, nil } // maybeDeflate invokes the passed decoder over the passed data. If an error is // returned, it then attempts to deflate the passed data before re-invoking // the decoder over the deflated data. func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error { err := decoder(data) if err == nil { return nil } // Default to 5MB max size if maxSize == 0 { maxSize = defaultMaxDecompressedResponseSize } lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1) deflated, err := io.ReadAll(lr) if err != nil { return err } if int64(len(deflated)) > maxSize { return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize) } return decoder(deflated) } // parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested func parseResponse(xml []byte, maxSize int64) (*etree.Document, *etree.Element, error) { var doc *etree.Document var rawXML []byte err := maybeDeflate(xml, maxSize, func(xml []byte) error { doc = etree.NewDocument() rawXML = xml return doc.ReadFromBytes(xml) }) if err != nil { return nil, nil, err } el := doc.Root() if el == nil { return nil, nil, fmt.Errorf("unable to parse response") } // Examine the response for attempts to exploit weaknesses in Go's encoding/xml err = rtvalidator.Validate(bytes.NewReader(rawXML)) if err != nil { return nil, nil, err } return doc, el, nil } // DecodeUnverifiedLogoutResponse decodes several attributes from a SAML Logout response, without doing any verifications. func DecodeUnverifiedLogoutResponse(encodedResponse string) (*types.LogoutResponse, error) { raw, err := base64.StdEncoding.DecodeString(encodedResponse) if err != nil { return nil, err } var response *types.LogoutResponse err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error { response = &types.LogoutResponse{} return xml.Unmarshal(maybeXML, response) }) if err != nil { return nil, err } return response, nil } func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse string) (*types.LogoutResponse, error) { raw, err := base64.StdEncoding.DecodeString(encodedResponse) if err != nil { return nil, err } // Parse the raw response doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize) if err != nil { return nil, err } var responseSignatureValidated bool if !sp.SkipSignatureValidation { el, err = sp.validateElementSignature(el) if err == dsig.ErrMissingSignature { // Unfortunately we just blew away our Response el = doc.Root() } else if err != nil { return nil, err } else if el == nil { return nil, fmt.Errorf("missing transformed logout response") } else { responseSignatureValidated = true } } decodedResponse := &types.LogoutResponse{} err = xmlUnmarshalElement(el, decodedResponse) if err != nil { return nil, fmt.Errorf("unable to unmarshal logout response: %v", err) } decodedResponse.SignatureValidated = responseSignatureValidated err = sp.ValidateDecodedLogoutResponse(decodedResponse) if err != nil { return nil, err } return decodedResponse, nil }