sources/oauth: use GitHub's dedicated email API when no public email address is configured

closes #3472

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-08-26 21:21:39 +02:00
parent 2868331976
commit 83eaac375d
4 changed files with 80 additions and 17 deletions

View file

@ -12,8 +12,6 @@ from authentik.events.models import Event, EventAction
from authentik.lib.utils.http import get_http_session
from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
class BaseOAuthClient:
"""Base OAuth Client"""
@ -30,6 +28,7 @@ class BaseOAuthClient:
self.session = get_http_session()
self.request = request
self.callback = callback
self.logger = get_logger().bind(source=source.slug)
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"""Fetch access token from callback request."""
@ -44,7 +43,7 @@ class BaseOAuthClient:
response = self.do_request("get", profile_url, token=token)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text)
return None
else:
return response.json()
@ -73,7 +72,7 @@ class BaseOAuthClient:
# to make additional scopes easier
args["scope"] = " ".join(sorted(set(args["scope"])))
params = urlencode(args, quote_via=quote, doseq=True)
LOGGER.info("redirect args", **args)
self.logger.info("redirect args", **args)
return urlunparse(parsed_url._replace(query=params))
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:

View file

@ -1,6 +1,10 @@
"""GitHub Type tests"""
from django.test import TestCase
from copy import copy
from django.test import RequestFactory, TestCase
from requests_mock import Mocker
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
@ -55,11 +59,9 @@ class TestTypeGitHub(TestCase):
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key="",
provider_type="github",
)
self.factory = RequestFactory()
def test_enroll_context(self):
"""Test GitHub Enrollment context"""
@ -67,3 +69,30 @@ class TestTypeGitHub(TestCase):
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
def test_enroll_context_email(self):
"""Test GitHub Enrollment context"""
email = generate_id()
user = copy(GITHUB_USER)
del user["email"]
with Mocker() as mocker:
mocker.get(
"https://api.github.com/user/emails",
json=[
{
"primary": True,
"email": email,
}
],
)
ak_context = GitHubOAuth2Callback(
source=self.source,
request=self.factory.get("/"),
token={
"access_token": generate_id(),
"token_type": generate_id(),
},
).get_user_enroll_context(user)
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], email)
self.assertEqual(ak_context["name"], GITHUB_USER["name"])

View file

@ -1,6 +1,9 @@
"""GitHub OAuth Views"""
from typing import Any
from typing import Any, Optional
from requests.exceptions import RequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -15,16 +18,47 @@ class GitHubOAuthRedirect(OAuthRedirect):
}
class GitHubOAuth2Client(OAuth2Client):
"""GitHub OAuth2 Client"""
def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"""Get Emails from the GitHub API"""
profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url
profile_url += "/emails"
try:
response = self.do_request("get", profile_url, token=token)
response.raise_for_status()
except RequestException as exc:
self.logger.warning("Unable to fetch github emails", exc=exc)
return []
else:
return response.json()
class GitHubOAuth2Callback(OAuthCallback):
"""GitHub OAuth2 Callback"""
client_class = GitHubOAuth2Client
def get_user_enroll_context(
self,
info: dict[str, Any],
) -> dict[str, Any]:
chosen_email = info.get("email")
if not chosen_email:
# The GitHub Userprofile API only returns an email address if the profile
# has a public email address set (despite us asking for user:email, this behaviour
# doesn't change.). So we fetch all the user's email addresses
client: GitHubOAuth2Client = self.get_client(self.source)
emails = client.get_github_emails(self.token)
for email in emails:
if email.get("primary", False):
chosen_email = email.get("email", None)
return {
"username": info.get("login"),
"email": info.get("email"),
"email": chosen_email,
"name": info.get("name"),
}

View file

@ -22,6 +22,7 @@ class OAuthCallback(OAuthClientMixin, View):
"Base OAuth callback view."
source: OAuthSource
token: Optional[dict] = None
# pylint: disable=too-many-return-statements
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
@ -36,14 +37,14 @@ class OAuthCallback(OAuthClientMixin, View):
raise Http404(f"Source {slug} is not enabled.")
client = self.get_client(self.source, callback=self.get_callback_url(self.source))
# Fetch access token
token = client.get_access_token()
if token is None:
self.token = client.get_access_token()
if self.token is None:
return self.handle_login_failure("Could not retrieve token.")
if "error" in token:
return self.handle_login_failure(token["error"])
if "error" in self.token:
return self.handle_login_failure(self.token["error"])
# Fetch profile info
try:
raw_info = client.get_profile_info(token)
raw_info = client.get_profile_info(self.token)
if raw_info is None:
return self.handle_login_failure("Could not retrieve profile.")
except JSONDecodeError as exc:
@ -66,7 +67,7 @@ class OAuthCallback(OAuthClientMixin, View):
)
sfm.policy_context = {"oauth_userinfo": raw_info}
return sfm.get_flow(
access_token=token.get("access_token"),
access_token=self.token.get("access_token"),
)
# pylint: disable=unused-argument