diff --git a/authentik/api/v2/config.py b/authentik/api/v2/config.py index e6de567f7..fe291fe9d 100644 --- a/authentik/api/v2/config.py +++ b/authentik/api/v2/config.py @@ -43,7 +43,7 @@ class ConfigView(APIView): deb_test = settings.DEBUG or settings.TEST if path.ismount(settings.MEDIA_ROOT) or deb_test: caps.append(Capabilities.CAN_SAVE_MEDIA) - if GEOIP_READER: + if GEOIP_READER.enabled: caps.append(Capabilities.CAN_GEO_IP) return caps diff --git a/authentik/core/api/authenticated_sessions.py b/authentik/core/api/authenticated_sessions.py index 07e44e26a..c3cdb69a7 100644 --- a/authentik/core/api/authenticated_sessions.py +++ b/authentik/core/api/authenticated_sessions.py @@ -2,7 +2,6 @@ from typing import Optional, TypedDict from django_filters.rest_framework import DjangoFilterBackend -from geoip2.errors import GeoIP2Error from guardian.utils import get_anonymous_user from rest_framework import mixins from rest_framework.fields import SerializerMethodField @@ -13,7 +12,7 @@ from rest_framework.viewsets import GenericViewSet from ua_parser import user_agent_parser from authentik.core.models import AuthenticatedSession -from authentik.events.geo import GEOIP_READER +from authentik.events.geo import GEOIP_READER, GeoIPDict class UserAgentDeviceDict(TypedDict): @@ -52,15 +51,6 @@ class UserAgentDict(TypedDict): string: str -class GeoIPDict(TypedDict): - """GeoIP Details""" - - continent: str - country: str - lat: float - long: float - - class AuthenticatedSessionSerializer(ModelSerializer): """AuthenticatedSession Serializer""" @@ -81,18 +71,7 @@ class AuthenticatedSessionSerializer(ModelSerializer): self, instance: AuthenticatedSession ) -> Optional[GeoIPDict]: # pragma: no cover """Get parsed user agent""" - if not GEOIP_READER: - return None - try: - city = GEOIP_READER.city(instance.last_ip) - return { - "continent": city.continent.code, - "country": city.country.iso_code, - "lat": city.location.latitude, - "long": city.location.longitude, - } - except (GeoIP2Error, ValueError): - return None + return GEOIP_READER.city_dict(instance.last_ip) class Meta: diff --git a/authentik/events/geo.py b/authentik/events/geo.py index ac1eecabe..042cfb178 100644 --- a/authentik/events/geo.py +++ b/authentik/events/geo.py @@ -1,7 +1,12 @@ """events GeoIP Reader""" -from typing import Optional +from datetime import datetime +from os import stat +from time import time +from typing import Optional, TypedDict from geoip2.database import Reader +from geoip2.errors import GeoIP2Error +from geoip2.models import City from structlog.stdlib import get_logger from authentik.lib.config import CONFIG @@ -9,17 +14,78 @@ from authentik.lib.config import CONFIG LOGGER = get_logger() -def get_geoip_reader() -> Optional[Reader]: - """Get GeoIP Reader, if configured, otherwise none""" - path = CONFIG.y("authentik.geoip") - if path == "" or not path: - return None - try: - reader = Reader(path) - LOGGER.info("Enabled GeoIP support") - return reader - except OSError: - return None +class GeoIPDict(TypedDict): + """GeoIP Details""" + + continent: str + country: str + lat: float + long: float + city: str -GEOIP_READER = get_geoip_reader() +class GeoIPReader: + """Slim wrapper around GeoIP API""" + + __reader: Optional[Reader] = None + __last_mtime: float = 0.0 + + def __init__(self): + self.__open() + + def __open(self): + """Get GeoIP Reader, if configured, otherwise none""" + path = CONFIG.y("authentik.geoip") + if path == "" or not path: + return + try: + reader = Reader(path) + LOGGER.info("Loaded GeoIP database") + self.__reader = reader + self.__last_mtime = stat(path).st_mtime + except OSError as exc: + LOGGER.warning("Failed to load GeoIP database", exc=exc) + + def __check_expired(self): + """Check if the geoip database has been opened longer than 8 hours, + and re-open it, as it will probably will have been re-downloaded""" + now = time() + diff = datetime.fromtimestamp(now) - datetime.fromtimestamp(self.__last_mtime) + diff_hours = diff.total_seconds() // 3600 + if diff_hours >= 8: + LOGGER.info("GeoIP databased loaded too long, re-opening", diff=diff) + self.__open() + + @property + def enabled(self) -> bool: + """Check if GeoIP is enabled""" + return bool(self.__reader) + + def city(self, ip_address: str) -> Optional[City]: + """Wrapper for Reader.city""" + if not self.enabled: + return None + self.__check_expired() + try: + return self.__reader.city(ip_address) + except (GeoIP2Error, ValueError): + return None + + def city_dict(self, ip_address: str) -> Optional[GeoIPDict]: + """Wrapper for self.city that returns a dict""" + city = self.city(ip_address) + if not city: + return None + city_dict: GeoIPDict = { + "continent": city.continent.code, + "country": city.country.iso_code, + "lat": city.location.latitude, + "long": city.location.longitude, + "city": "", + } + if city.city.name: + city_dict["city"] = city.city.name + return city_dict + + +GEOIP_READER = GeoIPReader() diff --git a/authentik/events/models.py b/authentik/events/models.py index fa9392475..c543a3907 100644 --- a/authentik/events/models.py +++ b/authentik/events/models.py @@ -10,7 +10,6 @@ from django.db import models from django.http import HttpRequest from django.utils.timezone import now from django.utils.translation import gettext as _ -from geoip2.errors import GeoIP2Error from prometheus_client import Gauge from requests import RequestException, post from structlog.stdlib import get_logger @@ -160,20 +159,10 @@ class Event(ExpiringModel): def with_geoip(self): # pragma: no cover """Apply GeoIP Data, when enabled""" - if not GEOIP_READER: + city = GEOIP_READER.city_dict(self.client_ip) + if not city: return - try: - response = GEOIP_READER.city(self.client_ip) - self.context["geo"] = { - "continent": response.continent.code, - "country": response.country.iso_code, - "lat": response.location.latitude, - "long": response.location.longitude, - } - if response.city.name: - self.context["geo"]["city"] = response.city.name - except (GeoIP2Error, ValueError) as exc: - LOGGER.warning("Failed to add geoIP Data to event", exc=exc) + self.context["geo"] = city def _set_prom_metrics(self): GAUGE_EVENTS.labels( diff --git a/authentik/events/tests/test_geoip.py b/authentik/events/tests/test_geoip.py new file mode 100644 index 000000000..3120dacae --- /dev/null +++ b/authentik/events/tests/test_geoip.py @@ -0,0 +1,26 @@ +"""Test GeoIP Wrapper""" +from django.test import TestCase + +from authentik.events.geo import GeoIPReader + + +class TestGeoIP(TestCase): + """Test GeoIP Wrapper""" + + def setUp(self) -> None: + self.reader = GeoIPReader() + + def test_simple(self): + """Test simple city wrapper""" + # IPs from + # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json + self.assertEqual( + self.reader.city_dict("2.125.160.216"), + { + "city": "Boxford", + "continent": "EU", + "country": "GB", + "lat": 51.75, + "long": -1.25, + }, + ) diff --git a/authentik/policies/types.py b/authentik/policies/types.py index 6307aba9c..4a6e05a2e 100644 --- a/authentik/policies/types.py +++ b/authentik/policies/types.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional from django.db.models import Model from django.http import HttpRequest -from geoip2.errors import GeoIP2Error from structlog.stdlib import get_logger from authentik.events.geo import GEOIP_READER @@ -39,16 +38,12 @@ class PolicyRequest: def set_http_request(self, request: HttpRequest): # pragma: no cover """Load data from HTTP request, including geoip when enabled""" self.http_request = request - if not GEOIP_READER: + if not GEOIP_READER.enabled: return - try: - client_ip = get_client_ip(request) - if not client_ip: - return - response = GEOIP_READER.city(client_ip) - self.context["geoip"] = response - except (GeoIP2Error, ValueError) as exc: - LOGGER.warning("failed to get geoip data", exc=exc) + client_ip = get_client_ip(request) + if not client_ip: + return + self.context["geoip"] = GEOIP_READER.city(client_ip) def __str__(self): text = f"