policies: add unittests for evaluator

This commit is contained in:
Jens Langhammer 2020-02-23 15:54:26 +01:00
parent b99d23c119
commit 2b5fddb7bf
4 changed files with 72 additions and 11 deletions

View file

@ -19,7 +19,7 @@ LOGGER = get_logger()
class Evaluator: class Evaluator:
"""Validate and evaulate jinja2-based expressions""" """Validate and evaluate jinja2-based expressions"""
_env: NativeEnvironment _env: NativeEnvironment
@ -51,14 +51,15 @@ class Evaluator:
"""Return dictionary with additional global variables passed to expression""" """Return dictionary with additional global variables passed to expression"""
# update passbook/policies/expression/templates/policy/expression/form.html # update passbook/policies/expression/templates/policy/expression/form.html
# update docs/policies/expression/index.md # update docs/policies/expression/index.md
kwargs["pb_is_sso_flow"] = request.http_request.session.get(
AuthenticationView.SESSION_IS_SSO_LOGIN, False
)
kwargs["pb_is_group_member"] = Evaluator.jinja2_func_is_group_member kwargs["pb_is_group_member"] = Evaluator.jinja2_func_is_group_member
kwargs["pb_logger"] = get_logger() kwargs["pb_logger"] = get_logger()
kwargs["pb_client_ip"] = ( if request.http_request:
get_client_ip(request.http_request) or "255.255.255.255" kwargs["pb_is_sso_flow"] = request.http_request.session.get(
) AuthenticationView.SESSION_IS_SSO_LOGIN, False
)
kwargs["pb_client_ip"] = (
get_client_ip(request.http_request) or "255.255.255.255"
)
return kwargs return kwargs
def evaluate(self, expression_source: str, request: PolicyRequest) -> PolicyResult: def evaluate(self, expression_source: str, request: PolicyRequest) -> PolicyResult:
@ -81,7 +82,7 @@ class Evaluator:
req=request, req=request,
) )
return PolicyResult(False) return PolicyResult(False)
if isinstance(result, list) and len(result) == 2: if isinstance(result, (list, tuple)) and len(result) == 2:
return PolicyResult(*result) return PolicyResult(*result)
if result: if result:
return PolicyResult(result) return PolicyResult(result)

View file

@ -0,0 +1,58 @@
"""evaluator tests"""
from django.core.exceptions import ValidationError
from django.test import TestCase
from guardian.shortcuts import get_anonymous_user
from passbook.policies.expression.evaluator import Evaluator
from passbook.policies.types import PolicyRequest
class TestEvaluator(TestCase):
"""Evaluator tests"""
def setUp(self):
self.request = PolicyRequest(user=get_anonymous_user())
def test_valid(self):
"""test simple value expression"""
template = "True"
evaluator = Evaluator()
self.assertEqual(evaluator.evaluate(template, self.request).passing, True)
def test_messages(self):
"""test expression with message return"""
template = "False, 'some message'"
evaluator = Evaluator()
result = evaluator.evaluate(template, self.request)
self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("some message",))
def test_invalid_syntax(self):
"""test invalid syntax"""
template = "{%"
evaluator = Evaluator()
result = evaluator.evaluate(template, self.request)
self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("tag name expected",))
def test_undefined(self):
"""test undefined result"""
template = "{{ foo.bar }}"
evaluator = Evaluator()
result = evaluator.evaluate(template, self.request)
self.assertEqual(result.passing, False)
self.assertEqual(result.messages, ("'foo' is undefined",))
def test_validate(self):
"""test validate"""
template = "True"
evaluator = Evaluator()
result = evaluator.validate(template)
self.assertEqual(result, True)
def test_validate_invalid(self):
"""test validate"""
template = "{%"
evaluator = Evaluator()
with self.assertRaises(ValidationError):
evaluator.validate(template)

View file

@ -1,7 +1,7 @@
"""policy structures""" """policy structures"""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from django.db.models import Model from django.db.models import Model
from django.http import HttpRequest from django.http import HttpRequest
@ -14,11 +14,13 @@ class PolicyRequest:
"""Data-class to hold policy request data""" """Data-class to hold policy request data"""
user: User user: User
http_request: HttpRequest http_request: Optional[HttpRequest]
obj: Model obj: Optional[Model]
def __init__(self, user: User): def __init__(self, user: User):
self.user = user self.user = user
self.http_request = None
self.obj = None
def __str__(self): def __str__(self):
return f"<PolicyRequest user={self.user}>" return f"<PolicyRequest user={self.user}>"