diff --git a/authentik/core/api/applications_transactional.py b/authentik/core/api/applications_transactional.py new file mode 100644 index 000000000..31adb48a3 --- /dev/null +++ b/authentik/core/api/applications_transactional.py @@ -0,0 +1,97 @@ +from django.apps import apps +from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema, extend_schema_field +from rest_framework.exceptions import ValidationError +from rest_framework.fields import ChoiceField, DictField +from rest_framework.permissions import IsAdminUser +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView +from yaml import ScalarNode +from authentik.blueprints.v1.common import Blueprint, BlueprintEntry, BlueprintEntryDesiredState, KeyOf +from authentik.blueprints.v1.importer import Importer + +from authentik.core.api.applications import ApplicationSerializer +from authentik.core.api.utils import PassiveSerializer +from authentik.core.models import Provider +from authentik.lib.utils.reflection import all_subclasses + + +def get_provider_serializer_mapping(): + map = {} + for model in all_subclasses(Provider): + if model._meta.abstract: + continue + map[f"{model._meta.app_label}.{model._meta.model_name}"] = model().serializer + return map + + +@extend_schema_field( + PolymorphicProxySerializer( + component_name="model", + serializers=get_provider_serializer_mapping, + resource_type_field_name="provider_model", + ) +) +class TransactionProviderField(DictField): + pass + + +class TransactionApplicationSerializer(PassiveSerializer): + """Serializer for creating a provider and an application in one transaction""" + + app = ApplicationSerializer() + provider_model = ChoiceField(choices=list(get_provider_serializer_mapping().keys())) + provider = TransactionProviderField() + + _provider_model: type[Provider] = None + + def validate_provider_model(self, fq_model_name: str) -> str: + """Validate that the model exists and is a provider""" + if "." not in fq_model_name: + raise ValidationError("Invalid provider model") + try: + app, model_name = fq_model_name.split(".") + model = apps.get_model(app, model_name) + if not issubclass(model, Provider): + raise ValidationError("Invalid provider model") + self._provider_model = model + except LookupError: + raise ValidationError("Invalid provider model") + return fq_model_name + + def validate_provider(self, provider: dict) -> dict: + """Validate provider data""" + # ensure the model has been validated + self.validate_provider_model(self.initial_data["provider_model"]) + model_serializer = self._provider_model().serializer(data=provider) + model_serializer.is_valid(raise_exception=True) + return model_serializer.validated_data + + +class TransactionalApplicationView(APIView): + permission_classes = [IsAdminUser] + + @extend_schema(request=TransactionApplicationSerializer()) + def put(self, request: Request) -> Response: + data = TransactionApplicationSerializer(data=request.data) + data.is_valid(raise_exception=True) + print(data.validated_data) + + blueprint = Blueprint() + blueprint.entries.append(BlueprintEntry( + model=data.validated_data["provider_model"], + state=BlueprintEntryDesiredState.PRESENT, + identifiers={}, + id="provider", + attrs=data.validated_data["provider"], + )) + app_data = data.validated_data["app"] + app_data["provider"] = KeyOf(None, ScalarNode(value="provider")) + blueprint.entries.append(BlueprintEntry( + model="authentik_core.application", + state=BlueprintEntryDesiredState.PRESENT, + identifiers={}, + attrs=app_data, + )) + importer = Importer("", {}) + return Response(status=200) diff --git a/authentik/core/urls.py b/authentik/core/urls.py index c9aa748c5..8e16e56bd 100644 --- a/authentik/core/urls.py +++ b/authentik/core/urls.py @@ -8,6 +8,7 @@ from django.views.decorators.csrf import ensure_csrf_cookie from django.views.generic import RedirectView from authentik.core.api.applications import ApplicationViewSet +from authentik.core.api.applications_transactional import TransactionalApplicationView from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet from authentik.core.api.groups import GroupViewSet @@ -70,6 +71,11 @@ urlpatterns = [ api_urlpatterns = [ ("core/authenticated_sessions", AuthenticatedSessionViewSet), ("core/applications", ApplicationViewSet), + path( + "core/applications/create_transactional/", + TransactionalApplicationView.as_view(), + name="core-apps-transactional", + ), ("core/groups", GroupViewSet), ("core/users", UserViewSet), ("core/tokens", TokenViewSet), diff --git a/authentik/flows/views/executor.py b/authentik/flows/views/executor.py index 1279940b6..758597e27 100644 --- a/authentik/flows/views/executor.py +++ b/authentik/flows/views/executor.py @@ -73,40 +73,24 @@ QS_QUERY = "query" def challenge_types(): - """This is a workaround for PolymorphicProxySerializer not accepting a callable for - `serializers`. This function returns a class which is an iterator, which returns the + """This function returns a class which is an iterator, which returns the subclasses of Challenge, and Challenge itself.""" - - class Inner(dict): - """dummy class with custom callback on .items()""" - - def items(self): - mapping = {} - classes = all_subclasses(Challenge) - classes.remove(WithUserInfoChallenge) - for cls in classes: - mapping[cls().fields["component"].default] = cls - return mapping.items() - - return Inner() + mapping = {} + classes = all_subclasses(Challenge) + classes.remove(WithUserInfoChallenge) + for cls in classes: + mapping[cls().fields["component"].default] = cls + return mapping def challenge_response_types(): - """This is a workaround for PolymorphicProxySerializer not accepting a callable for - `serializers`. This function returns a class which is an iterator, which returns the + """This function returns a class which is an iterator, which returns the subclasses of Challenge, and Challenge itself.""" - - class Inner(dict): - """dummy class with custom callback on .items()""" - - def items(self): - mapping = {} - classes = all_subclasses(ChallengeResponse) - for cls in classes: - mapping[cls(stage=None).fields["component"].default] = cls - return mapping.items() - - return Inner() + mapping = {} + classes = all_subclasses(ChallengeResponse) + for cls in classes: + mapping[cls(stage=None).fields["component"].default] = cls + return mapping class InvalidStageError(SentryIgnoredException): @@ -264,7 +248,7 @@ class FlowExecutorView(APIView): responses={ 200: PolymorphicProxySerializer( component_name="ChallengeTypes", - serializers=challenge_types(), + serializers=challenge_types, resource_type_field_name="component", ), }, @@ -304,13 +288,13 @@ class FlowExecutorView(APIView): responses={ 200: PolymorphicProxySerializer( component_name="ChallengeTypes", - serializers=challenge_types(), + serializers=challenge_types, resource_type_field_name="component", ), }, request=PolymorphicProxySerializer( component_name="FlowChallengeResponse", - serializers=challenge_response_types(), + serializers=challenge_response_types, resource_type_field_name="component", ), parameters=[ diff --git a/schema.yml b/schema.yml index f6320501b..6ad851120 100644 --- a/schema.yml +++ b/schema.yml @@ -3088,6 +3088,34 @@ paths: schema: $ref: '#/components/schemas/GenericError' description: '' + /core/applications/create_transactional/: + put: + operationId: core_applications_create_transactional_update + tags: + - core + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/TransactionApplicationRequest' + required: true + security: + - authentik: [] + responses: + '200': + description: No response body + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' /core/authenticated_sessions/: get: operationId: core_authenticated_sessions_list @@ -37502,6 +37530,22 @@ components: description: |- * `twilio` - Twilio * `generic` - Generic + ProviderModelEnum: + enum: + - authentik_providers_ldap.ldapprovider + - authentik_providers_oauth2.oauth2provider + - authentik_providers_proxy.proxyprovider + - authentik_providers_radius.radiusprovider + - authentik_providers_saml.samlprovider + - authentik_providers_scim.scimprovider + type: string + description: |- + * `authentik_providers_ldap.ldapprovider` - authentik_providers_ldap.ldapprovider + * `authentik_providers_oauth2.oauth2provider` - authentik_providers_oauth2.oauth2provider + * `authentik_providers_proxy.proxyprovider` - authentik_providers_proxy.proxyprovider + * `authentik_providers_radius.radiusprovider` - authentik_providers_radius.radiusprovider + * `authentik_providers_saml.samlprovider` - authentik_providers_saml.samlprovider + * `authentik_providers_scim.scimprovider` - authentik_providers_scim.scimprovider ProviderRequest: type: object description: Provider Serializer @@ -39913,6 +39957,20 @@ components: readOnly: true required: - key + TransactionApplicationRequest: + type: object + description: Serializer for creating a provider and an application in one transaction + properties: + app: + $ref: '#/components/schemas/ApplicationRequest' + provider_model: + $ref: '#/components/schemas/ProviderModelEnum' + provider: + $ref: '#/components/schemas/modelRequest' + required: + - app + - provider + - provider_model TypeCreate: type: object description: Types of an object that can be created @@ -40840,6 +40898,23 @@ components: type: integer required: - count + modelRequest: + oneOf: + - $ref: '#/components/schemas/LDAPProviderRequest' + - $ref: '#/components/schemas/OAuth2ProviderRequest' + - $ref: '#/components/schemas/ProxyProviderRequest' + - $ref: '#/components/schemas/RadiusProviderRequest' + - $ref: '#/components/schemas/SAMLProviderRequest' + - $ref: '#/components/schemas/SCIMProviderRequest' + discriminator: + propertyName: provider_model + mapping: + authentik_providers_ldap.ldapprovider: '#/components/schemas/LDAPProviderRequest' + authentik_providers_oauth2.oauth2provider: '#/components/schemas/OAuth2ProviderRequest' + authentik_providers_proxy.proxyprovider: '#/components/schemas/ProxyProviderRequest' + authentik_providers_radius.radiusprovider: '#/components/schemas/RadiusProviderRequest' + authentik_providers_saml.samlprovider: '#/components/schemas/SAMLProviderRequest' + authentik_providers_scim.scimprovider: '#/components/schemas/SCIMProviderRequest' securitySchemes: authentik: type: apiKey