policies/expression: add ak_call_policy

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-12-09 09:37:41 +01:00
parent 1ed2bddba7
commit 6209714f87
4 changed files with 42 additions and 9 deletions

View File

@ -11,6 +11,8 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO
from authentik.lib.expression.evaluator import BaseEvaluator from authentik.lib.expression.evaluator import BaseEvaluator
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip
from authentik.policies.exceptions import PolicyException from authentik.policies.exceptions import PolicyException
from authentik.policies.models import Policy, PolicyBinding
from authentik.policies.process import PolicyProcess
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
@ -31,6 +33,7 @@ class PolicyEvaluator(BaseEvaluator):
self._context["ak_logger"] = get_logger(policy_name) self._context["ak_logger"] = get_logger(policy_name)
self._context["ak_message"] = self.expr_func_message self._context["ak_message"] = self.expr_func_message
self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator
self._context["ak_call_policy"] = self.expr_func_call_policy
self._context["ip_address"] = ip_address self._context["ip_address"] = ip_address
self._context["ip_network"] = ip_network self._context["ip_network"] = ip_network
self._filename = policy_name or "PolicyEvaluator" self._filename = policy_name or "PolicyEvaluator"
@ -39,6 +42,16 @@ class PolicyEvaluator(BaseEvaluator):
"""Wrapper to append to messages list, which is returned with PolicyResult""" """Wrapper to append to messages list, which is returned with PolicyResult"""
self._messages.append(message) self._messages.append(message)
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
"""Call policy by name, with current request"""
policy = Policy.objects.filter(name=name).select_subclasses().first()
if not policy:
raise ValueError(f"Policy '{name}' not found.")
req: PolicyRequest = self._context["request"]
req.context.update(kwargs)
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
return proc.profiling_wrapper()
def expr_func_user_has_authenticator( def expr_func_user_has_authenticator(
self, user: User, device_type: Optional[str] = None self, user: User, device_type: Optional[str] = None
) -> bool: ) -> bool:

View File

@ -127,8 +127,8 @@ class PolicyProcess(PROCESS_CLASS):
) )
return policy_result return policy_result
def run(self): # pragma: no cover def profiling_wrapper(self):
"""Task wrapper to run policy checking""" """Run with profiling enabled"""
with Hub.current.start_span( with Hub.current.start_span(
op="policy.process.execute", op="policy.process.execute",
) as span, HIST_POLICIES_EXECUTION_TIME.labels( ) as span, HIST_POLICIES_EXECUTION_TIME.labels(
@ -142,8 +142,12 @@ class PolicyProcess(PROCESS_CLASS):
span: Span span: Span
span.set_data("policy", self.binding.policy) span.set_data("policy", self.binding.policy)
span.set_data("request", self.request) span.set_data("request", self.request)
return self.execute()
def run(self): # pragma: no cover
"""Task wrapper to run policy checking"""
try: try:
self.connection.send(self.execute()) self.connection.send(self.profiling_wrapper())
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
LOGGER.warning(str(exc)) LOGGER.warning(str(exc))
self.connection.send(PolicyResult(False, str(exc))) self.connection.send(PolicyResult(False, str(exc)))

View File

@ -1,7 +1,6 @@
"""authentik core celery""" """authentik core celery"""
import os import os
from logging.config import dictConfig from logging.config import dictConfig
from uuid import uuid4
from celery import Celery from celery import Celery
from celery.signals import ( from celery.signals import (

View File

@ -25,7 +25,7 @@ ak_message("Access denied")
return False return False
``` ```
### `ak_user_has_authenticator(user: User, device_type: Optional[str] = None)` (2021.9+) ### `ak_user_has_authenticator(user: User, device_type: Optional[str] = None) -> bool` (2021.9+)
Check if a user has any authenticator devices. Only fully validated devices are counted. Check if a user has any authenticator devices. Only fully validated devices are counted.
@ -42,6 +42,23 @@ Example:
return ak_user_has_authenticator(request.user) return ak_user_has_authenticator(request.user)
``` ```
### `ak_call_policy(name: str, **kwargs) -> PolicyResult` (2021.12+)
Call another policy with the name *name*. Current request is passed to policy. Key-word arguments
can be used to modify the request's context.
Example:
```python
result = ak_call_policy("test-policy")
# result is a PolicyResult object, so you can access `.passing` and `.messages`.
return result.passing
result = ak_call_policy("test-policy-2", foo="bar")
# Inside the `test-policy-2` you can then use `request.context["foo"]`
return result.passing
```
import Functions from '../expressions/_functions.md' import Functions from '../expressions/_functions.md'
<Functions /> <Functions />