"""passbook core policy engine""" import cProfile from logging import getLogger from amqp.exceptions import UnexpectedFrame from celery import group from celery.exceptions import TimeoutError as CeleryTimeoutError from django.core.cache import cache from ipware import get_client_ip from passbook.core.celery import CELERY_APP from passbook.core.models import Policy, User def profileit(func): def wrapper(*args, **kwargs): datafn = func.__name__ + ".profile" # Name the data file sensibly prof = cProfile.Profile() retval = prof.runcall(func, *args, **kwargs) prof.dump_stats(datafn) return retval return wrapper LOGGER = getLogger(__name__) def _cache_key(policy, user): return "%s#%s" % (policy.uuid, user.pk) @CELERY_APP.task() def _policy_engine_task(user_pk, policy_pk, **kwargs): """Task wrapper to run policy checking""" if not user_pk: raise ValueError() policy_obj = Policy.objects.filter(pk=policy_pk).select_subclasses().first() user_obj = User.objects.get(pk=user_pk) for key, value in kwargs.items(): setattr(user_obj, key, value) LOGGER.debug("Running policy `%s`#%s for user %s...", policy_obj.name, policy_obj.pk.hex, user_obj) policy_result = policy_obj.passes(user_obj) # Handle policy result correctly if result, message or just result message = None if isinstance(policy_result, (tuple, list)): policy_result, message = policy_result # Invert result if policy.negate is set if policy_obj.negate: policy_result = not policy_result LOGGER.debug("Policy %r#%s got %s", policy_obj.name, policy_obj.pk.hex, policy_result) cache_key = _cache_key(policy_obj, user_obj) cache.set(cache_key, (policy_obj.action, policy_result, message)) LOGGER.debug("Cached entry as %s", cache_key) return policy_obj.action, policy_result, message class PolicyEngine: """Orchestrate policy checking, launch tasks and return result""" __group = None __cached = None policies = None __get_timeout = 0 __request = None __user = None def __init__(self, policies): self.policies = policies self.__request = None self.__user = None def for_user(self, user): """Check policies for user""" self.__user = user return self def with_request(self, request): """Set request""" self.__request = request return self @profileit def build(self): """Build task group""" if not self.__user: raise ValueError("User not set.") signatures = [] cached_policies = [] kwargs = { '__password__': getattr(self.__user, '__password__', None), } if self.__request: kwargs['remote_ip'], _ = get_client_ip(self.__request) if not kwargs['remote_ip']: kwargs['remote_ip'] = '255.255.255.255' for policy in self.policies: cached_policy = cache.get(_cache_key(policy, self.__user), None) if cached_policy: LOGGER.warning("Taking result from cache for %s", policy.pk.hex) cached_policies.append(cached_policy) else: LOGGER.warning("Evaluating policy %s", policy.pk.hex) signatures.append(_policy_engine_task.signature( args=(self.__user.pk, policy.pk.hex), kwargs=kwargs, time_limit=policy.timeout)) self.__get_timeout += policy.timeout LOGGER.warning("Set total policy timeout to %r", self.__get_timeout) # If all policies are cached, we have an empty list here. if signatures: self.__group = group(signatures)() self.__get_timeout += 3 self.__get_timeout = (self.__get_timeout / len(self.policies)) * 1.5 self.__cached = cached_policies return self @property def result(self): """Get policy-checking result""" messages = [] result = [] try: if self.__group: # ValueError can be thrown from _policy_engine_task when user is None result += self.__group.get(timeout=self.__get_timeout) result += self.__cached except ValueError as exc: # ValueError can be thrown from _policy_engine_task when user is None return False, [str(exc)] except UnexpectedFrame as exc: return False, [str(exc)] except CeleryTimeoutError as exc: return False, [str(exc)] for policy_action, policy_result, policy_message in result: passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ (policy_action == Policy.ACTION_DENY and not policy_result) LOGGER.debug('Action=%s, Result=%r => %r', policy_action, policy_result, passing) if policy_message: messages.append(policy_message) if not passing: return False, messages return True, messages @property def passing(self): """Only get true/false if user passes""" return self.result[0]