Merge branch 'master' into outpost-ldap
This commit is contained in:
commit
4d858c64e0
|
@ -1,5 +1,5 @@
|
||||||
"""API Authentication"""
|
"""API Authentication"""
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode
|
||||||
from binascii import Error
|
from binascii import Error
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
@ -19,14 +19,6 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
|
||||||
auth_credentials = raw_header.decode()
|
auth_credentials = raw_header.decode()
|
||||||
if auth_credentials == "":
|
if auth_credentials == "":
|
||||||
return None
|
return None
|
||||||
# Legacy, accept basic auth thats fully encoded (2021.3 outposts)
|
|
||||||
if " " not in auth_credentials:
|
|
||||||
try:
|
|
||||||
plain = b64decode(auth_credentials.encode()).decode()
|
|
||||||
auth_type, body = plain.split()
|
|
||||||
auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}"
|
|
||||||
except (UnicodeDecodeError, Error):
|
|
||||||
raise AuthenticationFailed("Malformed header")
|
|
||||||
auth_type, auth_credentials = auth_credentials.split()
|
auth_type, auth_credentials = auth_credentials.split()
|
||||||
if auth_type.lower() not in ["basic", "bearer"]:
|
if auth_type.lower() not in ["basic", "bearer"]:
|
||||||
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())
|
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Test config API"""
|
||||||
|
from json import loads
|
||||||
|
|
||||||
|
from django.urls import reverse
|
||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(APITestCase):
|
||||||
|
"""Test config API"""
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
"""Test YAML generation"""
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:configs-list"),
|
||||||
|
)
|
||||||
|
self.assertTrue(loads(response.content.decode()))
|
|
@ -0,0 +1,33 @@
|
||||||
|
"""test decorators api"""
|
||||||
|
from django.urls import reverse
|
||||||
|
from guardian.shortcuts import assign_perm
|
||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.core.models import Application, User
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIDecorators(APITestCase):
|
||||||
|
"""test decorators api"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
self.user = User.objects.create(username="test-user")
|
||||||
|
|
||||||
|
def test_obj_perm_denied(self):
|
||||||
|
"""Test object perm denied"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
app = Application.objects.create(name="denied", slug="denied")
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 403)
|
||||||
|
|
||||||
|
def test_other_perm_denied(self):
|
||||||
|
"""Test other perm denied"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
app = Application.objects.create(name="denied", slug="denied")
|
||||||
|
assign_perm("authentik_core.view_application", self.user, app)
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 403)
|
|
@ -20,10 +20,12 @@ def is_dict(value: Any):
|
||||||
class PassiveSerializer(Serializer):
|
class PassiveSerializer(Serializer):
|
||||||
"""Base serializer class which doesn't implement create/update methods"""
|
"""Base serializer class which doesn't implement create/update methods"""
|
||||||
|
|
||||||
def create(self, validated_data: dict) -> Model:
|
def create(self, validated_data: dict) -> Model: # pragma: no cover
|
||||||
return Model()
|
return Model()
|
||||||
|
|
||||||
def update(self, instance: Model, validated_data: dict) -> Model:
|
def update(
|
||||||
|
self, instance: Model, validated_data: dict
|
||||||
|
) -> Model: # pragma: no cover
|
||||||
return Model()
|
return Model()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class CertificateBuilder:
|
||||||
def save(self) -> Optional[CertificateKeyPair]:
|
def save(self) -> Optional[CertificateKeyPair]:
|
||||||
"""Save generated certificate as model"""
|
"""Save generated certificate as model"""
|
||||||
if not self.__certificate:
|
if not self.__certificate:
|
||||||
return None
|
raise ValueError("Certificated hasn't been built yet")
|
||||||
return CertificateKeyPair.objects.create(
|
return CertificateKeyPair.objects.create(
|
||||||
name=self.common_name,
|
name=self.common_name,
|
||||||
certificate_data=self.certificate,
|
certificate_data=self.certificate,
|
||||||
|
|
|
@ -37,6 +37,8 @@ class TestCrypto(TestCase):
|
||||||
"""Test Builder"""
|
"""Test Builder"""
|
||||||
builder = CertificateBuilder()
|
builder = CertificateBuilder()
|
||||||
builder.common_name = "test-cert"
|
builder.common_name = "test-cert"
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
builder.save()
|
||||||
builder.build(
|
builder.build(
|
||||||
subject_alt_names=[],
|
subject_alt_names=[],
|
||||||
validity_days=3,
|
validity_days=3,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Notification API Views"""
|
"""Notification API Views"""
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from rest_framework.fields import ReadOnlyField
|
from rest_framework.fields import ReadOnlyField
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
|
@ -48,6 +49,5 @@ class NotificationViewSet(
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
return super().get_queryset()
|
return Notification.objects.filter(user=user)
|
||||||
return Notification.objects.filter(user=self.request.user)
|
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""base model tests"""
|
||||||
|
from typing import Callable, Type
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from authentik.flows.models import Stage
|
||||||
|
from authentik.flows.stage import StageView
|
||||||
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels(TestCase):
|
||||||
|
"""Generic model properties tests"""
|
||||||
|
|
||||||
|
|
||||||
|
def model_tester_factory(test_model: Type[Stage]) -> Callable:
|
||||||
|
"""Test a form"""
|
||||||
|
|
||||||
|
def tester(self: TestModels):
|
||||||
|
try:
|
||||||
|
model_class = None
|
||||||
|
if test_model._meta.abstract:
|
||||||
|
model_class = test_model.__bases__[0]()
|
||||||
|
else:
|
||||||
|
model_class = test_model()
|
||||||
|
self.assertTrue(issubclass(model_class.type, StageView))
|
||||||
|
self.assertIsNotNone(test_model.component)
|
||||||
|
_ = test_model.ui_user_settings
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return tester
|
||||||
|
|
||||||
|
|
||||||
|
for model in all_subclasses(Stage):
|
||||||
|
setattr(TestModels, f"test_model_{model.__name__}", model_tester_factory(model))
|
|
@ -160,7 +160,7 @@ class FlowImporter:
|
||||||
try:
|
try:
|
||||||
model: SerializerModel = apps.get_model(model_app_label, model_name)
|
model: SerializerModel = apps.get_model(model_app_label, model_name)
|
||||||
except LookupError:
|
except LookupError:
|
||||||
self.logger.error(
|
self.logger.warning(
|
||||||
"app or model does not exist", app=model_app_label, model=model_name
|
"app or model does not exist", app=model_app_label, model=model_name
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
@ -168,7 +168,7 @@ class FlowImporter:
|
||||||
try:
|
try:
|
||||||
serializer = self._validate_single(entry)
|
serializer = self._validate_single(entry)
|
||||||
except EntryInvalidError as exc:
|
except EntryInvalidError as exc:
|
||||||
self.logger.error("entry not valid", entry=entry, error=exc)
|
self.logger.warning("entry not valid", entry=entry, error=exc)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model = serializer.save()
|
model = serializer.save()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from aioredis.errors import ConnectionClosedError, ReplyError
|
||||||
from billiard.exceptions import WorkerLostError
|
from billiard.exceptions import WorkerLostError
|
||||||
from botocore.client import ClientError
|
from botocore.client import ClientError
|
||||||
from celery.exceptions import CeleryError
|
from celery.exceptions import CeleryError
|
||||||
|
from channels.middleware import BaseMiddleware
|
||||||
from channels_redis.core import ChannelFull
|
from channels_redis.core import ChannelFull
|
||||||
from django.core.exceptions import SuspiciousOperation, ValidationError
|
from django.core.exceptions import SuspiciousOperation, ValidationError
|
||||||
from django.db import InternalError, OperationalError, ProgrammingError
|
from django.db import InternalError, OperationalError, ProgrammingError
|
||||||
|
@ -14,12 +15,28 @@ from ldap3.core.exceptions import LDAPException
|
||||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||||
from redis.exceptions import RedisError, ResponseError
|
from redis.exceptions import RedisError, ResponseError
|
||||||
from rest_framework.exceptions import APIException
|
from rest_framework.exceptions import APIException
|
||||||
|
from sentry_sdk import Hub
|
||||||
|
from sentry_sdk.tracing import Transaction
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
from websockets.exceptions import WebSocketException
|
from websockets.exceptions import WebSocketException
|
||||||
|
|
||||||
|
from authentik.lib.utils.reflection import class_to_path
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class SentryWSMiddleware(BaseMiddleware):
|
||||||
|
"""Sentry Websocket middleweare to set the transaction name based on
|
||||||
|
consumer class path"""
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
transaction: Optional[Transaction] = Hub.current.scope.transaction
|
||||||
|
class_path = class_to_path(self.inner.consumer_class)
|
||||||
|
if transaction:
|
||||||
|
transaction.name = class_path
|
||||||
|
return await self.inner(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class SentryIgnoredException(Exception):
|
class SentryIgnoredException(Exception):
|
||||||
"""Base Class for all errors that are suppressed, and not sent to sentry."""
|
"""Base Class for all errors that are suppressed, and not sent to sentry."""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Test Reflection utils"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from authentik.lib.utils.reflection import path_to_class
|
||||||
|
|
||||||
|
|
||||||
|
class TestReflectionUtils(TestCase):
|
||||||
|
"""Test Reflection-utils"""
|
||||||
|
|
||||||
|
def test_path_to_class(self):
|
||||||
|
"""Test path_to_class"""
|
||||||
|
self.assertIsNone(path_to_class(None))
|
||||||
|
self.assertEqual(path_to_class("datetime.datetime"), datetime)
|
|
@ -203,7 +203,7 @@ class DockerServiceConnection(OutpostServiceConnection):
|
||||||
)
|
)
|
||||||
client.containers.list()
|
client.containers.list()
|
||||||
except DockerException as exc:
|
except DockerException as exc:
|
||||||
LOGGER.error(exc)
|
LOGGER.warning(exc)
|
||||||
raise ServiceConnectionInvalid from exc
|
raise ServiceConnectionInvalid from exc
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""OAuth2Provider API Views"""
|
"""OAuth2Provider API Views"""
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from rest_framework.fields import CharField, ListField
|
from rest_framework.fields import CharField, ListField
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
|
@ -38,11 +39,10 @@ class AuthorizationCodeViewSet(
|
||||||
ordering = ["provider", "expires"]
|
ordering = ["provider", "expires"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
|
if user.is_superuser:
|
||||||
return super().get_queryset()
|
return super().get_queryset()
|
||||||
if self.request.user.is_superuser:
|
return super().get_queryset().filter(user=user)
|
||||||
return super().get_queryset()
|
|
||||||
return super().get_queryset().filter(user=self.request.user)
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshTokenViewSet(
|
class RefreshTokenViewSet(
|
||||||
|
@ -59,8 +59,7 @@ class RefreshTokenViewSet(
|
||||||
ordering = ["provider", "expires"]
|
ordering = ["provider", "expires"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
|
if user.is_superuser:
|
||||||
return super().get_queryset()
|
return super().get_queryset()
|
||||||
if self.request.user.is_superuser:
|
return super().get_queryset().filter(user=user)
|
||||||
return super().get_queryset()
|
|
||||||
return super().get_queryset().filter(user=self.request.user)
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
"""Test authorize view"""
|
"""Test authorize view"""
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
from jwt import decode
|
|
||||||
|
|
||||||
from authentik.core.models import Application, User
|
from authentik.core.models import Application, User
|
||||||
from authentik.flows.challenge import ChallengeTypes
|
from authentik.flows.challenge import ChallengeTypes
|
||||||
|
@ -22,10 +21,11 @@ from authentik.providers.oauth2.models import (
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
|
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||||
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
|
||||||
|
|
||||||
|
|
||||||
class TestAuthorize(TestCase):
|
class TestAuthorize(OAuthTestCase):
|
||||||
"""Test authorize view"""
|
"""Test authorize view"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
@ -238,23 +238,4 @@ class TestAuthorize(TestCase):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
jwt = decode(
|
self.validate_jwt(token, provider)
|
||||||
token.access_token,
|
|
||||||
provider.client_secret,
|
|
||||||
algorithms=[provider.jwt_alg],
|
|
||||||
audience=provider.client_id,
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(jwt["exp"])
|
|
||||||
self.assertIsNotNone(jwt["iat"])
|
|
||||||
self.assertIsNotNone(jwt["auth_time"])
|
|
||||||
self.assertIsNotNone(jwt["acr"])
|
|
||||||
self.assertIsNotNone(jwt["sub"])
|
|
||||||
self.assertIsNotNone(jwt["iss"])
|
|
||||||
# Check id_token
|
|
||||||
id_token = token.id_token.to_dict()
|
|
||||||
self.assertIsNotNone(id_token["exp"])
|
|
||||||
self.assertIsNotNone(id_token["iat"])
|
|
||||||
self.assertIsNotNone(id_token["auth_time"])
|
|
||||||
self.assertIsNotNone(id_token["acr"])
|
|
||||||
self.assertIsNotNone(id_token["sub"])
|
|
||||||
self.assertIsNotNone(id_token["iss"])
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
"""Test token view"""
|
"""Test token view"""
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import Application, User
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
from authentik.providers.oauth2.constants import (
|
from authentik.providers.oauth2.constants import (
|
||||||
GRANT_TYPE_AUTHORIZATION_CODE,
|
GRANT_TYPE_AUTHORIZATION_CODE,
|
||||||
|
@ -20,15 +20,17 @@ from authentik.providers.oauth2.models import (
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
|
from authentik.providers.oauth2.tests.utils import OAuthTestCase
|
||||||
from authentik.providers.oauth2.views.token import TokenParams
|
from authentik.providers.oauth2.views.token import TokenParams
|
||||||
|
|
||||||
|
|
||||||
class TestToken(TestCase):
|
class TestToken(OAuthTestCase):
|
||||||
"""Test token view"""
|
"""Test token view"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.factory = RequestFactory()
|
self.factory = RequestFactory()
|
||||||
|
self.app = Application.objects.create(name="test", slug="test")
|
||||||
|
|
||||||
def test_request_auth_code(self):
|
def test_request_auth_code(self):
|
||||||
"""test request param"""
|
"""test request param"""
|
||||||
|
@ -97,12 +99,15 @@ class TestToken(TestCase):
|
||||||
authorization_flow=Flow.objects.first(),
|
authorization_flow=Flow.objects.first(),
|
||||||
redirect_uris="http://local.invalid",
|
redirect_uris="http://local.invalid",
|
||||||
)
|
)
|
||||||
|
# Needs to be assigned to an application for iss to be set
|
||||||
|
self.app.provider = provider
|
||||||
|
self.app.save()
|
||||||
header = b64encode(
|
header = b64encode(
|
||||||
f"{provider.client_id}:{provider.client_secret}".encode()
|
f"{provider.client_id}:{provider.client_secret}".encode()
|
||||||
).decode()
|
).decode()
|
||||||
user = User.objects.get(username="akadmin")
|
user = User.objects.get(username="akadmin")
|
||||||
code = AuthorizationCode.objects.create(
|
code = AuthorizationCode.objects.create(
|
||||||
code="foobar", provider=provider, user=user
|
code="foobar", provider=provider, user=user, is_open_id=True
|
||||||
)
|
)
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse("authentik_providers_oauth2:token"),
|
reverse("authentik_providers_oauth2:token"),
|
||||||
|
@ -126,6 +131,7 @@ class TestToken(TestCase):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_jwt(new_token, provider)
|
||||||
|
|
||||||
def test_refresh_token_view(self):
|
def test_refresh_token_view(self):
|
||||||
"""test request param"""
|
"""test request param"""
|
||||||
|
@ -136,6 +142,9 @@ class TestToken(TestCase):
|
||||||
authorization_flow=Flow.objects.first(),
|
authorization_flow=Flow.objects.first(),
|
||||||
redirect_uris="http://local.invalid",
|
redirect_uris="http://local.invalid",
|
||||||
)
|
)
|
||||||
|
# Needs to be assigned to an application for iss to be set
|
||||||
|
self.app.provider = provider
|
||||||
|
self.app.save()
|
||||||
header = b64encode(
|
header = b64encode(
|
||||||
f"{provider.client_id}:{provider.client_secret}".encode()
|
f"{provider.client_id}:{provider.client_secret}".encode()
|
||||||
).decode()
|
).decode()
|
||||||
|
@ -174,6 +183,7 @@ class TestToken(TestCase):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_jwt(new_token, provider)
|
||||||
|
|
||||||
def test_refresh_token_view_invalid_origin(self):
|
def test_refresh_token_view_invalid_origin(self):
|
||||||
"""test request param"""
|
"""test request param"""
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""OAuth test helpers"""
|
||||||
|
from django.test import TestCase
|
||||||
|
from jwt import decode
|
||||||
|
|
||||||
|
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTestCase(TestCase):
|
||||||
|
"""OAuth test helpers"""
|
||||||
|
|
||||||
|
required_jwt_keys = [
|
||||||
|
"exp",
|
||||||
|
"iat",
|
||||||
|
"auth_time",
|
||||||
|
"acr",
|
||||||
|
"sub",
|
||||||
|
"iss",
|
||||||
|
]
|
||||||
|
|
||||||
|
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
||||||
|
"""Validate that all required fields are set"""
|
||||||
|
jwt = decode(
|
||||||
|
token.access_token,
|
||||||
|
provider.client_secret,
|
||||||
|
algorithms=[provider.jwt_alg],
|
||||||
|
audience=provider.client_id,
|
||||||
|
)
|
||||||
|
id_token = token.id_token.to_dict()
|
||||||
|
for key in self.required_jwt_keys:
|
||||||
|
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
|
||||||
|
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
|
|
@ -16,6 +16,7 @@ from authentik.providers.oauth2.constants import (
|
||||||
from authentik.providers.oauth2.errors import TokenError, UserAuthError
|
from authentik.providers.oauth2.errors import TokenError, UserAuthError
|
||||||
from authentik.providers.oauth2.models import (
|
from authentik.providers.oauth2.models import (
|
||||||
AuthorizationCode,
|
AuthorizationCode,
|
||||||
|
ClientTypes,
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
|
@ -75,7 +76,7 @@ class TokenParams:
|
||||||
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
|
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
|
||||||
raise TokenError("invalid_client")
|
raise TokenError("invalid_client")
|
||||||
|
|
||||||
if self.provider.client_type == "confidential":
|
if self.provider.client_type == ClientTypes.CONFIDENTIAL:
|
||||||
if self.provider.client_secret != self.client_secret:
|
if self.provider.client_secret != self.client_secret:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"Invalid client secret: client does not have secret",
|
"Invalid client secret: client does not have secret",
|
||||||
|
|
|
@ -2,10 +2,13 @@
|
||||||
from channels.auth import AuthMiddlewareStack
|
from channels.auth import AuthMiddlewareStack
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
|
|
||||||
|
from authentik.lib.sentry import SentryWSMiddleware
|
||||||
from authentik.outposts.channels import OutpostConsumer
|
from authentik.outposts.channels import OutpostConsumer
|
||||||
from authentik.root.messages.consumer import MessageConsumer
|
from authentik.root.messages.consumer import MessageConsumer
|
||||||
|
|
||||||
websocket_urlpatterns = [
|
websocket_urlpatterns = [
|
||||||
path("ws/outpost/<uuid:pk>/", OutpostConsumer.as_asgi()),
|
path("ws/outpost/<uuid:pk>/", SentryWSMiddleware(OutpostConsumer.as_asgi())),
|
||||||
path("ws/client/", AuthMiddlewareStack(MessageConsumer.as_asgi())),
|
path(
|
||||||
|
"ws/client/", AuthMiddlewareStack(SentryWSMiddleware(MessageConsumer.as_asgi()))
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""OAuth Source Serializer"""
|
"""OAuth Source Serializer"""
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.sources import SourceSerializer
|
from authentik.core.api.sources import SourceSerializer
|
||||||
|
@ -26,8 +27,7 @@ class UserOAuthSourceConnectionViewSet(ModelViewSet):
|
||||||
filterset_fields = ["source__slug"]
|
filterset_fields = ["source__slug"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
|
if user.is_superuser:
|
||||||
return super().get_queryset()
|
return super().get_queryset()
|
||||||
if self.request.user.is_superuser:
|
return super().get_queryset().filter(user=user)
|
||||||
return super().get_queryset()
|
|
||||||
return super().get_queryset().filter(user=self.request.user)
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ DISCORD_USER = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestTypeGitHub(TestCase):
|
class TestTypeDiscord(TestCase):
|
||||||
"""OAuth Source tests"""
|
"""OAuth Source tests"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -32,7 +32,7 @@ class TestTypeGitHub(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_enroll_context(self):
|
def test_enroll_context(self):
|
||||||
"""Test GitHub Enrollment context"""
|
"""Test discord Enrollment context"""
|
||||||
ak_context = DiscordOAuth2Callback().get_user_enroll_context(
|
ak_context = DiscordOAuth2Callback().get_user_enroll_context(
|
||||||
self.source, UserOAuthSourceConnection(), DISCORD_USER
|
self.source, UserOAuthSourceConnection(), DISCORD_USER
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""google Type tests"""
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
|
from authentik.sources.oauth.types.google import GoogleOAuth2Callback
|
||||||
|
|
||||||
|
# https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en
|
||||||
|
GOOGLE_USER = {
|
||||||
|
"id": "1324813249123401234",
|
||||||
|
"email": "foo@bar.baz",
|
||||||
|
"verified_email": True,
|
||||||
|
"name": "foo bar",
|
||||||
|
"given_name": "foo",
|
||||||
|
"family_name": "bar",
|
||||||
|
"picture": "",
|
||||||
|
"locale": "en",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestTypeGoogle(TestCase):
|
||||||
|
"""OAuth Source tests"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.source = OAuthSource.objects.create(
|
||||||
|
name="test",
|
||||||
|
slug="test",
|
||||||
|
provider_type="google",
|
||||||
|
authorization_url="",
|
||||||
|
profile_url="",
|
||||||
|
consumer_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_enroll_context(self):
|
||||||
|
"""Test Google Enrollment context"""
|
||||||
|
ak_context = GoogleOAuth2Callback().get_user_enroll_context(
|
||||||
|
self.source, UserOAuthSourceConnection(), GOOGLE_USER
|
||||||
|
)
|
||||||
|
self.assertEqual(ak_context["username"], GOOGLE_USER["email"])
|
||||||
|
self.assertEqual(ak_context["email"], GOOGLE_USER["email"])
|
||||||
|
self.assertEqual(ak_context["name"], GOOGLE_USER["name"])
|
|
@ -1,5 +1,4 @@
|
||||||
"""Dispatch OAuth views to respective views"""
|
"""Dispatch OAuth views to respective views"""
|
||||||
from django.http import Http404
|
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.views import View
|
from django.views import View
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
@ -15,12 +14,9 @@ class DispatcherView(View):
|
||||||
|
|
||||||
kind = ""
|
kind = ""
|
||||||
|
|
||||||
def dispatch(self, *args, **kwargs):
|
def dispatch(self, *args, source_slug: str, **kwargs):
|
||||||
"""Find Source by slug and forward request"""
|
"""Find Source by slug and forward request"""
|
||||||
slug = kwargs.get("source_slug", None)
|
source = get_object_or_404(OAuthSource, slug=source_slug)
|
||||||
if not slug:
|
|
||||||
raise Http404
|
|
||||||
source = get_object_or_404(OAuthSource, slug=slug)
|
|
||||||
view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind))
|
view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind))
|
||||||
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
|
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
|
||||||
return view.as_view()(*args, **kwargs)
|
return view.as_view()(*args, source_slug=source_slug, **kwargs)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""AuthenticatorStaticStage API Views"""
|
"""AuthenticatorStaticStage API Views"""
|
||||||
from django_otp.plugins.otp_static.models import StaticDevice
|
from django_otp.plugins.otp_static.models import StaticDevice
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework.permissions import IsAdminUser
|
from rest_framework.permissions import IsAdminUser
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||||
|
@ -44,9 +45,8 @@ class StaticDeviceViewSet(ModelViewSet):
|
||||||
ordering = ["name"]
|
ordering = ["name"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
return super().get_queryset()
|
return StaticDevice.objects.filter(user=user)
|
||||||
return StaticDevice.objects.filter(user=self.request.user)
|
|
||||||
|
|
||||||
|
|
||||||
class StaticAdminDeviceViewSet(ReadOnlyModelViewSet):
|
class StaticAdminDeviceViewSet(ReadOnlyModelViewSet):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""AuthenticatorTOTPStage API Views"""
|
"""AuthenticatorTOTPStage API Views"""
|
||||||
from django_otp.plugins.otp_totp.models import TOTPDevice
|
from django_otp.plugins.otp_totp.models import TOTPDevice
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework.permissions import IsAdminUser
|
from rest_framework.permissions import IsAdminUser
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||||
|
@ -47,9 +48,8 @@ class TOTPDeviceViewSet(ModelViewSet):
|
||||||
ordering = ["name"]
|
ordering = ["name"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
return super().get_queryset()
|
return TOTPDevice.objects.filter(user=user)
|
||||||
return TOTPDevice.objects.filter(user=self.request.user)
|
|
||||||
|
|
||||||
|
|
||||||
class TOTPAdminDeviceViewSet(ReadOnlyModelViewSet):
|
class TOTPAdminDeviceViewSet(ReadOnlyModelViewSet):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""AuthenticateWebAuthnStage API Views"""
|
"""AuthenticateWebAuthnStage API Views"""
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework.permissions import IsAdminUser
|
from rest_framework.permissions import IsAdminUser
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||||
|
@ -46,9 +47,8 @@ class WebAuthnDeviceViewSet(ModelViewSet):
|
||||||
ordering = ["name"]
|
ordering = ["name"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
return super().get_queryset()
|
return WebAuthnDevice.objects.filter(user=user)
|
||||||
return WebAuthnDevice.objects.filter(user=self.request.user)
|
|
||||||
|
|
||||||
|
|
||||||
class WebAuthnAdminDeviceViewSet(ReadOnlyModelViewSet):
|
class WebAuthnAdminDeviceViewSet(ReadOnlyModelViewSet):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""ConsentStage API Views"""
|
"""ConsentStage API Views"""
|
||||||
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||||
|
|
||||||
|
@ -50,8 +51,7 @@ class UserConsentViewSet(
|
||||||
ordering = ["application", "expires"]
|
ordering = ["application", "expires"]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
if not self.request:
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
|
if user.is_superuser:
|
||||||
return super().get_queryset()
|
return super().get_queryset()
|
||||||
if self.request.user.is_superuser:
|
return super().get_queryset().filter(user=user)
|
||||||
return super().get_queryset()
|
|
||||||
return super().get_queryset().filter(user=self.request.user)
|
|
||||||
|
|
|
@ -68,7 +68,7 @@ def send_mail(
|
||||||
messages=["Successfully sent Mail."],
|
messages=["Successfully sent Mail."],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (SMTPException, ConnectionError) as exc:
|
except (SMTPException, ConnectionError, ValueError) as exc:
|
||||||
LOGGER.debug("Error sending email, retrying...", exc=exc)
|
LOGGER.debug("Error sending email, retrying...", exc=exc)
|
||||||
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
|
||||||
raise exc
|
raise exc
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""login tests"""
|
"""login tests"""
|
||||||
|
from time import sleep
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from django.test import Client, TestCase
|
from django.test import Client, TestCase
|
||||||
|
@ -51,6 +52,31 @@ class TestUserLoginStage(TestCase):
|
||||||
{"to": reverse("authentik_core:root-redirect"), "type": "redirect"},
|
{"to": reverse("authentik_core:root-redirect"), "type": "redirect"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_expiry(self):
|
||||||
|
"""Test with expiry"""
|
||||||
|
self.stage.session_duration = "seconds=2"
|
||||||
|
self.stage.save()
|
||||||
|
plan = FlowPlan(
|
||||||
|
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
|
||||||
|
)
|
||||||
|
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||||
|
session = self.client.session
|
||||||
|
session[SESSION_KEY_PLAN] = plan
|
||||||
|
session.save()
|
||||||
|
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertJSONEqual(
|
||||||
|
force_str(response.content),
|
||||||
|
{"to": reverse("authentik_core:root-redirect"), "type": "redirect"},
|
||||||
|
)
|
||||||
|
self.assertNotEqual(list(self.client.session.keys()), [])
|
||||||
|
sleep(3)
|
||||||
|
self.client.session.clear_expired()
|
||||||
|
self.assertEqual(list(self.client.session.keys()), [])
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"authentik.flows.views.to_stage_response",
|
"authentik.flows.views.to_stage_response",
|
||||||
TO_STAGE_RESPONSE_MOCK,
|
TO_STAGE_RESPONSE_MOCK,
|
||||||
|
|
|
@ -23,6 +23,11 @@ export function configureSentry(canDoPpi: boolean = false): Promise<Config> {
|
||||||
if (hint.originalException instanceof SentryIgnoredError) {
|
if (hint.originalException instanceof SentryIgnoredError) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
if (hint.originalException instanceof Error) {
|
||||||
|
if (hint.originalException.name == 'NetworkError') {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (hint.originalException instanceof Response) {
|
if (hint.originalException instanceof Response) {
|
||||||
const response = hint.originalException as Response;
|
const response = hint.originalException as Response;
|
||||||
// We only care about server errors
|
// We only care about server errors
|
||||||
|
|
|
@ -70,7 +70,6 @@ export class ServiceConnectionDockerForm extends Form<DockerServiceConnection> {
|
||||||
</ak-form-element-horizontal>
|
</ak-form-element-horizontal>
|
||||||
<ak-form-element-horizontal
|
<ak-form-element-horizontal
|
||||||
label=${t`TLS Verification Certificate`}
|
label=${t`TLS Verification Certificate`}
|
||||||
?required=${true}
|
|
||||||
name="tlsVerification">
|
name="tlsVerification">
|
||||||
<select class="pf-c-form-control">
|
<select class="pf-c-form-control">
|
||||||
<option value="" ?selected=${this.sc?.tlsVerification === undefined}>---------</option>
|
<option value="" ?selected=${this.sc?.tlsVerification === undefined}>---------</option>
|
||||||
|
@ -86,7 +85,6 @@ export class ServiceConnectionDockerForm extends Form<DockerServiceConnection> {
|
||||||
</ak-form-element-horizontal>
|
</ak-form-element-horizontal>
|
||||||
<ak-form-element-horizontal
|
<ak-form-element-horizontal
|
||||||
label=${t`TLS Authentication Certificate`}
|
label=${t`TLS Authentication Certificate`}
|
||||||
?required=${true}
|
|
||||||
name="tlsAuthentication">
|
name="tlsAuthentication">
|
||||||
<select class="pf-c-form-control">
|
<select class="pf-c-form-control">
|
||||||
<option value="" ?selected=${this.sc?.tlsAuthentication === undefined}>---------</option>
|
<option value="" ?selected=${this.sc?.tlsAuthentication === undefined}>---------</option>
|
||||||
|
|
Reference in New Issue