sources/oauth: only send header authentication for OIDC source

closes #3327

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-07-29 18:20:17 +02:00
parent b41acebf5b
commit 1dcec17a58
5 changed files with 35 additions and 31 deletions

View file

@ -128,3 +128,25 @@ class OAuth2Client(BaseOAuthClient):
@property @property
def session_key(self): def session_key(self):
return f"oauth-client-{self.source.name}-request-state" return f"oauth-client-{self.source.name}-request-state"
class UserprofileHeaderAuthClient(OAuth2Client):
"""OAuth client which only sends authentication via header, not querystring"""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
try:
response = self.session.request(
"get",
profile_url,
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
return None
else:
return response.json()

View file

@ -1,10 +1,9 @@
"""AzureAD OAuth2 Views""" """AzureAD OAuth2 Views"""
from typing import Any, Optional from typing import Any
from requests.exceptions import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -21,32 +20,10 @@ class AzureADOAuthRedirect(OAuthRedirect):
} }
class AzureADClient(OAuth2Client):
"""Azure AD Oauth client, azure ad doesn't like the ?access_token that is sent by default"""
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
try:
response = self.session.request(
"get",
profile_url,
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
return None
else:
return response.json()
class AzureADOAuthCallback(OAuthCallback): class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback""" """AzureAD OAuth2 Callback"""
client_class = AzureADClient client_class = UserprofileHeaderAuthClient
def get_user_enroll_context( def get_user_enroll_context(
self, self,

View file

@ -1,6 +1,7 @@
"""OpenID Connect OAuth Views""" """OpenID Connect OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@ -19,6 +20,8 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
class OpenIDConnectOAuth2Callback(OAuthCallback): class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback""" """OpenIDConnect OAuth2 Callback"""
client_class: UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "") return info.get("sub", "")

View file

@ -1,8 +1,8 @@
"""Okta OAuth Views""" """Okta OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.azure_ad import AzureADClient
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -23,7 +23,7 @@ class OktaOAuth2Callback(OAuthCallback):
# Okta has the same quirk as azure and throws an error if the access token # Okta has the same quirk as azure and throws an error if the access token
# is set via query parameter, so we re-use the azure client # is set via query parameter, so we re-use the azure client
# see https://github.com/goauthentik/authentik/issues/1910 # see https://github.com/goauthentik/authentik/issues/1910
client_class = AzureADClient client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "") return info.get("sub", "")

View file

@ -2,14 +2,16 @@
from typing import Any, Optional from typing import Any, Optional
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.sources.oauth.clients.oauth2 import SESSION_KEY_OAUTH_PKCE from authentik.sources.oauth.clients.oauth2 import (
from authentik.sources.oauth.types.azure_ad import AzureADClient SESSION_KEY_OAUTH_PKCE,
UserprofileHeaderAuthClient,
)
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
class TwitterClient(AzureADClient): class TwitterClient(UserprofileHeaderAuthClient):
"""Twitter has similar quirks to Azure AD, and additionally requires Basic auth on """Twitter has similar quirks to Azure AD, and additionally requires Basic auth on
the access token endpoint for some reason.""" the access token endpoint for some reason."""