479 lines
14 KiB
Go
479 lines
14 KiB
Go
// Copyright 2016 Russell Haering et al.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package saml2
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/flate"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
|
|
"encoding/xml"
|
|
|
|
"github.com/beevik/etree"
|
|
"github.com/mattermost/gosaml2/types"
|
|
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
|
|
dsig "github.com/russellhaering/goxmldsig"
|
|
"github.com/russellhaering/goxmldsig/etreeutils"
|
|
)
|
|
|
|
const (
|
|
defaultMaxDecompressedResponseSize = 5 * 1024 * 1024
|
|
)
|
|
|
|
func (sp *SAMLServiceProvider) validationContext() *dsig.ValidationContext {
|
|
ctx := dsig.NewDefaultValidationContext(sp.IDPCertificateStore)
|
|
ctx.Clock = sp.Clock
|
|
return ctx
|
|
}
|
|
|
|
// validateResponseAttributes validates a SAML Response's tag and attributes. It does
|
|
// not inspect child elements of the Response at all.
|
|
func (sp *SAMLServiceProvider) validateResponseAttributes(response *types.Response) error {
|
|
if response.Destination != "" && response.Destination != sp.AssertionConsumerServiceURL {
|
|
return ErrInvalidValue{
|
|
Key: DestinationAttr,
|
|
Expected: sp.AssertionConsumerServiceURL,
|
|
Actual: response.Destination,
|
|
}
|
|
}
|
|
|
|
if response.Version != "2.0" {
|
|
return ErrInvalidValue{
|
|
Reason: ReasonUnsupported,
|
|
Key: "SAML version",
|
|
Expected: "2.0",
|
|
Actual: response.Version,
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateLogoutResponseAttributes validates a SAML Response's tag and attributes. It does
|
|
// not inspect child elements of the Response at all.
|
|
func (sp *SAMLServiceProvider) validateLogoutResponseAttributes(response *types.LogoutResponse) error {
|
|
if response.Destination != "" && response.Destination != sp.ServiceProviderSLOURL {
|
|
return ErrInvalidValue{
|
|
Key: DestinationAttr,
|
|
Expected: sp.ServiceProviderSLOURL,
|
|
Actual: response.Destination,
|
|
}
|
|
}
|
|
|
|
if response.Version != "2.0" {
|
|
return ErrInvalidValue{
|
|
Reason: ReasonUnsupported,
|
|
Key: "SAML version",
|
|
Expected: "2.0",
|
|
Actual: response.Version,
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func xmlUnmarshalElement(el *etree.Element, obj interface{}) error {
|
|
doc := etree.NewDocument()
|
|
doc.SetRoot(el)
|
|
data, err := doc.WriteToBytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = xml.Unmarshal(data, obj)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (sp *SAMLServiceProvider) getDecryptCert() (*tls.Certificate, error) {
|
|
if sp.SPKeyStore == nil {
|
|
return nil, fmt.Errorf("no decryption certs available")
|
|
}
|
|
|
|
//This is the tls.Certificate we'll use to decrypt any encrypted assertions
|
|
var decryptCert tls.Certificate
|
|
|
|
switch crt := sp.SPKeyStore.(type) {
|
|
case dsig.TLSCertKeyStore:
|
|
// Get the tls.Certificate directly if possible
|
|
decryptCert = tls.Certificate(crt)
|
|
|
|
default:
|
|
|
|
//Otherwise, construct one from the results of GetKeyPair
|
|
pk, cert, err := sp.SPKeyStore.GetKeyPair()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting keypair: %v", err)
|
|
}
|
|
|
|
decryptCert = tls.Certificate{
|
|
Certificate: [][]byte{cert},
|
|
PrivateKey: pk,
|
|
}
|
|
}
|
|
|
|
if sp.ValidateEncryptionCert {
|
|
// Check Validity period of certificate
|
|
if len(decryptCert.Certificate) < 1 || len(decryptCert.Certificate[0]) < 1 {
|
|
return nil, fmt.Errorf("empty decryption cert")
|
|
} else if cert, err := x509.ParseCertificate(decryptCert.Certificate[0]); err != nil {
|
|
return nil, fmt.Errorf("invalid x509 decryption cert: %v", err)
|
|
} else {
|
|
now := sp.Clock.Now()
|
|
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
|
|
return nil, fmt.Errorf("decryption cert is not valid at this time")
|
|
}
|
|
}
|
|
}
|
|
|
|
return &decryptCert, nil
|
|
}
|
|
|
|
func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
|
|
var decryptCert *tls.Certificate
|
|
|
|
decryptAssertion := func(ctx etreeutils.NSContext, encryptedElement *etree.Element) error {
|
|
if encryptedElement.Parent() != el {
|
|
return fmt.Errorf("found encrypted assertion with unexpected parent element: %s", encryptedElement.Parent().Tag)
|
|
}
|
|
|
|
detached, err := etreeutils.NSDetatch(ctx, encryptedElement) // make a detached copy
|
|
if err != nil {
|
|
return fmt.Errorf("unable to detach encrypted assertion: %v", err)
|
|
}
|
|
|
|
encryptedAssertion := &types.EncryptedAssertion{}
|
|
err = xmlUnmarshalElement(detached, encryptedAssertion)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to unmarshal encrypted assertion: %v", err)
|
|
}
|
|
|
|
if decryptCert == nil {
|
|
decryptCert, err = sp.getDecryptCert()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to get decryption certificate: %v", err)
|
|
}
|
|
}
|
|
|
|
raw, derr := encryptedAssertion.DecryptBytes(decryptCert)
|
|
if derr != nil {
|
|
return fmt.Errorf("unable to decrypt encrypted assertion: %v", derr)
|
|
}
|
|
|
|
doc, _, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create element from decrypted assertion bytes: %v", derr)
|
|
}
|
|
|
|
// Replace the original encrypted assertion with the decrypted one.
|
|
if el.RemoveChild(encryptedElement) == nil {
|
|
// Out of an abundance of caution, make sure removed worked
|
|
panic("unable to remove encrypted assertion")
|
|
}
|
|
|
|
el.AddChild(doc.Root())
|
|
return nil
|
|
}
|
|
|
|
return etreeutils.NSFindIterate(el, SAMLAssertionNamespace, EncryptedAssertionTag, decryptAssertion)
|
|
}
|
|
|
|
func (sp *SAMLServiceProvider) validateElementSignature(el *etree.Element) (*etree.Element, error) {
|
|
return sp.validationContext().Validate(el)
|
|
}
|
|
|
|
func (sp *SAMLServiceProvider) validateAssertionSignatures(el *etree.Element) error {
|
|
signedAssertions := 0
|
|
unsignedAssertions := 0
|
|
validateAssertion := func(ctx etreeutils.NSContext, unverifiedAssertion *etree.Element) error {
|
|
parent := unverifiedAssertion.Parent()
|
|
if parent == nil {
|
|
return fmt.Errorf("parent is nil")
|
|
}
|
|
if parent != el {
|
|
return fmt.Errorf("found assertion with unexpected parent element: %s", unverifiedAssertion.Parent().Tag)
|
|
}
|
|
|
|
detached, err := etreeutils.NSDetatch(ctx, unverifiedAssertion) // make a detached copy
|
|
if err != nil {
|
|
return fmt.Errorf("unable to detach unverified assertion: %v", err)
|
|
}
|
|
|
|
assertion, err := sp.validationContext().Validate(detached)
|
|
if err == dsig.ErrMissingSignature {
|
|
unsignedAssertions++
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Replace the original unverified Assertion with the verified one. Note that
|
|
// if the Response is not signed, only signed Assertions (and not the parent Response) can be trusted.
|
|
if el.RemoveChild(unverifiedAssertion) == nil {
|
|
// Out of an abundance of caution, check to make sure an Assertion was actually
|
|
// removed. If it wasn't a programming error has occurred.
|
|
panic("unable to remove assertion")
|
|
}
|
|
|
|
el.AddChild(assertion)
|
|
signedAssertions++
|
|
|
|
return nil
|
|
}
|
|
|
|
if err := etreeutils.NSFindIterate(el, SAMLAssertionNamespace, AssertionTag, validateAssertion); err != nil {
|
|
return err
|
|
} else if signedAssertions > 0 && unsignedAssertions > 0 {
|
|
return fmt.Errorf("invalid to have both signed and unsigned assertions")
|
|
} else if signedAssertions < 1 {
|
|
return dsig.ErrMissingSignature
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ValidateEncodedResponse both decodes and validates, based on SP
|
|
// configuration, an encoded, signed response. It will also appropriately
|
|
// decrypt a response if the assertion was encrypted
|
|
func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (*types.Response, error) {
|
|
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse the raw response
|
|
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
elAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, AssertionTag)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
elEncAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, EncryptedAssertionTag)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// We verify that either one of assertion or encrypted assertion elements are present,
|
|
// but not both.
|
|
if (elAssertion == nil) == (elEncAssertion == nil) {
|
|
return nil, fmt.Errorf("found both or no assertion and encrypted assertion elements")
|
|
}
|
|
// And if a decryptCert is present, then it's only encrypted assertion elements.
|
|
if sp.SPKeyStore != nil && elAssertion != nil {
|
|
return nil, fmt.Errorf("all assertions are not encrypted")
|
|
}
|
|
|
|
var responseSignatureValidated bool
|
|
if !sp.SkipSignatureValidation {
|
|
el, err = sp.validateElementSignature(el)
|
|
if err == dsig.ErrMissingSignature {
|
|
// Unfortunately we just blew away our Response
|
|
el = doc.Root()
|
|
} else if err != nil {
|
|
return nil, err
|
|
} else if el == nil {
|
|
return nil, fmt.Errorf("missing transformed response")
|
|
} else {
|
|
responseSignatureValidated = true
|
|
}
|
|
}
|
|
|
|
err = sp.decryptAssertions(el)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var assertionSignaturesValidated bool
|
|
if !sp.SkipSignatureValidation {
|
|
err = sp.validateAssertionSignatures(el)
|
|
if err == dsig.ErrMissingSignature {
|
|
if !responseSignatureValidated {
|
|
return nil, fmt.Errorf("response and/or assertions must be signed")
|
|
}
|
|
} else if err != nil {
|
|
return nil, err
|
|
} else {
|
|
assertionSignaturesValidated = true
|
|
}
|
|
}
|
|
|
|
decodedResponse := &types.Response{}
|
|
err = xmlUnmarshalElement(el, decodedResponse)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to unmarshal response: %v", err)
|
|
}
|
|
decodedResponse.SignatureValidated = responseSignatureValidated
|
|
if assertionSignaturesValidated {
|
|
for idx := 0; idx < len(decodedResponse.Assertions); idx++ {
|
|
decodedResponse.Assertions[idx].SignatureValidated = true
|
|
}
|
|
}
|
|
|
|
err = sp.Validate(decodedResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return decodedResponse, nil
|
|
}
|
|
|
|
// DecodeUnverifiedBaseResponse decodes several attributes from a SAML response for the purpose
|
|
// of determining how to validate the response. This is useful for Service Providers which
|
|
// expose a single Assertion Consumer Service URL but consume Responses from many IdPs.
|
|
func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBaseResponse, error) {
|
|
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var response *types.UnverifiedBaseResponse
|
|
|
|
err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
|
|
response = &types.UnverifiedBaseResponse{}
|
|
return xml.Unmarshal(maybeXML, response)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// maybeDeflate invokes the passed decoder over the passed data. If an error is
|
|
// returned, it then attempts to deflate the passed data before re-invoking
|
|
// the decoder over the deflated data.
|
|
func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error {
|
|
err := decoder(data)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// Default to 5MB max size
|
|
if maxSize == 0 {
|
|
maxSize = defaultMaxDecompressedResponseSize
|
|
}
|
|
|
|
lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1)
|
|
|
|
deflated, err := io.ReadAll(lr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if int64(len(deflated)) > maxSize {
|
|
return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize)
|
|
}
|
|
|
|
return decoder(deflated)
|
|
}
|
|
|
|
// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested
|
|
func parseResponse(xml []byte, maxSize int64) (*etree.Document, *etree.Element, error) {
|
|
var doc *etree.Document
|
|
var rawXML []byte
|
|
|
|
err := maybeDeflate(xml, maxSize, func(xml []byte) error {
|
|
doc = etree.NewDocument()
|
|
rawXML = xml
|
|
return doc.ReadFromBytes(xml)
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
el := doc.Root()
|
|
if el == nil {
|
|
return nil, nil, fmt.Errorf("unable to parse response")
|
|
}
|
|
|
|
// Examine the response for attempts to exploit weaknesses in Go's encoding/xml
|
|
err = rtvalidator.Validate(bytes.NewReader(rawXML))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return doc, el, nil
|
|
}
|
|
|
|
// DecodeUnverifiedLogoutResponse decodes several attributes from a SAML Logout response, without doing any verifications.
|
|
func DecodeUnverifiedLogoutResponse(encodedResponse string) (*types.LogoutResponse, error) {
|
|
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var response *types.LogoutResponse
|
|
|
|
err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
|
|
response = &types.LogoutResponse{}
|
|
return xml.Unmarshal(maybeXML, response)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse string) (*types.LogoutResponse, error) {
|
|
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse the raw response
|
|
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var responseSignatureValidated bool
|
|
if !sp.SkipSignatureValidation {
|
|
el, err = sp.validateElementSignature(el)
|
|
if err == dsig.ErrMissingSignature {
|
|
// Unfortunately we just blew away our Response
|
|
el = doc.Root()
|
|
} else if err != nil {
|
|
return nil, err
|
|
} else if el == nil {
|
|
return nil, fmt.Errorf("missing transformed logout response")
|
|
} else {
|
|
responseSignatureValidated = true
|
|
}
|
|
}
|
|
|
|
decodedResponse := &types.LogoutResponse{}
|
|
err = xmlUnmarshalElement(el, decodedResponse)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to unmarshal logout response: %v", err)
|
|
}
|
|
decodedResponse.SignatureValidated = responseSignatureValidated
|
|
|
|
err = sp.ValidateDecodedLogoutResponse(decodedResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return decodedResponse, nil
|
|
}
|