From 9bbe8e6c57e0a7ef0eb2ff1c1cfecf12824a21ee Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sat, 8 Oct 2022 15:04:43 +0300 Subject: [PATCH] providers/oauth2: save full IDToken to database, only use to_dict for encoding final token Signed-off-by: Jens Langhammer --- authentik/providers/oauth2/models.py | 2 +- .../providers/oauth2/tests/test_introspect.py | 5 +++- .../providers/oauth2/tests/test_revoke.py | 7 ++++- .../providers/oauth2/tests/test_userinfo.py | 7 ++++- authentik/providers/oauth2/views/token.py | 30 ++++++------------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index c4f685534..9f885c808 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -439,7 +439,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): @id_token.setter def id_token(self, value: IDToken): - self._id_token = json.dumps(value.to_dict()) + self._id_token = json.dumps(asdict(value)) def __str__(self): return f"Refresh Token for {self.provider} for user {self.user}" diff --git a/authentik/providers/oauth2/tests/test_introspect.py b/authentik/providers/oauth2/tests/test_introspect.py index 63e77b7db..40dea7bc5 100644 --- a/authentik/providers/oauth2/tests/test_introspect.py +++ b/authentik/providers/oauth2/tests/test_introspect.py @@ -1,6 +1,7 @@ """Test introspect view""" import json from base64 import b64encode +from dataclasses import asdict from django.urls import reverse @@ -36,7 +37,9 @@ class TesOAuth2Introspection(OAuthTestCase): refresh_token=generate_id(), _scope="openid user profile", _id_token=json.dumps( - IDToken("foo", "bar").to_dict(), + asdict( + IDToken("foo", "bar"), + ) ), ) self.auth = b64encode( diff --git a/authentik/providers/oauth2/tests/test_revoke.py b/authentik/providers/oauth2/tests/test_revoke.py index 046f1c74d..0e474d8dc 100644 --- a/authentik/providers/oauth2/tests/test_revoke.py +++ b/authentik/providers/oauth2/tests/test_revoke.py @@ -1,6 +1,7 @@ """Test revoke view""" import json from base64 import b64encode +from dataclasses import asdict from django.urls import reverse @@ -35,7 +36,11 @@ class TesOAuth2Revoke(OAuthTestCase): access_token=generate_id(), refresh_token=generate_id(), _scope="openid user profile", - _id_token=json.dumps(IDToken("foo", "bar").to_dict()), + _id_token=json.dumps( + asdict( + IDToken("foo", "bar"), + ) + ), ) self.auth = b64encode( f"{self.provider.client_id}:{self.provider.client_secret}".encode() diff --git a/authentik/providers/oauth2/tests/test_userinfo.py b/authentik/providers/oauth2/tests/test_userinfo.py index e6f420f6a..007db5c40 100644 --- a/authentik/providers/oauth2/tests/test_userinfo.py +++ b/authentik/providers/oauth2/tests/test_userinfo.py @@ -1,5 +1,6 @@ """Test userinfo view""" import json +from dataclasses import asdict from django.urls import reverse @@ -38,7 +39,11 @@ class TestUserinfo(OAuthTestCase): access_token=generate_id(), refresh_token=generate_id(), _scope="openid user profile", - _id_token=json.dumps(IDToken("foo", "bar").to_dict()), + _id_token=json.dumps( + asdict( + IDToken("foo", "bar"), + ) + ), ) def test_userinfo_normal(self): diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 11a2aba42..b04131455 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -421,8 +421,7 @@ class TokenView(View): def create_code_response(self) -> dict[str, Any]: """See https://tools.ietf.org/html/rfc6749#section-4.1""" - - refresh_token = self.params.authorization_code.provider.create_refresh_token( + refresh_token = self.provider.create_refresh_token( user=self.params.authorization_code.user, scope=self.params.authorization_code.scope, request=self.request, @@ -447,22 +446,17 @@ class TokenView(View): "access_token": refresh_token.access_token, "refresh_token": refresh_token.refresh_token, "token_type": "bearer", - "expires_in": int( - timedelta_from_string(self.params.provider.token_validity).total_seconds() - ), - "id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()), + "expires_in": int(timedelta_from_string(self.provider.token_validity).total_seconds()), + "id_token": self.provider.encode(refresh_token.id_token.to_dict()), } def create_refresh_response(self) -> dict[str, Any]: """See https://tools.ietf.org/html/rfc6749#section-6""" - unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope) if unauthorized_scopes: raise TokenError("invalid_scope") - provider: OAuth2Provider = self.params.refresh_token.provider - - refresh_token: RefreshToken = provider.create_refresh_token( + refresh_token: RefreshToken = self.provider.create_refresh_token( user=self.params.refresh_token.user, scope=self.params.scope, request=self.request, @@ -487,17 +481,13 @@ class TokenView(View): "access_token": refresh_token.access_token, "refresh_token": refresh_token.refresh_token, "token_type": "bearer", - "expires_in": int( - timedelta_from_string(refresh_token.provider.token_validity).total_seconds() - ), - "id_token": self.params.provider.encode(refresh_token.id_token.to_dict()), + "expires_in": int(timedelta_from_string(self.provider.token_validity).total_seconds()), + "id_token": self.provider.encode(refresh_token.id_token.to_dict()), } def create_client_credentials_response(self) -> dict[str, Any]: """See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4""" - provider: OAuth2Provider = self.params.provider - - refresh_token: RefreshToken = provider.create_refresh_token( + refresh_token: RefreshToken = self.provider.create_refresh_token( user=self.params.user, scope=self.params.scope, request=self.request, @@ -514,8 +504,6 @@ class TokenView(View): return { "access_token": refresh_token.access_token, "token_type": "bearer", - "expires_in": int( - timedelta_from_string(refresh_token.provider.token_validity).total_seconds() - ), - "id_token": self.params.provider.encode(refresh_token.id_token.to_dict()), + "expires_in": int(timedelta_from_string(self.provider.token_validity).total_seconds()), + "id_token": self.provider.encode(refresh_token.id_token.to_dict()), }