diff --git a/authentik/api/auth.py b/authentik/api/auth.py index d25db3629..27060a7e8 100644 --- a/authentik/api/auth.py +++ b/authentik/api/auth.py @@ -4,6 +4,7 @@ from binascii import Error from typing import Any, Optional, Union from rest_framework.authentication import BaseAuthentication, get_authorization_header +from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request from structlog.stdlib import get_logger @@ -14,7 +15,7 @@ LOGGER = get_logger() # pylint: disable=too-many-return-statements def token_from_header(raw_header: bytes) -> Optional[Token]: - """raw_header in the Format of `Basic dGVzdDp0ZXN0`""" + """raw_header in the Format of `Bearer dGVzdDp0ZXN0`""" auth_credentials = raw_header.decode() if auth_credentials == "": return None @@ -25,28 +26,27 @@ def token_from_header(raw_header: bytes) -> Optional[Token]: auth_type, body = plain.split() auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}" except (UnicodeDecodeError, Error): - return None + raise AuthenticationFailed("Malformed header") auth_type, auth_credentials = auth_credentials.split() if auth_type.lower() not in ["basic", "bearer"]: LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower()) - return None + raise AuthenticationFailed("Unsupported authentication type") password = auth_credentials if auth_type.lower() == "basic": try: auth_credentials = b64decode(auth_credentials.encode()).decode() except (UnicodeDecodeError, Error): - return None + raise AuthenticationFailed("Malformed header") # Accept credentials with username and without if ":" in auth_credentials: _, password = auth_credentials.split(":") else: password = auth_credentials if password == "": # nosec - return None + raise AuthenticationFailed("Malformed header") tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API) if not tokens.exists(): - LOGGER.debug("Token not found") - return None + raise AuthenticationFailed("Token invalid/expired") return tokens.first() @@ -58,6 +58,7 @@ class AuthentikTokenAuthentication(BaseAuthentication): auth = get_authorization_header(request) token = token_from_header(auth) + # None is only returned when the header isn't set. if not token: return None diff --git a/authentik/api/tests/test_auth.py b/authentik/api/tests/test_auth.py index 558bf603c..4f6a6120f 100644 --- a/authentik/api/tests/test_auth.py +++ b/authentik/api/tests/test_auth.py @@ -3,6 +3,7 @@ from base64 import b64encode from django.test import TestCase from guardian.shortcuts import get_anonymous_user +from rest_framework.exceptions import AuthenticationFailed from authentik.api.auth import token_from_header from authentik.core.models import Token, TokenIntents @@ -28,17 +29,21 @@ class TestAPIAuth(TestCase): def test_invalid_type(self): """Test invalid type""" - self.assertIsNone(token_from_header("foo bar".encode())) + with self.assertRaises(AuthenticationFailed): + token_from_header("foo bar".encode()) def test_invalid_decode(self): """Test invalid bas64""" - self.assertIsNone(token_from_header("Basic bar".encode())) + with self.assertRaises(AuthenticationFailed): + token_from_header("Basic bar".encode()) def test_invalid_empty_password(self): """Test invalid with empty password""" - self.assertIsNone(token_from_header("Basic :".encode())) + with self.assertRaises(AuthenticationFailed): + token_from_header("Basic :".encode()) def test_invalid_no_token(self): """Test invalid with no token""" - auth = b64encode(":abc".encode()).decode() - self.assertIsNone(token_from_header(f"Basic :{auth}".encode())) + with self.assertRaises(AuthenticationFailed): + auth = b64encode(":abc".encode()).decode() + self.assertIsNone(token_from_header(f"Basic :{auth}".encode())) diff --git a/authentik/core/channels.py b/authentik/core/channels.py index 9b21d7da5..cb124e4b4 100644 --- a/authentik/core/channels.py +++ b/authentik/core/channels.py @@ -1,6 +1,7 @@ """Channels base classes""" from channels.exceptions import DenyConnection from channels.generic.websocket import JsonWebsocketConsumer +from rest_framework.exceptions import AuthenticationFailed from structlog.stdlib import get_logger from authentik.api.auth import token_from_header @@ -22,9 +23,13 @@ class AuthJsonConsumer(JsonWebsocketConsumer): raw_header = headers[b"authorization"] - token = token_from_header(raw_header) - if not token: - LOGGER.warning("Failed to authenticate") + try: + token = token_from_header(raw_header) + # token is only None when no header was given, in which case we deny too + if not token: + raise DenyConnection() + except AuthenticationFailed as exc: + LOGGER.warning("Failed to authenticate", exc=exc) raise DenyConnection() self.user = token.user