diff --git a/passbook/core/channels.py b/passbook/core/channels.py new file mode 100644 index 000000000..1f1d732f4 --- /dev/null +++ b/passbook/core/channels.py @@ -0,0 +1,37 @@ +"""Channels base classes""" +from channels.generic.websocket import JsonWebsocketConsumer +from django.core.exceptions import ValidationError +from structlog import get_logger + +from passbook.core.models import Token, TokenIntents, User + +LOGGER = get_logger() + + +class AuthJsonConsumer(JsonWebsocketConsumer): + """Authorize a client with a token""" + + user: User + + def connect(self): + headers = dict(self.scope["headers"]) + if b"authorization" not in headers: + LOGGER.warning("WS Request without authorization header") + self.close() + + token = headers[b"authorization"] + try: + token_uuid = token.decode("utf-8") + tokens = Token.filter_not_expired( + token_uuid=token_uuid, intent=TokenIntents.INTENT_API + ) + if not tokens.exists(): + LOGGER.warning("WS Request with invalid token") + self.close() + return False + except ValidationError: + LOGGER.warning("WS Invalid UUID") + self.close() + return False + self.user = tokens.first().user + return True diff --git a/passbook/outposts/channels.py b/passbook/outposts/channels.py index 8d92bcbbf..9c4b83c4c 100644 --- a/passbook/outposts/channels.py +++ b/passbook/outposts/channels.py @@ -4,14 +4,13 @@ from enum import IntEnum from time import time from typing import Any, Dict -from channels.generic.websocket import JsonWebsocketConsumer from dacite import from_dict from dacite.data import Data from django.core.cache import cache -from django.core.exceptions import ValidationError +from guardian.shortcuts import get_objects_for_user from structlog import get_logger -from passbook.core.models import Token, TokenIntents +from passbook.core.channels import AuthJsonConsumer from passbook.outposts.models import Outpost LOGGER = get_logger() @@ -38,33 +37,18 @@ class WebsocketMessage: args: Dict[str, Any] = field(default_factory=dict) -class OutpostConsumer(JsonWebsocketConsumer): +class OutpostConsumer(AuthJsonConsumer): """Handler for Outposts that connect over websockets for health checks and live updates""" outpost: Outpost def connect(self): - # TODO: This authentication block could be handeled in middleware - headers = dict(self.scope["headers"]) - if b"authorization" not in headers: - LOGGER.warning("WS Request without authorization header") - self.close() - - token = headers[b"authorization"] - try: - token_uuid = token.decode("utf-8") - tokens = Token.filter_not_expired( - token_uuid=token_uuid, intent=TokenIntents.INTENT_API - ) - if not tokens.exists(): - LOGGER.warning("WS Request with invalid token") - self.close() - except ValidationError: - LOGGER.warning("WS Invalid UUID") - self.close() - + if not super().connect(): + return uuid = self.scope["url_route"]["kwargs"]["pk"] - outpost = Outpost.objects.filter(pk=uuid) + outpost = get_objects_for_user( + self.user, "passbook_outposts.view_outpost" + ).filter(pk=uuid) if not outpost.exists(): self.close() return