core: add tests for authenticated sessions
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
fc45d35699
commit
b9773d39c0
|
@ -1,10 +1,12 @@
|
||||||
"""AuthenticatedSessions API Viewset"""
|
"""AuthenticatedSessions API Viewset"""
|
||||||
from typing import Optional, TypedDict
|
from typing import Optional, TypedDict
|
||||||
|
|
||||||
|
from django_filters.rest_framework import DjangoFilterBackend
|
||||||
from geoip2.errors import GeoIP2Error
|
from geoip2.errors import GeoIP2Error
|
||||||
from guardian.utils import get_anonymous_user
|
from guardian.utils import get_anonymous_user
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from rest_framework.fields import SerializerMethodField
|
from rest_framework.fields import SerializerMethodField
|
||||||
|
from rest_framework.filters import OrderingFilter, SearchFilter
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import GenericViewSet
|
from rest_framework.viewsets import GenericViewSet
|
||||||
|
@ -75,7 +77,9 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
"""Get parsed user agent"""
|
"""Get parsed user agent"""
|
||||||
return user_agent_parser.Parse(instance.last_user_agent)
|
return user_agent_parser.Parse(instance.last_user_agent)
|
||||||
|
|
||||||
def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]:
|
def get_geo_ip(
|
||||||
|
self, instance: AuthenticatedSession
|
||||||
|
) -> Optional[GeoIPDict]: # pragma: no cover
|
||||||
"""Get parsed user agent"""
|
"""Get parsed user agent"""
|
||||||
if not GEOIP_READER:
|
if not GEOIP_READER:
|
||||||
return None
|
return None
|
||||||
|
@ -87,7 +91,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
"lat": city.location.latitude,
|
"lat": city.location.latitude,
|
||||||
"long": city.location.longitude,
|
"long": city.location.longitude,
|
||||||
}
|
}
|
||||||
except GeoIP2Error:
|
except (GeoIP2Error, ValueError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -119,6 +123,11 @@ class AuthenticatedSessionViewSet(
|
||||||
search_fields = ["user__username", "last_ip", "last_user_agent"]
|
search_fields = ["user__username", "last_ip", "last_user_agent"]
|
||||||
filterset_fields = ["user__username", "last_ip", "last_user_agent"]
|
filterset_fields = ["user__username", "last_ip", "last_user_agent"]
|
||||||
ordering = ["user__username"]
|
ordering = ["user__username"]
|
||||||
|
filter_backends = [
|
||||||
|
DjangoFilterBackend,
|
||||||
|
OrderingFilter,
|
||||||
|
SearchFilter,
|
||||||
|
]
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
user = self.request.user if self.request else get_anonymous_user()
|
user = self.request.user if self.request else get_anonymous_user()
|
||||||
|
|
31
authentik/core/tests/test_authenticated_sessions_api.py
Normal file
31
authentik/core/tests/test_authenticated_sessions_api.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
"""Test AuthenticatedSessions API"""
|
||||||
|
from json import loads
|
||||||
|
|
||||||
|
from django.urls.base import reverse
|
||||||
|
from django.utils.encoding import force_str
|
||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.core.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthenticatedSessionsAPI(APITestCase):
|
||||||
|
"""Test AuthenticatedSessions API"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
self.user = User.objects.get(username="akadmin")
|
||||||
|
self.other_user = User.objects.create(username="normal-user")
|
||||||
|
|
||||||
|
def test_list(self):
|
||||||
|
"""Test session list endpoint"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get(reverse("authentik_api:authenticatedsession-list"))
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
def test_non_admin_list(self):
|
||||||
|
"""Test non-admin list"""
|
||||||
|
self.client.force_login(self.other_user)
|
||||||
|
response = self.client.get(reverse("authentik_api:authenticatedsession-list"))
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = loads(force_str(response.content))
|
||||||
|
self.assertEqual(body["pagination"]["count"], 1)
|
|
@ -158,7 +158,7 @@ class Event(ExpiringModel):
|
||||||
self.save()
|
self.save()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_geoip(self):
|
def with_geoip(self): # pragma: no cover
|
||||||
"""Apply GeoIP Data, when enabled"""
|
"""Apply GeoIP Data, when enabled"""
|
||||||
if not GEOIP_READER:
|
if not GEOIP_READER:
|
||||||
return
|
return
|
||||||
|
@ -172,7 +172,7 @@ class Event(ExpiringModel):
|
||||||
}
|
}
|
||||||
if response.city.name:
|
if response.city.name:
|
||||||
self.context["geo"]["city"] = response.city.name
|
self.context["geo"]["city"] = response.city.name
|
||||||
except GeoIP2Error as exc:
|
except (GeoIP2Error, ValueError) as exc:
|
||||||
LOGGER.warning("Failed to add geoIP Data to event", exc=exc)
|
LOGGER.warning("Failed to add geoIP Data to event", exc=exc)
|
||||||
|
|
||||||
def _set_prom_metrics(self):
|
def _set_prom_metrics(self):
|
||||||
|
|
|
@ -36,7 +36,7 @@ class PolicyRequest:
|
||||||
self.obj = None
|
self.obj = None
|
||||||
self.context = {}
|
self.context = {}
|
||||||
|
|
||||||
def set_http_request(self, request: HttpRequest):
|
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
||||||
"""Load data from HTTP request, including geoip when enabled"""
|
"""Load data from HTTP request, including geoip when enabled"""
|
||||||
self.http_request = request
|
self.http_request = request
|
||||||
if not GEOIP_READER:
|
if not GEOIP_READER:
|
||||||
|
@ -47,7 +47,7 @@ class PolicyRequest:
|
||||||
return
|
return
|
||||||
response = GEOIP_READER.city(client_ip)
|
response = GEOIP_READER.city(client_ip)
|
||||||
self.context["geoip"] = response
|
self.context["geoip"] = response
|
||||||
except GeoIP2Error as exc:
|
except (GeoIP2Error, ValueError) as exc:
|
||||||
LOGGER.warning("failed to get geoip data", exc=exc)
|
LOGGER.warning("failed to get geoip data", exc=exc)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
Reference in a new issue