root: reformat to 100 line width

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-08-03 17:45:16 +02:00
parent b87903a209
commit 77ed25ae34
272 changed files with 825 additions and 2590 deletions

View file

@ -23,9 +23,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
date_from = now() - timedelta(days=1)
result = (
Event.objects.filter(created__gte=date_from, **filter_kwargs)
.annotate(
age=ExpressionWrapper(now() - F("created"), output_field=DurationField())
)
.annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField()))
.annotate(age_hours=ExtractHour("age"))
.values("age_hours")
.annotate(count=Count("pk"))
@ -37,8 +35,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
for hour in range(0, -24, -1):
results.append(
{
"x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple())
* 1000,
"x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000,
"y_cord": data[hour * -1],
}
)

View file

@ -61,9 +61,7 @@ class SystemSerializer(PassiveSerializer):
return {
"python_version": python_version,
"gunicorn_version": ".".join(str(x) for x in gunicorn_version),
"environment": "kubernetes"
if SERVICE_HOST_ENV_NAME in os.environ
else "compose",
"environment": "kubernetes" if SERVICE_HOST_ENV_NAME in os.environ else "compose",
"architecture": platform.machine(),
"platform": platform.platform(),
"uname": " ".join(platform.uname()),

View file

@ -92,10 +92,7 @@ class TaskViewSet(ViewSet):
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
messages.success(
self.request,
_(
"Successfully re-scheduled Task %(name)s!"
% {"name": task.task_name}
),
_("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}),
)
return Response(status=204)
except ImportError: # pragma: no cover

View file

@ -41,9 +41,7 @@ class VersionSerializer(PassiveSerializer):
def get_outdated(self, instance) -> bool:
"""Check if we're running the latest version"""
return parse(self.get_version_current(instance)) < parse(
self.get_version_latest(instance)
)
return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
class VersionView(APIView):

View file

@ -17,9 +17,7 @@ class WorkerView(APIView):
permission_classes = [IsAdminUser]
@extend_schema(
responses=inline_serializer("Workers", fields={"count": IntegerField()})
)
@extend_schema(responses=inline_serializer("Workers", fields={"count": IntegerField()}))
def get(self, request: Request) -> Response:
"""Get currently connected worker count."""
count = len(CELERY_APP.control.ping(timeout=0.5))

View file

@ -37,18 +37,14 @@ def _set_prom_info():
def update_latest_version(self: MonitoredTask):
"""Update latest version info"""
try:
response = get(
"https://api.github.com/repos/goauthentik/authentik/releases/latest"
)
response = get("https://api.github.com/repos/goauthentik/authentik/releases/latest")
response.raise_for_status()
data = response.json()
tag_name = data.get("tag_name")
upstream_version = tag_name.split("/")[1]
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"]
)
TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"])
)
_set_prom_info()
# Check if upstream version is newer than what we're running,

View file

@ -27,9 +27,7 @@ class TestAdminAPI(TestCase):
response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertTrue(
any(task["task_name"] == "clean_expired_models" for task in body)
)
self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
def test_tasks_single(self):
"""Test Task API (read single)"""
@ -45,9 +43,7 @@ class TestAdminAPI(TestCase):
self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name)
self.assertEqual(body["task_name"], "clean_expired_models")
response = self.client.get(
reverse(
"authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"}
)
reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
)
self.assertEqual(response.status_code, 404)

View file

@ -7,9 +7,7 @@ from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
def permission_required(
perm: Optional[str] = None, other_perms: Optional[list[str]] = None
):
def permission_required(perm: Optional[str] = None, other_perms: Optional[list[str]] = None):
"""Check permissions for a single custom action"""
def wrapper_outter(func: Callable):

View file

@ -63,9 +63,7 @@ def postprocess_schema_responses(result, generator, **kwargs): # noqa: W0613
method["responses"].setdefault("400", validation_error.ref)
method["responses"].setdefault("403", generic_error.ref)
result["components"] = generator.registry.build(
spectacular_settings.APPEND_COMPONENTS
)
result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)
# This is a workaround for authentik/stages/prompt/stage.py
# since the serializer PromptChallengeResponse

View file

@ -16,17 +16,13 @@ class TestAPIAuth(TestCase):
def test_valid_basic(self):
"""Test valid token"""
token = Token.objects.create(
intent=TokenIntents.INTENT_API, user=get_anonymous_user()
)
token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user())
auth = b64encode(f":{token.key}".encode()).decode()
self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user)
def test_valid_bearer(self):
"""Test valid token"""
token = Token.objects.create(
intent=TokenIntents.INTENT_API, user=get_anonymous_user()
)
token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user())
self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user)
def test_invalid_type(self):

View file

@ -52,20 +52,12 @@ from authentik.policies.reputation.api import (
from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet
from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet
from authentik.providers.oauth2.api.scope import ScopeMappingViewSet
from authentik.providers.oauth2.api.tokens import (
AuthorizationCodeViewSet,
RefreshTokenViewSet,
)
from authentik.providers.proxy.api import (
ProxyOutpostConfigViewSet,
ProxyProviderViewSet,
)
from authentik.providers.oauth2.api.tokens import AuthorizationCodeViewSet, RefreshTokenViewSet
from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet
from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet
from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet
from authentik.sources.oauth.api.source import OAuthSourceViewSet
from authentik.sources.oauth.api.source_connection import (
UserOAuthSourceConnectionViewSet,
)
from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet
from authentik.sources.plex.api import PlexSourceViewSet
from authentik.sources.saml.api import SAMLSourceViewSet
from authentik.stages.authenticator_duo.api import (
@ -83,9 +75,7 @@ from authentik.stages.authenticator_totp.api import (
TOTPAdminDeviceViewSet,
TOTPDeviceViewSet,
)
from authentik.stages.authenticator_validate.api import (
AuthenticatorValidateStageViewSet,
)
from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageViewSet
from authentik.stages.authenticator_webauthn.api import (
AuthenticateWebAuthnStageViewSet,
WebAuthnAdminDeviceViewSet,
@ -122,9 +112,7 @@ router.register("core/tenants", TenantViewSet)
router.register("outposts/instances", OutpostViewSet)
router.register("outposts/service_connections/all", ServiceConnectionViewSet)
router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet)
router.register(
"outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet
)
router.register("outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet)
router.register("outposts/proxy", ProxyOutpostConfigViewSet)
router.register("outposts/ldap", LDAPOutpostConfigViewSet)
@ -184,9 +172,7 @@ router.register(
StaticAdminDeviceViewSet,
basename="admin-staticdevice",
)
router.register(
"authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice"
)
router.register("authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice")
router.register(
"authenticators/admin/webauthn",
WebAuthnAdminDeviceViewSet,

View file

@ -147,9 +147,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
"""Custom list method that checks Policy based access instead of guardian"""
should_cache = request.GET.get("search", "") == ""
superuser_full_list = (
str(request.GET.get("superuser_full_list", "false")).lower() == "true"
)
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
if superuser_full_list and request.user.is_superuser:
return super().list(request)
@ -240,9 +238,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
app.save()
return Response({})
@permission_required(
"authentik_core.view_application", ["authentik_events.view_event"]
)
@permission_required("authentik_core.view_application", ["authentik_events.view_event"])
@extend_schema(responses={200: CoordinateSerializer(many=True)})
@action(detail=True, pagination_class=None, filter_backends=[])
# pylint: disable=unused-argument

View file

@ -68,9 +68,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
"""Get parsed user agent"""
return user_agent_parser.Parse(instance.last_user_agent)
def get_geo_ip(
self, instance: AuthenticatedSession
) -> Optional[GeoIPDict]: # pragma: no cover
def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover
"""Get parsed user agent"""
return GEOIP_READER.city_dict(instance.last_ip)

View file

@ -15,11 +15,7 @@ from rest_framework.viewsets import GenericViewSet
from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import (
MetaNameSerializer,
PassiveSerializer,
TypeCreateSerializer,
)
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
from authentik.core.expression import PropertyMappingEvaluator
from authentik.core.models import PropertyMapping
from authentik.lib.utils.reflection import all_subclasses
@ -141,9 +137,7 @@ class PropertyMappingViewSet(
self.request,
**test_params.validated_data.get("context", {}),
)
response_data["result"] = dumps(
result, indent=(4 if format_result else None)
)
response_data["result"] = dumps(result, indent=(4 if format_result else None))
except Exception as exc: # pylint: disable=broad-except
response_data["result"] = str(exc)
response_data["successful"] = False

View file

@ -93,9 +93,7 @@ class SourceViewSet(
@action(detail=False, pagination_class=None, filter_backends=[])
def user_settings(self, request: Request) -> Response:
"""Get all sources the user can configure"""
_all_sources: Iterable[Source] = Source.objects.filter(
enabled=True
).select_subclasses()
_all_sources: Iterable[Source] = Source.objects.filter(enabled=True).select_subclasses()
matching_sources: list[UserSettingSerializer] = []
for source in _all_sources:
user_settings = source.ui_user_settings

View file

@ -70,9 +70,7 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
serializer.save(
user=self.request.user,
intent=TokenIntents.INTENT_API,
expiring=self.request.user.attributes.get(
USER_ATTRIBUTE_TOKEN_EXPIRING, True
),
expiring=self.request.user.attributes.get(USER_ATTRIBUTE_TOKEN_EXPIRING, True),
)
@permission_required("authentik_core.view_token_key")
@ -89,7 +87,5 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
token: Token = self.get_object()
if token.is_expired:
raise Http404
Event.new(EventAction.SECRET_VIEW, secret=token).from_http( # noqa # nosec
request
)
Event.new(EventAction.SECRET_VIEW, secret=token).from_http(request) # noqa # nosec
return Response(TokenViewSerializer({"key": token.key}).data)

View file

@ -79,9 +79,7 @@ class UsedByMixin:
).all():
# Only merge shadows on first object
if first_object:
shadows += getattr(
manager.model._meta, "authentik_used_by_shadows", []
)
shadows += getattr(manager.model._meta, "authentik_used_by_shadows", [])
first_object = False
serializer = UsedBySerializer(
data={

View file

@ -26,10 +26,7 @@ from authentik.api.decorators import permission_required
from authentik.core.api.groups import GroupSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
from authentik.core.middleware import (
SESSION_IMPERSONATE_ORIGINAL_USER,
SESSION_IMPERSONATE_USER,
)
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
from authentik.core.models import Token, TokenIntents, User
from authentik.events.models import EventAction
from authentik.tenants.models import Tenant
@ -87,17 +84,13 @@ class UserMetricsSerializer(PassiveSerializer):
def get_logins_failed_per_1h(self, _):
"""Get failed logins per hour for the last 24 hours"""
user = self.context["user"]
return get_events_per_1h(
action=EventAction.LOGIN_FAILED, context__username=user.username
)
return get_events_per_1h(action=EventAction.LOGIN_FAILED, context__username=user.username)
@extend_schema_field(CoordinateSerializer(many=True))
def get_authorizations_per_1h(self, _):
"""Get failed logins per hour for the last 24 hours"""
user = self.context["user"]
return get_events_per_1h(
action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk
)
return get_events_per_1h(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk)
class UsersFilter(FilterSet):
@ -154,9 +147,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
# pylint: disable=invalid-name
def me(self, request: Request) -> Response:
"""Get information about current user"""
serializer = SessionUserSerializer(
data={"user": UserSerializer(request.user).data}
)
serializer = SessionUserSerializer(data={"user": UserSerializer(request.user).data})
if SESSION_IMPERSONATE_USER in request._request.session:
serializer.initial_data["original"] = UserSerializer(
request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER]

View file

@ -3,20 +3,14 @@ from typing import Any
from django.db.models import Model
from rest_framework.fields import CharField, IntegerField
from rest_framework.serializers import (
Serializer,
SerializerMethodField,
ValidationError,
)
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
def is_dict(value: Any):
"""Ensure a value is a dictionary, useful for JSONFields"""
if isinstance(value, dict):
return
raise ValidationError(
"Value must be a dictionary, and not have any duplicate keys."
)
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
class PassiveSerializer(Serializer):
@ -25,9 +19,7 @@ class PassiveSerializer(Serializer):
def create(self, validated_data: dict) -> Model: # pragma: no cover
return Model()
def update(
self, instance: Model, validated_data: dict
) -> Model: # pragma: no cover
def update(self, instance: Model, validated_data: dict) -> Model: # pragma: no cover
return Model()
class Meta:

View file

@ -38,9 +38,7 @@ class Migration(migrations.Migration):
("password", models.CharField(max_length=128, verbose_name="password")),
(
"last_login",
models.DateTimeField(
blank=True, null=True, verbose_name="last login"
),
models.DateTimeField(blank=True, null=True, verbose_name="last login"),
),
(
"is_superuser",
@ -53,35 +51,25 @@ class Migration(migrations.Migration):
(
"username",
models.CharField(
error_messages={
"unique": "A user with that username already exists."
},
error_messages={"unique": "A user with that username already exists."},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150,
unique=True,
validators=[
django.contrib.auth.validators.UnicodeUsernameValidator()
],
validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
verbose_name="username",
),
),
(
"first_name",
models.CharField(
blank=True, max_length=30, verbose_name="first name"
),
models.CharField(blank=True, max_length=30, verbose_name="first name"),
),
(
"last_name",
models.CharField(
blank=True, max_length=150, verbose_name="last name"
),
models.CharField(blank=True, max_length=150, verbose_name="last name"),
),
(
"email",
models.EmailField(
blank=True, max_length=254, verbose_name="email address"
),
models.EmailField(blank=True, max_length=254, verbose_name="email address"),
),
(
"is_staff",
@ -217,9 +205,7 @@ class Migration(migrations.Migration):
),
(
"expires",
models.DateTimeField(
default=authentik.core.models.default_token_duration
),
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
("description", models.TextField(blank=True, default="")),
@ -306,9 +292,7 @@ class Migration(migrations.Migration):
("name", models.TextField(help_text="Application's display Name.")),
(
"slug",
models.SlugField(
help_text="Internal application name, used in URLs."
),
models.SlugField(help_text="Internal application name, used in URLs."),
),
("skip_authorization", models.BooleanField(default=False)),
("meta_launch_url", models.URLField(blank=True, default="")),

View file

@ -17,9 +17,7 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
username="akadmin", email="root@localhost", name="authentik Default Admin"
)
if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST:
akadmin.set_password(
environ.get("AK_ADMIN_PASS", "akadmin"), signal=False
) # noqa # nosec
akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False) # noqa # nosec
else:
akadmin.set_unusable_password()
akadmin.save()

View file

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="source",
name="slug",
field=models.SlugField(
help_text="Internal source name, used in URLs.", unique=True
),
field=models.SlugField(help_text="Internal source name, used in URLs.", unique=True),
),
]

View file

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="first_name",
field=models.CharField(
blank=True, max_length=150, verbose_name="first name"
),
field=models.CharField(blank=True, max_length=150, verbose_name="first name"),
),
]

View file

@ -40,9 +40,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="pb_groups",
field=models.ManyToManyField(
related_name="users", to="authentik_core.Group"
),
field=models.ManyToManyField(related_name="users", to="authentik_core.Group"),
),
migrations.AddField(
model_name="group",

View file

@ -42,9 +42,7 @@ class Migration(migrations.Migration):
),
migrations.AddIndex(
model_name="token",
index=models.Index(
fields=["identifier"], name="authentik_co_identif_1a34a8_idx"
),
index=models.Index(fields=["identifier"], name="authentik_co_identif_1a34a8_idx"),
),
migrations.RunPython(set_default_token_key),
]

View file

@ -17,8 +17,6 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="application",
name="meta_icon",
field=models.FileField(
blank=True, default="", upload_to="application-icons/"
),
field=models.FileField(blank=True, default="", upload_to="application-icons/"),
),
]

View file

@ -25,9 +25,7 @@ class Migration(migrations.Migration):
),
migrations.AddIndex(
model_name="token",
index=models.Index(
fields=["identifier"], name="authentik_c_identif_d9d032_idx"
),
index=models.Index(fields=["identifier"], name="authentik_c_identif_d9d032_idx"),
),
migrations.AddIndex(
model_name="token",

View file

@ -32,16 +32,12 @@ class Migration(migrations.Migration):
fields=[
(
"expires",
models.DateTimeField(
default=authentik.core.models.default_token_duration
),
models.DateTimeField(default=authentik.core.models.default_token_duration),
),
("expiring", models.BooleanField(default=True)),
(
"uuid",
models.UUIDField(
default=uuid.uuid4, primary_key=True, serialize=False
),
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
),
("session_key", models.CharField(max_length=40)),
("last_ip", models.TextField()),

View file

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="application",
name="meta_icon",
field=models.FileField(
default=None, null=True, upload_to="application-icons/"
),
field=models.FileField(default=None, null=True, upload_to="application-icons/"),
),
]

View file

@ -154,9 +154,7 @@ class User(GuardianUserMixin, AbstractUser):
("s", "158"),
("r", "g"),
]
gravatar_url = (
f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
)
gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
return escape(gravatar_url)
return mode % {
"username": self.username,
@ -186,9 +184,7 @@ class Provider(SerializerModel):
related_name="provider_authorization",
)
property_mappings = models.ManyToManyField(
"PropertyMapping", default=None, blank=True
)
property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
objects = InheritanceManager()
@ -218,9 +214,7 @@ class Application(PolicyBindingModel):
add custom fields and other properties"""
name = models.TextField(help_text=_("Application's display Name."))
slug = models.SlugField(
help_text=_("Internal application name, used in URLs."), unique=True
)
slug = models.SlugField(help_text=_("Internal application name, used in URLs."), unique=True)
provider = models.OneToOneField(
"Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT
)
@ -244,9 +238,7 @@ class Application(PolicyBindingModel):
it is returned as-is"""
if not self.meta_icon:
return None
if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith(
"/static"
):
if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"):
return self.meta_icon.name
return self.meta_icon.url
@ -301,14 +293,10 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
name = models.TextField(help_text=_("Source's display Name."))
slug = models.SlugField(
help_text=_("Internal source name, used in URLs."), unique=True
)
slug = models.SlugField(help_text=_("Internal source name, used in URLs."), unique=True)
enabled = models.BooleanField(default=True)
property_mappings = models.ManyToManyField(
"PropertyMapping", default=None, blank=True
)
property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
authentication_flow = models.ForeignKey(
Flow,
@ -481,9 +469,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
"""Get serializer for this model"""
raise NotImplementedError
def evaluate(
self, user: Optional[User], request: Optional[HttpRequest], **kwargs
) -> Any:
def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any:
"""Evaluate `self.expression` using `**kwargs` as Context."""
from authentik.core.expression import PropertyMappingEvaluator
@ -522,9 +508,7 @@ class AuthenticatedSession(ExpiringModel):
last_used = models.DateTimeField(auto_now=True)
@staticmethod
def from_request(
request: HttpRequest, user: User
) -> Optional["AuthenticatedSession"]:
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
"""Create a new session from a http request"""
if not hasattr(request, "session") or not request.session.session_key:
return None

View file

@ -14,9 +14,7 @@ from prometheus_client import Gauge
# Arguments: user: User, password: str
password_changed = Signal()
GAUGE_MODELS = Gauge(
"authentik_models", "Count of various objects", ["model_name", "app"]
)
GAUGE_MODELS = Gauge("authentik_models", "Count of various objects", ["model_name", "app"])
if TYPE_CHECKING:
from authentik.core.models import AuthenticatedSession, User
@ -60,15 +58,11 @@ def user_logged_out_session(sender, request: HttpRequest, user: "User", **_):
"""Delete AuthenticatedSession if it exists"""
from authentik.core.models import AuthenticatedSession
AuthenticatedSession.objects.filter(
session_key=request.session.session_key
).delete()
AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
@receiver(pre_delete)
def authenticated_session_delete(
sender: Type[Model], instance: "AuthenticatedSession", **_
):
def authenticated_session_delete(sender: Type[Model], instance: "AuthenticatedSession", **_):
"""Delete session when authenticated session is deleted"""
from authentik.core.models import AuthenticatedSession

View file

@ -11,16 +11,8 @@ from django.urls import reverse
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger
from authentik.core.models import (
Source,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostUserEnrollmentStage,
)
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow, Stage, in_memory_stage
from authentik.flows.planner import (
@ -76,9 +68,7 @@ class SourceFlowManager:
# pylint: disable=too-many-return-statements
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
"""decide which action should be taken"""
new_connection = self.connection_type(
source=self.source, identifier=self.identifier
)
new_connection = self.connection_type(source=self.source, identifier=self.identifier)
# When request is authenticated, always link
if self.request.user.is_authenticated:
new_connection.user = self.request.user
@ -113,9 +103,7 @@ class SourceFlowManager:
SourceUserMatchingModes.USERNAME_DENY,
]:
if not self.enroll_info.get("username", None):
self._logger.warning(
"Refusing to use none username", source=self.source
)
self._logger.warning("Refusing to use none username", source=self.source)
return Action.DENY, None
query = Q(username__exact=self.enroll_info.get("username", None))
self._logger.debug("trying to link with existing user", query=query)
@ -229,10 +217,7 @@ class SourceFlowManager:
"""Login user and redirect."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
_("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
)
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
@ -270,10 +255,7 @@ class SourceFlowManager:
"""User was not authenticated and previous request was not authenticated."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
_("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
)
# We run the Flow planner here so we can pass the Pending user in the context

View file

@ -27,9 +27,7 @@ def clean_expired_models(self: MonitoredTask):
for cls in ExpiringModel.__subclasses__():
cls: ExpiringModel
objects = (
cls.objects.all()
.exclude(expiring=False)
.exclude(expiring=True, expires__gt=now())
cls.objects.all().exclude(expiring=False).exclude(expiring=True, expires__gt=now())
)
for obj in objects:
obj.expire_action()

View file

@ -17,9 +17,7 @@ class TestApplicationsAPI(APITestCase):
self.denied = Application.objects.create(name="denied", slug="denied")
PolicyBinding.objects.create(
target=self.denied,
policy=DummyPolicy.objects.create(
name="deny", result=False, wait_min=1, wait_max=2
),
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
order=0,
)
@ -33,9 +31,7 @@ class TestApplicationsAPI(APITestCase):
)
)
self.assertEqual(response.status_code, 200)
self.assertJSONEqual(
force_str(response.content), {"messages": [], "passing": True}
)
self.assertJSONEqual(force_str(response.content), {"messages": [], "passing": True})
response = self.client.get(
reverse(
"authentik_api:application-check-access",
@ -43,9 +39,7 @@ class TestApplicationsAPI(APITestCase):
)
)
self.assertEqual(response.status_code, 200)
self.assertJSONEqual(
force_str(response.content), {"messages": ["dummy"], "passing": False}
)
self.assertJSONEqual(force_str(response.content), {"messages": ["dummy"], "passing": False})
def test_list(self):
"""Test list operation without superuser_full_list"""

View file

@ -46,9 +46,7 @@ class TestImpersonation(TestCase):
self.client.force_login(self.other_user)
self.client.get(
reverse(
"authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk}
)
reverse("authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk})
)
response = self.client.get(reverse("authentik_api:user-me"))

View file

@ -22,9 +22,7 @@ class TestModels(TestCase):
def test_token_expire_no_expire(self):
"""Test token expiring with "expiring" set"""
token = Token.objects.create(
expires=now(), user=get_anonymous_user(), expiring=False
)
token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
sleep(0.5)
self.assertFalse(token.is_expired)

View file

@ -16,9 +16,7 @@ class TestPropertyMappings(TestCase):
def test_expression(self):
"""Test expression"""
mapping = PropertyMapping.objects.create(
name="test", expression="return 'test'"
)
mapping = PropertyMapping.objects.create(name="test", expression="return 'test'")
self.assertEqual(mapping.evaluate(None, None), "test")
def test_expression_syntax(self):

View file

@ -23,9 +23,7 @@ class TestPropertyMappingAPI(APITestCase):
def test_test_call(self):
"""Test PropertMappings's test endpoint"""
response = self.client.post(
reverse(
"authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}
),
reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}),
data={
"user": self.user.pk,
},

View file

@ -4,12 +4,7 @@ from django.utils.timezone import now
from guardian.shortcuts import get_anonymous_user
from rest_framework.test import APITestCase
from authentik.core.models import (
USER_ATTRIBUTE_TOKEN_EXPIRING,
Token,
TokenIntents,
User,
)
from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
from authentik.core.tasks import clean_expired_models

View file

@ -5,10 +5,7 @@ from django.shortcuts import get_object_or_404, redirect
from django.views import View
from structlog.stdlib import get_logger
from authentik.core.middleware import (
SESSION_IMPERSONATE_ORIGINAL_USER,
SESSION_IMPERSONATE_USER,
)
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
from authentik.core.models import User
from authentik.events.models import Event, EventAction
@ -21,9 +18,7 @@ class ImpersonateInitView(View):
def get(self, request: HttpRequest, user_id: int) -> HttpResponse:
"""Impersonation handler, checks permissions"""
if not request.user.has_perm("impersonate"):
LOGGER.debug(
"User attempted to impersonate without permissions", user=request.user
)
LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
return HttpResponse("Unauthorized", status=401)
user_to_be = get_object_or_404(User, pk=user_id)

View file

@ -14,9 +14,7 @@ class EndSessionView(TemplateView, PolicyAccessView):
template_name = "if/end_session.html"
def resolve_provider_application(self):
self.application = get_object_or_404(
Application, slug=self.kwargs["application_slug"]
)
self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
context = super().get_context_data(**kwargs)

View file

@ -10,12 +10,7 @@ from django_filters.filters import BooleanFilter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import (
CharField,
DateTimeField,
IntegerField,
SerializerMethodField,
)
from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, ValidationError
@ -86,9 +81,7 @@ class CertificateKeyPairSerializer(ModelSerializer):
backend=default_backend(),
)
except (ValueError, TypeError):
raise ValidationError(
"Unable to load private key (possibly encrypted?)."
)
raise ValidationError("Unable to load private key (possibly encrypted?).")
return value
class Meta:
@ -123,9 +116,7 @@ class CertificateGenerationSerializer(PassiveSerializer):
"""Certificate generation parameters"""
common_name = CharField()
subject_alt_name = CharField(
required=False, allow_blank=True, label=_("Subject-alt name")
)
subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name"))
validity_days = IntegerField(initial=365)
@ -170,9 +161,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
builder = CertificateBuilder()
builder.common_name = data.validated_data["common_name"]
builder.build(
subject_alt_names=data.validated_data.get("subject_alt_name", "").split(
","
),
subject_alt_names=data.validated_data.get("subject_alt_name", "").split(","),
validity_days=int(data.validated_data["validity_days"]),
)
instance = builder.save()
@ -208,9 +197,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
"Content-Disposition"
] = f'attachment; filename="{certificate.name}_certificate.pem"'
return response
return Response(
CertificateDataSerializer({"data": certificate.certificate_data}).data
)
return Response(CertificateDataSerializer({"data": certificate.certificate_data}).data)
@extend_schema(
parameters=[
@ -234,9 +221,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
).from_http(request)
if "download" in request._request.GET:
# Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html
response = HttpResponse(
certificate.key_data, content_type="application/x-pem-file"
)
response = HttpResponse(certificate.key_data, content_type="application/x-pem-file")
response[
"Content-Disposition"
] = f'attachment; filename="{certificate.name}_private_key.pem"'

View file

@ -46,9 +46,7 @@ class CertificateBuilder:
public_exponent=65537, key_size=2048, backend=default_backend()
)
self.__public_key = self.__private_key.public_key()
alt_names: list[x509.GeneralName] = [
x509.DNSName(x) for x in subject_alt_names or []
]
alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []]
self.__builder = (
x509.CertificateBuilder()
.subject_name(
@ -59,9 +57,7 @@ class CertificateBuilder:
self.common_name,
),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
x509.NameAttribute(
NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"
),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
]
)
)
@ -77,9 +73,7 @@ class CertificateBuilder:
)
.add_extension(x509.SubjectAlternativeName(alt_names), critical=True)
.not_valid_before(datetime.datetime.today() - one_day)
.not_valid_after(
datetime.datetime.today() + datetime.timedelta(days=validity_days)
)
.not_valid_after(datetime.datetime.today() + datetime.timedelta(days=validity_days))
.serial_number(int(uuid.uuid4()))
.public_key(self.__public_key)
)

View file

@ -57,9 +57,7 @@ class CertificateKeyPair(CreatedUpdatedModel):
if not self._private_key and self._private_key != "":
try:
self._private_key = load_pem_private_key(
str.encode(
"\n".join([x.strip() for x in self.key_data.split("\n")])
),
str.encode("\n".join([x.strip() for x in self.key_data.split("\n")])),
password=None,
backend=default_backend(),
)
@ -70,25 +68,19 @@ class CertificateKeyPair(CreatedUpdatedModel):
@property
def fingerprint_sha256(self) -> str:
"""Get SHA256 Fingerprint of certificate_data"""
return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode(
"utf-8"
)
return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode("utf-8")
@property
def fingerprint_sha1(self) -> str:
"""Get SHA1 Fingerprint of certificate_data"""
return hexlify(
self.certificate.fingerprint(hashes.SHA1()), ":" # nosec
).decode("utf-8")
return hexlify(self.certificate.fingerprint(hashes.SHA1()), ":").decode("utf-8") # nosec
@property
def kid(self):
"""Get Key ID used for JWKS"""
return "{0}".format(
md5(self.key_data.encode("utf-8")).hexdigest() # nosec
if self.key_data
else ""
)
md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else ""
) # nosec
def __str__(self) -> str:
return f"Certificate-Key Pair {self.name}"

View file

@ -143,7 +143,5 @@ class EventViewSet(ModelViewSet):
"""Get all actions"""
data = []
for value, name in EventAction.choices:
data.append(
{"name": name, "description": "", "component": value, "model_name": ""}
)
data.append({"name": name, "description": "", "component": value, "model_name": ""})
return Response(TypeCreateSerializer(data, many=True).data)

View file

@ -29,12 +29,8 @@ class AuditMiddleware:
def __call__(self, request: HttpRequest) -> HttpResponse:
# Connect signal for automatic logging
if hasattr(request, "user") and getattr(
request.user, "is_authenticated", False
):
post_save_handler = partial(
self.post_save_handler, user=request.user, request=request
)
if hasattr(request, "user") and getattr(request.user, "is_authenticated", False):
post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
pre_delete_handler = partial(
self.pre_delete_handler, user=request.user, request=request
)
@ -94,13 +90,9 @@ class AuditMiddleware:
@staticmethod
# pylint: disable=unused-argument
def pre_delete_handler(
user: User, request: HttpRequest, sender, instance: Model, **_
):
def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_):
"""Signal handler for all object's pre_delete"""
if isinstance(
instance, (Event, Notification, UserObjectPermission)
): # pragma: no cover
if isinstance(instance, (Event, Notification, UserObjectPermission)): # pragma: no cover
return
EventNewThread(

View file

@ -14,9 +14,7 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
event.delete()
# Because event objects cannot be updated, we have to re-create them
event.pk = None
event.user_json = (
authentik.events.models.get_user(event.user) if event.user else {}
)
event.user_json = authentik.events.models.get_user(event.user) if event.user else {}
event._state.adding = True
event.save()
@ -58,7 +56,5 @@ class Migration(migrations.Migration):
model_name="event",
name="user",
),
migrations.RenameField(
model_name="event", old_name="user_json", new_name="user"
),
migrations.RenameField(model_name="event", old_name="user_json", new_name="user"),
]

View file

@ -11,16 +11,12 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
EventMatcherPolicy = apps.get_model(
"authentik_policies_event_matcher", "EventMatcherPolicy"
)
EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
admin_group = (
Group.objects.using(db_alias)
.filter(name="authentik Admins", is_superuser=True)
.first()
Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
)
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
@ -32,9 +28,7 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
)
trigger.transports.set(
NotificationTransport.objects.using(db_alias).filter(
name="default-email-transport"
)
NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
)
trigger.save()
PolicyBinding.objects.using(db_alias).update_or_create(
@ -50,16 +44,12 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
EventMatcherPolicy = apps.get_model(
"authentik_policies_event_matcher", "EventMatcherPolicy"
)
EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
admin_group = (
Group.objects.using(db_alias)
.filter(name="authentik Admins", is_superuser=True)
.first()
Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
)
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
@ -71,9 +61,7 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
)
trigger.transports.set(
NotificationTransport.objects.using(db_alias).filter(
name="default-email-transport"
)
NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
)
trigger.save()
PolicyBinding.objects.using(db_alias).update_or_create(
@ -89,16 +77,12 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
EventMatcherPolicy = apps.get_model(
"authentik_policies_event_matcher", "EventMatcherPolicy"
)
EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
NotificationRule = apps.get_model("authentik_events", "NotificationRule")
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
admin_group = (
Group.objects.using(db_alias)
.filter(name="authentik Admins", is_superuser=True)
.first()
Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
)
policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
@ -114,9 +98,7 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
)
trigger.transports.set(
NotificationTransport.objects.using(db_alias).filter(
name="default-email-transport"
)
NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
)
trigger.save()
PolicyBinding.objects.using(db_alias).update_or_create(

View file

@ -38,9 +38,7 @@ def progress_bar(
def print_progress_bar(iteration):
"""Progress Bar Printing Function"""
percent = ("{0:." + str(decimals) + "f}").format(
100 * (iteration / float(total))
)
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bar = fill * filledLength + "-" * (length - filledLength)
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
@ -78,9 +76,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="event",
name="expires",
field=models.DateTimeField(
default=authentik.events.models.default_event_duration
),
field=models.DateTimeField(default=authentik.events.models.default_event_duration),
),
migrations.AddField(
model_name="event",

View file

@ -15,9 +15,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="event",
name="tenant",
field=models.JSONField(
blank=True, default=authentik.events.models.default_tenant
),
field=models.JSONField(blank=True, default=authentik.events.models.default_tenant),
),
migrations.AlterField(
model_name="event",

View file

@ -15,10 +15,7 @@ from requests import RequestException, post
from structlog.stdlib import get_logger
from authentik import __version__
from authentik.core.middleware import (
SESSION_IMPERSONATE_ORIGINAL_USER,
SESSION_IMPERSONATE_USER,
)
from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
from authentik.core.models import ExpiringModel, Group, User
from authentik.events.geo import GEOIP_READER
from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
@ -159,9 +156,7 @@ class Event(ExpiringModel):
if hasattr(request, "user"):
original_user = None
if hasattr(request, "session"):
original_user = request.session.get(
SESSION_IMPERSONATE_ORIGINAL_USER, None
)
original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None)
self.user = get_user(request.user, original_user)
if user:
self.user = get_user(user)
@ -169,9 +164,7 @@ class Event(ExpiringModel):
if hasattr(request, "session"):
if SESSION_IMPERSONATE_ORIGINAL_USER in request.session:
self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER])
self.user["on_behalf_of"] = get_user(
request.session[SESSION_IMPERSONATE_USER]
)
self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER])
# User 255.255.255.255 as fallback if IP cannot be determined
self.client_ip = get_client_ip(request)
# Apply GeoIP Data, when enabled
@ -414,9 +407,7 @@ class NotificationRule(PolicyBindingModel):
severity = models.TextField(
choices=NotificationSeverity.choices,
default=NotificationSeverity.NOTICE,
help_text=_(
"Controls which severity level the created notifications will have."
),
help_text=_("Controls which severity level the created notifications will have."),
)
group = models.ForeignKey(
Group,

View file

@ -135,9 +135,7 @@ class MonitoredTask(Task):
self._result = result
# pylint: disable=too-many-arguments
def after_return(
self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo
):
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
if self._result:
if not self._result.uid:
self._result.uid = self._uid
@ -159,9 +157,7 @@ class MonitoredTask(Task):
# pylint: disable=too-many-arguments
def on_failure(self, exc, task_id, args, kwargs, einfo):
if not self._result:
self._result = TaskResult(
status=TaskResultStatus.ERROR, messages=[str(exc)]
)
self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
if not self._result.uid:
self._result.uid = self._uid
TaskInfo(
@ -179,8 +175,7 @@ class MonitoredTask(Task):
Event.new(
EventAction.SYSTEM_TASK_EXCEPTION,
message=(
f"Task {self.__name__} encountered an error: "
"\n".join(self._result.messages)
f"Task {self.__name__} encountered an error: " "\n".join(self._result.messages)
),
).save()
return super().on_failure(exc, task_id, args, kwargs, einfo=einfo)

View file

@ -2,11 +2,7 @@
from threading import Thread
from typing import Any, Optional
from django.contrib.auth.signals import (
user_logged_in,
user_logged_out,
user_login_failed,
)
from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
from django.db.models.signals import post_save
from django.dispatch import receiver
from django.http import HttpRequest
@ -30,9 +26,7 @@ class EventNewThread(Thread):
kwargs: dict[str, Any]
user: Optional[User] = None
def __init__(
self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs
):
def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs):
super().__init__()
self.action = action
self.request = request
@ -68,9 +62,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
@receiver(user_write)
# pylint: disable=unused-argument
def on_user_write(
sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs
):
def on_user_write(sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs):
"""Log User write"""
thread = EventNewThread(EventAction.USER_WRITE, request, **data)
thread.kwargs["created"] = kwargs.get("created", False)
@ -80,9 +72,7 @@ def on_user_write(
@receiver(user_login_failed)
# pylint: disable=unused-argument
def on_user_login_failed(
sender, credentials: dict[str, str], request: HttpRequest, **_
):
def on_user_login_failed(sender, credentials: dict[str, str], request: HttpRequest, **_):
"""Failed Login"""
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
thread.run()

View file

@ -22,9 +22,7 @@ LOGGER = get_logger()
def event_notification_handler(event_uuid: str):
"""Start task for each trigger definition"""
for trigger in NotificationRule.objects.all():
event_trigger_handler.apply_async(
args=[event_uuid, trigger.name], queue="authentik_events"
)
event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events")
@CELERY_APP.task()
@ -43,17 +41,13 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
if "policy_uuid" in event.context:
policy_uuid = event.context["policy_uuid"]
if PolicyBinding.objects.filter(
target__in=NotificationRule.objects.all().values_list(
"pbm_uuid", flat=True
),
target__in=NotificationRule.objects.all().values_list("pbm_uuid", flat=True),
policy=policy_uuid,
).exists():
# If policy that caused this event to be created is attached
# to *any* NotificationRule, we return early.
# This is the most effective way to prevent infinite loops.
LOGGER.debug(
"e(trigger): attempting to prevent infinite loop", trigger=trigger
)
LOGGER.debug("e(trigger): attempting to prevent infinite loop", trigger=trigger)
return
if not trigger.group:
@ -62,9 +56,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger)
try:
user = (
User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user()
)
user = User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user()
except User.DoesNotExist:
LOGGER.warning("e(trigger): failed to get user", trigger=trigger)
return
@ -99,20 +91,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
retry_backoff=True,
base=MonitoredTask,
)
def notification_transport(
self: MonitoredTask, notification_pk: int, transport_pk: int
):
def notification_transport(self: MonitoredTask, notification_pk: int, transport_pk: int):
"""Send notification over specified transport"""
self.save_on_success = False
try:
notification: Notification = Notification.objects.filter(
pk=notification_pk
).first()
notification: Notification = Notification.objects.filter(pk=notification_pk).first()
if not notification:
return
transport: NotificationTransport = NotificationTransport.objects.get(
pk=transport_pk
)
transport: NotificationTransport = NotificationTransport.objects.get(pk=transport_pk)
transport.send(notification)
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
except NotificationTransportError as exc:

View file

@ -38,7 +38,5 @@ class TestEvents(TestCase):
event = Event.new("unittest", model=temp_model)
event.save() # We save to ensure nothing is un-saveable
model_content_type = ContentType.objects.get_for_model(temp_model)
self.assertEqual(
event.context.get("model").get("app"), model_content_type.app_label
)
self.assertEqual(event.context.get("model").get("app"), model_content_type.app_label)
self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex)

View file

@ -81,12 +81,8 @@ class TestEventsNotifications(TestCase):
execute_mock = MagicMock()
passes = MagicMock(side_effect=PolicyException)
with patch(
"authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes
):
with patch(
"authentik.events.models.NotificationTransport.send", execute_mock
):
with patch("authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes):
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(passes.call_count, 1)
@ -96,9 +92,7 @@ class TestEventsNotifications(TestCase):
self.group.users.add(user2)
self.group.save()
transport = NotificationTransport.objects.create(
name="transport", send_once=True
)
transport = NotificationTransport.objects.create(name="transport", send_once=True)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name="trigger", group=self.group)
trigger.transports.add(transport)

View file

@ -14,12 +14,7 @@ from rest_framework.fields import BooleanField, FileField, ReadOnlyField
from rest_framework.parsers import MultiPartParser
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import (
CharField,
ModelSerializer,
Serializer,
SerializerMethodField,
)
from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
@ -152,11 +147,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
],
)
@extend_schema(
request={
"multipart/form-data": inline_serializer(
"SetIcon", fields={"file": FileField()}
)
},
request={"multipart/form-data": inline_serializer("SetIcon", fields={"file": FileField()})},
responses={
204: OpenApiResponse(description="Successfully imported flow"),
400: OpenApiResponse(description="Bad request"),
@ -221,9 +212,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
.order_by("order")
):
for p_index, policy_binding in enumerate(
get_objects_for_user(
request.user, "authentik_policies.view_policybinding"
)
get_objects_for_user(request.user, "authentik_policies.view_policybinding")
.filter(target=stage_binding)
.exclude(policy__isnull=True)
.order_by("order")
@ -256,20 +245,14 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
element: DiagramElement = body[index]
if element.type == "condition":
# Policy passes, link policy yes to next stage
footer.append(
f"{element.identifier}(yes, right)->{body[index + 1].identifier}"
)
footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}")
# Policy doesn't pass, go to stage after next stage
no_element = body[index + 1]
if no_element.type != "end":
no_element = body[index + 2]
footer.append(
f"{element.identifier}(no, bottom)->{no_element.identifier}"
)
footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}")
elif element.type == "operation":
footer.append(
f"{element.identifier}(bottom)->{body[index + 1].identifier}"
)
footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}")
diagram = "\n".join([str(x) for x in header + body + footer])
return Response({"diagram": diagram})

View file

@ -95,9 +95,7 @@ class Command(BaseCommand): # pragma: no cover
"""Output results human readable"""
total_max: int = max([max(inner) for inner in values])
total_min: int = min([min(inner) for inner in values])
total_avg = sum([sum(inner) for inner in values]) / sum(
[len(inner) for inner in values]
)
total_avg = sum([sum(inner) for inner in values]) / sum([len(inner) for inner in values])
print(f"Version: {__version__}")
print(f"Processes: {len(values)}")

View file

@ -9,21 +9,15 @@ from authentik.stages.identification.models import UserFields
from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP
def create_default_authentication_flow(
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
def create_default_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage")
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")
IdentificationStage = apps.get_model(
"authentik_stages_identification", "IdentificationStage"
)
IdentificationStage = apps.get_model("authentik_stages_identification", "IdentificationStage")
db_alias = schema_editor.connection.alias
identification_stage, _ = IdentificationStage.objects.using(
db_alias
).update_or_create(
identification_stage, _ = IdentificationStage.objects.using(db_alias).update_or_create(
name="default-authentication-identification",
defaults={
"user_fields": [UserFields.E_MAIL, UserFields.USERNAME],
@ -69,17 +63,13 @@ def create_default_authentication_flow(
)
def create_default_invalidation_flow(
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
def create_default_invalidation_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage")
db_alias = schema_editor.connection.alias
UserLogoutStage.objects.using(db_alias).update_or_create(
name="default-invalidation-logout"
)
UserLogoutStage.objects.using(db_alias).update_or_create(name="default-invalidation-logout")
flow, _ = Flow.objects.using(db_alias).update_or_create(
slug="default-invalidation-flow",

View file

@ -15,16 +15,12 @@ PROMPT_POLICY_EXPRESSION = """# Check if we've not been given a username by the
return 'username' not in context.get('prompt_data', {})"""
def create_default_source_enrollment_flow(
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
def create_default_source_enrollment_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
ExpressionPolicy = apps.get_model(
"authentik_policies_expression", "ExpressionPolicy"
)
ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
@ -99,16 +95,12 @@ def create_default_source_enrollment_flow(
)
def create_default_source_authentication_flow(
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
def create_default_source_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
ExpressionPolicy = apps.get_model(
"authentik_policies_expression", "ExpressionPolicy"
)
ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")

View file

@ -7,9 +7,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from authentik.flows.models import FlowDesignation
def create_default_provider_authorization_flow(
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
def create_default_provider_authorization_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")

View file

@ -32,9 +32,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
ExpressionPolicy = apps.get_model(
"authentik_policies_expression", "ExpressionPolicy"
)
ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
db_alias = schema_editor.connection.alias
@ -52,9 +50,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
name="default-oobe-prefill-user",
defaults={"expression": PREFILL_POLICY_EXPRESSION},
)
password_usable_policy, _ = ExpressionPolicy.objects.using(
db_alias
).update_or_create(
password_usable_policy, _ = ExpressionPolicy.objects.using(db_alias).update_or_create(
name="default-oobe-password-usable",
defaults={"expression": PW_USABLE_POLICY_EXPRESSION},
)
@ -83,9 +79,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create(
name="default-oobe-password",
)
prompt_stage.fields.set(
[prompt_header, prompt_email, password_first, password_second]
)
prompt_stage.fields.set([prompt_header, prompt_email, password_first, password_second])
prompt_stage.save()
user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create(

View file

@ -138,9 +138,7 @@ class Flow(SerializerModel, PolicyBindingModel):
it is returned as-is"""
if not self.background:
return "/static/dist/assets/images/flow_background.jpg"
if self.background.name.startswith("http") or self.background.name.startswith(
"/static"
):
if self.background.name.startswith("http") or self.background.name.startswith("/static"):
return self.background.name
return self.background.url
@ -165,9 +163,7 @@ class Flow(SerializerModel, PolicyBindingModel):
if result.passing:
LOGGER.debug("with_policy: flow passing", flow=flow)
return flow
LOGGER.warning(
"with_policy: flow not passing", flow=flow, messages=result.messages
)
LOGGER.warning("with_policy: flow not passing", flow=flow, messages=result.messages)
LOGGER.debug("with_policy: no flow found", filters=flow_filter)
return None

View file

@ -78,14 +78,10 @@ class FlowPlan:
marker = self.markers[0]
if marker.__class__ is not StageMarker:
LOGGER.debug(
"f(plan_inst): stage has marker", binding=binding, marker=marker
)
LOGGER.debug("f(plan_inst): stage has marker", binding=binding, marker=marker)
marked_stage = marker.process(self, binding, http_request)
if not marked_stage:
LOGGER.debug(
"f(plan_inst): marker returned none, next stage", binding=binding
)
LOGGER.debug("f(plan_inst): marker returned none, next stage", binding=binding)
self.bindings.remove(binding)
self.markers.remove(marker)
if not self.has_stages:
@ -193,9 +189,9 @@ class FlowPlanner:
if default_context:
plan.context = default_context
# Check Flow policies
for binding in FlowStageBinding.objects.filter(
target__pk=self.flow.pk
).order_by("order"):
for binding in FlowStageBinding.objects.filter(target__pk=self.flow.pk).order_by(
"order"
):
binding: FlowStageBinding
stage = binding.stage
marker = StageMarker()

View file

@ -26,9 +26,7 @@ def invalidate_flow_cache(sender, instance, **_):
LOGGER.debug("Invalidating Flow cache", flow=instance, len=total)
if isinstance(instance, FlowStageBinding):
total = delete_cache_prefix(f"{cache_key(instance.target)}*")
LOGGER.debug(
"Invalidating Flow cache from FlowStageBinding", binding=instance, len=total
)
LOGGER.debug("Invalidating Flow cache from FlowStageBinding", binding=instance, len=total)
if isinstance(instance, Stage):
total = 0
for binding in FlowStageBinding.objects.filter(stage=instance):

View file

@ -42,14 +42,9 @@ class StageView(View):
other things besides the form display.
If no user is pending, returns request.user"""
if (
PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context
and for_display
):
if PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context and for_display:
return User(
username=self.executor.plan.context.get(
PLAN_CONTEXT_PENDING_USER_IDENTIFIER
),
username=self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER_IDENTIFIER),
email="",
)
if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context:

View file

@ -89,14 +89,10 @@ class TestFlowPlanner(TestCase):
planner = FlowPlanner(flow)
planner.plan(request)
self.assertEqual(
CACHE_MOCK.set.call_count, 1
) # Ensure plan is written to cache
self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure plan is written to cache
planner = FlowPlanner(flow)
planner.plan(request)
self.assertEqual(
CACHE_MOCK.set.call_count, 1
) # Ensure nothing is written to cache
self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure nothing is written to cache
self.assertEqual(CACHE_MOCK.get.call_count, 2) # Get is called twice
def test_planner_default_context(self):
@ -176,9 +172,7 @@ class TestFlowPlanner(TestCase):
request.session.save()
# Here we patch the dummy policy to evaluate to true so the stage is included
with patch(
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
planner = FlowPlanner(flow)
plan = planner.plan(request)

View file

@ -76,9 +76,7 @@ class TestFlowTransfer(TransactionTestCase):
PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0)
user_login = UserLoginStage.objects.create(name=stage_name)
fsb = FlowStageBinding.objects.create(
target=flow, stage=user_login, order=0
)
fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0)
PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0)
exporter = FlowExporter(flow)

View file

@ -11,12 +11,7 @@ from authentik.core.models import User
from authentik.flows.challenge import ChallengeTypes
from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.markers import ReevaluateMarker, StageMarker
from authentik.flows.models import (
Flow,
FlowDesignation,
FlowStageBinding,
InvalidResponseAction,
)
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction
from authentik.flows.planner import FlowPlan, FlowPlanner
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
@ -61,9 +56,7 @@ class TestFlowExecutor(TestCase):
)
stage = DummyStage.objects.create(name="dummy")
binding = FlowStageBinding(target=flow, stage=stage, order=0)
plan = FlowPlan(
flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()]
)
plan = FlowPlan(flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()])
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
@ -163,9 +156,7 @@ class TestFlowExecutor(TestCase):
target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1
)
exec_url = reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
# First Request, start planning, renders form
response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200)
@ -209,13 +200,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included
with patch(
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
exec_url = reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
# First request, run the planner
response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200)
@ -263,13 +250,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included
with patch(
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
exec_url = reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
# First request, run the planner
response = self.client.get(exec_url)
@ -334,13 +317,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included
with patch(
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
exec_url = reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
# First request, run the planner
response = self.client.get(exec_url)
@ -422,13 +401,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included
with patch(
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
exec_url = reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
# First request, run the planner
response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200)
@ -511,9 +486,7 @@ class TestFlowExecutor(TestCase):
)
request.user = user
planner = FlowPlanner(flow)
plan = planner.plan(
request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident}
)
plan = planner.plan(request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident})
executor = FlowExecutorView()
executor.plan = plan
@ -542,9 +515,7 @@ class TestFlowExecutor(TestCase):
evaluate_on_plan=False,
re_evaluate_policies=True,
)
PolicyBinding.objects.create(
policy=reputation_policy, target=deny_binding, order=0
)
PolicyBinding.objects.create(policy=reputation_policy, target=deny_binding, order=0)
# Stage 1 is an identification stage
ident_stage = IdentificationStage.objects.create(
@ -557,9 +528,7 @@ class TestFlowExecutor(TestCase):
order=1,
invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT,
)
exec_url = reverse(
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
# First request, run the planner
response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200)
@ -579,9 +548,7 @@ class TestFlowExecutor(TestCase):
"user_fields": [UserFields.E_MAIL],
},
)
response = self.client.post(
exec_url, {"uid_field": "invalid-string"}, follow=True
)
response = self.client.post(exec_url, {"uid_field": "invalid-string"}, follow=True)
self.assertEqual(response.status_code, 200)
self.assertJSONEqual(
force_str(response.content),

View file

@ -21,9 +21,7 @@ class TestHelperView(TestCase):
response = self.client.get(
reverse("authentik_flows:default-invalidation"),
)
expected_url = reverse(
"authentik_core:if-flow", kwargs={"flow_slug": flow.slug}
)
expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url)
@ -40,8 +38,6 @@ class TestHelperView(TestCase):
response = self.client.get(
reverse("authentik_flows:default-invalidation"),
)
expected_url = reverse(
"authentik_core:if-flow", kwargs={"flow_slug": flow.slug}
)
expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url)

View file

@ -44,9 +44,7 @@ class FlowBundleEntry:
attrs: dict[str, Any]
@staticmethod
def from_model(
model: SerializerModel, *extra_identifier_names: str
) -> "FlowBundleEntry":
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "FlowBundleEntry":
"""Convert a SerializerModel instance to a Bundle Entry"""
identifiers = {
"pk": model.pk,

View file

@ -6,11 +6,7 @@ from uuid import UUID
from django.db.models import Q
from authentik.flows.models import Flow, FlowStageBinding, Stage
from authentik.flows.transfer.common import (
DataclassEncoder,
FlowBundle,
FlowBundleEntry,
)
from authentik.flows.transfer.common import DataclassEncoder, FlowBundle, FlowBundleEntry
from authentik.policies.models import Policy, PolicyBinding
from authentik.stages.prompt.models import PromptStage
@ -37,9 +33,7 @@ class FlowExporter:
def walk_stages(self) -> Iterator[FlowBundleEntry]:
"""Convert all stages attached to self.flow into FlowBundleEntry objects"""
stages = (
Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
)
stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
for stage in stages:
if isinstance(stage, PromptStage):
pass
@ -56,9 +50,7 @@ class FlowExporter:
a direct foreign key to a policy."""
# Special case for PromptStage as that has a direct M2M to policy, we have to ensure
# all policies referenced in there we also include here
prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list(
"pk", flat=True
)
prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list("pk", flat=True)
query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages)
policies = Policy.objects.filter(query).select_related()
for policy in policies:
@ -67,9 +59,7 @@ class FlowExporter:
def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]:
"""Walk over all policybindings relative to us. This is run at the end of the export, as
we are sure all objects exist now."""
bindings = PolicyBinding.objects.filter(
target__in=self.pbm_uuids
).select_related()
bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related()
for binding in bindings:
yield FlowBundleEntry.from_model(binding, "policy", "target", "order")

View file

@ -16,11 +16,7 @@ from rest_framework.serializers import BaseSerializer, Serializer
from structlog.stdlib import BoundLogger, get_logger
from authentik.flows.models import Flow, FlowStageBinding, Stage
from authentik.flows.transfer.common import (
EntryInvalidError,
FlowBundle,
FlowBundleEntry,
)
from authentik.flows.transfer.common import EntryInvalidError, FlowBundle, FlowBundleEntry
from authentik.lib.models import SerializerModel
from authentik.policies.models import Policy, PolicyBinding
from authentik.stages.prompt.models import Prompt
@ -105,9 +101,7 @@ class FlowImporter:
if isinstance(value, dict) and "pk" in value:
del updated_identifiers[key]
updated_identifiers[f"{key}"] = value["pk"]
existing_models = model.objects.filter(
self.__query_from_identifier(updated_identifiers)
)
existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers))
serializer_kwargs = {}
if existing_models.exists():
@ -120,9 +114,7 @@ class FlowImporter:
)
serializer_kwargs["instance"] = model_instance
else:
self.logger.debug(
"initialise new instance", model=model, **updated_identifiers
)
self.logger.debug("initialise new instance", model=model, **updated_identifiers)
full_data = self.__update_pks_for_attrs(entry.attrs)
full_data.update(updated_identifiers)
serializer_kwargs["data"] = full_data

View file

@ -38,13 +38,7 @@ from authentik.flows.challenge import (
WithUserInfoChallenge,
)
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from authentik.flows.models import (
ConfigurableStage,
Flow,
FlowDesignation,
FlowStageBinding,
Stage,
)
from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage
from authentik.flows.planner import (
PLAN_CONTEXT_PENDING_USER,
PLAN_CONTEXT_REDIRECT,
@ -155,9 +149,7 @@ class FlowExecutorView(APIView):
try:
self.plan = self._initiate_plan()
except FlowNonApplicableException as exc:
self._logger.warning(
"f(exec): Flow not applicable to current user", exc=exc
)
self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
return to_stage_response(self.request, self.handle_invalid_flow(exc))
except EmptyFlowException as exc:
self._logger.warning("f(exec): Flow is empty", exc=exc)
@ -174,9 +166,7 @@ class FlowExecutorView(APIView):
# in which case we just delete the plan and invalidate everything
next_binding = self.plan.next(self.request)
except Exception as exc: # pylint: disable=broad-except
self._logger.warning(
"f(exec): found incompatible flow plan, invalidating run", exc=exc
)
self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc)
keys = cache.keys("flow_*")
cache.delete_many(keys)
return self.stage_invalid()
@ -314,9 +304,7 @@ class FlowExecutorView(APIView):
self.request.session[SESSION_KEY_PLAN] = plan
kwargs = self.kwargs
kwargs.update({"flow_slug": self.flow.slug})
return redirect_with_qs(
"authentik_api:flow-executor", self.request.GET, **kwargs
)
return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
def _flow_done(self) -> HttpResponse:
"""User Successfully passed all stages"""
@ -350,9 +338,7 @@ class FlowExecutorView(APIView):
)
kwargs = self.kwargs
kwargs.update({"flow_slug": self.flow.slug})
return redirect_with_qs(
"authentik_api:flow-executor", self.request.GET, **kwargs
)
return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
# User passed all stages
self._logger.debug(
"f(exec): User passed all stages",
@ -408,18 +394,13 @@ class FlowErrorResponse(TemplateResponse):
super().__init__(request=request, template="flows/error.html")
self.error = error
def resolve_context(
self, context: Optional[dict[str, Any]]
) -> Optional[dict[str, Any]]:
def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
if not context:
context = {}
context["error"] = self.error
if self._request.user and self._request.user.is_authenticated:
if (
self._request.user.is_superuser
or self._request.user.group_attributes().get(
USER_ATTRIBUTE_DEBUG, False
)
if self._request.user.is_superuser or self._request.user.group_attributes().get(
USER_ATTRIBUTE_DEBUG, False
):
context["tb"] = "".join(format_tb(self.error.__traceback__))
return context
@ -464,9 +445,7 @@ class ToDefaultFlow(View):
flow_slug=flow.slug,
)
del self.request.session[SESSION_KEY_PLAN]
return redirect_with_qs(
"authentik_core:if-flow", request.GET, flow_slug=flow.slug
)
return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse:

View file

@ -115,9 +115,7 @@ class ConfigLoader:
for key, value in os.environ.items():
if not key.startswith(ENV_PREFIX):
continue
relative_key = (
key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
)
relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
# Recursively convert path from a.b.c into outer[a][b][c]
current_obj = outer
dot_parts = relative_key.split(".")

View file

@ -37,15 +37,11 @@ class InheritanceAutoManager(InheritanceManager):
return super().get_queryset().select_subclasses()
class InheritanceForwardManyToOneDescriptor(
models.fields.related.ForwardManyToOneDescriptor
):
class InheritanceForwardManyToOneDescriptor(models.fields.related.ForwardManyToOneDescriptor):
"""Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager."""
def get_queryset(self, **hints):
return self.field.remote_field.model.objects.db_manager(
hints=hints
).select_subclasses()
return self.field.remote_field.model.objects.db_manager(hints=hints).select_subclasses()
class InheritanceForeignKey(models.ForeignKey):

View file

@ -8,11 +8,7 @@ from botocore.exceptions import BotoCoreError
from celery.exceptions import CeleryError
from channels.middleware import BaseMiddleware
from channels_redis.core import ChannelFull
from django.core.exceptions import (
ImproperlyConfigured,
SuspiciousOperation,
ValidationError,
)
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
from django.db import InternalError, OperationalError, ProgrammingError
from django.http.response import Http404
from django_redis.exceptions import ConnectionInterrupted

View file

@ -26,7 +26,5 @@ class TestEvaluator(TestCase):
def test_is_group_member(self):
"""Test expr_is_group_member"""
self.assertFalse(
BaseEvaluator.expr_is_group_member(
User.objects.get(username="akadmin"), name="test"
)
BaseEvaluator.expr_is_group_member(User.objects.get(username="akadmin"), name="test")
)

View file

@ -1,17 +1,8 @@
"""Test HTTP Helpers"""
from django.test import RequestFactory, TestCase
from authentik.core.models import (
USER_ATTRIBUTE_CAN_OVERRIDE_IP,
Token,
TokenIntents,
User,
)
from authentik.lib.utils.http import (
OUTPOST_REMOTE_IP_HEADER,
OUTPOST_TOKEN_HEADER,
get_client_ip,
)
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User
from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
class TestHTTP(TestCase):

View file

@ -9,9 +9,7 @@ class TestSentry(TestCase):
def test_error_not_sent(self):
"""Test SentryIgnoredError not sent"""
self.assertIsNone(
before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})
)
self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
def test_error_sent(self):
"""Test error sent"""

View file

@ -29,16 +29,9 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
"""Get the actual remote IP when set by an outpost. Only
allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set
to outpost"""
from authentik.core.models import (
USER_ATTRIBUTE_CAN_OVERRIDE_IP,
Token,
TokenIntents,
)
from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents
if (
OUTPOST_REMOTE_IP_HEADER not in request.META
or OUTPOST_TOKEN_HEADER not in request.META
):
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
return None
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
tokens = Token.filter_not_expired(

View file

@ -12,9 +12,7 @@ def managed_reconcile(self: MonitoredTask):
try:
ObjectManager().run()
self.set_status(
TaskResult(
TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."]
)
TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."])
)
except DatabaseError as exc:
self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)]))

View file

@ -15,12 +15,7 @@ from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, is_dict
from authentik.core.models import Provider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.models import (
Outpost,
OutpostConfig,
OutpostType,
default_outpost_config,
)
from authentik.outposts.models import Outpost, OutpostConfig, OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider

View file

@ -15,11 +15,7 @@ from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import (
MetaNameSerializer,
PassiveSerializer,
TypeCreateSerializer,
)
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
from authentik.lib.utils.reflection import all_subclasses
from authentik.outposts.models import (
DockerServiceConnection,
@ -129,9 +125,7 @@ class KubernetesServiceConnectionSerializer(ServiceConnectionSerializer):
if kubeconfig == {}:
if not self.initial_data["local"]:
raise serializers.ValidationError(
_(
"You can only use an empty kubeconfig when connecting to a local cluster."
)
_("You can only use an empty kubeconfig when connecting to a local cluster.")
)
# Empty kubeconfig is valid
return kubeconfig

View file

@ -59,9 +59,7 @@ class OutpostConsumer(AuthJsonConsumer):
def connect(self):
super().connect()
uuid = self.scope["url_route"]["kwargs"]["pk"]
outpost = get_objects_for_user(
self.user, "authentik_outposts.view_outpost"
).filter(pk=uuid)
outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid)
if not outpost.exists():
raise DenyConnection()
self.accept()
@ -129,7 +127,5 @@ class OutpostConsumer(AuthJsonConsumer):
def event_update(self, event):
"""Event handler which is called by post_save signals, Send update instruction"""
self.send_json(
asdict(
WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)
)
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
)

View file

@ -9,11 +9,7 @@ from yaml import safe_dump
from authentik import __version__
from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.models import (
DockerServiceConnection,
Outpost,
ServiceConnectionInvalid,
)
from authentik.outposts.models import DockerServiceConnection, Outpost, ServiceConnectionInvalid
class DockerController(BaseController):
@ -37,9 +33,7 @@ class DockerController(BaseController):
def _get_env(self) -> dict[str, str]:
return {
"AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(),
"AUTHENTIK_INSECURE": str(
self.outpost.config.authentik_host_insecure
).lower(),
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure).lower(),
"AUTHENTIK_TOKEN": self.outpost.token.key,
}
@ -141,9 +135,7 @@ class DockerController(BaseController):
.lower()
!= "unless-stopped"
):
self.logger.info(
"Container has mis-matched restart policy, re-creating..."
)
self.logger.info("Container has mis-matched restart policy, re-creating...")
self.down()
return self.up()
# Check that container is healthy
@ -157,9 +149,7 @@ class DockerController(BaseController):
if has_been_created:
# Since we've just created the container, give it some time to start.
# If its still not up by then, restart it
self.logger.info(
"Container is unhealthy and new, giving it time to boot."
)
self.logger.info("Container is unhealthy and new, giving it time to boot.")
sleep(60)
self.logger.info("Container is unhealthy, restarting...")
container.restart()
@ -198,9 +188,7 @@ class DockerController(BaseController):
"ports": ports,
"environment": {
"AUTHENTIK_HOST": self.outpost.config.authentik_host,
"AUTHENTIK_INSECURE": str(
self.outpost.config.authentik_host_insecure
),
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
"AUTHENTIK_TOKEN": self.outpost.token.key,
},
"labels": self._get_labels(),

View file

@ -17,10 +17,7 @@ from kubernetes.client import (
)
from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import (
KubernetesObjectReconciler,
NeedsUpdate,
)
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
from authentik.outposts.models import Outpost
if TYPE_CHECKING:
@ -124,9 +121,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
)
def delete(self, reference: V1Deployment):
return self.api.delete_namespaced_deployment(
reference.metadata.name, self.namespace
)
return self.api.delete_namespaced_deployment(reference.metadata.name, self.namespace)
def retrieve(self) -> V1Deployment:
return self.api.read_namespaced_deployment(self.name, self.namespace)

View file

@ -5,10 +5,7 @@ from typing import TYPE_CHECKING
from kubernetes.client import CoreV1Api, V1Secret
from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import (
KubernetesObjectReconciler,
NeedsUpdate,
)
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
if TYPE_CHECKING:
from authentik.outposts.controllers.kubernetes import KubernetesController
@ -38,9 +35,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
return V1Secret(
metadata=meta,
data={
"authentik_host": b64string(
self.controller.outpost.config.authentik_host
),
"authentik_host": b64string(self.controller.outpost.config.authentik_host),
"authentik_host_insecure": b64string(
str(self.controller.outpost.config.authentik_host_insecure)
),
@ -54,9 +49,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
)
def delete(self, reference: V1Secret):
return self.api.delete_namespaced_secret(
reference.metadata.name, self.namespace
)
return self.api.delete_namespaced_secret(reference.metadata.name, self.namespace)
def retrieve(self) -> V1Secret:
return self.api.read_namespaced_secret(self.name, self.namespace)

View file

@ -4,10 +4,7 @@ from typing import TYPE_CHECKING
from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec
from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import (
KubernetesObjectReconciler,
NeedsUpdate,
)
from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler
if TYPE_CHECKING:
@ -58,9 +55,7 @@ class ServiceReconciler(KubernetesObjectReconciler[V1Service]):
)
def delete(self, reference: V1Service):
return self.api.delete_namespaced_service(
reference.metadata.name, self.namespace
)
return self.api.delete_namespaced_service(reference.metadata.name, self.namespace)
def retrieve(self) -> V1Service:
return self.api.read_namespaced_service(self.name, self.namespace)

View file

@ -24,9 +24,7 @@ class KubernetesController(BaseController):
client: ApiClient
connection: KubernetesServiceConnection
def __init__(
self, outpost: Outpost, connection: KubernetesServiceConnection
) -> None:
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection) -> None:
super().__init__(outpost, connection)
self.client = connection.client()
self.reconcilers = {

View file

@ -15,9 +15,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="outpost",
name="_config",
field=models.JSONField(
default=authentik.outposts.models.default_outpost_config
),
field=models.JSONField(default=authentik.outposts.models.default_outpost_config),
),
migrations.AddField(
model_name="outpost",

View file

@ -10,9 +10,7 @@ def fix_missing_token_identifier(apps: Apps, schema_editor: BaseDatabaseSchemaEd
Token = apps.get_model("authentik_core", "Token")
from authentik.outposts.models import Outpost
for outpost in (
Outpost.objects.using(schema_editor.connection.alias).all().only("pk")
):
for outpost in Outpost.objects.using(schema_editor.connection.alias).all().only("pk"):
user_identifier = outpost.user_identifier
users = User.objects.filter(username=user_identifier)
if not users.exists():

View file

@ -14,9 +14,7 @@ import authentik.lib.models
def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias
Outpost = apps.get_model("authentik_outposts", "Outpost")
DockerServiceConnection = apps.get_model(
"authentik_outposts", "DockerServiceConnection"
)
DockerServiceConnection = apps.get_model("authentik_outposts", "DockerServiceConnection")
KubernetesServiceConnection = apps.get_model(
"authentik_outposts", "KubernetesServiceConnection"
)
@ -25,9 +23,7 @@ def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaE
k8s = KubernetesServiceConnection.objects.filter(local=True).first()
try:
for outpost in (
Outpost.objects.using(db_alias).all().exclude(deployment_type="custom")
):
for outpost in Outpost.objects.using(db_alias).all().exclude(deployment_type="custom"):
if outpost.deployment_type == "kubernetes":
outpost.service_connection = k8s
elif outpost.deployment_type == "docker":

View file

@ -11,9 +11,7 @@ def remove_pb_prefix_users(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Outpost = apps.get_model("authentik_outposts", "Outpost")
for outpost in Outpost.objects.using(alias).all():
matching = User.objects.using(alias).filter(
username=f"pb-outpost-{outpost.uuid.hex}"
)
matching = User.objects.using(alias).filter(username=f"pb-outpost-{outpost.uuid.hex}")
if matching.exists():
matching.delete()

View file

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="outpost",
name="type",
field=models.TextField(
choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"
),
field=models.TextField(choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"),
),
]

View file

@ -64,9 +64,7 @@ class OutpostConfig:
log_level: str = CONFIG.y("log_level")
error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled")
error_reporting_environment: str = CONFIG.y(
"error_reporting.environment", "customer"
)
error_reporting_environment: str = CONFIG.y("error_reporting.environment", "customer")
object_naming_template: str = field(default="ak-outpost-%(name)s")
kubernetes_replicas: int = field(default=1)
@ -264,9 +262,7 @@ class KubernetesServiceConnection(OutpostServiceConnection):
client = self.client()
api_instance = VersionApi(client)
version: VersionInfo = api_instance.get_code()
return OutpostServiceConnectionState(
version=version.git_version, healthy=True
)
return OutpostServiceConnectionState(version=version.git_version, healthy=True)
except (OpenApiException, HTTPError, ServiceConnectionInvalid):
return OutpostServiceConnectionState(version="", healthy=False)
@ -360,8 +356,7 @@ class Outpost(ManagedModel):
if isinstance(model_or_perm, models.Model):
model_or_perm: models.Model
code_name = (
f"{model_or_perm._meta.app_label}."
f"view_{model_or_perm._meta.model_name}"
f"{model_or_perm._meta.app_label}." f"view_{model_or_perm._meta.model_name}"
)
assign_perm(code_name, user, model_or_perm)
else:
@ -417,9 +412,7 @@ class Outpost(ManagedModel):
self,
"authentik_events.add_event",
]
for provider in (
Provider.objects.filter(outpost=self).select_related().select_subclasses()
):
for provider in Provider.objects.filter(outpost=self).select_related().select_subclasses():
if isinstance(provider, OutpostModel):
objects.extend(provider.get_required_objects())
else:

View file

@ -9,11 +9,7 @@ from authentik.core.models import Provider
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import (
CACHE_KEY_OUTPOST_DOWN,
outpost_controller,
outpost_post_save,
)
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
LOGGER = get_logger()
UPDATE_TRIGGERING_MODELS = (
@ -37,9 +33,7 @@ def pre_save_outpost(sender, instance: Outpost, **_):
# Name changes the deployment name, need to recreate
dirty += old_instance.name != instance.name
# namespace requires re-create
dirty += (
old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace
)
dirty += old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace
if bool(dirty):
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)

View file

@ -62,9 +62,7 @@ def controller_for_outpost(outpost: Outpost) -> Optional[BaseController]:
def outpost_service_connection_state(connection_pk: Any):
"""Update cached state of a service connection"""
connection: OutpostServiceConnection = (
OutpostServiceConnection.objects.filter(pk=connection_pk)
.select_subclasses()
.first()
OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first()
)
if not connection:
return
@ -157,9 +155,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
outpost_controller.delay(instance.pk)
if isinstance(instance, (OutpostModel, Outpost)):
LOGGER.debug(
"triggering outpost update from outpostmodel/outpost", instance=instance
)
LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
outpost_send_update(instance)
if isinstance(instance, OutpostServiceConnection):
@ -208,9 +204,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids:
LOGGER.debug(
"sending update", channel=channel, instance=state.uid, outpost=outpost
)
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
async_to_sync(layer.send)(channel, {"type": "event.update"})
@ -231,9 +225,7 @@ def outpost_local_connection():
if Path(kubeconfig_path).exists():
LOGGER.debug("Detected kubeconfig")
kubeconfig_local_name = f"k8s-{gethostname()}"
if not KubernetesServiceConnection.objects.filter(
name=kubeconfig_local_name
).exists():
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
LOGGER.debug("Creating kubeconfig Service Connection")
with open(kubeconfig_path, "r") as _kubeconfig:
KubernetesServiceConnection.objects.create(

View file

@ -63,9 +63,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
provider = ProxyProvider.objects.create(
name="test", authorization_flow=Flow.objects.first()
)
invalid = OutpostSerializer(
data={"name": "foo", "providers": [provider.pk], "config": {}}
)
invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": {}})
self.assertFalse(invalid.is_valid())
self.assertIn("config", invalid.errors)
valid = OutpostSerializer(

View file

@ -2,11 +2,7 @@
from typing import OrderedDict
from django.core.exceptions import ObjectDoesNotExist
from rest_framework.serializers import (
ModelSerializer,
PrimaryKeyRelatedField,
ValidationError,
)
from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger

Some files were not shown because too many files have changed in this diff Show more