From 83eaac375dfe0b46f2e7fcb9b3a508a5ea547d65 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Fri, 26 Aug 2022 21:21:39 +0200 Subject: [PATCH] sources/oauth: use GitHub's dedicated email API when no public email address is configured closes #3472 Signed-off-by: Jens Langhammer --- authentik/sources/oauth/clients/base.py | 7 ++-- .../sources/oauth/tests/test_type_github.py | 39 ++++++++++++++++--- authentik/sources/oauth/types/github.py | 38 +++++++++++++++++- authentik/sources/oauth/views/callback.py | 13 ++++--- 4 files changed, 80 insertions(+), 17 deletions(-) diff --git a/authentik/sources/oauth/clients/base.py b/authentik/sources/oauth/clients/base.py index 4ae0023d4..a53a515cb 100644 --- a/authentik/sources/oauth/clients/base.py +++ b/authentik/sources/oauth/clients/base.py @@ -12,8 +12,6 @@ from authentik.events.models import Event, EventAction from authentik.lib.utils.http import get_http_session from authentik.sources.oauth.models import OAuthSource -LOGGER = get_logger() - class BaseOAuthClient: """Base OAuth Client""" @@ -30,6 +28,7 @@ class BaseOAuthClient: self.session = get_http_session() self.request = request self.callback = callback + self.logger = get_logger().bind(source=source.slug) def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: """Fetch access token from callback request.""" @@ -44,7 +43,7 @@ class BaseOAuthClient: response = self.do_request("get", profile_url, token=token) response.raise_for_status() except RequestException as exc: - LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text) + self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text) return None else: return response.json() @@ -73,7 +72,7 @@ class BaseOAuthClient: # to make additional scopes easier args["scope"] = " ".join(sorted(set(args["scope"]))) params = urlencode(args, quote_via=quote, doseq=True) - LOGGER.info("redirect args", **args) + self.logger.info("redirect args", **args) return urlunparse(parsed_url._replace(query=params)) def parse_raw_token(self, raw_token: str) -> dict[str, Any]: diff --git a/authentik/sources/oauth/tests/test_type_github.py b/authentik/sources/oauth/tests/test_type_github.py index 2ff26c846..ad62a6f50 100644 --- a/authentik/sources/oauth/tests/test_type_github.py +++ b/authentik/sources/oauth/tests/test_type_github.py @@ -1,6 +1,10 @@ """GitHub Type tests""" -from django.test import TestCase +from copy import copy +from django.test import RequestFactory, TestCase +from requests_mock import Mocker + +from authentik.lib.generators import generate_id from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.github import GitHubOAuth2Callback @@ -55,11 +59,9 @@ class TestTypeGitHub(TestCase): self.source = OAuthSource.objects.create( name="test", slug="test", - provider_type="openidconnect", - authorization_url="", - profile_url="", - consumer_key="", + provider_type="github", ) + self.factory = RequestFactory() def test_enroll_context(self): """Test GitHub Enrollment context""" @@ -67,3 +69,30 @@ class TestTypeGitHub(TestCase): self.assertEqual(ak_context["username"], GITHUB_USER["login"]) self.assertEqual(ak_context["email"], GITHUB_USER["email"]) self.assertEqual(ak_context["name"], GITHUB_USER["name"]) + + def test_enroll_context_email(self): + """Test GitHub Enrollment context""" + email = generate_id() + user = copy(GITHUB_USER) + del user["email"] + with Mocker() as mocker: + mocker.get( + "https://api.github.com/user/emails", + json=[ + { + "primary": True, + "email": email, + } + ], + ) + ak_context = GitHubOAuth2Callback( + source=self.source, + request=self.factory.get("/"), + token={ + "access_token": generate_id(), + "token_type": generate_id(), + }, + ).get_user_enroll_context(user) + self.assertEqual(ak_context["username"], GITHUB_USER["login"]) + self.assertEqual(ak_context["email"], email) + self.assertEqual(ak_context["name"], GITHUB_USER["name"]) diff --git a/authentik/sources/oauth/types/github.py b/authentik/sources/oauth/types/github.py index a1b4bcff1..f800edb4e 100644 --- a/authentik/sources/oauth/types/github.py +++ b/authentik/sources/oauth/types/github.py @@ -1,6 +1,9 @@ """GitHub OAuth Views""" -from typing import Any +from typing import Any, Optional +from requests.exceptions import RequestException + +from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -15,16 +18,47 @@ class GitHubOAuthRedirect(OAuthRedirect): } +class GitHubOAuth2Client(OAuth2Client): + """GitHub OAuth2 Client""" + + def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + """Get Emails from the GitHub API""" + profile_url = self.source.type.profile_url or "" + if self.source.type.urls_customizable and self.source.profile_url: + profile_url = self.source.profile_url + profile_url += "/emails" + try: + response = self.do_request("get", profile_url, token=token) + response.raise_for_status() + except RequestException as exc: + self.logger.warning("Unable to fetch github emails", exc=exc) + return [] + else: + return response.json() + + class GitHubOAuth2Callback(OAuthCallback): """GitHub OAuth2 Callback""" + client_class = GitHubOAuth2Client + def get_user_enroll_context( self, info: dict[str, Any], ) -> dict[str, Any]: + chosen_email = info.get("email") + if not chosen_email: + # The GitHub Userprofile API only returns an email address if the profile + # has a public email address set (despite us asking for user:email, this behaviour + # doesn't change.). So we fetch all the user's email addresses + client: GitHubOAuth2Client = self.get_client(self.source) + emails = client.get_github_emails(self.token) + for email in emails: + if email.get("primary", False): + chosen_email = email.get("email", None) return { "username": info.get("login"), - "email": info.get("email"), + "email": chosen_email, "name": info.get("name"), } diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index b4ddf2948..0dc0390d7 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -22,6 +22,7 @@ class OAuthCallback(OAuthClientMixin, View): "Base OAuth callback view." source: OAuthSource + token: Optional[dict] = None # pylint: disable=too-many-return-statements def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: @@ -36,14 +37,14 @@ class OAuthCallback(OAuthClientMixin, View): raise Http404(f"Source {slug} is not enabled.") client = self.get_client(self.source, callback=self.get_callback_url(self.source)) # Fetch access token - token = client.get_access_token() - if token is None: + self.token = client.get_access_token() + if self.token is None: return self.handle_login_failure("Could not retrieve token.") - if "error" in token: - return self.handle_login_failure(token["error"]) + if "error" in self.token: + return self.handle_login_failure(self.token["error"]) # Fetch profile info try: - raw_info = client.get_profile_info(token) + raw_info = client.get_profile_info(self.token) if raw_info is None: return self.handle_login_failure("Could not retrieve profile.") except JSONDecodeError as exc: @@ -66,7 +67,7 @@ class OAuthCallback(OAuthClientMixin, View): ) sfm.policy_context = {"oauth_userinfo": raw_info} return sfm.get_flow( - access_token=token.get("access_token"), + access_token=self.token.get("access_token"), ) # pylint: disable=unused-argument