providers/oauth2: if a redirect_uri cannot be parsed as regex, compare strict (#3070)

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens L 2022-06-10 23:32:57 +02:00 committed by Jens Langhammer
parent d0eb6af7e9
commit caed306346
7 changed files with 108 additions and 49 deletions

View File

@ -47,11 +47,11 @@ def create_test_tenant() -> Tenant:
def create_test_cert() -> CertificateKeyPair:
"""Generate a certificate for testing"""
CertificateKeyPair.objects.filter(name="goauthentik.io").delete()
builder = CertificateBuilder()
builder.common_name = "goauthentik.io"
builder.build(
subject_alt_names=["goauthentik.io"],
validity_days=360,
)
builder.name = generate_id()
return builder.save()

View File

@ -53,10 +53,7 @@ class CertificateBuilder:
.subject_name(
x509.Name(
[
x509.NameAttribute(
NameOID.COMMON_NAME,
self.common_name,
),
x509.NameAttribute(NameOID.COMMON_NAME, self.common_name),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
]
@ -65,10 +62,7 @@ class CertificateBuilder:
.issuer_name(
x509.Name(
[
x509.NameAttribute(
NameOID.COMMON_NAME,
f"authentik {__version__}",
),
x509.NameAttribute(NameOID.COMMON_NAME, f"authentik {__version__}"),
]
)
)

View File

@ -3,7 +3,7 @@ from django.test import RequestFactory
from django.urls import reverse
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id, generate_key
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
@ -39,7 +39,7 @@ class TestAuthorize(OAuthTestCase):
def test_request(self):
"""test request param"""
OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid/Foo",
@ -59,7 +59,7 @@ class TestAuthorize(OAuthTestCase):
def test_invalid_redirect_uri(self):
"""test missing/invalid redirect URI"""
OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
@ -78,10 +78,55 @@ class TestAuthorize(OAuthTestCase):
)
OAuthAuthorizationParams.from_request(request)
def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="",
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "+",
},
)
OAuthAuthorizationParams.from_request(request)
provider.refresh_from_db()
self.assertEqual(provider.redirect_uris, "+")
def test_invalid_redirect_uri_regex(self):
"""test missing/invalid redirect URI"""
OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid?",
)
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "http://localhost",
},
)
OAuthAuthorizationParams.from_request(request)
def test_redirect_uri_invalid_regex(self):
"""test missing/invalid redirect URI (invalid regex)"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="+",
@ -103,7 +148,7 @@ class TestAuthorize(OAuthTestCase):
def test_empty_redirect_uri(self):
"""test empty redirect URI (configure in provider)"""
OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
)
@ -123,7 +168,7 @@ class TestAuthorize(OAuthTestCase):
def test_response_type(self):
"""test response_type"""
OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid/Foo",
@ -201,7 +246,7 @@ class TestAuthorize(OAuthTestCase):
"""Test full authorization"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
@ -237,12 +282,12 @@ class TestAuthorize(OAuthTestCase):
"""Test full authorization"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
client_secret=generate_key(),
authorization_flow=flow,
redirect_uris="http://localhost",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
@ -281,12 +326,12 @@ class TestAuthorize(OAuthTestCase):
"""Test full authorization (form_post response)"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id="test",
client_secret=generate_key(),
authorization_flow=flow,
redirect_uris="http://localhost",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()

View File

@ -5,7 +5,7 @@ from django.test import RequestFactory
from django.urls import reverse
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id, generate_key
from authentik.providers.oauth2.constants import (
@ -24,17 +24,17 @@ class TestToken(OAuthTestCase):
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.app = Application.objects.create(name="test", slug="test")
self.app = Application.objects.create(name=generate_id(), slug="test")
def test_request_auth_code(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://testserver",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
@ -56,12 +56,12 @@ class TestToken(OAuthTestCase):
def test_request_auth_code_invalid(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://testserver",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
request = self.factory.post(
@ -79,12 +79,12 @@ class TestToken(OAuthTestCase):
def test_request_refresh_token(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
@ -108,12 +108,12 @@ class TestToken(OAuthTestCase):
def test_auth_code_view(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
@ -150,12 +150,12 @@ class TestToken(OAuthTestCase):
def test_refresh_token_view(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
@ -199,12 +199,12 @@ class TestToken(OAuthTestCase):
def test_refresh_token_view_invalid_origin(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://local.invalid",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
user = create_test_admin_user()
@ -244,12 +244,12 @@ class TestToken(OAuthTestCase):
def test_refresh_token_revoke(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
name=generate_id(),
client_id=generate_id(),
client_secret=generate_key(),
authorization_flow=create_test_flow(),
redirect_uris="http://testserver",
signing_key=create_test_cert(),
signing_key=self.keypair,
)
# Needs to be assigned to an application for iss to be set
self.app.provider = provider

View File

@ -2,12 +2,15 @@
from django.test import TestCase
from jwt import decode
from authentik.core.tests.utils import create_test_cert
from authentik.crypto.models import CertificateKeyPair
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken
class OAuthTestCase(TestCase):
"""OAuth test helpers"""
keypair: CertificateKeyPair
required_jwt_keys = [
"exp",
"iat",
@ -17,6 +20,11 @@ class OAuthTestCase(TestCase):
"iss",
]
@classmethod
def setUpClass(cls) -> None:
cls.keypair = create_test_cert()
super().setUpClass()
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
"""Validate that all required fields are set"""
key, alg = provider.get_jwt_key()

View File

@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from datetime import timedelta
from re import error as RegexError
from re import escape, fullmatch
from re import fullmatch
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit
from uuid import uuid4
@ -181,7 +181,7 @@ class OAuthAuthorizationParams:
if self.provider.redirect_uris == "":
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
self.provider.redirect_uris = escape(self.redirect_uri)
self.provider.redirect_uris = self.redirect_uri
self.provider.save()
allowed_redirect_urls = self.provider.redirect_uris.split()
@ -194,13 +194,19 @@ class OAuthAuthorizationParams:
try:
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri",
"Invalid redirect uri (regex comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
except RegexError as exc:
LOGGER.warning("Invalid regular expression configured", exc=exc)
LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
if not any(x == self.redirect_uri for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri (strict comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
)
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
if self.request:
raise AuthorizeError(

View File

@ -154,7 +154,7 @@ class TokenParams:
try:
if not any(fullmatch(x, self.redirect_uri) for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri",
"Invalid redirect uri (regex comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
)
@ -167,10 +167,16 @@ class TokenParams:
).from_http(request)
raise TokenError("invalid_client")
except RegexError as exc:
LOGGER.warning("Invalid regular expression configured", exc=exc)
LOGGER.info("Failed to parse regular expression, checking directly", exc=exc)
if not any(x == self.redirect_uri for x in allowed_redirect_urls):
LOGGER.warning(
"Invalid redirect uri (strict comparison)",
redirect_uri=self.redirect_uri,
expected=allowed_redirect_urls,
)
Event.new(
EventAction.CONFIGURATION_ERROR,
message="Invalid redirect_uri RegEx configured",
message="Invalid redirect_uri configured",
provider=self.provider,
).from_http(request)
raise TokenError("invalid_client")