sources/oauth: fix missing get_user_id for OIDC-like sources (Azure AD) (#7970)

* lib: add debug requests session that shows all sent requests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* sources/oauth: fix missing get_user_id for OIDC-like OAuth Sources

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-12-22 00:10:47 +01:00 committed by GitHub
parent 48e5823ad6
commit 06df705240
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 30 additions and 16 deletions

View file

@ -1,5 +1,8 @@
"""http helpers"""
from requests.sessions import Session
from uuid import uuid4
from django.conf import settings
from requests.sessions import PreparedRequest, Session
from structlog.stdlib import get_logger
from authentik import get_full_version
@ -12,8 +15,25 @@ def authentik_user_agent() -> str:
return f"authentik@{get_full_version()}"
class DebugSession(Session):
"""requests session which logs http requests and responses"""
def send(self, req: PreparedRequest, *args, **kwargs):
request_id = str(uuid4())
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
resp = super().send(req, *args, **kwargs)
LOGGER.debug(
"HTTP response received",
uid=request_id,
status=resp.status_code,
body=resp.text,
headers=resp.headers,
)
return resp
def get_http_session() -> Session:
"""Get a requests session with common headers"""
session = Session()
session = DebugSession() if settings.DEBUG else Session()
session.headers["User-Agent"] = authentik_user_agent()
return session

View file

@ -4,8 +4,8 @@ from typing import Any
from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger()
@ -20,7 +20,7 @@ class AzureADOAuthRedirect(OAuthRedirect):
}
class AzureADOAuthCallback(OAuthCallback):
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
"""AzureAD OAuth2 Callback"""
client_class = UserprofileHeaderAuthClient
@ -50,7 +50,7 @@ class AzureADType(SourceType):
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
profile_url = "https://graph.microsoft.com/v1.0/me"
profile_url = "https://login.microsoftonline.com/common/openid/userinfo"
oidc_well_known_url = (
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
)

View file

@ -23,7 +23,7 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
return info.get("sub", None)
def get_user_enroll_context(
self,

View file

@ -3,8 +3,8 @@ from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -17,7 +17,7 @@ class OktaOAuthRedirect(OAuthRedirect):
}
class OktaOAuth2Callback(OAuthCallback):
class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
"""Okta OAuth2 Callback"""
# Okta has the same quirk as azure and throws an error if the access token
@ -25,9 +25,6 @@ class OktaOAuth2Callback(OAuthCallback):
# see https://github.com/goauthentik/authentik/issues/1910
client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context(
self,
info: dict[str, Any],

View file

@ -3,8 +3,8 @@ from json import dumps
from typing import Any, Optional
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -27,14 +27,11 @@ class TwitchOAuthRedirect(OAuthRedirect):
}
class TwitchOAuth2Callback(OAuthCallback):
class TwitchOAuth2Callback(OpenIDConnectOAuth2Callback):
"""Twitch OAuth2 Callback"""
client_class = TwitchClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context(
self,
info: dict[str, Any],