diff --git a/e2e/test_enroll_2_step.py b/e2e/test_enroll_2_step.py index 143bf1cfc..111284475 100644 --- a/e2e/test_enroll_2_step.py +++ b/e2e/test_enroll_2_step.py @@ -6,7 +6,7 @@ from selenium.webdriver.common.desired_capabilities import DesiredCapabilities from selenium.webdriver.support import expected_conditions as ec from selenium.webdriver.support.ui import WebDriverWait -from e2e.utils import apply_default_data +from e2e.utils import SeleniumTestCase from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding from passbook.policies.expression.models import ExpressionPolicy from passbook.policies.models import PolicyBinding @@ -16,22 +16,9 @@ from passbook.stages.user_login.models import UserLoginStage from passbook.stages.user_write.models import UserWriteStage -class TestEnroll2Step(StaticLiveServerTestCase): +class TestEnroll2Step(SeleniumTestCase): """Test 2-step enroll flow""" - def setUp(self): - self.driver = webdriver.Remote( - command_executor="http://localhost:4444/wd/hub", - desired_capabilities=DesiredCapabilities.CHROME, - ) - self.wait = WebDriverWait(self.driver, 10) - self.driver.implicitly_wait(5) - apply_default_data() - - def tearDown(self): - super().tearDown() - self.driver.quit() - def test_enroll_2_step(self): """Test 2-step enroll flow""" # First stage fields diff --git a/e2e/test_login_default.py b/e2e/test_login_default.py index c8bf5f445..0925d7b3f 100644 --- a/e2e/test_login_default.py +++ b/e2e/test_login_default.py @@ -1,28 +1,15 @@ """test default login flow""" -from django.contrib.staticfiles.testing import StaticLiveServerTestCase from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.common.desired_capabilities import DesiredCapabilities from selenium.webdriver.common.keys import Keys -from e2e.utils import apply_default_data +from e2e.utils import SeleniumTestCase -class TestLogin(StaticLiveServerTestCase): +class TestLogin(SeleniumTestCase): """test default login flow""" - def setUp(self): - self.driver = webdriver.Remote( - command_executor="http://localhost:4444/wd/hub", - desired_capabilities=DesiredCapabilities.CHROME, - ) - self.driver.implicitly_wait(5) - apply_default_data() - - def tearDown(self): - super().tearDown() - self.driver.quit() - def test_login(self): """test default login flow""" self.driver.get(f"{self.live_server_url}/flows/default-authentication-flow/") diff --git a/e2e/test_provider_oidc.py b/e2e/test_provider_oidc.py index 52c9fa68d..127f6fd40 100644 --- a/e2e/test_provider_oidc.py +++ b/e2e/test_provider_oidc.py @@ -12,22 +12,17 @@ from selenium.webdriver.common.keys import Keys from docker import DockerClient, from_env from docker.models.containers import Container from docker.types import Healthcheck -from e2e.utils import apply_default_data, ensure_rsa_key +from e2e.utils import SeleniumTestCase, ensure_rsa_key from passbook.core.models import Application from passbook.flows.models import Flow from passbook.providers.oidc.models import OpenIDProvider -class TestProviderOIDC(StaticLiveServerTestCase): +class TestProviderOIDC(SeleniumTestCase): """test OpenID Provider flow""" def setUp(self): - self.driver = webdriver.Remote( - command_executor="http://localhost:4444/wd/hub", - desired_capabilities=DesiredCapabilities.CHROME, - ) - self.driver.implicitly_wait(5) - apply_default_data() + super().setUp() self.client_id = generate_client_id() self.client_secret = generate_client_secret() self.container = self.setup_client() diff --git a/e2e/utils.py b/e2e/utils.py index 5434e7da5..61fdf9b1b 100644 --- a/e2e/utils.py +++ b/e2e/utils.py @@ -1,41 +1,17 @@ """passbook e2e testing utilities""" - from glob import glob from importlib.util import module_from_spec, spec_from_file_location from inspect import getmembers, isfunction from Cryptodome.PublicKey import RSA from django.apps import apps +from django.contrib.staticfiles.testing import StaticLiveServerTestCase from django.db import connection, transaction from django.db.utils import IntegrityError - - -def apply_default_data(): - """apply objects created by migrations after tables have been truncated""" - # Find all migration files - # load all functions - migration_files = glob("**/migrations/*.py", recursive=True) - matches = [] - for migration in migration_files: - with open(migration, "r+") as migration_file: - # Check if they have a `RunPython` - if "RunPython" in migration_file.read(): - matches.append(migration) - - with connection.schema_editor() as schema_editor: - for match in matches: - # Load module from file path - spec = spec_from_file_location("", match) - migration_module = module_from_spec(spec) - # pyright: reportGeneralTypeIssues=false - spec.loader.exec_module(migration_module) - # Call all functions from module - for _, func in getmembers(migration_module, isfunction): - with transaction.atomic(): - try: - func(apps, schema_editor) - except IntegrityError: - pass +from selenium import webdriver +from selenium.webdriver.common.desired_capabilities import DesiredCapabilities +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.support.ui import WebDriverWait def ensure_rsa_key(): @@ -46,3 +22,49 @@ def ensure_rsa_key(): key = RSA.generate(2048) rsakey = RSAKey(key=key.exportKey("PEM").decode("utf8")) rsakey.save() + + +class SeleniumTestCase(StaticLiveServerTestCase): + def setUp(self): + super().setUp() + self.driver = self._get_driver() + self.driver.implicitly_wait(5) + self.wait = WebDriverWait(self.driver, 10) + self.apply_default_data() + + def _get_driver(self) -> WebDriver: + return webdriver.Remote( + command_executor="http://localhost:4444/wd/hub", + desired_capabilities=DesiredCapabilities.CHROME, + ) + + def tearDown(self): + super().tearDown() + self.driver.quit() + + def apply_default_data(self): + """apply objects created by migrations after tables have been truncated""" + # Find all migration files + # load all functions + migration_files = glob("**/migrations/*.py", recursive=True) + matches = [] + for migration in migration_files: + with open(migration, "r+") as migration_file: + # Check if they have a `RunPython` + if "RunPython" in migration_file.read(): + matches.append(migration) + + with connection.schema_editor() as schema_editor: + for match in matches: + # Load module from file path + spec = spec_from_file_location("", match) + migration_module = module_from_spec(spec) + # pyright: reportGeneralTypeIssues=false + spec.loader.exec_module(migration_module) + # Call all functions from module + for _, func in getmembers(migration_module, isfunction): + with transaction.atomic(): + try: + func(apps, schema_editor) + except IntegrityError: + pass diff --git a/passbook/flows/views.py b/passbook/flows/views.py index 1625e2c01..eea803b43 100644 --- a/passbook/flows/views.py +++ b/passbook/flows/views.py @@ -215,7 +215,9 @@ class FlowExecutorShellView(TemplateView): kwargs["exec_url"] = reverse("passbook_flows:flow-executor", kwargs=self.kwargs) kwargs["msg_url"] = reverse("passbook_api:messages-list") if NEXT_ARG_NAME in self.request.GET: - self.request.session[SESSION_KEY_NEXT] = self.request.GET[NEXT_ARG_NAME] + next_arg = self.request.GET[NEXT_ARG_NAME] + LOGGER.debug("f(exec/shell): Saved next param", next=next_arg) + self.request.session[SESSION_KEY_NEXT] = next_arg return kwargs