diff --git a/authentik/sources/oauth/clients/base.py b/authentik/sources/oauth/clients/base.py index fdece64ac..df1e07054 100644 --- a/authentik/sources/oauth/clients/base.py +++ b/authentik/sources/oauth/clients/base.py @@ -41,11 +41,7 @@ class BaseOAuthClient: if self.source.type.urls_customizable and self.source.profile_url: profile_url = self.source.profile_url try: - response = self.do_request( - "get", - profile_url, - headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, - ) + response = self.do_request("get", profile_url, token=token) response.raise_for_status() except RequestException as exc: LOGGER.warning("Unable to fetch user profile", exc=exc) diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index f5a7587b6..cb0252454 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -1,8 +1,10 @@ """AzureAD OAuth2 Views""" -from typing import Any +from typing import Any, Optional +from requests.exceptions import RequestException from structlog.stdlib import get_logger +from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -19,9 +21,33 @@ class AzureADOAuthRedirect(OAuthRedirect): } +class AzureADClient(OAuth2Client): + """Azure AD Oauth client, azure ad doesn't like the ?access_token that is sent by default""" + + def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + "Fetch user profile information." + profile_url = self.source.type.profile_url or "" + if self.source.type.urls_customizable and self.source.profile_url: + profile_url = self.source.profile_url + try: + response = self.session.request( + "get", + profile_url, + headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, + ) + response.raise_for_status() + except RequestException as exc: + LOGGER.warning("Unable to fetch user profile", exc=exc) + return None + else: + return response.json() + + class AzureADOAuthCallback(OAuthCallback): """AzureAD OAuth2 Callback""" + client_class = AzureADClient + def get_user_enroll_context( self, info: dict[str, Any],