api: make 401 messages clearer
closes #755 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
837d2f6fab
commit
464a1c0536
|
@ -4,6 +4,7 @@ from binascii import Error
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
|
@ -14,7 +15,7 @@ LOGGER = get_logger()
|
|||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def token_from_header(raw_header: bytes) -> Optional[Token]:
|
||||
"""raw_header in the Format of `Basic dGVzdDp0ZXN0`"""
|
||||
"""raw_header in the Format of `Bearer dGVzdDp0ZXN0`"""
|
||||
auth_credentials = raw_header.decode()
|
||||
if auth_credentials == "":
|
||||
return None
|
||||
|
@ -25,28 +26,27 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
|
|||
auth_type, body = plain.split()
|
||||
auth_credentials = f"{auth_type} {b64encode(body.encode()).decode()}"
|
||||
except (UnicodeDecodeError, Error):
|
||||
return None
|
||||
raise AuthenticationFailed("Malformed header")
|
||||
auth_type, auth_credentials = auth_credentials.split()
|
||||
if auth_type.lower() not in ["basic", "bearer"]:
|
||||
LOGGER.debug("Unsupported authentication type, denying", type=auth_type.lower())
|
||||
return None
|
||||
raise AuthenticationFailed("Unsupported authentication type")
|
||||
password = auth_credentials
|
||||
if auth_type.lower() == "basic":
|
||||
try:
|
||||
auth_credentials = b64decode(auth_credentials.encode()).decode()
|
||||
except (UnicodeDecodeError, Error):
|
||||
return None
|
||||
raise AuthenticationFailed("Malformed header")
|
||||
# Accept credentials with username and without
|
||||
if ":" in auth_credentials:
|
||||
_, password = auth_credentials.split(":")
|
||||
else:
|
||||
password = auth_credentials
|
||||
if password == "": # nosec
|
||||
return None
|
||||
raise AuthenticationFailed("Malformed header")
|
||||
tokens = Token.filter_not_expired(key=password, intent=TokenIntents.INTENT_API)
|
||||
if not tokens.exists():
|
||||
LOGGER.debug("Token not found")
|
||||
return None
|
||||
raise AuthenticationFailed("Token invalid/expired")
|
||||
return tokens.first()
|
||||
|
||||
|
||||
|
@ -58,6 +58,7 @@ class AuthentikTokenAuthentication(BaseAuthentication):
|
|||
auth = get_authorization_header(request)
|
||||
|
||||
token = token_from_header(auth)
|
||||
# None is only returned when the header isn't set.
|
||||
if not token:
|
||||
return None
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from base64 import b64encode
|
|||
|
||||
from django.test import TestCase
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
from authentik.api.auth import token_from_header
|
||||
from authentik.core.models import Token, TokenIntents
|
||||
|
@ -28,17 +29,21 @@ class TestAPIAuth(TestCase):
|
|||
|
||||
def test_invalid_type(self):
|
||||
"""Test invalid type"""
|
||||
self.assertIsNone(token_from_header("foo bar".encode()))
|
||||
with self.assertRaises(AuthenticationFailed):
|
||||
token_from_header("foo bar".encode())
|
||||
|
||||
def test_invalid_decode(self):
|
||||
"""Test invalid bas64"""
|
||||
self.assertIsNone(token_from_header("Basic bar".encode()))
|
||||
with self.assertRaises(AuthenticationFailed):
|
||||
token_from_header("Basic bar".encode())
|
||||
|
||||
def test_invalid_empty_password(self):
|
||||
"""Test invalid with empty password"""
|
||||
self.assertIsNone(token_from_header("Basic :".encode()))
|
||||
with self.assertRaises(AuthenticationFailed):
|
||||
token_from_header("Basic :".encode())
|
||||
|
||||
def test_invalid_no_token(self):
|
||||
"""Test invalid with no token"""
|
||||
auth = b64encode(":abc".encode()).decode()
|
||||
self.assertIsNone(token_from_header(f"Basic :{auth}".encode()))
|
||||
with self.assertRaises(AuthenticationFailed):
|
||||
auth = b64encode(":abc".encode()).decode()
|
||||
self.assertIsNone(token_from_header(f"Basic :{auth}".encode()))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Channels base classes"""
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.auth import token_from_header
|
||||
|
@ -22,9 +23,13 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
|
|||
|
||||
raw_header = headers[b"authorization"]
|
||||
|
||||
token = token_from_header(raw_header)
|
||||
if not token:
|
||||
LOGGER.warning("Failed to authenticate")
|
||||
try:
|
||||
token = token_from_header(raw_header)
|
||||
# token is only None when no header was given, in which case we deny too
|
||||
if not token:
|
||||
raise DenyConnection()
|
||||
except AuthenticationFailed as exc:
|
||||
LOGGER.warning("Failed to authenticate", exc=exc)
|
||||
raise DenyConnection()
|
||||
|
||||
self.user = token.user
|
||||
|
|
Reference in a new issue