diff --git a/e2e/test_source_oauth.py b/e2e/test_source_oauth.py index 696833e64..7af40227b 100644 --- a/e2e/test_source_oauth.py +++ b/e2e/test_source_oauth.py @@ -1,6 +1,7 @@ """test OAuth Source""" from os.path import abspath from sys import platform +from time import sleep from typing import Any, Dict, Optional from unittest.case import skipUnless @@ -200,10 +201,11 @@ class TestSourceOAuth(SeleniumTestCase): (By.CLASS_NAME, "pf-c-login__main-footer-links-item-link") ) ) + sleep(1) self.driver.find_element( By.CLASS_NAME, "pf-c-login__main-footer-links-item-link" ).click() - + sleep(1) # Now we should be at the IDP, wait for the login field self.wait.until(ec.presence_of_element_located((By.ID, "login"))) self.driver.find_element(By.ID, "login").send_keys("admin@example.com") diff --git a/passbook/sources/saml/processors/response.py b/passbook/sources/saml/processors/response.py index 51f0f443e..d1b2d98c8 100644 --- a/passbook/sources/saml/processors/response.py +++ b/passbook/sources/saml/processors/response.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Dict from defusedxml import ElementTree +from django.core.cache import cache +from django.core.exceptions import SuspiciousOperation from django.http import HttpRequest, HttpResponse from signxml import XMLVerifier from structlog import get_logger @@ -37,6 +39,8 @@ from passbook.stages.prompt.stage import PLAN_CONTEXT_PROMPT LOGGER = get_logger() if TYPE_CHECKING: from xml.etree.ElementTree import Element # nosec + +CACHE_SEEN_REQUEST_ID = "passbook_saml_seen_ids_%s" DEFAULT_BACKEND = "django.contrib.auth.backends.ModelBackend" @@ -75,6 +79,13 @@ class ResponseProcessor: def _verify_request_id(self, request: HttpRequest): if self._source.allow_idp_initiated: + # If IdP-initiated SSO flows are enabled, we want to cache the Response ID + # somewhat mitigate replay attacks + seen_ids = cache.get(CACHE_SEEN_REQUEST_ID % self._source.pk, []) + if self._root.attrib["ID"] in seen_ids: + raise SuspiciousOperation("Replay attack detected") + seen_ids.append(self._root.attrib["ID"]) + cache.set(CACHE_SEEN_REQUEST_ID % self._source.pk, seen_ids) return if ( SESSION_REQUEST_ID not in request.session