From d9428dc104ab8aa86f98f76b53da00bdc3043416 Mon Sep 17 00:00:00 2001
From: Jens Langhammer
Date: Sun, 30 Jul 2023 17:02:33 +0200
Subject: [PATCH] sources/oauth: add initial group sync
Signed-off-by: Jens Langhammer
---
.../core/migrations/0032_alter_group_name.py | 17 ++++++++
authentik/core/models.py | 2 +-
authentik/core/sources/flow_manager.py | 8 ++--
authentik/sources/oauth/api/source.py | 2 +
.../0008_oauthsource_groups_claim.py | 24 +++++++++++
authentik/sources/oauth/models.py | 10 +++++
.../sources/oauth/tests/test_type_openid.py | 5 +++
authentik/sources/oauth/types/oidc.py | 3 ++
authentik/sources/oauth/views/callback.py | 43 ++++++++++++++++++-
blueprints/schema.json | 10 ++++-
schema.yml | 29 ++++++++++---
.../admin/sources/oauth/OAuthSourceForm.ts | 19 ++++++++
12 files changed, 160 insertions(+), 12 deletions(-)
create mode 100644 authentik/core/migrations/0032_alter_group_name.py
create mode 100644 authentik/sources/oauth/migrations/0008_oauthsource_groups_claim.py
diff --git a/authentik/core/migrations/0032_alter_group_name.py b/authentik/core/migrations/0032_alter_group_name.py
new file mode 100644
index 000000000..097a9bc70
--- /dev/null
+++ b/authentik/core/migrations/0032_alter_group_name.py
@@ -0,0 +1,17 @@
+# Generated by Django 4.1.10 on 2023-07-30 14:48
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("authentik_core", "0031_alter_user_type"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="group",
+ name="name",
+ field=models.TextField(verbose_name="name"),
+ ),
+ ]
diff --git a/authentik/core/models.py b/authentik/core/models.py
index 72ede3d43..7e38bf9bd 100644
--- a/authentik/core/models.py
+++ b/authentik/core/models.py
@@ -83,7 +83,7 @@ class Group(SerializerModel):
group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
- name = models.CharField(_("name"), max_length=80)
+ name = models.TextField(_("name"))
is_superuser = models.BooleanField(
default=False, help_text=_("Users added to this group will be superusers.")
)
diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py
index 0339fee2f..aef3bfd0c 100644
--- a/authentik/core/sources/flow_manager.py
+++ b/authentik/core/sources/flow_manager.py
@@ -220,7 +220,7 @@ class SourceFlowManager:
flow: Flow,
connection: UserSourceConnection,
stages: Optional[list[StageView]] = None,
- **kwargs,
+ **flow_context,
) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to
@@ -228,7 +228,7 @@ class SourceFlowManager:
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-user"
)
- kwargs.update(
+ flow_context.update(
{
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: BACKEND_INBUILT,
@@ -238,7 +238,7 @@ class SourceFlowManager:
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
}
)
- kwargs.update(self.policy_context)
+ flow_context.update(self.policy_context)
if not flow:
return bad_request_message(
self.request,
@@ -246,7 +246,7 @@ class SourceFlowManager:
)
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
- plan = planner.plan(self.request, kwargs)
+ plan = planner.plan(self.request, flow_context)
for stage in self.get_stages_to_append(flow):
plan.append_stage(stage)
if stages:
diff --git a/authentik/sources/oauth/api/source.py b/authentik/sources/oauth/api/source.py
index 6c9399f95..7121c2ee9 100644
--- a/authentik/sources/oauth/api/source.py
+++ b/authentik/sources/oauth/api/source.py
@@ -105,6 +105,7 @@ class OAuthSourceSerializer(SourceSerializer):
"consumer_secret",
"callback_url",
"additional_scopes",
+ "groups_claim",
"type",
"oidc_well_known_url",
"oidc_jwks_url",
@@ -137,6 +138,7 @@ class OAuthSourceFilter(FilterSet):
"authorization_url",
"access_token_url",
"profile_url",
+ "groups_claim",
"consumer_key",
"additional_scopes",
]
diff --git a/authentik/sources/oauth/migrations/0008_oauthsource_groups_claim.py b/authentik/sources/oauth/migrations/0008_oauthsource_groups_claim.py
new file mode 100644
index 000000000..eb0c9b0e4
--- /dev/null
+++ b/authentik/sources/oauth/migrations/0008_oauthsource_groups_claim.py
@@ -0,0 +1,24 @@
+# Generated by Django 4.1.10 on 2023-07-30 14:48
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ (
+ "authentik_sources_oauth",
+ "0007_oauthsource_oidc_jwks_oauthsource_oidc_jwks_url_and_more",
+ ),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="oauthsource",
+ name="groups_claim",
+ field=models.TextField(
+ default=None,
+ help_text="Sync groups and group membership from the source. Only use this option with sources that you control, as otherwise unwanted users might get added to groups with superuser permissions.",
+ null=True,
+ ),
+ ),
+ ]
diff --git a/authentik/sources/oauth/models.py b/authentik/sources/oauth/models.py
index 7d8dd2fb7..fbb01f400 100644
--- a/authentik/sources/oauth/models.py
+++ b/authentik/sources/oauth/models.py
@@ -54,6 +54,16 @@ class OAuthSource(Source):
oidc_jwks_url = models.TextField(default="", blank=True)
oidc_jwks = models.JSONField(default=dict, blank=True)
+ groups_claim = models.TextField(
+ default=None,
+ null=True,
+ help_text=_(
+ "Sync groups and group membership from the source. Only use this option with "
+ "sources that you control, as otherwise unwanted users might get added to "
+ "groups with superuser permissions."
+ ),
+ )
+
@property
def type(self) -> type["SourceType"]:
"""Return the provider instance for this source"""
diff --git a/authentik/sources/oauth/tests/test_type_openid.py b/authentik/sources/oauth/tests/test_type_openid.py
index e04bea4b0..e597cda55 100644
--- a/authentik/sources/oauth/tests/test_type_openid.py
+++ b/authentik/sources/oauth/tests/test_type_openid.py
@@ -14,6 +14,10 @@ OPENID_USER = {
"department": "Engineering",
"birthdate": "1975-12-31",
"nickname": "foo",
+ "groups": [
+ "foo",
+ "bar",
+ ]
}
@@ -28,6 +32,7 @@ class TestTypeOpenID(TestCase):
authorization_url="",
profile_url="http://localhost/userinfo",
consumer_key="",
+ groups_claim="groups",
)
self.factory = RequestFactory()
diff --git a/authentik/sources/oauth/types/oidc.py b/authentik/sources/oauth/types/oidc.py
index 7ebd24579..72a16bb84 100644
--- a/authentik/sources/oauth/types/oidc.py
+++ b/authentik/sources/oauth/types/oidc.py
@@ -35,6 +35,9 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
"name": info.get("name"),
}
+ def get_user_group_names(self, info: dict[str, Any]) -> list[str]:
+ return info.get(self.source.groups_claim, [])
+
@registry.register()
class OpenIDConnectType(SourceType):
diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py
index 893a5003c..7c96bb849 100644
--- a/authentik/sources/oauth/views/callback.py
+++ b/authentik/sources/oauth/views/callback.py
@@ -10,12 +10,17 @@ from django.utils.translation import gettext as _
from django.views.generic import View
from structlog.stdlib import get_logger
+from authentik.core.models import Group, User
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.events.models import Event, EventAction
+from authentik.flows.models import Flow, Stage, in_memory_stage
+from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
+from authentik.flows.stage import StageView
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.views.base import OAuthClientMixin
LOGGER = get_logger()
+PLAN_CONTEXT_GROUPS = "goauthentik.io/sources/oauth/groups"
class OAuthCallback(OAuthClientMixin, View):
@@ -59,13 +64,17 @@ class OAuthCallback(OAuthClientMixin, View):
return self.handle_login_failure("Could not determine id.")
# Get or create access record
enroll_info = self.get_user_enroll_context(raw_info)
+ group_info = self.get_user_group_names(raw_info)
sfm = OAuthSourceFlowManager(
source=self.source,
request=self.request,
identifier=identifier,
enroll_info=enroll_info,
)
- sfm.policy_context = {"oauth_userinfo": raw_info}
+ sfm.policy_context = {
+ "oauth_userinfo": raw_info,
+ PLAN_CONTEXT_GROUPS: group_info,
+ }
return sfm.get_flow(
access_token=self.token.get("access_token"),
)
@@ -85,6 +94,10 @@ class OAuthCallback(OAuthClientMixin, View):
"""Create a dict of User data"""
raise NotImplementedError()
+ def get_user_group_names(self, info: dict[str, Any]) -> list[str]:
+ """Return a list of all groups the user is member of"""
+ return []
+
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
"""Return unique identifier from the profile info."""
if "id" in info:
@@ -111,6 +124,13 @@ class OAuthSourceFlowManager(SourceFlowManager):
connection_type = UserOAuthSourceConnection
+ def get_stages_to_append(self, flow: Flow) -> list[Stage]:
+ return super().get_stages_to_append(flow) + [
+ # Always run this stage after the default `PostUserEnrollmentStage` stage
+ # as it relies on the user object existing
+ in_memory_stage(OAuthUserUpdateStage),
+ ]
+
def update_connection(
self,
connection: UserOAuthSourceConnection,
@@ -119,3 +139,24 @@ class OAuthSourceFlowManager(SourceFlowManager):
"""Set the access_token on the connection"""
connection.access_token = access_token
return connection
+
+
+class OAuthUserUpdateStage(StageView):
+ """Dynamically injected stage which updates the user after enrollment/authentication."""
+
+ def handle_groups(self):
+ """Sync users' groups from oauth data"""
+ user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
+ group_names: list[str] = self.executor.plan.context[PLAN_CONTEXT_GROUPS]
+ for group_name in group_names:
+ Group.objects.update_or_create(name=group_name, defaults={})
+ user.ak_groups.set(Group.objects.filter(name__in=[group_names]))
+
+ def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
+ """Stage used after the user has been enrolled"""
+ self.handle_groups()
+ return self.executor.stage_ok()
+
+ def post(self, request: HttpRequest) -> HttpResponse:
+ """Wrapper for post requests"""
+ return self.get(request)
diff --git a/blueprints/schema.json b/blueprints/schema.json
index 7637beec4..b122b0455 100644
--- a/blueprints/schema.json
+++ b/blueprints/schema.json
@@ -5117,6 +5117,15 @@
"type": "string",
"title": "Additional Scopes"
},
+ "groups_claim": {
+ "type": [
+ "string",
+ "null"
+ ],
+ "minLength": 1,
+ "title": "Groups claim",
+ "description": "Sync groups and group membership from the source. Only use this option with sources that you control, as otherwise unwanted users might get added to groups with superuser permissions."
+ },
"oidc_well_known_url": {
"type": "string",
"title": "Oidc well known url"
@@ -8305,7 +8314,6 @@
"properties": {
"name": {
"type": "string",
- "maxLength": 80,
"minLength": 1,
"title": "Name"
},
diff --git a/schema.yml b/schema.yml
index dcd8b0814..c6958d3f2 100644
--- a/schema.yml
+++ b/schema.yml
@@ -17943,6 +17943,10 @@ paths:
schema:
type: string
format: uuid
+ - in: query
+ name: groups_claim
+ schema:
+ type: string
- in: query
name: has_jwks
schema:
@@ -30470,7 +30474,6 @@ components:
readOnly: true
name:
type: string
- maxLength: 80
is_superuser:
type: boolean
description: Users added to this group will be superusers.
@@ -30583,7 +30586,6 @@ components:
name:
type: string
minLength: 1
- maxLength: 80
is_superuser:
type: boolean
description: Users added to this group will be superusers.
@@ -32649,6 +32651,12 @@ components:
readOnly: true
additional_scopes:
type: string
+ groups_claim:
+ type: string
+ nullable: true
+ description: Sync groups and group membership from the source. Only use
+ this option with sources that you control, as otherwise unwanted users
+ might get added to groups with superuser permissions.
type:
allOf:
- $ref: '#/components/schemas/SourceType'
@@ -32752,6 +32760,13 @@ components:
minLength: 1
additional_scopes:
type: string
+ groups_claim:
+ type: string
+ nullable: true
+ minLength: 1
+ description: Sync groups and group membership from the source. Only use
+ this option with sources that you control, as otherwise unwanted users
+ might get added to groups with superuser permissions.
oidc_well_known_url:
type: string
oidc_jwks_url:
@@ -36979,7 +36994,6 @@ components:
name:
type: string
minLength: 1
- maxLength: 80
is_superuser:
type: boolean
description: Users added to this group will be superusers.
@@ -37560,6 +37574,13 @@ components:
minLength: 1
additional_scopes:
type: string
+ groups_claim:
+ type: string
+ nullable: true
+ minLength: 1
+ description: Sync groups and group membership from the source. Only use
+ this option with sources that you control, as otherwise unwanted users
+ might get added to groups with superuser permissions.
oidc_well_known_url:
type: string
oidc_jwks_url:
@@ -42029,7 +42050,6 @@ components:
readOnly: true
name:
type: string
- maxLength: 80
is_superuser:
type: boolean
description: Users added to this group will be superusers.
@@ -42055,7 +42075,6 @@ components:
name:
type: string
minLength: 1
- maxLength: 80
is_superuser:
type: boolean
description: Users added to this group will be superusers.
diff --git a/web/src/admin/sources/oauth/OAuthSourceForm.ts b/web/src/admin/sources/oauth/OAuthSourceForm.ts
index fdf4d22b0..cdc4864bd 100644
--- a/web/src/admin/sources/oauth/OAuthSourceForm.ts
+++ b/web/src/admin/sources/oauth/OAuthSourceForm.ts
@@ -70,6 +70,9 @@ export class OAuthSourceForm extends ModelForm {
async send(data: OAuthSource): Promise {
data.providerType = (this.providerType?.slug || "") as ProviderTypeEnum;
+ if (data.groupsClaim === "") {
+ data.groupsClaim = null;
+ }
let source: OAuthSource;
if (this.instance) {
source = await new SourcesApi(DEFAULT_CONFIG).sourcesOauthPartialUpdate({
@@ -185,6 +188,7 @@ export class OAuthSourceForm extends ModelForm {
: html``}
${this.providerType.slug === ProviderTypeEnum.Openidconnect
? html`
+
{
+
+
+
+ ${msg(
+ "Sync groups and group membership from the source. Only use this option with sources that you control, as otherwise unwanted users might get added to groups with superuser permissions.",
+ )}
+