Merge branch 'dev' into web/sidebar-with-live-content-3
* dev: (131 commits) web: Replace calls to `rootInterface()?.tenant?` with a contextual `this.tenant` object (#7778) web: abstract `rootInterface()?.config?.capabilities.includes()` into `.can()` (#7737) web: update some locale details (#8090) web: bump the eslint group in /web with 2 updates (#8082) web: bump rollup from 4.9.2 to 4.9.4 in /web (#8083) core: bump github.com/redis/go-redis/v9 from 9.3.1 to 9.4.0 (#8085) web: bump the eslint group in /tests/wdio with 2 updates (#8086) website: bump @types/react from 18.2.46 to 18.2.47 in /website (#8088) stages/user_login: only set last_ip in session if a binding is given (#8074) providers/oauth2: fix missing nonce in token endpoint not being saved (#8073) core: bump goauthentik.io/api/v3 from 3.2023105.3 to 3.2023105.5 (#8066) providers/oauth2: fix missing nonce in id_token (#8072) rbac: fix error when looking up permissions for now uninstalled apps (#8068) web/flows: fix device picker incorrect foreground color (#8067) translate: Updates for file web/xliff/en.xlf in zh_CN (#8061) translate: Updates for file web/xliff/en.xlf in zh-Hans (#8062) website: bump postcss from 8.4.32 to 8.4.33 in /website (#8063) web: bump the sentry group in /web with 2 updates (#8064) core: bump golang.org/x/sync from 0.5.0 to 0.6.0 (#8065) website/docs: add link to our example flows (#8052) ...
This commit is contained in:
commit
9768684c3c
|
@ -1,5 +1,5 @@
|
||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 2023.10.4
|
current_version = 2023.10.5
|
||||||
tag = True
|
tag = True
|
||||||
commit = True
|
commit = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
|
|
@ -9,3 +9,4 @@ blueprints/local
|
||||||
.git
|
.git
|
||||||
!gen-ts-api/node_modules
|
!gen-ts-api/node_modules
|
||||||
!gen-ts-api/dist/**
|
!gen-ts-api/dist/**
|
||||||
|
!gen-go-api/
|
||||||
|
|
1
.github/codespell-words.txt
vendored
1
.github/codespell-words.txt
vendored
|
@ -2,3 +2,4 @@ keypair
|
||||||
keypairs
|
keypairs
|
||||||
hass
|
hass
|
||||||
warmup
|
warmup
|
||||||
|
ontext
|
||||||
|
|
44
.github/workflows/ci-main.yml
vendored
44
.github/workflows/ci-main.yml
vendored
|
@ -61,10 +61,6 @@ jobs:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Setup authentik env
|
|
||||||
uses: ./.github/actions/setup
|
|
||||||
with:
|
|
||||||
postgresql_version: ${{ matrix.psql }}
|
|
||||||
- name: checkout stable
|
- name: checkout stable
|
||||||
run: |
|
run: |
|
||||||
# Delete all poetry envs
|
# Delete all poetry envs
|
||||||
|
@ -76,7 +72,7 @@ jobs:
|
||||||
git checkout version/$(python -c "from authentik import __version__; print(__version__)")
|
git checkout version/$(python -c "from authentik import __version__; print(__version__)")
|
||||||
rm -rf .github/ scripts/
|
rm -rf .github/ scripts/
|
||||||
mv ../.github ../scripts .
|
mv ../.github ../scripts .
|
||||||
- name: Setup authentik env (ensure stable deps are installed)
|
- name: Setup authentik env (stable)
|
||||||
uses: ./.github/actions/setup
|
uses: ./.github/actions/setup
|
||||||
with:
|
with:
|
||||||
postgresql_version: ${{ matrix.psql }}
|
postgresql_version: ${{ matrix.psql }}
|
||||||
|
@ -90,15 +86,20 @@ jobs:
|
||||||
git clean -d -fx .
|
git clean -d -fx .
|
||||||
git checkout $GITHUB_SHA
|
git checkout $GITHUB_SHA
|
||||||
# Delete previous poetry env
|
# Delete previous poetry env
|
||||||
rm -rf $(poetry env info --path)
|
rm -rf /home/runner/.cache/pypoetry/virtualenvs/*
|
||||||
- name: Setup authentik env (ensure latest deps are installed)
|
- name: Setup authentik env (ensure latest deps are installed)
|
||||||
uses: ./.github/actions/setup
|
uses: ./.github/actions/setup
|
||||||
with:
|
with:
|
||||||
postgresql_version: ${{ matrix.psql }}
|
postgresql_version: ${{ matrix.psql }}
|
||||||
- name: migrate to latest
|
- name: migrate to latest
|
||||||
run: |
|
run: |
|
||||||
poetry install
|
|
||||||
poetry run python -m lifecycle.migrate
|
poetry run python -m lifecycle.migrate
|
||||||
|
- name: run tests
|
||||||
|
env:
|
||||||
|
# Test in the main database that we just migrated from the previous stable version
|
||||||
|
AUTHENTIK_POSTGRESQL__TEST__NAME: authentik
|
||||||
|
run: |
|
||||||
|
poetry run make test
|
||||||
test-unittest:
|
test-unittest:
|
||||||
name: test-unittest - PostgreSQL ${{ matrix.psql }}
|
name: test-unittest - PostgreSQL ${{ matrix.psql }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -248,12 +249,6 @@ jobs:
|
||||||
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
|
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
|
||||||
cache-from: type=gha
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max
|
||||||
- name: Comment on PR
|
|
||||||
if: github.event_name == 'pull_request'
|
|
||||||
continue-on-error: true
|
|
||||||
uses: ./.github/actions/comment-pr-instructions
|
|
||||||
with:
|
|
||||||
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
|
|
||||||
build-arm64:
|
build-arm64:
|
||||||
needs: ci-core-mark
|
needs: ci-core-mark
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -302,3 +297,26 @@ jobs:
|
||||||
platforms: linux/arm64
|
platforms: linux/arm64
|
||||||
cache-from: type=gha
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max
|
||||||
|
pr-comment:
|
||||||
|
needs:
|
||||||
|
- build
|
||||||
|
- build-arm64
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ github.event_name == 'pull_request' }}
|
||||||
|
permissions:
|
||||||
|
# Needed to write comments on PRs
|
||||||
|
pull-requests: write
|
||||||
|
timeout-minutes: 120
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.pull_request.head.sha }}
|
||||||
|
- name: prepare variables
|
||||||
|
uses: ./.github/actions/docker-push-variables
|
||||||
|
id: ev
|
||||||
|
env:
|
||||||
|
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||||
|
- name: Comment on PR
|
||||||
|
uses: ./.github/actions/comment-pr-instructions
|
||||||
|
with:
|
||||||
|
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
|
||||||
|
|
2
.github/workflows/ci-outpost.yml
vendored
2
.github/workflows/ci-outpost.yml
vendored
|
@ -65,6 +65,7 @@ jobs:
|
||||||
- proxy
|
- proxy
|
||||||
- ldap
|
- ldap
|
||||||
- radius
|
- radius
|
||||||
|
- rac
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
# Needed to upload contianer images to ghcr.io
|
# Needed to upload contianer images to ghcr.io
|
||||||
|
@ -119,6 +120,7 @@ jobs:
|
||||||
- proxy
|
- proxy
|
||||||
- ldap
|
- ldap
|
||||||
- radius
|
- radius
|
||||||
|
- rac
|
||||||
goos: [linux]
|
goos: [linux]
|
||||||
goarch: [amd64, arm64]
|
goarch: [amd64, arm64]
|
||||||
steps:
|
steps:
|
||||||
|
|
6
.github/workflows/codeql-analysis.yml
vendored
6
.github/workflows/codeql-analysis.yml
vendored
|
@ -27,10 +27,10 @@ jobs:
|
||||||
- name: Setup authentik env
|
- name: Setup authentik env
|
||||||
uses: ./.github/actions/setup
|
uses: ./.github/actions/setup
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v2
|
uses: github/codeql-action/init@v3
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
- name: Autobuild
|
- name: Autobuild
|
||||||
uses: github/codeql-action/autobuild@v2
|
uses: github/codeql-action/autobuild@v3
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v2
|
uses: github/codeql-action/analyze@v3
|
||||||
|
|
1
.github/workflows/release-publish.yml
vendored
1
.github/workflows/release-publish.yml
vendored
|
@ -65,6 +65,7 @@ jobs:
|
||||||
- proxy
|
- proxy
|
||||||
- ldap
|
- ldap
|
||||||
- radius
|
- radius
|
||||||
|
- rac
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
|
|
|
@ -71,7 +71,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||||
# Stage 4: MaxMind GeoIP
|
# Stage 4: MaxMind GeoIP
|
||||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.0 as geoip
|
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.0 as geoip
|
||||||
|
|
||||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City"
|
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||||
ENV GEOIPUPDATE_VERBOSE="true"
|
ENV GEOIPUPDATE_VERBOSE="true"
|
||||||
ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID"
|
ENV GEOIPUPDATE_ACCOUNT_ID_FILE="/run/secrets/GEOIPUPDATE_ACCOUNT_ID"
|
||||||
ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY"
|
ENV GEOIPUPDATE_LICENSE_KEY_FILE="/run/secrets/GEOIPUPDATE_LICENSE_KEY"
|
||||||
|
|
7
Makefile
7
Makefile
|
@ -58,7 +58,7 @@ test: ## Run the server tests and produce a coverage report (locally)
|
||||||
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
|
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
|
||||||
isort $(PY_SOURCES)
|
isort $(PY_SOURCES)
|
||||||
black $(PY_SOURCES)
|
black $(PY_SOURCES)
|
||||||
ruff $(PY_SOURCES)
|
ruff --fix $(PY_SOURCES)
|
||||||
codespell -w $(CODESPELL_ARGS)
|
codespell -w $(CODESPELL_ARGS)
|
||||||
|
|
||||||
lint: ## Lint the python and golang sources
|
lint: ## Lint the python and golang sources
|
||||||
|
@ -115,8 +115,9 @@ gen-diff: ## (Release) generate the changelog diff between the current schema a
|
||||||
npx prettier --write diff.md
|
npx prettier --write diff.md
|
||||||
|
|
||||||
gen-clean:
|
gen-clean:
|
||||||
rm -rf web/api/src/
|
rm -rf gen-go-api/
|
||||||
rm -rf api/
|
rm -rf gen-ts-api/
|
||||||
|
rm -rf web/node_modules/@goauthentik/api/
|
||||||
|
|
||||||
gen-client-ts: ## Build and install the authentik API for Typescript into the authentik UI Application
|
gen-client-ts: ## Build and install the authentik API for Typescript into the authentik UI Application
|
||||||
docker run \
|
docker run \
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from os import environ
|
from os import environ
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
__version__ = "2023.10.4"
|
__version__ = "2023.10.5"
|
||||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ from authentik.blueprints.tests import reconcile_app
|
||||||
from authentik.core.models import Token, TokenIntents, User, UserTypes
|
from authentik.core.models import Token, TokenIntents, User, UserTypes
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||||
|
from authentik.outposts.models import Outpost
|
||||||
from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API
|
from authentik.providers.oauth2.constants import SCOPE_AUTHENTIK_API
|
||||||
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider
|
from authentik.providers.oauth2.models import AccessToken, OAuth2Provider
|
||||||
|
|
||||||
|
@ -49,8 +51,12 @@ class TestAPIAuth(TestCase):
|
||||||
with self.assertRaises(AuthenticationFailed):
|
with self.assertRaises(AuthenticationFailed):
|
||||||
bearer_auth(f"Bearer {token.key}".encode())
|
bearer_auth(f"Bearer {token.key}".encode())
|
||||||
|
|
||||||
def test_managed_outpost(self):
|
@reconcile_app("authentik_outposts")
|
||||||
|
def test_managed_outpost_fail(self):
|
||||||
"""Test managed outpost"""
|
"""Test managed outpost"""
|
||||||
|
outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||||
|
outpost.user.delete()
|
||||||
|
outpost.delete()
|
||||||
with self.assertRaises(AuthenticationFailed):
|
with self.assertRaises(AuthenticationFailed):
|
||||||
bearer_auth(f"Bearer {settings.SECRET_KEY}".encode())
|
bearer_auth(f"Bearer {settings.SECRET_KEY}".encode())
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from rest_framework.response import Response
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import PassiveSerializer
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.context_processors.base import get_context_processors
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
|
|
||||||
capabilities = Signal()
|
capabilities = Signal()
|
||||||
|
@ -30,6 +30,7 @@ class Capabilities(models.TextChoices):
|
||||||
|
|
||||||
CAN_SAVE_MEDIA = "can_save_media"
|
CAN_SAVE_MEDIA = "can_save_media"
|
||||||
CAN_GEO_IP = "can_geo_ip"
|
CAN_GEO_IP = "can_geo_ip"
|
||||||
|
CAN_ASN = "can_asn"
|
||||||
CAN_IMPERSONATE = "can_impersonate"
|
CAN_IMPERSONATE = "can_impersonate"
|
||||||
CAN_DEBUG = "can_debug"
|
CAN_DEBUG = "can_debug"
|
||||||
IS_ENTERPRISE = "is_enterprise"
|
IS_ENTERPRISE = "is_enterprise"
|
||||||
|
@ -68,8 +69,9 @@ class ConfigView(APIView):
|
||||||
deb_test = settings.DEBUG or settings.TEST
|
deb_test = settings.DEBUG or settings.TEST
|
||||||
if Path(settings.MEDIA_ROOT).is_mount() or deb_test:
|
if Path(settings.MEDIA_ROOT).is_mount() or deb_test:
|
||||||
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
caps.append(Capabilities.CAN_SAVE_MEDIA)
|
||||||
if GEOIP_READER.enabled:
|
for processor in get_context_processors():
|
||||||
caps.append(Capabilities.CAN_GEO_IP)
|
if cap := processor.capability():
|
||||||
|
caps.append(cap)
|
||||||
if CONFIG.get_bool("impersonation"):
|
if CONFIG.get_bool("impersonation"):
|
||||||
caps.append(Capabilities.CAN_IMPERSONATE)
|
caps.append(Capabilities.CAN_IMPERSONATE)
|
||||||
if settings.DEBUG: # pragma: no cover
|
if settings.DEBUG: # pragma: no cover
|
||||||
|
|
|
@ -3,7 +3,7 @@ from django.utils.translation import gettext_lazy as _
|
||||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.exceptions import ValidationError
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.fields import CharField, DateTimeField, JSONField
|
from rest_framework.fields import CharField, DateTimeField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ListSerializer, ModelSerializer
|
from rest_framework.serializers import ListSerializer, ModelSerializer
|
||||||
|
@ -15,7 +15,7 @@ from authentik.blueprints.v1.importer import Importer
|
||||||
from authentik.blueprints.v1.oci import OCI_PREFIX
|
from authentik.blueprints.v1.oci import OCI_PREFIX
|
||||||
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
|
from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
|
|
||||||
|
|
||||||
class ManagedSerializer:
|
class ManagedSerializer:
|
||||||
|
@ -28,7 +28,7 @@ class MetadataSerializer(PassiveSerializer):
|
||||||
"""Serializer for blueprint metadata"""
|
"""Serializer for blueprint metadata"""
|
||||||
|
|
||||||
name = CharField()
|
name = CharField()
|
||||||
labels = JSONField()
|
labels = JSONDictField()
|
||||||
|
|
||||||
|
|
||||||
class BlueprintInstanceSerializer(ModelSerializer):
|
class BlueprintInstanceSerializer(ModelSerializer):
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ManagedAppConfig(AppConfig):
|
||||||
meth()
|
meth()
|
||||||
self._logger.debug("Successfully reconciled", name=name)
|
self._logger.debug("Successfully reconciled", name=name)
|
||||||
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
||||||
self._logger.debug("Failed to run reconcile", name=name, exc=exc)
|
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
|
||||||
|
|
||||||
|
|
||||||
class AuthentikBlueprintsConfig(ManagedAppConfig):
|
class AuthentikBlueprintsConfig(ManagedAppConfig):
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from rest_framework.exceptions import ValidationError
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.fields import BooleanField, JSONField
|
from rest_framework.fields import BooleanField
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.blueprints.v1.meta.registry import BaseMetaModel, MetaResult, registry
|
from authentik.blueprints.v1.meta.registry import BaseMetaModel, MetaResult, registry
|
||||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.blueprints.models import BlueprintInstance
|
from authentik.blueprints.models import BlueprintInstance
|
||||||
|
@ -17,7 +17,7 @@ LOGGER = get_logger()
|
||||||
class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
||||||
"""Serializer for meta apply blueprint model"""
|
"""Serializer for meta apply blueprint model"""
|
||||||
|
|
||||||
identifiers = JSONField(validators=[is_dict])
|
identifiers = JSONDictField()
|
||||||
required = BooleanField(default=True)
|
required = BooleanField(default=True)
|
||||||
|
|
||||||
# We cannot override `instance` as that will confuse rest_framework
|
# We cannot override `instance` as that will confuse rest_framework
|
||||||
|
|
|
@ -14,7 +14,8 @@ from ua_parser import user_agent_parser
|
||||||
from authentik.api.authorization import OwnerSuperuserPermissions
|
from authentik.api.authorization import OwnerSuperuserPermissions
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.models import AuthenticatedSession
|
from authentik.core.models import AuthenticatedSession
|
||||||
from authentik.events.geo import GEOIP_READER, GeoIPDict
|
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR, ASNDict
|
||||||
|
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR, GeoIPDict
|
||||||
|
|
||||||
|
|
||||||
class UserAgentDeviceDict(TypedDict):
|
class UserAgentDeviceDict(TypedDict):
|
||||||
|
@ -59,6 +60,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
current = SerializerMethodField()
|
current = SerializerMethodField()
|
||||||
user_agent = SerializerMethodField()
|
user_agent = SerializerMethodField()
|
||||||
geo_ip = SerializerMethodField()
|
geo_ip = SerializerMethodField()
|
||||||
|
asn = SerializerMethodField()
|
||||||
|
|
||||||
def get_current(self, instance: AuthenticatedSession) -> bool:
|
def get_current(self, instance: AuthenticatedSession) -> bool:
|
||||||
"""Check if session is currently active session"""
|
"""Check if session is currently active session"""
|
||||||
|
@ -70,8 +72,12 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
return user_agent_parser.Parse(instance.last_user_agent)
|
return user_agent_parser.Parse(instance.last_user_agent)
|
||||||
|
|
||||||
def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover
|
def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover
|
||||||
"""Get parsed user agent"""
|
"""Get GeoIP Data"""
|
||||||
return GEOIP_READER.city_dict(instance.last_ip)
|
return GEOIP_CONTEXT_PROCESSOR.city_dict(instance.last_ip)
|
||||||
|
|
||||||
|
def get_asn(self, instance: AuthenticatedSession) -> Optional[ASNDict]: # pragma: no cover
|
||||||
|
"""Get ASN Data"""
|
||||||
|
return ASN_CONTEXT_PROCESSOR.asn_dict(instance.last_ip)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = AuthenticatedSession
|
model = AuthenticatedSession
|
||||||
|
@ -80,6 +86,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
|
||||||
"current",
|
"current",
|
||||||
"user_agent",
|
"user_agent",
|
||||||
"geo_ip",
|
"geo_ip",
|
||||||
|
"asn",
|
||||||
"user",
|
"user",
|
||||||
"last_ip",
|
"last_ip",
|
||||||
"last_user_agent",
|
"last_user_agent",
|
||||||
|
|
|
@ -8,7 +8,7 @@ from django_filters.filterset import FilterSet
|
||||||
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiResponse, extend_schema
|
||||||
from guardian.shortcuts import get_objects_for_user
|
from guardian.shortcuts import get_objects_for_user
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import CharField, IntegerField, JSONField
|
from rest_framework.fields import CharField, IntegerField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
|
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
|
||||||
|
@ -16,7 +16,7 @@ from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.api.decorators import permission_required
|
from authentik.api.decorators import permission_required
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
from authentik.core.models import Group, User
|
from authentik.core.models import Group, User
|
||||||
from authentik.rbac.api.roles import RoleSerializer
|
from authentik.rbac.api.roles import RoleSerializer
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from authentik.rbac.api.roles import RoleSerializer
|
||||||
class GroupMemberSerializer(ModelSerializer):
|
class GroupMemberSerializer(ModelSerializer):
|
||||||
"""Stripped down user serializer to show relevant users for groups"""
|
"""Stripped down user serializer to show relevant users for groups"""
|
||||||
|
|
||||||
attributes = JSONField(validators=[is_dict], required=False)
|
attributes = JSONDictField(required=False)
|
||||||
uid = CharField(read_only=True)
|
uid = CharField(read_only=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -44,7 +44,7 @@ class GroupMemberSerializer(ModelSerializer):
|
||||||
class GroupSerializer(ModelSerializer):
|
class GroupSerializer(ModelSerializer):
|
||||||
"""Group Serializer"""
|
"""Group Serializer"""
|
||||||
|
|
||||||
attributes = JSONField(validators=[is_dict], required=False)
|
attributes = JSONDictField(required=False)
|
||||||
users_obj = ListSerializer(
|
users_obj = ListSerializer(
|
||||||
child=GroupMemberSerializer(), read_only=True, source="users", required=False
|
child=GroupMemberSerializer(), read_only=True, source="users", required=False
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@ from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
|
||||||
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
from authentik.core.expression.evaluator import PropertyMappingEvaluator
|
||||||
from authentik.core.models import PropertyMapping
|
from authentik.core.models import PropertyMapping
|
||||||
|
from authentik.enterprise.apps import EnterpriseConfig
|
||||||
from authentik.events.utils import sanitize_item
|
from authentik.events.utils import sanitize_item
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
from authentik.policies.api.exec import PolicyTestSerializer
|
from authentik.policies.api.exec import PolicyTestSerializer
|
||||||
|
@ -95,6 +96,7 @@ class PropertyMappingViewSet(
|
||||||
"description": subclass.__doc__,
|
"description": subclass.__doc__,
|
||||||
"component": subclass().component,
|
"component": subclass().component,
|
||||||
"model_name": subclass._meta.model_name,
|
"model_name": subclass._meta.model_name,
|
||||||
|
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return Response(TypeCreateSerializer(data, many=True).data)
|
return Response(TypeCreateSerializer(data, many=True).data)
|
||||||
|
|
|
@ -16,6 +16,7 @@ from rest_framework.viewsets import GenericViewSet
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer
|
||||||
from authentik.core.models import Provider
|
from authentik.core.models import Provider
|
||||||
|
from authentik.enterprise.apps import EnterpriseConfig
|
||||||
from authentik.lib.utils.reflection import all_subclasses
|
from authentik.lib.utils.reflection import all_subclasses
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,6 +114,7 @@ class ProviderViewSet(
|
||||||
"description": subclass.__doc__,
|
"description": subclass.__doc__,
|
||||||
"component": subclass().component,
|
"component": subclass().component,
|
||||||
"model_name": subclass._meta.model_name,
|
"model_name": subclass._meta.model_name,
|
||||||
|
"requires_enterprise": isinstance(subclass._meta.app_config, EnterpriseConfig),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
data.append(
|
data.append(
|
||||||
|
|
|
@ -32,13 +32,7 @@ from drf_spectacular.utils import (
|
||||||
)
|
)
|
||||||
from guardian.shortcuts import get_anonymous_user, get_objects_for_user
|
from guardian.shortcuts import get_anonymous_user, get_objects_for_user
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import (
|
from rest_framework.fields import CharField, IntegerField, ListField, SerializerMethodField
|
||||||
CharField,
|
|
||||||
IntegerField,
|
|
||||||
JSONField,
|
|
||||||
ListField,
|
|
||||||
SerializerMethodField,
|
|
||||||
)
|
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import (
|
from rest_framework.serializers import (
|
||||||
|
@ -57,7 +51,7 @@ from authentik.admin.api.metrics import CoordinateSerializer
|
||||||
from authentik.api.decorators import permission_required
|
from authentik.api.decorators import permission_required
|
||||||
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
|
from authentik.core.api.utils import JSONDictField, LinkSerializer, PassiveSerializer
|
||||||
from authentik.core.middleware import (
|
from authentik.core.middleware import (
|
||||||
SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
|
SESSION_KEY_IMPERSONATE_ORIGINAL_USER,
|
||||||
SESSION_KEY_IMPERSONATE_USER,
|
SESSION_KEY_IMPERSONATE_USER,
|
||||||
|
@ -89,7 +83,7 @@ LOGGER = get_logger()
|
||||||
class UserGroupSerializer(ModelSerializer):
|
class UserGroupSerializer(ModelSerializer):
|
||||||
"""Simplified Group Serializer for user's groups"""
|
"""Simplified Group Serializer for user's groups"""
|
||||||
|
|
||||||
attributes = JSONField(required=False)
|
attributes = JSONDictField(required=False)
|
||||||
parent_name = CharField(source="parent.name", read_only=True)
|
parent_name = CharField(source="parent.name", read_only=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -110,7 +104,7 @@ class UserSerializer(ModelSerializer):
|
||||||
|
|
||||||
is_superuser = BooleanField(read_only=True)
|
is_superuser = BooleanField(read_only=True)
|
||||||
avatar = CharField(read_only=True)
|
avatar = CharField(read_only=True)
|
||||||
attributes = JSONField(validators=[is_dict], required=False)
|
attributes = JSONDictField(required=False)
|
||||||
groups = PrimaryKeyRelatedField(
|
groups = PrimaryKeyRelatedField(
|
||||||
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list
|
allow_empty=True, many=True, source="ak_groups", queryset=Group.objects.all(), default=list
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,7 +2,10 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from rest_framework.fields import CharField, IntegerField, JSONField
|
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||||
|
from drf_spectacular.plumbing import build_basic_type
|
||||||
|
from drf_spectacular.types import OpenApiTypes
|
||||||
|
from rest_framework.fields import BooleanField, CharField, IntegerField, JSONField
|
||||||
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
|
from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +16,21 @@ def is_dict(value: Any):
|
||||||
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
|
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
|
||||||
|
|
||||||
|
|
||||||
|
class JSONDictField(JSONField):
|
||||||
|
"""JSON Field which only allows dictionaries"""
|
||||||
|
|
||||||
|
default_validators = [is_dict]
|
||||||
|
|
||||||
|
|
||||||
|
class JSONExtension(OpenApiSerializerFieldExtension):
|
||||||
|
"""Generate API Schema for JSON fields as"""
|
||||||
|
|
||||||
|
target_class = "authentik.core.api.utils.JSONDictField"
|
||||||
|
|
||||||
|
def map_serializer_field(self, auto_schema, direction):
|
||||||
|
return build_basic_type(OpenApiTypes.OBJECT)
|
||||||
|
|
||||||
|
|
||||||
class PassiveSerializer(Serializer):
|
class PassiveSerializer(Serializer):
|
||||||
"""Base serializer class which doesn't implement create/update methods"""
|
"""Base serializer class which doesn't implement create/update methods"""
|
||||||
|
|
||||||
|
@ -26,7 +44,7 @@ class PassiveSerializer(Serializer):
|
||||||
class PropertyMappingPreviewSerializer(PassiveSerializer):
|
class PropertyMappingPreviewSerializer(PassiveSerializer):
|
||||||
"""Preview how the current user is mapped via the property mappings selected in a provider"""
|
"""Preview how the current user is mapped via the property mappings selected in a provider"""
|
||||||
|
|
||||||
preview = JSONField(read_only=True)
|
preview = JSONDictField(read_only=True)
|
||||||
|
|
||||||
|
|
||||||
class MetaNameSerializer(PassiveSerializer):
|
class MetaNameSerializer(PassiveSerializer):
|
||||||
|
@ -56,6 +74,7 @@ class TypeCreateSerializer(PassiveSerializer):
|
||||||
description = CharField(required=True)
|
description = CharField(required=True)
|
||||||
component = CharField(required=True)
|
component = CharField(required=True)
|
||||||
model_name = CharField(required=True)
|
model_name = CharField(required=True)
|
||||||
|
requires_enterprise = BooleanField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class CacheSerializer(PassiveSerializer):
|
class CacheSerializer(PassiveSerializer):
|
||||||
|
|
|
@ -1,22 +1,29 @@
|
||||||
"""Channels base classes"""
|
"""Channels base classes"""
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
from channels.exceptions import DenyConnection
|
from channels.exceptions import DenyConnection
|
||||||
from channels.generic.websocket import JsonWebsocketConsumer
|
|
||||||
from rest_framework.exceptions import AuthenticationFailed
|
from rest_framework.exceptions import AuthenticationFailed
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.api.authentication import bearer_auth
|
from authentik.api.authentication import bearer_auth
|
||||||
from authentik.core.models import User
|
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class AuthJsonConsumer(JsonWebsocketConsumer):
|
class TokenOutpostMiddleware:
|
||||||
"""Authorize a client with a token"""
|
"""Authorize a client with a token"""
|
||||||
|
|
||||||
user: User
|
def __init__(self, inner):
|
||||||
|
self.inner = inner
|
||||||
|
|
||||||
def connect(self):
|
async def __call__(self, scope, receive, send):
|
||||||
headers = dict(self.scope["headers"])
|
scope = dict(scope)
|
||||||
|
await self.auth(scope)
|
||||||
|
return await self.inner(scope, receive, send)
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def auth(self, scope):
|
||||||
|
"""Authenticate request from header"""
|
||||||
|
headers = dict(scope["headers"])
|
||||||
if b"authorization" not in headers:
|
if b"authorization" not in headers:
|
||||||
LOGGER.warning("WS Request without authorization header")
|
LOGGER.warning("WS Request without authorization header")
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
|
@ -32,4 +39,4 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
|
||||||
LOGGER.warning("Failed to authenticate", exc=exc)
|
LOGGER.warning("Failed to authenticate", exc=exc)
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
|
|
||||||
self.user = user
|
scope["user"] = user
|
||||||
|
|
|
@ -44,6 +44,7 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||||
if request:
|
if request:
|
||||||
req.http_request = request
|
req.http_request = request
|
||||||
self._context["request"] = req
|
self._context["request"] = req
|
||||||
|
req.context.update(**kwargs)
|
||||||
self._context.update(**kwargs)
|
self._context.update(**kwargs)
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,6 @@ from authentik.lib.models import (
|
||||||
DomainlessFormattedURLValidator,
|
DomainlessFormattedURLValidator,
|
||||||
SerializerModel,
|
SerializerModel,
|
||||||
)
|
)
|
||||||
from authentik.lib.utils.http import get_client_ip
|
|
||||||
from authentik.policies.models import PolicyBindingModel
|
from authentik.policies.models import PolicyBindingModel
|
||||||
from authentik.root.install_id import get_install_id
|
from authentik.root.install_id import get_install_id
|
||||||
|
|
||||||
|
@ -748,12 +747,14 @@ class AuthenticatedSession(ExpiringModel):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
|
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
|
||||||
"""Create a new session from a http request"""
|
"""Create a new session from a http request"""
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
if not hasattr(request, "session") or not request.session.session_key:
|
if not hasattr(request, "session") or not request.session.session_key:
|
||||||
return None
|
return None
|
||||||
return AuthenticatedSession(
|
return AuthenticatedSession(
|
||||||
session_key=request.session.session_key,
|
session_key=request.session.session_key,
|
||||||
user=user,
|
user=user,
|
||||||
last_ip=get_client_ip(request),
|
last_ip=ClientIPMiddleware.get_client_ip(request),
|
||||||
last_user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
last_user_agent=request.META.get("HTTP_USER_AGENT", ""),
|
||||||
expires=request.session.get_expiry_date(),
|
expires=request.session.get_expiry_date(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,28 +44,14 @@
|
||||||
|
|
||||||
{% block body %}
|
{% block body %}
|
||||||
<div class="pf-c-background-image">
|
<div class="pf-c-background-image">
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" class="pf-c-background-image__filter" width="0" height="0">
|
|
||||||
<filter id="image_overlay">
|
|
||||||
<feColorMatrix in="SourceGraphic" type="matrix" values="1.3 0 0 0 0 0 1.3 0 0 0 0 0 1.3 0 0 0 0 0 1 0" />
|
|
||||||
<feComponentTransfer color-interpolation-filters="sRGB" result="duotone">
|
|
||||||
<feFuncR type="table" tableValues="0.086274509803922 0.43921568627451"></feFuncR>
|
|
||||||
<feFuncG type="table" tableValues="0.086274509803922 0.43921568627451"></feFuncG>
|
|
||||||
<feFuncB type="table" tableValues="0.086274509803922 0.43921568627451"></feFuncB>
|
|
||||||
<feFuncA type="table" tableValues="0 1"></feFuncA>
|
|
||||||
</feComponentTransfer>
|
|
||||||
</filter>
|
|
||||||
</svg>
|
|
||||||
</div>
|
</div>
|
||||||
<ak-message-container></ak-message-container>
|
<ak-message-container></ak-message-container>
|
||||||
<div class="pf-c-login">
|
<div class="pf-c-login stacked">
|
||||||
<div class="ak-login-container">
|
<div class="ak-login-container">
|
||||||
<header class="pf-c-login__header">
|
<main class="pf-c-login__main">
|
||||||
<div class="pf-c-brand ak-brand">
|
<div class="pf-c-login__main-header pf-c-brand ak-brand">
|
||||||
<img src="{{ tenant.branding_logo }}" alt="authentik Logo" />
|
<img src="{{ tenant.branding_logo }}" alt="authentik Logo" />
|
||||||
</div>
|
</div>
|
||||||
</header>
|
|
||||||
{% block main_container %}
|
|
||||||
<main class="pf-c-login__main">
|
|
||||||
<header class="pf-c-login__main-header">
|
<header class="pf-c-login__main-header">
|
||||||
<h1 class="pf-c-title pf-m-3xl">
|
<h1 class="pf-c-title pf-m-3xl">
|
||||||
{% block card_title %}
|
{% block card_title %}
|
||||||
|
@ -77,7 +63,6 @@
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
{% endblock %}
|
|
||||||
<footer class="pf-c-login__footer">
|
<footer class="pf-c-login__footer">
|
||||||
<ul class="pf-c-list pf-m-inline">
|
<ul class="pf-c-list pf-m-inline">
|
||||||
{% for link in footer_links %}
|
{% for link in footer_links %}
|
||||||
|
|
|
@ -22,6 +22,7 @@ class InterfaceView(TemplateView):
|
||||||
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
|
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
|
||||||
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
|
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
|
||||||
kwargs["build"] = get_build_hash()
|
kwargs["build"] = get_build_hash()
|
||||||
|
kwargs["url_kwargs"] = self.kwargs
|
||||||
return super().get_context_data(**kwargs)
|
return super().get_context_data(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,11 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
from drf_spectacular.types import OpenApiTypes
|
from drf_spectacular.types import OpenApiTypes
|
||||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField
|
from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField
|
||||||
from rest_framework.permissions import IsAuthenticated
|
from rest_framework.permissions import IsAuthenticated
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
|
@ -20,6 +22,18 @@ from authentik.enterprise.models import License, LicenseKey
|
||||||
from authentik.root.install_id import get_install_id
|
from authentik.root.install_id import get_install_id
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseRequiredMixin:
|
||||||
|
"""Mixin to validate that a valid enterprise license
|
||||||
|
exists before allowing to safe the object"""
|
||||||
|
|
||||||
|
def validate(self, attrs: dict) -> dict:
|
||||||
|
"""Check that a valid license exists"""
|
||||||
|
total = LicenseKey.get_total()
|
||||||
|
if not total.is_valid():
|
||||||
|
raise ValidationError(_("Enterprise is required to create/update this object."))
|
||||||
|
return super().validate(attrs)
|
||||||
|
|
||||||
|
|
||||||
class LicenseSerializer(ModelSerializer):
|
class LicenseSerializer(ModelSerializer):
|
||||||
"""License Serializer"""
|
"""License Serializer"""
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,11 @@
|
||||||
from authentik.blueprints.apps import ManagedAppConfig
|
from authentik.blueprints.apps import ManagedAppConfig
|
||||||
|
|
||||||
|
|
||||||
class AuthentikEnterpriseConfig(ManagedAppConfig):
|
class EnterpriseConfig(ManagedAppConfig):
|
||||||
|
"""Base app config for all enterprise apps"""
|
||||||
|
|
||||||
|
|
||||||
|
class AuthentikEnterpriseConfig(EnterpriseConfig):
|
||||||
"""Enterprise app config"""
|
"""Enterprise app config"""
|
||||||
|
|
||||||
name = "authentik.enterprise"
|
name = "authentik.enterprise"
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Enterprise license policies"""
|
"""Enterprise license policies"""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from authentik.core.models import User, UserTypes
|
from authentik.core.models import User, UserTypes
|
||||||
from authentik.enterprise.models import LicenseKey
|
from authentik.enterprise.models import LicenseKey
|
||||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||||
|
@ -13,10 +15,10 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
||||||
def check_license(self):
|
def check_license(self):
|
||||||
"""Check license"""
|
"""Check license"""
|
||||||
if not LicenseKey.get_total().is_valid():
|
if not LicenseKey.get_total().is_valid():
|
||||||
return False
|
return PolicyResult(False, _("Enterprise required to access this feature."))
|
||||||
if self.request.user.type != UserTypes.INTERNAL:
|
if self.request.user.type != UserTypes.INTERNAL:
|
||||||
return False
|
return PolicyResult(False, _("Feature only accessible for internal users."))
|
||||||
return True
|
return PolicyResult(True)
|
||||||
|
|
||||||
def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
|
def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
|
||||||
user = user or self.request.user
|
user = user or self.request.user
|
||||||
|
@ -24,7 +26,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
||||||
request.http_request = self.request
|
request.http_request = self.request
|
||||||
result = super().user_has_access(user)
|
result = super().user_has_access(user)
|
||||||
enterprise_result = self.check_license()
|
enterprise_result = self.check_license()
|
||||||
if not enterprise_result:
|
if not enterprise_result.passing:
|
||||||
return enterprise_result
|
return enterprise_result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
0
authentik/enterprise/providers/__init__.py
Normal file
0
authentik/enterprise/providers/__init__.py
Normal file
0
authentik/enterprise/providers/rac/__init__.py
Normal file
0
authentik/enterprise/providers/rac/__init__.py
Normal file
0
authentik/enterprise/providers/rac/api/__init__.py
Normal file
0
authentik/enterprise/providers/rac/api/__init__.py
Normal file
135
authentik/enterprise/providers/rac/api/endpoints.py
Normal file
135
authentik/enterprise/providers/rac/api/endpoints.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
"""RAC Provider API Views"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from django.urls import reverse
|
||||||
|
from drf_spectacular.types import OpenApiTypes
|
||||||
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||||
|
from rest_framework.fields import SerializerMethodField
|
||||||
|
from rest_framework.request import Request
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.serializers import ModelSerializer
|
||||||
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
|
from authentik.core.models import Provider
|
||||||
|
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||||
|
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
||||||
|
from authentik.enterprise.providers.rac.models import Endpoint
|
||||||
|
from authentik.policies.engine import PolicyEngine
|
||||||
|
from authentik.rbac.filters import ObjectFilter
|
||||||
|
|
||||||
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def user_endpoint_cache_key(user_pk: str) -> str:
|
||||||
|
"""Cache key where endpoint list for user is saved"""
|
||||||
|
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointSerializer(EnterpriseRequiredMixin, ModelSerializer):
|
||||||
|
"""Endpoint Serializer"""
|
||||||
|
|
||||||
|
provider_obj = RACProviderSerializer(source="provider", read_only=True)
|
||||||
|
launch_url = SerializerMethodField()
|
||||||
|
|
||||||
|
def get_launch_url(self, endpoint: Endpoint) -> Optional[str]:
|
||||||
|
"""Build actual launch URL (the provider itself does not have one, just
|
||||||
|
individual endpoints)"""
|
||||||
|
try:
|
||||||
|
# pylint: disable=no-member
|
||||||
|
return reverse(
|
||||||
|
"authentik_providers_rac:start",
|
||||||
|
kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk},
|
||||||
|
)
|
||||||
|
except Provider.application.RelatedObjectDoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Endpoint
|
||||||
|
fields = [
|
||||||
|
"pk",
|
||||||
|
"name",
|
||||||
|
"provider",
|
||||||
|
"provider_obj",
|
||||||
|
"protocol",
|
||||||
|
"host",
|
||||||
|
"settings",
|
||||||
|
"property_mappings",
|
||||||
|
"auth_mode",
|
||||||
|
"launch_url",
|
||||||
|
"maximum_connections",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointViewSet(UsedByMixin, ModelViewSet):
|
||||||
|
"""Endpoint Viewset"""
|
||||||
|
|
||||||
|
queryset = Endpoint.objects.all()
|
||||||
|
serializer_class = EndpointSerializer
|
||||||
|
filterset_fields = ["name", "provider"]
|
||||||
|
search_fields = ["name", "protocol"]
|
||||||
|
ordering = ["name", "protocol"]
|
||||||
|
|
||||||
|
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
|
||||||
|
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
|
||||||
|
for backend in list(self.filter_backends):
|
||||||
|
if backend == ObjectFilter:
|
||||||
|
continue
|
||||||
|
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def _get_allowed_endpoints(self, queryset: QuerySet) -> list[Endpoint]:
|
||||||
|
endpoints = []
|
||||||
|
for endpoint in queryset:
|
||||||
|
engine = PolicyEngine(endpoint, self.request.user, self.request)
|
||||||
|
engine.build()
|
||||||
|
if engine.passing:
|
||||||
|
endpoints.append(endpoint)
|
||||||
|
return endpoints
|
||||||
|
|
||||||
|
@extend_schema(
|
||||||
|
parameters=[
|
||||||
|
OpenApiParameter(
|
||||||
|
"search",
|
||||||
|
OpenApiTypes.STR,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="superuser_full_list",
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
type=OpenApiTypes.BOOL,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
responses={
|
||||||
|
200: EndpointSerializer(many=True),
|
||||||
|
400: OpenApiResponse(description="Bad request"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def list(self, request: Request, *args, **kwargs) -> Response:
|
||||||
|
"""List accessible endpoints"""
|
||||||
|
should_cache = request.GET.get("search", "") == ""
|
||||||
|
|
||||||
|
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
|
||||||
|
if superuser_full_list and request.user.is_superuser:
|
||||||
|
return super().list(request)
|
||||||
|
|
||||||
|
queryset = self._filter_queryset_for_list(self.get_queryset())
|
||||||
|
self.paginate_queryset(queryset)
|
||||||
|
|
||||||
|
allowed_endpoints = []
|
||||||
|
if not should_cache:
|
||||||
|
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
||||||
|
if should_cache:
|
||||||
|
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
|
||||||
|
if not allowed_endpoints:
|
||||||
|
LOGGER.debug("Caching allowed endpoint list")
|
||||||
|
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
||||||
|
cache.set(
|
||||||
|
user_endpoint_cache_key(self.request.user.pk),
|
||||||
|
allowed_endpoints,
|
||||||
|
timeout=86400,
|
||||||
|
)
|
||||||
|
serializer = self.get_serializer(allowed_endpoints, many=True)
|
||||||
|
return self.get_paginated_response(serializer.data)
|
36
authentik/enterprise/providers/rac/api/property_mappings.py
Normal file
36
authentik/enterprise/providers/rac/api/property_mappings.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
"""RAC Provider API Views"""
|
||||||
|
from rest_framework.fields import CharField
|
||||||
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
|
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||||
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
|
from authentik.core.api.utils import JSONDictField
|
||||||
|
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||||
|
from authentik.enterprise.providers.rac.models import RACPropertyMapping
|
||||||
|
|
||||||
|
|
||||||
|
class RACPropertyMappingSerializer(EnterpriseRequiredMixin, PropertyMappingSerializer):
|
||||||
|
"""RACPropertyMapping Serializer"""
|
||||||
|
|
||||||
|
static_settings = JSONDictField()
|
||||||
|
expression = CharField(allow_blank=True, required=False)
|
||||||
|
|
||||||
|
def validate_expression(self, expression: str) -> str:
|
||||||
|
"""Test Syntax"""
|
||||||
|
if expression == "":
|
||||||
|
return expression
|
||||||
|
return super().validate_expression(expression)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = RACPropertyMapping
|
||||||
|
fields = PropertyMappingSerializer.Meta.fields + ["static_settings"]
|
||||||
|
|
||||||
|
|
||||||
|
class RACPropertyMappingViewSet(UsedByMixin, ModelViewSet):
|
||||||
|
"""RACPropertyMapping Viewset"""
|
||||||
|
|
||||||
|
queryset = RACPropertyMapping.objects.all()
|
||||||
|
serializer_class = RACPropertyMappingSerializer
|
||||||
|
search_fields = ["name"]
|
||||||
|
ordering = ["name"]
|
||||||
|
filterset_fields = ["name", "managed"]
|
32
authentik/enterprise/providers/rac/api/providers.py
Normal file
32
authentik/enterprise/providers/rac/api/providers.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
"""RAC Provider API Views"""
|
||||||
|
from rest_framework.fields import CharField, ListField
|
||||||
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
|
from authentik.core.api.providers import ProviderSerializer
|
||||||
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
|
from authentik.enterprise.api import EnterpriseRequiredMixin
|
||||||
|
from authentik.enterprise.providers.rac.models import RACProvider
|
||||||
|
|
||||||
|
|
||||||
|
class RACProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer):
|
||||||
|
"""RACProvider Serializer"""
|
||||||
|
|
||||||
|
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = RACProvider
|
||||||
|
fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"]
|
||||||
|
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class RACProviderViewSet(UsedByMixin, ModelViewSet):
|
||||||
|
"""RACProvider Viewset"""
|
||||||
|
|
||||||
|
queryset = RACProvider.objects.all()
|
||||||
|
serializer_class = RACProviderSerializer
|
||||||
|
filterset_fields = {
|
||||||
|
"application": ["isnull"],
|
||||||
|
"name": ["iexact"],
|
||||||
|
}
|
||||||
|
search_fields = ["name"]
|
||||||
|
ordering = ["name"]
|
17
authentik/enterprise/providers/rac/apps.py
Normal file
17
authentik/enterprise/providers/rac/apps.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
"""RAC app config"""
|
||||||
|
from authentik.enterprise.apps import EnterpriseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AuthentikEnterpriseProviderRAC(EnterpriseConfig):
|
||||||
|
"""authentik enterprise rac app config"""
|
||||||
|
|
||||||
|
name = "authentik.enterprise.providers.rac"
|
||||||
|
label = "authentik_providers_rac"
|
||||||
|
verbose_name = "authentik Enterprise.Providers.RAC"
|
||||||
|
default = True
|
||||||
|
mountpoint = ""
|
||||||
|
ws_mountpoint = "authentik.enterprise.providers.rac.urls"
|
||||||
|
|
||||||
|
def reconcile_load_rac_signals(self):
|
||||||
|
"""Load rac signals"""
|
||||||
|
self.import_module("authentik.enterprise.providers.rac.signals")
|
163
authentik/enterprise/providers/rac/consumer_client.py
Normal file
163
authentik/enterprise/providers/rac/consumer_client.py
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
"""RAC Client consumer"""
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
|
from channels.exceptions import ChannelFull, DenyConnection
|
||||||
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
|
from django.http.request import QueryDict
|
||||||
|
from structlog.stdlib import BoundLogger, get_logger
|
||||||
|
|
||||||
|
from authentik.enterprise.providers.rac.models import ConnectionToken, RACProvider
|
||||||
|
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
|
||||||
|
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
||||||
|
|
||||||
|
# Global broadcast group, which messages are sent to when the outpost connects back
|
||||||
|
# to authentik for a specific connection
|
||||||
|
# The `RACClientConsumer` consumer adds itself to this group on connection,
|
||||||
|
# and removes itself once it has been assigned a specific outpost channel
|
||||||
|
RAC_CLIENT_GROUP = "group_enterprise_rac_client"
|
||||||
|
# A group for all connections in a given authentik session ID
|
||||||
|
# A disconnect message is sent to this group when the session expires/is deleted
|
||||||
|
RAC_CLIENT_GROUP_SESSION = "group_enterprise_rac_client_%(session)s"
|
||||||
|
# A group for all connections with a specific token, which in almost all cases
|
||||||
|
# is just one connection, however this is used to disconnect the connection
|
||||||
|
# when the token is deleted
|
||||||
|
RAC_CLIENT_GROUP_TOKEN = "group_enterprise_rac_token_%(token)s" # nosec
|
||||||
|
|
||||||
|
# Step 1: Client connects to this websocket endpoint
|
||||||
|
# Step 2: We prepare all the connection args for Guac
|
||||||
|
# Step 3: Send a websocket message to a single outpost that has this provider assigned
|
||||||
|
# (Currently sending to all of them)
|
||||||
|
# (Should probably do different load balancing algorithms)
|
||||||
|
# Step 4: Outpost creates a websocket connection back to authentik
|
||||||
|
# with /ws/outpost_rac/<our_channel_id>/
|
||||||
|
# Step 5: This consumer transfers data between the two channels
|
||||||
|
|
||||||
|
|
||||||
|
class RACClientConsumer(AsyncWebsocketConsumer):
|
||||||
|
"""RAC client consumer the browser connects to"""
|
||||||
|
|
||||||
|
dest_channel_id: str = ""
|
||||||
|
provider: RACProvider
|
||||||
|
token: ConnectionToken
|
||||||
|
logger: BoundLogger
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
await self.accept("guacamole")
|
||||||
|
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
|
||||||
|
await self.channel_layer.group_add(
|
||||||
|
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
|
||||||
|
self.channel_name,
|
||||||
|
)
|
||||||
|
await self.init_outpost_connection()
|
||||||
|
|
||||||
|
async def disconnect(self, code):
|
||||||
|
self.logger.debug("Disconnecting")
|
||||||
|
# Tell the outpost we're disconnecting
|
||||||
|
await self.channel_layer.send(
|
||||||
|
self.dest_channel_id,
|
||||||
|
{
|
||||||
|
"type": "event.disconnect",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def init_outpost_connection(self):
|
||||||
|
"""Initialize guac connection settings"""
|
||||||
|
self.token = ConnectionToken.filter_not_expired(
|
||||||
|
token=self.scope["url_route"]["kwargs"]["token"]
|
||||||
|
).first()
|
||||||
|
if not self.token:
|
||||||
|
raise DenyConnection()
|
||||||
|
self.provider = self.token.provider
|
||||||
|
params = self.token.get_settings()
|
||||||
|
self.logger = get_logger().bind(
|
||||||
|
endpoint=self.token.endpoint.name, user=self.scope["user"].username
|
||||||
|
)
|
||||||
|
msg = {
|
||||||
|
"type": "event.provider.specific",
|
||||||
|
"sub_type": "init_connection",
|
||||||
|
"dest_channel_id": self.channel_name,
|
||||||
|
"params": params,
|
||||||
|
"protocol": self.token.endpoint.protocol,
|
||||||
|
}
|
||||||
|
query = QueryDict(self.scope["query_string"].decode())
|
||||||
|
for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
|
||||||
|
value = query.get(key, None)
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
msg[key] = str(value)
|
||||||
|
outposts = Outpost.objects.filter(
|
||||||
|
type=OutpostType.RAC,
|
||||||
|
providers__in=[self.provider],
|
||||||
|
)
|
||||||
|
if not outposts.exists():
|
||||||
|
self.logger.warning("Provider has no outpost")
|
||||||
|
raise DenyConnection()
|
||||||
|
for outpost in outposts:
|
||||||
|
# Sort all states for the outpost by connection count
|
||||||
|
states = sorted(
|
||||||
|
OutpostState.for_outpost(outpost),
|
||||||
|
key=lambda state: int(state.args.get("active_connections", 0)),
|
||||||
|
)
|
||||||
|
if len(states) < 1:
|
||||||
|
continue
|
||||||
|
self.logger.debug("Sending out connection broadcast")
|
||||||
|
async_to_sync(self.channel_layer.group_send)(
|
||||||
|
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def receive(self, text_data=None, bytes_data=None):
|
||||||
|
"""Mirror data received from client to the dest_channel_id
|
||||||
|
which is the channel talking to guacd"""
|
||||||
|
if self.dest_channel_id == "":
|
||||||
|
return
|
||||||
|
if self.token.is_expired:
|
||||||
|
await self.event_disconnect({"reason": "token_expiry"})
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self.channel_layer.send(
|
||||||
|
self.dest_channel_id,
|
||||||
|
{
|
||||||
|
"type": "event.send",
|
||||||
|
"text_data": text_data,
|
||||||
|
"bytes_data": bytes_data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ChannelFull:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def event_outpost_connected(self, event: dict):
|
||||||
|
"""Handle event broadcasted from outpost consumer, and check if they
|
||||||
|
created a connection for us"""
|
||||||
|
outpost_channel = event.get("outpost_channel")
|
||||||
|
if event.get("client_channel") != self.channel_name:
|
||||||
|
return
|
||||||
|
if self.dest_channel_id != "":
|
||||||
|
# We've already selected an outpost channel, so tell the other channel to disconnect
|
||||||
|
# This should never happen since we remove ourselves from the broadcast group
|
||||||
|
await self.channel_layer.send(
|
||||||
|
outpost_channel,
|
||||||
|
{
|
||||||
|
"type": "event.disconnect",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.logger.debug("Connected to a single outpost instance")
|
||||||
|
self.dest_channel_id = outpost_channel
|
||||||
|
# Since we have a specific outpost channel now, we can remove
|
||||||
|
# ourselves from the global broadcast group
|
||||||
|
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
|
||||||
|
|
||||||
|
async def event_send(self, event: dict):
|
||||||
|
"""Handler called by outpost websocket that sends data to this specific
|
||||||
|
client connection"""
|
||||||
|
if self.token.is_expired:
|
||||||
|
await self.event_disconnect({"reason": "token_expiry"})
|
||||||
|
return
|
||||||
|
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
|
||||||
|
|
||||||
|
async def event_disconnect(self, event: dict):
|
||||||
|
"""Disconnect when the session ends"""
|
||||||
|
self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
|
||||||
|
await self.close()
|
48
authentik/enterprise/providers/rac/consumer_outpost.py
Normal file
48
authentik/enterprise/providers/rac/consumer_outpost.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
"""RAC consumer"""
|
||||||
|
from channels.exceptions import ChannelFull
|
||||||
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
|
|
||||||
|
from authentik.enterprise.providers.rac.consumer_client import RAC_CLIENT_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
class RACOutpostConsumer(AsyncWebsocketConsumer):
|
||||||
|
"""Consumer the outpost connects to, to send specific data back to a client connection"""
|
||||||
|
|
||||||
|
dest_channel_id: str
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
|
||||||
|
await self.accept()
|
||||||
|
await self.channel_layer.group_send(
|
||||||
|
RAC_CLIENT_GROUP,
|
||||||
|
{
|
||||||
|
"type": "event.outpost.connected",
|
||||||
|
"outpost_channel": self.channel_name,
|
||||||
|
"client_channel": self.dest_channel_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def receive(self, text_data=None, bytes_data=None):
|
||||||
|
"""Mirror data received from guacd running in the outpost
|
||||||
|
to the dest_channel_id which is the channel talking to the browser"""
|
||||||
|
try:
|
||||||
|
await self.channel_layer.send(
|
||||||
|
self.dest_channel_id,
|
||||||
|
{
|
||||||
|
"type": "event.send",
|
||||||
|
"text_data": text_data,
|
||||||
|
"bytes_data": bytes_data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ChannelFull:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def event_send(self, event: dict):
|
||||||
|
"""Handler called by client websocket that sends data to this specific
|
||||||
|
outpost connection"""
|
||||||
|
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
|
||||||
|
|
||||||
|
async def event_disconnect(self, event: dict):
|
||||||
|
"""Tell outpost we're about to disconnect"""
|
||||||
|
await self.send(text_data="0.authentik.disconnect")
|
||||||
|
await self.close()
|
11
authentik/enterprise/providers/rac/controllers/docker.py
Normal file
11
authentik/enterprise/providers/rac/controllers/docker.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
"""RAC Provider Docker Controller"""
|
||||||
|
from authentik.outposts.controllers.docker import DockerController
|
||||||
|
from authentik.outposts.models import DockerServiceConnection, Outpost
|
||||||
|
|
||||||
|
|
||||||
|
class RACDockerController(DockerController):
|
||||||
|
"""RAC Provider Docker Controller"""
|
||||||
|
|
||||||
|
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
|
||||||
|
super().__init__(outpost, connection)
|
||||||
|
self.deployment_ports = []
|
13
authentik/enterprise/providers/rac/controllers/kubernetes.py
Normal file
13
authentik/enterprise/providers/rac/controllers/kubernetes.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
"""RAC Provider Kubernetes Controller"""
|
||||||
|
from authentik.outposts.controllers.k8s.service import ServiceReconciler
|
||||||
|
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||||
|
from authentik.outposts.models import KubernetesServiceConnection, Outpost
|
||||||
|
|
||||||
|
|
||||||
|
class RACKubernetesController(KubernetesController):
|
||||||
|
"""RAC Provider Kubernetes Controller"""
|
||||||
|
|
||||||
|
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
|
||||||
|
super().__init__(outpost, connection)
|
||||||
|
self.deployment_ports = []
|
||||||
|
del self.reconcilers[ServiceReconciler.reconciler_name()]
|
164
authentik/enterprise/providers/rac/migrations/0001_initial.py
Normal file
164
authentik/enterprise/providers/rac/migrations/0001_initial.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# Generated by Django 4.2.8 on 2023-12-29 15:58
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
import authentik.core.models
|
||||||
|
import authentik.lib.utils.time
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
initial = True
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_policies", "0011_policybinding_failure_result_and_more"),
|
||||||
|
("authentik_core", "0032_group_roles"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="RACPropertyMapping",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"propertymapping_ptr",
|
||||||
|
models.OneToOneField(
|
||||||
|
auto_created=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
parent_link=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
to="authentik_core.propertymapping",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("static_settings", models.JSONField(default=dict)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"verbose_name": "RAC Property Mapping",
|
||||||
|
"verbose_name_plural": "RAC Property Mappings",
|
||||||
|
},
|
||||||
|
bases=("authentik_core.propertymapping",),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="RACProvider",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"provider_ptr",
|
||||||
|
models.OneToOneField(
|
||||||
|
auto_created=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
parent_link=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
to="authentik_core.provider",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("settings", models.JSONField(default=dict)),
|
||||||
|
(
|
||||||
|
"auth_mode",
|
||||||
|
models.TextField(
|
||||||
|
choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"connection_expiry",
|
||||||
|
models.TextField(
|
||||||
|
default="hours=8",
|
||||||
|
help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)",
|
||||||
|
validators=[authentik.lib.utils.time.timedelta_string_validator],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"verbose_name": "RAC Provider",
|
||||||
|
"verbose_name_plural": "RAC Providers",
|
||||||
|
},
|
||||||
|
bases=("authentik_core.provider",),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="Endpoint",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"policybindingmodel_ptr",
|
||||||
|
models.OneToOneField(
|
||||||
|
auto_created=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
parent_link=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
to="authentik_policies.policybindingmodel",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("name", models.TextField()),
|
||||||
|
("host", models.TextField()),
|
||||||
|
(
|
||||||
|
"protocol",
|
||||||
|
models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]),
|
||||||
|
),
|
||||||
|
("settings", models.JSONField(default=dict)),
|
||||||
|
(
|
||||||
|
"auth_mode",
|
||||||
|
models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"property_mappings",
|
||||||
|
models.ManyToManyField(
|
||||||
|
blank=True, default=None, to="authentik_core.propertymapping"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"provider",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="authentik_providers_rac.racprovider",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"verbose_name": "RAC Endpoint",
|
||||||
|
"verbose_name_plural": "RAC Endpoints",
|
||||||
|
},
|
||||||
|
bases=("authentik_policies.policybindingmodel", models.Model),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="ConnectionToken",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"expires",
|
||||||
|
models.DateTimeField(default=authentik.core.models.default_token_duration),
|
||||||
|
),
|
||||||
|
("expiring", models.BooleanField(default=True)),
|
||||||
|
(
|
||||||
|
"connection_token_uuid",
|
||||||
|
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||||
|
),
|
||||||
|
("token", models.TextField(default=authentik.core.models.default_token_key)),
|
||||||
|
("settings", models.JSONField(default=dict)),
|
||||||
|
(
|
||||||
|
"endpoint",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="authentik_providers_rac.endpoint",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"provider",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="authentik_providers_rac.racprovider",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"session",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="authentik_core.authenticatedsession",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Generated by Django 5.0 on 2024-01-03 23:44
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("authentik_providers_rac", "0001_initial"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="endpoint",
|
||||||
|
name="maximum_connections",
|
||||||
|
field=models.IntegerField(default=1),
|
||||||
|
),
|
||||||
|
]
|
192
authentik/enterprise/providers/rac/models.py
Normal file
192
authentik/enterprise/providers/rac/models.py
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
"""RAC Models"""
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from deepmerge import always_merger
|
||||||
|
from django.db import models
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
from rest_framework.serializers import Serializer
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||||
|
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key
|
||||||
|
from authentik.events.models import Event, EventAction
|
||||||
|
from authentik.lib.models import SerializerModel
|
||||||
|
from authentik.lib.utils.time import timedelta_string_validator
|
||||||
|
from authentik.policies.models import PolicyBindingModel
|
||||||
|
|
||||||
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Protocols(models.TextChoices):
|
||||||
|
"""Supported protocols"""
|
||||||
|
|
||||||
|
RDP = "rdp"
|
||||||
|
VNC = "vnc"
|
||||||
|
SSH = "ssh"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationMode(models.TextChoices):
|
||||||
|
"""Authentication modes"""
|
||||||
|
|
||||||
|
STATIC = "static"
|
||||||
|
PROMPT = "prompt"
|
||||||
|
|
||||||
|
|
||||||
|
class RACProvider(Provider):
|
||||||
|
"""Remotely access computers/servers via RDP/SSH/VNC."""
|
||||||
|
|
||||||
|
settings = models.JSONField(default=dict)
|
||||||
|
auth_mode = models.TextField(
|
||||||
|
choices=AuthenticationMode.choices, default=AuthenticationMode.PROMPT
|
||||||
|
)
|
||||||
|
connection_expiry = models.TextField(
|
||||||
|
default="hours=8",
|
||||||
|
validators=[timedelta_string_validator],
|
||||||
|
help_text=_(
|
||||||
|
"Determines how long a session lasts. Default of 0 means "
|
||||||
|
"that the sessions lasts until the browser is closed. "
|
||||||
|
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def launch_url(self) -> Optional[str]:
|
||||||
|
"""URL to this provider and initiate authorization for the user.
|
||||||
|
Can return None for providers that are not URL-based"""
|
||||||
|
return "goauthentik.io://providers/rac/launch"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component(self) -> str:
|
||||||
|
return "ak-provider-rac-form"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
||||||
|
|
||||||
|
return RACProviderSerializer
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("RAC Provider")
|
||||||
|
verbose_name_plural = _("RAC Providers")
|
||||||
|
|
||||||
|
|
||||||
|
class Endpoint(SerializerModel, PolicyBindingModel):
|
||||||
|
"""Remote-accessible endpoint"""
|
||||||
|
|
||||||
|
name = models.TextField()
|
||||||
|
host = models.TextField()
|
||||||
|
protocol = models.TextField(choices=Protocols.choices)
|
||||||
|
settings = models.JSONField(default=dict)
|
||||||
|
auth_mode = models.TextField(choices=AuthenticationMode.choices)
|
||||||
|
provider = models.ForeignKey("RACProvider", on_delete=models.CASCADE)
|
||||||
|
maximum_connections = models.IntegerField(default=1)
|
||||||
|
|
||||||
|
property_mappings = models.ManyToManyField(
|
||||||
|
"authentik_core.PropertyMapping", default=None, blank=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer
|
||||||
|
|
||||||
|
return EndpointSerializer
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"RAC Endpoint {self.name}"
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("RAC Endpoint")
|
||||||
|
verbose_name_plural = _("RAC Endpoints")
|
||||||
|
|
||||||
|
|
||||||
|
class RACPropertyMapping(PropertyMapping):
|
||||||
|
"""Configure settings for remote access endpoints."""
|
||||||
|
|
||||||
|
static_settings = models.JSONField(default=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component(self) -> str:
|
||||||
|
return "ak-property-mapping-rac-form"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> type[Serializer]:
|
||||||
|
from authentik.enterprise.providers.rac.api.property_mappings import (
|
||||||
|
RACPropertyMappingSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RACPropertyMappingSerializer
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = _("RAC Property Mapping")
|
||||||
|
verbose_name_plural = _("RAC Property Mappings")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionToken(ExpiringModel):
|
||||||
|
"""Token for a single connection to a specified endpoint"""
|
||||||
|
|
||||||
|
connection_token_uuid = models.UUIDField(default=uuid4, primary_key=True)
|
||||||
|
provider = models.ForeignKey(RACProvider, on_delete=models.CASCADE)
|
||||||
|
endpoint = models.ForeignKey(Endpoint, on_delete=models.CASCADE)
|
||||||
|
token = models.TextField(default=default_token_key)
|
||||||
|
settings = models.JSONField(default=dict)
|
||||||
|
session = models.ForeignKey("authentik_core.AuthenticatedSession", on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
def get_settings(self) -> dict:
|
||||||
|
"""Get settings"""
|
||||||
|
default_settings = {}
|
||||||
|
if ":" in self.endpoint.host:
|
||||||
|
host, _, port = self.endpoint.host.partition(":")
|
||||||
|
default_settings["hostname"] = host
|
||||||
|
default_settings["port"] = str(port)
|
||||||
|
else:
|
||||||
|
default_settings["hostname"] = self.endpoint.host
|
||||||
|
default_settings["client-name"] = "authentik"
|
||||||
|
# default_settings["enable-drive"] = "true"
|
||||||
|
# default_settings["drive-name"] = "authentik"
|
||||||
|
settings = {}
|
||||||
|
always_merger.merge(settings, default_settings)
|
||||||
|
always_merger.merge(settings, self.endpoint.provider.settings)
|
||||||
|
always_merger.merge(settings, self.endpoint.settings)
|
||||||
|
always_merger.merge(settings, self.settings)
|
||||||
|
|
||||||
|
def mapping_evaluator(mappings: QuerySet):
|
||||||
|
for mapping in mappings:
|
||||||
|
mapping: RACPropertyMapping
|
||||||
|
if len(mapping.static_settings) > 0:
|
||||||
|
always_merger.merge(settings, mapping.static_settings)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
mapping_settings = mapping.evaluate(
|
||||||
|
self.session.user, None, endpoint=self.endpoint, provider=self.provider
|
||||||
|
)
|
||||||
|
always_merger.merge(settings, mapping_settings)
|
||||||
|
except PropertyMappingExpressionException as exc:
|
||||||
|
Event.new(
|
||||||
|
EventAction.CONFIGURATION_ERROR,
|
||||||
|
message=f"Failed to evaluate property-mapping: '{mapping.name}'",
|
||||||
|
provider=self.provider,
|
||||||
|
mapping=mapping,
|
||||||
|
).set_user(self.session.user).save()
|
||||||
|
LOGGER.warning("Failed to evaluate property mapping", exc=exc)
|
||||||
|
|
||||||
|
mapping_evaluator(
|
||||||
|
RACPropertyMapping.objects.filter(provider__in=[self.provider]).order_by("name")
|
||||||
|
)
|
||||||
|
mapping_evaluator(
|
||||||
|
RACPropertyMapping.objects.filter(endpoint__in=[self.endpoint]).order_by("name")
|
||||||
|
)
|
||||||
|
|
||||||
|
settings["drive-path"] = f"/tmp/connection/{self.token}" # nosec
|
||||||
|
settings["create-drive-path"] = "true"
|
||||||
|
# Ensure all values of the settings dict are strings
|
||||||
|
for key, value in settings.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
continue
|
||||||
|
# Special case for bools
|
||||||
|
if isinstance(value, bool):
|
||||||
|
settings[key] = str(value).lower()
|
||||||
|
continue
|
||||||
|
settings[key] = str(value)
|
||||||
|
return settings
|
54
authentik/enterprise/providers/rac/signals.py
Normal file
54
authentik/enterprise/providers/rac/signals.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
"""RAC Signals"""
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from channels.layers import get_channel_layer
|
||||||
|
from django.contrib.auth.signals import user_logged_out
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.db.models import Model
|
||||||
|
from django.db.models.signals import post_save, pre_delete
|
||||||
|
from django.dispatch import receiver
|
||||||
|
from django.http import HttpRequest
|
||||||
|
|
||||||
|
from authentik.core.models import User
|
||||||
|
from authentik.enterprise.providers.rac.api.endpoints import user_endpoint_cache_key
|
||||||
|
from authentik.enterprise.providers.rac.consumer_client import (
|
||||||
|
RAC_CLIENT_GROUP_SESSION,
|
||||||
|
RAC_CLIENT_GROUP_TOKEN,
|
||||||
|
)
|
||||||
|
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(user_logged_out)
|
||||||
|
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
|
||||||
|
"""Disconnect any open RAC connections"""
|
||||||
|
layer = get_channel_layer()
|
||||||
|
async_to_sync(layer.group_send)(
|
||||||
|
RAC_CLIENT_GROUP_SESSION
|
||||||
|
% {
|
||||||
|
"session": request.session.session_key,
|
||||||
|
},
|
||||||
|
{"type": "event.disconnect", "reason": "session_logout"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(pre_delete, sender=ConnectionToken)
|
||||||
|
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
|
||||||
|
"""Disconnect session when connection token is deleted"""
|
||||||
|
layer = get_channel_layer()
|
||||||
|
async_to_sync(layer.group_send)(
|
||||||
|
RAC_CLIENT_GROUP_TOKEN
|
||||||
|
% {
|
||||||
|
"token": instance.token,
|
||||||
|
},
|
||||||
|
{"type": "event.disconnect", "reason": "token_delete"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(post_save, sender=Endpoint)
|
||||||
|
def post_save_application(sender: type[Model], instance, created: bool, **_):
|
||||||
|
"""Clear user's application cache upon application creation"""
|
||||||
|
if not created: # pragma: no cover
|
||||||
|
return
|
||||||
|
|
||||||
|
# Delete user endpoint cache
|
||||||
|
keys = cache.keys(user_endpoint_cache_key("*"))
|
||||||
|
cache.delete_many(keys)
|
18
authentik/enterprise/providers/rac/templates/if/rac.html
Normal file
18
authentik/enterprise/providers/rac/templates/if/rac.html
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
{% extends "base/skeleton.html" %}
|
||||||
|
|
||||||
|
{% load static %}
|
||||||
|
|
||||||
|
{% block head %}
|
||||||
|
<script src="{% static 'dist/enterprise/rac/index.js' %}?version={{ version }}" type="module"></script>
|
||||||
|
<meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)">
|
||||||
|
<meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)">
|
||||||
|
<link rel="icon" href="{{ tenant.branding_favicon }}">
|
||||||
|
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
|
||||||
|
{% include "base/header_js.html" %}
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block body %}
|
||||||
|
<ak-rac token="{{ url_kwargs.token }}" endpointName="{{ token.endpoint.name }}">
|
||||||
|
<ak-loading></ak-loading>
|
||||||
|
</ak-rac>
|
||||||
|
{% endblock %}
|
171
authentik/enterprise/providers/rac/tests/test_endpoints_api.py
Normal file
171
authentik/enterprise/providers/rac/tests/test_endpoints_api.py
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
"""Test Endpoints API"""
|
||||||
|
|
||||||
|
from django.urls import reverse
|
||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.core.models import Application
|
||||||
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
|
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
|
from authentik.policies.models import PolicyBinding
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndpointsAPI(APITestCase):
|
||||||
|
"""Test endpoints API"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.user = create_test_admin_user()
|
||||||
|
self.provider = RACProvider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
)
|
||||||
|
self.app = Application.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
slug=generate_id(),
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
self.allowed = Endpoint.objects.create(
|
||||||
|
name=f"a-{generate_id()}",
|
||||||
|
host=generate_id(),
|
||||||
|
protocol=Protocols.RDP,
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
self.denied = Endpoint.objects.create(
|
||||||
|
name=f"b-{generate_id()}",
|
||||||
|
host=generate_id(),
|
||||||
|
protocol=Protocols.RDP,
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
PolicyBinding.objects.create(
|
||||||
|
target=self.denied,
|
||||||
|
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||||
|
order=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list(self):
|
||||||
|
"""Test list operation without superuser_full_list"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get(reverse("authentik_api:endpoint-list"))
|
||||||
|
self.assertJSONEqual(
|
||||||
|
response.content.decode(),
|
||||||
|
{
|
||||||
|
"pagination": {
|
||||||
|
"next": 0,
|
||||||
|
"previous": 0,
|
||||||
|
"count": 2,
|
||||||
|
"current": 1,
|
||||||
|
"total_pages": 1,
|
||||||
|
"start_index": 1,
|
||||||
|
"end_index": 2,
|
||||||
|
},
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"pk": str(self.allowed.pk),
|
||||||
|
"name": self.allowed.name,
|
||||||
|
"provider": self.provider.pk,
|
||||||
|
"provider_obj": {
|
||||||
|
"pk": self.provider.pk,
|
||||||
|
"name": self.provider.name,
|
||||||
|
"authentication_flow": None,
|
||||||
|
"authorization_flow": None,
|
||||||
|
"property_mappings": [],
|
||||||
|
"connection_expiry": "hours=8",
|
||||||
|
"component": "ak-provider-rac-form",
|
||||||
|
"assigned_application_slug": self.app.slug,
|
||||||
|
"assigned_application_name": self.app.name,
|
||||||
|
"verbose_name": "RAC Provider",
|
||||||
|
"verbose_name_plural": "RAC Providers",
|
||||||
|
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||||
|
"settings": {},
|
||||||
|
"outpost_set": [],
|
||||||
|
},
|
||||||
|
"protocol": "rdp",
|
||||||
|
"host": self.allowed.host,
|
||||||
|
"maximum_connections": 1,
|
||||||
|
"settings": {},
|
||||||
|
"property_mappings": [],
|
||||||
|
"auth_mode": "",
|
||||||
|
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list_superuser_full_list(self):
|
||||||
|
"""Test list operation with superuser_full_list"""
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:endpoint-list") + "?superuser_full_list=true"
|
||||||
|
)
|
||||||
|
self.assertJSONEqual(
|
||||||
|
response.content.decode(),
|
||||||
|
{
|
||||||
|
"pagination": {
|
||||||
|
"next": 0,
|
||||||
|
"previous": 0,
|
||||||
|
"count": 2,
|
||||||
|
"current": 1,
|
||||||
|
"total_pages": 1,
|
||||||
|
"start_index": 1,
|
||||||
|
"end_index": 2,
|
||||||
|
},
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"pk": str(self.allowed.pk),
|
||||||
|
"name": self.allowed.name,
|
||||||
|
"provider": self.provider.pk,
|
||||||
|
"provider_obj": {
|
||||||
|
"pk": self.provider.pk,
|
||||||
|
"name": self.provider.name,
|
||||||
|
"authentication_flow": None,
|
||||||
|
"authorization_flow": None,
|
||||||
|
"property_mappings": [],
|
||||||
|
"component": "ak-provider-rac-form",
|
||||||
|
"assigned_application_slug": self.app.slug,
|
||||||
|
"assigned_application_name": self.app.name,
|
||||||
|
"connection_expiry": "hours=8",
|
||||||
|
"verbose_name": "RAC Provider",
|
||||||
|
"verbose_name_plural": "RAC Providers",
|
||||||
|
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||||
|
"settings": {},
|
||||||
|
"outpost_set": [],
|
||||||
|
},
|
||||||
|
"protocol": "rdp",
|
||||||
|
"host": self.allowed.host,
|
||||||
|
"maximum_connections": 1,
|
||||||
|
"settings": {},
|
||||||
|
"property_mappings": [],
|
||||||
|
"auth_mode": "",
|
||||||
|
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pk": str(self.denied.pk),
|
||||||
|
"name": self.denied.name,
|
||||||
|
"provider": self.provider.pk,
|
||||||
|
"provider_obj": {
|
||||||
|
"pk": self.provider.pk,
|
||||||
|
"name": self.provider.name,
|
||||||
|
"authentication_flow": None,
|
||||||
|
"authorization_flow": None,
|
||||||
|
"property_mappings": [],
|
||||||
|
"component": "ak-provider-rac-form",
|
||||||
|
"assigned_application_slug": self.app.slug,
|
||||||
|
"assigned_application_name": self.app.name,
|
||||||
|
"connection_expiry": "hours=8",
|
||||||
|
"verbose_name": "RAC Provider",
|
||||||
|
"verbose_name_plural": "RAC Providers",
|
||||||
|
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||||
|
"settings": {},
|
||||||
|
"outpost_set": [],
|
||||||
|
},
|
||||||
|
"protocol": "rdp",
|
||||||
|
"host": self.denied.host,
|
||||||
|
"maximum_connections": 1,
|
||||||
|
"settings": {},
|
||||||
|
"property_mappings": [],
|
||||||
|
"auth_mode": "",
|
||||||
|
"launch_url": f"/application/rac/{self.app.slug}/{str(self.denied.pk)}/",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
144
authentik/enterprise/providers/rac/tests/test_models.py
Normal file
144
authentik/enterprise/providers/rac/tests/test_models.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
"""Test RAC Models"""
|
||||||
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
|
from authentik.core.models import Application, AuthenticatedSession
|
||||||
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
|
from authentik.enterprise.providers.rac.models import (
|
||||||
|
ConnectionToken,
|
||||||
|
Endpoint,
|
||||||
|
Protocols,
|
||||||
|
RACPropertyMapping,
|
||||||
|
RACProvider,
|
||||||
|
)
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels(TransactionTestCase):
|
||||||
|
"""Test RAC Models"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.user = create_test_admin_user()
|
||||||
|
self.provider = RACProvider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
)
|
||||||
|
self.app = Application.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
slug=generate_id(),
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
self.endpoint = Endpoint.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
host=f"{generate_id()}:1324",
|
||||||
|
protocol=Protocols.RDP,
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_settings_merge(self):
|
||||||
|
"""Test settings merge"""
|
||||||
|
token = ConnectionToken.objects.create(
|
||||||
|
provider=self.provider,
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
session=AuthenticatedSession.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
session_key=generate_id(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
path = f"/tmp/connection/{token.token}" # nosec
|
||||||
|
self.assertEqual(
|
||||||
|
token.get_settings(),
|
||||||
|
{
|
||||||
|
"hostname": self.endpoint.host.split(":")[0],
|
||||||
|
"port": "1324",
|
||||||
|
"client-name": "authentik",
|
||||||
|
"drive-path": path,
|
||||||
|
"create-drive-path": "true",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Set settings in provider
|
||||||
|
self.provider.settings = {"level": "provider"}
|
||||||
|
self.provider.save()
|
||||||
|
self.assertEqual(
|
||||||
|
token.get_settings(),
|
||||||
|
{
|
||||||
|
"hostname": self.endpoint.host.split(":")[0],
|
||||||
|
"port": "1324",
|
||||||
|
"client-name": "authentik",
|
||||||
|
"drive-path": path,
|
||||||
|
"create-drive-path": "true",
|
||||||
|
"level": "provider",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Set settings in endpoint
|
||||||
|
self.endpoint.settings = {
|
||||||
|
"level": "endpoint",
|
||||||
|
}
|
||||||
|
self.endpoint.save()
|
||||||
|
self.assertEqual(
|
||||||
|
token.get_settings(),
|
||||||
|
{
|
||||||
|
"hostname": self.endpoint.host.split(":")[0],
|
||||||
|
"port": "1324",
|
||||||
|
"client-name": "authentik",
|
||||||
|
"drive-path": path,
|
||||||
|
"create-drive-path": "true",
|
||||||
|
"level": "endpoint",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Set settings in token
|
||||||
|
token.settings = {
|
||||||
|
"level": "token",
|
||||||
|
}
|
||||||
|
token.save()
|
||||||
|
self.assertEqual(
|
||||||
|
token.get_settings(),
|
||||||
|
{
|
||||||
|
"hostname": self.endpoint.host.split(":")[0],
|
||||||
|
"port": "1324",
|
||||||
|
"client-name": "authentik",
|
||||||
|
"drive-path": path,
|
||||||
|
"create-drive-path": "true",
|
||||||
|
"level": "token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Set settings in property mapping (provider)
|
||||||
|
mapping = RACPropertyMapping.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
expression="""return {
|
||||||
|
"level": "property_mapping_provider"
|
||||||
|
}""",
|
||||||
|
)
|
||||||
|
self.provider.property_mappings.add(mapping)
|
||||||
|
self.assertEqual(
|
||||||
|
token.get_settings(),
|
||||||
|
{
|
||||||
|
"hostname": self.endpoint.host.split(":")[0],
|
||||||
|
"port": "1324",
|
||||||
|
"client-name": "authentik",
|
||||||
|
"drive-path": path,
|
||||||
|
"create-drive-path": "true",
|
||||||
|
"level": "property_mapping_provider",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Set settings in property mapping (endpoint)
|
||||||
|
mapping = RACPropertyMapping.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
static_settings={
|
||||||
|
"level": "property_mapping_endpoint",
|
||||||
|
"foo": True,
|
||||||
|
"bar": 6,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.endpoint.property_mappings.add(mapping)
|
||||||
|
self.assertEqual(
|
||||||
|
token.get_settings(),
|
||||||
|
{
|
||||||
|
"hostname": self.endpoint.host.split(":")[0],
|
||||||
|
"port": "1324",
|
||||||
|
"client-name": "authentik",
|
||||||
|
"drive-path": path,
|
||||||
|
"create-drive-path": "true",
|
||||||
|
"level": "property_mapping_endpoint",
|
||||||
|
"foo": "true",
|
||||||
|
"bar": "6",
|
||||||
|
},
|
||||||
|
)
|
132
authentik/enterprise/providers/rac/tests/test_views.py
Normal file
132
authentik/enterprise/providers/rac/tests/test_views.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
"""RAC Views tests"""
|
||||||
|
from datetime import timedelta
|
||||||
|
from json import loads
|
||||||
|
from time import mktime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from django.urls import reverse
|
||||||
|
from django.utils.timezone import now
|
||||||
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.core.models import Application
|
||||||
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
|
from authentik.enterprise.models import License, LicenseKey
|
||||||
|
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.policies.denied import AccessDeniedResponse
|
||||||
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
|
from authentik.policies.models import PolicyBinding
|
||||||
|
|
||||||
|
|
||||||
|
class TestRACViews(APITestCase):
|
||||||
|
"""RAC Views tests"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.user = create_test_admin_user()
|
||||||
|
self.flow = create_test_flow()
|
||||||
|
self.provider = RACProvider.objects.create(name=generate_id(), authorization_flow=self.flow)
|
||||||
|
self.app = Application.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
slug=generate_id(),
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
self.endpoint = Endpoint.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
host=f"{generate_id()}:1324",
|
||||||
|
protocol=Protocols.RDP,
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"authentik.enterprise.models.LicenseKey.validate",
|
||||||
|
MagicMock(
|
||||||
|
return_value=LicenseKey(
|
||||||
|
aud="",
|
||||||
|
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||||
|
name=generate_id(),
|
||||||
|
internal_users=100,
|
||||||
|
external_users=100,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_no_policy(self):
|
||||||
|
"""Test request"""
|
||||||
|
License.objects.create(key=generate_id())
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get(
|
||||||
|
reverse(
|
||||||
|
"authentik_providers_rac:start",
|
||||||
|
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
flow_response = self.client.get(
|
||||||
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||||
|
)
|
||||||
|
body = loads(flow_response.content)
|
||||||
|
next_url = body["to"]
|
||||||
|
final_response = self.client.get(next_url)
|
||||||
|
self.assertEqual(final_response.status_code, 200)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"authentik.enterprise.models.LicenseKey.validate",
|
||||||
|
MagicMock(
|
||||||
|
return_value=LicenseKey(
|
||||||
|
aud="",
|
||||||
|
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||||
|
name=generate_id(),
|
||||||
|
internal_users=100,
|
||||||
|
external_users=100,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_app_deny(self):
|
||||||
|
"""Test request (deny on app level)"""
|
||||||
|
PolicyBinding.objects.create(
|
||||||
|
target=self.app,
|
||||||
|
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||||
|
order=0,
|
||||||
|
)
|
||||||
|
License.objects.create(key=generate_id())
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get(
|
||||||
|
reverse(
|
||||||
|
"authentik_providers_rac:start",
|
||||||
|
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertIsInstance(response, AccessDeniedResponse)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"authentik.enterprise.models.LicenseKey.validate",
|
||||||
|
MagicMock(
|
||||||
|
return_value=LicenseKey(
|
||||||
|
aud="",
|
||||||
|
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||||
|
name=generate_id(),
|
||||||
|
internal_users=100,
|
||||||
|
external_users=100,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_endpoint_deny(self):
|
||||||
|
"""Test request (deny on endpoint level)"""
|
||||||
|
PolicyBinding.objects.create(
|
||||||
|
target=self.endpoint,
|
||||||
|
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||||
|
order=0,
|
||||||
|
)
|
||||||
|
License.objects.create(key=generate_id())
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get(
|
||||||
|
reverse(
|
||||||
|
"authentik_providers_rac:start",
|
||||||
|
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
flow_response = self.client.get(
|
||||||
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||||
|
)
|
||||||
|
body = loads(flow_response.content)
|
||||||
|
self.assertEqual(body["component"], "ak-stage-access-denied")
|
47
authentik/enterprise/providers/rac/urls.py
Normal file
47
authentik/enterprise/providers/rac/urls.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
"""rac urls"""
|
||||||
|
from channels.auth import AuthMiddleware
|
||||||
|
from channels.sessions import CookieMiddleware
|
||||||
|
from django.urls import path
|
||||||
|
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||||
|
|
||||||
|
from authentik.core.channels import TokenOutpostMiddleware
|
||||||
|
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
|
||||||
|
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
|
||||||
|
from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
|
||||||
|
from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer
|
||||||
|
from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer
|
||||||
|
from authentik.enterprise.providers.rac.views import RACInterface, RACStartView
|
||||||
|
from authentik.root.asgi_middleware import SessionMiddleware
|
||||||
|
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path(
|
||||||
|
"application/rac/<slug:app>/<uuid:endpoint>/",
|
||||||
|
ensure_csrf_cookie(RACStartView.as_view()),
|
||||||
|
name="start",
|
||||||
|
),
|
||||||
|
path(
|
||||||
|
"if/rac/<str:token>/",
|
||||||
|
ensure_csrf_cookie(RACInterface.as_view()),
|
||||||
|
name="if-rac",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
websocket_urlpatterns = [
|
||||||
|
path(
|
||||||
|
"ws/rac/<str:token>/",
|
||||||
|
ChannelsLoggingMiddleware(
|
||||||
|
CookieMiddleware(SessionMiddleware(AuthMiddleware(RACClientConsumer.as_asgi())))
|
||||||
|
),
|
||||||
|
),
|
||||||
|
path(
|
||||||
|
"ws/outpost_rac/<str:channel>/",
|
||||||
|
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
api_urlpatterns = [
|
||||||
|
("providers/rac", RACProviderViewSet),
|
||||||
|
("propertymappings/rac", RACPropertyMappingViewSet),
|
||||||
|
("rac/endpoints", EndpointViewSet),
|
||||||
|
]
|
140
authentik/enterprise/providers/rac/views.py
Normal file
140
authentik/enterprise/providers/rac/views.py
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
"""RAC Views"""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from django.http import Http404, HttpRequest, HttpResponse
|
||||||
|
from django.shortcuts import get_object_or_404, redirect
|
||||||
|
from django.urls import reverse
|
||||||
|
from django.utils.timezone import now
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from authentik.core.models import Application, AuthenticatedSession
|
||||||
|
from authentik.core.views.interface import InterfaceView
|
||||||
|
from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||||
|
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint, RACProvider
|
||||||
|
from authentik.events.models import Event, EventAction
|
||||||
|
from authentik.flows.challenge import RedirectChallenge
|
||||||
|
from authentik.flows.exceptions import FlowNonApplicableException
|
||||||
|
from authentik.flows.models import in_memory_stage
|
||||||
|
from authentik.flows.planner import FlowPlanner
|
||||||
|
from authentik.flows.stage import RedirectStage
|
||||||
|
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||||
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
|
from authentik.lib.utils.urls import redirect_with_qs
|
||||||
|
from authentik.policies.engine import PolicyEngine
|
||||||
|
|
||||||
|
|
||||||
|
class RACStartView(EnterprisePolicyAccessView):
|
||||||
|
"""Start a RAC connection by checking access and creating a connection token"""
|
||||||
|
|
||||||
|
endpoint: Endpoint
|
||||||
|
|
||||||
|
def resolve_provider_application(self):
|
||||||
|
self.application = get_object_or_404(Application, slug=self.kwargs["app"])
|
||||||
|
# Endpoint permissions are validated in the RACFinalStage below
|
||||||
|
self.endpoint = get_object_or_404(Endpoint, pk=self.kwargs["endpoint"])
|
||||||
|
self.provider = RACProvider.objects.get(application=self.application)
|
||||||
|
|
||||||
|
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
|
"""Start flow planner for RAC provider"""
|
||||||
|
planner = FlowPlanner(self.provider.authorization_flow)
|
||||||
|
planner.allow_empty_flows = True
|
||||||
|
try:
|
||||||
|
plan = planner.plan(self.request)
|
||||||
|
except FlowNonApplicableException:
|
||||||
|
raise Http404
|
||||||
|
plan.insert_stage(
|
||||||
|
in_memory_stage(
|
||||||
|
RACFinalStage,
|
||||||
|
application=self.application,
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
provider=self.provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request.session[SESSION_KEY_PLAN] = plan
|
||||||
|
return redirect_with_qs(
|
||||||
|
"authentik_core:if-flow",
|
||||||
|
request.GET,
|
||||||
|
flow_slug=self.provider.authorization_flow.slug,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RACInterface(InterfaceView):
|
||||||
|
"""Start RAC connection"""
|
||||||
|
|
||||||
|
template_name = "if/rac.html"
|
||||||
|
token: ConnectionToken
|
||||||
|
|
||||||
|
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||||
|
# Early sanity check to ensure token still exists
|
||||||
|
token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
|
||||||
|
if not token:
|
||||||
|
return redirect("authentik_core:if-user")
|
||||||
|
self.token = token
|
||||||
|
return super().dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
kwargs["token"] = self.token
|
||||||
|
return super().get_context_data(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RACFinalStage(RedirectStage):
|
||||||
|
"""RAC Connection final stage, set the connection token in the stage"""
|
||||||
|
|
||||||
|
endpoint: Endpoint
|
||||||
|
provider: RACProvider
|
||||||
|
application: Application
|
||||||
|
|
||||||
|
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||||
|
self.endpoint = self.executor.current_stage.endpoint
|
||||||
|
self.provider = self.executor.current_stage.provider
|
||||||
|
self.application = self.executor.current_stage.application
|
||||||
|
# Check policies bound to endpoint directly
|
||||||
|
engine = PolicyEngine(self.endpoint, self.request.user, self.request)
|
||||||
|
engine.use_cache = False
|
||||||
|
engine.build()
|
||||||
|
passing = engine.result
|
||||||
|
if not passing.passing:
|
||||||
|
return self.executor.stage_invalid(", ".join(passing.messages))
|
||||||
|
# Check if we're already at the maximum connection limit
|
||||||
|
all_tokens = ConnectionToken.filter_not_expired(
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
).exclude(endpoint__maximum_connections__lte=-1)
|
||||||
|
if all_tokens.count() >= self.endpoint.maximum_connections:
|
||||||
|
msg = [_("Maximum connection limit reached.")]
|
||||||
|
# Check if any other tokens exist for the current user, and inform them
|
||||||
|
# they are already connected
|
||||||
|
if all_tokens.filter(session__user=self.request.user).exists():
|
||||||
|
msg.append(_("(You are already connected in another tab/window)"))
|
||||||
|
return self.executor.stage_invalid(" ".join(msg))
|
||||||
|
return super().dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
|
||||||
|
token = ConnectionToken.objects.create(
|
||||||
|
provider=self.provider,
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
settings=self.executor.plan.context.get("connection_settings", {}),
|
||||||
|
session=AuthenticatedSession.objects.filter(
|
||||||
|
session_key=self.request.session.session_key
|
||||||
|
).first(),
|
||||||
|
expires=now() + timedelta_from_string(self.provider.connection_expiry),
|
||||||
|
expiring=True,
|
||||||
|
)
|
||||||
|
Event.new(
|
||||||
|
EventAction.AUTHORIZE_APPLICATION,
|
||||||
|
authorized_application=self.application,
|
||||||
|
flow=self.executor.plan.flow_pk,
|
||||||
|
endpoint=self.endpoint.name,
|
||||||
|
).from_http(self.request)
|
||||||
|
setattr(
|
||||||
|
self.executor.current_stage,
|
||||||
|
"destination",
|
||||||
|
self.request.build_absolute_uri(
|
||||||
|
reverse(
|
||||||
|
"authentik_providers_rac:if-rac",
|
||||||
|
kwargs={
|
||||||
|
"token": str(token.token),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return super().get_challenge(*args, **kwargs)
|
|
@ -10,3 +10,7 @@ CELERY_BEAT_SCHEDULE = {
|
||||||
"options": {"queue": "authentik_scheduled"},
|
"options": {"queue": "authentik_scheduled"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTALLED_APPS = [
|
||||||
|
"authentik.enterprise.providers.rac",
|
||||||
|
]
|
||||||
|
|
|
@ -6,6 +6,7 @@ import django_filters
|
||||||
from django.db.models.aggregates import Count
|
from django.db.models.aggregates import Count
|
||||||
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
||||||
from django.db.models.functions import ExtractDay, ExtractHour
|
from django.db.models.functions import ExtractDay, ExtractHour
|
||||||
|
from django.db.models.query_utils import Q
|
||||||
from drf_spectacular.types import OpenApiTypes
|
from drf_spectacular.types import OpenApiTypes
|
||||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||||
from guardian.shortcuts import get_objects_for_user
|
from guardian.shortcuts import get_objects_for_user
|
||||||
|
@ -87,7 +88,12 @@ class EventsFilter(django_filters.FilterSet):
|
||||||
we need to remove the dashes that a client may send. We can't use a
|
we need to remove the dashes that a client may send. We can't use a
|
||||||
UUIDField for this, as some models might not have a UUID PK"""
|
UUIDField for this, as some models might not have a UUID PK"""
|
||||||
value = str(value).replace("-", "")
|
value = str(value).replace("-", "")
|
||||||
return queryset.filter(context__model__pk=value)
|
query = Q(context__model__pk=value)
|
||||||
|
try:
|
||||||
|
query |= Q(context__model__pk=int(value))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return queryset.filter(query)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Event
|
model = Event
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
from authentik.blueprints.apps import ManagedAppConfig
|
from authentik.blueprints.apps import ManagedAppConfig
|
||||||
|
from authentik.lib.config import CONFIG, ENV_PREFIX
|
||||||
|
|
||||||
GAUGE_TASKS = Gauge(
|
GAUGE_TASKS = Gauge(
|
||||||
"authentik_system_tasks",
|
"authentik_system_tasks",
|
||||||
|
@ -21,3 +22,24 @@ class AuthentikEventsConfig(ManagedAppConfig):
|
||||||
def reconcile_load_events_signals(self):
|
def reconcile_load_events_signals(self):
|
||||||
"""Load events signals"""
|
"""Load events signals"""
|
||||||
self.import_module("authentik.events.signals")
|
self.import_module("authentik.events.signals")
|
||||||
|
|
||||||
|
def reconcile_check_deprecations(self):
|
||||||
|
"""Check for config deprecations"""
|
||||||
|
from authentik.events.models import Event, EventAction
|
||||||
|
|
||||||
|
for key_replace, msg in CONFIG.deprecations.items():
|
||||||
|
key, replace = key_replace
|
||||||
|
key_env = f"{ENV_PREFIX}_{key.replace('.', '__')}".upper()
|
||||||
|
replace_env = f"{ENV_PREFIX}_{replace.replace('.', '__')}".upper()
|
||||||
|
if Event.objects.filter(
|
||||||
|
action=EventAction.CONFIGURATION_ERROR, context__deprecated_option=key
|
||||||
|
).exists():
|
||||||
|
continue
|
||||||
|
Event.new(
|
||||||
|
EventAction.CONFIGURATION_ERROR,
|
||||||
|
deprecated_option=key,
|
||||||
|
deprecated_env=key_env,
|
||||||
|
replacement_option=replace,
|
||||||
|
replacement_env=replace_env,
|
||||||
|
message=msg,
|
||||||
|
).save()
|
||||||
|
|
0
authentik/events/context_processors/__init__.py
Normal file
0
authentik/events/context_processors/__init__.py
Normal file
81
authentik/events/context_processors/asn.py
Normal file
81
authentik/events/context_processors/asn.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
"""ASN Enricher"""
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||||
|
|
||||||
|
from django.http import HttpRequest
|
||||||
|
from geoip2.errors import GeoIP2Error
|
||||||
|
from geoip2.models import ASN
|
||||||
|
from sentry_sdk import Hub
|
||||||
|
|
||||||
|
from authentik.events.context_processors.mmdb import MMDBContextProcessor
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.api.v3.config import Capabilities
|
||||||
|
from authentik.events.models import Event
|
||||||
|
|
||||||
|
|
||||||
|
class ASNDict(TypedDict):
|
||||||
|
"""ASN Details"""
|
||||||
|
|
||||||
|
asn: int
|
||||||
|
as_org: str | None
|
||||||
|
network: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class ASNContextProcessor(MMDBContextProcessor):
|
||||||
|
"""ASN Database reader wrapper"""
|
||||||
|
|
||||||
|
def capability(self) -> Optional["Capabilities"]:
|
||||||
|
from authentik.api.v3.config import Capabilities
|
||||||
|
|
||||||
|
return Capabilities.CAN_ASN
|
||||||
|
|
||||||
|
def path(self) -> str | None:
|
||||||
|
return CONFIG.get("events.context_processors.asn")
|
||||||
|
|
||||||
|
def enrich_event(self, event: "Event"):
|
||||||
|
asn = self.asn_dict(event.client_ip)
|
||||||
|
if not asn:
|
||||||
|
return
|
||||||
|
event.context["asn"] = asn
|
||||||
|
|
||||||
|
def enrich_context(self, request: HttpRequest) -> dict:
|
||||||
|
return {
|
||||||
|
"asn": self.asn_dict(ClientIPMiddleware.get_client_ip(request)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def asn(self, ip_address: str) -> Optional[ASN]:
|
||||||
|
"""Wrapper for Reader.asn"""
|
||||||
|
with Hub.current.start_span(
|
||||||
|
op="authentik.events.asn.asn",
|
||||||
|
description=ip_address,
|
||||||
|
):
|
||||||
|
if not self.configured():
|
||||||
|
return None
|
||||||
|
self.check_expired()
|
||||||
|
try:
|
||||||
|
return self.reader.asn(ip_address)
|
||||||
|
except (GeoIP2Error, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def asn_to_dict(self, asn: ASN | None) -> ASNDict:
|
||||||
|
"""Convert ASN to dict"""
|
||||||
|
if not asn:
|
||||||
|
return {}
|
||||||
|
asn_dict: ASNDict = {
|
||||||
|
"asn": asn.autonomous_system_number,
|
||||||
|
"as_org": asn.autonomous_system_organization,
|
||||||
|
"network": str(asn.network) if asn.network else None,
|
||||||
|
}
|
||||||
|
return asn_dict
|
||||||
|
|
||||||
|
def asn_dict(self, ip_address: str) -> Optional[ASNDict]:
|
||||||
|
"""Wrapper for self.asn that returns a dict"""
|
||||||
|
asn = self.asn(ip_address)
|
||||||
|
if not asn:
|
||||||
|
return None
|
||||||
|
return self.asn_to_dict(asn)
|
||||||
|
|
||||||
|
|
||||||
|
ASN_CONTEXT_PROCESSOR = ASNContextProcessor()
|
43
authentik/events/context_processors/base.py
Normal file
43
authentik/events/context_processors/base.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
"""Base event enricher"""
|
||||||
|
from functools import cache
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from django.http import HttpRequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.api.v3.config import Capabilities
|
||||||
|
from authentik.events.models import Event
|
||||||
|
|
||||||
|
|
||||||
|
class EventContextProcessor:
|
||||||
|
"""Base event enricher"""
|
||||||
|
|
||||||
|
def capability(self) -> Optional["Capabilities"]:
|
||||||
|
"""Return the capability this context processor provides"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def configured(self) -> bool:
|
||||||
|
"""Return true if this context processor is configured"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def enrich_event(self, event: "Event"):
|
||||||
|
"""Modify event"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def enrich_context(self, request: HttpRequest) -> dict:
|
||||||
|
"""Modify context"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_context_processors() -> list[EventContextProcessor]:
|
||||||
|
"""Get a list of all configured context processors"""
|
||||||
|
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||||
|
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||||
|
|
||||||
|
processors_types = [ASN_CONTEXT_PROCESSOR, GEOIP_CONTEXT_PROCESSOR]
|
||||||
|
processors = []
|
||||||
|
for _type in processors_types:
|
||||||
|
if _type.configured():
|
||||||
|
processors.append(_type)
|
||||||
|
return processors
|
86
authentik/events/context_processors/geoip.py
Normal file
86
authentik/events/context_processors/geoip.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
"""events GeoIP Reader"""
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||||
|
|
||||||
|
from django.http import HttpRequest
|
||||||
|
from geoip2.errors import GeoIP2Error
|
||||||
|
from geoip2.models import City
|
||||||
|
from sentry_sdk.hub import Hub
|
||||||
|
|
||||||
|
from authentik.events.context_processors.mmdb import MMDBContextProcessor
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.api.v3.config import Capabilities
|
||||||
|
from authentik.events.models import Event
|
||||||
|
|
||||||
|
|
||||||
|
class GeoIPDict(TypedDict):
|
||||||
|
"""GeoIP Details"""
|
||||||
|
|
||||||
|
continent: str
|
||||||
|
country: str
|
||||||
|
lat: float
|
||||||
|
long: float
|
||||||
|
city: str
|
||||||
|
|
||||||
|
|
||||||
|
class GeoIPContextProcessor(MMDBContextProcessor):
|
||||||
|
"""Slim wrapper around GeoIP API"""
|
||||||
|
|
||||||
|
def capability(self) -> Optional["Capabilities"]:
|
||||||
|
from authentik.api.v3.config import Capabilities
|
||||||
|
|
||||||
|
return Capabilities.CAN_GEO_IP
|
||||||
|
|
||||||
|
def path(self) -> str | None:
|
||||||
|
return CONFIG.get("events.context_processors.geoip")
|
||||||
|
|
||||||
|
def enrich_event(self, event: "Event"):
|
||||||
|
city = self.city_dict(event.client_ip)
|
||||||
|
if not city:
|
||||||
|
return
|
||||||
|
event.context["geo"] = city
|
||||||
|
|
||||||
|
def enrich_context(self, request: HttpRequest) -> dict:
|
||||||
|
# Different key `geoip` vs `geo` for legacy reasons
|
||||||
|
return {"geoip": self.city(ClientIPMiddleware.get_client_ip(request))}
|
||||||
|
|
||||||
|
def city(self, ip_address: str) -> Optional[City]:
|
||||||
|
"""Wrapper for Reader.city"""
|
||||||
|
with Hub.current.start_span(
|
||||||
|
op="authentik.events.geo.city",
|
||||||
|
description=ip_address,
|
||||||
|
):
|
||||||
|
if not self.configured():
|
||||||
|
return None
|
||||||
|
self.check_expired()
|
||||||
|
try:
|
||||||
|
return self.reader.city(ip_address)
|
||||||
|
except (GeoIP2Error, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def city_to_dict(self, city: City | None) -> GeoIPDict:
|
||||||
|
"""Convert City to dict"""
|
||||||
|
if not city:
|
||||||
|
return {}
|
||||||
|
city_dict: GeoIPDict = {
|
||||||
|
"continent": city.continent.code,
|
||||||
|
"country": city.country.iso_code,
|
||||||
|
"lat": city.location.latitude,
|
||||||
|
"long": city.location.longitude,
|
||||||
|
"city": "",
|
||||||
|
}
|
||||||
|
if city.city.name:
|
||||||
|
city_dict["city"] = city.city.name
|
||||||
|
return city_dict
|
||||||
|
|
||||||
|
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
|
||||||
|
"""Wrapper for self.city that returns a dict"""
|
||||||
|
city = self.city(ip_address)
|
||||||
|
if not city:
|
||||||
|
return None
|
||||||
|
return self.city_to_dict(city)
|
||||||
|
|
||||||
|
|
||||||
|
GEOIP_CONTEXT_PROCESSOR = GeoIPContextProcessor()
|
53
authentik/events/context_processors/mmdb.py
Normal file
53
authentik/events/context_processors/mmdb.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
"""Common logic for reading MMDB files"""
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from geoip2.database import Reader
|
||||||
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
|
from authentik.events.context_processors.base import EventContextProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class MMDBContextProcessor(EventContextProcessor):
|
||||||
|
"""Common logic for reading MaxMind DB files, including re-loading if the file has changed"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reader: Optional[Reader] = None
|
||||||
|
self._last_mtime: float = 0.0
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.open()
|
||||||
|
|
||||||
|
def path(self) -> str | None:
|
||||||
|
"""Get the path to the MMDB file to load"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
"""Get GeoIP Reader, if configured, otherwise none"""
|
||||||
|
path = self.path()
|
||||||
|
if path == "" or not path:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.reader = Reader(path)
|
||||||
|
self._last_mtime = Path(path).stat().st_mtime
|
||||||
|
self.logger.info("Loaded MMDB database", last_write=self._last_mtime, file=path)
|
||||||
|
except OSError as exc:
|
||||||
|
self.logger.warning("Failed to load MMDB database", path=path, exc=exc)
|
||||||
|
|
||||||
|
def check_expired(self):
|
||||||
|
"""Check if the modification date of the MMDB database has
|
||||||
|
changed, and reload it if so"""
|
||||||
|
path = self.path()
|
||||||
|
if path == "" or not path:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
mtime = Path(path).stat().st_mtime
|
||||||
|
diff = self._last_mtime < mtime
|
||||||
|
if diff > 0:
|
||||||
|
self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
|
||||||
|
self.open()
|
||||||
|
except OSError as exc:
|
||||||
|
self.logger.warning("Failed to check MMDB age", exc=exc)
|
||||||
|
|
||||||
|
def configured(self) -> bool:
|
||||||
|
"""Return true if this context processor is configured"""
|
||||||
|
return bool(self.reader)
|
|
@ -1,100 +0,0 @@
|
||||||
"""events GeoIP Reader"""
|
|
||||||
from os import stat
|
|
||||||
from typing import Optional, TypedDict
|
|
||||||
|
|
||||||
from geoip2.database import Reader
|
|
||||||
from geoip2.errors import GeoIP2Error
|
|
||||||
from geoip2.models import City
|
|
||||||
from sentry_sdk.hub import Hub
|
|
||||||
from structlog.stdlib import get_logger
|
|
||||||
|
|
||||||
from authentik.lib.config import CONFIG
|
|
||||||
|
|
||||||
LOGGER = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class GeoIPDict(TypedDict):
|
|
||||||
"""GeoIP Details"""
|
|
||||||
|
|
||||||
continent: str
|
|
||||||
country: str
|
|
||||||
lat: float
|
|
||||||
long: float
|
|
||||||
city: str
|
|
||||||
|
|
||||||
|
|
||||||
class GeoIPReader:
|
|
||||||
"""Slim wrapper around GeoIP API"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.__reader: Optional[Reader] = None
|
|
||||||
self.__last_mtime: float = 0.0
|
|
||||||
self.__open()
|
|
||||||
|
|
||||||
def __open(self):
|
|
||||||
"""Get GeoIP Reader, if configured, otherwise none"""
|
|
||||||
path = CONFIG.get("geoip")
|
|
||||||
if path == "" or not path:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
self.__reader = Reader(path)
|
|
||||||
self.__last_mtime = stat(path).st_mtime
|
|
||||||
LOGGER.info("Loaded GeoIP database", last_write=self.__last_mtime)
|
|
||||||
except OSError as exc:
|
|
||||||
LOGGER.warning("Failed to load GeoIP database", exc=exc)
|
|
||||||
|
|
||||||
def __check_expired(self):
|
|
||||||
"""Check if the modification date of the GeoIP database has
|
|
||||||
changed, and reload it if so"""
|
|
||||||
path = CONFIG.get("geoip")
|
|
||||||
try:
|
|
||||||
mtime = stat(path).st_mtime
|
|
||||||
diff = self.__last_mtime < mtime
|
|
||||||
if diff > 0:
|
|
||||||
LOGGER.info("Found new GeoIP Database, reopening", diff=diff)
|
|
||||||
self.__open()
|
|
||||||
except OSError as exc:
|
|
||||||
LOGGER.warning("Failed to check GeoIP age", exc=exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enabled(self) -> bool:
|
|
||||||
"""Check if GeoIP is enabled"""
|
|
||||||
return bool(self.__reader)
|
|
||||||
|
|
||||||
def city(self, ip_address: str) -> Optional[City]:
|
|
||||||
"""Wrapper for Reader.city"""
|
|
||||||
with Hub.current.start_span(
|
|
||||||
op="authentik.events.geo.city",
|
|
||||||
description=ip_address,
|
|
||||||
):
|
|
||||||
if not self.enabled:
|
|
||||||
return None
|
|
||||||
self.__check_expired()
|
|
||||||
try:
|
|
||||||
return self.__reader.city(ip_address)
|
|
||||||
except (GeoIP2Error, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def city_to_dict(self, city: City) -> GeoIPDict:
|
|
||||||
"""Convert City to dict"""
|
|
||||||
city_dict: GeoIPDict = {
|
|
||||||
"continent": city.continent.code,
|
|
||||||
"country": city.country.iso_code,
|
|
||||||
"lat": city.location.latitude,
|
|
||||||
"long": city.location.longitude,
|
|
||||||
"city": "",
|
|
||||||
}
|
|
||||||
if city.city.name:
|
|
||||||
city_dict["city"] = city.city.name
|
|
||||||
return city_dict
|
|
||||||
|
|
||||||
def city_dict(self, ip_address: str) -> Optional[GeoIPDict]:
|
|
||||||
"""Wrapper for self.city that returns a dict"""
|
|
||||||
city = self.city(ip_address)
|
|
||||||
if not city:
|
|
||||||
return None
|
|
||||||
return self.city_to_dict(city)
|
|
||||||
|
|
||||||
|
|
||||||
GEOIP_READER = GeoIPReader()
|
|
|
@ -20,6 +20,7 @@ from authentik.core.models import (
|
||||||
User,
|
User,
|
||||||
UserSourceConnection,
|
UserSourceConnection,
|
||||||
)
|
)
|
||||||
|
from authentik.enterprise.providers.rac.models import ConnectionToken
|
||||||
from authentik.events.models import Event, EventAction, Notification
|
from authentik.events.models import Event, EventAction, Notification
|
||||||
from authentik.events.utils import model_to_dict
|
from authentik.events.utils import model_to_dict
|
||||||
from authentik.flows.models import FlowToken, Stage
|
from authentik.flows.models import FlowToken, Stage
|
||||||
|
@ -54,6 +55,7 @@ IGNORED_MODELS = (
|
||||||
SCIMUser,
|
SCIMUser,
|
||||||
SCIMGroup,
|
SCIMGroup,
|
||||||
Reputation,
|
Reputation,
|
||||||
|
ConnectionToken,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from authentik.core.middleware import (
|
||||||
SESSION_KEY_IMPERSONATE_USER,
|
SESSION_KEY_IMPERSONATE_USER,
|
||||||
)
|
)
|
||||||
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
|
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.context_processors.base import get_context_processors
|
||||||
from authentik.events.utils import (
|
from authentik.events.utils import (
|
||||||
cleanse_dict,
|
cleanse_dict,
|
||||||
get_user,
|
get_user,
|
||||||
|
@ -36,9 +36,10 @@ from authentik.events.utils import (
|
||||||
)
|
)
|
||||||
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
from authentik.lib.utils.http import get_client_ip, get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.policies.models import PolicyBindingModel
|
from authentik.policies.models import PolicyBindingModel
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.email.utils import TemplateEmailMessage
|
from authentik.stages.email.utils import TemplateEmailMessage
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
from authentik.tenants.utils import DEFAULT_TENANT
|
from authentik.tenants.utils import DEFAULT_TENANT
|
||||||
|
@ -244,22 +245,16 @@ class Event(SerializerModel, ExpiringModel):
|
||||||
self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
|
self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
|
||||||
self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
|
self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
|
||||||
# User 255.255.255.255 as fallback if IP cannot be determined
|
# User 255.255.255.255 as fallback if IP cannot be determined
|
||||||
self.client_ip = get_client_ip(request)
|
self.client_ip = ClientIPMiddleware.get_client_ip(request)
|
||||||
# Apply GeoIP Data, when enabled
|
# Enrich event data
|
||||||
self.with_geoip()
|
for processor in get_context_processors():
|
||||||
|
processor.enrich_event(self)
|
||||||
# If there's no app set, we get it from the requests too
|
# If there's no app set, we get it from the requests too
|
||||||
if not self.app:
|
if not self.app:
|
||||||
self.app = Event._get_app_from_request(request)
|
self.app = Event._get_app_from_request(request)
|
||||||
self.save()
|
self.save()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_geoip(self): # pragma: no cover
|
|
||||||
"""Apply GeoIP Data, when enabled"""
|
|
||||||
city = GEOIP_READER.city_dict(self.client_ip)
|
|
||||||
if not city:
|
|
||||||
return
|
|
||||||
self.context["geo"] = city
|
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
if self._state.adding:
|
if self._state.adding:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
|
@ -466,7 +461,7 @@ class NotificationTransport(SerializerModel):
|
||||||
}
|
}
|
||||||
mail = TemplateEmailMessage(
|
mail = TemplateEmailMessage(
|
||||||
subject=subject_prefix + context["title"],
|
subject=subject_prefix + context["title"],
|
||||||
to=[notification.user.email],
|
to=[f"{notification.user.name} <{notification.user.email}>"],
|
||||||
language=notification.user.locale(),
|
language=notification.user.locale(),
|
||||||
template_name="email/event_notification.html",
|
template_name="email/event_notification.html",
|
||||||
template_context=context,
|
template_context=context,
|
||||||
|
|
|
@ -45,9 +45,14 @@ def get_login_event(request: HttpRequest) -> Optional[Event]:
|
||||||
|
|
||||||
|
|
||||||
@receiver(user_logged_out)
|
@receiver(user_logged_out)
|
||||||
def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
def on_user_logged_out(sender, request: HttpRequest, user: User, **kwargs):
|
||||||
"""Log successfully logout"""
|
"""Log successfully logout"""
|
||||||
Event.new(EventAction.LOGOUT).from_http(request, user=user)
|
# Check if this even comes from the user_login stage's middleware, which will set an extra
|
||||||
|
# argument
|
||||||
|
event = Event.new(EventAction.LOGOUT)
|
||||||
|
if "event_extra" in kwargs:
|
||||||
|
event.context.update(kwargs["event_extra"])
|
||||||
|
event.from_http(request, user=user)
|
||||||
|
|
||||||
|
|
||||||
@receiver(user_write)
|
@receiver(user_write)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Event API tests"""
|
"""Event API tests"""
|
||||||
|
from json import loads
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
@ -11,6 +12,9 @@ from authentik.events.models import (
|
||||||
NotificationSeverity,
|
NotificationSeverity,
|
||||||
TransportMode,
|
TransportMode,
|
||||||
)
|
)
|
||||||
|
from authentik.events.utils import model_to_dict
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
|
from authentik.providers.oauth2.models import OAuth2Provider
|
||||||
|
|
||||||
|
|
||||||
class TestEventsAPI(APITestCase):
|
class TestEventsAPI(APITestCase):
|
||||||
|
@ -20,6 +24,25 @@ class TestEventsAPI(APITestCase):
|
||||||
self.user = create_test_admin_user()
|
self.user = create_test_admin_user()
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
|
|
||||||
|
def test_filter_model_pk_int(self):
|
||||||
|
"""Test event list with context_model_pk and integer PKs"""
|
||||||
|
provider = OAuth2Provider.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
)
|
||||||
|
event = Event.new(EventAction.MODEL_CREATED, model=model_to_dict(provider))
|
||||||
|
event.save()
|
||||||
|
response = self.client.get(
|
||||||
|
reverse("authentik_api:event-list"),
|
||||||
|
data={
|
||||||
|
"context_model_pk": provider.pk,
|
||||||
|
"context_model_app": "authentik_providers_oauth2",
|
||||||
|
"context_model_name": "oauth2provider",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = loads(response.content)
|
||||||
|
self.assertEqual(body["pagination"]["count"], 1)
|
||||||
|
|
||||||
def test_top_n(self):
|
def test_top_n(self):
|
||||||
"""Test top_per_user"""
|
"""Test top_per_user"""
|
||||||
event = Event.new(EventAction.AUTHORIZE_APPLICATION)
|
event = Event.new(EventAction.AUTHORIZE_APPLICATION)
|
||||||
|
|
24
authentik/events/tests/test_enrich_asn.py
Normal file
24
authentik/events/tests/test_enrich_asn.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
"""Test ASN Wrapper"""
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from authentik.events.context_processors.asn import ASNContextProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestASN(TestCase):
|
||||||
|
"""Test ASN Wrapper"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.reader = ASNContextProcessor()
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
"""Test simple asn wrapper"""
|
||||||
|
# IPs from
|
||||||
|
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-ASN-Test.json
|
||||||
|
self.assertEqual(
|
||||||
|
self.reader.asn_dict("1.0.0.1"),
|
||||||
|
{
|
||||||
|
"asn": 15169,
|
||||||
|
"as_org": "Google Inc.",
|
||||||
|
"network": "1.0.0.0/24",
|
||||||
|
},
|
||||||
|
)
|
|
@ -1,14 +1,14 @@
|
||||||
"""Test GeoIP Wrapper"""
|
"""Test GeoIP Wrapper"""
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from authentik.events.geo import GeoIPReader
|
from authentik.events.context_processors.geoip import GeoIPContextProcessor
|
||||||
|
|
||||||
|
|
||||||
class TestGeoIP(TestCase):
|
class TestGeoIP(TestCase):
|
||||||
"""Test GeoIP Wrapper"""
|
"""Test GeoIP Wrapper"""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.reader = GeoIPReader()
|
self.reader = GeoIPContextProcessor()
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
"""Test simple city wrapper"""
|
"""Test simple city wrapper"""
|
|
@ -17,12 +17,13 @@ from django.db.models.base import Model
|
||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.views.debug import SafeExceptionReporterFilter
|
from django.views.debug import SafeExceptionReporterFilter
|
||||||
from geoip2.models import City
|
from geoip2.models import ASN, City
|
||||||
from guardian.utils import get_anonymous_user
|
from guardian.utils import get_anonymous_user
|
||||||
|
|
||||||
from authentik.blueprints.v1.common import YAMLTag
|
from authentik.blueprints.v1.common import YAMLTag
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||||
|
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||||
from authentik.policies.types import PolicyRequest
|
from authentik.policies.types import PolicyRequest
|
||||||
|
|
||||||
# Special keys which are *not* cleaned, even when the default filter
|
# Special keys which are *not* cleaned, even when the default filter
|
||||||
|
@ -123,7 +124,9 @@ def sanitize_item(value: Any) -> Any:
|
||||||
if isinstance(value, (HttpRequest, WSGIRequest)):
|
if isinstance(value, (HttpRequest, WSGIRequest)):
|
||||||
return ...
|
return ...
|
||||||
if isinstance(value, City):
|
if isinstance(value, City):
|
||||||
return GEOIP_READER.city_to_dict(value)
|
return GEOIP_CONTEXT_PROCESSOR.city_to_dict(value)
|
||||||
|
if isinstance(value, ASN):
|
||||||
|
return ASN_CONTEXT_PROCESSOR.asn_to_dict(value)
|
||||||
if isinstance(value, Path):
|
if isinstance(value, Path):
|
||||||
return str(value)
|
return str(value)
|
||||||
if isinstance(value, Exception):
|
if isinstance(value, Exception):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Generated by Django 4.2.6 on 2023-10-28 14:24
|
# Generated by Django 4.2.6 on 2023-10-28 14:24
|
||||||
|
|
||||||
from django.apps.registry import Apps
|
from django.apps.registry import Apps
|
||||||
from django.db import migrations
|
from django.db import migrations, models
|
||||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,4 +31,19 @@ class Migration(migrations.Migration):
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.RunPython(set_oobe_flow_authentication),
|
migrations.RunPython(set_oobe_flow_authentication),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="flow",
|
||||||
|
name="authentication",
|
||||||
|
field=models.TextField(
|
||||||
|
choices=[
|
||||||
|
("none", "None"),
|
||||||
|
("require_authenticated", "Require Authenticated"),
|
||||||
|
("require_unauthenticated", "Require Unauthenticated"),
|
||||||
|
("require_superuser", "Require Superuser"),
|
||||||
|
("require_outpost", "Require Outpost"),
|
||||||
|
],
|
||||||
|
default="none",
|
||||||
|
help_text="Required level of authentication and authorization to access a flow.",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -31,6 +31,7 @@ class FlowAuthenticationRequirement(models.TextChoices):
|
||||||
REQUIRE_AUTHENTICATED = "require_authenticated"
|
REQUIRE_AUTHENTICATED = "require_authenticated"
|
||||||
REQUIRE_UNAUTHENTICATED = "require_unauthenticated"
|
REQUIRE_UNAUTHENTICATED = "require_unauthenticated"
|
||||||
REQUIRE_SUPERUSER = "require_superuser"
|
REQUIRE_SUPERUSER = "require_superuser"
|
||||||
|
REQUIRE_OUTPOST = "require_outpost"
|
||||||
|
|
||||||
|
|
||||||
class NotConfiguredAction(models.TextChoices):
|
class NotConfiguredAction(models.TextChoices):
|
||||||
|
|
|
@ -23,6 +23,7 @@ from authentik.flows.models import (
|
||||||
)
|
)
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
PLAN_CONTEXT_PENDING_USER = "pending_user"
|
PLAN_CONTEXT_PENDING_USER = "pending_user"
|
||||||
|
@ -141,6 +142,10 @@ class FlowPlanner:
|
||||||
and not request.user.is_superuser
|
and not request.user.is_superuser
|
||||||
):
|
):
|
||||||
raise FlowNonApplicableException()
|
raise FlowNonApplicableException()
|
||||||
|
if self.flow.authentication == FlowAuthenticationRequirement.REQUIRE_OUTPOST:
|
||||||
|
outpost_user = ClientIPMiddleware.get_outpost_user(request)
|
||||||
|
if not outpost_user:
|
||||||
|
raise FlowNonApplicableException()
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
|
self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
|
||||||
|
|
|
@ -8,6 +8,7 @@ from django.test import RequestFactory, TestCase
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
|
|
||||||
|
from authentik.blueprints.tests import reconcile_app
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||||
|
@ -15,9 +16,12 @@ from authentik.flows.markers import ReevaluateMarker, StageMarker
|
||||||
from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignation, FlowStageBinding
|
from authentik.flows.models import FlowAuthenticationRequirement, FlowDesignation, FlowStageBinding
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key
|
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache_key
|
||||||
from authentik.lib.tests.utils import dummy_get_response
|
from authentik.lib.tests.utils import dummy_get_response
|
||||||
|
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||||
|
from authentik.outposts.models import Outpost
|
||||||
from authentik.policies.dummy.models import DummyPolicy
|
from authentik.policies.dummy.models import DummyPolicy
|
||||||
from authentik.policies.models import PolicyBinding
|
from authentik.policies.models import PolicyBinding
|
||||||
from authentik.policies.types import PolicyResult
|
from authentik.policies.types import PolicyResult
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.dummy.models import DummyStage
|
from authentik.stages.dummy.models import DummyStage
|
||||||
|
|
||||||
POLICY_RETURN_FALSE = PropertyMock(return_value=PolicyResult(False))
|
POLICY_RETURN_FALSE = PropertyMock(return_value=PolicyResult(False))
|
||||||
|
@ -68,6 +72,34 @@ class TestFlowPlanner(TestCase):
|
||||||
planner.allow_empty_flows = True
|
planner.allow_empty_flows = True
|
||||||
planner.plan(request)
|
planner.plan(request)
|
||||||
|
|
||||||
|
@reconcile_app("authentik_outposts")
|
||||||
|
def test_authentication_outpost(self):
|
||||||
|
"""Test flow authentication (outpost)"""
|
||||||
|
flow = create_test_flow()
|
||||||
|
flow.authentication = FlowAuthenticationRequirement.REQUIRE_OUTPOST
|
||||||
|
request = self.request_factory.get(
|
||||||
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
|
)
|
||||||
|
request.user = AnonymousUser()
|
||||||
|
with self.assertRaises(FlowNonApplicableException):
|
||||||
|
planner = FlowPlanner(flow)
|
||||||
|
planner.allow_empty_flows = True
|
||||||
|
planner.plan(request)
|
||||||
|
|
||||||
|
outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||||
|
request = self.request_factory.get(
|
||||||
|
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||||
|
HTTP_X_AUTHENTIK_OUTPOST_TOKEN=outpost.token.key,
|
||||||
|
HTTP_X_AUTHENTIK_REMOTE_IP="1.2.3.4",
|
||||||
|
)
|
||||||
|
request.user = AnonymousUser()
|
||||||
|
middleware = ClientIPMiddleware(dummy_get_response)
|
||||||
|
middleware(request)
|
||||||
|
|
||||||
|
planner = FlowPlanner(flow)
|
||||||
|
planner.allow_empty_flows = True
|
||||||
|
planner.plan(request)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"authentik.policies.engine.PolicyEngine.result",
|
"authentik.policies.engine.PolicyEngine.result",
|
||||||
POLICY_RETURN_FALSE,
|
POLICY_RETURN_FALSE,
|
||||||
|
|
|
@ -35,6 +35,7 @@ REDIS_ENV_KEYS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
DEPRECATIONS = {
|
DEPRECATIONS = {
|
||||||
|
"geoip": "events.context_processors.geoip",
|
||||||
"redis.broker_url": "broker.url",
|
"redis.broker_url": "broker.url",
|
||||||
"redis.broker_transport_options": "broker.transport_options",
|
"redis.broker_transport_options": "broker.transport_options",
|
||||||
"redis.cache_timeout": "cache.timeout",
|
"redis.cache_timeout": "cache.timeout",
|
||||||
|
@ -112,6 +113,8 @@ class ConfigLoader:
|
||||||
|
|
||||||
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
|
||||||
|
|
||||||
|
deprecations: dict[tuple[str, str], str] = {}
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__config = {}
|
self.__config = {}
|
||||||
|
@ -138,9 +141,9 @@ class ConfigLoader:
|
||||||
self.update_from_file(env_file)
|
self.update_from_file(env_file)
|
||||||
self.update_from_env()
|
self.update_from_env()
|
||||||
self.update(self.__config, kwargs)
|
self.update(self.__config, kwargs)
|
||||||
self.check_deprecations()
|
self.deprecations = self.check_deprecations()
|
||||||
|
|
||||||
def check_deprecations(self):
|
def check_deprecations(self) -> dict[str, str]:
|
||||||
"""Warn if any deprecated configuration options are used"""
|
"""Warn if any deprecated configuration options are used"""
|
||||||
|
|
||||||
def _pop_deprecated_key(current_obj, dot_parts, index):
|
def _pop_deprecated_key(current_obj, dot_parts, index):
|
||||||
|
@ -153,25 +156,23 @@ class ConfigLoader:
|
||||||
current_obj.pop(dot_part)
|
current_obj.pop(dot_part)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
deprecation_replacements = {}
|
||||||
for deprecation, replacement in DEPRECATIONS.items():
|
for deprecation, replacement in DEPRECATIONS.items():
|
||||||
if self.get(deprecation, default=UNSET) is not UNSET:
|
if self.get(deprecation, default=UNSET) is UNSET:
|
||||||
message = (
|
continue
|
||||||
f"'{deprecation}' has been deprecated in favor of '{replacement}'! "
|
message = (
|
||||||
+ "Please update your configuration."
|
f"'{deprecation}' has been deprecated in favor of '{replacement}'! "
|
||||||
)
|
+ "Please update your configuration."
|
||||||
self.log(
|
)
|
||||||
"warning",
|
self.log(
|
||||||
message,
|
"warning",
|
||||||
)
|
message,
|
||||||
try:
|
)
|
||||||
from authentik.events.models import Event, EventAction
|
deprecation_replacements[(deprecation, replacement)] = message
|
||||||
|
|
||||||
Event.new(EventAction.CONFIGURATION_ERROR, message=message).save()
|
deprecated_attr = _pop_deprecated_key(self.__config, deprecation.split("."), 0)
|
||||||
except ImportError:
|
self.set(replacement, deprecated_attr)
|
||||||
continue
|
return deprecation_replacements
|
||||||
|
|
||||||
deprecated_attr = _pop_deprecated_key(self.__config, deprecation.split("."), 0)
|
|
||||||
self.set(replacement, deprecated_attr.value)
|
|
||||||
|
|
||||||
def log(self, level: str, message: str, **kwargs):
|
def log(self, level: str, message: str, **kwargs):
|
||||||
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
|
||||||
|
@ -317,7 +318,9 @@ class ConfigLoader:
|
||||||
|
|
||||||
def set(self, path: str, value: Any, sep="."):
|
def set(self, path: str, value: Any, sep="."):
|
||||||
"""Set value using same syntax as get()"""
|
"""Set value using same syntax as get()"""
|
||||||
set_path_in_dict(self.raw, path, Attr(value), sep=sep)
|
if not isinstance(value, Attr):
|
||||||
|
value = Attr(value)
|
||||||
|
set_path_in_dict(self.raw, path, value, sep=sep)
|
||||||
|
|
||||||
|
|
||||||
CONFIG = ConfigLoader()
|
CONFIG = ConfigLoader()
|
||||||
|
|
|
@ -8,6 +8,8 @@ postgresql:
|
||||||
password: "env://POSTGRES_PASSWORD"
|
password: "env://POSTGRES_PASSWORD"
|
||||||
use_pgbouncer: false
|
use_pgbouncer: false
|
||||||
use_pgpool: false
|
use_pgpool: false
|
||||||
|
test:
|
||||||
|
name: test_authentik
|
||||||
|
|
||||||
listen:
|
listen:
|
||||||
listen_http: 0.0.0.0:9000
|
listen_http: 0.0.0.0:9000
|
||||||
|
@ -106,7 +108,10 @@ cookie_domain: null
|
||||||
disable_update_check: false
|
disable_update_check: false
|
||||||
disable_startup_analytics: false
|
disable_startup_analytics: false
|
||||||
avatars: env://AUTHENTIK_AUTHENTIK__AVATARS?gravatar,initials
|
avatars: env://AUTHENTIK_AUTHENTIK__AVATARS?gravatar,initials
|
||||||
geoip: "/geoip/GeoLite2-City.mmdb"
|
events:
|
||||||
|
context_processors:
|
||||||
|
geoip: "/geoip/GeoLite2-City.mmdb"
|
||||||
|
asn: "/geoip/GeoLite2-ASN.mmdb"
|
||||||
|
|
||||||
footer_links: []
|
footer_links: []
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ from django.test import RequestFactory, TestCase
|
||||||
|
|
||||||
from authentik.core.models import Token, TokenIntents, UserTypes
|
from authentik.core.models import Token, TokenIntents, UserTypes
|
||||||
from authentik.core.tests.utils import create_test_admin_user
|
from authentik.core.tests.utils import create_test_admin_user
|
||||||
from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
|
|
||||||
from authentik.lib.views import bad_request_message
|
from authentik.lib.views import bad_request_message
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
|
|
||||||
class TestHTTP(TestCase):
|
class TestHTTP(TestCase):
|
||||||
|
@ -22,12 +22,12 @@ class TestHTTP(TestCase):
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
"""Test normal request"""
|
"""Test normal request"""
|
||||||
request = self.factory.get("/")
|
request = self.factory.get("/")
|
||||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||||
|
|
||||||
def test_forward_for(self):
|
def test_forward_for(self):
|
||||||
"""Test x-forwarded-for request"""
|
"""Test x-forwarded-for request"""
|
||||||
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
|
request = self.factory.get("/", HTTP_X_FORWARDED_FOR="127.0.0.2")
|
||||||
self.assertEqual(get_client_ip(request), "127.0.0.2")
|
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.2")
|
||||||
|
|
||||||
def test_fake_outpost(self):
|
def test_fake_outpost(self):
|
||||||
"""Test faked IP which is overridden by an outpost"""
|
"""Test faked IP which is overridden by an outpost"""
|
||||||
|
@ -38,28 +38,28 @@ class TestHTTP(TestCase):
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
**{
|
**{
|
||||||
OUTPOST_REMOTE_IP_HEADER: "1.2.3.4",
|
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||||
OUTPOST_TOKEN_HEADER: "abc",
|
ClientIPMiddleware.outpost_token_header: "abc",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||||
# Invalid, user doesn't have permissions
|
# Invalid, user doesn't have permissions
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
**{
|
**{
|
||||||
OUTPOST_REMOTE_IP_HEADER: "1.2.3.4",
|
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||||
OUTPOST_TOKEN_HEADER: token.key,
|
ClientIPMiddleware.outpost_token_header: token.key,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(get_client_ip(request), "127.0.0.1")
|
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "127.0.0.1")
|
||||||
# Valid
|
# Valid
|
||||||
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
self.user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
|
||||||
self.user.save()
|
self.user.save()
|
||||||
request = self.factory.get(
|
request = self.factory.get(
|
||||||
"/",
|
"/",
|
||||||
**{
|
**{
|
||||||
OUTPOST_REMOTE_IP_HEADER: "1.2.3.4",
|
ClientIPMiddleware.outpost_remote_ip_header: "1.2.3.4",
|
||||||
OUTPOST_TOKEN_HEADER: token.key,
|
ClientIPMiddleware.outpost_token_header: token.key,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(get_client_ip(request), "1.2.3.4")
|
self.assertEqual(ClientIPMiddleware.get_client_ip(request), "1.2.3.4")
|
||||||
|
|
|
@ -1,89 +1,39 @@
|
||||||
"""http helpers"""
|
"""http helpers"""
|
||||||
from typing import Any, Optional
|
from uuid import uuid4
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.conf import settings
|
||||||
from requests.sessions import Session
|
from requests.sessions import PreparedRequest, Session
|
||||||
from sentry_sdk.hub import Hub
|
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik import get_full_version
|
from authentik import get_full_version
|
||||||
|
|
||||||
OUTPOST_REMOTE_IP_HEADER = "HTTP_X_AUTHENTIK_REMOTE_IP"
|
|
||||||
OUTPOST_TOKEN_HEADER = "HTTP_X_AUTHENTIK_OUTPOST_TOKEN" # nosec
|
|
||||||
DEFAULT_IP = "255.255.255.255"
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip_from_meta(meta: dict[str, Any]) -> str:
|
|
||||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
|
||||||
Returns none if no IP Could be found
|
|
||||||
|
|
||||||
No additional validation is done here as requests are expected to only arrive here
|
|
||||||
via the go proxy, which deals with validating these headers for us"""
|
|
||||||
headers = (
|
|
||||||
"HTTP_X_FORWARDED_FOR",
|
|
||||||
"REMOTE_ADDR",
|
|
||||||
)
|
|
||||||
for _header in headers:
|
|
||||||
if _header in meta:
|
|
||||||
ips: list[str] = meta.get(_header).split(",")
|
|
||||||
return ips[0].strip()
|
|
||||||
return DEFAULT_IP
|
|
||||||
|
|
||||||
|
|
||||||
def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
|
||||||
"""Get the actual remote IP when set by an outpost. Only
|
|
||||||
allowed when the request is authenticated, by an outpost internal service account"""
|
|
||||||
from authentik.core.models import Token, TokenIntents, UserTypes
|
|
||||||
|
|
||||||
if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
|
|
||||||
return None
|
|
||||||
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
|
|
||||||
token = (
|
|
||||||
Token.filter_not_expired(
|
|
||||||
key=request.META.get(OUTPOST_TOKEN_HEADER), intent=TokenIntents.INTENT_API
|
|
||||||
)
|
|
||||||
.select_related("user")
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not token:
|
|
||||||
LOGGER.warning("Attempted remote-ip override without token", fake_ip=fake_ip)
|
|
||||||
return None
|
|
||||||
user = token.user
|
|
||||||
if user.type != UserTypes.INTERNAL_SERVICE_ACCOUNT:
|
|
||||||
LOGGER.warning(
|
|
||||||
"Remote-IP override: user doesn't have permission",
|
|
||||||
user=user,
|
|
||||||
fake_ip=fake_ip,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
# Update sentry scope to include correct IP
|
|
||||||
user = Hub.current.scope._user
|
|
||||||
if not user:
|
|
||||||
user = {}
|
|
||||||
user["ip_address"] = fake_ip
|
|
||||||
Hub.current.scope.set_user(user)
|
|
||||||
return fake_ip
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_ip(request: Optional[HttpRequest]) -> str:
|
|
||||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
|
||||||
Returns none if no IP Could be found"""
|
|
||||||
if not request:
|
|
||||||
return DEFAULT_IP
|
|
||||||
override = _get_outpost_override_ip(request)
|
|
||||||
if override:
|
|
||||||
return override
|
|
||||||
return _get_client_ip_from_meta(request.META)
|
|
||||||
|
|
||||||
|
|
||||||
def authentik_user_agent() -> str:
|
def authentik_user_agent() -> str:
|
||||||
"""Get a common user agent"""
|
"""Get a common user agent"""
|
||||||
return f"authentik@{get_full_version()}"
|
return f"authentik@{get_full_version()}"
|
||||||
|
|
||||||
|
|
||||||
|
class DebugSession(Session):
|
||||||
|
"""requests session which logs http requests and responses"""
|
||||||
|
|
||||||
|
def send(self, req: PreparedRequest, *args, **kwargs):
|
||||||
|
request_id = str(uuid4())
|
||||||
|
LOGGER.debug("HTTP request sent", uid=request_id, path=req.path_url, headers=req.headers)
|
||||||
|
resp = super().send(req, *args, **kwargs)
|
||||||
|
LOGGER.debug(
|
||||||
|
"HTTP response received",
|
||||||
|
uid=request_id,
|
||||||
|
status=resp.status_code,
|
||||||
|
body=resp.text,
|
||||||
|
headers=resp.headers,
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def get_http_session() -> Session:
|
def get_http_session() -> Session:
|
||||||
"""Get a requests session with common headers"""
|
"""Get a requests session with common headers"""
|
||||||
session = Session()
|
session = DebugSession() if settings.DEBUG else Session()
|
||||||
session.headers["User-Agent"] = authentik_user_agent()
|
session.headers["User-Agent"] = authentik_user_agent()
|
||||||
return session
|
return session
|
||||||
|
|
|
@ -9,16 +9,17 @@ from rest_framework.fields import BooleanField, CharField, DateTimeField
|
||||||
from rest_framework.relations import PrimaryKeyRelatedField
|
from rest_framework.relations import PrimaryKeyRelatedField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.serializers import JSONField, ModelSerializer, ValidationError
|
from rest_framework.serializers import ModelSerializer, ValidationError
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik import get_build_hash
|
from authentik import get_build_hash
|
||||||
from authentik.core.api.providers import ProviderSerializer
|
from authentik.core.api.providers import ProviderSerializer
|
||||||
from authentik.core.api.used_by import UsedByMixin
|
from authentik.core.api.used_by import UsedByMixin
|
||||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
from authentik.core.models import Provider
|
from authentik.core.models import Provider
|
||||||
|
from authentik.enterprise.providers.rac.models import RACProvider
|
||||||
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
||||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
||||||
from authentik.outposts.models import (
|
from authentik.outposts.models import (
|
||||||
Outpost,
|
Outpost,
|
||||||
OutpostConfig,
|
OutpostConfig,
|
||||||
|
@ -34,7 +35,7 @@ from authentik.providers.radius.models import RadiusProvider
|
||||||
class OutpostSerializer(ModelSerializer):
|
class OutpostSerializer(ModelSerializer):
|
||||||
"""Outpost Serializer"""
|
"""Outpost Serializer"""
|
||||||
|
|
||||||
config = JSONField(validators=[is_dict], source="_config")
|
config = JSONDictField(source="_config")
|
||||||
# Need to set allow_empty=True for the embedded outpost with no providers
|
# Need to set allow_empty=True for the embedded outpost with no providers
|
||||||
# is checked for other providers in the API Viewset
|
# is checked for other providers in the API Viewset
|
||||||
providers = PrimaryKeyRelatedField(
|
providers = PrimaryKeyRelatedField(
|
||||||
|
@ -47,12 +48,23 @@ class OutpostSerializer(ModelSerializer):
|
||||||
source="service_connection", read_only=True
|
source="service_connection", read_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_name(self, name: str) -> str:
|
||||||
|
"""Validate name (especially for embedded outpost)"""
|
||||||
|
if not self.instance:
|
||||||
|
return name
|
||||||
|
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
|
||||||
|
raise ValidationError("Embedded outpost's name cannot be changed")
|
||||||
|
if self.instance.name == MANAGED_OUTPOST_NAME:
|
||||||
|
self.instance.managed = MANAGED_OUTPOST
|
||||||
|
return name
|
||||||
|
|
||||||
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
|
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
|
||||||
"""Check that all providers match the type of the outpost"""
|
"""Check that all providers match the type of the outpost"""
|
||||||
type_map = {
|
type_map = {
|
||||||
OutpostType.LDAP: LDAPProvider,
|
OutpostType.LDAP: LDAPProvider,
|
||||||
OutpostType.PROXY: ProxyProvider,
|
OutpostType.PROXY: ProxyProvider,
|
||||||
OutpostType.RADIUS: RadiusProvider,
|
OutpostType.RADIUS: RadiusProvider,
|
||||||
|
OutpostType.RAC: RACProvider,
|
||||||
None: Provider,
|
None: Provider,
|
||||||
}
|
}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
@ -95,7 +107,7 @@ class OutpostSerializer(ModelSerializer):
|
||||||
class OutpostDefaultConfigSerializer(PassiveSerializer):
|
class OutpostDefaultConfigSerializer(PassiveSerializer):
|
||||||
"""Global default outpost config"""
|
"""Global default outpost config"""
|
||||||
|
|
||||||
config = JSONField(read_only=True)
|
config = JSONDictField(read_only=True)
|
||||||
|
|
||||||
|
|
||||||
class OutpostHealthSerializer(PassiveSerializer):
|
class OutpostHealthSerializer(PassiveSerializer):
|
||||||
|
|
|
@ -15,6 +15,7 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
|
||||||
["outpost", "uid", "version"],
|
["outpost", "uid", "version"],
|
||||||
)
|
)
|
||||||
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
|
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
|
||||||
|
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
|
||||||
|
|
||||||
|
|
||||||
class AuthentikOutpostConfig(ManagedAppConfig):
|
class AuthentikOutpostConfig(ManagedAppConfig):
|
||||||
|
@ -35,14 +36,17 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
||||||
DockerServiceConnection,
|
DockerServiceConnection,
|
||||||
KubernetesServiceConnection,
|
KubernetesServiceConnection,
|
||||||
Outpost,
|
Outpost,
|
||||||
OutpostConfig,
|
|
||||||
OutpostType,
|
OutpostType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
|
||||||
|
outpost.managed = MANAGED_OUTPOST
|
||||||
|
outpost.save()
|
||||||
|
return
|
||||||
outpost, updated = Outpost.objects.update_or_create(
|
outpost, updated = Outpost.objects.update_or_create(
|
||||||
defaults={
|
defaults={
|
||||||
"name": "authentik Embedded Outpost",
|
|
||||||
"type": OutpostType.PROXY,
|
"type": OutpostType.PROXY,
|
||||||
|
"name": MANAGED_OUTPOST_NAME,
|
||||||
},
|
},
|
||||||
managed=MANAGED_OUTPOST,
|
managed=MANAGED_OUTPOST,
|
||||||
)
|
)
|
||||||
|
@ -51,10 +55,4 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
||||||
outpost.service_connection = KubernetesServiceConnection.objects.first()
|
outpost.service_connection = KubernetesServiceConnection.objects.first()
|
||||||
elif DockerServiceConnection.objects.exists():
|
elif DockerServiceConnection.objects.exists():
|
||||||
outpost.service_connection = DockerServiceConnection.objects.first()
|
outpost.service_connection = DockerServiceConnection.objects.first()
|
||||||
outpost.config = OutpostConfig(
|
|
||||||
kubernetes_disabled_components=[
|
|
||||||
"deployment",
|
|
||||||
"secret",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
outpost.save()
|
outpost.save()
|
||||||
|
|
|
@ -6,16 +6,18 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from channels.exceptions import DenyConnection
|
from channels.exceptions import DenyConnection
|
||||||
|
from channels.generic.websocket import JsonWebsocketConsumer
|
||||||
from dacite.core import from_dict
|
from dacite.core import from_dict
|
||||||
from dacite.data import Data
|
from dacite.data import Data
|
||||||
|
from django.http.request import QueryDict
|
||||||
from guardian.shortcuts import get_objects_for_user
|
from guardian.shortcuts import get_objects_for_user
|
||||||
from structlog.stdlib import BoundLogger, get_logger
|
from structlog.stdlib import BoundLogger, get_logger
|
||||||
|
|
||||||
from authentik.core.channels import AuthJsonConsumer
|
|
||||||
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
||||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||||
|
|
||||||
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
||||||
|
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
|
||||||
|
|
||||||
|
|
||||||
class WebsocketMessageInstruction(IntEnum):
|
class WebsocketMessageInstruction(IntEnum):
|
||||||
|
@ -42,25 +44,23 @@ class WebsocketMessage:
|
||||||
args: dict[str, Any] = field(default_factory=dict)
|
args: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class OutpostConsumer(AuthJsonConsumer):
|
class OutpostConsumer(JsonWebsocketConsumer):
|
||||||
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
||||||
|
|
||||||
outpost: Optional[Outpost] = None
|
outpost: Optional[Outpost] = None
|
||||||
logger: BoundLogger
|
logger: BoundLogger
|
||||||
|
|
||||||
last_uid: Optional[str] = None
|
instance_uid: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
super().connect()
|
|
||||||
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
||||||
|
user = self.scope["user"]
|
||||||
outpost = (
|
outpost = (
|
||||||
get_objects_for_user(self.user, "authentik_outposts.view_outpost")
|
get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
|
||||||
.filter(pk=uuid)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
if not outpost:
|
if not outpost:
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
|
@ -71,13 +71,19 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
self.logger.warning("runtime error during accept", exc=exc)
|
self.logger.warning("runtime error during accept", exc=exc)
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
self.outpost = outpost
|
self.outpost = outpost
|
||||||
self.last_uid = self.channel_name
|
query = QueryDict(self.scope["query_string"].decode())
|
||||||
|
self.instance_uid = query.get("instance_uuid", self.channel_name)
|
||||||
async_to_sync(self.channel_layer.group_add)(
|
async_to_sync(self.channel_layer.group_add)(
|
||||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||||
)
|
)
|
||||||
|
async_to_sync(self.channel_layer.group_add)(
|
||||||
|
OUTPOST_GROUP_INSTANCE
|
||||||
|
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||||
|
self.channel_name,
|
||||||
|
)
|
||||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||||
outpost=self.outpost.name,
|
outpost=self.outpost.name,
|
||||||
uid=self.last_uid,
|
uid=self.instance_uid,
|
||||||
expected=self.outpost.config.kubernetes_replicas,
|
expected=self.outpost.config.kubernetes_replicas,
|
||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
|
@ -86,34 +92,37 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
async_to_sync(self.channel_layer.group_discard)(
|
async_to_sync(self.channel_layer.group_discard)(
|
||||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||||
)
|
)
|
||||||
if self.outpost and self.last_uid:
|
if self.instance_uid:
|
||||||
|
async_to_sync(self.channel_layer.group_discard)(
|
||||||
|
OUTPOST_GROUP_INSTANCE
|
||||||
|
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||||
|
self.channel_name,
|
||||||
|
)
|
||||||
|
if self.outpost and self.instance_uid:
|
||||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||||
outpost=self.outpost.name,
|
outpost=self.outpost.name,
|
||||||
uid=self.last_uid,
|
uid=self.instance_uid,
|
||||||
expected=self.outpost.config.kubernetes_replicas,
|
expected=self.outpost.config.kubernetes_replicas,
|
||||||
).dec()
|
).dec()
|
||||||
|
|
||||||
def receive_json(self, content: Data, **kwargs):
|
def receive_json(self, content: Data, **kwargs):
|
||||||
msg = from_dict(WebsocketMessage, content)
|
msg = from_dict(WebsocketMessage, content)
|
||||||
uid = msg.args.get("uuid", self.channel_name)
|
|
||||||
self.last_uid = uid
|
|
||||||
|
|
||||||
if not self.outpost:
|
if not self.outpost:
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
|
|
||||||
state = OutpostState.for_instance_uid(self.outpost, uid)
|
state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
|
||||||
state.last_seen = datetime.now()
|
state.last_seen = datetime.now()
|
||||||
state.hostname = msg.args.pop("hostname", "")
|
state.hostname = msg.args.pop("hostname", "")
|
||||||
|
|
||||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||||
state.version = msg.args.pop("version", None)
|
state.version = msg.args.pop("version", None)
|
||||||
state.build_hash = msg.args.pop("buildHash", "")
|
state.build_hash = msg.args.pop("buildHash", "")
|
||||||
state.args = msg.args
|
state.args.update(msg.args)
|
||||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||||
return
|
return
|
||||||
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
||||||
outpost=self.outpost.name,
|
outpost=self.outpost.name,
|
||||||
uid=self.last_uid or "",
|
uid=self.instance_uid or "",
|
||||||
version=state.version or "",
|
version=state.version or "",
|
||||||
).set_to_current_time()
|
).set_to_current_time()
|
||||||
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
|
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
|
||||||
|
|
|
@ -43,6 +43,10 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
||||||
self.api = AppsV1Api(controller.client)
|
self.api = AppsV1Api(controller.client)
|
||||||
self.outpost = self.controller.outpost
|
self.outpost = self.controller.outpost
|
||||||
|
|
||||||
|
@property
|
||||||
|
def noop(self) -> bool:
|
||||||
|
return self.is_embedded
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reconciler_name() -> str:
|
def reconciler_name() -> str:
|
||||||
return "deployment"
|
return "deployment"
|
||||||
|
|
|
@ -24,6 +24,10 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
||||||
super().__init__(controller)
|
super().__init__(controller)
|
||||||
self.api = CoreV1Api(controller.client)
|
self.api = CoreV1Api(controller.client)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def noop(self) -> bool:
|
||||||
|
return self.is_embedded
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reconciler_name() -> str:
|
def reconciler_name() -> str:
|
||||||
return "secret"
|
return "secret"
|
||||||
|
|
|
@ -77,7 +77,10 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def noop(self) -> bool:
|
def noop(self) -> bool:
|
||||||
return (not self._crd_exists()) or (self.is_embedded)
|
if not self._crd_exists():
|
||||||
|
self.logger.debug("CRD doesn't exist")
|
||||||
|
return True
|
||||||
|
return self.is_embedded
|
||||||
|
|
||||||
def _crd_exists(self) -> bool:
|
def _crd_exists(self) -> bool:
|
||||||
"""Check if the Prometheus ServiceMonitor exists"""
|
"""Check if the Prometheus ServiceMonitor exists"""
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""k8s utils"""
|
"""k8s utils"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from kubernetes.client.models.v1_container_port import V1ContainerPort
|
from kubernetes.client.models.v1_container_port import V1ContainerPort
|
||||||
from kubernetes.client.models.v1_service_port import V1ServicePort
|
from kubernetes.client.models.v1_service_port import V1ServicePort
|
||||||
|
@ -37,9 +38,12 @@ def compare_port(
|
||||||
|
|
||||||
|
|
||||||
def compare_ports(
|
def compare_ports(
|
||||||
current: list[V1ServicePort | V1ContainerPort], reference: list[V1ServicePort | V1ContainerPort]
|
current: Optional[list[V1ServicePort | V1ContainerPort]],
|
||||||
|
reference: Optional[list[V1ServicePort | V1ContainerPort]],
|
||||||
):
|
):
|
||||||
"""Compare ports of a list"""
|
"""Compare ports of a list"""
|
||||||
|
if not current or not reference:
|
||||||
|
raise NeedsRecreate()
|
||||||
if len(current) != len(reference):
|
if len(current) != len(reference):
|
||||||
raise NeedsRecreate()
|
raise NeedsRecreate()
|
||||||
for port in reference:
|
for port in reference:
|
||||||
|
|
|
@ -81,7 +81,10 @@ class KubernetesController(BaseController):
|
||||||
def up(self):
|
def up(self):
|
||||||
try:
|
try:
|
||||||
for reconcile_key in self.reconcile_order:
|
for reconcile_key in self.reconcile_order:
|
||||||
reconciler = self.reconcilers[reconcile_key](self)
|
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||||
|
if not reconciler_cls:
|
||||||
|
continue
|
||||||
|
reconciler = reconciler_cls(self)
|
||||||
reconciler.up()
|
reconciler.up()
|
||||||
|
|
||||||
except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc:
|
except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc:
|
||||||
|
@ -95,7 +98,10 @@ class KubernetesController(BaseController):
|
||||||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||||
continue
|
continue
|
||||||
with capture_logs() as logs:
|
with capture_logs() as logs:
|
||||||
reconciler = self.reconcilers[reconcile_key](self)
|
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||||
|
if not reconciler_cls:
|
||||||
|
continue
|
||||||
|
reconciler = reconciler_cls(self)
|
||||||
reconciler.up()
|
reconciler.up()
|
||||||
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
||||||
return all_logs
|
return all_logs
|
||||||
|
@ -105,7 +111,10 @@ class KubernetesController(BaseController):
|
||||||
def down(self):
|
def down(self):
|
||||||
try:
|
try:
|
||||||
for reconcile_key in self.reconcile_order:
|
for reconcile_key in self.reconcile_order:
|
||||||
reconciler = self.reconcilers[reconcile_key](self)
|
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||||
|
if not reconciler_cls:
|
||||||
|
continue
|
||||||
|
reconciler = reconciler_cls(self)
|
||||||
self.logger.debug("Tearing down object", name=reconcile_key)
|
self.logger.debug("Tearing down object", name=reconcile_key)
|
||||||
reconciler.down()
|
reconciler.down()
|
||||||
|
|
||||||
|
@ -120,7 +129,10 @@ class KubernetesController(BaseController):
|
||||||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||||
continue
|
continue
|
||||||
with capture_logs() as logs:
|
with capture_logs() as logs:
|
||||||
reconciler = self.reconcilers[reconcile_key](self)
|
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||||
|
if not reconciler_cls:
|
||||||
|
continue
|
||||||
|
reconciler = reconciler_cls(self)
|
||||||
reconciler.down()
|
reconciler.down()
|
||||||
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
||||||
return all_logs
|
return all_logs
|
||||||
|
@ -130,7 +142,10 @@ class KubernetesController(BaseController):
|
||||||
def get_static_deployment(self) -> str:
|
def get_static_deployment(self) -> str:
|
||||||
documents = []
|
documents = []
|
||||||
for reconcile_key in self.reconcile_order:
|
for reconcile_key in self.reconcile_order:
|
||||||
reconciler = self.reconcilers[reconcile_key](self)
|
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||||
|
if not reconciler_cls:
|
||||||
|
continue
|
||||||
|
reconciler = reconciler_cls(self)
|
||||||
if reconciler.noop:
|
if reconciler.noop:
|
||||||
continue
|
continue
|
||||||
documents.append(reconciler.get_reference_object().to_dict())
|
documents.append(reconciler.get_reference_object().to_dict())
|
||||||
|
|
25
authentik/outposts/migrations/0021_alter_outpost_type.py
Normal file
25
authentik/outposts/migrations/0021_alter_outpost_type.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Generated by Django 4.2.6 on 2023-10-14 19:23
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("authentik_outposts", "0020_alter_outpost_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="outpost",
|
||||||
|
name="type",
|
||||||
|
field=models.TextField(
|
||||||
|
choices=[
|
||||||
|
("proxy", "Proxy"),
|
||||||
|
("ldap", "Ldap"),
|
||||||
|
("radius", "Radius"),
|
||||||
|
("rac", "Rac"),
|
||||||
|
],
|
||||||
|
default="proxy",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -90,11 +90,12 @@ class OutpostModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class OutpostType(models.TextChoices):
|
class OutpostType(models.TextChoices):
|
||||||
"""Outpost types, currently only the reverse proxy is available"""
|
"""Outpost types"""
|
||||||
|
|
||||||
PROXY = "proxy"
|
PROXY = "proxy"
|
||||||
LDAP = "ldap"
|
LDAP = "ldap"
|
||||||
RADIUS = "radius"
|
RADIUS = "radius"
|
||||||
|
RAC = "rac"
|
||||||
|
|
||||||
|
|
||||||
def default_outpost_config(host: Optional[str] = None):
|
def default_outpost_config(host: Optional[str] = None):
|
||||||
|
@ -459,7 +460,7 @@ class OutpostState:
|
||||||
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
|
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
|
||||||
"""Get state for a single instance"""
|
"""Get state for a single instance"""
|
||||||
key = f"{outpost.state_cache_prefix}/{uid}"
|
key = f"{outpost.state_cache_prefix}/{uid}"
|
||||||
default_data = {"uid": uid, "channel_ids": []}
|
default_data = {"uid": uid}
|
||||||
data = cache.get(key, default_data)
|
data = cache.get(key, default_data)
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
cache.delete(key)
|
cache.delete(key)
|
||||||
|
|
|
@ -17,6 +17,8 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
|
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
|
||||||
|
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||||
from authentik.events.monitored_tasks import (
|
from authentik.events.monitored_tasks import (
|
||||||
MonitoredTask,
|
MonitoredTask,
|
||||||
TaskResult,
|
TaskResult,
|
||||||
|
@ -71,6 +73,11 @@ def controller_for_outpost(outpost: Outpost) -> Optional[type[BaseController]]:
|
||||||
return RadiusDockerController
|
return RadiusDockerController
|
||||||
if isinstance(service_connection, KubernetesServiceConnection):
|
if isinstance(service_connection, KubernetesServiceConnection):
|
||||||
return RadiusKubernetesController
|
return RadiusKubernetesController
|
||||||
|
if outpost.type == OutpostType.RAC:
|
||||||
|
if isinstance(service_connection, DockerServiceConnection):
|
||||||
|
return RACDockerController
|
||||||
|
if isinstance(service_connection, KubernetesServiceConnection):
|
||||||
|
return RACKubernetesController
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,13 @@
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from authentik.blueprints.tests import reconcile_app
|
||||||
from authentik.core.models import PropertyMapping
|
from authentik.core.models import PropertyMapping
|
||||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.outposts.api.outposts import OutpostSerializer
|
from authentik.outposts.api.outposts import OutpostSerializer
|
||||||
from authentik.outposts.models import OutpostType, default_outpost_config
|
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||||
|
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
|
||||||
from authentik.providers.ldap.models import LDAPProvider
|
from authentik.providers.ldap.models import LDAPProvider
|
||||||
from authentik.providers.proxy.models import ProxyProvider
|
from authentik.providers.proxy.models import ProxyProvider
|
||||||
|
|
||||||
|
@ -22,7 +24,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
||||||
self.user = create_test_admin_user()
|
self.user = create_test_admin_user()
|
||||||
self.client.force_login(self.user)
|
self.client.force_login(self.user)
|
||||||
|
|
||||||
def test_outpost_validaton(self):
|
@reconcile_app("authentik_outposts")
|
||||||
|
def test_managed_name_change(self):
|
||||||
|
"""Test name change for embedded outpost"""
|
||||||
|
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||||
|
self.assertIsNotNone(embedded_outpost)
|
||||||
|
response = self.client.patch(
|
||||||
|
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
|
||||||
|
{"name": "foo"},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 400)
|
||||||
|
self.assertJSONEqual(
|
||||||
|
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
@reconcile_app("authentik_outposts")
|
||||||
|
def test_managed_without_managed(self):
|
||||||
|
"""Test name change for embedded outpost"""
|
||||||
|
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||||
|
self.assertIsNotNone(embedded_outpost)
|
||||||
|
embedded_outpost.managed = ""
|
||||||
|
embedded_outpost.save()
|
||||||
|
response = self.client.patch(
|
||||||
|
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
|
||||||
|
{"name": "foo"},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
embedded_outpost.refresh_from_db()
|
||||||
|
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
|
||||||
|
|
||||||
|
def test_outpost_validation(self):
|
||||||
"""Test Outpost validation"""
|
"""Test Outpost validation"""
|
||||||
valid = OutpostSerializer(
|
valid = OutpostSerializer(
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Websocket tests"""
|
"""Websocket tests"""
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from channels.exceptions import DenyConnection
|
||||||
from channels.routing import URLRouter
|
from channels.routing import URLRouter
|
||||||
from channels.testing import WebsocketCommunicator
|
from channels.testing import WebsocketCommunicator
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
@ -35,8 +36,9 @@ class TestOutpostWS(TransactionTestCase):
|
||||||
communicator = WebsocketCommunicator(
|
communicator = WebsocketCommunicator(
|
||||||
URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/"
|
URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/"
|
||||||
)
|
)
|
||||||
connected, _ = await communicator.connect()
|
with self.assertRaises(DenyConnection):
|
||||||
self.assertFalse(connected)
|
connected, _ = await communicator.connect()
|
||||||
|
self.assertFalse(connected)
|
||||||
|
|
||||||
async def test_auth_valid(self):
|
async def test_auth_valid(self):
|
||||||
"""Test auth with token"""
|
"""Test auth with token"""
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Outpost Websocket URLS"""
|
"""Outpost Websocket URLS"""
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
|
|
||||||
|
from authentik.core.channels import TokenOutpostMiddleware
|
||||||
from authentik.outposts.api.outposts import OutpostViewSet
|
from authentik.outposts.api.outposts import OutpostViewSet
|
||||||
from authentik.outposts.api.service_connections import (
|
from authentik.outposts.api.service_connections import (
|
||||||
DockerServiceConnectionViewSet,
|
DockerServiceConnectionViewSet,
|
||||||
|
@ -11,7 +12,10 @@ from authentik.outposts.consumer import OutpostConsumer
|
||||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||||
|
|
||||||
websocket_urlpatterns = [
|
websocket_urlpatterns = [
|
||||||
path("ws/outpost/<uuid:pk>/", ChannelsLoggingMiddleware(OutpostConsumer.as_asgi())),
|
path(
|
||||||
|
"ws/outpost/<uuid:pk>/",
|
||||||
|
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
api_urlpatterns = [
|
api_urlpatterns = [
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
"""Serializer for policy execution"""
|
"""Serializer for policy execution"""
|
||||||
from rest_framework.fields import BooleanField, CharField, DictField, JSONField, ListField
|
from rest_framework.fields import BooleanField, CharField, DictField, ListField
|
||||||
from rest_framework.relations import PrimaryKeyRelatedField
|
from rest_framework.relations import PrimaryKeyRelatedField
|
||||||
|
|
||||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ class PolicyTestSerializer(PassiveSerializer):
|
||||||
"""Test policy execution for a user with context"""
|
"""Test policy execution for a user with context"""
|
||||||
|
|
||||||
user = PrimaryKeyRelatedField(queryset=User.objects.all())
|
user = PrimaryKeyRelatedField(queryset=User.objects.all())
|
||||||
context = JSONField(required=False, validators=[is_dict])
|
context = JSONDictField(required=False)
|
||||||
|
|
||||||
|
|
||||||
class PolicyTestResultSerializer(PassiveSerializer):
|
class PolicyTestResultSerializer(PassiveSerializer):
|
||||||
|
|
|
@ -7,9 +7,9 @@ from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||||
from authentik.lib.utils.http import get_client_ip
|
|
||||||
from authentik.policies.exceptions import PolicyException
|
from authentik.policies.exceptions import PolicyException
|
||||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -49,7 +49,7 @@ class PolicyEvaluator(BaseEvaluator):
|
||||||
"""Update context based on http request"""
|
"""Update context based on http request"""
|
||||||
# update website/docs/expressions/_objects.md
|
# update website/docs/expressions/_objects.md
|
||||||
# update website/docs/expressions/_functions.md
|
# update website/docs/expressions/_functions.md
|
||||||
self._context["ak_client_ip"] = ip_address(get_client_ip(request))
|
self._context["ak_client_ip"] = ip_address(ClientIPMiddleware.get_client_ip(request))
|
||||||
self._context["http_request"] = request
|
self._context["http_request"] = request
|
||||||
|
|
||||||
def handle_error(self, exc: Exception, expression_source: str):
|
def handle_error(self, exc: Exception, expression_source: str):
|
||||||
|
|
|
@ -47,6 +47,7 @@ class ReputationSerializer(ModelSerializer):
|
||||||
"identifier",
|
"identifier",
|
||||||
"ip",
|
"ip",
|
||||||
"ip_geo_data",
|
"ip_geo_data",
|
||||||
|
"ip_asn_data",
|
||||||
"score",
|
"score",
|
||||||
"updated",
|
"updated",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Generated by Django 4.2.7 on 2023-12-05 22:20
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("authentik_policies_reputation", "0005_reputation_expires_reputation_expiring"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="reputation",
|
||||||
|
name="ip_asn_data",
|
||||||
|
field=models.JSONField(default=dict),
|
||||||
|
),
|
||||||
|
]
|
|
@ -13,9 +13,9 @@ from structlog import get_logger
|
||||||
from authentik.core.models import ExpiringModel
|
from authentik.core.models import ExpiringModel
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.models import SerializerModel
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.lib.utils.http import get_client_ip
|
|
||||||
from authentik.policies.models import Policy
|
from authentik.policies.models import Policy
|
||||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
|
CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/"
|
||||||
|
@ -44,7 +44,7 @@ class ReputationPolicy(Policy):
|
||||||
return "ak-policy-reputation-form"
|
return "ak-policy-reputation-form"
|
||||||
|
|
||||||
def passes(self, request: PolicyRequest) -> PolicyResult:
|
def passes(self, request: PolicyRequest) -> PolicyResult:
|
||||||
remote_ip = get_client_ip(request.http_request)
|
remote_ip = ClientIPMiddleware.get_client_ip(request.http_request)
|
||||||
query = Q()
|
query = Q()
|
||||||
if self.check_ip:
|
if self.check_ip:
|
||||||
query |= Q(ip=remote_ip)
|
query |= Q(ip=remote_ip)
|
||||||
|
@ -76,6 +76,7 @@ class Reputation(ExpiringModel, SerializerModel):
|
||||||
identifier = models.TextField()
|
identifier = models.TextField()
|
||||||
ip = models.GenericIPAddressField()
|
ip = models.GenericIPAddressField()
|
||||||
ip_geo_data = models.JSONField(default=dict)
|
ip_geo_data = models.JSONField(default=dict)
|
||||||
|
ip_asn_data = models.JSONField(default=dict)
|
||||||
score = models.BigIntegerField(default=0)
|
score = models.BigIntegerField(default=0)
|
||||||
|
|
||||||
expires = models.DateTimeField(default=reputation_expiry)
|
expires = models.DateTimeField(default=reputation_expiry)
|
||||||
|
|
|
@ -7,9 +7,9 @@ from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.signals import login_failed
|
from authentik.core.signals import login_failed
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.http import get_client_ip
|
|
||||||
from authentik.policies.reputation.models import CACHE_KEY_PREFIX
|
from authentik.policies.reputation.models import CACHE_KEY_PREFIX
|
||||||
from authentik.policies.reputation.tasks import save_reputation
|
from authentik.policies.reputation.tasks import save_reputation
|
||||||
|
from authentik.root.middleware import ClientIPMiddleware
|
||||||
from authentik.stages.identification.signals import identification_failed
|
from authentik.stages.identification.signals import identification_failed
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
@ -18,7 +18,7 @@ CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_reputation")
|
||||||
|
|
||||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||||
"""Update score for IP and User"""
|
"""Update score for IP and User"""
|
||||||
remote_ip = get_client_ip(request)
|
remote_ip = ClientIPMiddleware.get_client_ip(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# We only update the cache here, as its faster than writing to the DB
|
# We only update the cache here, as its faster than writing to the DB
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
|
||||||
|
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
|
||||||
from authentik.events.monitored_tasks import (
|
from authentik.events.monitored_tasks import (
|
||||||
MonitoredTask,
|
MonitoredTask,
|
||||||
TaskResult,
|
TaskResult,
|
||||||
|
@ -26,7 +27,8 @@ def save_reputation(self: MonitoredTask):
|
||||||
ip=score["ip"],
|
ip=score["ip"],
|
||||||
identifier=score["identifier"],
|
identifier=score["identifier"],
|
||||||
)
|
)
|
||||||
rep.ip_geo_data = GEOIP_READER.city_dict(score["ip"]) or {}
|
rep.ip_geo_data = GEOIP_CONTEXT_PROCESSOR.city_dict(score["ip"]) or {}
|
||||||
|
rep.ip_asn_data = ASN_CONTEXT_PROCESSOR.asn_dict(score["ip"]) or {}
|
||||||
rep.score = score["score"]
|
rep.score = score["score"]
|
||||||
objects_to_update.append(rep)
|
objects_to_update.append(rep)
|
||||||
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
|
Reputation.objects.bulk_update(objects_to_update, ["score", "ip_geo_data"])
|
||||||
|
|
|
@ -8,8 +8,7 @@ from django.db.models import Model
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.events.geo import GEOIP_READER
|
from authentik.events.context_processors.base import get_context_processors
|
||||||
from authentik.lib.utils.http import get_client_ip
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from authentik.core.models import User
|
from authentik.core.models import User
|
||||||
|
@ -39,12 +38,8 @@ class PolicyRequest:
|
||||||
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
def set_http_request(self, request: HttpRequest): # pragma: no cover
|
||||||
"""Load data from HTTP request, including geoip when enabled"""
|
"""Load data from HTTP request, including geoip when enabled"""
|
||||||
self.http_request = request
|
self.http_request = request
|
||||||
if not GEOIP_READER.enabled:
|
for processor in get_context_processors():
|
||||||
return
|
self.context.update(processor.enrich_context(request))
|
||||||
client_ip = get_client_ip(request)
|
|
||||||
if not client_ip:
|
|
||||||
return
|
|
||||||
self.context["geoip"] = GEOIP_READER.city(client_ip)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def should_cache(self) -> bool:
|
def should_cache(self) -> bool:
|
||||||
|
|
|
@ -7,8 +7,8 @@ GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"
|
||||||
GRANT_TYPE_PASSWORD = "password" # nosec
|
GRANT_TYPE_PASSWORD = "password" # nosec
|
||||||
GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
|
||||||
CLIENT_ASSERTION_TYPE = "client_assertion_type"
|
|
||||||
CLIENT_ASSERTION = "client_assertion"
|
CLIENT_ASSERTION = "client_assertion"
|
||||||
|
CLIENT_ASSERTION_TYPE = "client_assertion_type"
|
||||||
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||||
|
|
||||||
PROMPT_NONE = "none"
|
PROMPT_NONE = "none"
|
||||||
|
@ -18,9 +18,9 @@ PROMPT_LOGIN = "login"
|
||||||
SCOPE_OPENID = "openid"
|
SCOPE_OPENID = "openid"
|
||||||
SCOPE_OPENID_PROFILE = "profile"
|
SCOPE_OPENID_PROFILE = "profile"
|
||||||
SCOPE_OPENID_EMAIL = "email"
|
SCOPE_OPENID_EMAIL = "email"
|
||||||
|
SCOPE_OFFLINE_ACCESS = "offline_access"
|
||||||
|
|
||||||
# https://www.iana.org/assignments/oauth-parameters/\
|
# https://www.iana.org/assignments/oauth-parameters/auth-parameters.xhtml#pkce-code-challenge-method
|
||||||
# oauth-parameters.xhtml#pkce-code-challenge-method
|
|
||||||
PKCE_METHOD_PLAIN = "plain"
|
PKCE_METHOD_PLAIN = "plain"
|
||||||
PKCE_METHOD_S256 = "S256"
|
PKCE_METHOD_S256 = "S256"
|
||||||
|
|
||||||
|
@ -36,6 +36,12 @@ SCOPE_GITHUB_USER_READ = "read:user"
|
||||||
SCOPE_GITHUB_USER_EMAIL = "user:email"
|
SCOPE_GITHUB_USER_EMAIL = "user:email"
|
||||||
# Read info about teams
|
# Read info about teams
|
||||||
SCOPE_GITHUB_ORG_READ = "read:org"
|
SCOPE_GITHUB_ORG_READ = "read:org"
|
||||||
|
SCOPE_GITHUB = {
|
||||||
|
SCOPE_GITHUB_USER,
|
||||||
|
SCOPE_GITHUB_USER_READ,
|
||||||
|
SCOPE_GITHUB_USER_EMAIL,
|
||||||
|
SCOPE_GITHUB_ORG_READ,
|
||||||
|
}
|
||||||
|
|
||||||
ACR_AUTHENTIK_DEFAULT = "goauthentik.io/providers/oauth2/default"
|
ACR_AUTHENTIK_DEFAULT = "goauthentik.io/providers/oauth2/default"
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Reference in a new issue