#!/usr/bin/env python

# This script works around a common issue where xmlsec1 is not compiled with
# support for the `<RetrievalMethod>` to discover where the encrypted key is
# located, which causes problems for SAML2 responses that might place the
# `<EncryptedKey>` outside of the `<EncryptedData>` tree. This script avoids
# this issue by manually resolving the `<RetrievalMethod>`s and moving the
# `<EncryptedKey>` under the `<EncryptedData>`.
#
# Note: If the xmlsec1 is not in the path, it can be explicitly specified with
# the environment variable `XMLSEC1_PATH=...`.

import os
import subprocess
import sys
import tempfile

try:
    from lxml import etree
except ImportError:
    try:
        # Python 2.5
        import xml.etree.cElementTree as etree
    except ImportError:
        try:
            # Python 2.5
            import xml.etree.ElementTree as etree
        except ImportError:
            try:
                # normal cElementTree install
                import cElementTree as etree
            except ImportError:
                # normal ElementTree install
                import elementtree.ElementTree as etree

NS_SAML = 'urn:oasis:names:tc:SAML:2.0:assertion'
NS_XENC = 'http://www.w3.org/2001/04/xmlenc#'
NS_DS = 'http://www.w3.org/2000/09/xmldsig#'

NSMAP = {
    'saml': NS_SAML,
    'ds': NS_DS,
    'xenc': NS_XENC
}

XMLSEC1_PATH = os.environ.get('XMLSEC1_PATH', 'xmlsec1')


def run_xmlsec(args):
    cmd = [XMLSEC1_PATH] + args
    p = subprocess.Popen(cmd)
    return p.wait()


def modify_assertion(args, root):
    encrypted_assertion_nodes = root.xpath(
        '//saml:EncryptedAssertion',
        namespaces=NSMAP)

    if encrypted_assertion_nodes:
        encrypted_data_nodes = encrypted_assertion_nodes[0].xpath(
            '//saml:EncryptedAssertion/xenc:EncryptedData',
            namespaces=NSMAP)

        if encrypted_data_nodes:
            keyinfo = encrypted_assertion_nodes[0].xpath(
                '//saml:EncryptedAssertion/xenc:EncryptedData/ds:KeyInfo',
                namespaces=NSMAP)

            if not keyinfo:
                raise Exception('No KeyInfo present, invalid Assertion')

            keyinfo = keyinfo[0]
            children = keyinfo.getchildren()
            if not children:
                raise Exception('No child to KeyInfo, invalid Assertion')

            for child in children:
                if 'RetrievalMethod' in child.tag:
                    if child.attrib['Type'] != \
                            'http://www.w3.org/2001/04/xmlenc#EncryptedKey':
                        raise Exception('Unsupported Retrieval Method found')

                    uri = child.attrib['URI']
                    if not uri.startswith('#'):
                        break

                    uri = uri.split('#')[1]
                    encrypted_key = encrypted_assertion_nodes[0].xpath(
                        './xenc:EncryptedKey[@Id="' + uri + '"]',
                        namespaces=NSMAP)

                    if encrypted_key:
                        keyinfo.append(encrypted_key[0])

    with tempfile.NamedTemporaryFile(suffix='.xml') as f:
        xml = etree.tostring(root,
                             pretty_print=True,
                             xml_declaration=True,
                             encoding='utf-8')
        f.write(xml)
        f.flush()

        return run_xmlsec(args + [f.name])


def main():
    args = sys.argv[1:]

    if len(args) > 0:
        # Only try to modify the assertion if we are decrypting.
        if args[0] == '--decrypt':
            filename = args[-1]

            with open(filename) as f:
                parser = etree.XMLParser(remove_blank_text=True)
                root = etree.parse(f, parser)
                return modify_assertion(args[:-1], root)

    # Otherwise, fall back to just running xmlsec.
    return run_xmlsec(args)


if __name__ == '__main__':
    sys.exit(main())