*: use common user agent for all outgoing requests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-09-11 20:54:58 +02:00
parent 7e7ef289ba
commit c779ad2e3b
10 changed files with 82 additions and 88 deletions

View file

@ -6,13 +6,14 @@ from django.core.cache import cache
from django.core.validators import URLValidator from django.core.validators import URLValidator
from packaging.version import parse from packaging.version import parse
from prometheus_client import Info from prometheus_client import Info
from requests import RequestException, get from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import ENV_GIT_HASH_KEY, __version__ from authentik import ENV_GIT_HASH_KEY, __version__
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
@ -42,7 +43,9 @@ def update_latest_version(self: MonitoredTask):
self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."])) self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."]))
return return
try: try:
response = get("https://version.goauthentik.io/version.json") response = get_http_session().get(
"https://version.goauthentik.io/version.json",
)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
upstream_version = data.get("stable", {}).get("version") upstream_version = data.get("stable", {}).get("version")
@ -62,7 +65,7 @@ def update_latest_version(self: MonitoredTask):
).exists(): ).exists():
return return
event_dict = {"new_version": upstream_version} event_dict = {"new_version": upstream_version}
if match := re.search(URL_FINDER, data.get("body", "")): if match := re.search(URL_FINDER, data.get("stable", {}).get("changelog", "")):
event_dict["message"] = f"Changelog: {match.group()}" event_dict["message"] = f"Changelog: {match.group()}"
Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save() Event.new(EventAction.UPDATE_AVAILABLE, **event_dict).save()
except (RequestException, IndexError) as exc: except (RequestException, IndexError) as exc:

View file

@ -1,56 +1,28 @@
"""test admin tasks""" """test admin tasks"""
import json
from dataclasses import dataclass
from unittest.mock import Mock, patch
from django.core.cache import cache from django.core.cache import cache
from django.test import TestCase from django.test import TestCase
from requests.exceptions import RequestException from requests_mock import Mocker
from authentik.admin.tasks import VERSION_CACHE_KEY, update_latest_version from authentik.admin.tasks import VERSION_CACHE_KEY, update_latest_version
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
RESPONSE_VALID = {
@dataclass
class MockResponse:
"""Mock class to emulate the methods of requests's Response we need"""
status_code: int
response: str
def json(self) -> dict:
"""Get json parsed response"""
return json.loads(self.response)
def raise_for_status(self):
"""raise RequestException if status code is 400 or more"""
if self.status_code >= 400:
raise RequestException
REQUEST_MOCK_VALID = Mock(
return_value=MockResponse(
200,
"""{
"$schema": "https://version.goauthentik.io/schema.json", "$schema": "https://version.goauthentik.io/schema.json",
"stable": { "stable": {
"version": "99999999.9999999", "version": "99999999.9999999",
"changelog": "See https://goauthentik.io/test", "changelog": "See https://goauthentik.io/test",
"reason": "bugfix" "reason": "bugfix",
} },
}""", }
)
)
REQUEST_MOCK_INVALID = Mock(return_value=MockResponse(400, "{}"))
class TestAdminTasks(TestCase): class TestAdminTasks(TestCase):
"""test admin tasks""" """test admin tasks"""
@patch("authentik.admin.tasks.get", REQUEST_MOCK_VALID)
def test_version_valid_response(self): def test_version_valid_response(self):
"""Test Update checker with valid response""" """Test Update checker with valid response"""
with Mocker() as mocker:
mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID)
update_latest_version.delay().get() update_latest_version.delay().get()
self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999")
self.assertTrue( self.assertTrue(
@ -73,9 +45,10 @@ class TestAdminTasks(TestCase):
1, 1,
) )
@patch("authentik.admin.tasks.get", REQUEST_MOCK_INVALID)
def test_version_error(self): def test_version_error(self):
"""Test Update checker with invalid response""" """Test Update checker with invalid response"""
with Mocker() as mocker:
mocker.get("https://version.goauthentik.io/version.json", status_code=400)
update_latest_version.delay().get() update_latest_version.delay().get()
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
self.assertFalse( self.assertFalse(

View file

@ -4,7 +4,6 @@ from json import loads
from django.conf import settings from django.conf import settings
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.http.response import HttpResponse from django.http.response import HttpResponse
from requests import post
from requests.exceptions import RequestException from requests.exceptions import RequestException
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.parsers import BaseParser from rest_framework.parsers import BaseParser
@ -14,6 +13,9 @@ from rest_framework.throttling import AnonRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session
SENTRY_SESSION = get_http_session()
class PlainTextParser(BaseParser): class PlainTextParser(BaseParser):
@ -54,10 +56,12 @@ class SentryTunnelView(APIView):
dsn = header.get("dsn", "") dsn = header.get("dsn", "")
if dsn != settings.SENTRY_DSN: if dsn != settings.SENTRY_DSN:
return HttpResponse(status=400) return HttpResponse(status=400)
response = post( response = SENTRY_SESSION.post(
"https://sentry.beryju.org/api/8/envelope/", "https://sentry.beryju.org/api/8/envelope/",
data=full_body, data=full_body,
headers={"Content-Type": "application/octet-stream"}, headers={
"Content-Type": "application/octet-stream",
},
) )
try: try:
response.raise_for_status() response.raise_for_status()

View file

@ -10,7 +10,7 @@ from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from requests import RequestException, post from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import __version__ from authentik import __version__
@ -19,7 +19,7 @@ from authentik.core.models import ExpiringModel, Group, User
from authentik.events.geo import GEOIP_READER from authentik.events.geo import GEOIP_READER
from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip, get_http_session
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.models import PolicyBindingModel from authentik.policies.models import PolicyBindingModel
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
@ -240,7 +240,7 @@ class NotificationTransport(models.Model):
def send_webhook(self, notification: "Notification") -> list[str]: def send_webhook(self, notification: "Notification") -> list[str]:
"""Send notification to generic webhook""" """Send notification to generic webhook"""
try: try:
response = post( response = get_http_session().post(
self.webhook_url, self.webhook_url,
json={ json={
"body": notification.body, "body": notification.body,
@ -297,7 +297,7 @@ class NotificationTransport(models.Model):
if notification.event: if notification.event:
body["attachments"][0]["title"] = notification.event.action body["attachments"][0]["title"] = notification.event.action
try: try:
response = post(self.webhook_url, json=body) response = get_http_session().post(self.webhook_url, json=body)
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
text = exc.response.text if exc.response else str(exc) text = exc.response.text if exc.response else str(exc)

View file

@ -4,13 +4,13 @@ from textwrap import indent
from typing import Any, Iterable, Optional from typing import Any, Iterable, Optional
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from requests import Session
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from sentry_sdk.hub import Hub from sentry_sdk.hub import Hub
from sentry_sdk.tracing import Span from sentry_sdk.tracing import Span
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import User from authentik.core.models import User
from authentik.lib.utils.http import get_http_session
LOGGER = get_logger() LOGGER = get_logger()
@ -35,7 +35,7 @@ class BaseEvaluator:
"ak_is_group_member": BaseEvaluator.expr_is_group_member, "ak_is_group_member": BaseEvaluator.expr_is_group_member,
"ak_user_by": BaseEvaluator.expr_user_by, "ak_user_by": BaseEvaluator.expr_user_by,
"ak_logger": get_logger(), "ak_logger": get_logger(),
"requests": Session(), "requests": get_http_session(),
} }
self._context = {} self._context = {}
self._filename = "BaseEvalautor" self._filename = "BaseEvalautor"

View file

@ -1,9 +1,13 @@
"""http helpers""" """http helpers"""
from os import environ
from typing import Any, Optional from typing import Any, Optional
from django.http import HttpRequest from django.http import HttpRequest
from requests.sessions import Session
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import ENV_GIT_HASH_KEY, __version__
OUTPOST_REMOTE_IP_HEADER = "HTTP_X_AUTHENTIK_REMOTE_IP" OUTPOST_REMOTE_IP_HEADER = "HTTP_X_AUTHENTIK_REMOTE_IP"
OUTPOST_TOKEN_HEADER = "HTTP_X_AUTHENTIK_OUTPOST_TOKEN" # nosec OUTPOST_TOKEN_HEADER = "HTTP_X_AUTHENTIK_OUTPOST_TOKEN" # nosec
DEFAULT_IP = "255.255.255.255" DEFAULT_IP = "255.255.255.255"
@ -60,3 +64,16 @@ def get_client_ip(request: Optional[HttpRequest]) -> str:
if override: if override:
return override return override
return _get_client_ip_from_meta(request.META) return _get_client_ip_from_meta(request.META)
def authentik_user_agent() -> str:
"""Get a common user agent"""
build = environ.get(ENV_GIT_HASH_KEY, "tagged")
return f"authentik@{__version__} (build={build})"
def get_http_session() -> Session:
"""Get a requests session with common headers"""
session = Session()
session.headers["User-Agent"] = authentik_user_agent()
return session

View file

@ -3,10 +3,10 @@ from hashlib import sha1
from django.db import models from django.db import models
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from requests import get
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.lib.utils.http import get_http_session
from authentik.policies.models import Policy, PolicyResult from authentik.policies.models import Policy, PolicyResult
from authentik.policies.types import PolicyRequest from authentik.policies.types import PolicyRequest
@ -49,7 +49,7 @@ class HaveIBeenPwendPolicy(Policy):
pw_hash = sha1(password.encode("utf-8")).hexdigest() # nosec pw_hash = sha1(password.encode("utf-8")).hexdigest() # nosec
url = f"https://api.pwnedpasswords.com/range/{pw_hash[:5]}" url = f"https://api.pwnedpasswords.com/range/{pw_hash[:5]}"
result = get(url).text result = get_http_session().get(url).text
final_count = 0 final_count = 0
for line in result.split("\r\n"): for line in result.split("\r\n"):
full_hash, count = line.split(":") full_hash, count = line.split(":")

View file

@ -8,8 +8,8 @@ from requests.exceptions import RequestException
from requests.models import Response from requests.models import Response
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import __version__
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.utils.http import get_http_session
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger() LOGGER = get_logger()
@ -27,10 +27,9 @@ class BaseOAuthClient:
def __init__(self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None): def __init__(self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None):
self.source = source self.source = source
self.session = Session() self.session = get_http_session()
self.request = request self.request = request
self.callback = callback self.callback = callback
self.session.headers.update({"User-Agent": f"authentik {__version__}"})
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."

View file

@ -2,12 +2,12 @@
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http.response import Http404 from django.http.response import Http404
from requests import Session
from requests.exceptions import RequestException from requests.exceptions import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import __version__ from authentik import __version__
from authentik.core.sources.flow_manager import SourceFlowManager from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.lib.utils.http import get_http_session
from authentik.sources.plex.models import PlexSource, PlexSourceConnection from authentik.sources.plex.models import PlexSource, PlexSourceConnection
LOGGER = get_logger() LOGGER = get_logger()
@ -24,7 +24,7 @@ class PlexAuth:
def __init__(self, source: PlexSource, token: str): def __init__(self, source: PlexSource, token: str):
self._source = source self._source = source
self._token = token self._token = token
self._session = Session() self._session = get_http_session()
self._session.headers.update( self._session.headers.update(
{"Accept": "application/json", "Content-Type": "application/json"} {"Accept": "application/json", "Content-Type": "application/json"}
) )

View file

@ -1,11 +1,10 @@
"""authentik captcha stage""" """authentik captcha stage"""
from django.http.response import HttpResponse from django.http.response import HttpResponse
from requests import RequestException, post from requests import RequestException
from rest_framework.fields import CharField from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from authentik import __version__
from authentik.flows.challenge import ( from authentik.flows.challenge import (
Challenge, Challenge,
ChallengeResponse, ChallengeResponse,
@ -13,7 +12,7 @@ from authentik.flows.challenge import (
WithUserInfoChallenge, WithUserInfoChallenge,
) )
from authentik.flows.stage import ChallengeStageView from authentik.flows.stage import ChallengeStageView
from authentik.lib.utils.http import get_client_ip from authentik.lib.utils.http import get_client_ip, get_http_session
from authentik.stages.captcha.models import CaptchaStage from authentik.stages.captcha.models import CaptchaStage
@ -34,11 +33,10 @@ class CaptchaChallengeResponse(ChallengeResponse):
"""Validate captcha token""" """Validate captcha token"""
stage: CaptchaStage = self.stage.executor.current_stage stage: CaptchaStage = self.stage.executor.current_stage
try: try:
response = post( response = get_http_session().post(
"https://www.google.com/recaptcha/api/siteverify", "https://www.google.com/recaptcha/api/siteverify",
headers={ headers={
"Content-type": "application/x-www-form-urlencoded", "Content-type": "application/x-www-form-urlencoded",
"User-agent": f"authentik {__version__} ReCaptcha",
}, },
data={ data={
"secret": stage.private_key, "secret": stage.private_key,