"""SAML AuthnRequest Processor""" from base64 import b64encode from typing import Dict from urllib.parse import quote_plus import xmlsec from django.http import HttpRequest from lxml import etree # nosec from lxml.etree import Element # nosec from passbook.providers.saml.utils import get_random_id from passbook.providers.saml.utils.encoding import deflate_and_base64_encode from passbook.providers.saml.utils.time import get_time_string from passbook.sources.saml.models import SAMLSource from passbook.sources.saml.processors.constants import ( DIGEST_ALGORITHM_TRANSLATION_MAP, NS_MAP, NS_SAML_ASSERTION, NS_SAML_PROTOCOL, SIGN_ALGORITHM_TRANSFORM_MAP, ) SESSION_REQUEST_ID = "passbook_source_saml_request_id" class RequestProcessor: """SAML AuthnRequest Processor""" source: SAMLSource http_request: HttpRequest relay_state: str request_id: str issue_instant: str def __init__(self, source: SAMLSource, request: HttpRequest, relay_state: str): self.source = source self.http_request = request self.relay_state = relay_state self.request_id = get_random_id() self.http_request.session[SESSION_REQUEST_ID] = self.request_id self.issue_instant = get_time_string() def get_issuer(self) -> Element: """Get Issuer Element""" issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer") issuer.text = self.source.get_issuer(self.http_request) return issuer def get_name_id_policy(self) -> Element: """Get NameID Policy Element""" name_id_policy = Element(f"{{{NS_SAML_PROTOCOL}}}NameIDPolicy") name_id_policy.attrib["Format"] = self.source.name_id_policy return name_id_policy def get_auth_n(self) -> Element: """Get full AuthnRequest""" auth_n_request = Element(f"{{{NS_SAML_PROTOCOL}}}AuthnRequest", nsmap=NS_MAP) auth_n_request.attrib[ "AssertionConsumerServiceURL" ] = self.source.build_full_url(self.http_request) auth_n_request.attrib["Destination"] = self.source.sso_url auth_n_request.attrib["ID"] = self.request_id auth_n_request.attrib["IssueInstant"] = self.issue_instant auth_n_request.attrib["ProtocolBinding"] = self.source.binding_type auth_n_request.attrib["Version"] = "2.0" # Create issuer object auth_n_request.append(self.get_issuer()) if self.source.signing_kp: sign_algorithm_transform = SIGN_ALGORITHM_TRANSFORM_MAP.get( self.source.signature_algorithm, xmlsec.constants.TransformRsaSha1 ) signature = xmlsec.template.create( auth_n_request, xmlsec.constants.TransformExclC14N, sign_algorithm_transform, ns="ds", # type: ignore ) auth_n_request.append(signature) # Create NameID Policy Object auth_n_request.append(self.get_name_id_policy()) return auth_n_request def build_auth_n(self) -> str: """Get Signed string representation of AuthN Request (used for POST Bindings)""" auth_n_request = self.get_auth_n() if self.source.signing_kp: xmlsec.tree.add_ids(auth_n_request, ["ID"]) ctx = xmlsec.SignatureContext() key = xmlsec.Key.from_memory( self.source.signing_kp.key_data, xmlsec.constants.KeyDataFormatPem, None ) key.load_cert_from_memory( self.source.signing_kp.certificate_data, xmlsec.constants.KeyDataFormatCertPem, ) ctx.key = key digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get( self.source.digest_algorithm, xmlsec.constants.TransformSha1 ) signature_node = xmlsec.tree.find_node( auth_n_request, xmlsec.constants.NodeSignature ) ref = xmlsec.template.add_reference( signature_node, digest_algorithm_transform, uri="#" + auth_n_request.attrib["ID"], ) xmlsec.template.add_transform(ref, xmlsec.constants.TransformEnveloped) xmlsec.template.add_transform(ref, xmlsec.constants.TransformExclC14N) key_info = xmlsec.template.ensure_key_info(signature_node) xmlsec.template.add_x509_data(key_info) ctx.sign(signature_node) return etree.tostring(auth_n_request).decode() def build_auth_n_detached(self) -> Dict[str, str]: """Get Dict AuthN Request for Redirect bindings, with detached Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf""" auth_n_request = self.get_auth_n() saml_request = deflate_and_base64_encode( etree.tostring(auth_n_request).decode() ) response_dict = { "SAMLRequest": saml_request, } if self.relay_state != "": response_dict["RelayState"] = self.relay_state if self.source.signing_kp: sign_algorithm_transform = SIGN_ALGORITHM_TRANSFORM_MAP.get( self.source.signature_algorithm, xmlsec.constants.TransformRsaSha1 ) # Create the full querystring in the correct order to be signed querystring = f"SAMLRequest={quote_plus(saml_request)}&" if "RelayState" in response_dict: querystring += f"RelayState={quote_plus(response_dict['RelayState'])}&" querystring += f"SigAlg={quote_plus(self.source.signature_algorithm)}" ctx = xmlsec.SignatureContext() key = xmlsec.Key.from_memory( self.source.signing_kp.key_data, xmlsec.constants.KeyDataFormatPem, None ) key.load_cert_from_memory( self.source.signing_kp.certificate_data, xmlsec.constants.KeyDataFormatPem, ) ctx.key = key signature = ctx.sign_binary( querystring.encode("utf-8"), sign_algorithm_transform ) response_dict["Signature"] = b64encode(signature).decode() response_dict["SigAlg"] = self.source.signature_algorithm return response_dict