"""Outpost websocket handler"""
from dataclasses import asdict, dataclass, field
from enum import IntEnum
from time import time
from typing import Any, Dict

from channels.generic.websocket import JsonWebsocketConsumer
from dacite import from_dict
from dacite.data import Data
from django.core.cache import cache
from django.core.exceptions import ValidationError
from structlog import get_logger

from passbook.core.models import Token, TokenIntents
from passbook.outposts.models import Outpost

LOGGER = get_logger()


class WebsocketMessageInstruction(IntEnum):
    """Commands which can be triggered over Websocket"""

    # Simple message used by either side when a message is acknowledged
    ACK = 0

    # Message used by outposts to report their alive status
    HELLO = 1

    # Message sent by us to trigger an Update
    TRIGGER_UPDATE = 2


@dataclass
class WebsocketMessage:
    """Complete Websocket Message that is being sent"""

    instruction: int
    args: Dict[str, Any] = field(default_factory=dict)


class OutpostConsumer(JsonWebsocketConsumer):
    """Handler for Outposts that connect over websockets for health checks and live updates"""

    outpost: Outpost

    def connect(self):
        # TODO: This authentication block could be handeled in middleware
        headers = dict(self.scope["headers"])
        if b"authorization" not in headers:
            LOGGER.warning("WS Request without authorization header")
            self.close()

        token = headers[b"authorization"]
        try:
            token_uuid = token.decode("utf-8")
            tokens = Token.filter_not_expired(
                token_uuid=token_uuid, intent=TokenIntents.INTENT_API
            )
            if not tokens.exists():
                LOGGER.warning("WS Request with invalid token")
                self.close()
        except ValidationError:
            LOGGER.warning("WS Invalid UUID")
            self.close()

        uuid = self.scope["url_route"]["kwargs"]["pk"]
        outpost = Outpost.objects.filter(pk=uuid)
        if not outpost.exists():
            self.close()
            return
        self.accept()
        self.outpost = outpost.first()
        self.outpost.channels.append(self.channel_name)
        LOGGER.debug("added channel to outpost", channel_name=self.channel_name)
        self.outpost.save()

    # pylint: disable=unused-argument
    def disconnect(self, close_code):
        self.outpost.channels.remove(self.channel_name)
        self.outpost.save()
        LOGGER.debug("removed channel from outpost", channel_name=self.channel_name)

    def receive_json(self, content: Data):
        msg = from_dict(WebsocketMessage, content)
        if msg.instruction == WebsocketMessageInstruction.HELLO:
            cache.set(self.outpost.health_cache_key, time(), timeout=60)
        elif msg.instruction == WebsocketMessageInstruction.ACK:
            return

        response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK)
        self.send_json(asdict(response))

    # pylint: disable=unused-argument
    def event_update(self, event):
        """Event handler which is called by post_save signals"""
        self.send_json(
            asdict(
                WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)
            )
        )