flows: fix re-imports of entries with identical PK re-creating objects
closes #2941 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
d25a051eae
commit
56babb2649
|
@ -13,6 +13,25 @@ from authentik.policies.models import PolicyBinding
|
||||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
||||||
from authentik.stages.user_login.models import UserLoginStage
|
from authentik.stages.user_login.models import UserLoginStage
|
||||||
|
|
||||||
|
STATIC_PROMPT_EXPORT = """{
|
||||||
|
"version": 1,
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"identifiers": {
|
||||||
|
"pk": "cb954fd4-65a5-4ad9-b1ee-180ee9559cf4"
|
||||||
|
},
|
||||||
|
"model": "authentik_stages_prompt.prompt",
|
||||||
|
"attrs": {
|
||||||
|
"field_key": "username",
|
||||||
|
"label": "Username",
|
||||||
|
"type": "username",
|
||||||
|
"required": true,
|
||||||
|
"placeholder": "Username",
|
||||||
|
"order": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
|
||||||
class TestFlowTransfer(TransactionTestCase):
|
class TestFlowTransfer(TransactionTestCase):
|
||||||
"""Test flow transfer"""
|
"""Test flow transfer"""
|
||||||
|
@ -58,6 +77,19 @@ class TestFlowTransfer(TransactionTestCase):
|
||||||
|
|
||||||
self.assertTrue(Flow.objects.filter(slug=flow_slug).exists())
|
self.assertTrue(Flow.objects.filter(slug=flow_slug).exists())
|
||||||
|
|
||||||
|
def test_export_validate_import_re_import(self):
|
||||||
|
"""Test export and import it twice"""
|
||||||
|
importer = FlowImporter(STATIC_PROMPT_EXPORT)
|
||||||
|
self.assertTrue(importer.validate())
|
||||||
|
self.assertTrue(importer.apply())
|
||||||
|
|
||||||
|
self.assertEqual(Prompt.objects.filter(field_key="username").count(), 1)
|
||||||
|
|
||||||
|
importer = FlowImporter(STATIC_PROMPT_EXPORT)
|
||||||
|
self.assertTrue(importer.apply())
|
||||||
|
|
||||||
|
self.assertEqual(Prompt.objects.filter(field_key="username").count(), 1)
|
||||||
|
|
||||||
def test_export_validate_import_policies(self):
|
def test_export_validate_import_policies(self):
|
||||||
"""Test export and validate it"""
|
"""Test export and validate it"""
|
||||||
flow_slug = generate_id()
|
flow_slug = generate_id()
|
||||||
|
|
|
@ -115,6 +115,11 @@ class FlowImporter:
|
||||||
serializer_kwargs["instance"] = model_instance
|
serializer_kwargs["instance"] = model_instance
|
||||||
else:
|
else:
|
||||||
self.logger.debug("initialise new instance", model=model, **updated_identifiers)
|
self.logger.debug("initialise new instance", model=model, **updated_identifiers)
|
||||||
|
model_instance = model()
|
||||||
|
# pk needs to be set on the model instance otherwise a new one will be generated
|
||||||
|
if "pk" in updated_identifiers:
|
||||||
|
model_instance.pk = updated_identifiers["pk"]
|
||||||
|
serializer_kwargs["instance"] = model_instance
|
||||||
full_data = self.__update_pks_for_attrs(entry.attrs)
|
full_data = self.__update_pks_for_attrs(entry.attrs)
|
||||||
full_data.update(updated_identifiers)
|
full_data.update(updated_identifiers)
|
||||||
serializer_kwargs["data"] = full_data
|
serializer_kwargs["data"] = full_data
|
||||||
|
@ -167,7 +172,7 @@ class FlowImporter:
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
"""Validate loaded flow export, ensure all models are allowed
|
"""Validate loaded flow export, ensure all models are allowed
|
||||||
and serializers have no errors"""
|
and serializers have no errors"""
|
||||||
self.logger.debug("Starting flow import validaton")
|
self.logger.debug("Starting flow import validation")
|
||||||
if self.__import.version != 1:
|
if self.__import.version != 1:
|
||||||
self.logger.warning("Invalid bundle version")
|
self.logger.warning("Invalid bundle version")
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -168,11 +168,6 @@ class AuthenticatorValidateStageView(ChallengeStageView):
|
||||||
continue
|
continue
|
||||||
# check if device has been used within threshold and skip this stage if so
|
# check if device has been used within threshold and skip this stage if so
|
||||||
if threshold.total_seconds() > 0:
|
if threshold.total_seconds() > 0:
|
||||||
print("yeet")
|
|
||||||
print(get_device_last_usage(device))
|
|
||||||
print(_now - get_device_last_usage(device))
|
|
||||||
print(threshold)
|
|
||||||
print(_now - get_device_last_usage(device) <= threshold)
|
|
||||||
if _now - get_device_last_usage(device) <= threshold:
|
if _now - get_device_last_usage(device) <= threshold:
|
||||||
LOGGER.info("Device has been used within threshold", device=device)
|
LOGGER.info("Device has been used within threshold", device=device)
|
||||||
raise FlowSkipStageException()
|
raise FlowSkipStageException()
|
||||||
|
|
Reference in a new issue