diff --git a/passbook/policy/engine.py b/passbook/policy/engine.py index 7c3af7b4c..341518afb 100644 --- a/passbook/policy/engine.py +++ b/passbook/policy/engine.py @@ -1,14 +1,14 @@ """passbook policy engine""" from multiprocessing import Pipe from multiprocessing.connection import Connection -from typing import List, Tuple +from typing import List, Tuple, Dict from django.core.cache import cache from django.http import HttpRequest from structlog import get_logger from passbook.core.models import Policy, User -from passbook.policy.struct import PolicyRequest +from passbook.policy.struct import PolicyRequest, PolicyResult from passbook.policy.process import PolicyProcess LOGGER = get_logger() @@ -16,6 +16,18 @@ LOGGER = get_logger() def _cache_key(policy, user): return f"policy_{policy.pk}#{user.pk}" +class PolicyProcessInfo: + + process: PolicyProcess + connection: Connection + result: PolicyResult = None + policy: Policy + + def __init__(self, process: PolicyProcess, connection: Connection, policy: Policy): + self.process = process + self.connection = connection + self.policy = policy + class PolicyEngine: """Orchestrate policy checking, launch tasks and return result""" @@ -23,12 +35,13 @@ class PolicyEngine: __request: HttpRequest __user: User - __proc_list: List[Tuple[Connection, PolicyProcess]] = [] + __processes: List[PolicyProcessInfo] = [] def __init__(self, policies, user: User = None, request: HttpRequest = None): self.policies = policies self.__request = request self.__user = user + self.__processes = [] def for_user(self, user: User) -> 'PolicyEngine': """Check policies for user""" @@ -57,36 +70,35 @@ class PolicyEngine: for policy in self._select_subclasses(): cached_policy = cache.get(_cache_key(policy, self.__user), None) if cached_policy: - LOGGER.debug("Taking result from cache", policy=policy.pk.hex) + LOGGER.debug("Taking result from cache", policy=policy) cached_policies.append(cached_policy) else: - LOGGER.debug("Evaluating policy", policy=policy.pk.hex) + LOGGER.debug("Evaluating policy", policy=policy) our_end, task_end = Pipe(False) - task = PolicyProcess() - task.ret = task_end - task.request = request - task.policy = policy - LOGGER.debug("Starting Process", class_name=task.__class__.__name__) + task = PolicyProcess(policy, request, task_end) + LOGGER.debug("Starting Process", for_policy=policy) task.start() - self.__proc_list.append((our_end, task)) + self.__processes.append(PolicyProcessInfo(process=task, + connection=our_end, policy=policy)) # If all policies are cached, we have an empty list here. - if self.__proc_list: - for _, running_proc in self.__proc_list: - running_proc.join() + for proc_info in self.__processes: + proc_info.process.join(proc_info.policy.timeout) + # Only call .recv() if no result is saved, otherwise we just deadlock here + if not proc_info.result: + proc_info.result = proc_info.connection.recv() return self @property def result(self) -> Tuple[bool, List[str]]: """Get policy-checking result""" messages: List[str] = [] - for our_end, _ in self.__proc_list: - policy_result = our_end.recv() + for proc_info in self.__processes: # passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ # (policy_action == Policy.ACTION_DENY and not policy_result) - LOGGER.debug('Result=%r => %r', policy_result, policy_result.passing) - if policy_result.messages: - messages += policy_result.messages - if not policy_result.passing: + LOGGER.debug("Result", passing=proc_info.result.passing) + if proc_info.result.messages: + messages += proc_info.result.messages + if not proc_info.result.passing: return False, messages return True, messages diff --git a/passbook/policy/process.py b/passbook/policy/process.py index 8b09d0cf3..9427e93f2 100644 --- a/passbook/policy/process.py +++ b/passbook/policy/process.py @@ -17,10 +17,16 @@ def _cache_key(policy, user): class PolicyProcess(Process): """Evaluate a single policy within a seprate process""" - ret: Connection + connection: Connection policy: Policy request: PolicyRequest + def __init__(self, policy: Policy, request: PolicyRequest, connection: Connection): + super().__init__() + self.policy = policy + self.request = request + self.connection = connection + def run(self): """Task wrapper to run policy checking""" LOGGER.debug("Running policy", policy=self.policy, @@ -38,5 +44,4 @@ class PolicyProcess(Process): # cache_key = _cache_key(self.policy, self.request.user) # cache.set(cache_key, (self.policy.action, policy_result, message)) # LOGGER.debug("Cached entry as %s", cache_key) - self.ret.send(policy_result) - self.ret.close() + self.connection.send(policy_result)