providers/oauth2: save full IDToken to database, only use to_dict for encoding final token

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-10-08 15:04:43 +03:00
parent b2a658d091
commit 9bbe8e6c57
5 changed files with 26 additions and 25 deletions

View file

@ -439,7 +439,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
@id_token.setter @id_token.setter
def id_token(self, value: IDToken): def id_token(self, value: IDToken):
self._id_token = json.dumps(value.to_dict()) self._id_token = json.dumps(asdict(value))
def __str__(self): def __str__(self):
return f"Refresh Token for {self.provider} for user {self.user}" return f"Refresh Token for {self.provider} for user {self.user}"

View file

@ -1,6 +1,7 @@
"""Test introspect view""" """Test introspect view"""
import json import json
from base64 import b64encode from base64 import b64encode
from dataclasses import asdict
from django.urls import reverse from django.urls import reverse
@ -36,7 +37,9 @@ class TesOAuth2Introspection(OAuthTestCase):
refresh_token=generate_id(), refresh_token=generate_id(),
_scope="openid user profile", _scope="openid user profile",
_id_token=json.dumps( _id_token=json.dumps(
IDToken("foo", "bar").to_dict(), asdict(
IDToken("foo", "bar"),
)
), ),
) )
self.auth = b64encode( self.auth = b64encode(

View file

@ -1,6 +1,7 @@
"""Test revoke view""" """Test revoke view"""
import json import json
from base64 import b64encode from base64 import b64encode
from dataclasses import asdict
from django.urls import reverse from django.urls import reverse
@ -35,7 +36,11 @@ class TesOAuth2Revoke(OAuthTestCase):
access_token=generate_id(), access_token=generate_id(),
refresh_token=generate_id(), refresh_token=generate_id(),
_scope="openid user profile", _scope="openid user profile",
_id_token=json.dumps(IDToken("foo", "bar").to_dict()), _id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
) )
self.auth = b64encode( self.auth = b64encode(
f"{self.provider.client_id}:{self.provider.client_secret}".encode() f"{self.provider.client_id}:{self.provider.client_secret}".encode()

View file

@ -1,5 +1,6 @@
"""Test userinfo view""" """Test userinfo view"""
import json import json
from dataclasses import asdict
from django.urls import reverse from django.urls import reverse
@ -38,7 +39,11 @@ class TestUserinfo(OAuthTestCase):
access_token=generate_id(), access_token=generate_id(),
refresh_token=generate_id(), refresh_token=generate_id(),
_scope="openid user profile", _scope="openid user profile",
_id_token=json.dumps(IDToken("foo", "bar").to_dict()), _id_token=json.dumps(
asdict(
IDToken("foo", "bar"),
)
),
) )
def test_userinfo_normal(self): def test_userinfo_normal(self):

View file

@ -421,8 +421,7 @@ class TokenView(View):
def create_code_response(self) -> dict[str, Any]: def create_code_response(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-4.1""" """See https://tools.ietf.org/html/rfc6749#section-4.1"""
refresh_token = self.provider.create_refresh_token(
refresh_token = self.params.authorization_code.provider.create_refresh_token(
user=self.params.authorization_code.user, user=self.params.authorization_code.user,
scope=self.params.authorization_code.scope, scope=self.params.authorization_code.scope,
request=self.request, request=self.request,
@ -447,22 +446,17 @@ class TokenView(View):
"access_token": refresh_token.access_token, "access_token": refresh_token.access_token,
"refresh_token": refresh_token.refresh_token, "refresh_token": refresh_token.refresh_token,
"token_type": "bearer", "token_type": "bearer",
"expires_in": int( "expires_in": int(timedelta_from_string(self.provider.token_validity).total_seconds()),
timedelta_from_string(self.params.provider.token_validity).total_seconds() "id_token": self.provider.encode(refresh_token.id_token.to_dict()),
),
"id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()),
} }
def create_refresh_response(self) -> dict[str, Any]: def create_refresh_response(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-6""" """See https://tools.ietf.org/html/rfc6749#section-6"""
unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope) unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope)
if unauthorized_scopes: if unauthorized_scopes:
raise TokenError("invalid_scope") raise TokenError("invalid_scope")
provider: OAuth2Provider = self.params.refresh_token.provider refresh_token: RefreshToken = self.provider.create_refresh_token(
refresh_token: RefreshToken = provider.create_refresh_token(
user=self.params.refresh_token.user, user=self.params.refresh_token.user,
scope=self.params.scope, scope=self.params.scope,
request=self.request, request=self.request,
@ -487,17 +481,13 @@ class TokenView(View):
"access_token": refresh_token.access_token, "access_token": refresh_token.access_token,
"refresh_token": refresh_token.refresh_token, "refresh_token": refresh_token.refresh_token,
"token_type": "bearer", "token_type": "bearer",
"expires_in": int( "expires_in": int(timedelta_from_string(self.provider.token_validity).total_seconds()),
timedelta_from_string(refresh_token.provider.token_validity).total_seconds() "id_token": self.provider.encode(refresh_token.id_token.to_dict()),
),
"id_token": self.params.provider.encode(refresh_token.id_token.to_dict()),
} }
def create_client_credentials_response(self) -> dict[str, Any]: def create_client_credentials_response(self) -> dict[str, Any]:
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4""" """See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4"""
provider: OAuth2Provider = self.params.provider refresh_token: RefreshToken = self.provider.create_refresh_token(
refresh_token: RefreshToken = provider.create_refresh_token(
user=self.params.user, user=self.params.user,
scope=self.params.scope, scope=self.params.scope,
request=self.request, request=self.request,
@ -514,8 +504,6 @@ class TokenView(View):
return { return {
"access_token": refresh_token.access_token, "access_token": refresh_token.access_token,
"token_type": "bearer", "token_type": "bearer",
"expires_in": int( "expires_in": int(timedelta_from_string(self.provider.token_validity).total_seconds()),
timedelta_from_string(refresh_token.provider.token_validity).total_seconds() "id_token": self.provider.encode(refresh_token.id_token.to_dict()),
),
"id_token": self.params.provider.encode(refresh_token.id_token.to_dict()),
} }