diff --git a/authentik/sources/oauth/tests/test_type_openid.py b/authentik/sources/oauth/tests/test_type_openid.py index 2f7395071..e04bea4b0 100644 --- a/authentik/sources/oauth/tests/test_type_openid.py +++ b/authentik/sources/oauth/tests/test_type_openid.py @@ -1,6 +1,8 @@ """OpenID Type tests""" -from django.test import TestCase +from django.test import RequestFactory, TestCase +from requests_mock import Mocker +from authentik.lib.generators import generate_id from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback @@ -24,9 +26,10 @@ class TestTypeOpenID(TestCase): slug="test", provider_type="openidconnect", authorization_url="", - profile_url="", + profile_url="http://localhost/userinfo", consumer_key="", ) + self.factory = RequestFactory() def test_enroll_context(self): """Test OpenID Enrollment context""" @@ -34,3 +37,19 @@ class TestTypeOpenID(TestCase): self.assertEqual(ak_context["username"], OPENID_USER["nickname"]) self.assertEqual(ak_context["email"], OPENID_USER["email"]) self.assertEqual(ak_context["name"], OPENID_USER["name"]) + + @Mocker() + def test_userinfo(self, mock: Mocker): + """Test userinfo API call""" + mock.get("http://localhost/userinfo", json=OPENID_USER) + token = generate_id() + OpenIDConnectOAuth2Callback(request=self.factory.get("/")).get_client( + self.source + ).get_profile_info( + { + "token_type": "foo", + "access_token": token, + } + ) + self.assertEqual(mock.last_request.query, "") + self.assertEqual(mock.last_request.headers["Authorization"], f"foo {token}") diff --git a/authentik/sources/oauth/types/oidc.py b/authentik/sources/oauth/types/oidc.py index 189209c43..7ebd24579 100644 --- a/authentik/sources/oauth/types/oidc.py +++ b/authentik/sources/oauth/types/oidc.py @@ -20,7 +20,7 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect): class OpenIDConnectOAuth2Callback(OAuthCallback): """OpenIDConnect OAuth2 Callback""" - client_class: UserprofileHeaderAuthClient + client_class = UserprofileHeaderAuthClient def get_user_id(self, info: dict[str, str]) -> str: return info.get("sub", "")