Merge branch 'main' into multi-tenant-django-tenants

This commit is contained in:
Marc 'risson' Schmitt 2024-01-03 12:22:25 +01:00
commit f35f86442b
No known key found for this signature in database
GPG key ID: 9C3FA22FABF1AA8D
148 changed files with 77795 additions and 49039 deletions

View file

@ -9,3 +9,4 @@ blueprints/local
.git
!gen-ts-api/node_modules
!gen-ts-api/dist/**
!gen-go-api/

View file

@ -2,3 +2,4 @@ keypair
keypairs
hass
warmup
ontext

View file

@ -249,12 +249,6 @@ jobs:
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Comment on PR
if: github.event_name == 'pull_request'
continue-on-error: true
uses: ./.github/actions/comment-pr-instructions
with:
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
build-arm64:
needs: ci-core-mark
runs-on: ubuntu-latest
@ -303,3 +297,26 @@ jobs:
platforms: linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max
pr-comment:
needs:
- build
- build-arm64
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request' }}
permissions:
# Needed to write comments on PRs
pull-requests: write
timeout-minutes: 120
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: prepare variables
uses: ./.github/actions/docker-push-variables
id: ev
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
- name: Comment on PR
uses: ./.github/actions/comment-pr-instructions
with:
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}

View file

@ -65,6 +65,7 @@ jobs:
- proxy
- ldap
- radius
- rac
runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
@ -119,6 +120,7 @@ jobs:
- proxy
- ldap
- radius
- rac
goos: [linux]
goarch: [amd64, arm64]
steps:

View file

@ -65,6 +65,7 @@ jobs:
- proxy
- ldap
- radius
- rac
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5

View file

@ -58,7 +58,7 @@ test: ## Run the server tests and produce a coverage report (locally)
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
isort $(PY_SOURCES)
black $(PY_SOURCES)
ruff $(PY_SOURCES)
ruff --fix $(PY_SOURCES)
codespell -w $(CODESPELL_ARGS)
lint: ## Lint the python and golang sources

View file

@ -42,7 +42,7 @@ class ManagedAppConfig(AppConfig):
meth()
self._logger.debug("Successfully reconciled", name=name)
except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.debug("Failed to run reconcile", name=name, exc=exc)
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
def reconcile_tenant(self) -> None:
"""reconcile ourselves for tenanted methods"""

View file

@ -1,22 +1,29 @@
"""Channels base classes"""
from channels.db import database_sync_to_async
from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer
from rest_framework.exceptions import AuthenticationFailed
from structlog.stdlib import get_logger
from authentik.api.authentication import bearer_auth
from authentik.core.models import User
LOGGER = get_logger()
class AuthJsonConsumer(JsonWebsocketConsumer):
class TokenOutpostMiddleware:
"""Authorize a client with a token"""
user: User
def __init__(self, inner):
self.inner = inner
def connect(self):
headers = dict(self.scope["headers"])
async def __call__(self, scope, receive, send):
scope = dict(scope)
await self.auth(scope)
return await self.inner(scope, receive, send)
@database_sync_to_async
def auth(self, scope):
"""Authenticate request from header"""
headers = dict(scope["headers"])
if b"authorization" not in headers:
LOGGER.warning("WS Request without authorization header")
raise DenyConnection()
@ -32,4 +39,4 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
LOGGER.warning("Failed to authenticate", exc=exc)
raise DenyConnection()
self.user = user
scope["user"] = user

View file

@ -22,6 +22,7 @@ class InterfaceView(TemplateView):
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
kwargs["build"] = get_build_hash()
kwargs["url_kwargs"] = self.kwargs
return super().get_context_data(**kwargs)

View file

@ -1,6 +1,8 @@
"""Enterprise license policies"""
from typing import Optional
from django.utils.translation import gettext_lazy as _
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import LicenseKey
from authentik.policies.types import PolicyRequest, PolicyResult
@ -13,10 +15,10 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self):
"""Check license"""
if not LicenseKey.get_total().is_valid():
return False
return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL:
return False
return True
return PolicyResult(False, _("Feature only accessible for internal users."))
return PolicyResult(True)
def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
user = user or self.request.user
@ -24,7 +26,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
request.http_request = self.request
result = super().user_has_access(user)
enterprise_result = self.check_license()
if not enterprise_result:
if not enterprise_result.passing:
return enterprise_result
return result

View file

@ -0,0 +1,133 @@
"""RAC Provider API Views"""
from typing import Optional
from django.core.cache import cache
from django.db.models import QuerySet
from django.urls import reverse
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework.fields import SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.core.api.used_by import UsedByMixin
from authentik.core.models import Provider
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
from authentik.enterprise.providers.rac.models import Endpoint
from authentik.policies.engine import PolicyEngine
from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger()
def user_endpoint_cache_key(user_pk: str) -> str:
"""Cache key where endpoint list for user is saved"""
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
class EndpointSerializer(ModelSerializer):
"""Endpoint Serializer"""
provider_obj = RACProviderSerializer(source="provider", read_only=True)
launch_url = SerializerMethodField()
def get_launch_url(self, endpoint: Endpoint) -> Optional[str]:
"""Build actual launch URL (the provider itself does not have one, just
individual endpoints)"""
try:
# pylint: disable=no-member
return reverse(
"authentik_providers_rac:start",
kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk},
)
except Provider.application.RelatedObjectDoesNotExist:
return None
class Meta:
model = Endpoint
fields = [
"pk",
"name",
"provider",
"provider_obj",
"protocol",
"host",
"settings",
"property_mappings",
"auth_mode",
"launch_url",
]
class EndpointViewSet(UsedByMixin, ModelViewSet):
"""Endpoint Viewset"""
queryset = Endpoint.objects.all()
serializer_class = EndpointSerializer
filterset_fields = ["name", "provider"]
search_fields = ["name", "protocol"]
ordering = ["name", "protocol"]
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
for backend in list(self.filter_backends):
if backend == ObjectFilter:
continue
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
def _get_allowed_endpoints(self, queryset: QuerySet) -> list[Endpoint]:
endpoints = []
for endpoint in queryset:
engine = PolicyEngine(endpoint, self.request.user, self.request)
engine.build()
if engine.passing:
endpoints.append(endpoint)
return endpoints
@extend_schema(
parameters=[
OpenApiParameter(
"search",
OpenApiTypes.STR,
),
OpenApiParameter(
name="superuser_full_list",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.BOOL,
),
],
responses={
200: EndpointSerializer(many=True),
400: OpenApiResponse(description="Bad request"),
},
)
def list(self, request: Request, *args, **kwargs) -> Response:
"""List accessible endpoints"""
should_cache = request.GET.get("search", "") == ""
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
if superuser_full_list and request.user.is_superuser:
return super().list(request)
queryset = self._filter_queryset_for_list(self.get_queryset())
self.paginate_queryset(queryset)
allowed_endpoints = []
if not should_cache:
allowed_endpoints = self._get_allowed_endpoints(queryset)
if should_cache:
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
if not allowed_endpoints:
LOGGER.debug("Caching allowed endpoint list")
allowed_endpoints = self._get_allowed_endpoints(queryset)
cache.set(
user_endpoint_cache_key(self.request.user.pk),
allowed_endpoints,
timeout=86400,
)
serializer = self.get_serializer(allowed_endpoints, many=True)
return self.get_paginated_response(serializer.data)

View file

@ -0,0 +1,35 @@
"""RAC Provider API Views"""
from rest_framework.fields import CharField
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.propertymappings import PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField
from authentik.enterprise.providers.rac.models import RACPropertyMapping
class RACPropertyMappingSerializer(PropertyMappingSerializer):
"""RACPropertyMapping Serializer"""
static_settings = JSONDictField()
expression = CharField(allow_blank=True, required=False)
def validate_expression(self, expression: str) -> str:
"""Test Syntax"""
if expression == "":
return expression
return super().validate_expression(expression)
class Meta:
model = RACPropertyMapping
fields = PropertyMappingSerializer.Meta.fields + ["static_settings"]
class RACPropertyMappingViewSet(UsedByMixin, ModelViewSet):
"""RACPropertyMapping Viewset"""
queryset = RACPropertyMapping.objects.all()
serializer_class = RACPropertyMappingSerializer
search_fields = ["name"]
ordering = ["name"]
filterset_fields = ["name", "managed"]

View file

@ -0,0 +1,31 @@
"""RAC Provider API Views"""
from rest_framework.fields import CharField, ListField
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.enterprise.providers.rac.models import RACProvider
class RACProviderSerializer(ProviderSerializer):
"""RACProvider Serializer"""
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
class Meta:
model = RACProvider
fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
class RACProviderViewSet(UsedByMixin, ModelViewSet):
"""RACProvider Viewset"""
queryset = RACProvider.objects.all()
serializer_class = RACProviderSerializer
filterset_fields = {
"application": ["isnull"],
"name": ["iexact"],
}
search_fields = ["name"]
ordering = ["name"]

View file

@ -0,0 +1,17 @@
"""RAC app config"""
from authentik.blueprints.apps import ManagedAppConfig
class AuthentikEnterpriseProviderRAC(ManagedAppConfig):
"""authentik enterprise rac app config"""
name = "authentik.enterprise.providers.rac"
label = "authentik_providers_rac"
verbose_name = "authentik Enterprise.Providers.RAC"
default = True
mountpoint = ""
ws_mountpoint = "authentik.enterprise.providers.rac.urls"
def reconcile_load_rac_signals(self):
"""Load rac signals"""
self.import_module("authentik.enterprise.providers.rac.signals")

View file

@ -0,0 +1,163 @@
"""RAC Client consumer"""
from asgiref.sync import async_to_sync
from channels.db import database_sync_to_async
from channels.exceptions import ChannelFull, DenyConnection
from channels.generic.websocket import AsyncWebsocketConsumer
from django.http.request import QueryDict
from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.providers.rac.models import ConnectionToken, RACProvider
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
from authentik.outposts.models import Outpost, OutpostState, OutpostType
# Global broadcast group, which messages are sent to when the outpost connects back
# to authentik for a specific connection
# The `RACClientConsumer` consumer adds itself to this group on connection,
# and removes itself once it has been assigned a specific outpost channel
RAC_CLIENT_GROUP = "group_enterprise_rac_client"
# A group for all connections in a given authentik session ID
# A disconnect message is sent to this group when the session expires/is deleted
RAC_CLIENT_GROUP_SESSION = "group_enterprise_rac_client_%(session)s"
# A group for all connections with a specific token, which in almost all cases
# is just one connection, however this is used to disconnect the connection
# when the token is deleted
RAC_CLIENT_GROUP_TOKEN = "group_enterprise_rac_token_%(token)s" # nosec
# Step 1: Client connects to this websocket endpoint
# Step 2: We prepare all the connection args for Guac
# Step 3: Send a websocket message to a single outpost that has this provider assigned
# (Currently sending to all of them)
# (Should probably do different load balancing algorithms)
# Step 4: Outpost creates a websocket connection back to authentik
# with /ws/outpost_rac/<our_channel_id>/
# Step 5: This consumer transfers data between the two channels
class RACClientConsumer(AsyncWebsocketConsumer):
"""RAC client consumer the browser connects to"""
dest_channel_id: str = ""
provider: RACProvider
token: ConnectionToken
logger: BoundLogger
async def connect(self):
await self.accept("guacamole")
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
await self.channel_layer.group_add(
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
self.channel_name,
)
await self.init_outpost_connection()
async def disconnect(self, code):
self.logger.debug("Disconnecting")
# Tell the outpost we're disconnecting
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.disconnect",
},
)
@database_sync_to_async
def init_outpost_connection(self):
"""Initialize guac connection settings"""
self.token = ConnectionToken.filter_not_expired(
token=self.scope["url_route"]["kwargs"]["token"]
).first()
if not self.token:
raise DenyConnection()
self.provider = self.token.provider
params = self.token.get_settings()
self.logger = get_logger().bind(
endpoint=self.token.endpoint.name, user=self.scope["user"].username
)
msg = {
"type": "event.provider.specific",
"sub_type": "init_connection",
"dest_channel_id": self.channel_name,
"params": params,
"protocol": self.token.endpoint.protocol,
}
query = QueryDict(self.scope["query_string"].decode())
for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
value = query.get(key, None)
if not value:
continue
msg[key] = str(value)
outposts = Outpost.objects.filter(
type=OutpostType.RAC,
providers__in=[self.provider],
)
if not outposts.exists():
self.logger.warning("Provider has no outpost")
raise DenyConnection()
for outpost in outposts:
# Sort all states for the outpost by connection count
states = sorted(
OutpostState.for_outpost(outpost),
key=lambda state: int(state.args.get("active_connections", 0)),
)
if len(states) < 1:
continue
self.logger.debug("Sending out connection broadcast")
async_to_sync(self.channel_layer.group_send)(
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
msg,
)
async def receive(self, text_data=None, bytes_data=None):
"""Mirror data received from client to the dest_channel_id
which is the channel talking to guacd"""
if self.dest_channel_id == "":
return
if self.token.is_expired:
await self.event_disconnect({"reason": "token_expiry"})
return
try:
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.send",
"text_data": text_data,
"bytes_data": bytes_data,
},
)
except ChannelFull:
pass
async def event_outpost_connected(self, event: dict):
"""Handle event broadcasted from outpost consumer, and check if they
created a connection for us"""
outpost_channel = event.get("outpost_channel")
if event.get("client_channel") != self.channel_name:
return
if self.dest_channel_id != "":
# We've already selected an outpost channel, so tell the other channel to disconnect
# This should never happen since we remove ourselves from the broadcast group
await self.channel_layer.send(
outpost_channel,
{
"type": "event.disconnect",
},
)
return
self.logger.debug("Connected to a single outpost instance")
self.dest_channel_id = outpost_channel
# Since we have a specific outpost channel now, we can remove
# ourselves from the global broadcast group
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
async def event_send(self, event: dict):
"""Handler called by outpost websocket that sends data to this specific
client connection"""
if self.token.is_expired:
await self.event_disconnect({"reason": "token_expiry"})
return
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
async def event_disconnect(self, event: dict):
"""Disconnect when the session ends"""
self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
await self.close()

View file

@ -0,0 +1,48 @@
"""RAC consumer"""
from channels.exceptions import ChannelFull
from channels.generic.websocket import AsyncWebsocketConsumer
from authentik.enterprise.providers.rac.consumer_client import RAC_CLIENT_GROUP
class RACOutpostConsumer(AsyncWebsocketConsumer):
"""Consumer the outpost connects to, to send specific data back to a client connection"""
dest_channel_id: str
async def connect(self):
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
await self.accept()
await self.channel_layer.group_send(
RAC_CLIENT_GROUP,
{
"type": "event.outpost.connected",
"outpost_channel": self.channel_name,
"client_channel": self.dest_channel_id,
},
)
async def receive(self, text_data=None, bytes_data=None):
"""Mirror data received from guacd running in the outpost
to the dest_channel_id which is the channel talking to the browser"""
try:
await self.channel_layer.send(
self.dest_channel_id,
{
"type": "event.send",
"text_data": text_data,
"bytes_data": bytes_data,
},
)
except ChannelFull:
pass
async def event_send(self, event: dict):
"""Handler called by client websocket that sends data to this specific
outpost connection"""
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
async def event_disconnect(self, event: dict):
"""Tell outpost we're about to disconnect"""
await self.send(text_data="0.authentik.disconnect")
await self.close()

View file

@ -0,0 +1,11 @@
"""RAC Provider Docker Controller"""
from authentik.outposts.controllers.docker import DockerController
from authentik.outposts.models import DockerServiceConnection, Outpost
class RACDockerController(DockerController):
"""RAC Provider Docker Controller"""
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []

View file

@ -0,0 +1,13 @@
"""RAC Provider Kubernetes Controller"""
from authentik.outposts.controllers.k8s.service import ServiceReconciler
from authentik.outposts.controllers.kubernetes import KubernetesController
from authentik.outposts.models import KubernetesServiceConnection, Outpost
class RACKubernetesController(KubernetesController):
"""RAC Provider Kubernetes Controller"""
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
super().__init__(outpost, connection)
self.deployment_ports = []
del self.reconcilers[ServiceReconciler.reconciler_name()]

View file

@ -0,0 +1,164 @@
# Generated by Django 4.2.8 on 2023-12-29 15:58
import uuid
import django.db.models.deletion
from django.db import migrations, models
import authentik.core.models
import authentik.lib.utils.time
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_policies", "0011_policybinding_failure_result_and_more"),
("authentik_core", "0032_group_roles"),
]
operations = [
migrations.CreateModel(
name="RACPropertyMapping",
fields=[
(
"propertymapping_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.propertymapping",
),
),
("static_settings", models.JSONField(default=dict)),
],
options={
"verbose_name": "RAC Property Mapping",
"verbose_name_plural": "RAC Property Mappings",
},
bases=("authentik_core.propertymapping",),
),
migrations.CreateModel(
name="RACProvider",
fields=[
(
"provider_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.provider",
),
),
("settings", models.JSONField(default=dict)),
(
"auth_mode",
models.TextField(
choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt"
),
),
(
"connection_expiry",
models.TextField(
default="hours=8",
help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)",
validators=[authentik.lib.utils.time.timedelta_string_validator],
),
),
],
options={
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
},
bases=("authentik_core.provider",),
),
migrations.CreateModel(
name="Endpoint",
fields=[
(
"policybindingmodel_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_policies.policybindingmodel",
),
),
("name", models.TextField()),
("host", models.TextField()),
(
"protocol",
models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]),
),
("settings", models.JSONField(default=dict)),
(
"auth_mode",
models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]),
),
(
"property_mappings",
models.ManyToManyField(
blank=True, default=None, to="authentik_core.propertymapping"
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.racprovider",
),
),
],
options={
"verbose_name": "RAC Endpoint",
"verbose_name_plural": "RAC Endpoints",
},
bases=("authentik_policies.policybindingmodel", models.Model),
),
migrations.CreateModel(
name="ConnectionToken",
fields=[
(
"expires",
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
(
"connection_token_uuid",
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
),
("token", models.TextField(default=authentik.core.models.default_token_key)),
("settings", models.JSONField(default=dict)),
(
"endpoint",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.endpoint",
),
),
(
"provider",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_rac.racprovider",
),
),
(
"session",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_core.authenticatedsession",
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -0,0 +1,191 @@
"""RAC Models"""
from typing import Optional
from uuid import uuid4
from deepmerge import always_merger
from django.db import models
from django.db.models import QuerySet
from django.utils.translation import gettext as _
from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.exceptions import PropertyMappingExpressionException
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key
from authentik.events.models import Event, EventAction
from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_string_validator
from authentik.policies.models import PolicyBindingModel
LOGGER = get_logger()
class Protocols(models.TextChoices):
"""Supported protocols"""
RDP = "rdp"
VNC = "vnc"
SSH = "ssh"
class AuthenticationMode(models.TextChoices):
"""Authentication modes"""
STATIC = "static"
PROMPT = "prompt"
class RACProvider(Provider):
"""Remotely access computers/servers"""
settings = models.JSONField(default=dict)
auth_mode = models.TextField(
choices=AuthenticationMode.choices, default=AuthenticationMode.PROMPT
)
connection_expiry = models.TextField(
default="hours=8",
validators=[timedelta_string_validator],
help_text=_(
"Determines how long a session lasts. Default of 0 means "
"that the sessions lasts until the browser is closed. "
"(Format: hours=-1;minutes=-2;seconds=-3)"
),
)
@property
def launch_url(self) -> Optional[str]:
"""URL to this provider and initiate authorization for the user.
Can return None for providers that are not URL-based"""
return "goauthentik.io://providers/rac/launch"
@property
def component(self) -> str:
return "ak-provider-rac-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
return RACProviderSerializer
class Meta:
verbose_name = _("RAC Provider")
verbose_name_plural = _("RAC Providers")
class Endpoint(SerializerModel, PolicyBindingModel):
"""Remote-accessible endpoint"""
name = models.TextField()
host = models.TextField()
protocol = models.TextField(choices=Protocols.choices)
settings = models.JSONField(default=dict)
auth_mode = models.TextField(choices=AuthenticationMode.choices)
provider = models.ForeignKey("RACProvider", on_delete=models.CASCADE)
property_mappings = models.ManyToManyField(
"authentik_core.PropertyMapping", default=None, blank=True
)
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer
return EndpointSerializer
def __str__(self):
return f"RAC Endpoint {self.name}"
class Meta:
verbose_name = _("RAC Endpoint")
verbose_name_plural = _("RAC Endpoints")
class RACPropertyMapping(PropertyMapping):
"""Configure settings for remote access endpoints."""
static_settings = models.JSONField(default=dict)
@property
def component(self) -> str:
return "ak-property-mapping-rac-form"
@property
def serializer(self) -> type[Serializer]:
from authentik.enterprise.providers.rac.api.property_mappings import (
RACPropertyMappingSerializer,
)
return RACPropertyMappingSerializer
class Meta:
verbose_name = _("RAC Property Mapping")
verbose_name_plural = _("RAC Property Mappings")
class ConnectionToken(ExpiringModel):
"""Token for a single connection to a specified endpoint"""
connection_token_uuid = models.UUIDField(default=uuid4, primary_key=True)
provider = models.ForeignKey(RACProvider, on_delete=models.CASCADE)
endpoint = models.ForeignKey(Endpoint, on_delete=models.CASCADE)
token = models.TextField(default=default_token_key)
settings = models.JSONField(default=dict)
session = models.ForeignKey("authentik_core.AuthenticatedSession", on_delete=models.CASCADE)
def get_settings(self) -> dict:
"""Get settings"""
default_settings = {}
if ":" in self.endpoint.host:
host, _, port = self.endpoint.host.partition(":")
default_settings["hostname"] = host
default_settings["port"] = str(port)
else:
default_settings["hostname"] = self.endpoint.host
default_settings["client-name"] = "authentik"
# default_settings["enable-drive"] = "true"
# default_settings["drive-name"] = "authentik"
settings = {}
always_merger.merge(settings, default_settings)
always_merger.merge(settings, self.endpoint.provider.settings)
always_merger.merge(settings, self.endpoint.settings)
always_merger.merge(settings, self.settings)
def mapping_evaluator(mappings: QuerySet):
for mapping in mappings:
mapping: RACPropertyMapping
if len(mapping.static_settings) > 0:
always_merger.merge(settings, mapping.static_settings)
continue
try:
mapping_settings = mapping.evaluate(
self.session.user, None, endpoint=self.endpoint, provider=self.provider
)
always_merger.merge(settings, mapping_settings)
except PropertyMappingExpressionException as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
message=f"Failed to evaluate property-mapping: '{mapping.name}'",
provider=self.provider,
mapping=mapping,
).set_user(self.session.user).save()
LOGGER.warning("Failed to evaluate property mapping", exc=exc)
mapping_evaluator(
RACPropertyMapping.objects.filter(provider__in=[self.provider]).order_by("name")
)
mapping_evaluator(
RACPropertyMapping.objects.filter(endpoint__in=[self.endpoint]).order_by("name")
)
settings["drive-path"] = f"/tmp/connection/{self.token}" # nosec
settings["create-drive-path"] = "true"
# Ensure all values of the settings dict are strings
for key, value in settings.items():
if isinstance(value, str):
continue
# Special case for bools
if isinstance(value, bool):
settings[key] = str(value).lower()
continue
settings[key] = str(value)
return settings

View file

@ -0,0 +1,54 @@
"""RAC Signals"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import User
from authentik.enterprise.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.enterprise.providers.rac.consumer_client import (
RAC_CLIENT_GROUP_SESSION,
RAC_CLIENT_GROUP_TOKEN,
)
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections"""
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION
% {
"session": request.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@receiver(pre_delete, sender=ConnectionToken)
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
"""Disconnect session when connection token is deleted"""
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_TOKEN
% {
"token": instance.token,
},
{"type": "event.disconnect", "reason": "token_delete"},
)
@receiver(post_save, sender=Endpoint)
def post_save_application(sender: type[Model], instance, created: bool, **_):
"""Clear user's application cache upon application creation"""
if not created: # pragma: no cover
return
# Delete user endpoint cache
keys = cache.keys(user_endpoint_cache_key("*"))
cache.delete_many(keys)

View file

@ -0,0 +1,18 @@
{% extends "base/skeleton.html" %}
{% load static %}
{% block head %}
<script src="{% static 'dist/enterprise/rac/index.js' %}?version={{ version }}" type="module"></script>
<meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)">
<meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)">
<link rel="icon" href="{{ tenant.branding_favicon }}">
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
{% include "base/header_js.html" %}
{% endblock %}
{% block body %}
<ak-rac token="{{ url_kwargs.token }}" endpointName="{{ token.endpoint.name }}">
<ak-loading></ak-loading>
</ak-rac>
{% endblock %}

View file

@ -0,0 +1,168 @@
"""Test Endpoints API"""
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
class TestEndpointsAPI(APITestCase):
"""Test endpoints API"""
def setUp(self) -> None:
self.user = create_test_admin_user()
self.provider = RACProvider.objects.create(
name=generate_id(),
)
self.app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
self.allowed = Endpoint.objects.create(
name=f"a-{generate_id()}",
host=generate_id(),
protocol=Protocols.RDP,
provider=self.provider,
)
self.denied = Endpoint.objects.create(
name=f"b-{generate_id()}",
host=generate_id(),
protocol=Protocols.RDP,
provider=self.provider,
)
PolicyBinding.objects.create(
target=self.denied,
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
def test_list(self):
"""Test list operation without superuser_full_list"""
self.client.force_login(self.user)
response = self.client.get(reverse("authentik_api:endpoint-list"))
self.assertJSONEqual(
response.content.decode(),
{
"pagination": {
"next": 0,
"previous": 0,
"count": 2,
"current": 1,
"total_pages": 1,
"start_index": 1,
"end_index": 2,
},
"results": [
{
"pk": str(self.allowed.pk),
"name": self.allowed.name,
"provider": self.provider.pk,
"provider_obj": {
"pk": self.provider.pk,
"name": self.provider.name,
"authentication_flow": None,
"authorization_flow": None,
"property_mappings": [],
"connection_expiry": "hours=8",
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
"settings": {},
"outpost_set": [],
},
"protocol": "rdp",
"host": self.allowed.host,
"settings": {},
"property_mappings": [],
"auth_mode": "",
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
},
],
},
)
def test_list_superuser_full_list(self):
"""Test list operation with superuser_full_list"""
self.client.force_login(self.user)
response = self.client.get(
reverse("authentik_api:endpoint-list") + "?superuser_full_list=true"
)
self.assertJSONEqual(
response.content.decode(),
{
"pagination": {
"next": 0,
"previous": 0,
"count": 2,
"current": 1,
"total_pages": 1,
"start_index": 1,
"end_index": 2,
},
"results": [
{
"pk": str(self.allowed.pk),
"name": self.allowed.name,
"provider": self.provider.pk,
"provider_obj": {
"pk": self.provider.pk,
"name": self.provider.name,
"authentication_flow": None,
"authorization_flow": None,
"property_mappings": [],
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"connection_expiry": "hours=8",
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
"settings": {},
"outpost_set": [],
},
"protocol": "rdp",
"host": self.allowed.host,
"settings": {},
"property_mappings": [],
"auth_mode": "",
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
},
{
"pk": str(self.denied.pk),
"name": self.denied.name,
"provider": self.provider.pk,
"provider_obj": {
"pk": self.provider.pk,
"name": self.provider.name,
"authentication_flow": None,
"authorization_flow": None,
"property_mappings": [],
"component": "ak-provider-rac-form",
"assigned_application_slug": self.app.slug,
"assigned_application_name": self.app.name,
"connection_expiry": "hours=8",
"verbose_name": "RAC Provider",
"verbose_name_plural": "RAC Providers",
"meta_model_name": "authentik_providers_rac.racprovider",
"settings": {},
"outpost_set": [],
},
"protocol": "rdp",
"host": self.denied.host,
"settings": {},
"property_mappings": [],
"auth_mode": "",
"launch_url": f"/application/rac/{self.app.slug}/{str(self.denied.pk)}/",
},
],
},
)

View file

@ -0,0 +1,144 @@
"""Test RAC Models"""
from django.test import TransactionTestCase
from authentik.core.models import Application, AuthenticatedSession
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.providers.rac.models import (
ConnectionToken,
Endpoint,
Protocols,
RACPropertyMapping,
RACProvider,
)
from authentik.lib.generators import generate_id
class TestModels(TransactionTestCase):
"""Test RAC Models"""
def setUp(self):
self.user = create_test_admin_user()
self.provider = RACProvider.objects.create(
name=generate_id(),
)
self.app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
self.endpoint = Endpoint.objects.create(
name=generate_id(),
host=f"{generate_id()}:1324",
protocol=Protocols.RDP,
provider=self.provider,
)
def test_settings_merge(self):
"""Test settings merge"""
token = ConnectionToken.objects.create(
provider=self.provider,
endpoint=self.endpoint,
session=AuthenticatedSession.objects.create(
user=self.user,
session_key=generate_id(),
),
)
path = f"/tmp/connection/{token.token}" # nosec
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
},
)
# Set settings in provider
self.provider.settings = {"level": "provider"}
self.provider.save()
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "provider",
},
)
# Set settings in endpoint
self.endpoint.settings = {
"level": "endpoint",
}
self.endpoint.save()
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "endpoint",
},
)
# Set settings in token
token.settings = {
"level": "token",
}
token.save()
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "token",
},
)
# Set settings in property mapping (provider)
mapping = RACPropertyMapping.objects.create(
name=generate_id(),
expression="""return {
"level": "property_mapping_provider"
}""",
)
self.provider.property_mappings.add(mapping)
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "property_mapping_provider",
},
)
# Set settings in property mapping (endpoint)
mapping = RACPropertyMapping.objects.create(
name=generate_id(),
static_settings={
"level": "property_mapping_endpoint",
"foo": True,
"bar": 6,
},
)
self.endpoint.property_mappings.add(mapping)
self.assertEqual(
token.get_settings(),
{
"hostname": self.endpoint.host.split(":")[0],
"port": "1324",
"client-name": "authentik",
"drive-path": path,
"create-drive-path": "true",
"level": "property_mapping_endpoint",
"foo": "true",
"bar": "6",
},
)

View file

@ -0,0 +1,132 @@
"""RAC Views tests"""
from datetime import timedelta
from json import loads
from time import mktime
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from rest_framework.test import APITestCase
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.enterprise.models import License, LicenseKey
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
from authentik.lib.generators import generate_id
from authentik.policies.denied import AccessDeniedResponse
from authentik.policies.dummy.models import DummyPolicy
from authentik.policies.models import PolicyBinding
class TestRACViews(APITestCase):
"""RAC Views tests"""
def setUp(self):
self.user = create_test_admin_user()
self.flow = create_test_flow()
self.provider = RACProvider.objects.create(name=generate_id(), authorization_flow=self.flow)
self.app = Application.objects.create(
name=generate_id(),
slug=generate_id(),
provider=self.provider,
)
self.endpoint = Endpoint.objects.create(
name=generate_id(),
host=f"{generate_id()}:1324",
protocol=Protocols.RDP,
provider=self.provider,
)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_no_policy(self):
"""Test request"""
License.objects.create(key=generate_id())
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertEqual(response.status_code, 302)
flow_response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
body = loads(flow_response.content)
next_url = body["to"]
final_response = self.client.get(next_url)
self.assertEqual(final_response.status_code, 200)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_app_deny(self):
"""Test request (deny on app level)"""
PolicyBinding.objects.create(
target=self.app,
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
License.objects.create(key=generate_id())
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertIsInstance(response, AccessDeniedResponse)
@patch(
"authentik.enterprise.models.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_endpoint_deny(self):
"""Test request (deny on endpoint level)"""
PolicyBinding.objects.create(
target=self.endpoint,
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
License.objects.create(key=generate_id())
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertEqual(response.status_code, 302)
flow_response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
body = loads(flow_response.content)
self.assertEqual(body["component"], "ak-stage-access-denied")

View file

@ -0,0 +1,47 @@
"""rac urls"""
from channels.auth import AuthMiddleware
from channels.sessions import CookieMiddleware
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from authentik.core.channels import TokenOutpostMiddleware
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer
from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer
from authentik.enterprise.providers.rac.views import RACInterface, RACStartView
from authentik.root.asgi_middleware import SessionMiddleware
from authentik.root.middleware import ChannelsLoggingMiddleware
urlpatterns = [
path(
"application/rac/<slug:app>/<uuid:endpoint>/",
ensure_csrf_cookie(RACStartView.as_view()),
name="start",
),
path(
"if/rac/<str:token>/",
ensure_csrf_cookie(RACInterface.as_view()),
name="if-rac",
),
]
websocket_urlpatterns = [
path(
"ws/rac/<str:token>/",
ChannelsLoggingMiddleware(
CookieMiddleware(SessionMiddleware(AuthMiddleware(RACClientConsumer.as_asgi())))
),
),
path(
"ws/outpost_rac/<str:channel>/",
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
),
]
api_urlpatterns = [
("providers/rac", RACProviderViewSet),
("propertymappings/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet),
]

View file

@ -0,0 +1,115 @@
"""RAC Views"""
from typing import Any
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, redirect
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import Application, AuthenticatedSession
from authentik.core.views.interface import InterfaceView
from authentik.enterprise.policy import EnterprisePolicyAccessView
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint, RACProvider
from authentik.flows.challenge import RedirectChallenge
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.models import in_memory_stage
from authentik.flows.planner import FlowPlanner
from authentik.flows.stage import RedirectStage
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.utils.time import timedelta_from_string
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.engine import PolicyEngine
class RACStartView(EnterprisePolicyAccessView):
"""Start a RAC connection by checking access and creating a connection token"""
endpoint: Endpoint
def resolve_provider_application(self):
self.application = get_object_or_404(Application, slug=self.kwargs["app"])
# Endpoint permissions are validated in the RACFinalStage below
self.endpoint = get_object_or_404(Endpoint, pk=self.kwargs["endpoint"])
self.provider = RACProvider.objects.get(application=self.application)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Start flow planner for RAC provider"""
planner = FlowPlanner(self.provider.authorization_flow)
planner.allow_empty_flows = True
try:
plan = planner.plan(self.request)
except FlowNonApplicableException:
raise Http404
plan.insert_stage(
in_memory_stage(
RACFinalStage,
endpoint=self.endpoint,
provider=self.provider,
)
)
request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
request.GET,
flow_slug=self.provider.authorization_flow.slug,
)
class RACInterface(InterfaceView):
"""Start RAC connection"""
template_name = "if/rac.html"
token: ConnectionToken
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# Early sanity check to ensure token still exists
token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
if not token:
return redirect("authentik_core:if-user")
self.token = token
return super().dispatch(request, *args, **kwargs)
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["token"] = self.token
return super().get_context_data(**kwargs)
class RACFinalStage(RedirectStage):
"""RAC Connection final stage, set the connection token in the stage"""
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
endpoint: Endpoint = self.executor.current_stage.endpoint
engine = PolicyEngine(endpoint, self.request.user, self.request)
engine.use_cache = False
engine.build()
passing = engine.result
if not passing.passing:
return self.executor.stage_invalid(", ".join(passing.messages))
return super().dispatch(request, *args, **kwargs)
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
endpoint: Endpoint = self.executor.current_stage.endpoint
provider: RACProvider = self.executor.current_stage.provider
token = ConnectionToken.objects.create(
provider=provider,
endpoint=endpoint,
settings=self.executor.plan.context.get("connection_settings", {}),
session=AuthenticatedSession.objects.filter(
session_key=self.request.session.session_key
).first(),
expires=now() + timedelta_from_string(provider.connection_expiry),
expiring=True,
)
setattr(
self.executor.current_stage,
"destination",
self.request.build_absolute_uri(
reverse(
"authentik_providers_rac:if-rac",
kwargs={
"token": str(token.token),
},
)
),
)
return super().get_challenge(*args, **kwargs)

View file

@ -10,3 +10,7 @@ CELERY_BEAT_SCHEDULE = {
"options": {"queue": "authentik_scheduled"},
}
}
INSTALLED_APPS = [
"authentik.enterprise.providers.rac",
]

View file

@ -6,6 +6,7 @@ import django_filters
from django.db.models.aggregates import Count
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import ExtractDay, ExtractHour
from django.db.models.query_utils import Q
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from guardian.shortcuts import get_objects_for_user
@ -87,7 +88,12 @@ class EventsFilter(django_filters.FilterSet):
we need to remove the dashes that a client may send. We can't use a
UUIDField for this, as some models might not have a UUID PK"""
value = str(value).replace("-", "")
return queryset.filter(context__model__pk=value)
query = Q(context__model__pk=value)
try:
query |= Q(context__model__pk=int(value))
except ValueError:
pass
return queryset.filter(query)
class Meta:
model = Event

View file

@ -1,4 +1,5 @@
"""Event API tests"""
from json import loads
from django.urls import reverse
from rest_framework.test import APITestCase
@ -11,6 +12,9 @@ from authentik.events.models import (
NotificationSeverity,
TransportMode,
)
from authentik.events.utils import model_to_dict
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.models import OAuth2Provider
class TestEventsAPI(APITestCase):
@ -20,6 +24,25 @@ class TestEventsAPI(APITestCase):
self.user = create_test_admin_user()
self.client.force_login(self.user)
def test_filter_model_pk_int(self):
"""Test event list with context_model_pk and integer PKs"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
)
event = Event.new(EventAction.MODEL_CREATED, model=model_to_dict(provider))
event.save()
response = self.client.get(
reverse("authentik_api:event-list"),
data={
"context_model_pk": provider.pk,
"context_model_app": "authentik_providers_oauth2",
"context_model_name": "oauth2provider",
},
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["pagination"]["count"], 1)
def test_top_n(self):
"""Test top_per_user"""
event = Event.new(EventAction.AUTHORIZE_APPLICATION)

View file

@ -17,8 +17,9 @@ from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import JSONDictField, PassiveSerializer
from authentik.core.models import Provider
from authentik.enterprise.providers.rac.models import RACProvider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
from authentik.outposts.models import (
Outpost,
OutpostConfig,
@ -47,12 +48,23 @@ class OutpostSerializer(ModelSerializer):
source="service_connection", read_only=True
)
def validate_name(self, name: str) -> str:
"""Validate name (especially for embedded outpost)"""
if not self.instance:
return name
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
raise ValidationError("Embedded outpost's name cannot be changed")
if self.instance.name == MANAGED_OUTPOST_NAME:
self.instance.managed = MANAGED_OUTPOST
return name
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
"""Check that all providers match the type of the outpost"""
type_map = {
OutpostType.LDAP: LDAPProvider,
OutpostType.PROXY: ProxyProvider,
OutpostType.RADIUS: RadiusProvider,
OutpostType.RAC: RACProvider,
None: Provider,
}
for provider in providers:

View file

@ -18,6 +18,7 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
["tenant", "outpost", "uid", "version"],
)
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
class AuthentikOutpostConfig(ManagedAppConfig):
@ -38,15 +39,18 @@ class AuthentikOutpostConfig(ManagedAppConfig):
DockerServiceConnection,
KubernetesServiceConnection,
Outpost,
OutpostConfig,
OutpostType,
)
if not CONFIG.get_bool("outposts.disable_embedded_outpost", False):
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
outpost.managed = MANAGED_OUTPOST
outpost.save()
return
outpost, updated = Outpost.objects.update_or_create(
defaults={
"name": "authentik Embedded Outpost",
"type": OutpostType.PROXY,
"name": MANAGED_OUTPOST_NAME,
},
managed=MANAGED_OUTPOST,
)
@ -55,12 +59,6 @@ class AuthentikOutpostConfig(ManagedAppConfig):
outpost.service_connection = KubernetesServiceConnection.objects.first()
elif DockerServiceConnection.objects.exists():
outpost.service_connection = DockerServiceConnection.objects.first()
outpost.config = OutpostConfig(
kubernetes_disabled_components=[
"deployment",
"secret",
]
)
outpost.save()
else:
Outpost.objects.filter(managed=MANAGED_OUTPOST).delete()

View file

@ -6,17 +6,19 @@ from typing import Any, Optional
from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection
from channels.generic.websocket import JsonWebsocketConsumer
from dacite.core import from_dict
from dacite.data import Data
from django.db import connection
from django.http.request import QueryDict
from guardian.shortcuts import get_objects_for_user
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.channels import AuthJsonConsumer
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
class WebsocketMessageInstruction(IntEnum):
@ -43,25 +45,23 @@ class WebsocketMessage:
args: dict[str, Any] = field(default_factory=dict)
class OutpostConsumer(AuthJsonConsumer):
class OutpostConsumer(JsonWebsocketConsumer):
"""Handler for Outposts that connect over websockets for health checks and live updates"""
outpost: Optional[Outpost] = None
logger: BoundLogger
last_uid: Optional[str] = None
instance_uid: Optional[str] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = get_logger()
def connect(self):
super().connect()
uuid = self.scope["url_route"]["kwargs"]["pk"]
user = self.scope["user"]
outpost = (
get_objects_for_user(self.user, "authentik_outposts.view_outpost")
.filter(pk=uuid)
.first()
get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
)
if not outpost:
raise DenyConnection()
@ -72,14 +72,20 @@ class OutpostConsumer(AuthJsonConsumer):
self.logger.warning("runtime error during accept", exc=exc)
raise DenyConnection()
self.outpost = outpost
self.last_uid = self.channel_name
query = QueryDict(self.scope["query_string"].decode())
self.instance_uid = query.get("instance_uuid", self.channel_name)
async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP_INSTANCE
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
self.channel_name,
)
GAUGE_OUTPOSTS_CONNECTED.labels(
tenant=connection.schema_name,
outpost=self.outpost.name,
uid=self.last_uid,
uid=self.instance_uid,
expected=self.outpost.config.kubernetes_replicas,
).inc()
@ -88,36 +94,39 @@ class OutpostConsumer(AuthJsonConsumer):
async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
if self.outpost and self.last_uid:
if self.instance_uid:
async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP_INSTANCE
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
self.channel_name,
)
if self.outpost and self.instance_uid:
GAUGE_OUTPOSTS_CONNECTED.labels(
tenant=connection.schema_name,
outpost=self.outpost.name,
uid=self.last_uid,
uid=self.instance_uid,
expected=self.outpost.config.kubernetes_replicas,
).dec()
def receive_json(self, content: Data, **kwargs):
msg = from_dict(WebsocketMessage, content)
uid = msg.args.get("uuid", self.channel_name)
self.last_uid = uid
if not self.outpost:
raise DenyConnection()
state = OutpostState.for_instance_uid(self.outpost, uid)
state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
state.last_seen = datetime.now()
state.hostname = msg.args.pop("hostname", "")
if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.pop("version", None)
state.build_hash = msg.args.pop("buildHash", "")
state.args = msg.args
state.args.update(msg.args)
elif msg.instruction == WebsocketMessageInstruction.ACK:
return
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
tenant=connection.schema_name,
outpost=self.outpost.name,
uid=self.last_uid or "",
uid=self.instance_uid or "",
version=state.version or "",
).set_to_current_time()
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)

View file

@ -43,6 +43,10 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
self.api = AppsV1Api(controller.client)
self.outpost = self.controller.outpost
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod
def reconciler_name() -> str:
return "deployment"

View file

@ -24,6 +24,10 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
super().__init__(controller)
self.api = CoreV1Api(controller.client)
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod
def reconciler_name() -> str:
return "secret"

View file

@ -77,7 +77,10 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
@property
def noop(self) -> bool:
return (not self._crd_exists()) or (self.is_embedded)
if not self._crd_exists():
self.logger.debug("CRD doesn't exist")
return True
return self.is_embedded
def _crd_exists(self) -> bool:
"""Check if the Prometheus ServiceMonitor exists"""

View file

@ -1,5 +1,6 @@
"""k8s utils"""
from pathlib import Path
from typing import Optional
from kubernetes.client.models.v1_container_port import V1ContainerPort
from kubernetes.client.models.v1_service_port import V1ServicePort
@ -37,9 +38,12 @@ def compare_port(
def compare_ports(
current: list[V1ServicePort | V1ContainerPort], reference: list[V1ServicePort | V1ContainerPort]
current: Optional[list[V1ServicePort | V1ContainerPort]],
reference: Optional[list[V1ServicePort | V1ContainerPort]],
):
"""Compare ports of a list"""
if not current or not reference:
raise NeedsRecreate()
if len(current) != len(reference):
raise NeedsRecreate()
for port in reference:

View file

@ -81,7 +81,10 @@ class KubernetesController(BaseController):
def up(self):
try:
for reconcile_key in self.reconcile_order:
reconciler = self.reconcilers[reconcile_key](self)
reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
reconciler.up()
except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc:
@ -95,7 +98,10 @@ class KubernetesController(BaseController):
all_logs += [f"{reconcile_key.title()}: Disabled"]
continue
with capture_logs() as logs:
reconciler = self.reconcilers[reconcile_key](self)
reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
reconciler.up()
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
return all_logs
@ -105,7 +111,10 @@ class KubernetesController(BaseController):
def down(self):
try:
for reconcile_key in self.reconcile_order:
reconciler = self.reconcilers[reconcile_key](self)
reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
self.logger.debug("Tearing down object", name=reconcile_key)
reconciler.down()
@ -120,7 +129,10 @@ class KubernetesController(BaseController):
all_logs += [f"{reconcile_key.title()}: Disabled"]
continue
with capture_logs() as logs:
reconciler = self.reconcilers[reconcile_key](self)
reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
reconciler.down()
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
return all_logs
@ -130,7 +142,10 @@ class KubernetesController(BaseController):
def get_static_deployment(self) -> str:
documents = []
for reconcile_key in self.reconcile_order:
reconciler = self.reconcilers[reconcile_key](self)
reconciler_cls = self.reconcilers.get(reconcile_key)
if not reconciler_cls:
continue
reconciler = reconciler_cls(self)
if reconciler.noop:
continue
documents.append(reconciler.get_reference_object().to_dict())

View file

@ -0,0 +1,25 @@
# Generated by Django 4.2.6 on 2023-10-14 19:23
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_outposts", "0020_alter_outpost_type"),
]
operations = [
migrations.AlterField(
model_name="outpost",
name="type",
field=models.TextField(
choices=[
("proxy", "Proxy"),
("ldap", "Ldap"),
("radius", "Radius"),
("rac", "Rac"),
],
default="proxy",
),
),
]

View file

@ -90,11 +90,12 @@ class OutpostModel(Model):
class OutpostType(models.TextChoices):
"""Outpost types, currently only the reverse proxy is available"""
"""Outpost types"""
PROXY = "proxy"
LDAP = "ldap"
RADIUS = "radius"
RAC = "rac"
def default_outpost_config(host: Optional[str] = None):
@ -459,7 +460,7 @@ class OutpostState:
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
"""Get state for a single instance"""
key = f"{outpost.state_cache_prefix}/{uid}"
default_data = {"uid": uid, "channel_ids": []}
default_data = {"uid": uid}
data = cache.get(key, default_data)
if isinstance(data, str):
cache.delete(key)

View file

@ -17,6 +17,8 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
from structlog.stdlib import get_logger
from yaml import safe_load
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
from authentik.events.monitored_tasks import (
MonitoredTask,
TaskResult,
@ -71,6 +73,11 @@ def controller_for_outpost(outpost: Outpost) -> Optional[type[BaseController]]:
return RadiusDockerController
if isinstance(service_connection, KubernetesServiceConnection):
return RadiusKubernetesController
if outpost.type == OutpostType.RAC:
if isinstance(service_connection, DockerServiceConnection):
return RACDockerController
if isinstance(service_connection, KubernetesServiceConnection):
return RACKubernetesController
return None

View file

@ -2,11 +2,13 @@
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.blueprints.tests import reconcile_app
from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id
from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.models import OutpostType, default_outpost_config
from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider
@ -22,7 +24,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.user = create_test_admin_user()
self.client.force_login(self.user)
def test_outpost_validaton(self):
@reconcile_app("authentik_outposts")
def test_managed_name_change(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
)
@reconcile_app("authentik_outposts")
def test_managed_without_managed(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
embedded_outpost.managed = ""
embedded_outpost.save()
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 200)
embedded_outpost.refresh_from_db()
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
def test_outpost_validation(self):
"""Test Outpost validation"""
valid = OutpostSerializer(
data={

View file

@ -1,6 +1,7 @@
"""Websocket tests"""
from dataclasses import asdict
from channels.exceptions import DenyConnection
from channels.routing import URLRouter
from channels.testing import WebsocketCommunicator
from django.test import TransactionTestCase
@ -35,8 +36,9 @@ class TestOutpostWS(TransactionTestCase):
communicator = WebsocketCommunicator(
URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/"
)
connected, _ = await communicator.connect()
self.assertFalse(connected)
with self.assertRaises(DenyConnection):
connected, _ = await communicator.connect()
self.assertFalse(connected)
async def test_auth_valid(self):
"""Test auth with token"""

View file

@ -1,6 +1,7 @@
"""Outpost Websocket URLS"""
from django.urls import path
from authentik.core.channels import TokenOutpostMiddleware
from authentik.outposts.api.outposts import OutpostViewSet
from authentik.outposts.api.service_connections import (
DockerServiceConnectionViewSet,
@ -11,7 +12,10 @@ from authentik.outposts.consumer import OutpostConsumer
from authentik.root.middleware import ChannelsLoggingMiddleware
websocket_urlpatterns = [
path("ws/outpost/<uuid:pk>/", ChannelsLoggingMiddleware(OutpostConsumer.as_asgi())),
path(
"ws/outpost/<uuid:pk>/",
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
),
]
api_urlpatterns = [

View file

@ -38,13 +38,12 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
"outposts.container_image_base",
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
)
CONFIG.set("error_reporting.sample_rate", 0)
CONFIG.set("tenants.enabled", False)
CONFIG.set("outposts.disable_embedded_outpost", False)
sentry_init(
environment="testing",
send_default_pii=True,
)
CONFIG.set("error_reporting.sample_rate", 0)
CONFIG.set("error_reporting.environment", "testing")
CONFIG.set("error_reporting.send_pii", True)
sentry_init()
@classmethod
def add_arguments(cls, parser: ArgumentParser):

View file

@ -99,7 +99,9 @@ class OAuthSourceSerializer(SourceSerializer):
]:
if getattr(provider_type, url, None) is None:
if url not in attrs:
raise ValidationError(f"{url} is required for provider {provider_type.name}")
raise ValidationError(
f"{url} is required for provider {provider_type.verbose_name}"
)
return attrs
class Meta:

View file

@ -104,8 +104,8 @@ class AppleType(SourceType):
callback_view = AppleOAuth2Callback
redirect_view = AppleOAuthRedirect
name = "Apple"
slug = "apple"
verbose_name = "Apple"
name = "apple"
authorization_url = "https://appleid.apple.com/auth/authorize"
access_token_url = "https://appleid.apple.com/auth/token" # nosec

View file

@ -43,8 +43,8 @@ class AzureADType(SourceType):
callback_view = AzureADOAuthCallback
redirect_view = AzureADOAuthRedirect
name = "Azure AD"
slug = "azuread"
verbose_name = "Azure AD"
name = "azuread"
urls_customizable = True

View file

@ -36,8 +36,8 @@ class DiscordType(SourceType):
callback_view = DiscordOAuth2Callback
redirect_view = DiscordOAuthRedirect
name = "Discord"
slug = "discord"
verbose_name = "Discord"
name = "discord"
authorization_url = "https://discord.com/api/oauth2/authorize"
access_token_url = "https://discord.com/api/oauth2/token" # nosec

View file

@ -48,8 +48,8 @@ class FacebookType(SourceType):
callback_view = FacebookOAuth2Callback
redirect_view = FacebookOAuthRedirect
name = "Facebook"
slug = "facebook"
verbose_name = "Facebook"
name = "facebook"
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec

View file

@ -68,8 +68,8 @@ class GitHubType(SourceType):
callback_view = GitHubOAuth2Callback
redirect_view = GitHubOAuthRedirect
name = "GitHub"
slug = "github"
verbose_name = "GitHub"
name = "github"
urls_customizable = True

View file

@ -34,8 +34,8 @@ class GoogleType(SourceType):
callback_view = GoogleOAuth2Callback
redirect_view = GoogleOAuthRedirect
name = "Google"
slug = "google"
verbose_name = "Google"
name = "google"
authorization_url = "https://accounts.google.com/o/oauth2/auth"
access_token_url = "https://oauth2.googleapis.com/token" # nosec

View file

@ -63,7 +63,7 @@ class MailcowType(SourceType):
callback_view = MailcowOAuth2Callback
redirect_view = MailcowOAuthRedirect
name = "Mailcow"
slug = "mailcow"
verbose_name = "Mailcow"
name = "mailcow"
urls_customizable = True

View file

@ -42,7 +42,7 @@ class OpenIDConnectType(SourceType):
callback_view = OpenIDConnectOAuth2Callback
redirect_view = OpenIDConnectOAuthRedirect
name = "OpenID Connect"
slug = "openidconnect"
verbose_name = "OpenID Connect"
name = "openidconnect"
urls_customizable = True

View file

@ -42,7 +42,7 @@ class OktaType(SourceType):
callback_view = OktaOAuth2Callback
redirect_view = OktaOAuthRedirect
name = "Okta"
slug = "okta"
verbose_name = "Okta"
name = "okta"
urls_customizable = True

View file

@ -43,8 +43,8 @@ class PatreonType(SourceType):
callback_view = PatreonOAuthCallback
redirect_view = PatreonOAuthRedirect
name = "Patreon"
slug = "patreon"
verbose_name = "Patreon"
name = "patreon"
authorization_url = "https://www.patreon.com/oauth2/authorize"
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec

View file

@ -51,8 +51,8 @@ class RedditType(SourceType):
callback_view = RedditOAuth2Callback
redirect_view = RedditOAuthRedirect
name = "Reddit"
slug = "reddit"
verbose_name = "Reddit"
name = "reddit"
authorization_url = "https://www.reddit.com/api/v1/authorize"
access_token_url = "https://www.reddit.com/api/v1/access_token" # nosec

View file

@ -28,7 +28,7 @@ class SourceType:
callback_view = OAuthCallback
redirect_view = OAuthRedirect
name: str = "default"
slug: str = "default"
verbose_name: str = "Default source type"
urls_customizable = False
@ -41,7 +41,7 @@ class SourceType:
def icon_url(self) -> str:
"""Get Icon URL for login"""
return static(f"authentik/sources/{self.slug}.svg")
return static(f"authentik/sources/{self.name}.svg")
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
"""Allow types to return custom challenges"""
@ -77,20 +77,20 @@ class SourceTypeRegistry:
def get_name_tuple(self):
"""Get list of tuples of all registered names"""
return [(x.slug, x.name) for x in self.__sources]
return [(x.name, x.verbose_name) for x in self.__sources]
def find_type(self, type_name: str) -> Type[SourceType]:
"""Find type based on source"""
found_type = None
for src_type in self.__sources:
if src_type.slug == type_name:
if src_type.name == type_name:
return src_type
if not found_type:
found_type = SourceType
LOGGER.warning(
"no matching type found, using default",
wanted=type_name,
have=[x.slug for x in self.__sources],
have=[x.name for x in self.__sources],
)
return found_type

View file

@ -49,8 +49,8 @@ class TwitchType(SourceType):
callback_view = TwitchOAuth2Callback
redirect_view = TwitchOAuthRedirect
name = "Twitch"
slug = "twitch"
verbose_name = "Twitch"
name = "twitch"
authorization_url = "https://id.twitch.tv/oauth2/authorize"
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec

View file

@ -66,8 +66,8 @@ class TwitterType(SourceType):
callback_view = TwitterOAuthCallback
redirect_view = TwitterOAuthRedirect
name = "Twitter"
slug = "twitter"
verbose_name = "Twitter"
name = "twitter"
authorization_url = "https://twitter.com/i/oauth2/authorize"
access_token_url = "https://api.twitter.com/2/oauth2/token" # nosec

View file

@ -2816,6 +2816,117 @@
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_providers_rac.racprovider"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"attrs": {
"$ref": "#/$defs/model_authentik_providers_rac.racprovider"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_providers_rac.racprovider"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_providers_rac.endpoint"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"attrs": {
"$ref": "#/$defs/model_authentik_providers_rac.endpoint"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_providers_rac.endpoint"
}
}
},
{
"type": "object",
"required": [
"model",
"identifiers"
],
"properties": {
"model": {
"const": "authentik_providers_rac.racpropertymapping"
},
"id": {
"type": "string"
},
"state": {
"type": "string",
"enum": [
"absent",
"present",
"created",
"must_created"
],
"default": "present"
},
"conditions": {
"type": "array",
"items": {
"type": "boolean"
}
},
"attrs": {
"$ref": "#/$defs/model_authentik_providers_rac.racpropertymapping"
},
"identifiers": {
"$ref": "#/$defs/model_authentik_providers_rac.racpropertymapping"
}
}
},
{
"type": "object",
"required": [
@ -3353,7 +3464,8 @@
"enum": [
"proxy",
"ldap",
"radius"
"radius",
"rac"
],
"title": "Type"
},
@ -3534,7 +3646,8 @@
"authentik.brands",
"authentik.blueprints",
"authentik.core",
"authentik.enterprise"
"authentik.enterprise",
"authentik.enterprise.providers.rac"
],
"title": "App",
"description": "Match events created by selected application. When left empty, all applications are matched."
@ -3620,7 +3733,10 @@
"authentik_core.user",
"authentik_core.application",
"authentik_core.token",
"authentik_enterprise.license"
"authentik_enterprise.license",
"authentik_providers_rac.racprovider",
"authentik_providers_rac.endpoint",
"authentik_providers_rac.racpropertymapping"
],
"title": "Model",
"description": "Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched."
@ -8811,6 +8927,123 @@
},
"required": []
},
"model_authentik_providers_rac.racprovider": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"authentication_flow": {
"type": "integer",
"title": "Authentication flow",
"description": "Flow used for authentication when the associated application is accessed by an un-authenticated user."
},
"authorization_flow": {
"type": "integer",
"title": "Authorization flow",
"description": "Flow used when authorizing this provider."
},
"property_mappings": {
"type": "array",
"items": {
"type": "integer"
},
"title": "Property mappings"
},
"settings": {
"type": "object",
"additionalProperties": true,
"title": "Settings"
},
"connection_expiry": {
"type": "string",
"minLength": 1,
"title": "Connection expiry",
"description": "Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)"
}
},
"required": []
},
"model_authentik_providers_rac.endpoint": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"provider": {
"type": "integer",
"title": "Provider"
},
"protocol": {
"type": "string",
"enum": [
"rdp",
"vnc",
"ssh"
],
"title": "Protocol"
},
"host": {
"type": "string",
"minLength": 1,
"title": "Host"
},
"settings": {
"type": "object",
"additionalProperties": true,
"title": "Settings"
},
"property_mappings": {
"type": "array",
"items": {
"type": "integer"
},
"title": "Property mappings"
},
"auth_mode": {
"type": "string",
"enum": [
"static",
"prompt"
],
"title": "Auth mode"
}
},
"required": []
},
"model_authentik_providers_rac.racpropertymapping": {
"type": "object",
"properties": {
"managed": {
"type": [
"string",
"null"
],
"minLength": 1,
"title": "Managed by authentik",
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
},
"name": {
"type": "string",
"minLength": 1,
"title": "Name"
},
"expression": {
"type": "string",
"title": "Expression"
},
"static_settings": {
"type": "object",
"additionalProperties": true,
"title": "Static settings"
}
},
"required": []
},
"model_authentik_blueprints.metaapplyblueprint": {
"type": "object",
"properties": {

View file

@ -0,0 +1,32 @@
version: 1
metadata:
labels:
blueprints.goauthentik.io/system: "true"
name: System - RAC Provider - Mappings
entries:
- identifiers:
managed: goauthentik.io/providers/rac/rdp-default
model: authentik_providers_rac.racpropertymapping
attrs:
name: "authentik default RAC Mapping: RDP Default settings"
static_settings:
resize-method: "display-update"
enable-wallpaper: "true"
enable-font-smoothing: "true"
- identifiers:
managed: goauthentik.io/providers/rac/rdp-high-fidelity
model: authentik_providers_rac.racpropertymapping
attrs:
name: "authentik default RAC Mapping: RDP High Fidelity"
static_settings:
enable-theming: "true"
enable-full-window-drag: "true"
enable-desktop-composition: "true"
enable-menu-animations: "true"
- identifiers:
managed: goauthentik.io/providers/rac/ssh-default
model: authentik_providers_rac.racpropertymapping
attrs:
name: "authentik default RAC Mapping: SSH Default settings"
static_settings:
terminal-type: "xterm-256color"

93
cmd/rac/main.go Normal file
View file

@ -0,0 +1,93 @@
package main
import (
"fmt"
"net/url"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"goauthentik.io/internal/common"
"goauthentik.io/internal/debug"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ak/healthcheck"
"goauthentik.io/internal/outpost/rac"
)
const helpMessage = `authentik RAC
Required environment variables:
- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company")
- AUTHENTIK_TOKEN: Token to authenticate with
- AUTHENTIK_INSECURE: Skip SSL Certificate verification`
var rootCmd = &cobra.Command{
Long: helpMessage,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
log.SetLevel(log.DebugLevel)
log.SetFormatter(&log.JSONFormatter{
FieldMap: log.FieldMap{
log.FieldKeyMsg: "event",
log.FieldKeyTime: "timestamp",
},
DisableHTMLEscape: true,
})
},
Run: func(cmd *cobra.Command, args []string) {
debug.EnableDebugServer()
akURL, found := os.LookupEnv("AUTHENTIK_HOST")
if !found {
fmt.Println("env AUTHENTIK_HOST not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akToken, found := os.LookupEnv("AUTHENTIK_TOKEN")
if !found {
fmt.Println("env AUTHENTIK_TOKEN not set!")
fmt.Println(helpMessage)
os.Exit(1)
}
akURLActual, err := url.Parse(akURL)
if err != nil {
fmt.Println(err)
fmt.Println(helpMessage)
os.Exit(1)
}
ex := common.Init()
defer common.Defer()
go func() {
for {
<-ex
os.Exit(0)
}
}()
ac := ak.NewAPIController(*akURLActual, akToken)
if ac == nil {
os.Exit(1)
}
defer ac.Shutdown()
ac.Server = rac.NewServer(ac)
err = ac.Start()
if err != nil {
log.WithError(err).Panic("Failed to run server")
}
for {
<-ex
}
},
}
func main() {
rootCmd.AddCommand(healthcheck.Command)
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

15
go.mod
View file

@ -10,7 +10,7 @@ require (
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.6
github.com/go-openapi/runtime v0.26.2
github.com/go-openapi/strfmt v0.21.10
github.com/go-openapi/strfmt v0.22.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.5.0
github.com/gorilla/handlers v1.5.2
@ -22,12 +22,13 @@ require (
github.com/mitchellh/mapstructure v1.5.0
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.17.0
github.com/prometheus/client_golang v1.18.0
github.com/redis/go-redis/v9 v9.3.1
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.8.4
goauthentik.io/api/v3 v3.2023105.2
github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2023105.3
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.15.0
golang.org/x/sync v0.5.0
@ -60,14 +61,14 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.45.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.mongodb.org/mongo-driver v1.13.1 // indirect
go.opentelemetry.io/otel v1.17.0 // indirect

36
go.sum
View file

@ -116,8 +116,8 @@ github.com/go-openapi/spec v0.20.6/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6
github.com/go-openapi/spec v0.20.11 h1:J/TzFDLTt4Rcl/l1PmyErvkqlJDncGvPTMnCI39I4gY=
github.com/go-openapi/spec v0.20.11/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA=
github.com/go-openapi/strfmt v0.21.3/go.mod h1:k+RzNO0Da+k3FrrynSNN8F7n/peCmQQqbbXjtDfvmGg=
github.com/go-openapi/strfmt v0.21.10 h1:JIsly3KXZB/Qf4UzvzJpg4OELH/0ASDQsyk//TTBDDk=
github.com/go-openapi/strfmt v0.21.10/go.mod h1:vNDMwbilnl7xKiO/Ve/8H8Bb2JIInBnH+lqiw6QWgis=
github.com/go-openapi/strfmt v0.22.0 h1:Ew9PnEYc246TwrEspvBdDHS4BVKXy/AOVsfqGDgAcaI=
github.com/go-openapi/strfmt v0.22.0/go.mod h1:HzJ9kokGIju3/K6ap8jL+OlGAbjpSv27135Yr9OivU4=
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
@ -195,6 +195,7 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY=
github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
@ -210,6 +211,7 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@ -223,8 +225,8 @@ github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg=
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k=
github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
@ -247,21 +249,22 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac h1:jWKYCNlX4J5s8M0nHYkh7Y7c9gRVDEb3mq51j5J0F5M=
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac/go.mod h1:hoLfEwdY11HjRfKFH6KqnPsfxlo3BP6bJehpDv8t6sQ=
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk=
github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM=
github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
@ -269,8 +272,10 @@ github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyh
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@ -281,6 +286,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/wwt/guac v1.3.2 h1:sH6OFGa/1tBs7ieWBVlZe7t6F5JAOWBry/tqQL/Vup4=
github.com/wwt/guac v1.3.2/go.mod h1:eKm+NrnK7A88l4UBEcYNpZQGMpZRryYKoz4D/0/n1C0=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
@ -309,8 +316,8 @@ go.opentelemetry.io/otel/trace v1.17.0 h1:/SWhSRHmDPOImIAetP1QAeMnZYiQXrTy4fMMYO
go.opentelemetry.io/otel/trace v1.17.0/go.mod h1:I/4vKTgFclIsXRVucpH25X0mpFSczM7aHeaz0ZBLWjY=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
goauthentik.io/api/v3 v3.2023105.2 h1:ZUblqN5LidnCSlEZ/L19h7OnwppnAA3m5AGC7wUN0Ew=
goauthentik.io/api/v3 v3.2023105.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
goauthentik.io/api/v3 v3.2023105.3 h1:x0pMJIKkbN198OOssqA94h8bO6ft9gwG8bpZqZL7WVg=
goauthentik.io/api/v3 v3.2023105.3/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -414,6 +421,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View file

@ -159,8 +159,8 @@ func (a *APIController) AddRefreshHandler(handler func()) {
a.refreshHandlers = append(a.refreshHandlers, handler)
}
func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler)
func (a *APIController) Token() string {
return a.token
}
func (a *APIController) OnRefresh() error {
@ -182,7 +182,7 @@ func (a *APIController) OnRefresh() error {
return err
}
func (a *APIController) getWebsocketArgs() map[string]interface{} {
func (a *APIController) getWebsocketPingArgs() map[string]interface{} {
args := map[string]interface{}{
"version": constants.VERSION,
"buildHash": constants.BUILD("tagged"),

View file

@ -18,6 +18,8 @@ import (
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
pathTemplate := "%s://%s/ws/outpost/%s/?%s"
query := akURL.Query()
query.Set("instance_uuid", ac.instanceUUID.String())
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", ac.token)
@ -45,7 +47,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
// Send hello message with our version
msg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: ac.getWebsocketArgs(),
Args: ac.getWebsocketPingArgs(),
}
err = ws.WriteJSON(msg)
if err != nil {
@ -53,7 +55,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
return err
}
ac.lastWsReconnect = time.Now()
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket")
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket")
return nil
}
@ -157,23 +159,19 @@ func (ac *APIController) startWSHandler() {
func (ac *APIController) startWSHealth() {
ticker := time.NewTicker(time.Second * 10)
for ; true; <-ticker.C {
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: ac.getWebsocketArgs(),
}
if ac.wsConn == nil {
go ac.reconnectWS()
time.Sleep(time.Second * 5)
continue
}
err := ac.wsConn.WriteJSON(aliveMsg)
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
err := ac.SendWSHello(map[string]interface{}{})
if err != nil {
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
go ac.reconnectWS()
time.Sleep(time.Second * 5)
continue
} else {
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
ConnectionStatus.With(prometheus.Labels{
"outpost_name": ac.Outpost.Name,
"outpost_type": ac.Server.Type(),
@ -202,3 +200,20 @@ func (ac *APIController) startIntervalUpdater() {
}
}
}
func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler)
}
func (a *APIController) SendWSHello(args map[string]interface{}) error {
allArgs := a.getWebsocketPingArgs()
for key, value := range args {
allArgs[key] = value
}
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: allArgs,
}
err := a.wsConn.WriteJSON(aliveMsg)
return err
}

View file

@ -6,6 +6,7 @@ import (
"strings"
"beryju.io/ldap"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/utils"
@ -49,8 +50,8 @@ func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
constants.OCPosixAccount,
constants.OCAKUser,
},
"uidNumber": {pi.GetUidNumber(u)},
"gidNumber": {pi.GetUidNumber(u)},
"uidNumber": {pi.GetUserUidNumber(u)},
"gidNumber": {pi.GetUserGidNumber(u)},
"homeDirectory": {fmt.Sprintf("/home/%s", u.Username)},
"sn": {u.Name},
})

View file

@ -4,6 +4,7 @@ import (
"strconv"
"beryju.io/ldap"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/server"
@ -50,7 +51,7 @@ func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup {
DN: si.GetGroupDN(g.Name),
CN: g.Name,
Uid: string(g.Pk),
GidNumber: si.GetGidNumber(g),
GidNumber: si.GetGroupGidNumber(g),
Member: si.UsersForGroup(g),
IsVirtualGroup: false,
IsSuperuser: *g.IsSuperuser,
@ -63,7 +64,7 @@ func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup {
DN: si.GetVirtualGroupDN(u.Username),
CN: u.Username,
Uid: u.Uid,
GidNumber: si.GetUidNumber(u),
GidNumber: si.GetUserGidNumber(u),
Member: []string{si.GetUserDN(u.Username)},
IsVirtualGroup: true,
IsSuperuser: false,

View file

@ -3,6 +3,7 @@ package server
import (
"beryju.io/ldap"
"github.com/go-openapi/strfmt"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/flags"
)
@ -28,8 +29,9 @@ type LDAPServerInstance interface {
GetGroupDN(string) string
GetVirtualGroupDN(string) string
GetUidNumber(api.User) string
GetGidNumber(api.Group) string
GetUserUidNumber(api.User) string
GetUserGidNumber(api.User) string
GetGroupGidNumber(api.Group) string
UsersForGroup(api.Group) []string

View file

@ -35,7 +35,7 @@ func (pi *ProviderInstance) GetVirtualGroupDN(group string) string {
return fmt.Sprintf("cn=%s,%s", group, pi.VirtualGroupDN)
}
func (pi *ProviderInstance) GetUidNumber(user api.User) string {
func (pi *ProviderInstance) GetUserUidNumber(user api.User) string {
uidNumber, ok := user.GetAttributes()["uidNumber"].(string)
if ok {
@ -45,7 +45,17 @@ func (pi *ProviderInstance) GetUidNumber(user api.User) string {
return strconv.FormatInt(int64(pi.uidStartNumber+user.Pk), 10)
}
func (pi *ProviderInstance) GetGidNumber(group api.Group) string {
func (pi *ProviderInstance) GetUserGidNumber(user api.User) string {
gidNumber, ok := user.GetAttributes()["gidNumber"].(string)
if ok {
return gidNumber
}
return pi.GetUserUidNumber(user)
}
func (pi *ProviderInstance) GetGroupGidNumber(group api.Group) string {
gidNumber, ok := group.GetAttributes()["gidNumber"].(string)
if ok {

View file

@ -31,16 +31,11 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co
return nil, err
}
// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token")
}
a.log.WithField("id_token", rawIDToken).Trace("id_token")
jwt := oauth2Token.AccessToken
a.log.WithField("jwt", jwt).Trace("access_token")
// Parse and verify ID Token payload.
idToken, err := a.tokenVerifier.Verify(ctx, rawIDToken)
idToken, err := a.tokenVerifier.Verify(ctx, jwt)
if err != nil {
return nil, err
}
@ -53,6 +48,6 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co
if claims.Proxy == nil {
claims.Proxy = &ProxyClaims{}
}
claims.RawToken = rawIDToken
claims.RawToken = jwt
return claims, nil
}

View file

@ -13,6 +13,7 @@ import (
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/redis/go-redis/v9"
"goauthentik.io/api/v3"
"goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/proxyv2/codecs"
@ -40,7 +41,7 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
// New default RedisStore
rs, err := redisstore.NewRedisStore(context.Background(), client)
if err != nil {
panic(err)
a.log.WithError(err).Panic("failed to connect to redis")
}
rs.KeyPrefix(RedisKeyPrefix)
@ -62,7 +63,7 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
// https://github.com/markbates/goth/commit/7276be0fdf719ddff753f3574ef0f967e4a5a5f7
// set the maxLength of the cookies stored on the disk to a larger number to prevent issues with:
// securecookie: the value is too long
// when using OpenID Connect , since this can contain a large amount of extra information in the id_token
// when using OpenID Connect, since this can contain a large amount of extra information in the id_token
// Note, when using the FilesystemStore only the session.ID is written to a browser cookie, so this is explicit for the storage on disk
cs.MaxLength(math.MaxInt)

View file

@ -0,0 +1,124 @@
package connection
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"github.com/wwt/guac"
"goauthentik.io/internal/config"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ak"
)
const guacAddr = "0.0.0.0:4822"
type Connection struct {
log *log.Entry
st *guac.SimpleTunnel
ac *ak.APIController
ws *websocket.Conn
ctx context.Context
ctxCancel context.CancelFunc
OnError func(error)
closing bool
}
func NewConnection(ac *ak.APIController, forChannel string, cfg *guac.Config) (*Connection, error) {
ctx, canc := context.WithCancel(context.Background())
c := &Connection{
ac: ac,
log: log.WithField("connection", forChannel),
ctx: ctx,
ctxCancel: canc,
OnError: func(err error) {},
closing: false,
}
err := c.initGuac(cfg)
if err != nil {
return nil, err
}
err = c.initSocket(forChannel)
if err != nil {
_ = c.st.Close()
return nil, err
}
c.initMirror()
return c, nil
}
func (c *Connection) initSocket(forChannel string) error {
pathTemplate := "%s://%s/ws/outpost_rac/%s/"
scheme := strings.ReplaceAll(c.ac.Client.GetConfig().Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", c.ac.Token())
header := http.Header{
"Authorization": []string{authHeader},
"User-Agent": []string{constants.OutpostUserAgent()},
}
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Get().AuthentikInsecure,
},
}
url := fmt.Sprintf(pathTemplate, scheme, c.ac.Client.GetConfig().Host, forChannel)
ws, _, err := dialer.Dial(url, header)
if err != nil {
c.log.WithError(err).Warning("failed to connect websocket")
return err
}
c.ws = ws
return nil
}
func (c *Connection) initGuac(cfg *guac.Config) error {
addr, err := net.ResolveTCPAddr("tcp", guacAddr)
if err != nil {
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
stream := guac.NewStream(conn, guac.SocketTimeout)
err = stream.Handshake(cfg)
if err != nil {
return err
}
st := guac.NewSimpleTunnel(stream)
c.st = st
return nil
}
func (c *Connection) initMirror() {
go c.wsToGuacd()
go c.guacdToWs()
}
func (c *Connection) onError(err error) {
if c.closing {
return
}
c.closing = true
e := c.st.Close()
if e != nil {
c.log.WithError(e).Warning("failed to close guacd connection")
}
c.log.WithError(err).Info("removing connection")
c.ctxCancel()
c.OnError(err)
}

View file

@ -0,0 +1,103 @@
package connection
import (
"bytes"
"fmt"
"github.com/gorilla/websocket"
"github.com/wwt/guac"
)
var (
internalOpcodeIns = []byte(fmt.Sprint(len(guac.InternalDataOpcode), ".", guac.InternalDataOpcode))
authentikOpcode = []byte("0.authentik.")
)
// MessageReader wraps a websocket connection and only permits Reading
type MessageReader interface {
// ReadMessage should return a single complete message to send to guac
ReadMessage() (int, []byte, error)
}
func (c *Connection) wsToGuacd() {
w := c.st.AcquireWriter()
for {
select {
default:
_, data, e := c.ws.ReadMessage()
if e != nil {
c.log.WithError(e).Trace("Error reading message from ws")
c.onError(e)
return
}
if bytes.HasPrefix(data, internalOpcodeIns) {
if bytes.HasPrefix(data, authentikOpcode) {
switch string(bytes.Replace(data, authentikOpcode, []byte{}, 1)) {
case "disconnect":
_, e := w.Write([]byte(guac.NewInstruction("disconnect").String()))
c.onError(e)
return
}
}
// messages starting with the InternalDataOpcode are never sent to guacd
continue
}
if _, e = w.Write(data); e != nil {
c.log.WithError(e).Trace("Failed writing to guacd")
c.onError(e)
return
}
case <-c.ctx.Done():
return
}
}
}
// MessageWriter wraps a websocket connection and only permits Writing
type MessageWriter interface {
// WriteMessage writes one or more complete guac commands to the websocket
WriteMessage(int, []byte) error
}
func (c *Connection) guacdToWs() {
r := c.st.AcquireReader()
buf := bytes.NewBuffer(make([]byte, 0, guac.MaxGuacMessage*2))
for {
select {
default:
ins, e := r.ReadSome()
if e != nil {
c.log.WithError(e).Trace("Error reading from guacd")
c.onError(e)
return
}
if bytes.HasPrefix(ins, internalOpcodeIns) {
// messages starting with the InternalDataOpcode are never sent to the websocket
continue
}
if _, e = buf.Write(ins); e != nil {
c.log.WithError(e).Trace("Failed to buffer guacd to ws")
c.onError(e)
return
}
// if the buffer has more data in it or we've reached the max buffer size, send the data and reset
if !r.Available() || buf.Len() >= guac.MaxGuacMessage {
if e = c.ws.WriteMessage(1, buf.Bytes()); e != nil {
if e == websocket.ErrCloseSent {
return
}
c.log.WithError(e).Trace("Failed sending message to ws")
c.onError(e)
return
}
buf.Reset()
}
case <-c.ctx.Done():
return
}
}
}

View file

@ -0,0 +1,26 @@
package rac
import (
"os"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ak"
)
const (
guacdPath = "/opt/guacamole/sbin/guacd"
guacdDefaultArgs = " -b 0.0.0.0 -f"
)
func (rs *RACServer) startGuac() error {
guacdArgs := strings.Split(guacdDefaultArgs, " ")
guacdArgs = append(guacdArgs, "-L", rs.ac.Outpost.Config[ak.ConfigLogLevel].(string))
rs.guacd = exec.Command(guacdPath, guacdArgs...)
rs.guacd.Env = os.Environ()
rs.guacd.Stdout = rs.log.WithField("logger", "authentik.outpost.rac.guacd").WriterLevel(log.InfoLevel)
rs.guacd.Stderr = rs.log.WithField("logger", "authentik.outpost.rac.guacd").WriterLevel(log.InfoLevel)
rs.log.Info("starting guacd")
return rs.guacd.Start()
}

View file

@ -0,0 +1,28 @@
package metrics
import (
"net/http"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/utils/sentry"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func RunServer() {
m := mux.NewRouter()
l := log.WithField("logger", "authentik.outpost.metrics")
m.Use(sentry.SentryNoSampleMiddleware)
m.HandleFunc("/outpost.goauthentik.io/ping", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(204)
})
m.Path("/metrics").Handler(promhttp.Handler())
listen := config.Get().Listen.Metrics
l.WithField("listen", listen).Info("Starting Metrics server")
err := http.ListenAndServe(listen, m)
if err != nil {
l.WithError(err).Warning("Failed to start metrics listener")
}
}

126
internal/outpost/rac/rac.go Normal file
View file

@ -0,0 +1,126 @@
package rac
import (
"context"
"os/exec"
"strconv"
"sync"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
"github.com/wwt/guac"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/rac/connection"
"goauthentik.io/internal/outpost/rac/metrics"
)
type RACServer struct {
log *log.Entry
ac *ak.APIController
guacd *exec.Cmd
connm sync.RWMutex
conns map[string]connection.Connection
}
func NewServer(ac *ak.APIController) *RACServer {
rs := &RACServer{
log: log.WithField("logger", "authentik.outpost.rac"),
ac: ac,
connm: sync.RWMutex{},
conns: map[string]connection.Connection{},
}
ac.AddWSHandler(rs.wsHandler)
return rs
}
type WSMessage struct {
ConnID string `mapstructure:"conn_id"`
DestChannelID string `mapstructure:"dest_channel_id"`
Params map[string]string `mapstructure:"params"`
Protocol string `mapstructure:"protocol"`
OptimalScreenWidth string `mapstructure:"screen_width"`
OptimalScreenHeight string `mapstructure:"screen_height"`
OptimalScreenDPI string `mapstructure:"screen_dpi"`
}
func parseIntOrZero(input string) int {
x, err := strconv.Atoi(input)
if err != nil {
return 0
}
return x
}
func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) {
wsm := WSMessage{}
err := mapstructure.Decode(args, &wsm)
if err != nil {
rs.log.WithError(err).Warning("invalid ws message")
return
}
config := guac.NewGuacamoleConfiguration()
config.Protocol = wsm.Protocol
config.Parameters = wsm.Params
config.OptimalScreenWidth = parseIntOrZero(wsm.OptimalScreenWidth)
config.OptimalScreenHeight = parseIntOrZero(wsm.OptimalScreenHeight)
config.OptimalResolution = parseIntOrZero(wsm.OptimalScreenDPI)
config.AudioMimetypes = []string{
"audio/L8",
"audio/L16",
}
cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config)
if err != nil {
rs.log.WithError(err).Warning("failed to setup connection")
return
}
cc.OnError = func(err error) {
rs.connm.Lock()
delete(rs.conns, wsm.ConnID)
_ = rs.ac.SendWSHello(map[string]interface{}{
"active_connections": len(rs.conns),
})
rs.connm.Unlock()
}
rs.connm.Lock()
rs.conns[wsm.ConnID] = *cc
_ = rs.ac.SendWSHello(map[string]interface{}{
"active_connections": len(rs.conns),
})
rs.connm.Unlock()
}
func (rs *RACServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := rs.startGuac()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}
func (rs *RACServer) Stop() error {
if rs.guacd != nil {
return rs.guacd.Process.Kill()
}
return nil
}
func (rs *RACServer) TimerFlowCacheExpiry(context.Context) {}
func (rs *RACServer) Type() string {
return "rac"
}
func (rs *RACServer) Refresh() error {
return nil
}

View file

@ -33,6 +33,11 @@ func (ws *WebServer) configureStatic() {
})
indexLessRouter.PathPrefix("/if/admin/assets").Handler(http.StripPrefix("/if/admin", distFs))
indexLessRouter.PathPrefix("/if/user/assets").Handler(http.StripPrefix("/if/user", distFs))
indexLessRouter.PathPrefix("/if/rac/{app_slug}/assets").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
web.DisableIndex(http.StripPrefix(fmt.Sprintf("/if/rac/%s", vars["app_slug"]), distFs)).ServeHTTP(rw, r)
})
// Media files, if backend is file
if config.Get().Storage.Media.Backend == "file" {

View file

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-12-27 10:56+0000\n"
"POT-Creation-Date: 2024-01-03 11:22+0000\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@ -388,6 +388,45 @@ msgstr ""
msgid "License Usage Records"
msgstr ""
#: authentik/enterprise/policy.py:18
msgid "Enterprise required to access this feature."
msgstr ""
#: authentik/enterprise/policy.py:20
msgid "Feature only accessible for internal users."
msgstr ""
#: authentik/enterprise/providers/rac/models.py:48
#: authentik/stages/user_login/models.py:39
msgid ""
"Determines how long a session lasts. Default of 0 means that the sessions "
"lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)"
msgstr ""
#: authentik/enterprise/providers/rac/models.py:71
msgid "RAC Provider"
msgstr ""
#: authentik/enterprise/providers/rac/models.py:72
msgid "RAC Providers"
msgstr ""
#: authentik/enterprise/providers/rac/models.py:99
msgid "RAC Endpoint"
msgstr ""
#: authentik/enterprise/providers/rac/models.py:100
msgid "RAC Endpoints"
msgstr ""
#: authentik/enterprise/providers/rac/models.py:121
msgid "RAC Property Mapping"
msgstr ""
#: authentik/enterprise/providers/rac/models.py:122
msgid "RAC Property Mappings"
msgstr ""
#: authentik/events/models.py:289
msgid "Event"
msgstr ""
@ -490,7 +529,7 @@ msgstr ""
msgid "Webhook Mappings"
msgstr ""
#: authentik/events/monitored_tasks.py:205
#: authentik/events/monitored_tasks.py:207
msgid "Task has not been run yet."
msgstr ""
@ -669,75 +708,75 @@ msgstr ""
msgid "Invalid kubeconfig"
msgstr ""
#: authentik/outposts/models.py:122
#: authentik/outposts/models.py:123
msgid ""
"If enabled, use the local connection. Required Docker socket/Kubernetes "
"Integration"
msgstr ""
#: authentik/outposts/models.py:152
#: authentik/outposts/models.py:153
msgid "Outpost Service-Connection"
msgstr ""
#: authentik/outposts/models.py:153
#: authentik/outposts/models.py:154
msgid "Outpost Service-Connections"
msgstr ""
#: authentik/outposts/models.py:161
#: authentik/outposts/models.py:162
msgid ""
"Can be in the format of 'unix://<path>' when connecting to a local docker "
"daemon, or 'https://<hostname>:2376' when connecting to a remote system."
msgstr ""
#: authentik/outposts/models.py:173
#: authentik/outposts/models.py:174
msgid ""
"CA which the endpoint's Certificate is verified against. Can be left empty "
"for no validation."
msgstr ""
#: authentik/outposts/models.py:185
#: authentik/outposts/models.py:186
msgid ""
"Certificate/Key used for authentication. Can be left empty for no "
"authentication."
msgstr ""
#: authentik/outposts/models.py:203
#: authentik/outposts/models.py:204
msgid "Docker Service-Connection"
msgstr ""
#: authentik/outposts/models.py:204
#: authentik/outposts/models.py:205
msgid "Docker Service-Connections"
msgstr ""
#: authentik/outposts/models.py:212
#: authentik/outposts/models.py:213
msgid ""
"Paste your kubeconfig here. authentik will automatically use the currently "
"selected context."
msgstr ""
#: authentik/outposts/models.py:218
#: authentik/outposts/models.py:219
msgid "Verify SSL Certificates of the Kubernetes API endpoint"
msgstr ""
#: authentik/outposts/models.py:235
#: authentik/outposts/models.py:236
msgid "Kubernetes Service-Connection"
msgstr ""
#: authentik/outposts/models.py:236
#: authentik/outposts/models.py:237
msgid "Kubernetes Service-Connections"
msgstr ""
#: authentik/outposts/models.py:252
#: authentik/outposts/models.py:253
msgid ""
"Select Service-Connection authentik should use to manage this outpost. Leave "
"empty if authentik should not handle the deployment."
msgstr ""
#: authentik/outposts/models.py:419
#: authentik/outposts/models.py:420
msgid "Outpost"
msgstr ""
#: authentik/outposts/models.py:420
#: authentik/outposts/models.py:421
msgid "Outposts"
msgstr ""
@ -1591,11 +1630,11 @@ msgstr ""
msgid "Can edit system settings"
msgstr ""
#: authentik/recovery/management/commands/create_admin_group.py:11
#: authentik/recovery/management/commands/create_admin_group.py:12
msgid "Create admin group if the default group gets deleted."
msgstr ""
#: authentik/recovery/management/commands/create_recovery_key.py:17
#: authentik/recovery/management/commands/create_recovery_key.py:16
msgid "Create a Key which can be used to restore access to authentik."
msgstr ""
@ -2618,12 +2657,6 @@ msgstr ""
msgid "No Pending User."
msgstr ""
#: authentik/stages/user_login/models.py:39
msgid ""
"Determines how long a session lasts. Default of 0 means that the sessions "
"lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)"
msgstr ""
#: authentik/stages/user_login/models.py:47
msgid "Bind sessions created by this stage to the configured network"
msgstr ""

Binary file not shown.

File diff suppressed because it is too large Load diff

Binary file not shown.

1271
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -149,7 +149,12 @@ geoip2 = "*"
gunicorn = "*"
kubernetes = "*"
ldap3 = "*"
lxml = "*"
lxml = [
# 5.0.0 works with libxml2 2.11.x, which is standard on brew
{ version = "5.0.0", platform = "darwin" },
# 4.9.x works with previous libxml2 versions, which is what we get on linux
{ version = "4.9.4", platform = "linux" },
]
opencontainers = { extras = ["reggie"], version = "*" }
packaging = "*"
paramiko = "*"

38
rac.Dockerfile Normal file
View file

@ -0,0 +1,38 @@
# syntax=docker/dockerfile:1
# Stage 1: Build
FROM docker.io/golang:1.21.5-bookworm AS builder
WORKDIR /go/src/goauthentik.io
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
--mount=type=bind,target=/go/src/goauthentik.io/gen-go-api,src=./gen-go-api \
--mount=type=cache,target=/go/pkg/mod \
go mod download
ENV CGO_ENABLED=0
COPY . .
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
go build -o /go/rac ./cmd/rac
# Stage 2: Run
FROM ghcr.io/beryju/guacd:1.5.3
ARG GIT_BUILD_HASH
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
LABEL org.opencontainers.image.url https://goauthentik.io
LABEL org.opencontainers.image.description goauthentik.io RAC outpost, see https://goauthentik.io for more info.
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
LABEL org.opencontainers.image.version ${VERSION}
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
COPY --from=builder /go/rac /
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/rac", "healthcheck" ]
USER 1000
ENTRYPOINT ["/rac"]

1231
schema.yml

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,6 @@
"""LDAP and Outpost e2e tests"""
from dataclasses import asdict
from sys import platform
from time import sleep
from unittest.case import skipUnless
from docker.client import DockerClient, from_env
from docker.models.containers import Container
@ -14,13 +12,13 @@ from authentik.blueprints.tests import apply_blueprint, reconcile_app
from authentik.core.models import Application, User
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow
from authentik.lib.generators import generate_id
from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import Outpost, OutpostConfig, OutpostType
from authentik.providers.ldap.models import APIAccessMode, LDAPProvider
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderLDAP(SeleniumTestCase):
"""LDAP and Outpost e2e tests"""
@ -37,7 +35,10 @@ class TestProviderLDAP(SeleniumTestCase):
container = client.containers.run(
image=self.get_container_image("ghcr.io/goauthentik/dev-ldap"),
detach=True,
network_mode="host",
ports={
"3389": "3389",
"6636": "6636",
},
environment={
"AUTHENTIK_HOST": self.live_server_url,
"AUTHENTIK_TOKEN": outpost.token.key,
@ -51,15 +52,15 @@ class TestProviderLDAP(SeleniumTestCase):
self.user.save()
ldap: LDAPProvider = LDAPProvider.objects.create(
name="ldap_provider",
name=generate_id(),
authorization_flow=Flow.objects.get(slug="default-authentication-flow"),
search_group=self.user.ak_groups.first(),
search_mode=APIAccessMode.CACHED,
)
# we need to create an application to actually access the ldap
Application.objects.create(name="ldap", slug="ldap", provider=ldap)
Application.objects.create(name=generate_id(), slug=generate_id(), provider=ldap)
outpost: Outpost = Outpost.objects.create(
name="ldap_outpost",
name=generate_id(),
type=OutpostType.LDAP,
_config=asdict(OutpostConfig(log_level="debug")),
)

View file

@ -1,8 +1,6 @@
"""test OAuth Provider flow"""
from sys import platform
from time import sleep
from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck
from selenium.webdriver.common.by import By
@ -18,7 +16,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2Github(SeleniumTestCase):
"""test OAuth Provider flow"""
@ -32,7 +29,9 @@ class TestProviderOAuth2Github(SeleniumTestCase):
return {
"image": "grafana/grafana:7.1.0",
"detach": True,
"network_mode": "host",
"ports": {
"3000": "3000",
},
"auto_remove": True,
"healthcheck": Healthcheck(
test=["CMD", "wget", "--spider", "http://localhost:3000"],

View file

@ -1,8 +1,6 @@
"""test OAuth2 OpenID Provider flow"""
from sys import platform
from time import sleep
from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck
from selenium.webdriver.common.by import By
@ -24,7 +22,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2OAuth(SeleniumTestCase):
"""test OAuth with OAuth Provider flow"""
@ -38,13 +35,15 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
return {
"image": "grafana/grafana:7.1.0",
"detach": True,
"network_mode": "host",
"auto_remove": True,
"healthcheck": Healthcheck(
test=["CMD", "wget", "--spider", "http://localhost:3000"],
interval=5 * 1_000 * 1_000_000,
start_period=1 * 1_000 * 1_000_000,
),
"ports": {
"3000": "3000",
},
"environment": {
"GF_AUTH_GENERIC_OAUTH_ENABLED": "true",
"GF_AUTH_GENERIC_OAUTH_CLIENT_ID": self.client_id,

View file

@ -1,8 +1,6 @@
"""test OAuth2 OpenID Provider flow"""
from json import loads
from sys import platform
from time import sleep
from unittest.case import skipUnless
from docker import DockerClient, from_env
from docker.models.containers import Container
@ -25,7 +23,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2OIDC(SeleniumTestCase):
"""test OAuth with OpenID Provider flow"""
@ -36,13 +33,15 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
super().setUp()
def setup_client(self) -> Container:
"""Setup client saml-sp container which we test SAML against"""
"""Setup client oidc-test-client container which we test OIDC against"""
sleep(1)
client: DockerClient = from_env()
container = client.containers.run(
image="ghcr.io/beryju/oidc-test-client:1.3",
detach=True,
network_mode="host",
ports={
"9009": "9009",
},
environment={
"OIDC_CLIENT_ID": self.client_id,
"OIDC_CLIENT_SECRET": self.client_secret,

View file

@ -1,8 +1,6 @@
"""test OAuth2 OpenID Provider flow"""
from json import loads
from sys import platform
from time import sleep
from unittest.case import skipUnless
from docker import DockerClient, from_env
from docker.models.containers import Container
@ -25,7 +23,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
"""test OAuth with OpenID Provider flow"""
@ -36,13 +33,15 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
super().setUp()
def setup_client(self) -> Container:
"""Setup client saml-sp container which we test SAML against"""
"""Setup client oidc-test-client container which we test OIDC against"""
sleep(1)
client: DockerClient = from_env()
container = client.containers.run(
image="ghcr.io/beryju/oidc-test-client:1.3",
detach=True,
network_mode="host",
ports={
"9009": "9009",
},
environment={
"OIDC_CLIENT_ID": self.client_id,
"OIDC_CLIENT_SECRET": self.client_secret,

View file

@ -21,7 +21,6 @@ from authentik.providers.proxy.models import ProxyProvider
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderProxy(SeleniumTestCase):
"""Proxy and Outpost e2e tests"""
@ -36,7 +35,9 @@ class TestProviderProxy(SeleniumTestCase):
return {
"image": "traefik/whoami:latest",
"detach": True,
"network_mode": "host",
"ports": {
"80": "80",
},
"auto_remove": True,
}
@ -46,7 +47,9 @@ class TestProviderProxy(SeleniumTestCase):
container = client.containers.run(
image=self.get_container_image("ghcr.io/goauthentik/dev-proxy"),
detach=True,
network_mode="host",
ports={
"9000": "9000",
},
environment={
"AUTHENTIK_HOST": self.live_server_url,
"AUTHENTIK_TOKEN": outpost.token.key,
@ -78,7 +81,7 @@ class TestProviderProxy(SeleniumTestCase):
authorization_flow=Flow.objects.get(
slug="default-provider-authorization-implicit-consent"
),
internal_host="http://localhost",
internal_host=f"http://{self.host}",
external_host="http://localhost:9000",
)
# Ensure OAuth2 Params are set
@ -145,7 +148,7 @@ class TestProviderProxy(SeleniumTestCase):
authorization_flow=Flow.objects.get(
slug="default-provider-authorization-implicit-consent"
),
internal_host="http://localhost",
internal_host=f"http://{self.host}",
external_host="http://localhost:9000",
basic_auth_enabled=True,
basic_auth_user_attribute="basic-username",

View file

@ -1,8 +1,6 @@
"""Radius e2e tests"""
from dataclasses import asdict
from sys import platform
from time import sleep
from unittest.case import skipUnless
from docker.client import DockerClient, from_env
from docker.models.containers import Container
@ -19,7 +17,6 @@ from authentik.providers.radius.models import RadiusProvider
from tests.e2e.utils import SeleniumTestCase, retry
@skipUnless(platform.startswith("linux"), "requires local docker")
class TestProviderRadius(SeleniumTestCase):
"""Radius Outpost e2e tests"""
@ -40,7 +37,7 @@ class TestProviderRadius(SeleniumTestCase):
container = client.containers.run(
image=self.get_container_image("ghcr.io/goauthentik/dev-radius"),
detach=True,
network_mode="host",
ports={"1812/udp": "1812/udp"},
environment={
"AUTHENTIK_HOST": self.live_server_url,
"AUTHENTIK_TOKEN": outpost.token.key,

Some files were not shown because too many files have changed in this diff Show more