root: reformat to 100 line width
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
b87903a209
commit
77ed25ae34
|
@ -23,9 +23,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
|
||||||
date_from = now() - timedelta(days=1)
|
date_from = now() - timedelta(days=1)
|
||||||
result = (
|
result = (
|
||||||
Event.objects.filter(created__gte=date_from, **filter_kwargs)
|
Event.objects.filter(created__gte=date_from, **filter_kwargs)
|
||||||
.annotate(
|
.annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField()))
|
||||||
age=ExpressionWrapper(now() - F("created"), output_field=DurationField())
|
|
||||||
)
|
|
||||||
.annotate(age_hours=ExtractHour("age"))
|
.annotate(age_hours=ExtractHour("age"))
|
||||||
.values("age_hours")
|
.values("age_hours")
|
||||||
.annotate(count=Count("pk"))
|
.annotate(count=Count("pk"))
|
||||||
|
@ -37,8 +35,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
|
||||||
for hour in range(0, -24, -1):
|
for hour in range(0, -24, -1):
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple())
|
"x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000,
|
||||||
* 1000,
|
|
||||||
"y_cord": data[hour * -1],
|
"y_cord": data[hour * -1],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -61,9 +61,7 @@ class SystemSerializer(PassiveSerializer):
|
||||||
return {
|
return {
|
||||||
"python_version": python_version,
|
"python_version": python_version,
|
||||||
"gunicorn_version": ".".join(str(x) for x in gunicorn_version),
|
"gunicorn_version": ".".join(str(x) for x in gunicorn_version),
|
||||||
"environment": "kubernetes"
|
"environment": "kubernetes" if SERVICE_HOST_ENV_NAME in os.environ else "compose",
|
||||||
if SERVICE_HOST_ENV_NAME in os.environ
|
|
||||||
else "compose",
|
|
||||||
"architecture": platform.machine(),
|
"architecture": platform.machine(),
|
||||||
"platform": platform.platform(),
|
"platform": platform.platform(),
|
||||||
"uname": " ".join(platform.uname()),
|
"uname": " ".join(platform.uname()),
|
||||||
|
|
|
@ -92,10 +92,7 @@ class TaskViewSet(ViewSet):
|
||||||
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
|
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
|
||||||
messages.success(
|
messages.success(
|
||||||
self.request,
|
self.request,
|
||||||
_(
|
_("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}),
|
||||||
"Successfully re-scheduled Task %(name)s!"
|
|
||||||
% {"name": task.task_name}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return Response(status=204)
|
return Response(status=204)
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
|
|
|
@ -41,9 +41,7 @@ class VersionSerializer(PassiveSerializer):
|
||||||
|
|
||||||
def get_outdated(self, instance) -> bool:
|
def get_outdated(self, instance) -> bool:
|
||||||
"""Check if we're running the latest version"""
|
"""Check if we're running the latest version"""
|
||||||
return parse(self.get_version_current(instance)) < parse(
|
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
|
||||||
self.get_version_latest(instance)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VersionView(APIView):
|
class VersionView(APIView):
|
||||||
|
|
|
@ -17,9 +17,7 @@ class WorkerView(APIView):
|
||||||
|
|
||||||
permission_classes = [IsAdminUser]
|
permission_classes = [IsAdminUser]
|
||||||
|
|
||||||
@extend_schema(
|
@extend_schema(responses=inline_serializer("Workers", fields={"count": IntegerField()}))
|
||||||
responses=inline_serializer("Workers", fields={"count": IntegerField()})
|
|
||||||
)
|
|
||||||
def get(self, request: Request) -> Response:
|
def get(self, request: Request) -> Response:
|
||||||
"""Get currently connected worker count."""
|
"""Get currently connected worker count."""
|
||||||
count = len(CELERY_APP.control.ping(timeout=0.5))
|
count = len(CELERY_APP.control.ping(timeout=0.5))
|
||||||
|
|
|
@ -37,18 +37,14 @@ def _set_prom_info():
|
||||||
def update_latest_version(self: MonitoredTask):
|
def update_latest_version(self: MonitoredTask):
|
||||||
"""Update latest version info"""
|
"""Update latest version info"""
|
||||||
try:
|
try:
|
||||||
response = get(
|
response = get("https://api.github.com/repos/goauthentik/authentik/releases/latest")
|
||||||
"https://api.github.com/repos/goauthentik/authentik/releases/latest"
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
tag_name = data.get("tag_name")
|
tag_name = data.get("tag_name")
|
||||||
upstream_version = tag_name.split("/")[1]
|
upstream_version = tag_name.split("/")[1]
|
||||||
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
|
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
|
||||||
self.set_status(
|
self.set_status(
|
||||||
TaskResult(
|
TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"])
|
||||||
TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
_set_prom_info()
|
_set_prom_info()
|
||||||
# Check if upstream version is newer than what we're running,
|
# Check if upstream version is newer than what we're running,
|
||||||
|
|
|
@ -27,9 +27,7 @@ class TestAdminAPI(TestCase):
|
||||||
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
|
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
body = loads(response.content)
|
body = loads(response.content)
|
||||||
self.assertTrue(
|
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
|
||||||
any(task["task_name"] == "clean_expired_models" for task in body)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_tasks_single(self):
|
def test_tasks_single(self):
|
||||||
"""Test Task API (read single)"""
|
"""Test Task API (read single)"""
|
||||||
|
@ -45,9 +43,7 @@ class TestAdminAPI(TestCase):
|
||||||
self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name)
|
self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name)
|
||||||
self.assertEqual(body["task_name"], "clean_expired_models")
|
self.assertEqual(body["task_name"], "clean_expired_models")
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse(
|
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
|
||||||
"authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 404)
|
self.assertEqual(response.status_code, 404)
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,7 @@ from rest_framework.response import Response
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
|
|
||||||
def permission_required(
|
def permission_required(perm: Optional[str] = None, other_perms: Optional[list[str]] = None):
|
||||||
perm: Optional[str] = None, other_perms: Optional[list[str]] = None
|
|
||||||
):
|
|
||||||
"""Check permissions for a single custom action"""
|
"""Check permissions for a single custom action"""
|
||||||
|
|
||||||
def wrapper_outter(func: Callable):
|
def wrapper_outter(func: Callable):
|
||||||
|
|
|
@ -63,9 +63,7 @@ def postprocess_schema_responses(result, generator, **kwargs): # noqa: W0613
|
||||||
method["responses"].setdefault("400", validation_error.ref)
|
method["responses"].setdefault("400", validation_error.ref)
|
||||||
method["responses"].setdefault("403", generic_error.ref)
|
method["responses"].setdefault("403", generic_error.ref)
|
||||||
|
|
||||||
result["components"] = generator.registry.build(
|
result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)
|
||||||
spectacular_settings.APPEND_COMPONENTS
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is a workaround for authentik/stages/prompt/stage.py
|
# This is a workaround for authentik/stages/prompt/stage.py
|
||||||
# since the serializer PromptChallengeResponse
|
# since the serializer PromptChallengeResponse
|
||||||
|
|
|
@ -16,17 +16,13 @@ class TestAPIAuth(TestCase):
|
||||||
|
|
||||||
def test_valid_basic(self):
|
def test_valid_basic(self):
|
||||||
"""Test valid token"""
|
"""Test valid token"""
|
||||||
token = Token.objects.create(
|
token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user())
|
||||||
intent=TokenIntents.INTENT_API, user=get_anonymous_user()
|
|
||||||
)
|
|
||||||
auth = b64encode(f":{token.key}".encode()).decode()
|
auth = b64encode(f":{token.key}".encode()).decode()
|
||||||
self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user)
|
self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user)
|
||||||
|
|
||||||
def test_valid_bearer(self):
|
def test_valid_bearer(self):
|
||||||
"""Test valid token"""
|
"""Test valid token"""
|
||||||
token = Token.objects.create(
|
token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user())
|
||||||
intent=TokenIntents.INTENT_API, user=get_anonymous_user()
|
|
||||||
)
|
|
||||||
self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user)
|
self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user)
|
||||||
|
|
||||||
def test_invalid_type(self):
|
def test_invalid_type(self):
|
||||||
|
|
|
@ -52,20 +52,12 @@ from authentik.policies.reputation.api import (
|
||||||
from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet
|
from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet
|
||||||
from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet
|
from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet
|
||||||
from authentik.providers.oauth2.api.scope import ScopeMappingViewSet
|
from authentik.providers.oauth2.api.scope import ScopeMappingViewSet
|
||||||
from authentik.providers.oauth2.api.tokens import (
|
from authentik.providers.oauth2.api.tokens import AuthorizationCodeViewSet, RefreshTokenViewSet
|
||||||
AuthorizationCodeViewSet,
|
from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet
|
||||||
RefreshTokenViewSet,
|
|
||||||
)
|
|
||||||
from authentik.providers.proxy.api import (
|
|
||||||
ProxyOutpostConfigViewSet,
|
|
||||||
ProxyProviderViewSet,
|
|
||||||
)
|
|
||||||
from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet
|
from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet
|
||||||
from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet
|
from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet
|
||||||
from authentik.sources.oauth.api.source import OAuthSourceViewSet
|
from authentik.sources.oauth.api.source import OAuthSourceViewSet
|
||||||
from authentik.sources.oauth.api.source_connection import (
|
from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet
|
||||||
UserOAuthSourceConnectionViewSet,
|
|
||||||
)
|
|
||||||
from authentik.sources.plex.api import PlexSourceViewSet
|
from authentik.sources.plex.api import PlexSourceViewSet
|
||||||
from authentik.sources.saml.api import SAMLSourceViewSet
|
from authentik.sources.saml.api import SAMLSourceViewSet
|
||||||
from authentik.stages.authenticator_duo.api import (
|
from authentik.stages.authenticator_duo.api import (
|
||||||
|
@ -83,9 +75,7 @@ from authentik.stages.authenticator_totp.api import (
|
||||||
TOTPAdminDeviceViewSet,
|
TOTPAdminDeviceViewSet,
|
||||||
TOTPDeviceViewSet,
|
TOTPDeviceViewSet,
|
||||||
)
|
)
|
||||||
from authentik.stages.authenticator_validate.api import (
|
from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageViewSet
|
||||||
AuthenticatorValidateStageViewSet,
|
|
||||||
)
|
|
||||||
from authentik.stages.authenticator_webauthn.api import (
|
from authentik.stages.authenticator_webauthn.api import (
|
||||||
AuthenticateWebAuthnStageViewSet,
|
AuthenticateWebAuthnStageViewSet,
|
||||||
WebAuthnAdminDeviceViewSet,
|
WebAuthnAdminDeviceViewSet,
|
||||||
|
@ -122,9 +112,7 @@ router.register("core/tenants", TenantViewSet)
|
||||||
router.register("outposts/instances", OutpostViewSet)
|
router.register("outposts/instances", OutpostViewSet)
|
||||||
router.register("outposts/service_connections/all", ServiceConnectionViewSet)
|
router.register("outposts/service_connections/all", ServiceConnectionViewSet)
|
||||||
router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet)
|
router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet)
|
||||||
router.register(
|
router.register("outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet)
|
||||||
"outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet
|
|
||||||
)
|
|
||||||
router.register("outposts/proxy", ProxyOutpostConfigViewSet)
|
router.register("outposts/proxy", ProxyOutpostConfigViewSet)
|
||||||
router.register("outposts/ldap", LDAPOutpostConfigViewSet)
|
router.register("outposts/ldap", LDAPOutpostConfigViewSet)
|
||||||
|
|
||||||
|
@ -184,9 +172,7 @@ router.register(
|
||||||
StaticAdminDeviceViewSet,
|
StaticAdminDeviceViewSet,
|
||||||
basename="admin-staticdevice",
|
basename="admin-staticdevice",
|
||||||
)
|
)
|
||||||
router.register(
|
router.register("authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice")
|
||||||
"authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice"
|
|
||||||
)
|
|
||||||
router.register(
|
router.register(
|
||||||
"authenticators/admin/webauthn",
|
"authenticators/admin/webauthn",
|
||||||
WebAuthnAdminDeviceViewSet,
|
WebAuthnAdminDeviceViewSet,
|
||||||
|
|
|
@ -147,9 +147,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||||
"""Custom list method that checks Policy based access instead of guardian"""
|
"""Custom list method that checks Policy based access instead of guardian"""
|
||||||
should_cache = request.GET.get("search", "") == ""
|
should_cache = request.GET.get("search", "") == ""
|
||||||
|
|
||||||
superuser_full_list = (
|
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
|
||||||
str(request.GET.get("superuser_full_list", "false")).lower() == "true"
|
|
||||||
)
|
|
||||||
if superuser_full_list and request.user.is_superuser:
|
if superuser_full_list and request.user.is_superuser:
|
||||||
return super().list(request)
|
return super().list(request)
|
||||||
|
|
||||||
|
@ -240,9 +238,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||||
app.save()
|
app.save()
|
||||||
return Response({})
|
return Response({})
|
||||||
|
|
||||||
@permission_required(
|
@permission_required("authentik_core.view_application", ["authentik_events.view_event"])
|
||||||
"authentik_core.view_application", ["authentik_events.view_event"]
|
|
||||||
)
|
|
||||||
@extend_schema(responses={200: CoordinateSerializer(many=True)})
|
@extend_schema(responses={200: CoordinateSerializer(many=True)})
|
||||||
@action(detail=True, pagination_class=None, filter_backends=[])
|
@action(detail=True, pagination_class=None, filter_backends=[])
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
|
|
|
@ -68,9 +68,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
"""Get parsed user agent"""
|
"""Get parsed user agent"""
|
||||||
return user_agent_parser.Parse(instance.last_user_agent)
|
return user_agent_parser.Parse(instance.last_user_agent)
|
||||||
|
|
||||||
def get_geo_ip(
|
def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover
|
||||||
self, instance: AuthenticatedSession
|
|
||||||
) -> Optional[GeoIPDict]: # pragma: no cover
|
|
||||||
"""Get parsed user agent"""
|
"""Get parsed user agent"""
|
||||||
return GEOIP_READER.city_dict(instance.last_ip)
|
return GEOIP_READER.city_dict(instance.last_ip)
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,7 @@ from rest_framework.viewsets import GenericViewSet
|
||||||
|
|
||||||
from authentik.api.decorators import permission_required
|
from authentik.api.decorators import permission_required
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import (
|
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||||
MetaNameSerializer,
|
|
||||||
PassiveSerializer,
|
|
||||||
TypeCreateSerializer,
|
|
||||||
)
|
|
||||||
from authentik.core.expression import PropertyMappingEvaluator
|
from authentik.core.expression import PropertyMappingEvaluator
|
||||||
from authentik.core.models import PropertyMapping
|
from authentik.core.models import PropertyMapping
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
|
@ -141,9 +137,7 @@ class PropertyMappingViewSet(
|
||||||
self.request,
|
self.request,
|
||||||
**test_params.validated_data.get("context", {}),
|
**test_params.validated_data.get("context", {}),
|
||||||
)
|
)
|
||||||
response_data["result"] = dumps(
|
response_data["result"] = dumps(result, indent=(4 if format_result else None))
|
||||||
result, indent=(4 if format_result else None)
|
|
||||||
)
|
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
response_data["result"] = str(exc)
|
response_data["result"] = str(exc)
|
||||||
response_data["successful"] = False
|
response_data["successful"] = False
|
||||||
|
|
|
@ -93,9 +93,7 @@ class SourceViewSet(
|
||||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
def user_settings(self, request: Request) -> Response:
|
def user_settings(self, request: Request) -> Response:
|
||||||
"""Get all sources the user can configure"""
|
"""Get all sources the user can configure"""
|
||||||
_all_sources: Iterable[Source] = Source.objects.filter(
|
_all_sources: Iterable[Source] = Source.objects.filter(enabled=True).select_subclasses()
|
||||||
enabled=True
|
|
||||||
).select_subclasses()
|
|
||||||
matching_sources: list[UserSettingSerializer] = []
|
matching_sources: list[UserSettingSerializer] = []
|
||||||
for source in _all_sources:
|
for source in _all_sources:
|
||||||
user_settings = source.ui_user_settings
|
user_settings = source.ui_user_settings
|
||||||
|
|
|
@ -70,9 +70,7 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
|
||||||
serializer.save(
|
serializer.save(
|
||||||
user=self.request.user,
|
user=self.request.user,
|
||||||
intent=TokenIntents.INTENT_API,
|
intent=TokenIntents.INTENT_API,
|
||||||
expiring=self.request.user.attributes.get(
|
expiring=self.request.user.attributes.get(USER_ATTRIBUTE_TOKEN_EXPIRING, True),
|
||||||
USER_ATTRIBUTE_TOKEN_EXPIRING, True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@permission_required("authentik_core.view_token_key")
|
@permission_required("authentik_core.view_token_key")
|
||||||
|
@ -89,7 +87,5 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
|
||||||
token: Token = self.get_object()
|
token: Token = self.get_object()
|
||||||
if token.is_expired:
|
if token.is_expired:
|
||||||
raise Http404
|
raise Http404
|
||||||
Event.new(EventAction.SECRET_VIEW, secret=token).from_http( # noqa # nosec
|
Event.new(EventAction.SECRET_VIEW, secret=token).from_http(request) # noqa # nosec
|
||||||
request
|
|
||||||
)
|
|
||||||
return Response(TokenViewSerializer({"key": token.key}).data)
|
return Response(TokenViewSerializer({"key": token.key}).data)
|
||||||
|
|
|
@ -79,9 +79,7 @@ class UsedByMixin:
|
||||||
).all():
|
).all():
|
||||||
# Only merge shadows on first object
|
# Only merge shadows on first object
|
||||||
if first_object:
|
if first_object:
|
||||||
shadows += getattr(
|
shadows += getattr(manager.model._meta, "authentik_used_by_shadows", [])
|
||||||
manager.model._meta, "authentik_used_by_shadows", []
|
|
||||||
)
|
|
||||||
first_object = False
|
first_object = False
|
||||||
serializer = UsedBySerializer(
|
serializer = UsedBySerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -26,10 +26,7 @@ from authentik.api.decorators import permission_required
|
||||||
from authentik.core.api.groups import GroupSerializer
|
from authentik.core.api.groups import GroupSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
|
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
|
||||||
from authentik.core.middleware import (
|
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
|
||||||
SESSION_IMPERSONATE_ORIGINAL_USER,
|
|
||||||
SESSION_IMPERSONATE_USER,
|
|
||||||
)
|
|
||||||
from authentik.core.models import Token, TokenIntents, User
|
from authentik.core.models import Token, TokenIntents, User
|
||||||
from authentik.events.models import EventAction
|
from authentik.events.models import EventAction
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
|
@ -87,17 +84,13 @@ class UserMetricsSerializer(PassiveSerializer):
|
||||||
def get_logins_failed_per_1h(self, _):
|
def get_logins_failed_per_1h(self, _):
|
||||||
"""Get failed logins per hour for the last 24 hours"""
|
"""Get failed logins per hour for the last 24 hours"""
|
||||||
user = self.context["user"]
|
user = self.context["user"]
|
||||||
return get_events_per_1h(
|
return get_events_per_1h(action=EventAction.LOGIN_FAILED, context__username=user.username)
|
||||||
action=EventAction.LOGIN_FAILED, context__username=user.username
|
|
||||||
)
|
|
||||||
|
|
||||||
@extend_schema_field(CoordinateSerializer(many=True))
|
@extend_schema_field(CoordinateSerializer(many=True))
|
||||||
def get_authorizations_per_1h(self, _):
|
def get_authorizations_per_1h(self, _):
|
||||||
"""Get failed logins per hour for the last 24 hours"""
|
"""Get failed logins per hour for the last 24 hours"""
|
||||||
user = self.context["user"]
|
user = self.context["user"]
|
||||||
return get_events_per_1h(
|
return get_events_per_1h(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk)
|
||||||
action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UsersFilter(FilterSet):
|
class UsersFilter(FilterSet):
|
||||||
|
@ -154,9 +147,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
def me(self, request: Request) -> Response:
|
def me(self, request: Request) -> Response:
|
||||||
"""Get information about current user"""
|
"""Get information about current user"""
|
||||||
serializer = SessionUserSerializer(
|
serializer = SessionUserSerializer(data={"user": UserSerializer(request.user).data})
|
||||||
data={"user": UserSerializer(request.user).data}
|
|
||||||
)
|
|
||||||
if SESSION_IMPERSONATE_USER in request._request.session:
|
if SESSION_IMPERSONATE_USER in request._request.session:
|
||||||
serializer.initial_data["original"] = UserSerializer(
|
serializer.initial_data["original"] = UserSerializer(
|
||||||
request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER]
|
request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER]
|
||||||
|
|
|
@ -3,20 +3,14 @@ from typing import Any
|
||||||
|
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from rest_framework.fields import CharField, IntegerField
|
from rest_framework.fields import CharField, IntegerField
|
||||||
from rest_framework.serializers import (
|
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
|
||||||
Serializer,
|
|
||||||
SerializerMethodField,
|
|
||||||
ValidationError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_dict(value: Any):
|
def is_dict(value: Any):
|
||||||
"""Ensure a value is a dictionary, useful for JSONFields"""
|
"""Ensure a value is a dictionary, useful for JSONFields"""
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return
|
return
|
||||||
raise ValidationError(
|
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
|
||||||
"Value must be a dictionary, and not have any duplicate keys."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PassiveSerializer(Serializer):
|
class PassiveSerializer(Serializer):
|
||||||
|
@ -25,9 +19,7 @@ class PassiveSerializer(Serializer):
|
||||||
def create(self, validated_data: dict) -> Model: # pragma: no cover
|
def create(self, validated_data: dict) -> Model: # pragma: no cover
|
||||||
return Model()
|
return Model()
|
||||||
|
|
||||||
def update(
|
def update(self, instance: Model, validated_data: dict) -> Model: # pragma: no cover
|
||||||
self, instance: Model, validated_data: dict
|
|
||||||
) -> Model: # pragma: no cover
|
|
||||||
return Model()
|
return Model()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
|
@ -38,9 +38,7 @@ class Migration(migrations.Migration):
|
||||||
("password", models.CharField(max_length=128, verbose_name="password")),
|
("password", models.CharField(max_length=128, verbose_name="password")),
|
||||||
(
|
(
|
||||||
"last_login",
|
"last_login",
|
||||||
models.DateTimeField(
|
models.DateTimeField(blank=True, null=True, verbose_name="last login"),
|
||||||
blank=True, null=True, verbose_name="last login"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"is_superuser",
|
"is_superuser",
|
||||||
|
@ -53,35 +51,25 @@ class Migration(migrations.Migration):
|
||||||
(
|
(
|
||||||
"username",
|
"username",
|
||||||
models.CharField(
|
models.CharField(
|
||||||
error_messages={
|
error_messages={"unique": "A user with that username already exists."},
|
||||||
"unique": "A user with that username already exists."
|
|
||||||
},
|
|
||||||
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
|
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
|
||||||
max_length=150,
|
max_length=150,
|
||||||
unique=True,
|
unique=True,
|
||||||
validators=[
|
validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
|
||||||
django.contrib.auth.validators.UnicodeUsernameValidator()
|
|
||||||
],
|
|
||||||
verbose_name="username",
|
verbose_name="username",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"first_name",
|
"first_name",
|
||||||
models.CharField(
|
models.CharField(blank=True, max_length=30, verbose_name="first name"),
|
||||||
blank=True, max_length=30, verbose_name="first name"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"last_name",
|
"last_name",
|
||||||
models.CharField(
|
models.CharField(blank=True, max_length=150, verbose_name="last name"),
|
||||||
blank=True, max_length=150, verbose_name="last name"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"email",
|
"email",
|
||||||
models.EmailField(
|
models.EmailField(blank=True, max_length=254, verbose_name="email address"),
|
||||||
blank=True, max_length=254, verbose_name="email address"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"is_staff",
|
"is_staff",
|
||||||
|
@ -217,9 +205,7 @@ class Migration(migrations.Migration):
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"expires",
|
"expires",
|
||||||
models.DateTimeField(
|
models.DateTimeField(default=authentik.core.models.default_token_duration),
|
||||||
default=authentik.core.models.default_token_duration
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
("expiring", models.BooleanField(default=True)),
|
("expiring", models.BooleanField(default=True)),
|
||||||
("description", models.TextField(blank=True, default="")),
|
("description", models.TextField(blank=True, default="")),
|
||||||
|
@ -306,9 +292,7 @@ class Migration(migrations.Migration):
|
||||||
("name", models.TextField(help_text="Application's display Name.")),
|
("name", models.TextField(help_text="Application's display Name.")),
|
||||||
(
|
(
|
||||||
"slug",
|
"slug",
|
||||||
models.SlugField(
|
models.SlugField(help_text="Internal application name, used in URLs."),
|
||||||
help_text="Internal application name, used in URLs."
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
("skip_authorization", models.BooleanField(default=False)),
|
("skip_authorization", models.BooleanField(default=False)),
|
||||||
("meta_launch_url", models.URLField(blank=True, default="")),
|
("meta_launch_url", models.URLField(blank=True, default="")),
|
||||||
|
|
|
@ -17,9 +17,7 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
username="akadmin", email="root@localhost", name="authentik Default Admin"
|
username="akadmin", email="root@localhost", name="authentik Default Admin"
|
||||||
)
|
)
|
||||||
if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST:
|
if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST:
|
||||||
akadmin.set_password(
|
akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False) # noqa # nosec
|
||||||
environ.get("AK_ADMIN_PASS", "akadmin"), signal=False
|
|
||||||
) # noqa # nosec
|
|
||||||
else:
|
else:
|
||||||
akadmin.set_unusable_password()
|
akadmin.set_unusable_password()
|
||||||
akadmin.save()
|
akadmin.save()
|
||||||
|
|
|
@ -13,8 +13,6 @@ class Migration(migrations.Migration):
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="source",
|
model_name="source",
|
||||||
name="slug",
|
name="slug",
|
||||||
field=models.SlugField(
|
field=models.SlugField(help_text="Internal source name, used in URLs.", unique=True),
|
||||||
help_text="Internal source name, used in URLs.", unique=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -13,8 +13,6 @@ class Migration(migrations.Migration):
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="user",
|
model_name="user",
|
||||||
name="first_name",
|
name="first_name",
|
||||||
field=models.CharField(
|
field=models.CharField(blank=True, max_length=150, verbose_name="first name"),
|
||||||
blank=True, max_length=150, verbose_name="first name"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -40,9 +40,7 @@ class Migration(migrations.Migration):
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="user",
|
model_name="user",
|
||||||
name="pb_groups",
|
name="pb_groups",
|
||||||
field=models.ManyToManyField(
|
field=models.ManyToManyField(related_name="users", to="authentik_core.Group"),
|
||||||
related_name="users", to="authentik_core.Group"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="group",
|
model_name="group",
|
||||||
|
|
|
@ -42,9 +42,7 @@ class Migration(migrations.Migration):
|
||||||
),
|
),
|
||||||
migrations.AddIndex(
|
migrations.AddIndex(
|
||||||
model_name="token",
|
model_name="token",
|
||||||
index=models.Index(
|
index=models.Index(fields=["identifier"], name="authentik_co_identif_1a34a8_idx"),
|
||||||
fields=["identifier"], name="authentik_co_identif_1a34a8_idx"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
migrations.RunPython(set_default_token_key),
|
migrations.RunPython(set_default_token_key),
|
||||||
]
|
]
|
||||||
|
|
|
@ -17,8 +17,6 @@ class Migration(migrations.Migration):
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="application",
|
model_name="application",
|
||||||
name="meta_icon",
|
name="meta_icon",
|
||||||
field=models.FileField(
|
field=models.FileField(blank=True, default="", upload_to="application-icons/"),
|
||||||
blank=True, default="", upload_to="application-icons/"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -25,9 +25,7 @@ class Migration(migrations.Migration):
|
||||||
),
|
),
|
||||||
migrations.AddIndex(
|
migrations.AddIndex(
|
||||||
model_name="token",
|
model_name="token",
|
||||||
index=models.Index(
|
index=models.Index(fields=["identifier"], name="authentik_c_identif_d9d032_idx"),
|
||||||
fields=["identifier"], name="authentik_c_identif_d9d032_idx"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
migrations.AddIndex(
|
migrations.AddIndex(
|
||||||
model_name="token",
|
model_name="token",
|
||||||
|
|
|
@ -32,16 +32,12 @@ class Migration(migrations.Migration):
|
||||||
fields=[
|
fields=[
|
||||||
(
|
(
|
||||||
"expires",
|
"expires",
|
||||||
models.DateTimeField(
|
models.DateTimeField(default=authentik.core.models.default_token_duration),
|
||||||
default=authentik.core.models.default_token_duration
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
("expiring", models.BooleanField(default=True)),
|
("expiring", models.BooleanField(default=True)),
|
||||||
(
|
(
|
||||||
"uuid",
|
"uuid",
|
||||||
models.UUIDField(
|
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||||
default=uuid.uuid4, primary_key=True, serialize=False
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
("session_key", models.CharField(max_length=40)),
|
("session_key", models.CharField(max_length=40)),
|
||||||
("last_ip", models.TextField()),
|
("last_ip", models.TextField()),
|
||||||
|
|
|
@ -13,8 +13,6 @@ class Migration(migrations.Migration):
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="application",
|
model_name="application",
|
||||||
name="meta_icon",
|
name="meta_icon",
|
||||||
field=models.FileField(
|
field=models.FileField(default=None, null=True, upload_to="application-icons/"),
|
||||||
default=None, null=True, upload_to="application-icons/"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -154,9 +154,7 @@ class User(GuardianUserMixin, AbstractUser):
|
||||||
("s", "158"),
|
("s", "158"),
|
||||||
("r", "g"),
|
("r", "g"),
|
||||||
]
|
]
|
||||||
gravatar_url = (
|
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
|
||||||
f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
|
|
||||||
)
|
|
||||||
return escape(gravatar_url)
|
return escape(gravatar_url)
|
||||||
return mode % {
|
return mode % {
|
||||||
"username": self.username,
|
"username": self.username,
|
||||||
|
@ -186,9 +184,7 @@ class Provider(SerializerModel):
|
||||||
related_name="provider_authorization",
|
related_name="provider_authorization",
|
||||||
)
|
)
|
||||||
|
|
||||||
property_mappings = models.ManyToManyField(
|
property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
|
||||||
"PropertyMapping", default=None, blank=True
|
|
||||||
)
|
|
||||||
|
|
||||||
objects = InheritanceManager()
|
objects = InheritanceManager()
|
||||||
|
|
||||||
|
@ -218,9 +214,7 @@ class Application(PolicyBindingModel):
|
||||||
add custom fields and other properties"""
|
add custom fields and other properties"""
|
||||||
|
|
||||||
name = models.TextField(help_text=_("Application's display Name."))
|
name = models.TextField(help_text=_("Application's display Name."))
|
||||||
slug = models.SlugField(
|
slug = models.SlugField(help_text=_("Internal application name, used in URLs."), unique=True)
|
||||||
help_text=_("Internal application name, used in URLs."), unique=True
|
|
||||||
)
|
|
||||||
provider = models.OneToOneField(
|
provider = models.OneToOneField(
|
||||||
"Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT
|
"Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT
|
||||||
)
|
)
|
||||||
|
@ -244,9 +238,7 @@ class Application(PolicyBindingModel):
|
||||||
it is returned as-is"""
|
it is returned as-is"""
|
||||||
if not self.meta_icon:
|
if not self.meta_icon:
|
||||||
return None
|
return None
|
||||||
if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith(
|
if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"):
|
||||||
"/static"
|
|
||||||
):
|
|
||||||
return self.meta_icon.name
|
return self.meta_icon.name
|
||||||
return self.meta_icon.url
|
return self.meta_icon.url
|
||||||
|
|
||||||
|
@ -301,14 +293,10 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||||
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
|
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
|
||||||
|
|
||||||
name = models.TextField(help_text=_("Source's display Name."))
|
name = models.TextField(help_text=_("Source's display Name."))
|
||||||
slug = models.SlugField(
|
slug = models.SlugField(help_text=_("Internal source name, used in URLs."), unique=True)
|
||||||
help_text=_("Internal source name, used in URLs."), unique=True
|
|
||||||
)
|
|
||||||
|
|
||||||
enabled = models.BooleanField(default=True)
|
enabled = models.BooleanField(default=True)
|
||||||
property_mappings = models.ManyToManyField(
|
property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
|
||||||
"PropertyMapping", default=None, blank=True
|
|
||||||
)
|
|
||||||
|
|
||||||
authentication_flow = models.ForeignKey(
|
authentication_flow = models.ForeignKey(
|
||||||
Flow,
|
Flow,
|
||||||
|
@ -481,9 +469,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
|
||||||
"""Get serializer for this model"""
|
"""Get serializer for this model"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any:
|
||||||
self, user: Optional[User], request: Optional[HttpRequest], **kwargs
|
|
||||||
) -> Any:
|
|
||||||
"""Evaluate `self.expression` using `**kwargs` as Context."""
|
"""Evaluate `self.expression` using `**kwargs` as Context."""
|
||||||
from authentik.core.expression import PropertyMappingEvaluator
|
from authentik.core.expression import PropertyMappingEvaluator
|
||||||
|
|
||||||
|
@ -522,9 +508,7 @@ class AuthenticatedSession(ExpiringModel):
|
||||||
last_used = models.DateTimeField(auto_now=True)
|
last_used = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_request(
|
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
|
||||||
request: HttpRequest, user: User
|
|
||||||
) -> Optional["AuthenticatedSession"]:
|
|
||||||
"""Create a new session from a http request"""
|
"""Create a new session from a http request"""
|
||||||
if not hasattr(request, "session") or not request.session.session_key:
|
if not hasattr(request, "session") or not request.session.session_key:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -14,9 +14,7 @@ from prometheus_client import Gauge
|
||||||
# Arguments: user: User, password: str
|
# Arguments: user: User, password: str
|
||||||
password_changed = Signal()
|
password_changed = Signal()
|
||||||
|
|
||||||
GAUGE_MODELS = Gauge(
|
GAUGE_MODELS = Gauge("authentik_models", "Count of various objects", ["model_name", "app"])
|
||||||
"authentik_models", "Count of various objects", ["model_name", "app"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.core.models import AuthenticatedSession, User
|
from authentik.core.models import AuthenticatedSession, User
|
||||||
|
@ -60,15 +58,11 @@ def user_logged_out_session(sender, request: HttpRequest, user: "User", **_):
|
||||||
"""Delete AuthenticatedSession if it exists"""
|
"""Delete AuthenticatedSession if it exists"""
|
||||||
from authentik.core.models import AuthenticatedSession
|
from authentik.core.models import AuthenticatedSession
|
||||||
|
|
||||||
AuthenticatedSession.objects.filter(
|
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
|
||||||
session_key=request.session.session_key
|
|
||||||
).delete()
|
|
||||||
|
|
||||||
|
|
||||||
@receiver(pre_delete)
|
@receiver(pre_delete)
|
||||||
def authenticated_session_delete(
|
def authenticated_session_delete(sender: Type[Model], instance: "AuthenticatedSession", **_):
|
||||||
sender: Type[Model], instance: "AuthenticatedSession", **_
|
|
||||||
):
|
|
||||||
"""Delete session when authenticated session is deleted"""
|
"""Delete session when authenticated session is deleted"""
|
||||||
from authentik.core.models import AuthenticatedSession
|
from authentik.core.models import AuthenticatedSession
|
||||||
|
|
||||||
|
|
|
@ -11,16 +11,8 @@ from django.urls import reverse
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.models import (
|
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
|
||||||
Source,
|
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage
|
||||||
SourceUserMatchingModes,
|
|
||||||
User,
|
|
||||||
UserSourceConnection,
|
|
||||||
)
|
|
||||||
from authentik.core.sources.stage import (
|
|
||||||
PLAN_CONTEXT_SOURCES_CONNECTION,
|
|
||||||
PostUserEnrollmentStage,
|
|
||||||
)
|
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.flows.models import Flow, Stage, in_memory_stage
|
from authentik.flows.models import Flow, Stage, in_memory_stage
|
||||||
from authentik.flows.planner import (
|
from authentik.flows.planner import (
|
||||||
|
@ -76,9 +68,7 @@ class SourceFlowManager:
|
||||||
# pylint: disable=too-many-return-statements
|
# pylint: disable=too-many-return-statements
|
||||||
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
|
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
|
||||||
"""decide which action should be taken"""
|
"""decide which action should be taken"""
|
||||||
new_connection = self.connection_type(
|
new_connection = self.connection_type(source=self.source, identifier=self.identifier)
|
||||||
source=self.source, identifier=self.identifier
|
|
||||||
)
|
|
||||||
# When request is authenticated, always link
|
# When request is authenticated, always link
|
||||||
if self.request.user.is_authenticated:
|
if self.request.user.is_authenticated:
|
||||||
new_connection.user = self.request.user
|
new_connection.user = self.request.user
|
||||||
|
@ -113,9 +103,7 @@ class SourceFlowManager:
|
||||||
SourceUserMatchingModes.USERNAME_DENY,
|
SourceUserMatchingModes.USERNAME_DENY,
|
||||||
]:
|
]:
|
||||||
if not self.enroll_info.get("username", None):
|
if not self.enroll_info.get("username", None):
|
||||||
self._logger.warning(
|
self._logger.warning("Refusing to use none username", source=self.source)
|
||||||
"Refusing to use none username", source=self.source
|
|
||||||
)
|
|
||||||
return Action.DENY, None
|
return Action.DENY, None
|
||||||
query = Q(username__exact=self.enroll_info.get("username", None))
|
query = Q(username__exact=self.enroll_info.get("username", None))
|
||||||
self._logger.debug("trying to link with existing user", query=query)
|
self._logger.debug("trying to link with existing user", query=query)
|
||||||
|
@ -229,10 +217,7 @@ class SourceFlowManager:
|
||||||
"""Login user and redirect."""
|
"""Login user and redirect."""
|
||||||
messages.success(
|
messages.success(
|
||||||
self.request,
|
self.request,
|
||||||
_(
|
_("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
|
||||||
"Successfully authenticated with %(source)s!"
|
|
||||||
% {"source": self.source.name}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
|
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
|
||||||
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
|
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
|
||||||
|
@ -270,10 +255,7 @@ class SourceFlowManager:
|
||||||
"""User was not authenticated and previous request was not authenticated."""
|
"""User was not authenticated and previous request was not authenticated."""
|
||||||
messages.success(
|
messages.success(
|
||||||
self.request,
|
self.request,
|
||||||
_(
|
_("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
|
||||||
"Successfully authenticated with %(source)s!"
|
|
||||||
% {"source": self.source.name}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We run the Flow planner here so we can pass the Pending user in the context
|
# We run the Flow planner here so we can pass the Pending user in the context
|
||||||
|
|
|
@ -27,9 +27,7 @@ def clean_expired_models(self: MonitoredTask):
|
||||||
for cls in ExpiringModel.__subclasses__():
|
for cls in ExpiringModel.__subclasses__():
|
||||||
cls: ExpiringModel
|
cls: ExpiringModel
|
||||||
objects = (
|
objects = (
|
||||||
cls.objects.all()
|
cls.objects.all().exclude(expiring=False).exclude(expiring=True, expires__gt=now())
|
||||||
.exclude(expiring=False)
|
|
||||||
.exclude(expiring=True, expires__gt=now())
|
|
||||||
)
|
)
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
obj.expire_action()
|
obj.expire_action()
|
||||||
|
|
|
@ -17,9 +17,7 @@ class TestApplicationsAPI(APITestCase):
|
||||||
self.denied = Application.objects.create(name="denied", slug="denied")
|
self.denied = Application.objects.create(name="denied", slug="denied")
|
||||||
PolicyBinding.objects.create(
|
PolicyBinding.objects.create(
|
||||||
target=self.denied,
|
target=self.denied,
|
||||||
policy=DummyPolicy.objects.create(
|
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||||
name="deny", result=False, wait_min=1, wait_max=2
|
|
||||||
),
|
|
||||||
order=0,
|
order=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,9 +31,7 @@ class TestApplicationsAPI(APITestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(force_str(response.content), {"messages": [], "passing": True})
|
||||||
force_str(response.content), {"messages": [], "passing": True}
|
|
||||||
)
|
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse(
|
reverse(
|
||||||
"authentik_api:application-check-access",
|
"authentik_api:application-check-access",
|
||||||
|
@ -43,9 +39,7 @@ class TestApplicationsAPI(APITestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(force_str(response.content), {"messages": ["dummy"], "passing": False})
|
||||||
force_str(response.content), {"messages": ["dummy"], "passing": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_list(self):
|
def test_list(self):
|
||||||
"""Test list operation without superuser_full_list"""
|
"""Test list operation without superuser_full_list"""
|
||||||
|
|
|
@ -46,9 +46,7 @@ class TestImpersonation(TestCase):
|
||||||
self.client.force_login(self.other_user)
|
self.client.force_login(self.other_user)
|
||||||
|
|
||||||
self.client.get(
|
self.client.get(
|
||||||
reverse(
|
reverse("authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk})
|
||||||
"authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.get(reverse("authentik_api:user-me"))
|
response = self.client.get(reverse("authentik_api:user-me"))
|
||||||
|
|
|
@ -22,9 +22,7 @@ class TestModels(TestCase):
|
||||||
|
|
||||||
def test_token_expire_no_expire(self):
|
def test_token_expire_no_expire(self):
|
||||||
"""Test token expiring with "expiring" set"""
|
"""Test token expiring with "expiring" set"""
|
||||||
token = Token.objects.create(
|
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
|
||||||
expires=now(), user=get_anonymous_user(), expiring=False
|
|
||||||
)
|
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
self.assertFalse(token.is_expired)
|
self.assertFalse(token.is_expired)
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,7 @@ class TestPropertyMappings(TestCase):
|
||||||
|
|
||||||
def test_expression(self):
|
def test_expression(self):
|
||||||
"""Test expression"""
|
"""Test expression"""
|
||||||
mapping = PropertyMapping.objects.create(
|
mapping = PropertyMapping.objects.create(name="test", expression="return 'test'")
|
||||||
name="test", expression="return 'test'"
|
|
||||||
)
|
|
||||||
self.assertEqual(mapping.evaluate(None, None), "test")
|
self.assertEqual(mapping.evaluate(None, None), "test")
|
||||||
|
|
||||||
def test_expression_syntax(self):
|
def test_expression_syntax(self):
|
||||||
|
|
|
@ -23,9 +23,7 @@ class TestPropertyMappingAPI(APITestCase):
|
||||||
def test_test_call(self):
|
def test_test_call(self):
|
||||||
"""Test PropertMappings's test endpoint"""
|
"""Test PropertMappings's test endpoint"""
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
reverse(
|
reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}),
|
||||||
"authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}
|
|
||||||
),
|
|
||||||
data={
|
data={
|
||||||
"user": self.user.pk,
|
"user": self.user.pk,
|
||||||
},
|
},
|
||||||
|
|
|
@ -4,12 +4,7 @@ from django.utils.timezone import now
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from authentik.core.models import (
|
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
|
||||||
USER_ATTRIBUTE_TOKEN_EXPIRING,
|
|
||||||
Token,
|
|
||||||
TokenIntents,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
from authentik.core.tasks import clean_expired_models
|
from authentik.core.tasks import clean_expired_models
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,7 @@ from django.shortcuts import get_object_or_404, redirect
|
||||||
from django.views import View
|
from django.views import View
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.middleware import (
|
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
|
||||||
SESSION_IMPERSONATE_ORIGINAL_USER,
|
|
||||||
SESSION_IMPERSONATE_USER,
|
|
||||||
)
|
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
|
|
||||||
|
@ -21,9 +18,7 @@ class ImpersonateInitView(View):
|
||||||
def get(self, request: HttpRequest, user_id: int) -> HttpResponse:
|
def get(self, request: HttpRequest, user_id: int) -> HttpResponse:
|
||||||
"""Impersonation handler, checks permissions"""
|
"""Impersonation handler, checks permissions"""
|
||||||
if not request.user.has_perm("impersonate"):
|
if not request.user.has_perm("impersonate"):
|
||||||
LOGGER.debug(
|
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
|
||||||
"User attempted to impersonate without permissions", user=request.user
|
|
||||||
)
|
|
||||||
return HttpResponse("Unauthorized", status=401)
|
return HttpResponse("Unauthorized", status=401)
|
||||||
|
|
||||||
user_to_be = get_object_or_404(User, pk=user_id)
|
user_to_be = get_object_or_404(User, pk=user_id)
|
||||||
|
|
|
@ -14,9 +14,7 @@ class EndSessionView(TemplateView, PolicyAccessView):
|
||||||
template_name = "if/end_session.html"
|
template_name = "if/end_session.html"
|
||||||
|
|
||||||
def resolve_provider_application(self):
|
def resolve_provider_application(self):
|
||||||
self.application = get_object_or_404(
|
self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
|
||||||
Application, slug=self.kwargs["application_slug"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
context = super().get_context_data(**kwargs)
|
context = super().get_context_data(**kwargs)
|
||||||
|
|
|
@ -10,12 +10,7 @@ from django_filters.filters import BooleanFilter
|
||||||
from drf_spectacular.types import OpenApiTypes
|
from drf_spectacular.types import OpenApiTypes
|
||||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import (
|
from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField
|
||||||
CharField,
|
|
||||||
DateTimeField,
|
|
||||||
IntegerField,
|
|
||||||
SerializerMethodField,
|
|
||||||
)
|
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ModelSerializer, ValidationError
|
from rest_framework.serializers import ModelSerializer, ValidationError
|
||||||
|
@ -86,9 +81,7 @@ class CertificateKeyPairSerializer(ModelSerializer):
|
||||||
backend=default_backend(),
|
backend=default_backend(),
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
raise ValidationError(
|
raise ValidationError("Unable to load private key (possibly encrypted?).")
|
||||||
"Unable to load private key (possibly encrypted?)."
|
|
||||||
)
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -123,9 +116,7 @@ class CertificateGenerationSerializer(PassiveSerializer):
|
||||||
"""Certificate generation parameters"""
|
"""Certificate generation parameters"""
|
||||||
|
|
||||||
common_name = CharField()
|
common_name = CharField()
|
||||||
subject_alt_name = CharField(
|
subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name"))
|
||||||
required=False, allow_blank=True, label=_("Subject-alt name")
|
|
||||||
)
|
|
||||||
validity_days = IntegerField(initial=365)
|
validity_days = IntegerField(initial=365)
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,9 +161,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||||
builder = CertificateBuilder()
|
builder = CertificateBuilder()
|
||||||
builder.common_name = data.validated_data["common_name"]
|
builder.common_name = data.validated_data["common_name"]
|
||||||
builder.build(
|
builder.build(
|
||||||
subject_alt_names=data.validated_data.get("subject_alt_name", "").split(
|
subject_alt_names=data.validated_data.get("subject_alt_name", "").split(","),
|
||||||
","
|
|
||||||
),
|
|
||||||
validity_days=int(data.validated_data["validity_days"]),
|
validity_days=int(data.validated_data["validity_days"]),
|
||||||
)
|
)
|
||||||
instance = builder.save()
|
instance = builder.save()
|
||||||
|
@ -208,9 +197,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||||
"Content-Disposition"
|
"Content-Disposition"
|
||||||
] = f'attachment; filename="{certificate.name}_certificate.pem"'
|
] = f'attachment; filename="{certificate.name}_certificate.pem"'
|
||||||
return response
|
return response
|
||||||
return Response(
|
return Response(CertificateDataSerializer({"data": certificate.certificate_data}).data)
|
||||||
CertificateDataSerializer({"data": certificate.certificate_data}).data
|
|
||||||
)
|
|
||||||
|
|
||||||
@extend_schema(
|
@extend_schema(
|
||||||
parameters=[
|
parameters=[
|
||||||
|
@ -234,9 +221,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||||
).from_http(request)
|
).from_http(request)
|
||||||
if "download" in request._request.GET:
|
if "download" in request._request.GET:
|
||||||
# Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
# Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
||||||
response = HttpResponse(
|
response = HttpResponse(certificate.key_data, content_type="application/x-pem-file")
|
||||||
certificate.key_data, content_type="application/x-pem-file"
|
|
||||||
)
|
|
||||||
response[
|
response[
|
||||||
"Content-Disposition"
|
"Content-Disposition"
|
||||||
] = f'attachment; filename="{certificate.name}_private_key.pem"'
|
] = f'attachment; filename="{certificate.name}_private_key.pem"'
|
||||||
|
|
|
@ -46,9 +46,7 @@ class CertificateBuilder:
|
||||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
)
|
)
|
||||||
self.__public_key = self.__private_key.public_key()
|
self.__public_key = self.__private_key.public_key()
|
||||||
alt_names: list[x509.GeneralName] = [
|
alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []]
|
||||||
x509.DNSName(x) for x in subject_alt_names or []
|
|
||||||
]
|
|
||||||
self.__builder = (
|
self.__builder = (
|
||||||
x509.CertificateBuilder()
|
x509.CertificateBuilder()
|
||||||
.subject_name(
|
.subject_name(
|
||||||
|
@ -59,9 +57,7 @@ class CertificateBuilder:
|
||||||
self.common_name,
|
self.common_name,
|
||||||
),
|
),
|
||||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
|
||||||
x509.NameAttribute(
|
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
|
||||||
NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -77,9 +73,7 @@ class CertificateBuilder:
|
||||||
)
|
)
|
||||||
.add_extension(x509.SubjectAlternativeName(alt_names), critical=True)
|
.add_extension(x509.SubjectAlternativeName(alt_names), critical=True)
|
||||||
.not_valid_before(datetime.datetime.today() - one_day)
|
.not_valid_before(datetime.datetime.today() - one_day)
|
||||||
.not_valid_after(
|
.not_valid_after(datetime.datetime.today() + datetime.timedelta(days=validity_days))
|
||||||
datetime.datetime.today() + datetime.timedelta(days=validity_days)
|
|
||||||
)
|
|
||||||
.serial_number(int(uuid.uuid4()))
|
.serial_number(int(uuid.uuid4()))
|
||||||
.public_key(self.__public_key)
|
.public_key(self.__public_key)
|
||||||
)
|
)
|
||||||
|
|
|
@ -57,9 +57,7 @@ class CertificateKeyPair(CreatedUpdatedModel):
|
||||||
if not self._private_key and self._private_key != "":
|
if not self._private_key and self._private_key != "":
|
||||||
try:
|
try:
|
||||||
self._private_key = load_pem_private_key(
|
self._private_key = load_pem_private_key(
|
||||||
str.encode(
|
str.encode("\n".join([x.strip() for x in self.key_data.split("\n")])),
|
||||||
"\n".join([x.strip() for x in self.key_data.split("\n")])
|
|
||||||
),
|
|
||||||
password=None,
|
password=None,
|
||||||
backend=default_backend(),
|
backend=default_backend(),
|
||||||
)
|
)
|
||||||
|
@ -70,25 +68,19 @@ class CertificateKeyPair(CreatedUpdatedModel):
|
||||||
@property
|
@property
|
||||||
def fingerprint_sha256(self) -> str:
|
def fingerprint_sha256(self) -> str:
|
||||||
"""Get SHA256 Fingerprint of certificate_data"""
|
"""Get SHA256 Fingerprint of certificate_data"""
|
||||||
return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode(
|
return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode("utf-8")
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fingerprint_sha1(self) -> str:
|
def fingerprint_sha1(self) -> str:
|
||||||
"""Get SHA1 Fingerprint of certificate_data"""
|
"""Get SHA1 Fingerprint of certificate_data"""
|
||||||
return hexlify(
|
return hexlify(self.certificate.fingerprint(hashes.SHA1()), ":").decode("utf-8") # nosec
|
||||||
self.certificate.fingerprint(hashes.SHA1()), ":" # nosec
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kid(self):
|
def kid(self):
|
||||||
"""Get Key ID used for JWKS"""
|
"""Get Key ID used for JWKS"""
|
||||||
return "{0}".format(
|
return "{0}".format(
|
||||||
md5(self.key_data.encode("utf-8")).hexdigest() # nosec
|
md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else ""
|
||||||
if self.key_data
|
) # nosec
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"Certificate-Key Pair {self.name}"
|
return f"Certificate-Key Pair {self.name}"
|
||||||
|
|
|
@ -143,7 +143,5 @@ class EventViewSet(ModelViewSet):
|
||||||
"""Get all actions"""
|
"""Get all actions"""
|
||||||
data = []
|
data = []
|
||||||
for value, name in EventAction.choices:
|
for value, name in EventAction.choices:
|
||||||
data.append(
|
data.append({"name": name, "description": "", "component": value, "model_name": ""})
|
||||||
{"name": name, "description": "", "component": value, "model_name": ""}
|
|
||||||
)
|
|
||||||
return Response(TypeCreateSerializer(data, many=True).data)
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
|
@ -29,12 +29,8 @@ class AuditMiddleware:
|
||||||
|
|
||||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||||
# Connect signal for automatic logging
|
# Connect signal for automatic logging
|
||||||
if hasattr(request, "user") and getattr(
|
if hasattr(request, "user") and getattr(request.user, "is_authenticated", False):
|
||||||
request.user, "is_authenticated", False
|
post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
|
||||||
):
|
|
||||||
post_save_handler = partial(
|
|
||||||
self.post_save_handler, user=request.user, request=request
|
|
||||||
)
|
|
||||||
pre_delete_handler = partial(
|
pre_delete_handler = partial(
|
||||||
self.pre_delete_handler, user=request.user, request=request
|
self.pre_delete_handler, user=request.user, request=request
|
||||||
)
|
)
|
||||||
|
@ -94,13 +90,9 @@ class AuditMiddleware:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def pre_delete_handler(
|
def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_):
|
||||||
user: User, request: HttpRequest, sender, instance: Model, **_
|
|
||||||
):
|
|
||||||
"""Signal handler for all object's pre_delete"""
|
"""Signal handler for all object's pre_delete"""
|
||||||
if isinstance(
|
if isinstance(instance, (Event, Notification, UserObjectPermission)): # pragma: no cover
|
||||||
instance, (Event, Notification, UserObjectPermission)
|
|
||||||
): # pragma: no cover
|
|
||||||
return
|
return
|
||||||
|
|
||||||
EventNewThread(
|
EventNewThread(
|
||||||
|
|
|
@ -14,9 +14,7 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
event.delete()
|
event.delete()
|
||||||
# Because event objects cannot be updated, we have to re-create them
|
# Because event objects cannot be updated, we have to re-create them
|
||||||
event.pk = None
|
event.pk = None
|
||||||
event.user_json = (
|
event.user_json = authentik.events.models.get_user(event.user) if event.user else {}
|
||||||
authentik.events.models.get_user(event.user) if event.user else {}
|
|
||||||
)
|
|
||||||
event._state.adding = True
|
event._state.adding = True
|
||||||
event.save()
|
event.save()
|
||||||
|
|
||||||
|
@ -58,7 +56,5 @@ class Migration(migrations.Migration):
|
||||||
model_name="event",
|
model_name="event",
|
||||||
name="user",
|
name="user",
|
||||||
),
|
),
|
||||||
migrations.RenameField(
|
migrations.RenameField(model_name="event", old_name="user_json", new_name="user"),
|
||||||
model_name="event", old_name="user_json", new_name="user"
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -11,16 +11,12 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
Group = apps.get_model("authentik_core", "Group")
|
Group = apps.get_model("authentik_core", "Group")
|
||||||
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
||||||
EventMatcherPolicy = apps.get_model(
|
EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
|
||||||
"authentik_policies_event_matcher", "EventMatcherPolicy"
|
|
||||||
)
|
|
||||||
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
|
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
|
||||||
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
|
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
|
||||||
|
|
||||||
admin_group = (
|
admin_group = (
|
||||||
Group.objects.using(db_alias)
|
Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
|
||||||
.filter(name="authentik Admins", is_superuser=True)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
|
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
|
||||||
|
@ -32,9 +28,7 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
|
||||||
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
|
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
|
||||||
)
|
)
|
||||||
trigger.transports.set(
|
trigger.transports.set(
|
||||||
NotificationTransport.objects.using(db_alias).filter(
|
NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
|
||||||
name="default-email-transport"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
trigger.save()
|
trigger.save()
|
||||||
PolicyBinding.objects.using(db_alias).update_or_create(
|
PolicyBinding.objects.using(db_alias).update_or_create(
|
||||||
|
@ -50,16 +44,12 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
Group = apps.get_model("authentik_core", "Group")
|
Group = apps.get_model("authentik_core", "Group")
|
||||||
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
||||||
EventMatcherPolicy = apps.get_model(
|
EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
|
||||||
"authentik_policies_event_matcher", "EventMatcherPolicy"
|
|
||||||
)
|
|
||||||
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
|
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
|
||||||
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
|
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
|
||||||
|
|
||||||
admin_group = (
|
admin_group = (
|
||||||
Group.objects.using(db_alias)
|
Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
|
||||||
.filter(name="authentik Admins", is_superuser=True)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
|
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
|
||||||
|
@ -71,9 +61,7 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
|
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
|
||||||
)
|
)
|
||||||
trigger.transports.set(
|
trigger.transports.set(
|
||||||
NotificationTransport.objects.using(db_alias).filter(
|
NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
|
||||||
name="default-email-transport"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
trigger.save()
|
trigger.save()
|
||||||
PolicyBinding.objects.using(db_alias).update_or_create(
|
PolicyBinding.objects.using(db_alias).update_or_create(
|
||||||
|
@ -89,16 +77,12 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
Group = apps.get_model("authentik_core", "Group")
|
Group = apps.get_model("authentik_core", "Group")
|
||||||
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
||||||
EventMatcherPolicy = apps.get_model(
|
EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
|
||||||
"authentik_policies_event_matcher", "EventMatcherPolicy"
|
|
||||||
)
|
|
||||||
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
|
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
|
||||||
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
|
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
|
||||||
|
|
||||||
admin_group = (
|
admin_group = (
|
||||||
Group.objects.using(db_alias)
|
Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
|
||||||
.filter(name="authentik Admins", is_superuser=True)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
|
policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
|
||||||
|
@ -114,9 +98,7 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
|
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
|
||||||
)
|
)
|
||||||
trigger.transports.set(
|
trigger.transports.set(
|
||||||
NotificationTransport.objects.using(db_alias).filter(
|
NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
|
||||||
name="default-email-transport"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
trigger.save()
|
trigger.save()
|
||||||
PolicyBinding.objects.using(db_alias).update_or_create(
|
PolicyBinding.objects.using(db_alias).update_or_create(
|
||||||
|
|
|
@ -38,9 +38,7 @@ def progress_bar(
|
||||||
|
|
||||||
def print_progress_bar(iteration):
|
def print_progress_bar(iteration):
|
||||||
"""Progress Bar Printing Function"""
|
"""Progress Bar Printing Function"""
|
||||||
percent = ("{0:." + str(decimals) + "f}").format(
|
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
|
||||||
100 * (iteration / float(total))
|
|
||||||
)
|
|
||||||
filledLength = int(length * iteration // total)
|
filledLength = int(length * iteration // total)
|
||||||
bar = fill * filledLength + "-" * (length - filledLength)
|
bar = fill * filledLength + "-" * (length - filledLength)
|
||||||
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
|
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
|
||||||
|
@ -78,9 +76,7 @@ class Migration(migrations.Migration):
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="event",
|
model_name="event",
|
||||||
name="expires",
|
name="expires",
|
||||||
field=models.DateTimeField(
|
field=models.DateTimeField(default=authentik.events.models.default_event_duration),
|
||||||
default=authentik.events.models.default_event_duration
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="event",
|
model_name="event",
|
||||||
|
|
|
@ -15,9 +15,7 @@ class Migration(migrations.Migration):
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="event",
|
model_name="event",
|
||||||
name="tenant",
|
name="tenant",
|
||||||
field=models.JSONField(
|
field=models.JSONField(blank=True, default=authentik.events.models.default_tenant),
|
||||||
blank=True, default=authentik.events.models.default_tenant
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="event",
|
model_name="event",
|
||||||
|
|
|
@ -15,10 +15,7 @@ from requests import RequestException, post
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik import __version__
|
from authentik import __version__
|
||||||
from authentik.core.middleware import (
|
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
|
||||||
SESSION_IMPERSONATE_ORIGINAL_USER,
|
|
||||||
SESSION_IMPERSONATE_USER,
|
|
||||||
)
|
|
||||||
from authentik.core.models import ExpiringModel, Group, User
|
from authentik.core.models import ExpiringModel, Group, User
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.geo import GEOIP_READER
|
||||||
from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
|
from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
|
||||||
|
@ -159,9 +156,7 @@ class Event(ExpiringModel):
|
||||||
if hasattr(request, "user"):
|
if hasattr(request, "user"):
|
||||||
original_user = None
|
original_user = None
|
||||||
if hasattr(request, "session"):
|
if hasattr(request, "session"):
|
||||||
original_user = request.session.get(
|
original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None)
|
||||||
SESSION_IMPERSONATE_ORIGINAL_USER, None
|
|
||||||
)
|
|
||||||
self.user = get_user(request.user, original_user)
|
self.user = get_user(request.user, original_user)
|
||||||
if user:
|
if user:
|
||||||
self.user = get_user(user)
|
self.user = get_user(user)
|
||||||
|
@ -169,9 +164,7 @@ class Event(ExpiringModel):
|
||||||
if hasattr(request, "session"):
|
if hasattr(request, "session"):
|
||||||
if SESSION_IMPERSONATE_ORIGINAL_USER in request.session:
|
if SESSION_IMPERSONATE_ORIGINAL_USER in request.session:
|
||||||
self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER])
|
self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER])
|
||||||
self.user["on_behalf_of"] = get_user(
|
self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER])
|
||||||
request.session[SESSION_IMPERSONATE_USER]
|
|
||||||
)
|
|
||||||
# User 255.255.255.255 as fallback if IP cannot be determined
|
# User 255.255.255.255 as fallback if IP cannot be determined
|
||||||
self.client_ip = get_client_ip(request)
|
self.client_ip = get_client_ip(request)
|
||||||
# Apply GeoIP Data, when enabled
|
# Apply GeoIP Data, when enabled
|
||||||
|
@ -414,9 +407,7 @@ class NotificationRule(PolicyBindingModel):
|
||||||
severity = models.TextField(
|
severity = models.TextField(
|
||||||
choices=NotificationSeverity.choices,
|
choices=NotificationSeverity.choices,
|
||||||
default=NotificationSeverity.NOTICE,
|
default=NotificationSeverity.NOTICE,
|
||||||
help_text=_(
|
help_text=_("Controls which severity level the created notifications will have."),
|
||||||
"Controls which severity level the created notifications will have."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
group = models.ForeignKey(
|
group = models.ForeignKey(
|
||||||
Group,
|
Group,
|
||||||
|
|
|
@ -135,9 +135,7 @@ class MonitoredTask(Task):
|
||||||
self._result = result
|
self._result = result
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def after_return(
|
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
|
||||||
self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo
|
|
||||||
):
|
|
||||||
if self._result:
|
if self._result:
|
||||||
if not self._result.uid:
|
if not self._result.uid:
|
||||||
self._result.uid = self._uid
|
self._result.uid = self._uid
|
||||||
|
@ -159,9 +157,7 @@ class MonitoredTask(Task):
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
def on_failure(self, exc, task_id, args, kwargs, einfo):
|
||||||
if not self._result:
|
if not self._result:
|
||||||
self._result = TaskResult(
|
self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
|
||||||
status=TaskResultStatus.ERROR, messages=[str(exc)]
|
|
||||||
)
|
|
||||||
if not self._result.uid:
|
if not self._result.uid:
|
||||||
self._result.uid = self._uid
|
self._result.uid = self._uid
|
||||||
TaskInfo(
|
TaskInfo(
|
||||||
|
@ -179,8 +175,7 @@ class MonitoredTask(Task):
|
||||||
Event.new(
|
Event.new(
|
||||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||||
message=(
|
message=(
|
||||||
f"Task {self.__name__} encountered an error: "
|
f"Task {self.__name__} encountered an error: " "\n".join(self._result.messages)
|
||||||
"\n".join(self._result.messages)
|
|
||||||
),
|
),
|
||||||
).save()
|
).save()
|
||||||
return super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
|
return super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
|
||||||
|
|
|
@ -2,11 +2,7 @@
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.contrib.auth.signals import (
|
from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
|
||||||
user_logged_in,
|
|
||||||
user_logged_out,
|
|
||||||
user_login_failed,
|
|
||||||
)
|
|
||||||
from django.db.models.signals import post_save
|
from django.db.models.signals import post_save
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -30,9 +26,7 @@ class EventNewThread(Thread):
|
||||||
kwargs: dict[str, Any]
|
kwargs: dict[str, Any]
|
||||||
user: Optional[User] = None
|
user: Optional[User] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs):
|
||||||
self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.action = action
|
self.action = action
|
||||||
self.request = request
|
self.request = request
|
||||||
|
@ -68,9 +62,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
||||||
|
|
||||||
@receiver(user_write)
|
@receiver(user_write)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def on_user_write(
|
def on_user_write(sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs):
|
||||||
sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs
|
|
||||||
):
|
|
||||||
"""Log User write"""
|
"""Log User write"""
|
||||||
thread = EventNewThread(EventAction.USER_WRITE, request, **data)
|
thread = EventNewThread(EventAction.USER_WRITE, request, **data)
|
||||||
thread.kwargs["created"] = kwargs.get("created", False)
|
thread.kwargs["created"] = kwargs.get("created", False)
|
||||||
|
@ -80,9 +72,7 @@ def on_user_write(
|
||||||
|
|
||||||
@receiver(user_login_failed)
|
@receiver(user_login_failed)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def on_user_login_failed(
|
def on_user_login_failed(sender, credentials: dict[str, str], request: HttpRequest, **_):
|
||||||
sender, credentials: dict[str, str], request: HttpRequest, **_
|
|
||||||
):
|
|
||||||
"""Failed Login"""
|
"""Failed Login"""
|
||||||
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
|
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
|
||||||
thread.run()
|
thread.run()
|
||||||
|
|
|
@ -22,9 +22,7 @@ LOGGER = get_logger()
|
||||||
def event_notification_handler(event_uuid: str):
|
def event_notification_handler(event_uuid: str):
|
||||||
"""Start task for each trigger definition"""
|
"""Start task for each trigger definition"""
|
||||||
for trigger in NotificationRule.objects.all():
|
for trigger in NotificationRule.objects.all():
|
||||||
event_trigger_handler.apply_async(
|
event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events")
|
||||||
args=[event_uuid, trigger.name], queue="authentik_events"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task()
|
@CELERY_APP.task()
|
||||||
|
@ -43,17 +41,13 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||||
if "policy_uuid" in event.context:
|
if "policy_uuid" in event.context:
|
||||||
policy_uuid = event.context["policy_uuid"]
|
policy_uuid = event.context["policy_uuid"]
|
||||||
if PolicyBinding.objects.filter(
|
if PolicyBinding.objects.filter(
|
||||||
target__in=NotificationRule.objects.all().values_list(
|
target__in=NotificationRule.objects.all().values_list("pbm_uuid", flat=True),
|
||||||
"pbm_uuid", flat=True
|
|
||||||
),
|
|
||||||
policy=policy_uuid,
|
policy=policy_uuid,
|
||||||
).exists():
|
).exists():
|
||||||
# If policy that caused this event to be created is attached
|
# If policy that caused this event to be created is attached
|
||||||
# to *any* NotificationRule, we return early.
|
# to *any* NotificationRule, we return early.
|
||||||
# This is the most effective way to prevent infinite loops.
|
# This is the most effective way to prevent infinite loops.
|
||||||
LOGGER.debug(
|
LOGGER.debug("e(trigger): attempting to prevent infinite loop", trigger=trigger)
|
||||||
"e(trigger): attempting to prevent infinite loop", trigger=trigger
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not trigger.group:
|
if not trigger.group:
|
||||||
|
@ -62,9 +56,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||||
|
|
||||||
LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger)
|
LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger)
|
||||||
try:
|
try:
|
||||||
user = (
|
user = User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user()
|
||||||
User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user()
|
|
||||||
)
|
|
||||||
except User.DoesNotExist:
|
except User.DoesNotExist:
|
||||||
LOGGER.warning("e(trigger): failed to get user", trigger=trigger)
|
LOGGER.warning("e(trigger): failed to get user", trigger=trigger)
|
||||||
return
|
return
|
||||||
|
@ -99,20 +91,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||||
retry_backoff=True,
|
retry_backoff=True,
|
||||||
base=MonitoredTask,
|
base=MonitoredTask,
|
||||||
)
|
)
|
||||||
def notification_transport(
|
def notification_transport(self: MonitoredTask, notification_pk: int, transport_pk: int):
|
||||||
self: MonitoredTask, notification_pk: int, transport_pk: int
|
|
||||||
):
|
|
||||||
"""Send notification over specified transport"""
|
"""Send notification over specified transport"""
|
||||||
self.save_on_success = False
|
self.save_on_success = False
|
||||||
try:
|
try:
|
||||||
notification: Notification = Notification.objects.filter(
|
notification: Notification = Notification.objects.filter(pk=notification_pk).first()
|
||||||
pk=notification_pk
|
|
||||||
).first()
|
|
||||||
if not notification:
|
if not notification:
|
||||||
return
|
return
|
||||||
transport: NotificationTransport = NotificationTransport.objects.get(
|
transport: NotificationTransport = NotificationTransport.objects.get(pk=transport_pk)
|
||||||
pk=transport_pk
|
|
||||||
)
|
|
||||||
transport.send(notification)
|
transport.send(notification)
|
||||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
|
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
|
||||||
except NotificationTransportError as exc:
|
except NotificationTransportError as exc:
|
||||||
|
|
|
@ -38,7 +38,5 @@ class TestEvents(TestCase):
|
||||||
event = Event.new("unittest", model=temp_model)
|
event = Event.new("unittest", model=temp_model)
|
||||||
event.save() # We save to ensure nothing is un-saveable
|
event.save() # We save to ensure nothing is un-saveable
|
||||||
model_content_type = ContentType.objects.get_for_model(temp_model)
|
model_content_type = ContentType.objects.get_for_model(temp_model)
|
||||||
self.assertEqual(
|
self.assertEqual(event.context.get("model").get("app"), model_content_type.app_label)
|
||||||
event.context.get("model").get("app"), model_content_type.app_label
|
|
||||||
)
|
|
||||||
self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex)
|
self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex)
|
||||||
|
|
|
@ -81,12 +81,8 @@ class TestEventsNotifications(TestCase):
|
||||||
|
|
||||||
execute_mock = MagicMock()
|
execute_mock = MagicMock()
|
||||||
passes = MagicMock(side_effect=PolicyException)
|
passes = MagicMock(side_effect=PolicyException)
|
||||||
with patch(
|
with patch("authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes):
|
||||||
"authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes
|
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"authentik.events.models.NotificationTransport.send", execute_mock
|
|
||||||
):
|
|
||||||
Event.new(EventAction.CUSTOM_PREFIX).save()
|
Event.new(EventAction.CUSTOM_PREFIX).save()
|
||||||
self.assertEqual(passes.call_count, 1)
|
self.assertEqual(passes.call_count, 1)
|
||||||
|
|
||||||
|
@ -96,9 +92,7 @@ class TestEventsNotifications(TestCase):
|
||||||
self.group.users.add(user2)
|
self.group.users.add(user2)
|
||||||
self.group.save()
|
self.group.save()
|
||||||
|
|
||||||
transport = NotificationTransport.objects.create(
|
transport = NotificationTransport.objects.create(name="transport", send_once=True)
|
||||||
name="transport", send_once=True
|
|
||||||
)
|
|
||||||
NotificationRule.objects.filter(name__startswith="default").delete()
|
NotificationRule.objects.filter(name__startswith="default").delete()
|
||||||
trigger = NotificationRule.objects.create(name="trigger", group=self.group)
|
trigger = NotificationRule.objects.create(name="trigger", group=self.group)
|
||||||
trigger.transports.add(transport)
|
trigger.transports.add(transport)
|
||||||
|
|
|
@ -14,12 +14,7 @@ from rest_framework.fields import BooleanField, FileField, ReadOnlyField
|
||||||
from rest_framework.parsers import MultiPartParser
|
from rest_framework.parsers import MultiPartParser
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import (
|
from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField
|
||||||
CharField,
|
|
||||||
ModelSerializer,
|
|
||||||
Serializer,
|
|
||||||
SerializerMethodField,
|
|
||||||
)
|
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
@ -152,11 +147,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@extend_schema(
|
@extend_schema(
|
||||||
request={
|
request={"multipart/form-data": inline_serializer("SetIcon", fields={"file": FileField()})},
|
||||||
"multipart/form-data": inline_serializer(
|
|
||||||
"SetIcon", fields={"file": FileField()}
|
|
||||||
)
|
|
||||||
},
|
|
||||||
responses={
|
responses={
|
||||||
204: OpenApiResponse(description="Successfully imported flow"),
|
204: OpenApiResponse(description="Successfully imported flow"),
|
||||||
400: OpenApiResponse(description="Bad request"),
|
400: OpenApiResponse(description="Bad request"),
|
||||||
|
@ -221,9 +212,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
|
||||||
.order_by("order")
|
.order_by("order")
|
||||||
):
|
):
|
||||||
for p_index, policy_binding in enumerate(
|
for p_index, policy_binding in enumerate(
|
||||||
get_objects_for_user(
|
get_objects_for_user(request.user, "authentik_policies.view_policybinding")
|
||||||
request.user, "authentik_policies.view_policybinding"
|
|
||||||
)
|
|
||||||
.filter(target=stage_binding)
|
.filter(target=stage_binding)
|
||||||
.exclude(policy__isnull=True)
|
.exclude(policy__isnull=True)
|
||||||
.order_by("order")
|
.order_by("order")
|
||||||
|
@ -256,20 +245,14 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
|
||||||
element: DiagramElement = body[index]
|
element: DiagramElement = body[index]
|
||||||
if element.type == "condition":
|
if element.type == "condition":
|
||||||
# Policy passes, link policy yes to next stage
|
# Policy passes, link policy yes to next stage
|
||||||
footer.append(
|
footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}")
|
||||||
f"{element.identifier}(yes, right)->{body[index + 1].identifier}"
|
|
||||||
)
|
|
||||||
# Policy doesn't pass, go to stage after next stage
|
# Policy doesn't pass, go to stage after next stage
|
||||||
no_element = body[index + 1]
|
no_element = body[index + 1]
|
||||||
if no_element.type != "end":
|
if no_element.type != "end":
|
||||||
no_element = body[index + 2]
|
no_element = body[index + 2]
|
||||||
footer.append(
|
footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}")
|
||||||
f"{element.identifier}(no, bottom)->{no_element.identifier}"
|
|
||||||
)
|
|
||||||
elif element.type == "operation":
|
elif element.type == "operation":
|
||||||
footer.append(
|
footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}")
|
||||||
f"{element.identifier}(bottom)->{body[index + 1].identifier}"
|
|
||||||
)
|
|
||||||
diagram = "\n".join([str(x) for x in header + body + footer])
|
diagram = "\n".join([str(x) for x in header + body + footer])
|
||||||
return Response({"diagram": diagram})
|
return Response({"diagram": diagram})
|
||||||
|
|
||||||
|
|
|
@ -95,9 +95,7 @@ class Command(BaseCommand): # pragma: no cover
|
||||||
"""Output results human readable"""
|
"""Output results human readable"""
|
||||||
total_max: int = max([max(inner) for inner in values])
|
total_max: int = max([max(inner) for inner in values])
|
||||||
total_min: int = min([min(inner) for inner in values])
|
total_min: int = min([min(inner) for inner in values])
|
||||||
total_avg = sum([sum(inner) for inner in values]) / sum(
|
total_avg = sum([sum(inner) for inner in values]) / sum([len(inner) for inner in values])
|
||||||
[len(inner) for inner in values]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Version: {__version__}")
|
print(f"Version: {__version__}")
|
||||||
print(f"Processes: {len(values)}")
|
print(f"Processes: {len(values)}")
|
||||||
|
|
|
@ -9,21 +9,15 @@ from authentik.stages.identification.models import UserFields
|
||||||
from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP
|
from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP
|
||||||
|
|
||||||
|
|
||||||
def create_default_authentication_flow(
|
def create_default_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
|
||||||
):
|
|
||||||
Flow = apps.get_model("authentik_flows", "Flow")
|
Flow = apps.get_model("authentik_flows", "Flow")
|
||||||
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
||||||
PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage")
|
PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage")
|
||||||
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")
|
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")
|
||||||
IdentificationStage = apps.get_model(
|
IdentificationStage = apps.get_model("authentik_stages_identification", "IdentificationStage")
|
||||||
"authentik_stages_identification", "IdentificationStage"
|
|
||||||
)
|
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
|
|
||||||
identification_stage, _ = IdentificationStage.objects.using(
|
identification_stage, _ = IdentificationStage.objects.using(db_alias).update_or_create(
|
||||||
db_alias
|
|
||||||
).update_or_create(
|
|
||||||
name="default-authentication-identification",
|
name="default-authentication-identification",
|
||||||
defaults={
|
defaults={
|
||||||
"user_fields": [UserFields.E_MAIL, UserFields.USERNAME],
|
"user_fields": [UserFields.E_MAIL, UserFields.USERNAME],
|
||||||
|
@ -69,17 +63,13 @@ def create_default_authentication_flow(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_default_invalidation_flow(
|
def create_default_invalidation_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
|
||||||
):
|
|
||||||
Flow = apps.get_model("authentik_flows", "Flow")
|
Flow = apps.get_model("authentik_flows", "Flow")
|
||||||
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
||||||
UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage")
|
UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage")
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
|
|
||||||
UserLogoutStage.objects.using(db_alias).update_or_create(
|
UserLogoutStage.objects.using(db_alias).update_or_create(name="default-invalidation-logout")
|
||||||
name="default-invalidation-logout"
|
|
||||||
)
|
|
||||||
|
|
||||||
flow, _ = Flow.objects.using(db_alias).update_or_create(
|
flow, _ = Flow.objects.using(db_alias).update_or_create(
|
||||||
slug="default-invalidation-flow",
|
slug="default-invalidation-flow",
|
||||||
|
|
|
@ -15,16 +15,12 @@ PROMPT_POLICY_EXPRESSION = """# Check if we've not been given a username by the
|
||||||
return 'username' not in context.get('prompt_data', {})"""
|
return 'username' not in context.get('prompt_data', {})"""
|
||||||
|
|
||||||
|
|
||||||
def create_default_source_enrollment_flow(
|
def create_default_source_enrollment_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
|
||||||
):
|
|
||||||
Flow = apps.get_model("authentik_flows", "Flow")
|
Flow = apps.get_model("authentik_flows", "Flow")
|
||||||
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
||||||
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
||||||
|
|
||||||
ExpressionPolicy = apps.get_model(
|
ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
|
||||||
"authentik_policies_expression", "ExpressionPolicy"
|
|
||||||
)
|
|
||||||
|
|
||||||
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
|
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
|
||||||
Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
|
Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
|
||||||
|
@ -99,16 +95,12 @@ def create_default_source_enrollment_flow(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_default_source_authentication_flow(
|
def create_default_source_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
|
||||||
):
|
|
||||||
Flow = apps.get_model("authentik_flows", "Flow")
|
Flow = apps.get_model("authentik_flows", "Flow")
|
||||||
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
||||||
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
|
||||||
|
|
||||||
ExpressionPolicy = apps.get_model(
|
ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
|
||||||
"authentik_policies_expression", "ExpressionPolicy"
|
|
||||||
)
|
|
||||||
|
|
||||||
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")
|
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||||
from authentik.flows.models import FlowDesignation
|
from authentik.flows.models import FlowDesignation
|
||||||
|
|
||||||
|
|
||||||
def create_default_provider_authorization_flow(
|
def create_default_provider_authorization_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
|
||||||
):
|
|
||||||
Flow = apps.get_model("authentik_flows", "Flow")
|
Flow = apps.get_model("authentik_flows", "Flow")
|
||||||
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
|
||||||
|
|
||||||
|
|
|
@ -32,9 +32,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
||||||
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
|
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
|
||||||
Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
|
Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
|
||||||
|
|
||||||
ExpressionPolicy = apps.get_model(
|
ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
|
||||||
"authentik_policies_expression", "ExpressionPolicy"
|
|
||||||
)
|
|
||||||
|
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
|
|
||||||
|
@ -52,9 +50,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
||||||
name="default-oobe-prefill-user",
|
name="default-oobe-prefill-user",
|
||||||
defaults={"expression": PREFILL_POLICY_EXPRESSION},
|
defaults={"expression": PREFILL_POLICY_EXPRESSION},
|
||||||
)
|
)
|
||||||
password_usable_policy, _ = ExpressionPolicy.objects.using(
|
password_usable_policy, _ = ExpressionPolicy.objects.using(db_alias).update_or_create(
|
||||||
db_alias
|
|
||||||
).update_or_create(
|
|
||||||
name="default-oobe-password-usable",
|
name="default-oobe-password-usable",
|
||||||
defaults={"expression": PW_USABLE_POLICY_EXPRESSION},
|
defaults={"expression": PW_USABLE_POLICY_EXPRESSION},
|
||||||
)
|
)
|
||||||
|
@ -83,9 +79,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
|
||||||
prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create(
|
prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create(
|
||||||
name="default-oobe-password",
|
name="default-oobe-password",
|
||||||
)
|
)
|
||||||
prompt_stage.fields.set(
|
prompt_stage.fields.set([prompt_header, prompt_email, password_first, password_second])
|
||||||
[prompt_header, prompt_email, password_first, password_second]
|
|
||||||
)
|
|
||||||
prompt_stage.save()
|
prompt_stage.save()
|
||||||
|
|
||||||
user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create(
|
user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create(
|
||||||
|
|
|
@ -138,9 +138,7 @@ class Flow(SerializerModel, PolicyBindingModel):
|
||||||
it is returned as-is"""
|
it is returned as-is"""
|
||||||
if not self.background:
|
if not self.background:
|
||||||
return "/static/dist/assets/images/flow_background.jpg"
|
return "/static/dist/assets/images/flow_background.jpg"
|
||||||
if self.background.name.startswith("http") or self.background.name.startswith(
|
if self.background.name.startswith("http") or self.background.name.startswith("/static"):
|
||||||
"/static"
|
|
||||||
):
|
|
||||||
return self.background.name
|
return self.background.name
|
||||||
return self.background.url
|
return self.background.url
|
||||||
|
|
||||||
|
@ -165,9 +163,7 @@ class Flow(SerializerModel, PolicyBindingModel):
|
||||||
if result.passing:
|
if result.passing:
|
||||||
LOGGER.debug("with_policy: flow passing", flow=flow)
|
LOGGER.debug("with_policy: flow passing", flow=flow)
|
||||||
return flow
|
return flow
|
||||||
LOGGER.warning(
|
LOGGER.warning("with_policy: flow not passing", flow=flow, messages=result.messages)
|
||||||
"with_policy: flow not passing", flow=flow, messages=result.messages
|
|
||||||
)
|
|
||||||
LOGGER.debug("with_policy: no flow found", filters=flow_filter)
|
LOGGER.debug("with_policy: no flow found", filters=flow_filter)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -78,14 +78,10 @@ class FlowPlan:
|
||||||
marker = self.markers[0]
|
marker = self.markers[0]
|
||||||
|
|
||||||
if marker.__class__ is not StageMarker:
|
if marker.__class__ is not StageMarker:
|
||||||
LOGGER.debug(
|
LOGGER.debug("f(plan_inst): stage has marker", binding=binding, marker=marker)
|
||||||
"f(plan_inst): stage has marker", binding=binding, marker=marker
|
|
||||||
)
|
|
||||||
marked_stage = marker.process(self, binding, http_request)
|
marked_stage = marker.process(self, binding, http_request)
|
||||||
if not marked_stage:
|
if not marked_stage:
|
||||||
LOGGER.debug(
|
LOGGER.debug("f(plan_inst): marker returned none, next stage", binding=binding)
|
||||||
"f(plan_inst): marker returned none, next stage", binding=binding
|
|
||||||
)
|
|
||||||
self.bindings.remove(binding)
|
self.bindings.remove(binding)
|
||||||
self.markers.remove(marker)
|
self.markers.remove(marker)
|
||||||
if not self.has_stages:
|
if not self.has_stages:
|
||||||
|
@ -193,9 +189,9 @@ class FlowPlanner:
|
||||||
if default_context:
|
if default_context:
|
||||||
plan.context = default_context
|
plan.context = default_context
|
||||||
# Check Flow policies
|
# Check Flow policies
|
||||||
for binding in FlowStageBinding.objects.filter(
|
for binding in FlowStageBinding.objects.filter(target__pk=self.flow.pk).order_by(
|
||||||
target__pk=self.flow.pk
|
"order"
|
||||||
).order_by("order"):
|
):
|
||||||
binding: FlowStageBinding
|
binding: FlowStageBinding
|
||||||
stage = binding.stage
|
stage = binding.stage
|
||||||
marker = StageMarker()
|
marker = StageMarker()
|
||||||
|
|
|
@ -26,9 +26,7 @@ def invalidate_flow_cache(sender, instance, **_):
|
||||||
LOGGER.debug("Invalidating Flow cache", flow=instance, len=total)
|
LOGGER.debug("Invalidating Flow cache", flow=instance, len=total)
|
||||||
if isinstance(instance, FlowStageBinding):
|
if isinstance(instance, FlowStageBinding):
|
||||||
total = delete_cache_prefix(f"{cache_key(instance.target)}*")
|
total = delete_cache_prefix(f"{cache_key(instance.target)}*")
|
||||||
LOGGER.debug(
|
LOGGER.debug("Invalidating Flow cache from FlowStageBinding", binding=instance, len=total)
|
||||||
"Invalidating Flow cache from FlowStageBinding", binding=instance, len=total
|
|
||||||
)
|
|
||||||
if isinstance(instance, Stage):
|
if isinstance(instance, Stage):
|
||||||
total = 0
|
total = 0
|
||||||
for binding in FlowStageBinding.objects.filter(stage=instance):
|
for binding in FlowStageBinding.objects.filter(stage=instance):
|
||||||
|
|
|
@ -42,14 +42,9 @@ class StageView(View):
|
||||||
other things besides the form display.
|
other things besides the form display.
|
||||||
|
|
||||||
If no user is pending, returns request.user"""
|
If no user is pending, returns request.user"""
|
||||||
if (
|
if PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context and for_display:
|
||||||
PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context
|
|
||||||
and for_display
|
|
||||||
):
|
|
||||||
return User(
|
return User(
|
||||||
username=self.executor.plan.context.get(
|
username=self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER_IDENTIFIER),
|
||||||
PLAN_CONTEXT_PENDING_USER_IDENTIFIER
|
|
||||||
),
|
|
||||||
email="",
|
email="",
|
||||||
)
|
)
|
||||||
if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context:
|
if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context:
|
||||||
|
|
|
@ -89,14 +89,10 @@ class TestFlowPlanner(TestCase):
|
||||||
|
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
planner.plan(request)
|
planner.plan(request)
|
||||||
self.assertEqual(
|
self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure plan is written to cache
|
||||||
CACHE_MOCK.set.call_count, 1
|
|
||||||
) # Ensure plan is written to cache
|
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
planner.plan(request)
|
planner.plan(request)
|
||||||
self.assertEqual(
|
self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure nothing is written to cache
|
||||||
CACHE_MOCK.set.call_count, 1
|
|
||||||
) # Ensure nothing is written to cache
|
|
||||||
self.assertEqual(CACHE_MOCK.get.call_count, 2) # Get is called twice
|
self.assertEqual(CACHE_MOCK.get.call_count, 2) # Get is called twice
|
||||||
|
|
||||||
def test_planner_default_context(self):
|
def test_planner_default_context(self):
|
||||||
|
@ -176,9 +172,7 @@ class TestFlowPlanner(TestCase):
|
||||||
request.session.save()
|
request.session.save()
|
||||||
|
|
||||||
# Here we patch the dummy policy to evaluate to true so the stage is included
|
# Here we patch the dummy policy to evaluate to true so the stage is included
|
||||||
with patch(
|
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
|
||||||
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
|
|
||||||
):
|
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
plan = planner.plan(request)
|
plan = planner.plan(request)
|
||||||
|
|
||||||
|
|
|
@ -76,9 +76,7 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0)
|
PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0)
|
||||||
|
|
||||||
user_login = UserLoginStage.objects.create(name=stage_name)
|
user_login = UserLoginStage.objects.create(name=stage_name)
|
||||||
fsb = FlowStageBinding.objects.create(
|
fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0)
|
||||||
target=flow, stage=user_login, order=0
|
|
||||||
)
|
|
||||||
PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0)
|
PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0)
|
||||||
|
|
||||||
exporter = FlowExporter(flow)
|
exporter = FlowExporter(flow)
|
||||||
|
|
|
@ -11,12 +11,7 @@ from authentik.core.models import User
|
||||||
from authentik.flows.challenge import ChallengeTypes
|
from authentik.flows.challenge import ChallengeTypes
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
from authentik.flows.exceptions import FlowNonApplicableException
|
||||||
from authentik.flows.markers import ReevaluateMarker, StageMarker
|
from authentik.flows.markers import ReevaluateMarker, StageMarker
|
||||||
from authentik.flows.models import (
|
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction
|
||||||
Flow,
|
|
||||||
FlowDesignation,
|
|
||||||
FlowStageBinding,
|
|
||||||
InvalidResponseAction,
|
|
||||||
)
|
|
||||||
from authentik.flows.planner import FlowPlan, FlowPlanner
|
from authentik.flows.planner import FlowPlan, FlowPlanner
|
||||||
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
|
||||||
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
|
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
|
||||||
|
@ -61,9 +56,7 @@ class TestFlowExecutor(TestCase):
|
||||||
)
|
)
|
||||||
stage = DummyStage.objects.create(name="dummy")
|
stage = DummyStage.objects.create(name="dummy")
|
||||||
binding = FlowStageBinding(target=flow, stage=stage, order=0)
|
binding = FlowStageBinding(target=flow, stage=stage, order=0)
|
||||||
plan = FlowPlan(
|
plan = FlowPlan(flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()])
|
||||||
flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()]
|
|
||||||
)
|
|
||||||
session = self.client.session
|
session = self.client.session
|
||||||
session[SESSION_KEY_PLAN] = plan
|
session[SESSION_KEY_PLAN] = plan
|
||||||
session.save()
|
session.save()
|
||||||
|
@ -163,9 +156,7 @@ class TestFlowExecutor(TestCase):
|
||||||
target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1
|
target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1
|
||||||
)
|
)
|
||||||
|
|
||||||
exec_url = reverse(
|
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
# First Request, start planning, renders form
|
# First Request, start planning, renders form
|
||||||
response = self.client.get(exec_url)
|
response = self.client.get(exec_url)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
@ -209,13 +200,9 @@ class TestFlowExecutor(TestCase):
|
||||||
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
|
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
|
||||||
|
|
||||||
# Here we patch the dummy policy to evaluate to true so the stage is included
|
# Here we patch the dummy policy to evaluate to true so the stage is included
|
||||||
with patch(
|
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
|
||||||
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
|
|
||||||
):
|
|
||||||
|
|
||||||
exec_url = reverse(
|
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
# First request, run the planner
|
# First request, run the planner
|
||||||
response = self.client.get(exec_url)
|
response = self.client.get(exec_url)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
@ -263,13 +250,9 @@ class TestFlowExecutor(TestCase):
|
||||||
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
|
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
|
||||||
|
|
||||||
# Here we patch the dummy policy to evaluate to true so the stage is included
|
# Here we patch the dummy policy to evaluate to true so the stage is included
|
||||||
with patch(
|
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
|
||||||
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
|
|
||||||
):
|
|
||||||
|
|
||||||
exec_url = reverse(
|
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
# First request, run the planner
|
# First request, run the planner
|
||||||
response = self.client.get(exec_url)
|
response = self.client.get(exec_url)
|
||||||
|
|
||||||
|
@ -334,13 +317,9 @@ class TestFlowExecutor(TestCase):
|
||||||
PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0)
|
PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0)
|
||||||
|
|
||||||
# Here we patch the dummy policy to evaluate to true so the stage is included
|
# Here we patch the dummy policy to evaluate to true so the stage is included
|
||||||
with patch(
|
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
|
||||||
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
|
|
||||||
):
|
|
||||||
|
|
||||||
exec_url = reverse(
|
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
# First request, run the planner
|
# First request, run the planner
|
||||||
response = self.client.get(exec_url)
|
response = self.client.get(exec_url)
|
||||||
|
|
||||||
|
@ -422,13 +401,9 @@ class TestFlowExecutor(TestCase):
|
||||||
PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0)
|
PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0)
|
||||||
|
|
||||||
# Here we patch the dummy policy to evaluate to true so the stage is included
|
# Here we patch the dummy policy to evaluate to true so the stage is included
|
||||||
with patch(
|
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
|
||||||
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
|
|
||||||
):
|
|
||||||
|
|
||||||
exec_url = reverse(
|
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
# First request, run the planner
|
# First request, run the planner
|
||||||
response = self.client.get(exec_url)
|
response = self.client.get(exec_url)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
@ -511,9 +486,7 @@ class TestFlowExecutor(TestCase):
|
||||||
)
|
)
|
||||||
request.user = user
|
request.user = user
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
plan = planner.plan(
|
plan = planner.plan(request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident})
|
||||||
request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident}
|
|
||||||
)
|
|
||||||
|
|
||||||
executor = FlowExecutorView()
|
executor = FlowExecutorView()
|
||||||
executor.plan = plan
|
executor.plan = plan
|
||||||
|
@ -542,9 +515,7 @@ class TestFlowExecutor(TestCase):
|
||||||
evaluate_on_plan=False,
|
evaluate_on_plan=False,
|
||||||
re_evaluate_policies=True,
|
re_evaluate_policies=True,
|
||||||
)
|
)
|
||||||
PolicyBinding.objects.create(
|
PolicyBinding.objects.create(policy=reputation_policy, target=deny_binding, order=0)
|
||||||
policy=reputation_policy, target=deny_binding, order=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stage 1 is an identification stage
|
# Stage 1 is an identification stage
|
||||||
ident_stage = IdentificationStage.objects.create(
|
ident_stage = IdentificationStage.objects.create(
|
||||||
|
@ -557,9 +528,7 @@ class TestFlowExecutor(TestCase):
|
||||||
order=1,
|
order=1,
|
||||||
invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT,
|
invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT,
|
||||||
)
|
)
|
||||||
exec_url = reverse(
|
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
# First request, run the planner
|
# First request, run the planner
|
||||||
response = self.client.get(exec_url)
|
response = self.client.get(exec_url)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
@ -579,9 +548,7 @@ class TestFlowExecutor(TestCase):
|
||||||
"user_fields": [UserFields.E_MAIL],
|
"user_fields": [UserFields.E_MAIL],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = self.client.post(
|
response = self.client.post(exec_url, {"uid_field": "invalid-string"}, follow=True)
|
||||||
exec_url, {"uid_field": "invalid-string"}, follow=True
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
force_str(response.content),
|
force_str(response.content),
|
||||||
|
|
|
@ -21,9 +21,7 @@ class TestHelperView(TestCase):
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_flows:default-invalidation"),
|
reverse("authentik_flows:default-invalidation"),
|
||||||
)
|
)
|
||||||
expected_url = reverse(
|
expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_core:if-flow", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 302)
|
self.assertEqual(response.status_code, 302)
|
||||||
self.assertEqual(response.url, expected_url)
|
self.assertEqual(response.url, expected_url)
|
||||||
|
|
||||||
|
@ -40,8 +38,6 @@ class TestHelperView(TestCase):
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
reverse("authentik_flows:default-invalidation"),
|
reverse("authentik_flows:default-invalidation"),
|
||||||
)
|
)
|
||||||
expected_url = reverse(
|
expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
|
||||||
"authentik_core:if-flow", kwargs={"flow_slug": flow.slug}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 302)
|
self.assertEqual(response.status_code, 302)
|
||||||
self.assertEqual(response.url, expected_url)
|
self.assertEqual(response.url, expected_url)
|
||||||
|
|
|
@ -44,9 +44,7 @@ class FlowBundleEntry:
|
||||||
attrs: dict[str, Any]
|
attrs: dict[str, Any]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model(
|
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "FlowBundleEntry":
|
||||||
model: SerializerModel, *extra_identifier_names: str
|
|
||||||
) -> "FlowBundleEntry":
|
|
||||||
"""Convert a SerializerModel instance to a Bundle Entry"""
|
"""Convert a SerializerModel instance to a Bundle Entry"""
|
||||||
identifiers = {
|
identifiers = {
|
||||||
"pk": model.pk,
|
"pk": model.pk,
|
||||||
|
|
|
@ -6,11 +6,7 @@ from uuid import UUID
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
|
|
||||||
from authentik.flows.models import Flow, FlowStageBinding, Stage
|
from authentik.flows.models import Flow, FlowStageBinding, Stage
|
||||||
from authentik.flows.transfer.common import (
|
from authentik.flows.transfer.common import DataclassEncoder, FlowBundle, FlowBundleEntry
|
||||||
DataclassEncoder,
|
|
||||||
FlowBundle,
|
|
||||||
FlowBundleEntry,
|
|
||||||
)
|
|
||||||
from authentik.policies.models import Policy, PolicyBinding
|
from authentik.policies.models import Policy, PolicyBinding
|
||||||
from authentik.stages.prompt.models import PromptStage
|
from authentik.stages.prompt.models import PromptStage
|
||||||
|
|
||||||
|
@ -37,9 +33,7 @@ class FlowExporter:
|
||||||
|
|
||||||
def walk_stages(self) -> Iterator[FlowBundleEntry]:
|
def walk_stages(self) -> Iterator[FlowBundleEntry]:
|
||||||
"""Convert all stages attached to self.flow into FlowBundleEntry objects"""
|
"""Convert all stages attached to self.flow into FlowBundleEntry objects"""
|
||||||
stages = (
|
stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
|
||||||
Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
|
|
||||||
)
|
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
if isinstance(stage, PromptStage):
|
if isinstance(stage, PromptStage):
|
||||||
pass
|
pass
|
||||||
|
@ -56,9 +50,7 @@ class FlowExporter:
|
||||||
a direct foreign key to a policy."""
|
a direct foreign key to a policy."""
|
||||||
# Special case for PromptStage as that has a direct M2M to policy, we have to ensure
|
# Special case for PromptStage as that has a direct M2M to policy, we have to ensure
|
||||||
# all policies referenced in there we also include here
|
# all policies referenced in there we also include here
|
||||||
prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list(
|
prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list("pk", flat=True)
|
||||||
"pk", flat=True
|
|
||||||
)
|
|
||||||
query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages)
|
query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages)
|
||||||
policies = Policy.objects.filter(query).select_related()
|
policies = Policy.objects.filter(query).select_related()
|
||||||
for policy in policies:
|
for policy in policies:
|
||||||
|
@ -67,9 +59,7 @@ class FlowExporter:
|
||||||
def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]:
|
def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]:
|
||||||
"""Walk over all policybindings relative to us. This is run at the end of the export, as
|
"""Walk over all policybindings relative to us. This is run at the end of the export, as
|
||||||
we are sure all objects exist now."""
|
we are sure all objects exist now."""
|
||||||
bindings = PolicyBinding.objects.filter(
|
bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related()
|
||||||
target__in=self.pbm_uuids
|
|
||||||
).select_related()
|
|
||||||
for binding in bindings:
|
for binding in bindings:
|
||||||
yield FlowBundleEntry.from_model(binding, "policy", "target", "order")
|
yield FlowBundleEntry.from_model(binding, "policy", "target", "order")
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,7 @@ from rest_framework.serializers import BaseSerializer, Serializer
|
||||||
from structlog.stdlib import BoundLogger, get_logger
|
from structlog.stdlib import BoundLogger, get_logger
|
||||||
|
|
||||||
from authentik.flows.models import Flow, FlowStageBinding, Stage
|
from authentik.flows.models import Flow, FlowStageBinding, Stage
|
||||||
from authentik.flows.transfer.common import (
|
from authentik.flows.transfer.common import EntryInvalidError, FlowBundle, FlowBundleEntry
|
||||||
EntryInvalidError,
|
|
||||||
FlowBundle,
|
|
||||||
FlowBundleEntry,
|
|
||||||
)
|
|
||||||
from authentik.lib.models import SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.policies.models import Policy, PolicyBinding
|
from authentik.policies.models import Policy, PolicyBinding
|
||||||
from authentik.stages.prompt.models import Prompt
|
from authentik.stages.prompt.models import Prompt
|
||||||
|
@ -105,9 +101,7 @@ class FlowImporter:
|
||||||
if isinstance(value, dict) and "pk" in value:
|
if isinstance(value, dict) and "pk" in value:
|
||||||
del updated_identifiers[key]
|
del updated_identifiers[key]
|
||||||
updated_identifiers[f"{key}"] = value["pk"]
|
updated_identifiers[f"{key}"] = value["pk"]
|
||||||
existing_models = model.objects.filter(
|
existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers))
|
||||||
self.__query_from_identifier(updated_identifiers)
|
|
||||||
)
|
|
||||||
|
|
||||||
serializer_kwargs = {}
|
serializer_kwargs = {}
|
||||||
if existing_models.exists():
|
if existing_models.exists():
|
||||||
|
@ -120,9 +114,7 @@ class FlowImporter:
|
||||||
)
|
)
|
||||||
serializer_kwargs["instance"] = model_instance
|
serializer_kwargs["instance"] = model_instance
|
||||||
else:
|
else:
|
||||||
self.logger.debug(
|
self.logger.debug("initialise new instance", model=model, **updated_identifiers)
|
||||||
"initialise new instance", model=model, **updated_identifiers
|
|
||||||
)
|
|
||||||
full_data = self.__update_pks_for_attrs(entry.attrs)
|
full_data = self.__update_pks_for_attrs(entry.attrs)
|
||||||
full_data.update(updated_identifiers)
|
full_data.update(updated_identifiers)
|
||||||
serializer_kwargs["data"] = full_data
|
serializer_kwargs["data"] = full_data
|
||||||
|
|
|
@ -38,13 +38,7 @@ from authentik.flows.challenge import (
|
||||||
WithUserInfoChallenge,
|
WithUserInfoChallenge,
|
||||||
)
|
)
|
||||||
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||||
from authentik.flows.models import (
|
from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage
|
||||||
ConfigurableStage,
|
|
||||||
Flow,
|
|
||||||
FlowDesignation,
|
|
||||||
FlowStageBinding,
|
|
||||||
Stage,
|
|
||||||
)
|
|
||||||
from authentik.flows.planner import (
|
from authentik.flows.planner import (
|
||||||
PLAN_CONTEXT_PENDING_USER,
|
PLAN_CONTEXT_PENDING_USER,
|
||||||
PLAN_CONTEXT_REDIRECT,
|
PLAN_CONTEXT_REDIRECT,
|
||||||
|
@ -155,9 +149,7 @@ class FlowExecutorView(APIView):
|
||||||
try:
|
try:
|
||||||
self.plan = self._initiate_plan()
|
self.plan = self._initiate_plan()
|
||||||
except FlowNonApplicableException as exc:
|
except FlowNonApplicableException as exc:
|
||||||
self._logger.warning(
|
self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
|
||||||
"f(exec): Flow not applicable to current user", exc=exc
|
|
||||||
)
|
|
||||||
return to_stage_response(self.request, self.handle_invalid_flow(exc))
|
return to_stage_response(self.request, self.handle_invalid_flow(exc))
|
||||||
except EmptyFlowException as exc:
|
except EmptyFlowException as exc:
|
||||||
self._logger.warning("f(exec): Flow is empty", exc=exc)
|
self._logger.warning("f(exec): Flow is empty", exc=exc)
|
||||||
|
@ -174,9 +166,7 @@ class FlowExecutorView(APIView):
|
||||||
# in which case we just delete the plan and invalidate everything
|
# in which case we just delete the plan and invalidate everything
|
||||||
next_binding = self.plan.next(self.request)
|
next_binding = self.plan.next(self.request)
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
self._logger.warning(
|
self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc)
|
||||||
"f(exec): found incompatible flow plan, invalidating run", exc=exc
|
|
||||||
)
|
|
||||||
keys = cache.keys("flow_*")
|
keys = cache.keys("flow_*")
|
||||||
cache.delete_many(keys)
|
cache.delete_many(keys)
|
||||||
return self.stage_invalid()
|
return self.stage_invalid()
|
||||||
|
@ -314,9 +304,7 @@ class FlowExecutorView(APIView):
|
||||||
self.request.session[SESSION_KEY_PLAN] = plan
|
self.request.session[SESSION_KEY_PLAN] = plan
|
||||||
kwargs = self.kwargs
|
kwargs = self.kwargs
|
||||||
kwargs.update({"flow_slug": self.flow.slug})
|
kwargs.update({"flow_slug": self.flow.slug})
|
||||||
return redirect_with_qs(
|
return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
|
||||||
"authentik_api:flow-executor", self.request.GET, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _flow_done(self) -> HttpResponse:
|
def _flow_done(self) -> HttpResponse:
|
||||||
"""User Successfully passed all stages"""
|
"""User Successfully passed all stages"""
|
||||||
|
@ -350,9 +338,7 @@ class FlowExecutorView(APIView):
|
||||||
)
|
)
|
||||||
kwargs = self.kwargs
|
kwargs = self.kwargs
|
||||||
kwargs.update({"flow_slug": self.flow.slug})
|
kwargs.update({"flow_slug": self.flow.slug})
|
||||||
return redirect_with_qs(
|
return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
|
||||||
"authentik_api:flow-executor", self.request.GET, **kwargs
|
|
||||||
)
|
|
||||||
# User passed all stages
|
# User passed all stages
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
"f(exec): User passed all stages",
|
"f(exec): User passed all stages",
|
||||||
|
@ -408,18 +394,13 @@ class FlowErrorResponse(TemplateResponse):
|
||||||
super().__init__(request=request, template="flows/error.html")
|
super().__init__(request=request, template="flows/error.html")
|
||||||
self.error = error
|
self.error = error
|
||||||
|
|
||||||
def resolve_context(
|
def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
|
||||||
self, context: Optional[dict[str, Any]]
|
|
||||||
) -> Optional[dict[str, Any]]:
|
|
||||||
if not context:
|
if not context:
|
||||||
context = {}
|
context = {}
|
||||||
context["error"] = self.error
|
context["error"] = self.error
|
||||||
if self._request.user and self._request.user.is_authenticated:
|
if self._request.user and self._request.user.is_authenticated:
|
||||||
if (
|
if self._request.user.is_superuser or self._request.user.group_attributes().get(
|
||||||
self._request.user.is_superuser
|
USER_ATTRIBUTE_DEBUG, False
|
||||||
or self._request.user.group_attributes().get(
|
|
||||||
USER_ATTRIBUTE_DEBUG, False
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
context["tb"] = "".join(format_tb(self.error.__traceback__))
|
context["tb"] = "".join(format_tb(self.error.__traceback__))
|
||||||
return context
|
return context
|
||||||
|
@ -464,9 +445,7 @@ class ToDefaultFlow(View):
|
||||||
flow_slug=flow.slug,
|
flow_slug=flow.slug,
|
||||||
)
|
)
|
||||||
del self.request.session[SESSION_KEY_PLAN]
|
del self.request.session[SESSION_KEY_PLAN]
|
||||||
return redirect_with_qs(
|
return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
|
||||||
"authentik_core:if-flow", request.GET, flow_slug=flow.slug
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse:
|
def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse:
|
||||||
|
|
|
@ -115,9 +115,7 @@ class ConfigLoader:
|
||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
if not key.startswith(ENV_PREFIX):
|
if not key.startswith(ENV_PREFIX):
|
||||||
continue
|
continue
|
||||||
relative_key = (
|
relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
|
||||||
key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
|
|
||||||
)
|
|
||||||
# Recursively convert path from a.b.c into outer[a][b][c]
|
# Recursively convert path from a.b.c into outer[a][b][c]
|
||||||
current_obj = outer
|
current_obj = outer
|
||||||
dot_parts = relative_key.split(".")
|
dot_parts = relative_key.split(".")
|
||||||
|
|
|
@ -37,15 +37,11 @@ class InheritanceAutoManager(InheritanceManager):
|
||||||
return super().get_queryset().select_subclasses()
|
return super().get_queryset().select_subclasses()
|
||||||
|
|
||||||
|
|
||||||
class InheritanceForwardManyToOneDescriptor(
|
class InheritanceForwardManyToOneDescriptor(models.fields.related.ForwardManyToOneDescriptor):
|
||||||
models.fields.related.ForwardManyToOneDescriptor
|
|
||||||
):
|
|
||||||
"""Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager."""
|
"""Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager."""
|
||||||
|
|
||||||
def get_queryset(self, **hints):
|
def get_queryset(self, **hints):
|
||||||
return self.field.remote_field.model.objects.db_manager(
|
return self.field.remote_field.model.objects.db_manager(hints=hints).select_subclasses()
|
||||||
hints=hints
|
|
||||||
).select_subclasses()
|
|
||||||
|
|
||||||
|
|
||||||
class InheritanceForeignKey(models.ForeignKey):
|
class InheritanceForeignKey(models.ForeignKey):
|
||||||
|
|
|
@ -8,11 +8,7 @@ from botocore.exceptions import BotoCoreError
|
||||||
from celery.exceptions import CeleryError
|
from celery.exceptions import CeleryError
|
||||||
from channels.middleware import BaseMiddleware
|
from channels.middleware import BaseMiddleware
|
||||||
from channels_redis.core import ChannelFull
|
from channels_redis.core import ChannelFull
|
||||||
from django.core.exceptions import (
|
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
|
||||||
ImproperlyConfigured,
|
|
||||||
SuspiciousOperation,
|
|
||||||
ValidationError,
|
|
||||||
)
|
|
||||||
from django.db import InternalError, OperationalError, ProgrammingError
|
from django.db import InternalError, OperationalError, ProgrammingError
|
||||||
from django.http.response import Http404
|
from django.http.response import Http404
|
||||||
from django_redis.exceptions import ConnectionInterrupted
|
from django_redis.exceptions import ConnectionInterrupted
|
||||||
|
|
|
@ -26,7 +26,5 @@ class TestEvaluator(TestCase):
|
||||||
def test_is_group_member(self):
|
def test_is_group_member(self):
|
||||||
"""Test expr_is_group_member"""
|
"""Test expr_is_group_member"""
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
BaseEvaluator.expr_is_group_member(
|
BaseEvaluator.expr_is_group_member(User.objects.get(username="akadmin"), name="test")
|
||||||
User.objects.get(username="akadmin"), name="test"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,17 +1,8 @@
|
||||||
"""Test HTTP Helpers"""
|
"""Test HTTP Helpers"""
|
||||||
from django.test import RequestFactory, TestCase
|
from django.test import RequestFactory, TestCase
|
||||||
|
|
||||||
from authentik.core.models import (
|
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User
|
||||||
USER_ATTRIBUTE_CAN_OVERRIDE_IP,
|
from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
|
||||||
Token,
|
|
||||||
TokenIntents,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
from authentik.lib.utils.http import (
|
|
||||||
OUTPOST_REMOTE_IP_HEADER,
|
|
||||||
OUTPOST_TOKEN_HEADER,
|
|
||||||
get_client_ip,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestHTTP(TestCase):
|
class TestHTTP(TestCase):
|
||||||
|
|
|
@ -9,9 +9,7 @@ class TestSentry(TestCase):
|
||||||
|
|
||||||
def test_error_not_sent(self):
|
def test_error_not_sent(self):
|
||||||
"""Test SentryIgnoredError not sent"""
|
"""Test SentryIgnoredError not sent"""
|
||||||
self.assertIsNone(
|
self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
|
||||||
before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_error_sent(self):
|
def test_error_sent(self):
|
||||||
"""Test error sent"""
|
"""Test error sent"""
|
||||||
|
|
|
@ -29,16 +29,9 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
||||||
"""Get the actual remote IP when set by an outpost. Only
|
"""Get the actual remote IP when set by an outpost. Only
|
||||||
allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set
|
allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set
|
||||||
to outpost"""
|
to outpost"""
|
||||||
from authentik.core.models import (
|
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents
|
||||||
USER_ATTRIBUTE_CAN_OVERRIDE_IP,
|
|
||||||
Token,
|
|
||||||
TokenIntents,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
|
||||||
OUTPOST_REMOTE_IP_HEADER not in request.META
|
|
||||||
or OUTPOST_TOKEN_HEADER not in request.META
|
|
||||||
):
|
|
||||||
return None
|
return None
|
||||||
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
|
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
|
||||||
tokens = Token.filter_not_expired(
|
tokens = Token.filter_not_expired(
|
||||||
|
|
|
@ -12,9 +12,7 @@ def managed_reconcile(self: MonitoredTask):
|
||||||
try:
|
try:
|
||||||
ObjectManager().run()
|
ObjectManager().run()
|
||||||
self.set_status(
|
self.set_status(
|
||||||
TaskResult(
|
TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."])
|
||||||
TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except DatabaseError as exc:
|
except DatabaseError as exc:
|
||||||
self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)]))
|
self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)]))
|
||||||
|
|
|
@ -15,12 +15,7 @@ from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
from authentik.core.api.utils import PassiveSerializer, is_dict
|
||||||
from authentik.core.models import Provider
|
from authentik.core.models import Provider
|
||||||
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
||||||
from authentik.outposts.models import (
|
from authentik.outposts.models import Outpost, OutpostConfig, OutpostType, default_outpost_config
|
||||||
Outpost,
|
|
||||||
OutpostConfig,
|
|
||||||
OutpostType,
|
|
||||||
default_outpost_config,
|
|
||||||
)
|
|
||||||
from authentik.providers.ldap.models import LDAPProvider
|
from authentik.providers.ldap.models import LDAPProvider
|
||||||
from authentik.providers.proxy.models import ProxyProvider
|
from authentik.providers.proxy.models import ProxyProvider
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,7 @@ from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import (
|
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||||
MetaNameSerializer,
|
|
||||||
PassiveSerializer,
|
|
||||||
TypeCreateSerializer,
|
|
||||||
)
|
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
from authentik.outposts.models import (
|
from authentik.outposts.models import (
|
||||||
DockerServiceConnection,
|
DockerServiceConnection,
|
||||||
|
@ -129,9 +125,7 @@ class KubernetesServiceConnectionSerializer(ServiceConnectionSerializer):
|
||||||
if kubeconfig == {}:
|
if kubeconfig == {}:
|
||||||
if not self.initial_data["local"]:
|
if not self.initial_data["local"]:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
_(
|
_("You can only use an empty kubeconfig when connecting to a local cluster.")
|
||||||
"You can only use an empty kubeconfig when connecting to a local cluster."
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Empty kubeconfig is valid
|
# Empty kubeconfig is valid
|
||||||
return kubeconfig
|
return kubeconfig
|
||||||
|
|
|
@ -59,9 +59,7 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
def connect(self):
|
def connect(self):
|
||||||
super().connect()
|
super().connect()
|
||||||
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
||||||
outpost = get_objects_for_user(
|
outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid)
|
||||||
self.user, "authentik_outposts.view_outpost"
|
|
||||||
).filter(pk=uuid)
|
|
||||||
if not outpost.exists():
|
if not outpost.exists():
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
self.accept()
|
self.accept()
|
||||||
|
@ -129,7 +127,5 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
def event_update(self, event):
|
def event_update(self, event):
|
||||||
"""Event handler which is called by post_save signals, Send update instruction"""
|
"""Event handler which is called by post_save signals, Send update instruction"""
|
||||||
self.send_json(
|
self.send_json(
|
||||||
asdict(
|
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
|
||||||
WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,11 +9,7 @@ from yaml import safe_dump
|
||||||
|
|
||||||
from authentik import __version__
|
from authentik import __version__
|
||||||
from authentik.outposts.controllers.base import BaseController, ControllerException
|
from authentik.outposts.controllers.base import BaseController, ControllerException
|
||||||
from authentik.outposts.models import (
|
from authentik.outposts.models import DockerServiceConnection, Outpost, ServiceConnectionInvalid
|
||||||
DockerServiceConnection,
|
|
||||||
Outpost,
|
|
||||||
ServiceConnectionInvalid,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DockerController(BaseController):
|
class DockerController(BaseController):
|
||||||
|
@ -37,9 +33,7 @@ class DockerController(BaseController):
|
||||||
def _get_env(self) -> dict[str, str]:
|
def _get_env(self) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
"AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(),
|
"AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(),
|
||||||
"AUTHENTIK_INSECURE": str(
|
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure).lower(),
|
||||||
self.outpost.config.authentik_host_insecure
|
|
||||||
).lower(),
|
|
||||||
"AUTHENTIK_TOKEN": self.outpost.token.key,
|
"AUTHENTIK_TOKEN": self.outpost.token.key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,9 +135,7 @@ class DockerController(BaseController):
|
||||||
.lower()
|
.lower()
|
||||||
!= "unless-stopped"
|
!= "unless-stopped"
|
||||||
):
|
):
|
||||||
self.logger.info(
|
self.logger.info("Container has mis-matched restart policy, re-creating...")
|
||||||
"Container has mis-matched restart policy, re-creating..."
|
|
||||||
)
|
|
||||||
self.down()
|
self.down()
|
||||||
return self.up()
|
return self.up()
|
||||||
# Check that container is healthy
|
# Check that container is healthy
|
||||||
|
@ -157,9 +149,7 @@ class DockerController(BaseController):
|
||||||
if has_been_created:
|
if has_been_created:
|
||||||
# Since we've just created the container, give it some time to start.
|
# Since we've just created the container, give it some time to start.
|
||||||
# If its still not up by then, restart it
|
# If its still not up by then, restart it
|
||||||
self.logger.info(
|
self.logger.info("Container is unhealthy and new, giving it time to boot.")
|
||||||
"Container is unhealthy and new, giving it time to boot."
|
|
||||||
)
|
|
||||||
sleep(60)
|
sleep(60)
|
||||||
self.logger.info("Container is unhealthy, restarting...")
|
self.logger.info("Container is unhealthy, restarting...")
|
||||||
container.restart()
|
container.restart()
|
||||||
|
@ -198,9 +188,7 @@ class DockerController(BaseController):
|
||||||
"ports": ports,
|
"ports": ports,
|
||||||
"environment": {
|
"environment": {
|
||||||
"AUTHENTIK_HOST": self.outpost.config.authentik_host,
|
"AUTHENTIK_HOST": self.outpost.config.authentik_host,
|
||||||
"AUTHENTIK_INSECURE": str(
|
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
|
||||||
self.outpost.config.authentik_host_insecure
|
|
||||||
),
|
|
||||||
"AUTHENTIK_TOKEN": self.outpost.token.key,
|
"AUTHENTIK_TOKEN": self.outpost.token.key,
|
||||||
},
|
},
|
||||||
"labels": self._get_labels(),
|
"labels": self._get_labels(),
|
||||||
|
|
|
@ -17,10 +17,7 @@ from kubernetes.client import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from authentik.outposts.controllers.base import FIELD_MANAGER
|
from authentik.outposts.controllers.base import FIELD_MANAGER
|
||||||
from authentik.outposts.controllers.k8s.base import (
|
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
|
||||||
KubernetesObjectReconciler,
|
|
||||||
NeedsUpdate,
|
|
||||||
)
|
|
||||||
from authentik.outposts.models import Outpost
|
from authentik.outposts.models import Outpost
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -124,9 +121,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, reference: V1Deployment):
|
def delete(self, reference: V1Deployment):
|
||||||
return self.api.delete_namespaced_deployment(
|
return self.api.delete_namespaced_deployment(reference.metadata.name, self.namespace)
|
||||||
reference.metadata.name, self.namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(self) -> V1Deployment:
|
def retrieve(self) -> V1Deployment:
|
||||||
return self.api.read_namespaced_deployment(self.name, self.namespace)
|
return self.api.read_namespaced_deployment(self.name, self.namespace)
|
||||||
|
|
|
@ -5,10 +5,7 @@ from typing import TYPE_CHECKING
|
||||||
from kubernetes.client import CoreV1Api, V1Secret
|
from kubernetes.client import CoreV1Api, V1Secret
|
||||||
|
|
||||||
from authentik.outposts.controllers.base import FIELD_MANAGER
|
from authentik.outposts.controllers.base import FIELD_MANAGER
|
||||||
from authentik.outposts.controllers.k8s.base import (
|
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
|
||||||
KubernetesObjectReconciler,
|
|
||||||
NeedsUpdate,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||||
|
@ -38,9 +35,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
||||||
return V1Secret(
|
return V1Secret(
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
data={
|
data={
|
||||||
"authentik_host": b64string(
|
"authentik_host": b64string(self.controller.outpost.config.authentik_host),
|
||||||
self.controller.outpost.config.authentik_host
|
|
||||||
),
|
|
||||||
"authentik_host_insecure": b64string(
|
"authentik_host_insecure": b64string(
|
||||||
str(self.controller.outpost.config.authentik_host_insecure)
|
str(self.controller.outpost.config.authentik_host_insecure)
|
||||||
),
|
),
|
||||||
|
@ -54,9 +49,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, reference: V1Secret):
|
def delete(self, reference: V1Secret):
|
||||||
return self.api.delete_namespaced_secret(
|
return self.api.delete_namespaced_secret(reference.metadata.name, self.namespace)
|
||||||
reference.metadata.name, self.namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(self) -> V1Secret:
|
def retrieve(self) -> V1Secret:
|
||||||
return self.api.read_namespaced_secret(self.name, self.namespace)
|
return self.api.read_namespaced_secret(self.name, self.namespace)
|
||||||
|
|
|
@ -4,10 +4,7 @@ from typing import TYPE_CHECKING
|
||||||
from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec
|
from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec
|
||||||
|
|
||||||
from authentik.outposts.controllers.base import FIELD_MANAGER
|
from authentik.outposts.controllers.base import FIELD_MANAGER
|
||||||
from authentik.outposts.controllers.k8s.base import (
|
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
|
||||||
KubernetesObjectReconciler,
|
|
||||||
NeedsUpdate,
|
|
||||||
)
|
|
||||||
from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler
|
from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -58,9 +55,7 @@ class ServiceReconciler(KubernetesObjectReconciler[V1Service]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, reference: V1Service):
|
def delete(self, reference: V1Service):
|
||||||
return self.api.delete_namespaced_service(
|
return self.api.delete_namespaced_service(reference.metadata.name, self.namespace)
|
||||||
reference.metadata.name, self.namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(self) -> V1Service:
|
def retrieve(self) -> V1Service:
|
||||||
return self.api.read_namespaced_service(self.name, self.namespace)
|
return self.api.read_namespaced_service(self.name, self.namespace)
|
||||||
|
|
|
@ -24,9 +24,7 @@ class KubernetesController(BaseController):
|
||||||
client: ApiClient
|
client: ApiClient
|
||||||
connection: KubernetesServiceConnection
|
connection: KubernetesServiceConnection
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection) -> None:
|
||||||
self, outpost: Outpost, connection: KubernetesServiceConnection
|
|
||||||
) -> None:
|
|
||||||
super().__init__(outpost, connection)
|
super().__init__(outpost, connection)
|
||||||
self.client = connection.client()
|
self.client = connection.client()
|
||||||
self.reconcilers = {
|
self.reconcilers = {
|
||||||
|
|
|
@ -15,9 +15,7 @@ class Migration(migrations.Migration):
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="outpost",
|
model_name="outpost",
|
||||||
name="_config",
|
name="_config",
|
||||||
field=models.JSONField(
|
field=models.JSONField(default=authentik.outposts.models.default_outpost_config),
|
||||||
default=authentik.outposts.models.default_outpost_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name="outpost",
|
model_name="outpost",
|
||||||
|
|
|
@ -10,9 +10,7 @@ def fix_missing_token_identifier(apps: Apps, schema_editor: BaseDatabaseSchemaEd
|
||||||
Token = apps.get_model("authentik_core", "Token")
|
Token = apps.get_model("authentik_core", "Token")
|
||||||
from authentik.outposts.models import Outpost
|
from authentik.outposts.models import Outpost
|
||||||
|
|
||||||
for outpost in (
|
for outpost in Outpost.objects.using(schema_editor.connection.alias).all().only("pk"):
|
||||||
Outpost.objects.using(schema_editor.connection.alias).all().only("pk")
|
|
||||||
):
|
|
||||||
user_identifier = outpost.user_identifier
|
user_identifier = outpost.user_identifier
|
||||||
users = User.objects.filter(username=user_identifier)
|
users = User.objects.filter(username=user_identifier)
|
||||||
if not users.exists():
|
if not users.exists():
|
||||||
|
|
|
@ -14,9 +14,7 @@ import authentik.lib.models
|
||||||
def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
db_alias = schema_editor.connection.alias
|
db_alias = schema_editor.connection.alias
|
||||||
Outpost = apps.get_model("authentik_outposts", "Outpost")
|
Outpost = apps.get_model("authentik_outposts", "Outpost")
|
||||||
DockerServiceConnection = apps.get_model(
|
DockerServiceConnection = apps.get_model("authentik_outposts", "DockerServiceConnection")
|
||||||
"authentik_outposts", "DockerServiceConnection"
|
|
||||||
)
|
|
||||||
KubernetesServiceConnection = apps.get_model(
|
KubernetesServiceConnection = apps.get_model(
|
||||||
"authentik_outposts", "KubernetesServiceConnection"
|
"authentik_outposts", "KubernetesServiceConnection"
|
||||||
)
|
)
|
||||||
|
@ -25,9 +23,7 @@ def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaE
|
||||||
k8s = KubernetesServiceConnection.objects.filter(local=True).first()
|
k8s = KubernetesServiceConnection.objects.filter(local=True).first()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for outpost in (
|
for outpost in Outpost.objects.using(db_alias).all().exclude(deployment_type="custom"):
|
||||||
Outpost.objects.using(db_alias).all().exclude(deployment_type="custom")
|
|
||||||
):
|
|
||||||
if outpost.deployment_type == "kubernetes":
|
if outpost.deployment_type == "kubernetes":
|
||||||
outpost.service_connection = k8s
|
outpost.service_connection = k8s
|
||||||
elif outpost.deployment_type == "docker":
|
elif outpost.deployment_type == "docker":
|
||||||
|
|
|
@ -11,9 +11,7 @@ def remove_pb_prefix_users(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||||
Outpost = apps.get_model("authentik_outposts", "Outpost")
|
Outpost = apps.get_model("authentik_outposts", "Outpost")
|
||||||
|
|
||||||
for outpost in Outpost.objects.using(alias).all():
|
for outpost in Outpost.objects.using(alias).all():
|
||||||
matching = User.objects.using(alias).filter(
|
matching = User.objects.using(alias).filter(username=f"pb-outpost-{outpost.uuid.hex}")
|
||||||
username=f"pb-outpost-{outpost.uuid.hex}"
|
|
||||||
)
|
|
||||||
if matching.exists():
|
if matching.exists():
|
||||||
matching.delete()
|
matching.delete()
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,6 @@ class Migration(migrations.Migration):
|
||||||
migrations.AlterField(
|
migrations.AlterField(
|
||||||
model_name="outpost",
|
model_name="outpost",
|
||||||
name="type",
|
name="type",
|
||||||
field=models.TextField(
|
field=models.TextField(choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"),
|
||||||
choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -64,9 +64,7 @@ class OutpostConfig:
|
||||||
|
|
||||||
log_level: str = CONFIG.y("log_level")
|
log_level: str = CONFIG.y("log_level")
|
||||||
error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled")
|
error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled")
|
||||||
error_reporting_environment: str = CONFIG.y(
|
error_reporting_environment: str = CONFIG.y("error_reporting.environment", "customer")
|
||||||
"error_reporting.environment", "customer"
|
|
||||||
)
|
|
||||||
|
|
||||||
object_naming_template: str = field(default="ak-outpost-%(name)s")
|
object_naming_template: str = field(default="ak-outpost-%(name)s")
|
||||||
kubernetes_replicas: int = field(default=1)
|
kubernetes_replicas: int = field(default=1)
|
||||||
|
@ -264,9 +262,7 @@ class KubernetesServiceConnection(OutpostServiceConnection):
|
||||||
client = self.client()
|
client = self.client()
|
||||||
api_instance = VersionApi(client)
|
api_instance = VersionApi(client)
|
||||||
version: VersionInfo = api_instance.get_code()
|
version: VersionInfo = api_instance.get_code()
|
||||||
return OutpostServiceConnectionState(
|
return OutpostServiceConnectionState(version=version.git_version, healthy=True)
|
||||||
version=version.git_version, healthy=True
|
|
||||||
)
|
|
||||||
except (OpenApiException, HTTPError, ServiceConnectionInvalid):
|
except (OpenApiException, HTTPError, ServiceConnectionInvalid):
|
||||||
return OutpostServiceConnectionState(version="", healthy=False)
|
return OutpostServiceConnectionState(version="", healthy=False)
|
||||||
|
|
||||||
|
@ -360,8 +356,7 @@ class Outpost(ManagedModel):
|
||||||
if isinstance(model_or_perm, models.Model):
|
if isinstance(model_or_perm, models.Model):
|
||||||
model_or_perm: models.Model
|
model_or_perm: models.Model
|
||||||
code_name = (
|
code_name = (
|
||||||
f"{model_or_perm._meta.app_label}."
|
f"{model_or_perm._meta.app_label}." f"view_{model_or_perm._meta.model_name}"
|
||||||
f"view_{model_or_perm._meta.model_name}"
|
|
||||||
)
|
)
|
||||||
assign_perm(code_name, user, model_or_perm)
|
assign_perm(code_name, user, model_or_perm)
|
||||||
else:
|
else:
|
||||||
|
@ -417,9 +412,7 @@ class Outpost(ManagedModel):
|
||||||
self,
|
self,
|
||||||
"authentik_events.add_event",
|
"authentik_events.add_event",
|
||||||
]
|
]
|
||||||
for provider in (
|
for provider in Provider.objects.filter(outpost=self).select_related().select_subclasses():
|
||||||
Provider.objects.filter(outpost=self).select_related().select_subclasses()
|
|
||||||
):
|
|
||||||
if isinstance(provider, OutpostModel):
|
if isinstance(provider, OutpostModel):
|
||||||
objects.extend(provider.get_required_objects())
|
objects.extend(provider.get_required_objects())
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -9,11 +9,7 @@ from authentik.core.models import Provider
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.lib.utils.reflection import class_to_path
|
from authentik.lib.utils.reflection import class_to_path
|
||||||
from authentik.outposts.models import Outpost, OutpostServiceConnection
|
from authentik.outposts.models import Outpost, OutpostServiceConnection
|
||||||
from authentik.outposts.tasks import (
|
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
|
||||||
CACHE_KEY_OUTPOST_DOWN,
|
|
||||||
outpost_controller,
|
|
||||||
outpost_post_save,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
UPDATE_TRIGGERING_MODELS = (
|
UPDATE_TRIGGERING_MODELS = (
|
||||||
|
@ -37,9 +33,7 @@ def pre_save_outpost(sender, instance: Outpost, **_):
|
||||||
# Name changes the deployment name, need to recreate
|
# Name changes the deployment name, need to recreate
|
||||||
dirty += old_instance.name != instance.name
|
dirty += old_instance.name != instance.name
|
||||||
# namespace requires re-create
|
# namespace requires re-create
|
||||||
dirty += (
|
dirty += old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace
|
||||||
old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace
|
|
||||||
)
|
|
||||||
if bool(dirty):
|
if bool(dirty):
|
||||||
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
|
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
|
||||||
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)
|
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)
|
||||||
|
|
|
@ -62,9 +62,7 @@ def controller_for_outpost(outpost: Outpost) -> Optional[BaseController]:
|
||||||
def outpost_service_connection_state(connection_pk: Any):
|
def outpost_service_connection_state(connection_pk: Any):
|
||||||
"""Update cached state of a service connection"""
|
"""Update cached state of a service connection"""
|
||||||
connection: OutpostServiceConnection = (
|
connection: OutpostServiceConnection = (
|
||||||
OutpostServiceConnection.objects.filter(pk=connection_pk)
|
OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first()
|
||||||
.select_subclasses()
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
if not connection:
|
if not connection:
|
||||||
return
|
return
|
||||||
|
@ -157,9 +155,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
|
||||||
outpost_controller.delay(instance.pk)
|
outpost_controller.delay(instance.pk)
|
||||||
|
|
||||||
if isinstance(instance, (OutpostModel, Outpost)):
|
if isinstance(instance, (OutpostModel, Outpost)):
|
||||||
LOGGER.debug(
|
LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
|
||||||
"triggering outpost update from outpostmodel/outpost", instance=instance
|
|
||||||
)
|
|
||||||
outpost_send_update(instance)
|
outpost_send_update(instance)
|
||||||
|
|
||||||
if isinstance(instance, OutpostServiceConnection):
|
if isinstance(instance, OutpostServiceConnection):
|
||||||
|
@ -208,9 +204,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
||||||
layer = get_channel_layer()
|
layer = get_channel_layer()
|
||||||
for state in OutpostState.for_outpost(outpost):
|
for state in OutpostState.for_outpost(outpost):
|
||||||
for channel in state.channel_ids:
|
for channel in state.channel_ids:
|
||||||
LOGGER.debug(
|
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
|
||||||
"sending update", channel=channel, instance=state.uid, outpost=outpost
|
|
||||||
)
|
|
||||||
async_to_sync(layer.send)(channel, {"type": "event.update"})
|
async_to_sync(layer.send)(channel, {"type": "event.update"})
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,9 +225,7 @@ def outpost_local_connection():
|
||||||
if Path(kubeconfig_path).exists():
|
if Path(kubeconfig_path).exists():
|
||||||
LOGGER.debug("Detected kubeconfig")
|
LOGGER.debug("Detected kubeconfig")
|
||||||
kubeconfig_local_name = f"k8s-{gethostname()}"
|
kubeconfig_local_name = f"k8s-{gethostname()}"
|
||||||
if not KubernetesServiceConnection.objects.filter(
|
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
|
||||||
name=kubeconfig_local_name
|
|
||||||
).exists():
|
|
||||||
LOGGER.debug("Creating kubeconfig Service Connection")
|
LOGGER.debug("Creating kubeconfig Service Connection")
|
||||||
with open(kubeconfig_path, "r") as _kubeconfig:
|
with open(kubeconfig_path, "r") as _kubeconfig:
|
||||||
KubernetesServiceConnection.objects.create(
|
KubernetesServiceConnection.objects.create(
|
||||||
|
|
|
@ -63,9 +63,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
||||||
provider = ProxyProvider.objects.create(
|
provider = ProxyProvider.objects.create(
|
||||||
name="test", authorization_flow=Flow.objects.first()
|
name="test", authorization_flow=Flow.objects.first()
|
||||||
)
|
)
|
||||||
invalid = OutpostSerializer(
|
invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": {}})
|
||||||
data={"name": "foo", "providers": [provider.pk], "config": {}}
|
|
||||||
)
|
|
||||||
self.assertFalse(invalid.is_valid())
|
self.assertFalse(invalid.is_valid())
|
||||||
self.assertIn("config", invalid.errors)
|
self.assertIn("config", invalid.errors)
|
||||||
valid = OutpostSerializer(
|
valid = OutpostSerializer(
|
||||||
|
|
|
@ -2,11 +2,7 @@
|
||||||
from typing import OrderedDict
|
from typing import OrderedDict
|
||||||
|
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from rest_framework.serializers import (
|
from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError
|
||||||
ModelSerializer,
|
|
||||||
PrimaryKeyRelatedField,
|
|
||||||
ValidationError,
|
|
||||||
)
|
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue