diff --git a/docs/expressions/index.md b/docs/expressions/index.md index 9f34050ef..b4dcf80fe 100644 --- a/docs/expressions/index.md +++ b/docs/expressions/index.md @@ -53,3 +53,14 @@ Example: ```python other_user = pb_user_by(username="other_user") ``` + +## Comparing IP Addresses + +To compare IP Addresses or check if an IP Address is within a given subnet, you can use the functions `ip_address('192.0.2.1')` and `ip_network('192.0.2.0/24')`. With these objects you can do [arithmetic operations](https://docs.python.org/3/library/ipaddress.html#operators). + +You can also check if an IP Address is within a subnet by writing the following: + +```python +ip_address('192.0.2.1') in ip_network('192.0.2.0/24') +# evaluates to True +``` diff --git a/docs/policies/expression.md b/docs/policies/expression.md index c057c9e41..e4fbe05aa 100644 --- a/docs/policies/expression.md +++ b/docs/policies/expression.md @@ -26,5 +26,5 @@ return False - `request.obj`: A Django Model instance. This is only set if the policy is ran against an object. - `request.context`: A dictionary with dynamic data. This depends on the origin of the execution. - `pb_is_sso_flow`: Boolean which is true if request was initiated by authenticating through an external provider. -- `pb_client_ip`: Client's IP Address or '255.255.255.255' if no IP Address could be extracted. +- `pb_client_ip`: Client's IP Address or '255.255.255.255' if no IP Address could be extracted. Can be [compared](../expressions/index.md#comparing-ip-addresses) - `pb_flow_plan`: Current Plan if Policy is called from the Flow Planner. diff --git a/e2e/test_provider_oauth.py b/e2e/test_provider_oauth.py index 622ac9cdb..722872377 100644 --- a/e2e/test_provider_oauth.py +++ b/e2e/test_provider_oauth.py @@ -165,6 +165,7 @@ class TestProviderOAuth(SeleniumTestCase): By.XPATH, "/html/body/div[2]/div/main/div/form/div[2]/ul/li[1]" ).text, ) + sleep(1) self.driver.find_element(By.CSS_SELECTOR, "[type=submit]").click() self.wait_for_url("http://localhost:3000/?orgId=1") diff --git a/passbook/lib/expression/evaluator.py b/passbook/lib/expression/evaluator.py index c8595c9c8..df954d92c 100644 --- a/passbook/lib/expression/evaluator.py +++ b/passbook/lib/expression/evaluator.py @@ -64,7 +64,9 @@ class BaseEvaluator: def wrap_expression(self, expression: str, params: Iterable[str]) -> str: """Wrap expression in a function, call it, and save the result as `result`""" handler_signature = ",".join(params) - full_expression = f"def handler({handler_signature}):\n" + full_expression = "" + full_expression += "from ipaddress import ip_address, ip_network\n" + full_expression += f"def handler({handler_signature}):\n" full_expression += indent(expression, " ") full_expression += f"\nresult = handler({handler_signature})" return full_expression diff --git a/passbook/policies/expression/evaluator.py b/passbook/policies/expression/evaluator.py index 968e29834..a8bcf6d06 100644 --- a/passbook/policies/expression/evaluator.py +++ b/passbook/policies/expression/evaluator.py @@ -1,4 +1,5 @@ """passbook expression policy evaluator""" +from ipaddress import ip_address from typing import List from django.http import HttpRequest @@ -41,7 +42,9 @@ class PolicyEvaluator(BaseEvaluator): """Update context based on http request""" # update passbook/policies/expression/templates/policy/expression/form.html # update docs/policies/expression/index.md - self._context["pb_client_ip"] = get_client_ip(request) or "255.255.255.255" + self._context["pb_client_ip"] = ip_address( + get_client_ip(request) or "255.255.255.255" + ) self._context["request"] = request if SESSION_KEY_PLAN in request.session: self._context["pb_flow_plan"] = request.session[SESSION_KEY_PLAN] diff --git a/passbook/policies/expression/tests.py b/passbook/policies/expression/tests.py index 0526d6c9c..54b049813 100644 --- a/passbook/policies/expression/tests.py +++ b/passbook/policies/expression/tests.py @@ -36,7 +36,7 @@ class TestEvaluator(TestCase): evaluator.set_policy_request(self.request) result = evaluator.evaluate(template) self.assertEqual(result.passing, False) - self.assertEqual(result.messages, ("invalid syntax (test, line 2)",)) + self.assertEqual(result.messages, ("invalid syntax (test, line 3)",)) def test_undefined(self): """test undefined result"""