Merge branch 'master' into outpost-ldap

This commit is contained in:
Jens Langhammer 2021-04-27 17:08:26 +02:00
commit 4d858c64e0
30 changed files with 288 additions and 85 deletions

View File

@ -1,5 +1,5 @@
"""API Authentication""" """API Authentication"""
from base64 import b64decode, b64encode from base64 import b64decode
from binascii import Error from binascii import Error
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -19,14 +19,6 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
auth_credentials = raw_header.decode() auth_credentials = raw_header.decode()
if auth_credentials == "": if auth_credentials == "":
return None return None
# Legacy, accept basic auth thats fully encoded (2021.3 outposts)
if " " not in auth_credentials:
try:
plain = b64decode(auth_credentials.encode()).decode()
auth_type, body = plain.split()
auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}"
except (UnicodeDecodeError, Error):
raise AuthenticationFailed("Malformed header")
auth_type, auth_credentials = auth_credentials.split() auth_type, auth_credentials = auth_credentials.split()
if auth_type.lower() not in ["basic", "bearer"]: if auth_type.lower() not in ["basic", "bearer"]:
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower()) LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())

View File

@ -0,0 +1,16 @@
"""Test config API"""
from json import loads
from django.urls import reverse
from rest_framework.test import APITestCase
class TestConfig(APITestCase):
"""Test config API"""
def test_config(self):
"""Test YAML generation"""
response = self.client.get(
reverse("authentik_api:configs-list"),
)
self.assertTrue(loads(response.content.decode()))

View File

@ -0,0 +1,33 @@
"""test decorators api"""
from django.urls import reverse
from guardian.shortcuts import assign_perm
from rest_framework.test import APITestCase
from authentik.core.models import Application, User
class TestAPIDecorators(APITestCase):
"""test decorators api"""
def setUp(self) -> None:
super().setUp()
self.user = User.objects.create(username="test-user")
def test_obj_perm_denied(self):
"""Test object perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name="denied", slug="denied")
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 403)
def test_other_perm_denied(self):
"""Test other perm denied"""
self.client.force_login(self.user)
app = Application.objects.create(name="denied", slug="denied")
assign_perm("authentik_core.view_application", self.user, app)
response = self.client.get(
reverse("authentik_api:application-metrics", kwargs={"slug": app.slug})
)
self.assertEqual(response.status_code, 403)

View File

@ -20,10 +20,12 @@ def is_dict(value: Any):
class PassiveSerializer(Serializer): class PassiveSerializer(Serializer):
"""Base serializer class which doesn't implement create/update methods""" """Base serializer class which doesn't implement create/update methods"""
def create(self, validated_data: dict) -> Model: def create(self, validated_data: dict) -> Model: # pragma: no cover
return Model() return Model()
def update(self, instance: Model, validated_data: dict) -> Model: def update(
self, instance: Model, validated_data: dict
) -> Model: # pragma: no cover
return Model() return Model()

View File

@ -33,7 +33,7 @@ class CertificateBuilder:
def save(self) -> Optional[CertificateKeyPair]: def save(self) -> Optional[CertificateKeyPair]:
"""Save generated certificate as model""" """Save generated certificate as model"""
if not self.__certificate: if not self.__certificate:
return None raise ValueError("Certificated hasn't been built yet")
return CertificateKeyPair.objects.create( return CertificateKeyPair.objects.create(
name=self.common_name, name=self.common_name,
certificate_data=self.certificate, certificate_data=self.certificate,

View File

@ -37,6 +37,8 @@ class TestCrypto(TestCase):
"""Test Builder""" """Test Builder"""
builder = CertificateBuilder() builder = CertificateBuilder()
builder.common_name = "test-cert" builder.common_name = "test-cert"
with self.assertRaises(ValueError):
builder.save()
builder.build( builder.build(
subject_alt_names=[], subject_alt_names=[],
validity_days=3, validity_days=3,

View File

@ -1,4 +1,5 @@
"""Notification API Views""" """Notification API Views"""
from guardian.utils import get_anonymous_user
from rest_framework import mixins from rest_framework import mixins
from rest_framework.fields import ReadOnlyField from rest_framework.fields import ReadOnlyField
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
@ -48,6 +49,5 @@ class NotificationViewSet(
] ]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
return super().get_queryset() return Notification.objects.filter(user=user)
return Notification.objects.filter(user=self.request.user)

View File

@ -0,0 +1,35 @@
"""base model tests"""
from typing import Callable, Type
from django.test import TestCase
from authentik.flows.models import Stage
from authentik.flows.stage import StageView
from authentik.lib.utils.reflection import all_subclasses
class TestModels(TestCase):
"""Generic model properties tests"""
def model_tester_factory(test_model: Type[Stage]) -> Callable:
"""Test a form"""
def tester(self: TestModels):
try:
model_class = None
if test_model._meta.abstract:
model_class = test_model.__bases__[0]()
else:
model_class = test_model()
self.assertTrue(issubclass(model_class.type, StageView))
self.assertIsNotNone(test_model.component)
_ = test_model.ui_user_settings
except NotImplementedError:
pass
return tester
for model in all_subclasses(Stage):
setattr(TestModels, f"test_model_{model.__name__}", model_tester_factory(model))

View File

@ -160,7 +160,7 @@ class FlowImporter:
try: try:
model: SerializerModel = apps.get_model(model_app_label, model_name) model: SerializerModel = apps.get_model(model_app_label, model_name)
except LookupError: except LookupError:
self.logger.error( self.logger.warning(
"app or model does not exist", app=model_app_label, model=model_name "app or model does not exist", app=model_app_label, model=model_name
) )
return False return False
@ -168,7 +168,7 @@ class FlowImporter:
try: try:
serializer = self._validate_single(entry) serializer = self._validate_single(entry)
except EntryInvalidError as exc: except EntryInvalidError as exc:
self.logger.error("entry not valid", entry=entry, error=exc) self.logger.warning("entry not valid", entry=entry, error=exc)
return False return False
model = serializer.save() model = serializer.save()

View File

@ -5,6 +5,7 @@ from aioredis.errors import ConnectionClosedError, ReplyError
from billiard.exceptions import WorkerLostError from billiard.exceptions import WorkerLostError
from botocore.client import ClientError from botocore.client import ClientError
from celery.exceptions import CeleryError from celery.exceptions import CeleryError
from channels.middleware import BaseMiddleware
from channels_redis.core import ChannelFull from channels_redis.core import ChannelFull
from django.core.exceptions import SuspiciousOperation, ValidationError from django.core.exceptions import SuspiciousOperation, ValidationError
from django.db import InternalError, OperationalError, ProgrammingError from django.db import InternalError, OperationalError, ProgrammingError
@ -14,12 +15,28 @@ from ldap3.core.exceptions import LDAPException
from redis.exceptions import ConnectionError as RedisConnectionError from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError, ResponseError from redis.exceptions import RedisError, ResponseError
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
from sentry_sdk import Hub
from sentry_sdk.tracing import Transaction
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from websockets.exceptions import WebSocketException from websockets.exceptions import WebSocketException
from authentik.lib.utils.reflection import class_to_path
LOGGER = get_logger() LOGGER = get_logger()
class SentryWSMiddleware(BaseMiddleware):
"""Sentry Websocket middleweare to set the transaction name based on
consumer class path"""
async def __call__(self, scope, receive, send):
transaction: Optional[Transaction] = Hub.current.scope.transaction
class_path = class_to_path(self.inner.consumer_class)
if transaction:
transaction.name = class_path
return await self.inner(scope, receive, send)
class SentryIgnoredException(Exception): class SentryIgnoredException(Exception):
"""Base Class for all errors that are suppressed, and not sent to sentry.""" """Base Class for all errors that are suppressed, and not sent to sentry."""

View File

@ -0,0 +1,16 @@
"""Test Reflection utils"""
from datetime import datetime
from django.test import TestCase
from authentik.lib.utils.reflection import path_to_class
class TestReflectionUtils(TestCase):
"""Test Reflection-utils"""
def test_path_to_class(self):
"""Test path_to_class"""
self.assertIsNone(path_to_class(None))
self.assertEqual(path_to_class("datetime.datetime"), datetime)

View File

@ -203,7 +203,7 @@ class DockerServiceConnection(OutpostServiceConnection):
) )
client.containers.list() client.containers.list()
except DockerException as exc: except DockerException as exc:
LOGGER.error(exc) LOGGER.warning(exc)
raise ServiceConnectionInvalid from exc raise ServiceConnectionInvalid from exc
return client return client

View File

@ -1,4 +1,5 @@
"""OAuth2Provider API Views""" """OAuth2Provider API Views"""
from guardian.utils import get_anonymous_user
from rest_framework import mixins from rest_framework import mixins
from rest_framework.fields import CharField, ListField from rest_framework.fields import CharField, ListField
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
@ -38,11 +39,10 @@ class AuthorizationCodeViewSet(
ordering = ["provider", "expires"] ordering = ["provider", "expires"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
if user.is_superuser:
return super().get_queryset() return super().get_queryset()
if self.request.user.is_superuser: return super().get_queryset().filter(user=user)
return super().get_queryset()
return super().get_queryset().filter(user=self.request.user)
class RefreshTokenViewSet( class RefreshTokenViewSet(
@ -59,8 +59,7 @@ class RefreshTokenViewSet(
ordering = ["provider", "expires"] ordering = ["provider", "expires"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
if user.is_superuser:
return super().get_queryset() return super().get_queryset()
if self.request.user.is_superuser: return super().get_queryset().filter(user=user)
return super().get_queryset()
return super().get_queryset().filter(user=self.request.user)

View File

@ -1,8 +1,7 @@
"""Test authorize view""" """Test authorize view"""
from django.test import RequestFactory, TestCase from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from django.utils.encoding import force_str from django.utils.encoding import force_str
from jwt import decode
from authentik.core.models import Application, User from authentik.core.models import Application, User
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
@ -22,10 +21,11 @@ from authentik.providers.oauth2.models import (
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams from authentik.providers.oauth2.views.authorize import OAuthAuthorizationParams
class TestAuthorize(TestCase): class TestAuthorize(OAuthTestCase):
"""Test authorize view""" """Test authorize view"""
def setUp(self) -> None: def setUp(self) -> None:
@ -238,23 +238,4 @@ class TestAuthorize(TestCase):
), ),
}, },
) )
jwt = decode( self.validate_jwt(token, provider)
token.access_token,
provider.client_secret,
algorithms=[provider.jwt_alg],
audience=provider.client_id,
)
self.assertIsNotNone(jwt["exp"])
self.assertIsNotNone(jwt["iat"])
self.assertIsNotNone(jwt["auth_time"])
self.assertIsNotNone(jwt["acr"])
self.assertIsNotNone(jwt["sub"])
self.assertIsNotNone(jwt["iss"])
# Check id_token
id_token = token.id_token.to_dict()
self.assertIsNotNone(id_token["exp"])
self.assertIsNotNone(id_token["iat"])
self.assertIsNotNone(id_token["auth_time"])
self.assertIsNotNone(id_token["acr"])
self.assertIsNotNone(id_token["sub"])
self.assertIsNotNone(id_token["iss"])

View File

@ -1,11 +1,11 @@
"""Test token view""" """Test token view"""
from base64 import b64encode from base64 import b64encode
from django.test import RequestFactory, TestCase from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
from django.utils.encoding import force_str from django.utils.encoding import force_str
from authentik.core.models import User from authentik.core.models import Application, User
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.providers.oauth2.constants import ( from authentik.providers.oauth2.constants import (
GRANT_TYPE_AUTHORIZATION_CODE, GRANT_TYPE_AUTHORIZATION_CODE,
@ -20,15 +20,17 @@ from authentik.providers.oauth2.models import (
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
from authentik.providers.oauth2.tests.utils import OAuthTestCase
from authentik.providers.oauth2.views.token import TokenParams from authentik.providers.oauth2.views.token import TokenParams
class TestToken(TestCase): class TestToken(OAuthTestCase):
"""Test token view""" """Test token view"""
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.factory = RequestFactory() self.factory = RequestFactory()
self.app = Application.objects.create(name="test", slug="test")
def test_request_auth_code(self): def test_request_auth_code(self):
"""test request param""" """test request param"""
@ -97,12 +99,15 @@ class TestToken(TestCase):
authorization_flow=Flow.objects.first(), authorization_flow=Flow.objects.first(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
) )
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
header = b64encode( header = b64encode(
f"{provider.client_id}:{provider.client_secret}".encode() f"{provider.client_id}:{provider.client_secret}".encode()
).decode() ).decode()
user = User.objects.get(username="akadmin") user = User.objects.get(username="akadmin")
code = AuthorizationCode.objects.create( code = AuthorizationCode.objects.create(
code="foobar", provider=provider, user=user code="foobar", provider=provider, user=user, is_open_id=True
) )
response = self.client.post( response = self.client.post(
reverse("authentik_providers_oauth2:token"), reverse("authentik_providers_oauth2:token"),
@ -126,6 +131,7 @@ class TestToken(TestCase):
), ),
}, },
) )
self.validate_jwt(new_token, provider)
def test_refresh_token_view(self): def test_refresh_token_view(self):
"""test request param""" """test request param"""
@ -136,6 +142,9 @@ class TestToken(TestCase):
authorization_flow=Flow.objects.first(), authorization_flow=Flow.objects.first(),
redirect_uris="http://local.invalid", redirect_uris="http://local.invalid",
) )
# Needs to be assigned to an application for iss to be set
self.app.provider = provider
self.app.save()
header = b64encode( header = b64encode(
f"{provider.client_id}:{provider.client_secret}".encode() f"{provider.client_id}:{provider.client_secret}".encode()
).decode() ).decode()
@ -174,6 +183,7 @@ class TestToken(TestCase):
), ),
}, },
) )
self.validate_jwt(new_token, provider)
def test_refresh_token_view_invalid_origin(self): def test_refresh_token_view_invalid_origin(self):
"""test request param""" """test request param"""

View File

@ -0,0 +1,31 @@
"""OAuth test helpers"""
from django.test import TestCase
from jwt import decode
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
class OAuthTestCase(TestCase):
"""OAuth test helpers"""
required_jwt_keys = [
"exp",
"iat",
"auth_time",
"acr",
"sub",
"iss",
]
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
"""Validate that all required fields are set"""
jwt = decode(
token.access_token,
provider.client_secret,
algorithms=[provider.jwt_alg],
audience=provider.client_id,
)
id_token = token.id_token.to_dict()
for key in self.required_jwt_keys:
self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")

View File

@ -16,6 +16,7 @@ from authentik.providers.oauth2.constants import (
from authentik.providers.oauth2.errors import TokenError, UserAuthError from authentik.providers.oauth2.errors import TokenError, UserAuthError
from authentik.providers.oauth2.models import ( from authentik.providers.oauth2.models import (
AuthorizationCode, AuthorizationCode,
ClientTypes,
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
@ -75,7 +76,7 @@ class TokenParams:
LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id)
raise TokenError("invalid_client") raise TokenError("invalid_client")
if self.provider.client_type == "confidential": if self.provider.client_type == ClientTypes.CONFIDENTIAL:
if self.provider.client_secret != self.client_secret: if self.provider.client_secret != self.client_secret:
LOGGER.warning( LOGGER.warning(
"Invalid client secret: client does not have secret", "Invalid client secret: client does not have secret",

View File

@ -2,10 +2,13 @@
from channels.auth import AuthMiddlewareStack from channels.auth import AuthMiddlewareStack
from django.urls import path from django.urls import path
from authentik.lib.sentry import SentryWSMiddleware
from authentik.outposts.channels import OutpostConsumer from authentik.outposts.channels import OutpostConsumer
from authentik.root.messages.consumer import MessageConsumer from authentik.root.messages.consumer import MessageConsumer
websocket_urlpatterns = [ websocket_urlpatterns = [
path("ws/outpost/<uuid:pk>/", OutpostConsumer.as_asgi()), path("ws/outpost/<uuid:pk>/", SentryWSMiddleware(OutpostConsumer.as_asgi())),
path("ws/client/", AuthMiddlewareStack(MessageConsumer.as_asgi())), path(
"ws/client/", AuthMiddlewareStack(SentryWSMiddleware(MessageConsumer.as_asgi()))
),
] ]

View File

@ -1,4 +1,5 @@
"""OAuth Source Serializer""" """OAuth Source Serializer"""
from guardian.utils import get_anonymous_user
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import SourceSerializer from authentik.core.api.sources import SourceSerializer
@ -26,8 +27,7 @@ class UserOAuthSourceConnectionViewSet(ModelViewSet):
filterset_fields = ["source__slug"] filterset_fields = ["source__slug"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
if user.is_superuser:
return super().get_queryset() return super().get_queryset()
if self.request.user.is_superuser: return super().get_queryset().filter(user=user)
return super().get_queryset()
return super().get_queryset().filter(user=self.request.user)

View File

@ -18,7 +18,7 @@ DISCORD_USER = {
} }
class TestTypeGitHub(TestCase): class TestTypeDiscord(TestCase):
"""OAuth Source tests""" """OAuth Source tests"""
def setUp(self): def setUp(self):
@ -32,7 +32,7 @@ class TestTypeGitHub(TestCase):
) )
def test_enroll_context(self): def test_enroll_context(self):
"""Test GitHub Enrollment context""" """Test discord Enrollment context"""
ak_context = DiscordOAuth2Callback().get_user_enroll_context( ak_context = DiscordOAuth2Callback().get_user_enroll_context(
self.source, UserOAuthSourceConnection(), DISCORD_USER self.source, UserOAuthSourceConnection(), DISCORD_USER
) )

View File

@ -0,0 +1,40 @@
"""google Type tests"""
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.google import GoogleOAuth2Callback
# https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en
GOOGLE_USER = {
"id": "1324813249123401234",
"email": "foo@bar.baz",
"verified_email": True,
"name": "foo bar",
"given_name": "foo",
"family_name": "bar",
"picture": "",
"locale": "en",
}
class TestTypeGoogle(TestCase):
"""OAuth Source tests"""
def setUp(self):
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="google",
authorization_url="",
profile_url="",
consumer_key="",
)
def test_enroll_context(self):
"""Test Google Enrollment context"""
ak_context = GoogleOAuth2Callback().get_user_enroll_context(
self.source, UserOAuthSourceConnection(), GOOGLE_USER
)
self.assertEqual(ak_context["username"], GOOGLE_USER["email"])
self.assertEqual(ak_context["email"], GOOGLE_USER["email"])
self.assertEqual(ak_context["name"], GOOGLE_USER["name"])

View File

@ -1,5 +1,4 @@
"""Dispatch OAuth views to respective views""" """Dispatch OAuth views to respective views"""
from django.http import Http404
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.views import View from django.views import View
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -15,12 +14,9 @@ class DispatcherView(View):
kind = "" kind = ""
def dispatch(self, *args, **kwargs): def dispatch(self, *args, source_slug: str, **kwargs):
"""Find Source by slug and forward request""" """Find Source by slug and forward request"""
slug = kwargs.get("source_slug", None) source = get_object_or_404(OAuthSource, slug=source_slug)
if not slug:
raise Http404
source = get_object_or_404(OAuthSource, slug=slug)
view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind)) view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind))
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind) LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
return view.as_view()(*args, **kwargs) return view.as_view()(*args, source_slug=source_slug, **kwargs)

View File

@ -1,5 +1,6 @@
"""AuthenticatorStaticStage API Views""" """AuthenticatorStaticStage API Views"""
from django_otp.plugins.otp_static.models import StaticDevice from django_otp.plugins.otp_static.models import StaticDevice
from guardian.utils import get_anonymous_user
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
@ -44,9 +45,8 @@ class StaticDeviceViewSet(ModelViewSet):
ordering = ["name"] ordering = ["name"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
return super().get_queryset() return StaticDevice.objects.filter(user=user)
return StaticDevice.objects.filter(user=self.request.user)
class StaticAdminDeviceViewSet(ReadOnlyModelViewSet): class StaticAdminDeviceViewSet(ReadOnlyModelViewSet):

View File

@ -1,5 +1,6 @@
"""AuthenticatorTOTPStage API Views""" """AuthenticatorTOTPStage API Views"""
from django_otp.plugins.otp_totp.models import TOTPDevice from django_otp.plugins.otp_totp.models import TOTPDevice
from guardian.utils import get_anonymous_user
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
@ -47,9 +48,8 @@ class TOTPDeviceViewSet(ModelViewSet):
ordering = ["name"] ordering = ["name"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
return super().get_queryset() return TOTPDevice.objects.filter(user=user)
return TOTPDevice.objects.filter(user=self.request.user)
class TOTPAdminDeviceViewSet(ReadOnlyModelViewSet): class TOTPAdminDeviceViewSet(ReadOnlyModelViewSet):

View File

@ -1,4 +1,5 @@
"""AuthenticateWebAuthnStage API Views""" """AuthenticateWebAuthnStage API Views"""
from guardian.utils import get_anonymous_user
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
@ -46,9 +47,8 @@ class WebAuthnDeviceViewSet(ModelViewSet):
ordering = ["name"] ordering = ["name"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
return super().get_queryset() return WebAuthnDevice.objects.filter(user=user)
return WebAuthnDevice.objects.filter(user=self.request.user)
class WebAuthnAdminDeviceViewSet(ReadOnlyModelViewSet): class WebAuthnAdminDeviceViewSet(ReadOnlyModelViewSet):

View File

@ -1,4 +1,5 @@
"""ConsentStage API Views""" """ConsentStage API Views"""
from guardian.utils import get_anonymous_user
from rest_framework import mixins from rest_framework import mixins
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
@ -50,8 +51,7 @@ class UserConsentViewSet(
ordering = ["application", "expires"] ordering = ["application", "expires"]
def get_queryset(self): def get_queryset(self):
if not self.request: user = self.request.user if self.request else get_anonymous_user()
if user.is_superuser:
return super().get_queryset() return super().get_queryset()
if self.request.user.is_superuser: return super().get_queryset().filter(user=user)
return super().get_queryset()
return super().get_queryset().filter(user=self.request.user)

View File

@ -68,7 +68,7 @@ def send_mail(
messages=["Successfully sent Mail."], messages=["Successfully sent Mail."],
) )
) )
except (SMTPException, ConnectionError) as exc: except (SMTPException, ConnectionError, ValueError) as exc:
LOGGER.debug("Error sending email, retrying...", exc=exc) LOGGER.debug("Error sending email, retrying...", exc=exc)
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc)) self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
raise exc raise exc

View File

@ -1,4 +1,5 @@
"""login tests""" """login tests"""
from time import sleep
from unittest.mock import patch from unittest.mock import patch
from django.test import Client, TestCase from django.test import Client, TestCase
@ -51,6 +52,31 @@ class TestUserLoginStage(TestCase):
{"to": reverse("authentik_core:root-redirect"), "type": "redirect"}, {"to": reverse("authentik_core:root-redirect"), "type": "redirect"},
) )
def test_expiry(self):
"""Test with expiry"""
self.stage.session_duration = "seconds=2"
self.stage.save()
plan = FlowPlan(
flow_pk=self.flow.pk.hex, stages=[self.stage], markers=[StageMarker()]
)
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
self.assertEqual(response.status_code, 200)
self.assertJSONEqual(
force_str(response.content),
{"to": reverse("authentik_core:root-redirect"), "type": "redirect"},
)
self.assertNotEqual(list(self.client.session.keys()), [])
sleep(3)
self.client.session.clear_expired()
self.assertEqual(list(self.client.session.keys()), [])
@patch( @patch(
"authentik.flows.views.to_stage_response", "authentik.flows.views.to_stage_response",
TO_STAGE_RESPONSE_MOCK, TO_STAGE_RESPONSE_MOCK,

View File

@ -23,6 +23,11 @@ export function configureSentry(canDoPpi: boolean = false): Promise<Config> {
if (hint.originalException instanceof SentryIgnoredError) { if (hint.originalException instanceof SentryIgnoredError) {
return null; return null;
} }
if (hint.originalException instanceof Error) {
if (hint.originalException.name == 'NetworkError') {
return null;
}
}
if (hint.originalException instanceof Response) { if (hint.originalException instanceof Response) {
const response = hint.originalException as Response; const response = hint.originalException as Response;
// We only care about server errors // We only care about server errors

View File

@ -70,7 +70,6 @@ export class ServiceConnectionDockerForm extends Form<DockerServiceConnection> {
</ak-form-element-horizontal> </ak-form-element-horizontal>
<ak-form-element-horizontal <ak-form-element-horizontal
label=${t`TLS Verification Certificate`} label=${t`TLS Verification Certificate`}
?required=${true}
name="tlsVerification"> name="tlsVerification">
<select class="pf-c-form-control"> <select class="pf-c-form-control">
<option value="" ?selected=${this.sc?.tlsVerification === undefined}>---------</option> <option value="" ?selected=${this.sc?.tlsVerification === undefined}>---------</option>
@ -86,7 +85,6 @@ export class ServiceConnectionDockerForm extends Form<DockerServiceConnection> {
</ak-form-element-horizontal> </ak-form-element-horizontal>
<ak-form-element-horizontal <ak-form-element-horizontal
label=${t`TLS Authentication Certificate`} label=${t`TLS Authentication Certificate`}
?required=${true}
name="tlsAuthentication"> name="tlsAuthentication">
<select class="pf-c-form-control"> <select class="pf-c-form-control">
<option value="" ?selected=${this.sc?.tlsAuthentication === undefined}>---------</option> <option value="" ?selected=${this.sc?.tlsAuthentication === undefined}>---------</option>