Merge branch 'master' into stage-challenge

# Conflicts:
#	authentik/stages/authenticator_validate/stage.py
#	authentik/stages/identification/stage.py
This commit is contained in:
Jens Langhammer 2021-02-18 14:04:35 +01:00
commit b229b2f40d
73 changed files with 216 additions and 215 deletions

View file

@ -2,7 +2,6 @@
import time
from collections import Counter
from datetime import timedelta
from typing import Dict, List
from django.db.models import Count, ExpressionWrapper, F, Model
from django.db.models.fields import DurationField
@ -19,7 +18,7 @@ from rest_framework.viewsets import ViewSet
from authentik.events.models import Event, EventAction
def get_events_per_1h(**filter_kwargs) -> List[Dict[str, int]]:
def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
"""Get event count by hour in the last day, fill with zeros"""
date_from = now() - timedelta(days=1)
result = (

View file

@ -1,6 +1,6 @@
"""authentik Outpost administration"""
from dataclasses import asdict
from typing import Any, Dict
from typing import Any
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import (
@ -33,7 +33,7 @@ class OutpostCreateView(
template_name = "generic/create.html"
success_message = _("Successfully created Outpost")
def get_initial(self) -> Dict[str, Any]:
def get_initial(self) -> dict[str, Any]:
return {
"_config": asdict(
OutpostConfig(authentik_host=self.request.build_absolute_uri("/"))

View file

@ -1,5 +1,5 @@
"""authentik Policy administration"""
from typing import Any, Dict
from typing import Any
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import (
@ -102,7 +102,7 @@ class PolicyTestView(LoginRequiredMixin, DetailView, PermissionRequiredMixin, Fo
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
)
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["policy"] = self.get_object()
return super().get_context_data(**kwargs)

View file

@ -1,5 +1,5 @@
"""authentik Tasks List"""
from typing import Any, Dict
from typing import Any
from django.views.generic.base import TemplateView
@ -12,7 +12,7 @@ class TaskListView(AdminRequiredMixin, TemplateView):
template_name = "administration/task/list.html"
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
kwargs["object_list"] = sorted(
TaskInfo.all().values(), key=lambda x: x.task_name

View file

@ -1,5 +1,5 @@
"""authentik admin util views"""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from urllib.parse import urlparse
from django.contrib import messages
@ -40,7 +40,7 @@ class SearchListMixin(MultipleObjectMixin):
"""Accept search query using `search` querystring parameter. Requires self.search_fields,
a list of all fields to search. Can contain special lookups like __icontains"""
search_fields: List[str]
search_fields: list[str]
def get_queryset(self) -> QuerySet:
queryset = super().get_queryset()
@ -69,7 +69,7 @@ class InheritanceCreateView(CreateAssignPermView):
raise Http404 from exc
return model().form
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
form_cls = self.get_form_class()
if hasattr(form_cls, "template_name"):
@ -80,7 +80,7 @@ class InheritanceCreateView(CreateAssignPermView):
class InheritanceUpdateView(UpdateView):
"""UpdateView for objects using InheritanceManager"""
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
form_cls = self.get_form_class()
if hasattr(form_cls, "template_name"):

View file

@ -1,7 +1,7 @@
"""API Authentication"""
from base64 import b64decode
from binascii import Error
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Union
from rest_framework.authentication import BaseAuthentication, get_authorization_header
from rest_framework.request import Request
@ -44,7 +44,7 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
class AuthentikTokenAuthentication(BaseAuthentication):
"""Token-based authentication using HTTP Basic authentication"""
def authenticate(self, request: Request) -> Union[Tuple[User, Any], None]:
def authenticate(self, request: Request) -> Union[tuple[User, Any], None]:
"""Token-based authentication using HTTP Basic authentication"""
auth = get_authorization_header(request)

View file

@ -1,7 +1,7 @@
"""authentik core models"""
from datetime import timedelta
from hashlib import sha256
from typing import Any, Dict, Optional, Type
from typing import Any, Optional, Type
from uuid import uuid4
from django.conf import settings
@ -96,7 +96,7 @@ class User(GuardianUserMixin, AbstractUser):
objects = UserManager()
def group_attributes(self) -> Dict[str, Any]:
def group_attributes(self) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to,
including the users attributes"""
final_attributes = {}

View file

@ -1,5 +1,5 @@
"""authentik core user views"""
from typing import Any, Dict
from typing import Any
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import (
@ -45,7 +45,7 @@ class UserDetailsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
def get_object(self):
return self.request.user
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
unenrollment_flow = Flow.with_policy(
self.request, designation=FlowDesignation.UNRENOLLMENT

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from traceback import format_tb
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from celery import Task
from django.core.cache import cache
@ -26,7 +26,7 @@ class TaskResult:
status: TaskResultStatus
messages: List[str] = field(default_factory=list)
messages: list[str] = field(default_factory=list)
# Optional UID used in cache for tasks that run in different instances
uid: Optional[str] = field(default=None)
@ -49,8 +49,8 @@ class TaskInfo:
task_call_module: str
task_call_func: str
task_call_args: List[Any] = field(default_factory=list)
task_call_kwargs: Dict[str, Any] = field(default_factory=dict)
task_call_args: list[Any] = field(default_factory=list)
task_call_kwargs: dict[str, Any] = field(default_factory=dict)
task_description: Optional[str] = field(default=None)
@ -60,7 +60,7 @@ class TaskInfo:
return self.task_name.split("_")
@staticmethod
def all() -> Dict[str, "TaskInfo"]:
def all() -> dict[str, "TaskInfo"]:
"""Get all TaskInfo objects"""
return cache.get_many(cache.keys("task_*"))
@ -109,7 +109,7 @@ class MonitoredTask(Task):
# pylint: disable=too-many-arguments
def after_return(
self, status, retval, task_id, args: List[Any], kwargs: Dict[str, Any], einfo
self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo
):
if not self._result.uid:
self._result.uid = self._uid

View file

@ -1,6 +1,6 @@
"""authentik events signal listener"""
from threading import Thread
from typing import Any, Dict, Optional
from typing import Any, Optional
from django.contrib.auth.signals import (
user_logged_in,
@ -27,7 +27,7 @@ class EventNewThread(Thread):
action: str
request: HttpRequest
kwargs: Dict[str, Any]
kwargs: dict[str, Any]
user: Optional[User] = None
def __init__(
@ -69,7 +69,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
@receiver(user_write)
# pylint: disable=unused-argument
def on_user_write(
sender, request: HttpRequest, user: User, data: Dict[str, Any], **kwargs
sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs
):
"""Log User write"""
thread = EventNewThread(EventAction.USER_WRITE, request, **data)
@ -81,7 +81,7 @@ def on_user_write(
@receiver(user_login_failed)
# pylint: disable=unused-argument
def on_user_login_failed(
sender, credentials: Dict[str, str], request: HttpRequest, **_
sender, credentials: dict[str, str], request: HttpRequest, **_
):
"""Failed Login"""
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)

View file

@ -1,7 +1,7 @@
"""event utilities"""
import re
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Optional
from typing import Any, Optional
from uuid import UUID
from django.contrib.auth.models import AnonymousUser
@ -20,7 +20,7 @@ from authentik.policies.types import PolicyRequest
ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I)
def cleanse_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]:
"""Cleanse a dictionary, recursively"""
final_dict = {}
for key, value in source.items():
@ -38,7 +38,7 @@ def cleanse_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
return final_dict
def model_to_dict(model: Model) -> Dict[str, Any]:
def model_to_dict(model: Model) -> dict[str, Any]:
"""Convert model to dict"""
name = str(model)
if hasattr(model, "name"):
@ -51,7 +51,7 @@ def model_to_dict(model: Model) -> Dict[str, Any]:
}
def get_user(user: User, original_user: Optional[User] = None) -> Dict[str, Any]:
def get_user(user: User, original_user: Optional[User] = None) -> dict[str, Any]:
"""Convert user object to dictionary, optionally including the original user"""
if isinstance(user, AnonymousUser):
user = get_anonymous_user()
@ -67,7 +67,7 @@ def get_user(user: User, original_user: Optional[User] = None) -> Dict[str, Any]
return user_data
def sanitize_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
"""clean source of all Models that would interfere with the JSONField.
Models are replaced with a dictionary of {
app: str,

View file

@ -1,6 +1,6 @@
"""Flows Planner"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from django.core.cache import cache
from django.http import HttpRequest
@ -38,9 +38,9 @@ class FlowPlan:
flow_pk: str
stages: List[Stage] = field(default_factory=list)
context: Dict[str, Any] = field(default_factory=dict)
markers: List[StageMarker] = field(default_factory=list)
stages: list[Stage] = field(default_factory=list)
context: dict[str, Any] = field(default_factory=dict)
markers: list[StageMarker] = field(default_factory=list)
def append(self, stage: Stage, marker: Optional[StageMarker] = None):
"""Append `stage` to all stages, optionall with stage marker"""
@ -96,7 +96,7 @@ class FlowPlanner:
self._logger = get_logger().bind(flow=flow)
def plan(
self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None
self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
) -> FlowPlan:
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
and return ordered list"""
@ -149,7 +149,7 @@ class FlowPlanner:
self,
user: User,
request: HttpRequest,
default_context: Optional[Dict[str, Any]],
default_context: Optional[dict[str, Any]],
) -> FlowPlan:
"""Build flow plan by checking each stage in their respective
order and checking the applied policies"""

View file

@ -1,6 +1,6 @@
"""authentik stage Base view"""
from collections import namedtuple
from typing import Any, Dict
from typing import Any
from django.http import HttpRequest
from django.http.response import HttpResponse, JsonResponse
@ -32,7 +32,7 @@ class StageView(TemplateView):
def __init__(self, executor: FlowExecutorView):
self.executor = executor
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["title"] = self.executor.flow.title
# Either show the matched User object or show what the user entered,
# based on what the earlier stage (mostly IdentificationStage) set.

View file

@ -1,6 +1,6 @@
"""transfer common classes"""
from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Dict, List
from typing import Any
from uuid import UUID
from django.core.serializers.json import DjangoJSONEncoder
@ -9,7 +9,7 @@ from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException
def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
def get_attrs(obj: SerializerModel) -> dict[str, Any]:
"""Get object's attributes via their serializer, and covert it to a normal dict"""
data = dict(obj.serializer(obj).data)
to_remove = (
@ -33,9 +33,9 @@ def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
class FlowBundleEntry:
"""Single entry of a bundle"""
identifiers: Dict[str, Any]
identifiers: dict[str, Any]
model: str
attrs: Dict[str, Any]
attrs: dict[str, Any]
@staticmethod
def from_model(
@ -61,7 +61,7 @@ class FlowBundle:
"""Dataclass used for a full export"""
version: int = field(default=1)
entries: List[FlowBundleEntry] = field(default_factory=list)
entries: list[FlowBundleEntry] = field(default_factory=list)
class DataclassEncoder(DjangoJSONEncoder):

View file

@ -1,6 +1,6 @@
"""Flow exporter"""
from json import dumps
from typing import Iterator, List
from typing import Iterator
from uuid import UUID
from django.db.models import Q
@ -22,7 +22,7 @@ class FlowExporter:
with_policies: bool
with_stage_prompts: bool
pbm_uuids: List[UUID]
pbm_uuids: list[UUID]
def __init__(self, flow: Flow):
self.flow = flow

View file

@ -2,7 +2,7 @@
from contextlib import contextmanager
from copy import deepcopy
from json import loads
from typing import Any, Dict, Type
from typing import Any, Type
from dacite import from_dict
from dacite.exceptions import DaciteError
@ -42,7 +42,7 @@ class FlowImporter:
__import: FlowBundle
__pk_map: Dict[Any, Model]
__pk_map: dict[Any, Model]
logger: BoundLogger
@ -55,7 +55,7 @@ class FlowImporter:
except DaciteError as exc:
raise EntryInvalidError from exc
def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Replace any value if it is a known primary key of an other object"""
def updater(value) -> Any:
@ -75,7 +75,7 @@ class FlowImporter:
attrs[key] = updater(value)
return attrs
def __query_from_identifier(self, attrs: Dict[str, Any]) -> Q:
def __query_from_identifier(self, attrs: dict[str, Any]) -> Q:
"""Generate an or'd query from all identifiers in an entry"""
# Since identifiers can also be pk-references to other objects (see FlowStageBinding)
# we have to ensure those references are also replaced

View file

@ -1,6 +1,6 @@
"""authentik multi-stage authentication engine"""
from traceback import format_tb
from typing import Any, Dict, Optional
from typing import Any, Optional
from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import (
@ -225,8 +225,8 @@ class FlowErrorResponse(TemplateResponse):
self.error = error
def resolve_context(
self, context: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
self, context: Optional[dict[str, Any]]
) -> Optional[dict[str, Any]]:
if not context:
context = {}
context["error"] = self.error
@ -244,7 +244,7 @@ class FlowExecutorShellView(TemplateView):
template_name = "flows/shell.html"
def get_context_data(self, **kwargs) -> Dict[str, Any]:
def get_context_data(self, **kwargs) -> dict[str, Any]:
flow: Flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug"))
kwargs["background_url"] = flow.background.url
kwargs["exec_url"] = reverse("authentik_api:flow-executor", kwargs=self.kwargs)

View file

@ -1,7 +1,7 @@
"""authentik expression policy evaluator"""
import re
from textwrap import indent
from typing import Any, Dict, Iterable, Optional
from typing import Any, Iterable, Optional
from django.core.exceptions import ValidationError
from requests import Session
@ -18,9 +18,9 @@ class BaseEvaluator:
"""Validate and evaluate python-based expressions"""
# Globals that can be used by function
_globals: Dict[str, Any]
_globals: dict[str, Any]
# Context passed as locals to exec()
_context: Dict[str, Any]
_context: dict[str, Any]
# Filename used for exec
_filename: str

View file

@ -1,10 +1,10 @@
"""http helpers"""
from typing import Any, Dict, Optional
from typing import Any, Optional
from django.http import HttpRequest
def _get_client_ip_from_meta(meta: Dict[str, Any]) -> Optional[str]:
def _get_client_ip_from_meta(meta: dict[str, Any]) -> Optional[str]:
"""Attempt to get the client's IP by checking common HTTP Headers.
Returns none if no IP Could be found"""
headers = (

View file

@ -1,8 +1,8 @@
"""authentik UI utils"""
from typing import Any, List
from typing import Any
def human_list(_list: List[Any]) -> str:
def human_list(_list: list[Any]) -> str:
"""Convert a list of items into 'a, b or c'"""
last_item = _list.pop()
if len(_list) < 1:

View file

@ -2,7 +2,7 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import IntEnum
from typing import Any, Dict, Optional
from typing import Any, Optional
from channels.exceptions import DenyConnection
from dacite import from_dict
@ -34,7 +34,7 @@ class WebsocketMessage:
"""Complete Websocket Message that is being sent"""
instruction: int
args: Dict[str, Any] = field(default_factory=dict)
args: dict[str, Any] = field(default_factory=dict)
class OutpostConsumer(AuthJsonConsumer):

View file

@ -1,6 +1,5 @@
"""Docker controller"""
from time import sleep
from typing import Dict, Tuple
from django.conf import settings
from docker import DockerClient
@ -33,10 +32,10 @@ class DockerController(BaseController):
except ServiceConnectionInvalid as exc:
raise ControllerException from exc
def _get_labels(self) -> Dict[str, str]:
def _get_labels(self) -> dict[str, str]:
return {}
def _get_env(self) -> Dict[str, str]:
def _get_env(self) -> dict[str, str]:
return {
"AUTHENTIK_HOST": self.outpost.config.authentik_host,
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
@ -55,7 +54,7 @@ class DockerController(BaseController):
return True
return False
def _get_container(self) -> Tuple[Container, bool]:
def _get_container(self) -> tuple[Container, bool]:
container_name = f"authentik-proxy-{self.outpost.uuid.hex}"
try:
return self.client.containers.get(container_name), False

View file

@ -1,5 +1,5 @@
"""Kubernetes Deployment Reconciler"""
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from kubernetes.client import (
AppsV1Api,
@ -53,7 +53,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
):
raise NeedsUpdate()
def get_pod_meta(self) -> Dict[str, str]:
def get_pod_meta(self) -> dict[str, str]:
"""Get common object metadata"""
return {
"app.kubernetes.io/name": "authentik-outpost",

View file

@ -1,6 +1,6 @@
"""Kubernetes deployment controller"""
from io import StringIO
from typing import Dict, List, Type
from typing import Type
from kubernetes.client import OpenApiException
from kubernetes.client.api_client import ApiClient
@ -18,8 +18,8 @@ from authentik.outposts.models import KubernetesServiceConnection, Outpost
class KubernetesController(BaseController):
"""Manage deployment of outpost in kubernetes"""
reconcilers: Dict[str, Type[KubernetesObjectReconciler]]
reconcile_order: List[str]
reconcilers: dict[str, Type[KubernetesObjectReconciler]]
reconcile_order: list[str]
client: ApiClient
connection: KubernetesServiceConnection
@ -45,7 +45,7 @@ class KubernetesController(BaseController):
except OpenApiException as exc:
raise ControllerException from exc
def up_with_logs(self) -> List[str]:
def up_with_logs(self) -> list[str]:
try:
all_logs = []
for reconcile_key in self.reconcile_order:

View file

@ -1,7 +1,7 @@
"""Outpost models"""
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Type, Union
from typing import Iterable, Optional, Type, Union
from uuid import uuid4
from dacite import from_dict
@ -58,7 +58,7 @@ class OutpostConfig:
kubernetes_replicas: int = field(default=1)
kubernetes_namespace: str = field(default="default")
kubernetes_ingress_annotations: Dict[str, str] = field(default_factory=dict)
kubernetes_ingress_annotations: dict[str, str] = field(default_factory=dict)
kubernetes_ingress_secret_name: str = field(default="authentik-outpost")
@ -315,7 +315,7 @@ class Outpost(models.Model):
return f"outpost_{self.uuid.hex}_state"
@property
def state(self) -> List["OutpostState"]:
def state(self) -> list["OutpostState"]:
"""Get outpost's health status"""
return OutpostState.for_outpost(self)
@ -399,7 +399,7 @@ class OutpostState:
return parse(self.version) < OUR_VERSION
@staticmethod
def for_outpost(outpost: Outpost) -> List["OutpostState"]:
def for_outpost(outpost: Outpost) -> list["OutpostState"]:
"""Get all states for an outpost"""
keys = cache.keys(f"{outpost.state_cache_prefix}_*")
states = []

View file

@ -2,7 +2,7 @@
from enum import Enum
from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection
from typing import Iterator, List, Optional
from typing import Iterator, Optional
from django.core.cache import cache
from django.http import HttpRequest
@ -54,8 +54,8 @@ class PolicyEngine:
empty_result: bool
__pbm: PolicyBindingModel
__cached_policies: List[PolicyResult]
__processes: List[PolicyProcessInfo]
__cached_policies: list[PolicyResult]
__processes: list[PolicyProcessInfo]
__expected_result_count: int
@ -137,7 +137,7 @@ class PolicyEngine:
@property
def result(self) -> PolicyResult:
"""Get policy-checking result"""
process_results: List[PolicyResult] = [
process_results: list[PolicyResult] = [
x.result for x in self.__processes if x.result
]
all_results = list(process_results + self.__cached_policies)

View file

@ -1,6 +1,6 @@
"""authentik expression policy evaluator"""
from ipaddress import ip_address, ip_network
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from django.http import HttpRequest
from structlog.stdlib import get_logger
@ -19,7 +19,7 @@ if TYPE_CHECKING:
class PolicyEvaluator(BaseEvaluator):
"""Validate and evaluate python-based expressions"""
_messages: List[str]
_messages: list[str]
policy: Optional["ExpressionPolicy"] = None

View file

@ -1,5 +1,5 @@
"""policy http response"""
from typing import Any, Dict, Optional
from typing import Any, Optional
from django.http.request import HttpRequest
from django.template.response import TemplateResponse
@ -24,8 +24,8 @@ class AccessDeniedResponse(TemplateResponse):
self.title = _("Access denied")
def resolve_context(
self, context: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
self, context: Optional[dict[str, Any]]
) -> Optional[dict[str, Any]]:
if not context:
context = {}
context["title"] = self.title

View file

@ -1,8 +1,8 @@
"""Policy Utils"""
from typing import Any, Dict
from typing import Any
def delete_none_keys(dict_: Dict[Any, Any]) -> Dict[Any, Any]:
def delete_none_keys(dict_: dict[Any, Any]) -> dict[Any, Any]:
"""Remove any keys from `dict_` that are None."""
new_dict = {}
for key, value in dict_.items():

View file

@ -5,7 +5,7 @@ import json
import time
from dataclasses import asdict, dataclass, field
from hashlib import sha256
from typing import Any, Dict, List, Optional, Type
from typing import Any, Optional, Type
from urllib.parse import urlparse
from uuid import uuid4
@ -218,7 +218,7 @@ class OAuth2Provider(Provider):
)
def create_refresh_token(
self, user: User, scope: List[str], request: HttpRequest
self, user: User, scope: list[str], request: HttpRequest
) -> "RefreshToken":
"""Create and populate a RefreshToken object."""
token = RefreshToken(
@ -231,7 +231,7 @@ class OAuth2Provider(Provider):
token.access_token = token.create_access_token(user, request)
return token
def get_jwt_keys(self) -> List[Key]:
def get_jwt_keys(self) -> list[Key]:
"""
Takes a provider and returns the set of keys associated with it.
Returns a list of keys.
@ -299,7 +299,7 @@ class OAuth2Provider(Provider):
def __str__(self):
return f"OAuth2 Provider {self.name}"
def encode(self, payload: Dict[str, Any]) -> str:
def encode(self, payload: dict[str, Any]) -> str:
"""Represent the ID Token as a JSON Web Token (JWT)."""
keys = self.get_jwt_keys()
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
@ -321,7 +321,7 @@ class BaseGrantModel(models.Model):
_scope = models.TextField(default="", verbose_name=_("Scopes"))
@property
def scope(self) -> List[str]:
def scope(self) -> list[str]:
"""Return scopes as list of strings"""
return self._scope.split()
@ -394,9 +394,9 @@ class IDToken:
nonce: Optional[str] = None
at_hash: Optional[str] = None
claims: Dict[str, Any] = field(default_factory=dict)
claims: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert dataclass to dict, and update with keys from `claims`"""
dic = asdict(self)
dic.pop("claims")

View file

@ -2,7 +2,7 @@
import re
from base64 import b64decode
from binascii import Error
from typing import List, Optional, Tuple
from typing import Optional
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils.cache import patch_vary_headers
@ -68,7 +68,7 @@ def extract_access_token(request: HttpRequest) -> Optional[str]:
return None
def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
def extract_client_auth(request: HttpRequest) -> tuple[str, str]:
"""
Get client credentials using HTTP Basic Authentication method.
Or try getting parameters via POST.
@ -92,7 +92,7 @@ def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
return (client_id, client_secret)
def protected_resource_view(scopes: List[str]):
def protected_resource_view(scopes: list[str]):
"""View decorator. The client accesses protected resources by presenting the
access token to the resource server.

View file

@ -1,7 +1,7 @@
"""authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field
from datetime import timedelta
from typing import List, Optional, Set
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
from uuid import uuid4
@ -69,10 +69,10 @@ class OAuthAuthorizationParams:
client_id: str
redirect_uri: str
response_type: str
scope: List[str]
scope: list[str]
state: str
nonce: Optional[str]
prompt: Set[str]
prompt: set[str]
grant_type: str
provider: OAuth2Provider = field(default_factory=OAuth2Provider)

View file

@ -1,5 +1,5 @@
"""authentik OAuth2 OpenID well-known views"""
from typing import Any, Dict
from typing import Any
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.shortcuts import get_object_or_404, reverse
@ -29,7 +29,7 @@ PLAN_CONTEXT_SCOPES = "scopes"
class ProviderInfoView(View):
"""OpenID-compliant Provider Info"""
def get_info(self, provider: OAuth2Provider) -> Dict[str, Any]:
def get_info(self, provider: OAuth2Provider) -> dict[str, Any]:
"""Get dictionary for OpenID Connect information"""
scopes = list(
ScopeMapping.objects.filter(provider=provider).values_list(

View file

@ -1,5 +1,5 @@
"""authentik OAuth2 Session Views"""
from typing import Any, Dict
from typing import Any
from django.shortcuts import get_object_or_404
from django.views.generic.base import TemplateView
@ -12,7 +12,7 @@ class EndSessionView(TemplateView):
template_name = "providers/oauth2/end_session.html"
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
context = super().get_context_data(**kwargs)
context["application"] = get_object_or_404(

View file

@ -2,7 +2,7 @@
from base64 import urlsafe_b64encode
from dataclasses import InitVar, dataclass
from hashlib import sha256
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from django.http import HttpRequest, HttpResponse
from django.views import View
@ -33,7 +33,7 @@ class TokenParams:
redirect_uri: str
grant_type: str
state: str
scope: List[str]
scope: list[str]
authorization_code: Optional[AuthorizationCode] = None
refresh_token: Optional[RefreshToken] = None
@ -171,7 +171,7 @@ class TokenView(View):
except UserAuthError as error:
return TokenResponse(error.create_dict(), status=403)
def create_code_response_dic(self) -> Dict[str, Any]:
def create_code_response_dic(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-4.1"""
refresh_token = self.params.authorization_code.provider.create_refresh_token(
@ -207,7 +207,7 @@ class TokenView(View):
return response_dict
def create_refresh_response_dic(self) -> Dict[str, Any]:
def create_refresh_response_dic(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-6"""
unauthorized_scopes = set(self.params.scope) - set(

View file

@ -1,5 +1,5 @@
"""authentik OAuth2 OpenID Userinfo views"""
from typing import Any, Dict, List
from typing import Any
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext_lazy as _
@ -22,7 +22,7 @@ class UserInfoView(View):
"""Create a dictionary with all the requested claims about the End-User.
See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse"""
def get_scope_descriptions(self, scopes: List[str]) -> Dict[str, str]:
def get_scope_descriptions(self, scopes: list[str]) -> dict[str, str]:
"""Get a list of all Scopes's descriptions"""
scope_descriptions = {}
for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by(
@ -47,7 +47,7 @@ class UserInfoView(View):
scope_descriptions[scope] = github_scope_map[scope]
return scope_descriptions
def get_claims(self, token: RefreshToken) -> Dict[str, Any]:
def get_claims(self, token: RefreshToken) -> dict[str, Any]:
"""Get a dictionary of claims from scopes that the token
requires and are assigned to the provider."""

View file

@ -1,5 +1,4 @@
"""Proxy Provider Docker Contoller"""
from typing import Dict
from urllib.parse import urlparse
from authentik.outposts.controllers.base import DeploymentPort
@ -18,7 +17,7 @@ class ProxyDockerController(DockerController):
DeploymentPort(4443, "https", "tcp"),
]
def _get_labels(self) -> Dict[str, str]:
def _get_labels(self) -> dict[str, str]:
hosts = []
for proxy_provider in ProxyProvider.objects.filter(outpost__in=[self.outpost]):
proxy_provider: ProxyProvider

View file

@ -1,5 +1,5 @@
"""Kubernetes Ingress Reconciler"""
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from kubernetes.client import (
@ -78,7 +78,7 @@ class IngressReconciler(KubernetesObjectReconciler[NetworkingV1beta1Ingress]):
if have_hosts_tls != expected_hosts_tls:
raise NeedsUpdate()
def get_ingress_annotations(self) -> Dict[str, str]:
def get_ingress_annotations(self) -> dict[str, str]:
"""Get ingress annotations"""
annotations = {
# Ensure that with multiple proxy replicas deployed, the same CSRF request

View file

@ -8,7 +8,7 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
"""
import typing
from time import time
from typing import Any, ByteString, Dict
from typing import Any, ByteString
import django
from asgiref.compatibility import guarantee_single_callable
@ -64,7 +64,7 @@ class ASGILogger:
app: ASGIApp
scope: Scope
headers: Dict[ByteString, Any]
headers: dict[ByteString, Any]
status_code: int
start: float

View file

@ -1,5 +1,5 @@
"""authentik ldap source signals"""
from typing import Any, Dict
from typing import Any
from django.core.exceptions import ValidationError
from django.db.models.signals import post_save
@ -26,7 +26,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
@receiver(password_validate)
# pylint: disable=unused-argument
def ldap_password_validate(sender, password: str, plan_context: Dict[str, Any], **__):
def ldap_password_validate(sender, password: str, plan_context: dict[str, Any], **__):
"""if there's an LDAP Source with enabled password sync, check the password"""
sources = LDAPSource.objects.filter(sync_users_password=True)
if not sources.exists():

View file

@ -1,5 +1,5 @@
"""OAuth Clients"""
from typing import Any, Dict, Optional
from typing import Any, Optional
from urllib.parse import urlencode
from django.http import HttpRequest
@ -33,11 +33,11 @@ class BaseOAuthClient:
self.callback = callback
self.session.headers.update({"User-Agent": f"authentik {__version__}"})
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
try:
response = self.do_request("get", self.source.profile_url, token=token)
@ -48,7 +48,7 @@ class BaseOAuthClient:
else:
return response.json()
def get_redirect_args(self) -> Dict[str, str]:
def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
@ -61,7 +61,7 @@ class BaseOAuthClient:
LOGGER.info("redirect args", **args)
return f"{self.source.authorization_url}?{params}"
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover

View file

@ -1,5 +1,5 @@
"""OAuth 1 Clients"""
from typing import Any, Dict, Optional
from typing import Any, Optional
from urllib.parse import parse_qsl
from requests.exceptions import RequestException
@ -20,7 +20,7 @@ class OAuthClient(BaseOAuthClient):
"Accept": "application/json",
}
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request."
raw_token = self.request.session.get(self.session_key, None)
verifier = self.request.GET.get("oauth_verifier", None)
@ -60,7 +60,7 @@ class OAuthClient(BaseOAuthClient):
else:
return response.text
def get_redirect_args(self) -> Dict[str, Any]:
def get_redirect_args(self) -> dict[str, Any]:
"Get request parameters for redirect url."
callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token()
@ -71,7 +71,7 @@ class OAuthClient(BaseOAuthClient):
"oauth_callback": callback,
}
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response."
return dict(parse_qsl(raw_token))
@ -80,7 +80,7 @@ class OAuthClient(BaseOAuthClient):
resource_owner_key = None
resource_owner_secret = None
if "token" in kwargs:
user_token: Dict[str, Any] = kwargs.pop("token")
user_token: dict[str, Any] = kwargs.pop("token")
resource_owner_key = user_token["oauth_token"]
resource_owner_secret = user_token["oauth_token_secret"]

View file

@ -1,6 +1,6 @@
"""OAuth 2 Clients"""
from json import loads
from typing import Any, Dict, Optional
from typing import Any, Optional
from urllib.parse import parse_qsl
from django.utils.crypto import constant_time_compare, get_random_string
@ -38,7 +38,7 @@ class OAuth2Client(BaseOAuthClient):
"Generate state optional parameter."
return get_random_string(32)
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request."
callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state():
@ -69,11 +69,11 @@ class OAuth2Client(BaseOAuthClient):
else:
return response.json()
def get_redirect_args(self) -> Dict[str, str]:
def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url."
callback = self.request.build_absolute_uri(self.callback)
client_id: str = self.source.consumer_key
args: Dict[str, str] = {
args: dict[str, str] = {
"client_id": client_id,
"redirect_uri": callback,
"response_type": "code",
@ -84,7 +84,7 @@ class OAuth2Client(BaseOAuthClient):
self.request.session[self.session_key] = state
return args
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response."
# Load as json first then parse as query string
try:

View file

@ -1,5 +1,5 @@
"""AzureAD OAuth2 Views"""
from typing import Any, Dict
from typing import Any
from uuid import UUID
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
@ -11,15 +11,15 @@ from authentik.sources.oauth.views.callback import OAuthCallback
class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str:
def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> str:
return str(UUID(info.get("objectId")).int)
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
return {
"username": info.get("displayName"),

View file

@ -1,5 +1,5 @@
"""Discord OAuth Views"""
from typing import Any, Dict
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -25,8 +25,8 @@ class DiscordOAuth2Callback(OAuthCallback):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("username"),
"email": info.get("email", None),

View file

@ -1,5 +1,5 @@
"""Facebook OAuth Views"""
from typing import Any, Dict, Optional
from typing import Any, Optional
from facebook import GraphAPI
@ -23,7 +23,7 @@ class FacebookOAuthRedirect(OAuthRedirect):
class FacebookOAuth2Client(OAuth2Client):
"""Facebook OAuth2 Client"""
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
api = GraphAPI(access_token=token["access_token"])
return api.get_object("me", fields="id,name,email")
@ -38,8 +38,8 @@ class FacebookOAuth2Callback(OAuthCallback):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("name"),
"email": info.get("email"),

View file

@ -1,5 +1,5 @@
"""GitHub OAuth Views"""
from typing import Any, Dict
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -14,8 +14,8 @@ class GitHubOAuth2Callback(OAuthCallback):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("login"),
"email": info.get("email"),

View file

@ -1,5 +1,5 @@
"""Google OAuth Views"""
from typing import Any, Dict
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -25,8 +25,8 @@ class GoogleOAuth2Callback(OAuthCallback):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("email"),
"email": info.get("email"),

View file

@ -1,6 +1,6 @@
"""Source type manager"""
from enum import Enum
from typing import Callable, Dict, List
from typing import Callable
from django.utils.text import slugify
from structlog.stdlib import get_logger
@ -22,8 +22,8 @@ class RequestKind(Enum):
class SourceTypeManager:
"""Manager to hold all Source types."""
__source_types: Dict[RequestKind, Dict[str, Callable]] = {}
__names: List[str] = []
__source_types: dict[RequestKind, dict[str, Callable]] = {}
__names: list[str] = []
def source(self, kind: RequestKind, name: str):
"""Class decorator to register classes inline."""

View file

@ -1,5 +1,5 @@
"""OpenID Connect OAuth Views"""
from typing import Any, Dict
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -21,15 +21,15 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str:
def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("nickname"),
"email": info.get("email"),

View file

@ -1,5 +1,5 @@
"""Reddit OAuth Views"""
from typing import Any, Dict
from typing import Any
from requests.auth import HTTPBasicAuth
@ -40,8 +40,8 @@ class RedditOAuth2Callback(OAuthCallback):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("name"),
"email": None,

View file

@ -1,5 +1,5 @@
"""Twitter OAuth Views"""
from typing import Any, Dict
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -14,8 +14,8 @@ class TwitterOAuthCallback(OAuthCallback):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
return {
"username": info.get("screen_name"),
"email": info.get("email", None),

View file

@ -1,5 +1,5 @@
"""OAuth Callback Views"""
from typing import Any, Dict, Optional
from typing import Any, Optional
from django.conf import settings
from django.contrib import messages
@ -115,14 +115,14 @@ class OAuthCallback(OAuthClientMixin, View):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
info: dict[str, Any],
) -> dict[str, Any]:
"""Create a dict of User data"""
raise NotImplementedError()
# pylint: disable=unused-argument
def get_user_id(
self, source: UserOAuthSourceConnection, info: Dict[str, Any]
self, source: UserOAuthSourceConnection, info: dict[str, Any]
) -> Optional[str]:
"""Return unique identifier from the profile info."""
if "id" in info:
@ -167,7 +167,7 @@ class OAuthCallback(OAuthClientMixin, View):
source: OAuthSource,
user: User,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
info: dict[str, Any],
) -> HttpResponse:
"Login user and redirect."
messages.success(
@ -184,7 +184,7 @@ class OAuthCallback(OAuthClientMixin, View):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
info: dict[str, Any],
) -> HttpResponse:
"""Handler when the user was already authenticated and linked an external source
to their account."""
@ -211,7 +211,7 @@ class OAuthCallback(OAuthClientMixin, View):
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
info: dict[str, Any],
) -> HttpResponse:
"""User was not authenticated and previous request was not authenticated."""
messages.success(

View file

@ -1,5 +1,5 @@
"""OAuth Redirect Views"""
from typing import Any, Dict
from typing import Any
from django.http import Http404
from django.urls import reverse
@ -19,7 +19,7 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
params = None
# pylint: disable=unused-argument
def get_additional_parameters(self, source: OAuthSource) -> Dict[str, Any]:
def get_additional_parameters(self, source: OAuthSource) -> dict[str, Any]:
"Return additional redirect parameters for this source."
return self.params or {}

View file

@ -1,6 +1,5 @@
"""SAML AuthnRequest Processor"""
from base64 import b64encode
from typing import Dict
from urllib.parse import quote_plus
import xmlsec
@ -125,7 +124,7 @@ class RequestProcessor:
return etree.tostring(auth_n_request).decode()
def build_auth_n_detached(self) -> Dict[str, str]:
def build_auth_n_detached(self) -> dict[str, str]:
"""Get Dict AuthN Request for Redirect bindings, with detached
Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf"""
auth_n_request = self.get_auth_n()

View file

@ -1,6 +1,6 @@
"""authentik saml source processor"""
from base64 import b64decode
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
import xmlsec
from defusedxml.lxml import fromstring
@ -154,7 +154,7 @@ class ResponseProcessor:
raise ValueError("NameID Element not found!")
return name_id
def _get_name_id_filter(self) -> Dict[str, str]:
def _get_name_id_filter(self) -> dict[str, str]:
"""Returns the subject's NameID as a Filter for the `User`"""
name_id_el = self._get_name_id()
name_id = name_id_el.text

View file

@ -1,5 +1,5 @@
"""Static OTP Setup stage"""
from typing import Any, Dict
from typing import Any
from django.http import HttpRequest, HttpResponse
from django.views.generic import FormView
@ -21,7 +21,7 @@ class AuthenticatorStaticStageView(FormView, StageView):
form_class = SetupForm
def get_form_kwargs(self, **kwargs) -> Dict[str, Any]:
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs)
tokens = self.request.session[SESSION_STATIC_TOKENS]
kwargs["tokens"] = tokens

View file

@ -1,5 +1,5 @@
"""TOTP Setup stage"""
from typing import Any, Dict
from typing import Any
from django.http import HttpRequest, HttpResponse
from django.utils.encoding import force_str
@ -24,7 +24,7 @@ class AuthenticatorTOTPStageView(FormView, StageView):
form_class = SetupForm
def get_form_kwargs(self, **kwargs) -> Dict[str, Any]:
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs)
device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE]
kwargs["device"] = device

View file

@ -1,5 +1,5 @@
"""OTP Validation"""
from typing import Any, Dict
from typing import Any
from django.http import HttpRequest, HttpResponse
from django.views.generic import FormView
@ -32,6 +32,11 @@ class AuthenticatorValidateStageView(ChallengeStageView):
form_class = ValidationForm
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs)
kwargs["user"] = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
return kwargs
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Check if a user is set, and check if the user has any devices
if not, we can skip this entire stage"""

View file

@ -1,5 +1,5 @@
"""authentik consent stage"""
from typing import Any, Dict, List
from typing import Any
from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now
@ -19,13 +19,13 @@ class ConsentStageView(FormView, StageView):
form_class = ConsentForm
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
kwargs["current_stage"] = self.executor.current_stage
kwargs["context"] = self.executor.plan.context
return kwargs
def get_template_names(self) -> List[str]:
def get_template_names(self) -> list[str]:
# PLAN_CONTEXT_CONSENT_TEMPLATE has to be set by a template that calls this stage
if PLAN_CONTEXT_CONSENT_TEMPLATE in self.executor.plan.context:
template_name = self.executor.plan.context[PLAN_CONTEXT_CONSENT_TEMPLATE]

View file

@ -1,5 +1,5 @@
"""authentik multi-stage authentication engine"""
from typing import Any, Dict
from typing import Any
from django.http import HttpRequest
@ -13,7 +13,7 @@ class DummyStageView(StageView):
"""Just redirect to next stage"""
return self.executor.stage_ok()
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
kwargs["title"] = self.executor.current_stage.name
return kwargs

View file

@ -1,5 +1,5 @@
"""Identification stage logic"""
from typing import List, Optional
from typing import Optional
from django.contrib import messages
from django.db.models import Q
@ -75,7 +75,7 @@ class IdentificationStageView(ChallengeStageView):
# Check all enabled source, add them if they have a UI Login button.
args["sources"] = []
sources: List[Source] = (
sources: list[Source] = (
Source.objects.filter(enabled=True).order_by("name").select_subclasses()
)
for source in sources:

View file

@ -1,5 +1,5 @@
"""authentik password stage"""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from django.contrib.auth import _clean_credentials
from django.contrib.auth.backends import BaseBackend
@ -24,7 +24,7 @@ SESSION_INVALID_TRIES = "user_invalid_tries"
def authenticate(
request: HttpRequest, backends: List[str], **credentials: Dict[str, Any]
request: HttpRequest, backends: list[str], **credentials: dict[str, Any]
) -> Optional[User]:
"""If the given credentials are valid, return a User object.

View file

@ -1,7 +1,7 @@
"""Prompt forms"""
from email.policy import Policy
from types import MethodType
from typing import Any, Callable, Iterator, List
from typing import Any, Callable, Iterator
from django import forms
from django.db.models.query import QuerySet
@ -52,10 +52,10 @@ class PromptAdminForm(forms.ModelForm):
class ListPolicyEngine(PolicyEngine):
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
__list: List[Policy]
__list: list[Policy]
def __init__(
self, policies: List[Policy], user: User, request: HttpRequest = None
self, policies: list[Policy], user: User, request: HttpRequest = None
) -> None:
super().__init__(PolicyBindingModel(), user, request)
self.__list = policies

View file

@ -1,5 +1,5 @@
"""authentik prompt stage signals"""
from django.core.signals import Signal
# Arguments: password: str, plan_context: Dict[str, Any]
# Arguments: password: str, plan_context: dict[str, Any]
password_validate = Signal()

View file

@ -1,5 +1,5 @@
"""authentik user_write signals"""
from django.core.signals import Signal
# Arguments: request: HttpRequest, user: User, data: Dict[str, Any], created: bool
# Arguments: request: HttpRequest, user: User, data: dict[str, Any], created: bool
user_write = Signal()

View file

@ -1,6 +1,6 @@
"""Test Enroll flow"""
from sys import platform
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest.case import skipUnless
from django.test import override_settings
@ -22,7 +22,7 @@ from tests.e2e.utils import USER, SeleniumTestCase, retry
class TestFlowsEnroll(SeleniumTestCase):
"""Test Enroll flow"""
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "mailhog/mailhog:v1.0.1",
"detach": True,

View file

@ -1,7 +1,7 @@
"""test OAuth Provider flow"""
from sys import platform
from time import sleep
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck
@ -30,7 +30,7 @@ class TestProviderOAuth2Github(SeleniumTestCase):
self.client_secret = generate_client_secret()
super().setUp()
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
"""Setup client grafana container which we test OAuth against"""
return {
"image": "grafana/grafana:7.1.0",

View file

@ -1,7 +1,7 @@
"""test OAuth2 OpenID Provider flow"""
from sys import platform
from time import sleep
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck
@ -40,7 +40,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
self.client_secret = generate_client_secret()
super().setUp()
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "grafana/grafana:7.1.0",
"detach": True,

View file

@ -2,7 +2,7 @@
from dataclasses import asdict
from sys import platform
from time import sleep
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest.case import skipUnless
from channels.testing import ChannelsLiveServerTestCase
@ -35,7 +35,7 @@ class TestProviderProxy(SeleniumTestCase):
super().tearDown()
self.proxy_container.kill()
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "traefik/whoami:latest",
"detach": True,

View file

@ -2,7 +2,7 @@
from os.path import abspath
from sys import platform
from time import sleep
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest.case import skipUnless
from django.test import override_settings
@ -72,7 +72,7 @@ class TestSourceOAuth2(SeleniumTestCase):
with open(CONFIG_PATH, "w+") as _file:
safe_dump(config, _file)
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "quay.io/dexidp/dex:v2.24.0",
"detach": True,
@ -249,7 +249,7 @@ class TestSourceOAuth1(SeleniumTestCase):
self.source_slug = "oauth1-test"
super().setUp()
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "beryju/oauth1-test-server",
"detach": True,

View file

@ -1,7 +1,7 @@
"""test SAML Source"""
from sys import platform
from time import sleep
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest.case import skipUnless
from docker.types import Healthcheck
@ -73,7 +73,7 @@ Sm75WXsflOxuTn08LbgGc4s=
class TestSourceSAML(SeleniumTestCase):
"""test SAML Source flow"""
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
return {
"image": "kristophjunge/test-saml-idp:1.15",
"detach": True,

View file

@ -6,7 +6,7 @@ from importlib.util import module_from_spec, spec_from_file_location
from inspect import getmembers, isfunction
from os import environ, makedirs
from time import sleep, time
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from django.apps import apps
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
@ -56,7 +56,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
if specs := self.get_container_specs():
self.container = self._start_container(specs)
def _start_container(self, specs: Dict[str, Any]) -> Container:
def _start_container(self, specs: dict[str, Any]) -> Container:
client: DockerClient = from_env()
client.images.pull(specs["image"])
container = client.containers.run(**specs)
@ -70,7 +70,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
self.logger.info("Container failed healthcheck")
sleep(1)
def get_container_specs(self) -> Optional[Dict[str, Any]]:
def get_container_specs(self) -> Optional[dict[str, Any]]:
"""Optionally get container specs which will launched on setup, wait for the container to
be healthy, and deleted again on tearDown"""
return None