sources/oauth: use GitHub's dedicated email API when no public email address is configured

closes #3472

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-08-26 21:21:39 +02:00
parent 2868331976
commit 83eaac375d
4 changed files with 80 additions and 17 deletions

View file

@ -12,8 +12,6 @@ from authentik.events.models import Event, EventAction
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
class BaseOAuthClient: class BaseOAuthClient:
"""Base OAuth Client""" """Base OAuth Client"""
@ -30,6 +28,7 @@ class BaseOAuthClient:
self.session = get_http_session() self.session = get_http_session()
self.request = request self.request = request
self.callback = callback self.callback = callback
self.logger = get_logger().bind(source=source.slug)
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"""Fetch access token from callback request.""" """Fetch access token from callback request."""
@ -44,7 +43,7 @@ class BaseOAuthClient:
response = self.do_request("get", profile_url, token=token) response = self.do_request("get", profile_url, token=token)
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text) self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text)
return None return None
else: else:
return response.json() return response.json()
@ -73,7 +72,7 @@ class BaseOAuthClient:
# to make additional scopes easier # to make additional scopes easier
args["scope"] = " ".join(sorted(set(args["scope"]))) args["scope"] = " ".join(sorted(set(args["scope"])))
params = urlencode(args, quote_via=quote, doseq=True) params = urlencode(args, quote_via=quote, doseq=True)
LOGGER.info("redirect args", **args) self.logger.info("redirect args", **args)
return urlunparse(parsed_url._replace(query=params)) return urlunparse(parsed_url._replace(query=params))
def parse_raw_token(self, raw_token: str) -> dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:

View file

@ -1,6 +1,10 @@
"""GitHub Type tests""" """GitHub Type tests"""
from django.test import TestCase from copy import copy
from django.test import RequestFactory, TestCase
from requests_mock import Mocker
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.github import GitHubOAuth2Callback from authentik.sources.oauth.types.github import GitHubOAuth2Callback
@ -55,11 +59,9 @@ class TestTypeGitHub(TestCase):
self.source = OAuthSource.objects.create( self.source = OAuthSource.objects.create(
name="test", name="test",
slug="test", slug="test",
provider_type="openidconnect", provider_type="github",
authorization_url="",
profile_url="",
consumer_key="",
) )
self.factory = RequestFactory()
def test_enroll_context(self): def test_enroll_context(self):
"""Test GitHub Enrollment context""" """Test GitHub Enrollment context"""
@ -67,3 +69,30 @@ class TestTypeGitHub(TestCase):
self.assertEqual(ak_context["username"], GITHUB_USER["login"]) self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], GITHUB_USER["email"]) self.assertEqual(ak_context["email"], GITHUB_USER["email"])
self.assertEqual(ak_context["name"], GITHUB_USER["name"]) self.assertEqual(ak_context["name"], GITHUB_USER["name"])
def test_enroll_context_email(self):
"""Test GitHub Enrollment context"""
email = generate_id()
user = copy(GITHUB_USER)
del user["email"]
with Mocker() as mocker:
mocker.get(
"https://api.github.com/user/emails",
json=[
{
"primary": True,
"email": email,
}
],
)
ak_context = GitHubOAuth2Callback(
source=self.source,
request=self.factory.get("/"),
token={
"access_token": generate_id(),
"token_type": generate_id(),
},
).get_user_enroll_context(user)
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], email)
self.assertEqual(ak_context["name"], GITHUB_USER["name"])

View file

@ -1,6 +1,9 @@
"""GitHub OAuth Views""" """GitHub OAuth Views"""
from typing import Any from typing import Any, Optional
from requests.exceptions import RequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -15,16 +18,47 @@ class GitHubOAuthRedirect(OAuthRedirect):
} }
class GitHubOAuth2Client(OAuth2Client):
"""GitHub OAuth2 Client"""
def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"""Get Emails from the GitHub API"""
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
profile_url += "/emails"
try:
response = self.do_request("get", profile_url, token=token)
response.raise_for_status()
except RequestException as exc:
self.logger.warning("Unable to fetch github emails", exc=exc)
return []
else:
return response.json()
class GitHubOAuth2Callback(OAuthCallback): class GitHubOAuth2Callback(OAuthCallback):
"""GitHub OAuth2 Callback""" """GitHub OAuth2 Callback"""
client_class = GitHubOAuth2Client
def get_user_enroll_context( def get_user_enroll_context(
self, self,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
chosen_email = info.get("email")
if not chosen_email:
# The GitHub Userprofile API only returns an email address if the profile
# has a public email address set (despite us asking for user:email, this behaviour
# doesn't change.). So we fetch all the user's email addresses
client: GitHubOAuth2Client = self.get_client(self.source)
emails = client.get_github_emails(self.token)
for email in emails:
if email.get("primary", False):
chosen_email = email.get("email", None)
return { return {
"username": info.get("login"), "username": info.get("login"),
"email": info.get("email"), "email": chosen_email,
"name": info.get("name"), "name": info.get("name"),
} }

View file

@ -22,6 +22,7 @@ class OAuthCallback(OAuthClientMixin, View):
"Base OAuth callback view." "Base OAuth callback view."
source: OAuthSource source: OAuthSource
token: Optional[dict] = None
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
@ -36,14 +37,14 @@ class OAuthCallback(OAuthClientMixin, View):
raise Http404(f"Source {slug} is not enabled.") raise Http404(f"Source {slug} is not enabled.")
client = self.get_client(self.source, callback=self.get_callback_url(self.source)) client = self.get_client(self.source, callback=self.get_callback_url(self.source))
# Fetch access token # Fetch access token
token = client.get_access_token() self.token = client.get_access_token()
if token is None: if self.token is None:
return self.handle_login_failure("Could not retrieve token.") return self.handle_login_failure("Could not retrieve token.")
if "error" in token: if "error" in self.token:
return self.handle_login_failure(token["error"]) return self.handle_login_failure(self.token["error"])
# Fetch profile info # Fetch profile info
try: try:
raw_info = client.get_profile_info(token) raw_info = client.get_profile_info(self.token)
if raw_info is None: if raw_info is None:
return self.handle_login_failure("Could not retrieve profile.") return self.handle_login_failure("Could not retrieve profile.")
except JSONDecodeError as exc: except JSONDecodeError as exc:
@ -66,7 +67,7 @@ class OAuthCallback(OAuthClientMixin, View):
) )
sfm.policy_context = {"oauth_userinfo": raw_info} sfm.policy_context = {"oauth_userinfo": raw_info}
return sfm.get_flow( return sfm.get_flow(
access_token=token.get("access_token"), access_token=self.token.get("access_token"),
) )
# pylint: disable=unused-argument # pylint: disable=unused-argument