diff --git a/authentik/outposts/channels.py b/authentik/outposts/consumer.py similarity index 79% rename from authentik/outposts/channels.py rename to authentik/outposts/consumer.py index f0b656a47..e8c2ee127 100644 --- a/authentik/outposts/channels.py +++ b/authentik/outposts/consumer.py @@ -4,6 +4,7 @@ from datetime import datetime from enum import IntEnum from typing import Any, Optional +from asgiref.sync import async_to_sync from channels.exceptions import DenyConnection from dacite.core import from_dict from dacite.data import Data @@ -14,6 +15,8 @@ from authentik.core.channels import AuthJsonConsumer from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState +OUTPOST_GROUP = "group_outpost_%(outpost_pk)s" + class WebsocketMessageInstruction(IntEnum): """Commands which can be triggered over Websocket""" @@ -47,8 +50,6 @@ class OutpostConsumer(AuthJsonConsumer): last_uid: Optional[str] = None - first_msg = False - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.logger = get_logger() @@ -71,22 +72,26 @@ class OutpostConsumer(AuthJsonConsumer): raise DenyConnection() self.outpost = outpost self.last_uid = self.channel_name + async_to_sync(self.channel_layer.group_add)( + OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name + ) + GAUGE_OUTPOSTS_CONNECTED.labels( + outpost=self.outpost.name, + uid=self.last_uid, + expected=self.outpost.config.kubernetes_replicas, + ).inc() def disconnect(self, code): + if self.outpost: + async_to_sync(self.channel_layer.group_discard)( + OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name + ) if self.outpost and self.last_uid: - state = OutpostState.for_instance_uid(self.outpost, self.last_uid) - if self.channel_name in state.channel_ids: - state.channel_ids.remove(self.channel_name) - state.save() GAUGE_OUTPOSTS_CONNECTED.labels( outpost=self.outpost.name, uid=self.last_uid, expected=self.outpost.config.kubernetes_replicas, ).dec() - self.logger.debug( - "removed outpost instance from cache", - instance_uuid=self.last_uid, - ) def receive_json(self, content: Data): msg = from_dict(WebsocketMessage, content) @@ -97,26 +102,13 @@ class OutpostConsumer(AuthJsonConsumer): raise DenyConnection() state = OutpostState.for_instance_uid(self.outpost, uid) - if self.channel_name not in state.channel_ids: - state.channel_ids.append(self.channel_name) state.last_seen = datetime.now() - state.hostname = msg.args.get("hostname", "") - - if not self.first_msg: - GAUGE_OUTPOSTS_CONNECTED.labels( - outpost=self.outpost.name, - uid=self.last_uid, - expected=self.outpost.config.kubernetes_replicas, - ).inc() - self.logger.debug( - "added outpost instance to cache", - instance_uuid=self.last_uid, - ) - self.first_msg = True + state.hostname = msg.args.pop("hostname", "") if msg.instruction == WebsocketMessageInstruction.HELLO: - state.version = msg.args.get("version", None) - state.build_hash = msg.args.get("buildHash", "") + state.version = msg.args.pop("version", None) + state.build_hash = msg.args.pop("buildHash", "") + state.args = msg.args elif msg.instruction == WebsocketMessageInstruction.ACK: return GAUGE_OUTPOSTS_LAST_UPDATE.labels( diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 3caae7e73..878a3e9e6 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -411,12 +411,12 @@ class OutpostState: """Outpost instance state, last_seen and version""" uid: str - channel_ids: list[str] = field(default_factory=list) last_seen: Optional[datetime] = field(default=None) version: Optional[str] = field(default=None) version_should: Version = field(default=OUR_VERSION) build_hash: str = field(default="") hostname: str = field(default="") + args: dict = field(default_factory=dict) _outpost: Optional[Outpost] = field(default=None) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index ddb0d5352..b6b3a9bab 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -25,6 +25,7 @@ from authentik.events.monitored_tasks import ( ) from authentik.lib.config import CONFIG from authentik.lib.utils.reflection import path_to_class +from authentik.outposts.consumer import OUTPOST_GROUP from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.docker import DockerClient from authentik.outposts.controllers.kubernetes import KubernetesClient @@ -34,7 +35,6 @@ from authentik.outposts.models import ( Outpost, OutpostModel, OutpostServiceConnection, - OutpostState, OutpostType, ServiceConnectionInvalid, ) @@ -243,10 +243,9 @@ def _outpost_single_update(outpost: Outpost, layer=None): outpost.build_user_permissions(outpost.user) if not layer: # pragma: no cover layer = get_channel_layer() - for state in OutpostState.for_outpost(outpost): - for channel in state.channel_ids: - LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) - async_to_sync(layer.send)(channel, {"type": "event.update"}) + group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} + LOGGER.debug("sending update", channel=group, outpost=outpost) + async_to_sync(layer.group_send)(group, {"type": "event.update"}) @CELERY_APP.task( diff --git a/authentik/outposts/tests/test_ws.py b/authentik/outposts/tests/test_ws.py index 9d8546044..b8fcba925 100644 --- a/authentik/outposts/tests/test_ws.py +++ b/authentik/outposts/tests/test_ws.py @@ -7,7 +7,7 @@ from django.test import TransactionTestCase from authentik import __version__ from authentik.core.tests.utils import create_test_flow -from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction +from authentik.outposts.consumer import WebsocketMessage, WebsocketMessageInstruction from authentik.outposts.models import Outpost, OutpostType from authentik.providers.proxy.models import ProxyProvider from authentik.root import websocket diff --git a/authentik/outposts/urls.py b/authentik/outposts/urls.py index 353dfd13c..cd7ba3bf8 100644 --- a/authentik/outposts/urls.py +++ b/authentik/outposts/urls.py @@ -7,7 +7,7 @@ from authentik.outposts.api.service_connections import ( KubernetesServiceConnectionViewSet, ServiceConnectionViewSet, ) -from authentik.outposts.channels import OutpostConsumer +from authentik.outposts.consumer import OutpostConsumer from authentik.root.middleware import ChannelsLoggingMiddleware websocket_urlpatterns = [ diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py index 630b0d186..aec8e669a 100644 --- a/authentik/providers/proxy/tasks.py +++ b/authentik/providers/proxy/tasks.py @@ -3,7 +3,8 @@ from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.db import DatabaseError, InternalError, ProgrammingError -from authentik.outposts.models import Outpost, OutpostState, OutpostType +from authentik.outposts.consumer import OUTPOST_GROUP +from authentik.outposts.models import Outpost, OutpostType from authentik.providers.proxy.models import ProxyProvider from authentik.root.celery import CELERY_APP @@ -23,13 +24,12 @@ def proxy_on_logout(session_id: str): """Update outpost instances connected to a single outpost""" layer = get_channel_layer() for outpost in Outpost.objects.filter(type=OutpostType.PROXY): - for state in OutpostState.for_outpost(outpost): - for channel in state.channel_ids: - async_to_sync(layer.send)( - channel, - { - "type": "event.provider.specific", - "sub_type": "logout", - "session_id": session_id, - }, - ) + group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} + async_to_sync(layer.group_send)( + group, + { + "type": "event.provider.specific", + "sub_type": "logout", + "session_id": session_id, + }, + ) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index a7ed583ae..3cb081862 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -253,10 +253,10 @@ ASGI_APPLICATION = "authentik.root.asgi.application" CHANNEL_LAYERS = { "default": { - "BACKEND": "channels_redis.core.RedisChannelLayer", + "BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer", "CONFIG": { "hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"], - "prefix": "authentik_channels", + "prefix": "authentik_channels_", }, }, }