Revert "sources/oauth: fix access_token being sent as query param and not authorization header"

This reverts commit 248f993541.
This commit is contained in:
Jens Langhammer 2021-09-14 11:59:32 +02:00
parent 248f993541
commit 942170f902
2 changed files with 28 additions and 6 deletions

View file

@ -41,11 +41,7 @@ class BaseOAuthClient:
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
try:
response = self.do_request(
"get",
profile_url,
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
response = self.do_request("get", profile_url, token=token)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc)

View file

@ -1,8 +1,10 @@
"""AzureAD OAuth2 Views"""
from typing import Any
from typing import Any, Optional
from requests.exceptions import RequestException
from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -19,9 +21,33 @@ class AzureADOAuthRedirect(OAuthRedirect):
}
class AzureADClient(OAuth2Client):
"""Azure AD Oauth client, azure ad doesn't like the ?access_token that is sent by default"""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
try:
response = self.session.request(
"get",
profile_url,
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc)
return None
else:
return response.json()
class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback"""
client_class = AzureADClient
def get_user_enroll_context(
self,
info: dict[str, Any],