providers/oauth2: launch url: if URL parsing fails, return no launch URL (#5918)

* providers/oauth2: launch url: if URL parsing fails, return no launch URL

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* only get provider launch URL when no url is set

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* only catch value error

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* format

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
risson 2023-06-09 21:56:34 +02:00 committed by GitHub
parent 587385587c
commit 0041cf88f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 4 deletions

View File

@ -376,10 +376,10 @@ class Application(SerializerModel, PolicyBindingModel):
def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]: def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]:
"""Get launch URL if set, otherwise attempt to get launch URL based on provider.""" """Get launch URL if set, otherwise attempt to get launch URL based on provider."""
url = None url = None
if provider := self.get_provider():
url = provider.launch_url
if self.meta_launch_url: if self.meta_launch_url:
url = self.meta_launch_url url = self.meta_launch_url
elif provider := self.get_provider():
url = provider.launch_url
if user and url: if user and url:
if isinstance(user, SimpleLazyObject): if isinstance(user, SimpleLazyObject):
user._setup() user._setup()

View File

@ -17,6 +17,7 @@ from django.urls import reverse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from jwt import encode from jwt import encode
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from structlog.stdlib import get_logger
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
@ -26,6 +27,8 @@ from authentik.lib.utils.time import timedelta_string_validator
from authentik.providers.oauth2.id_token import IDToken, SubModes from authentik.providers.oauth2.id_token import IDToken, SubModes
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
def generate_client_secret() -> str: def generate_client_secret() -> str:
"""Generate client secret with adequate length""" """Generate client secret with adequate length"""
@ -251,8 +254,12 @@ class OAuth2Provider(Provider):
if self.redirect_uris == "": if self.redirect_uris == "":
return None return None
main_url = self.redirect_uris.split("\n", maxsplit=1)[0] main_url = self.redirect_uris.split("\n", maxsplit=1)[0]
launch_url = urlparse(main_url)._replace(path="") try:
return urlunparse(launch_url) launch_url = urlparse(main_url)._replace(path="")
return urlunparse(launch_url)
except ValueError as exc:
LOGGER.warning("Failed to format launch url", exc=exc)
return None
@property @property
def component(self) -> str: def component(self) -> str:

View File

@ -1,5 +1,7 @@
"""Test OAuth2 API""" """Test OAuth2 API"""
from json import loads from json import loads
from sys import version_info
from unittest import skipUnless
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
@ -42,3 +44,14 @@ class TestAPI(APITestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content.decode()) body = loads(response.content.decode())
self.assertEqual(body["issuer"], "http://testserver/application/o/test/") self.assertEqual(body["issuer"], "http://testserver/application/o/test/")
# https://github.com/goauthentik/authentik/pull/5918
@skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up")
def test_launch_url(self):
"""Test launch_url"""
self.provider.redirect_uris = (
"https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n"
)
self.provider.save()
self.provider.refresh_from_db()
self.assertIsNone(self.provider.launch_url)