providers/oauth2: launch url: if URL parsing fails, return no launch URL (#5918)
* providers/oauth2: launch url: if URL parsing fails, return no launch URL Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> * add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> * only get provider launch URL when no url is set Signed-off-by: Jens Langhammer <jens@goauthentik.io> * only catch value error Signed-off-by: Jens Langhammer <jens@goauthentik.io> * format Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space> Signed-off-by: Jens Langhammer <jens@goauthentik.io> Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
587385587c
commit
0041cf88f4
|
@ -376,10 +376,10 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||||
def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]:
|
def get_launch_url(self, user: Optional["User"] = None) -> Optional[str]:
|
||||||
"""Get launch URL if set, otherwise attempt to get launch URL based on provider."""
|
"""Get launch URL if set, otherwise attempt to get launch URL based on provider."""
|
||||||
url = None
|
url = None
|
||||||
if provider := self.get_provider():
|
|
||||||
url = provider.launch_url
|
|
||||||
if self.meta_launch_url:
|
if self.meta_launch_url:
|
||||||
url = self.meta_launch_url
|
url = self.meta_launch_url
|
||||||
|
elif provider := self.get_provider():
|
||||||
|
url = provider.launch_url
|
||||||
if user and url:
|
if user and url:
|
||||||
if isinstance(user, SimpleLazyObject):
|
if isinstance(user, SimpleLazyObject):
|
||||||
user._setup()
|
user._setup()
|
||||||
|
|
|
@ -17,6 +17,7 @@ from django.urls import reverse
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from jwt import encode
|
from jwt import encode
|
||||||
from rest_framework.serializers import Serializer
|
from rest_framework.serializers import Serializer
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
|
@ -26,6 +27,8 @@ from authentik.lib.utils.time import timedelta_string_validator
|
||||||
from authentik.providers.oauth2.id_token import IDToken, SubModes
|
from authentik.providers.oauth2.id_token import IDToken, SubModes
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
|
|
||||||
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def generate_client_secret() -> str:
|
def generate_client_secret() -> str:
|
||||||
"""Generate client secret with adequate length"""
|
"""Generate client secret with adequate length"""
|
||||||
|
@ -251,8 +254,12 @@ class OAuth2Provider(Provider):
|
||||||
if self.redirect_uris == "":
|
if self.redirect_uris == "":
|
||||||
return None
|
return None
|
||||||
main_url = self.redirect_uris.split("\n", maxsplit=1)[0]
|
main_url = self.redirect_uris.split("\n", maxsplit=1)[0]
|
||||||
launch_url = urlparse(main_url)._replace(path="")
|
try:
|
||||||
return urlunparse(launch_url)
|
launch_url = urlparse(main_url)._replace(path="")
|
||||||
|
return urlunparse(launch_url)
|
||||||
|
except ValueError as exc:
|
||||||
|
LOGGER.warning("Failed to format launch url", exc=exc)
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
"""Test OAuth2 API"""
|
"""Test OAuth2 API"""
|
||||||
from json import loads
|
from json import loads
|
||||||
|
from sys import version_info
|
||||||
|
from unittest import skipUnless
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
@ -42,3 +44,14 @@ class TestAPI(APITestCase):
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
body = loads(response.content.decode())
|
body = loads(response.content.decode())
|
||||||
self.assertEqual(body["issuer"], "http://testserver/application/o/test/")
|
self.assertEqual(body["issuer"], "http://testserver/application/o/test/")
|
||||||
|
|
||||||
|
# https://github.com/goauthentik/authentik/pull/5918
|
||||||
|
@skipUnless(version_info >= (3, 11, 4), "This behaviour is only Python 3.11.4 and up")
|
||||||
|
def test_launch_url(self):
|
||||||
|
"""Test launch_url"""
|
||||||
|
self.provider.redirect_uris = (
|
||||||
|
"https://[\\d\\w]+.pr.test.goauthentik.io/source/oauth/callback/authentik/\n"
|
||||||
|
)
|
||||||
|
self.provider.save()
|
||||||
|
self.provider.refresh_from_db()
|
||||||
|
self.assertIsNone(self.provider.launch_url)
|
||||||
|
|
Reference in New Issue