diff --git a/passbook/flows/views.py b/passbook/flows/views.py index 499e7413c..856da2b7b 100644 --- a/passbook/flows/views.py +++ b/passbook/flows/views.py @@ -29,6 +29,7 @@ LOGGER = get_logger() # Argument used to redirect user after login NEXT_ARG_NAME = "next" SESSION_KEY_PLAN = "passbook_flows_plan" +SESSION_KEY_APPLICATION_PRE = "passbook_flows_application_pre" SESSION_KEY_GET = "passbook_flows_get" @@ -198,8 +199,14 @@ class FlowExecutorView(View): def cancel(self): """Cancel current execution and return a redirect""" - if SESSION_KEY_PLAN in self.request.session: - del self.request.session[SESSION_KEY_PLAN] + keys_to_delete = [ + SESSION_KEY_APPLICATION_PRE, + SESSION_KEY_PLAN, + SESSION_KEY_GET, + ] + for key in keys_to_delete: + if key in self.request.session: + del self.request.session[key] class FlowPermissionDeniedView(PermissionDeniedView): diff --git a/passbook/policies/mixins.py b/passbook/policies/mixins.py index 14e6f6e6c..4712b0df0 100644 --- a/passbook/policies/mixins.py +++ b/passbook/policies/mixins.py @@ -3,12 +3,14 @@ from typing import Optional from django.contrib import messages from django.contrib.auth.mixins import AccessMixin +from django.contrib.auth.views import redirect_to_login from django.http import HttpRequest, HttpResponse from django.shortcuts import redirect from django.utils.translation import gettext as _ from structlog import get_logger from passbook.core.models import Application, Provider, User +from passbook.flows.views import SESSION_KEY_APPLICATION_PRE from passbook.policies.engine import PolicyEngine from passbook.policies.types import PolicyResult @@ -25,6 +27,15 @@ class PolicyAccessMixin(BaseMixin, AccessMixin): """Mixin class for usage in Authorization views. Provider functions to check application access, etc""" + def handle_no_permission(self, application: Optional[Application] = None): + if application: + self.request.session[SESSION_KEY_APPLICATION_PRE] = application + return redirect_to_login( + self.request.get_full_path(), + self.get_login_url(), + self.get_redirect_field_name(), + ) + def handle_no_permission_authorized(self) -> HttpResponse: """Function called when user has no permissions but is authorized""" # TODO: Remove this URL and render the view instead diff --git a/passbook/providers/oauth/views/oauth2.py b/passbook/providers/oauth/views/oauth2.py index b6a122352..1291a2a16 100644 --- a/passbook/providers/oauth/views/oauth2.py +++ b/passbook/providers/oauth/views/oauth2.py @@ -49,6 +49,10 @@ class AuthorizationFlowInitView(PolicyAccessMixin, LoginRequiredMixin, View): application = self.provider_to_application(provider) except Application.DoesNotExist: return self.handle_no_permission_authorized() + # Check if user is unauthenticated, so we pass the application + # for the identification stage + if not request.user.is_authenticated: + return self.handle_no_permission(application) # Check permissions result = self.user_has_access(application) if not result.passing: diff --git a/passbook/providers/oidc/views.py b/passbook/providers/oidc/views.py index 5bdbe7351..1640445d3 100644 --- a/passbook/providers/oidc/views.py +++ b/passbook/providers/oidc/views.py @@ -42,6 +42,10 @@ class AuthorizationFlowInitView(PolicyAccessMixin, LoginRequiredMixin, View): application = self.provider_to_application(provider) except Application.DoesNotExist: return self.handle_no_permission_authorized() + # Check if user is unauthenticated, so we pass the application + # for the identification stage + if not request.user.is_authenticated: + return self.handle_no_permission(application) # Check permissions result = self.user_has_access(application) if not result.passing: diff --git a/passbook/providers/saml/views.py b/passbook/providers/saml/views.py index a43f37c8e..7d4751f55 100644 --- a/passbook/providers/saml/views.py +++ b/passbook/providers/saml/views.py @@ -47,7 +47,7 @@ REQUEST_KEY_RELAY_STATE = "RelayState" SESSION_KEY_AUTH_N_REQUEST = "authn_request" -class SAMLSSOView(LoginRequiredMixin, PolicyAccessMixin, View): +class SAMLSSOView(PolicyAccessMixin, LoginRequiredMixin, View): """"SAML SSO Base View, which plans a flow and injects our final stage. Calls get/post handler.""" @@ -62,7 +62,7 @@ class SAMLSSOView(LoginRequiredMixin, PolicyAccessMixin, View): SAMLProvider, pk=self.application.provider_id ) if not request.user.is_authenticated: - return self.handle_no_permission() + return self.handle_no_permission(self.application) if not self.user_has_access(self.application).passing: return self.handle_no_permission_authorized() # Call the method handler, which checks the SAML Request diff --git a/passbook/stages/identification/stage.py b/passbook/stages/identification/stage.py index bc3375843..0394a9e47 100644 --- a/passbook/stages/identification/stage.py +++ b/passbook/stages/identification/stage.py @@ -12,6 +12,7 @@ from structlog import get_logger from passbook.core.models import Source, User from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER from passbook.flows.stage import StageView +from passbook.flows.views import SESSION_KEY_APPLICATION_PRE from passbook.stages.identification.forms import IdentificationForm from passbook.stages.identification.models import IdentificationStage @@ -34,6 +35,12 @@ class IdentificationStageView(FormView, StageView): def get_context_data(self, **kwargs): current_stage: IdentificationStage = self.executor.current_stage + # If the user has been redirected to us whilst trying to access an + # application, SESSION_KEY_APPLICATION_PRE is set in the session + if SESSION_KEY_APPLICATION_PRE in self.request.session: + kwargs["application_pre"] = self.request.session[ + SESSION_KEY_APPLICATION_PRE + ] # Check for related enrollment and recovery flow, add URL to view if current_stage.enrollment_flow: kwargs["enroll_url"] = reverse( diff --git a/passbook/stages/identification/templates/stages/identification/login.html b/passbook/stages/identification/templates/stages/identification/login.html index 95079ae13..86c7fe8a5 100644 --- a/passbook/stages/identification/templates/stages/identification/login.html +++ b/passbook/stages/identification/templates/stages/identification/login.html @@ -11,6 +11,13 @@
{% block above_form %} + {% if application_pre %} +

+ {% blocktrans with app_name=application_pre.name %} + Login to continue to {{ app_name }}. + {% endblocktrans %} +

+ {% endif %} {% endblock %} {% include 'partials/form.html' %}