core: cleanup channels code, fix error when server side close

This commit is contained in:
Jens Langhammer 2020-12-13 17:46:34 +01:00
parent 3e49acf7ae
commit 3b5e1c7b34
4 changed files with 26 additions and 14 deletions

View File

@ -1,4 +1,5 @@
"""Channels base classes""" """Channels base classes"""
from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer from channels.generic.websocket import JsonWebsocketConsumer
from structlog import get_logger from structlog import get_logger
@ -17,16 +18,13 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
headers = dict(self.scope["headers"]) headers = dict(self.scope["headers"])
if b"authorization" not in headers: if b"authorization" not in headers:
LOGGER.warning("WS Request without authorization header") LOGGER.warning("WS Request without authorization header")
self.close() raise DenyConnection()
return False
raw_header = headers[b"authorization"] raw_header = headers[b"authorization"]
token = token_from_header(raw_header) token = token_from_header(raw_header)
if not token: if not token:
LOGGER.warning("Failed to authenticate") LOGGER.warning("Failed to authenticate")
self.close() raise DenyConnection()
return False
self.user = token.user self.user = token.user
return True

View File

@ -2,8 +2,9 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from enum import IntEnum from enum import IntEnum
from typing import Any, Dict from typing import Any, Dict, Optional
from channels.exceptions import DenyConnection
from dacite import from_dict from dacite import from_dict
from dacite.data import Data from dacite.data import Data
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
@ -39,18 +40,16 @@ class WebsocketMessage:
class OutpostConsumer(AuthJsonConsumer): class OutpostConsumer(AuthJsonConsumer):
"""Handler for Outposts that connect over websockets for health checks and live updates""" """Handler for Outposts that connect over websockets for health checks and live updates"""
outpost: Outpost outpost: Optional[Outpost] = None
def connect(self): def connect(self):
if not super().connect(): super().connect()
return
uuid = self.scope["url_route"]["kwargs"]["pk"] uuid = self.scope["url_route"]["kwargs"]["pk"]
outpost = get_objects_for_user( outpost = get_objects_for_user(
self.user, "authentik_outposts.view_outpost" self.user, "authentik_outposts.view_outpost"
).filter(pk=uuid) ).filter(pk=uuid)
if not outpost.exists(): if not outpost.exists():
self.close() raise DenyConnection()
return
self.accept() self.accept()
self.outpost = outpost.first() self.outpost = outpost.first()
OutpostState( OutpostState(
@ -60,6 +59,7 @@ class OutpostConsumer(AuthJsonConsumer):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def disconnect(self, close_code): def disconnect(self, close_code):
if self.outpost:
OutpostState.for_channel(self.outpost, self.channel_name).delete() OutpostState.for_channel(self.outpost, self.channel_name).delete()
LOGGER.debug("removed channel from cache", channel_name=self.channel_name) LOGGER.debug("removed channel from cache", channel_name=self.channel_name)

View File

@ -97,7 +97,12 @@ def outpost_token_ensurer(self: MonitoredTask):
all_outposts = Outpost.objects.all() all_outposts = Outpost.objects.all()
for outpost in all_outposts: for outpost in all_outposts:
_ = outpost.token _ = outpost.token
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, f"Successfully checked {len(all_outposts)} Outposts.")) self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL,
[f"Successfully checked {len(all_outposts)} Outposts."],
)
)
@CELERY_APP.task() @CELERY_APP.task()

View File

@ -105,7 +105,16 @@ class ASGILogger:
# https://code.djangoproject.com/ticket/31508 # https://code.djangoproject.com/ticket/31508
# https://github.com/encode/uvicorn/issues/266 # https://github.com/encode/uvicorn/issues/266
return return
try:
await self.app(scope, receive, send_hooked) await self.app(scope, receive, send_hooked)
except TypeError as exc:
# https://github.com/encode/uvicorn/issues/244
if exc.args == (
"An asyncio.Future, a coroutine or an awaitable is required",
):
pass
else:
raise exc
def _get_ip(self) -> str: def _get_ip(self) -> str:
client_ip = None client_ip = None