policies/*: remove Policy.negate, order, timeout (#39)
policies: rewrite engine to use PolicyBinding for order/negate/timeout policies: rewrite engine to use PolicyResult instead of tuple
This commit is contained in:
parent
fdfc6472d2
commit
df8995deed
|
@ -1,6 +1,7 @@
|
|||
"""passbook administration forms"""
|
||||
from django import forms
|
||||
|
||||
from passbook.admin.fields import CodeMirrorWidget, YAMLField
|
||||
from passbook.core.models import User
|
||||
|
||||
|
||||
|
@ -8,3 +9,4 @@ class PolicyTestForm(forms.Form):
|
|||
"""Form to test policies against user"""
|
||||
|
||||
user = forms.ModelChoiceField(queryset=User.objects.all())
|
||||
context = YAMLField(widget=CodeMirrorWidget())
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
"""passbook Policy administration"""
|
||||
from typing import Any, Dict
|
||||
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
from django.contrib.auth.mixins import (
|
||||
PermissionRequiredMixin as DjangoPermissionRequiredMixin,
|
||||
)
|
||||
from django.contrib.messages.views import SuccessMessageMixin
|
||||
from django.http import Http404
|
||||
from django.db.models import QuerySet
|
||||
from django.forms import Form
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.urls import reverse_lazy
|
||||
from django.utils.translation import ugettext as _
|
||||
from django.views.generic import DeleteView, FormView, ListView, UpdateView
|
||||
|
@ -15,8 +19,8 @@ from guardian.mixins import PermissionListMixin, PermissionRequiredMixin
|
|||
from passbook.admin.forms.policies import PolicyTestForm
|
||||
from passbook.lib.utils.reflection import all_subclasses, path_to_class
|
||||
from passbook.lib.views import CreateAssignPermView
|
||||
from passbook.policies.engine import PolicyEngine
|
||||
from passbook.policies.models import Policy
|
||||
from passbook.policies.models import Policy, PolicyBinding
|
||||
from passbook.policies.process import PolicyProcess, PolicyRequest
|
||||
|
||||
|
||||
class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView):
|
||||
|
@ -25,14 +29,14 @@ class PolicyListView(LoginRequiredMixin, PermissionListMixin, ListView):
|
|||
model = Policy
|
||||
permission_required = "passbook_policies.view_policy"
|
||||
paginate_by = 10
|
||||
ordering = "order"
|
||||
ordering = "name"
|
||||
template_name = "administration/policy/list.html"
|
||||
|
||||
def get_context_data(self, **kwargs):
|
||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs["types"] = {x.__name__: x for x in all_subclasses(Policy)}
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
def get_queryset(self) -> QuerySet:
|
||||
return super().get_queryset().select_subclasses()
|
||||
|
||||
|
||||
|
@ -51,14 +55,14 @@ class PolicyCreateView(
|
|||
success_url = reverse_lazy("passbook_admin:policies")
|
||||
success_message = _("Successfully created Policy")
|
||||
|
||||
def get_context_data(self, **kwargs):
|
||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs = super().get_context_data(**kwargs)
|
||||
form_cls = self.get_form_class()
|
||||
if hasattr(form_cls, "template_name"):
|
||||
kwargs["base_template"] = form_cls.template_name
|
||||
return kwargs
|
||||
|
||||
def get_form_class(self):
|
||||
def get_form_class(self) -> Form:
|
||||
policy_type = self.request.GET.get("type")
|
||||
try:
|
||||
model = next(x for x in all_subclasses(Policy) if x.__name__ == policy_type)
|
||||
|
@ -79,19 +83,19 @@ class PolicyUpdateView(
|
|||
success_url = reverse_lazy("passbook_admin:policies")
|
||||
success_message = _("Successfully updated Policy")
|
||||
|
||||
def get_context_data(self, **kwargs):
|
||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs = super().get_context_data(**kwargs)
|
||||
form_cls = self.get_form_class()
|
||||
if hasattr(form_cls, "template_name"):
|
||||
kwargs["base_template"] = form_cls.template_name
|
||||
return kwargs
|
||||
|
||||
def get_form_class(self):
|
||||
def get_form_class(self) -> Form:
|
||||
form_class_path = self.get_object().form
|
||||
form_class = path_to_class(form_class_path)
|
||||
return form_class
|
||||
|
||||
def get_object(self, queryset=None):
|
||||
def get_object(self, queryset=None) -> Policy:
|
||||
return (
|
||||
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
|
||||
)
|
||||
|
@ -109,12 +113,12 @@ class PolicyDeleteView(
|
|||
success_url = reverse_lazy("passbook_admin:policies")
|
||||
success_message = _("Successfully deleted Policy")
|
||||
|
||||
def get_object(self, queryset=None):
|
||||
def get_object(self, queryset=None) -> Policy:
|
||||
return (
|
||||
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
|
||||
)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
def delete(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
messages.success(self.request, self.success_message)
|
||||
return super().delete(request, *args, **kwargs)
|
||||
|
||||
|
@ -128,26 +132,29 @@ class PolicyTestView(LoginRequiredMixin, DetailView, PermissionRequiredMixin, Fo
|
|||
template_name = "administration/policy/test.html"
|
||||
object = None
|
||||
|
||||
def get_object(self, queryset=None):
|
||||
def get_object(self, queryset=None) -> QuerySet:
|
||||
return (
|
||||
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
|
||||
)
|
||||
|
||||
def get_context_data(self, **kwargs):
|
||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs["policy"] = self.get_object()
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
def post(self, *args, **kwargs):
|
||||
def post(self, *args, **kwargs) -> HttpResponse:
|
||||
self.object = self.get_object()
|
||||
return super().post(*args, **kwargs)
|
||||
|
||||
def form_valid(self, form):
|
||||
def form_valid(self, form: PolicyTestForm) -> HttpResponse:
|
||||
policy = self.get_object()
|
||||
user = form.cleaned_data.get("user")
|
||||
policy_engine = PolicyEngine([policy], user, self.request)
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.build()
|
||||
result = policy_engine.passing
|
||||
|
||||
p_request = PolicyRequest(user)
|
||||
p_request.http_request = self.request
|
||||
p_request.context = form.cleaned_data
|
||||
|
||||
proc = PolicyProcess(PolicyBinding(policy=policy), p_request, None)
|
||||
result = proc.execute()
|
||||
if result:
|
||||
messages.success(self.request, _("User successfully passed policy."))
|
||||
else:
|
||||
|
|
|
@ -17,12 +17,15 @@ password_changed = Signal(providing_args=["user", "password"])
|
|||
# pylint: disable=unused-argument
|
||||
def invalidate_policy_cache(sender, instance, **_):
|
||||
"""Invalidate Policy cache when policy is updated"""
|
||||
from passbook.policies.models import Policy
|
||||
from passbook.policies.models import Policy, PolicyBinding
|
||||
from passbook.policies.process import cache_key
|
||||
|
||||
if isinstance(instance, Policy):
|
||||
LOGGER.debug("Invalidating policy cache", policy=instance)
|
||||
prefix = cache_key(instance) + "*"
|
||||
keys = cache.keys(prefix)
|
||||
cache.delete_many(keys)
|
||||
LOGGER.debug("Deleted %d keys", len(keys))
|
||||
total = 0
|
||||
for binding in PolicyBinding.objects.filter(policy=instance):
|
||||
prefix = cache_key(binding) + "*"
|
||||
keys = cache.keys(prefix)
|
||||
total += len(keys)
|
||||
cache.delete_many(keys)
|
||||
LOGGER.debug("Deleted keys", len=total)
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
"""passbook access helper classes"""
|
||||
from typing import List, Tuple
|
||||
|
||||
from django.contrib import messages
|
||||
from django.http import HttpRequest
|
||||
from django.utils.translation import gettext as _
|
||||
|
@ -8,6 +6,7 @@ from structlog import get_logger
|
|||
|
||||
from passbook.core.models import Application, Provider, User
|
||||
from passbook.policies.engine import PolicyEngine
|
||||
from passbook.policies.types import PolicyResult
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
@ -33,9 +32,7 @@ class AccessMixin:
|
|||
)
|
||||
raise exc
|
||||
|
||||
def user_has_access(
|
||||
self, application: Application, user: User
|
||||
) -> Tuple[bool, List[str]]:
|
||||
def user_has_access(self, application: Application, user: User) -> PolicyResult:
|
||||
"""Check if user has access to application."""
|
||||
LOGGER.debug("Checking permissions", user=user, application=application)
|
||||
policy_engine = PolicyEngine(application.policies.all(), user, self.request)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Flows Planner"""
|
||||
from dataclasses import dataclass, field
|
||||
from time import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest
|
||||
|
@ -11,6 +11,7 @@ from passbook.core.models import User
|
|||
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||
from passbook.flows.models import Flow, Stage
|
||||
from passbook.policies.engine import PolicyEngine
|
||||
from passbook.policies.types import PolicyResult
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
@ -51,8 +52,8 @@ class FlowPlanner:
|
|||
self.use_cache = True
|
||||
self.flow = flow
|
||||
|
||||
def _check_flow_root_policies(self, request: HttpRequest) -> Tuple[bool, List[str]]:
|
||||
engine = PolicyEngine(self.flow.policies.all(), request.user, request)
|
||||
def _check_flow_root_policies(self, request: HttpRequest) -> PolicyResult:
|
||||
engine = PolicyEngine(self.flow, request.user, request)
|
||||
engine.build()
|
||||
return engine.result
|
||||
|
||||
|
@ -64,9 +65,9 @@ class FlowPlanner:
|
|||
LOGGER.debug("f(plan): Starting planning process", flow=self.flow)
|
||||
# First off, check the flow's direct policy bindings
|
||||
# to make sure the user even has access to the flow
|
||||
root_passing, root_passing_messages = self._check_flow_root_policies(request)
|
||||
if not root_passing:
|
||||
raise FlowNonApplicableException(root_passing_messages)
|
||||
root_result = self._check_flow_root_policies(request)
|
||||
if not root_result.passing:
|
||||
raise FlowNonApplicableException(*root_result.messages)
|
||||
# Bit of a workaround here, if there is a pending user set in the default context
|
||||
# we use that user for our cache key
|
||||
# to make sure they don't get the generic response
|
||||
|
@ -106,11 +107,10 @@ class FlowPlanner:
|
|||
.select_related()
|
||||
):
|
||||
binding = stage.flowstagebinding_set.get(flow__pk=self.flow.pk)
|
||||
engine = PolicyEngine(binding.policies.all(), user, request)
|
||||
engine = PolicyEngine(binding, user, request)
|
||||
engine.request.context = plan.context
|
||||
engine.build()
|
||||
passing, _ = engine.result
|
||||
if passing:
|
||||
if engine.passing:
|
||||
LOGGER.debug("f(plan): Stage passing", stage=stage, flow=self.flow)
|
||||
plan.stages.append(stage)
|
||||
end_time = time()
|
||||
|
|
|
@ -8,9 +8,10 @@ from guardian.shortcuts import get_anonymous_user
|
|||
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||
from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
|
||||
from passbook.flows.planner import FlowPlanner
|
||||
from passbook.policies.types import PolicyResult
|
||||
from passbook.stages.dummy.models import DummyStage
|
||||
|
||||
POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],))
|
||||
POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False))
|
||||
TIME_NOW_MOCK = MagicMock(return_value=3)
|
||||
|
||||
|
||||
|
|
|
@ -9,9 +9,10 @@ from passbook.flows.models import Flow, FlowDesignation, FlowStageBinding
|
|||
from passbook.flows.planner import FlowPlan
|
||||
from passbook.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN
|
||||
from passbook.lib.config import CONFIG
|
||||
from passbook.policies.types import PolicyResult
|
||||
from passbook.stages.dummy.models import DummyStage
|
||||
|
||||
POLICY_RESULT_MOCK = MagicMock(return_value=(False, [""],))
|
||||
POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False))
|
||||
|
||||
|
||||
class TestFlowExecutor(TestCase):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Generic models"""
|
||||
from django.db import models
|
||||
from model_utils.managers import InheritanceManager
|
||||
|
||||
|
||||
class CreatedUpdatedModel(models.Model):
|
||||
|
@ -10,3 +11,27 @@ class CreatedUpdatedModel(models.Model):
|
|||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class InheritanceAutoManager(InheritanceManager):
|
||||
"""Object manager which automatically selects the subclass"""
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().select_subclasses()
|
||||
|
||||
|
||||
class InheritanceForwardManyToOneDescriptor(
|
||||
models.fields.related.ForwardManyToOneDescriptor
|
||||
):
|
||||
"""Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager."""
|
||||
|
||||
def get_queryset(self, **hints):
|
||||
return self.field.remote_field.model.objects.db_manager(
|
||||
hints=hints
|
||||
).select_subclasses()
|
||||
|
||||
|
||||
class InheritanceForeignKey(models.ForeignKey):
|
||||
"""Custom ForeignKey that uses InheritanceForwardManyToOneDescriptor"""
|
||||
|
||||
forward_related_accessor_class = InheritanceForwardManyToOneDescriptor
|
||||
|
|
|
@ -12,7 +12,7 @@ class PolicyBindingSerializer(ModelSerializer):
|
|||
class Meta:
|
||||
|
||||
model = PolicyBinding
|
||||
fields = ["policy", "target", "enabled", "order"]
|
||||
fields = ["policy", "target", "enabled", "order", "timeout"]
|
||||
|
||||
|
||||
class PolicyBindingViewSet(ModelViewSet):
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
"""passbook policy engine"""
|
||||
from multiprocessing import Pipe, set_start_method
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest
|
||||
from structlog import get_logger
|
||||
|
||||
from passbook.core.models import User
|
||||
from passbook.policies.models import Policy
|
||||
from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel
|
||||
from passbook.policies.process import PolicyProcess, cache_key
|
||||
from passbook.policies.types import PolicyRequest, PolicyResult
|
||||
|
||||
|
@ -24,12 +24,14 @@ class PolicyProcessInfo:
|
|||
process: PolicyProcess
|
||||
connection: Connection
|
||||
result: Optional[PolicyResult]
|
||||
policy: Policy
|
||||
binding: PolicyBinding
|
||||
|
||||
def __init__(self, process: PolicyProcess, connection: Connection, policy: Policy):
|
||||
def __init__(
|
||||
self, process: PolicyProcess, connection: Connection, binding: PolicyBinding
|
||||
):
|
||||
self.process = process
|
||||
self.connection = connection
|
||||
self.policy = policy
|
||||
self.binding = binding
|
||||
self.result = None
|
||||
|
||||
|
||||
|
@ -37,54 +39,64 @@ class PolicyEngine:
|
|||
"""Orchestrate policy checking, launch tasks and return result"""
|
||||
|
||||
use_cache: bool = True
|
||||
policies: List[Policy] = []
|
||||
request: PolicyRequest
|
||||
|
||||
__pbm: PolicyBindingModel
|
||||
__cached_policies: List[PolicyResult]
|
||||
__processes: List[PolicyProcessInfo]
|
||||
|
||||
def __init__(self, policies, user: User, request: HttpRequest = None):
|
||||
self.policies = policies
|
||||
def __init__(
|
||||
self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None
|
||||
):
|
||||
if not isinstance(pbm, PolicyBindingModel):
|
||||
raise ValueError(f"{pbm} is not instance of PolicyBindingModel")
|
||||
self.__pbm = pbm
|
||||
self.request = PolicyRequest(user)
|
||||
if request:
|
||||
self.request.http_request = request
|
||||
self.__cached_policies = []
|
||||
self.__processes = []
|
||||
|
||||
def _select_subclasses(self) -> List[Policy]:
|
||||
def _iter_bindings(self) -> List[PolicyBinding]:
|
||||
"""Make sure all Policies are their respective classes"""
|
||||
return (
|
||||
Policy.objects.filter(pk__in=[x.pk for x in self.policies])
|
||||
.select_subclasses()
|
||||
.order_by("order")
|
||||
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by(
|
||||
"order"
|
||||
)
|
||||
|
||||
def _check_policy_type(self, policy: Policy):
|
||||
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
|
||||
# policy_type = type(policy)
|
||||
if policy.__class__ == Policy:
|
||||
raise TypeError(f"Policy '{policy}' is root type")
|
||||
|
||||
def build(self) -> "PolicyEngine":
|
||||
"""Build task group"""
|
||||
for policy in self._select_subclasses():
|
||||
cached_policy = cache.get(cache_key(policy, self.request.user), None)
|
||||
for binding in self._iter_bindings():
|
||||
self._check_policy_type(binding.policy)
|
||||
policy = binding.policy
|
||||
cached_policy = cache.get(cache_key(binding, self.request.user), None)
|
||||
if cached_policy and self.use_cache:
|
||||
LOGGER.debug("P_ENG: Taking result from cache", policy=policy)
|
||||
self.__cached_policies.append(cached_policy)
|
||||
continue
|
||||
LOGGER.debug("P_ENG: Evaluating policy", policy=policy)
|
||||
our_end, task_end = Pipe(False)
|
||||
task = PolicyProcess(policy, self.request, task_end)
|
||||
task = PolicyProcess(binding, self.request, task_end)
|
||||
LOGGER.debug("P_ENG: Starting Process", policy=policy)
|
||||
task.start()
|
||||
self.__processes.append(
|
||||
PolicyProcessInfo(process=task, connection=our_end, policy=policy)
|
||||
PolicyProcessInfo(process=task, connection=our_end, binding=binding)
|
||||
)
|
||||
# If all policies are cached, we have an empty list here.
|
||||
for proc_info in self.__processes:
|
||||
proc_info.process.join(proc_info.policy.timeout)
|
||||
proc_info.process.join(proc_info.binding.timeout)
|
||||
# Only call .recv() if no result is saved, otherwise we just deadlock here
|
||||
if not proc_info.result:
|
||||
proc_info.result = proc_info.connection.recv()
|
||||
return self
|
||||
|
||||
@property
|
||||
def result(self) -> Tuple[bool, List[str]]:
|
||||
def result(self) -> PolicyResult:
|
||||
"""Get policy-checking result"""
|
||||
messages: List[str] = []
|
||||
process_results: List[PolicyResult] = [
|
||||
|
@ -95,10 +107,10 @@ class PolicyEngine:
|
|||
if result.messages:
|
||||
messages += result.messages
|
||||
if not result.passing:
|
||||
return False, messages
|
||||
return True, messages
|
||||
return PolicyResult(False, *messages)
|
||||
return PolicyResult(True, *messages)
|
||||
|
||||
@property
|
||||
def passing(self) -> bool:
|
||||
"""Only get true/false if user passes"""
|
||||
return self.result[0]
|
||||
return self.result.passing
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
|||
|
||||
from django.core.exceptions import ValidationError
|
||||
from jinja2 import Undefined
|
||||
from jinja2.exceptions import TemplateSyntaxError, UndefinedError
|
||||
from jinja2.exceptions import TemplateSyntaxError
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
from requests import Session
|
||||
from structlog import get_logger
|
||||
|
@ -90,7 +90,8 @@ class Evaluator:
|
|||
if result:
|
||||
return PolicyResult(bool(result))
|
||||
return PolicyResult(False)
|
||||
except UndefinedError as exc:
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
LOGGER.warning("Expression error", exc=exc)
|
||||
return PolicyResult(False, str(exc))
|
||||
|
||||
def validate(self, expression: str):
|
||||
|
|
|
@ -3,8 +3,8 @@ from django import forms
|
|||
|
||||
from passbook.policies.models import PolicyBinding, PolicyBindingModel
|
||||
|
||||
GENERAL_FIELDS = ["name", "negate", "order", "timeout"]
|
||||
GENERAL_SERIALIZER_FIELDS = ["pk", "name", "negate", "order", "timeout"]
|
||||
GENERAL_FIELDS = ["name"]
|
||||
GENERAL_SERIALIZER_FIELDS = ["pk", "name"]
|
||||
|
||||
|
||||
class PolicyBindingForm(forms.ModelForm):
|
||||
|
@ -18,9 +18,4 @@ class PolicyBindingForm(forms.ModelForm):
|
|||
class Meta:
|
||||
|
||||
model = PolicyBinding
|
||||
fields = [
|
||||
"enabled",
|
||||
"policy",
|
||||
"target",
|
||||
"order",
|
||||
]
|
||||
fields = ["enabled", "policy", "target", "order", "timeout"]
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Generated by Django 3.0.6 on 2020-05-28 16:47
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import passbook.lib.models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("passbook_policies", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="policy",
|
||||
options={
|
||||
"base_manager_name": "objects",
|
||||
"verbose_name": "Policy",
|
||||
"verbose_name_plural": "Policies",
|
||||
},
|
||||
),
|
||||
migrations.RemoveField(model_name="policy", name="negate",),
|
||||
migrations.RemoveField(model_name="policy", name="order",),
|
||||
migrations.RemoveField(model_name="policy", name="timeout",),
|
||||
migrations.AddField(
|
||||
model_name="policybinding",
|
||||
name="negate",
|
||||
field=models.BooleanField(
|
||||
default=False,
|
||||
help_text="Negates the outcome of the policy. Messages are unaffected.",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="policybinding",
|
||||
name="timeout",
|
||||
field=models.IntegerField(
|
||||
default=30,
|
||||
help_text="Timeout after which Policy execution is terminated.",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="policybinding", name="order", field=models.IntegerField(),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="policybinding",
|
||||
name="policy",
|
||||
field=passbook.lib.models.InheritanceForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="+",
|
||||
to="passbook_policies.Policy",
|
||||
),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="policybinding", unique_together={("policy", "target", "order")},
|
||||
),
|
||||
]
|
|
@ -5,7 +5,11 @@ from django.db import models
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
from model_utils.managers import InheritanceManager
|
||||
|
||||
from passbook.lib.models import CreatedUpdatedModel
|
||||
from passbook.lib.models import (
|
||||
CreatedUpdatedModel,
|
||||
InheritanceAutoManager,
|
||||
InheritanceForeignKey,
|
||||
)
|
||||
from passbook.policies.exceptions import PolicyException
|
||||
from passbook.policies.types import PolicyRequest, PolicyResult
|
||||
|
||||
|
@ -22,7 +26,6 @@ class PolicyBindingModel(models.Model):
|
|||
objects = InheritanceManager()
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = _("Policy Binding Model")
|
||||
verbose_name_plural = _("Policy Binding Models")
|
||||
|
||||
|
@ -36,13 +39,19 @@ class PolicyBinding(models.Model):
|
|||
|
||||
enabled = models.BooleanField(default=True)
|
||||
|
||||
policy = models.ForeignKey("Policy", on_delete=models.CASCADE, related_name="+")
|
||||
policy = InheritanceForeignKey("Policy", on_delete=models.CASCADE, related_name="+")
|
||||
target = models.ForeignKey(
|
||||
PolicyBindingModel, on_delete=models.CASCADE, related_name="+"
|
||||
)
|
||||
negate = models.BooleanField(
|
||||
default=False,
|
||||
help_text=_("Negates the outcome of the policy. Messages are unaffected."),
|
||||
)
|
||||
timeout = models.IntegerField(
|
||||
default=30, help_text=_("Timeout after which Policy execution is terminated.")
|
||||
)
|
||||
|
||||
# default value and non-unique for compatibility
|
||||
order = models.IntegerField(default=0)
|
||||
order = models.IntegerField()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"PolicyBinding policy={self.policy} target={self.target} order={self.order}"
|
||||
|
@ -51,6 +60,7 @@ class PolicyBinding(models.Model):
|
|||
|
||||
verbose_name = _("Policy Binding")
|
||||
verbose_name_plural = _("Policy Bindings")
|
||||
unique_together = ("policy", "target", "order")
|
||||
|
||||
|
||||
class Policy(CreatedUpdatedModel):
|
||||
|
@ -60,11 +70,8 @@ class Policy(CreatedUpdatedModel):
|
|||
policy_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
|
||||
name = models.TextField(blank=True, null=True)
|
||||
negate = models.BooleanField(default=False)
|
||||
order = models.IntegerField(default=0)
|
||||
timeout = models.IntegerField(default=30)
|
||||
|
||||
objects = InheritanceManager()
|
||||
objects = InheritanceAutoManager()
|
||||
|
||||
def __str__(self):
|
||||
return f"Policy {self.name}"
|
||||
|
@ -72,3 +79,9 @@ class Policy(CreatedUpdatedModel):
|
|||
def passes(self, request: PolicyRequest) -> PolicyResult:
|
||||
"""Check if user instance passes this policy"""
|
||||
raise PolicyException()
|
||||
|
||||
class Meta:
|
||||
base_manager_name = "objects"
|
||||
|
||||
verbose_name = _("Policy")
|
||||
verbose_name_plural = _("Policies")
|
||||
|
|
|
@ -8,15 +8,15 @@ from structlog import get_logger
|
|||
|
||||
from passbook.core.models import User
|
||||
from passbook.policies.exceptions import PolicyException
|
||||
from passbook.policies.models import Policy
|
||||
from passbook.policies.models import PolicyBinding
|
||||
from passbook.policies.types import PolicyRequest, PolicyResult
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def cache_key(policy: Policy, user: Optional[User] = None) -> str:
|
||||
def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str:
|
||||
"""Generate Cache key for policy"""
|
||||
prefix = f"policy_{policy.pk}"
|
||||
prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}"
|
||||
if user:
|
||||
prefix += f"#{user.pk}"
|
||||
return prefix
|
||||
|
@ -26,40 +26,50 @@ class PolicyProcess(Process):
|
|||
"""Evaluate a single policy within a seprate process"""
|
||||
|
||||
connection: Connection
|
||||
policy: Policy
|
||||
binding: PolicyBinding
|
||||
request: PolicyRequest
|
||||
|
||||
def __init__(self, policy: Policy, request: PolicyRequest, connection: Connection):
|
||||
def __init__(
|
||||
self,
|
||||
binding: PolicyBinding,
|
||||
request: PolicyRequest,
|
||||
connection: Optional[Connection],
|
||||
):
|
||||
super().__init__()
|
||||
self.policy = policy
|
||||
self.binding = binding
|
||||
self.request = request
|
||||
self.connection = connection
|
||||
if connection:
|
||||
self.connection = connection
|
||||
|
||||
def run(self):
|
||||
"""Task wrapper to run policy checking"""
|
||||
def execute(self) -> PolicyResult:
|
||||
"""Run actual policy, returns result"""
|
||||
LOGGER.debug(
|
||||
"P_ENG(proc): Running policy",
|
||||
policy=self.policy,
|
||||
policy=self.binding.policy,
|
||||
user=self.request.user,
|
||||
process="PolicyProcess",
|
||||
)
|
||||
try:
|
||||
policy_result = self.policy.passes(self.request)
|
||||
policy_result = self.binding.policy.passes(self.request)
|
||||
except PolicyException as exc:
|
||||
LOGGER.debug("P_ENG(proc): error", exc=exc)
|
||||
policy_result = PolicyResult(False, str(exc))
|
||||
# Invert result if policy.negate is set
|
||||
if self.policy.negate:
|
||||
if self.binding.negate:
|
||||
policy_result.passing = not policy_result.passing
|
||||
LOGGER.debug(
|
||||
"P_ENG(proc): Finished",
|
||||
policy=self.policy,
|
||||
policy=self.binding.policy,
|
||||
result=policy_result,
|
||||
process="PolicyProcess",
|
||||
passing=policy_result.passing,
|
||||
user=self.request.user,
|
||||
)
|
||||
key = cache_key(self.policy, self.request.user)
|
||||
key = cache_key(self.binding, self.request.user)
|
||||
cache.set(key, policy_result)
|
||||
LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key)
|
||||
self.connection.send(policy_result)
|
||||
return policy_result
|
||||
|
||||
def run(self):
|
||||
"""Task wrapper to run policy checking"""
|
||||
self.connection.send(self.execute())
|
||||
|
|
|
@ -5,7 +5,8 @@ from django.test import TestCase
|
|||
from passbook.core.models import User
|
||||
from passbook.policies.dummy.models import DummyPolicy
|
||||
from passbook.policies.engine import PolicyEngine
|
||||
from passbook.policies.models import Policy
|
||||
from passbook.policies.expression.models import ExpressionPolicy
|
||||
from passbook.policies.models import Policy, PolicyBinding, PolicyBindingModel
|
||||
|
||||
|
||||
class PolicyTestEngine(TestCase):
|
||||
|
@ -20,40 +21,64 @@ class PolicyTestEngine(TestCase):
|
|||
self.policy_true = DummyPolicy.objects.create(
|
||||
result=True, wait_min=0, wait_max=1
|
||||
)
|
||||
self.policy_negate = DummyPolicy.objects.create(
|
||||
negate=True, result=True, wait_min=0, wait_max=1
|
||||
self.policy_wrong_type = Policy.objects.create(name="wrong_type")
|
||||
self.policy_raises = ExpressionPolicy.objects.create(
|
||||
name="raises", expression="{{ 0/0 }}"
|
||||
)
|
||||
self.policy_raises = Policy.objects.create(name="raises")
|
||||
|
||||
def test_engine_empty(self):
|
||||
"""Ensure empty policy list passes"""
|
||||
engine = PolicyEngine([], self.user)
|
||||
self.assertEqual(engine.build().passing, True)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, True)
|
||||
self.assertEqual(result.messages, ())
|
||||
|
||||
def test_engine(self):
|
||||
"""Ensure all policies passes (Mix of false and true -> false)"""
|
||||
engine = PolicyEngine(
|
||||
DummyPolicy.objects.filter(negate__exact=False), self.user
|
||||
)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, False)
|
||||
self.assertEqual(result.messages, ("dummy",))
|
||||
|
||||
def test_engine_negate(self):
|
||||
"""Test negate flag"""
|
||||
engine = PolicyEngine(DummyPolicy.objects.filter(negate__exact=True), self.user)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(
|
||||
target=pbm, policy=self.policy_true, negate=True, order=0
|
||||
)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, False)
|
||||
self.assertEqual(result.messages, ("dummy",))
|
||||
|
||||
def test_engine_policy_error(self):
|
||||
"""Test negate flag"""
|
||||
engine = PolicyEngine(Policy.objects.filter(name="raises"), self.user)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
"""Test policy raising an error flag"""
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_raises, order=0)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
result = engine.build().result
|
||||
self.assertEqual(result.passing, False)
|
||||
self.assertEqual(result.messages, ("division by zero",))
|
||||
|
||||
def test_engine_policy_type(self):
|
||||
"""Test invalid policy type"""
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_wrong_type, order=0)
|
||||
with self.assertRaises(TypeError):
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
engine.build()
|
||||
|
||||
def test_engine_cache(self):
|
||||
"""Ensure empty policy list passes"""
|
||||
engine = PolicyEngine(
|
||||
DummyPolicy.objects.filter(negate__exact=False), self.user
|
||||
)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0)
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
self.assertEqual(len(cache.keys("policy_*")), 0)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
self.assertEqual(len(cache.keys("policy_*")), 2)
|
||||
self.assertEqual(len(cache.keys("policy_*")), 1)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
self.assertEqual(len(cache.keys("policy_*")), 2)
|
||||
self.assertEqual(len(cache.keys("policy_*")), 1)
|
||||
|
|
|
@ -50,9 +50,9 @@ class PassbookAuthorizationView(AccessMixin, AuthorizationView):
|
|||
provider.save()
|
||||
self._application = application
|
||||
# Check permissions
|
||||
passing, policy_messages = self.user_has_access(self._application, request.user)
|
||||
if not passing:
|
||||
for policy_message in policy_messages:
|
||||
result = self.user_has_access(self._application, request.user)
|
||||
if not result.passing:
|
||||
for policy_message in result.messages:
|
||||
messages.error(request, policy_message)
|
||||
return redirect("passbook_providers_oauth:oauth2-permission-denied")
|
||||
# Some clients don't pass response_type, so we default to code
|
||||
|
|
|
@ -18,7 +18,7 @@ LOGGER = get_logger()
|
|||
def client_related_provider(client: Client) -> Optional[Provider]:
|
||||
"""Lookup related Application from Client"""
|
||||
# because oidc_provider is also used by app_gw, we can't be
|
||||
# sure an OpenIDPRovider instance exists. hence we look through all related models
|
||||
# sure an OpenIDProvider instance exists. hence we look through all related models
|
||||
# and choose the one that inherits from Provider, which is guaranteed to
|
||||
# have the application property
|
||||
collector = Collector(using="default")
|
||||
|
@ -50,9 +50,9 @@ def check_permissions(
|
|||
policy_engine.build()
|
||||
|
||||
# Check permissions
|
||||
passing, policy_messages = policy_engine.result
|
||||
if not passing:
|
||||
for policy_message in policy_messages:
|
||||
result = policy_engine.result
|
||||
if not result.passing:
|
||||
for policy_message in result.messages:
|
||||
messages.error(request, policy_message)
|
||||
return redirect("passbook_providers_oauth:oauth2-permission-denied")
|
||||
|
||||
|
|
|
@ -55,9 +55,9 @@ class PromptForm(forms.Form):
|
|||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
|
||||
engine = PolicyEngine(self.stage.policies.all(), user)
|
||||
engine = PolicyEngine(self.stage, user)
|
||||
engine.request.context = cleaned_data
|
||||
engine.build()
|
||||
passing, messages = engine.result
|
||||
if not passing:
|
||||
raise forms.ValidationError(messages)
|
||||
result = engine.result
|
||||
if not result.passing:
|
||||
raise forms.ValidationError(result.messages)
|
||||
|
|
|
@ -139,7 +139,7 @@ class TestPromptStage(TestCase):
|
|||
expr_policy = ExpressionPolicy.objects.create(
|
||||
name="validate-form", expression=expr
|
||||
)
|
||||
PolicyBinding.objects.create(policy=expr_policy, target=self.stage)
|
||||
PolicyBinding.objects.create(policy=expr_policy, target=self.stage, order=0)
|
||||
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
|
||||
self.assertEqual(form.is_valid(), True)
|
||||
return form
|
||||
|
@ -151,7 +151,7 @@ class TestPromptStage(TestCase):
|
|||
expr_policy = ExpressionPolicy.objects.create(
|
||||
name="validate-form", expression=expr
|
||||
)
|
||||
PolicyBinding.objects.create(policy=expr_policy, target=self.stage)
|
||||
PolicyBinding.objects.create(policy=expr_policy, target=self.stage, order=0)
|
||||
form = PromptForm(stage=self.stage, plan=plan, data=self.prompt_data)
|
||||
self.assertEqual(form.is_valid(), False)
|
||||
return form
|
||||
|
|
100
swagger.yaml
100
swagger.yaml
|
@ -837,7 +837,7 @@ paths:
|
|||
parameters:
|
||||
- name: policy_uuid
|
||||
in: path
|
||||
description: A UUID string identifying this policy.
|
||||
description: A UUID string identifying this Policy.
|
||||
required: true
|
||||
type: string
|
||||
format: uuid
|
||||
|
@ -5079,19 +5079,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
__type__:
|
||||
title: 'type '
|
||||
type: string
|
||||
|
@ -5100,6 +5087,7 @@ definitions:
|
|||
required:
|
||||
- policy
|
||||
- target
|
||||
- order
|
||||
type: object
|
||||
properties:
|
||||
policy:
|
||||
|
@ -5118,6 +5106,12 @@ definitions:
|
|||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
description: Timeout after which Policy execution is terminated.
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
DummyPolicy:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5130,19 +5124,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
result:
|
||||
title: Result
|
||||
type: boolean
|
||||
|
@ -5170,19 +5151,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
expression:
|
||||
title: Expression
|
||||
type: string
|
||||
|
@ -5199,19 +5167,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
allowed_count:
|
||||
title: Allowed count
|
||||
type: integer
|
||||
|
@ -5231,19 +5186,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
amount_uppercase:
|
||||
title: Amount uppercase
|
||||
type: integer
|
||||
|
@ -5286,19 +5228,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
days:
|
||||
title: Days
|
||||
type: integer
|
||||
|
@ -5319,19 +5248,6 @@ definitions:
|
|||
title: Name
|
||||
type: string
|
||||
x-nullable: true
|
||||
negate:
|
||||
title: Negate
|
||||
type: boolean
|
||||
order:
|
||||
title: Order
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
timeout:
|
||||
title: Timeout
|
||||
type: integer
|
||||
maximum: 2147483647
|
||||
minimum: -2147483648
|
||||
check_ip:
|
||||
title: Check ip
|
||||
type: boolean
|
||||
|
|
Reference in New Issue