root: use channel send workaround for sync sending of websocket messages

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-02-15 16:06:17 +01:00
parent 7f009f6d02
commit bff34cc5dc
No known key found for this signature in database
3 changed files with 25 additions and 28 deletions

View File

@ -7,7 +7,6 @@ from urllib.parse import urlparse
import yaml
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.core.cache import cache
from django.db import DatabaseError, InternalError, ProgrammingError
from django.db.models.base import Model
@ -43,6 +42,7 @@ from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesContro
from authentik.providers.proxy.controllers.docker import ProxyDockerController
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
from authentik.root.celery import CELERY_APP
from authentik.root.messages.storage import closing_send
LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s"
@ -217,26 +217,23 @@ def outpost_post_save(model_class: str, model_pk: Any):
def outpost_send_update(model_instace: Model):
"""Send outpost update to all registered outposts, regardless to which authentik
instance they are connected"""
channel_layer = get_channel_layer()
if isinstance(model_instace, OutpostModel):
for outpost in model_instace.outpost_set.all():
_outpost_single_update(outpost, channel_layer)
_outpost_single_update(outpost)
elif isinstance(model_instace, Outpost):
_outpost_single_update(model_instace, channel_layer)
_outpost_single_update(model_instace)
def _outpost_single_update(outpost: Outpost, layer=None):
def _outpost_single_update(outpost: Outpost):
"""Update outpost instances connected to a single outpost"""
# Ensure token again, because this function is called when anything related to an
# OutpostModel is saved, so we can be sure permissions are right
_ = outpost.token
outpost.build_user_permissions(outpost.user)
if not layer: # pragma: no cover
layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids:
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
async_to_sync(layer.send)(channel, {"type": "event.update"})
async_to_sync(closing_send)(channel, {"type": "event.update"})
@CELERY_APP.task()

View File

@ -1,6 +1,7 @@
"""Channels Messages storage"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from channels import DEFAULT_CHANNEL_LAYER
from channels.layers import channel_layers
from django.contrib.messages.storage.base import Message
from django.contrib.messages.storage.session import SessionStorage
from django.core.cache import cache
@ -10,13 +11,21 @@ SESSION_KEY = "_messages"
CACHE_PREFIX = "goauthentik.io/root/messages_"
async def closing_send(channel, message):
"""Wrapper around layer send that closes the connection"""
# See https://github.com/django/channels_redis/issues/332
# TODO: Remove this after channels_redis 4.1 is released
channel_layer = channel_layers.make_backend(DEFAULT_CHANNEL_LAYER)
await channel_layer.send(channel, message)
await channel_layer.close_pools()
class ChannelsStorage(SessionStorage):
"""Send contrib.messages over websocket"""
def __init__(self, request: HttpRequest) -> None:
# pyright: reportGeneralTypeIssues=false
super().__init__(request)
self.channel = get_channel_layer()
def _store(self, messages: list[Message], response, *args, **kwargs):
prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_"
@ -28,7 +37,7 @@ class ChannelsStorage(SessionStorage):
for key in keys:
uid = key.replace(prefix, "")
for message in messages:
async_to_sync(self.channel.send)(
async_to_sync(closing_send)(
uid,
{
"type": "event.update",

View File

@ -1,8 +1,5 @@
[tool.pyright]
ignore = [
"**/migrations/**",
"**/node_modules/**"
]
ignore = ["**/migrations/**", "**/node_modules/**"]
reportMissingTypeStubs = false
strictParameterNoneValue = true
strictDictionaryInference = true
@ -63,14 +60,7 @@ exclude_lines = [
show_missing = true
[tool.pylint.basic]
good-names = [
"pk",
"id",
"i",
"j",
"k",
"_",
]
good-names = ["pk", "id", "i", "j", "k", "_"]
[tool.pylint.master]
disable = [
@ -85,6 +75,7 @@ disable = [
"protected-access",
"unused-argument",
"raise-missing-from",
"fixme",
# To preserve django's translation function we need to use %-formatting
"consider-using-f-string",
]