events: rewrite GeoIP to a wrapper, reload file every 8 hours

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-06-06 00:38:14 +02:00
parent f5dbdbd48b
commit 17326615b7
7 changed files with 117 additions and 61 deletions

View file

@ -43,7 +43,7 @@ class ConfigView(APIView):
deb_test = settings.DEBUG or settings.TEST deb_test = settings.DEBUG or settings.TEST
if path.ismount(settings.MEDIA_ROOT) or deb_test: if path.ismount(settings.MEDIA_ROOT) or deb_test:
caps.append(Capabilities.CAN_SAVE_MEDIA) caps.append(Capabilities.CAN_SAVE_MEDIA)
if GEOIP_READER: if GEOIP_READER.enabled:
caps.append(Capabilities.CAN_GEO_IP) caps.append(Capabilities.CAN_GEO_IP)
return caps return caps

View file

@ -2,7 +2,6 @@
from typing import Optional, TypedDict from typing import Optional, TypedDict
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from geoip2.errors import GeoIP2Error
from guardian.utils import get_anonymous_user from guardian.utils import get_anonymous_user
from rest_framework import mixins from rest_framework import mixins
from rest_framework.fields import SerializerMethodField from rest_framework.fields import SerializerMethodField
@ -13,7 +12,7 @@ from rest_framework.viewsets import GenericViewSet
from ua_parser import user_agent_parser from ua_parser import user_agent_parser
from authentik.core.models import AuthenticatedSession from authentik.core.models import AuthenticatedSession
from authentik.events.geo import GEOIP_READER from authentik.events.geo import GEOIP_READER, GeoIPDict
class UserAgentDeviceDict(TypedDict): class UserAgentDeviceDict(TypedDict):
@ -52,15 +51,6 @@ class UserAgentDict(TypedDict):
string: str string: str
class GeoIPDict(TypedDict):
"""GeoIP Details"""
continent: str
country: str
lat: float
long: float
class AuthenticatedSessionSerializer(ModelSerializer): class AuthenticatedSessionSerializer(ModelSerializer):
"""AuthenticatedSession Serializer""" """AuthenticatedSession Serializer"""
@ -81,18 +71,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
self, instance: AuthenticatedSession self, instance: AuthenticatedSession
) -> Optional[GeoIPDict]: # pragma: no cover ) -> Optional[GeoIPDict]: # pragma: no cover
"""Get parsed user agent""" """Get parsed user agent"""
if not GEOIP_READER: return GEOIP_READER.city_dict(instance.last_ip)
return None
try:
city = GEOIP_READER.city(instance.last_ip)
return {
"continent": city.continent.code,
"country": city.country.iso_code,
"lat": city.location.latitude,
"long": city.location.longitude,
}
except (GeoIP2Error, ValueError):
return None
class Meta: class Meta:

View file

@ -1,7 +1,12 @@
"""events GeoIP Reader""" """events GeoIP Reader"""
from typing import Optional from datetime import datetime
from os import stat
from time import time
from typing import Optional, TypedDict
from geoip2.database import Reader from geoip2.database import Reader
from geoip2.errors import GeoIP2Error
from geoip2.models import City
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
@ -9,17 +14,78 @@ from authentik.lib.config import CONFIG
LOGGER = get_logger() LOGGER = get_logger()
def get_geoip_reader() -> Optional[Reader]: class GeoIPDict(TypedDict):
"""Get GeoIP Reader, if configured, otherwise none""" """GeoIP Details"""
path = CONFIG.y("authentik.geoip")
if path == "" or not path: continent: str
return None country: str
try: lat: float
reader = Reader(path) long: float
LOGGER.info("Enabled GeoIP support") city: str
return reader
except OSError:
return None
GEOIP_READER = get_geoip_reader() class GeoIPReader:
"""Slim wrapper around GeoIP API"""
__reader: Optional[Reader] = None
__last_mtime: float = 0.0
def __init__(self):
self.__open()
def __open(self):
"""Get GeoIP Reader, if configured, otherwise none"""
path = CONFIG.y("authentik.geoip")
if path == "" or not path:
return
try:
reader = Reader(path)
LOGGER.info("Loaded GeoIP database")
self.__reader = reader
self.__last_mtime = stat(path).st_mtime
except OSError as exc:
LOGGER.warning("Failed to load GeoIP database", exc=exc)
def __check_expired(self):
"""Check if the geoip database has been opened longer than 8 hours,
and re-open it, as it will probably will have been re-downloaded"""
now = time()
diff = datetime.fromtimestamp(now) - datetime.fromtimestamp(self.__last_mtime)
diff_hours = diff.total_seconds() // 3600
if diff_hours >= 8:
LOGGER.info("GeoIP databased loaded too long, re-opening", diff=diff)
self.__open()
@property
def enabled(self) -> bool:
"""Check if GeoIP is enabled"""
return bool(self.__reader)
def city(self, ip_address: str) -> Optional[City]:
"""Wrapper for Reader.city"""
if not self.enabled:
return None
self.__check_expired()
try:
return self.__reader.city(ip_address)
except (GeoIP2Error, ValueError):
return None
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
"""Wrapper for self.city that returns a dict"""
city = self.city(ip_address)
if not city:
return None
city_dict: GeoIPDict = {
"continent": city.continent.code,
"country": city.country.iso_code,
"lat": city.location.latitude,
"long": city.location.longitude,
"city": "",
}
if city.city.name:
city_dict["city"] = city.city.name
return city_dict
GEOIP_READER = GeoIPReader()

View file

@ -10,7 +10,6 @@ from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from geoip2.errors import GeoIP2Error
from prometheus_client import Gauge from prometheus_client import Gauge
from requests import RequestException, post from requests import RequestException, post
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -160,20 +159,10 @@ class Event(ExpiringModel):
def with_geoip(self): # pragma: no cover def with_geoip(self): # pragma: no cover
"""Apply GeoIP Data, when enabled""" """Apply GeoIP Data, when enabled"""
if not GEOIP_READER: city = GEOIP_READER.city_dict(self.client_ip)
if not city:
return return
try: self.context["geo"] = city
response = GEOIP_READER.city(self.client_ip)
self.context["geo"] = {
"continent": response.continent.code,
"country": response.country.iso_code,
"lat": response.location.latitude,
"long": response.location.longitude,
}
if response.city.name:
self.context["geo"]["city"] = response.city.name
except (GeoIP2Error, ValueError) as exc:
LOGGER.warning("Failed to add geoIP Data to event", exc=exc)
def _set_prom_metrics(self): def _set_prom_metrics(self):
GAUGE_EVENTS.labels( GAUGE_EVENTS.labels(

View file

@ -0,0 +1,26 @@
"""Test GeoIP Wrapper"""
from django.test import TestCase
from authentik.events.geo import GeoIPReader
class TestGeoIP(TestCase):
"""Test GeoIP Wrapper"""
def setUp(self) -> None:
self.reader = GeoIPReader()
def test_simple(self):
"""Test simple city wrapper"""
# IPs from
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
self.assertEqual(
self.reader.city_dict("2.125.160.216"),
{
"city": "Boxford",
"continent": "EU",
"country": "GB",
"lat": 51.75,
"long": -1.25,
},
)

View file

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
from django.db.models import Model from django.db.models import Model
from django.http import HttpRequest from django.http import HttpRequest
from geoip2.errors import GeoIP2Error
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.events.geo import GEOIP_READER from authentik.events.geo import GEOIP_READER
@ -39,16 +38,12 @@ class PolicyRequest:
def set_http_request(self, request: HttpRequest): # pragma: no cover def set_http_request(self, request: HttpRequest): # pragma: no cover
"""Load data from HTTP request, including geoip when enabled""" """Load data from HTTP request, including geoip when enabled"""
self.http_request = request self.http_request = request
if not GEOIP_READER: if not GEOIP_READER.enabled:
return return
try: client_ip = get_client_ip(request)
client_ip = get_client_ip(request) if not client_ip:
if not client_ip: return
return self.context["geoip"] = GEOIP_READER.city(client_ip)
response = GEOIP_READER.city(client_ip)
self.context["geoip"] = response
except (GeoIP2Error, ValueError) as exc:
LOGGER.warning("failed to get geoip data", exc=exc)
def __str__(self): def __str__(self):
text = f"<PolicyRequest user={self.user}" text = f"<PolicyRequest user={self.user}"

View file

@ -14,6 +14,7 @@ class PytestTestRunner: # pragma: no cover
settings.TEST = True settings.TEST = True
settings.CELERY_TASK_ALWAYS_EAGER = True settings.CELERY_TASK_ALWAYS_EAGER = True
CONFIG.y_set("authentik.avatars", "none") CONFIG.y_set("authentik.avatars", "none")
CONFIG.y_set("authentik.geoip", "tests/GeoLite2-City-Test.mmdb")
def run_tests(self, test_labels): def run_tests(self, test_labels):
"""Run pytest and return the exitcode. """Run pytest and return the exitcode.