diff --git a/django_fsm/__init__.py b/django_fsm/__init__.py index 77c3234..cad413b 100644 --- a/django_fsm/__init__.py +++ b/django_fsm/__init__.py @@ -431,19 +431,16 @@ def is_fsm_and_protected(f): return {f.attname for f in protected_fields} def refresh_from_db(self, *args, **kwargs): - fields = kwargs.pop("fields", None) + protected_fields = self._get_protected_fsm_fields() - # Use provided fields, if not set then reload all non-deferred fields.0 - if not fields: - deferred_fields = self.get_deferred_fields() - protected_fields = self._get_protected_fsm_fields() - skipped_fields = deferred_fields.union(protected_fields) + for f in protected_fields: + self._meta.get_field(f).protected = False - fields = [f.attname for f in self._meta.concrete_fields if f.attname not in skipped_fields] - - kwargs["fields"] = fields super().refresh_from_db(*args, **kwargs) + for f in protected_fields: + self._meta.get_field(f).protected = True + class ConcurrentTransitionMixin: """ diff --git a/tests/testapp/tests/test_protected_fields.py b/tests/testapp/tests/test_protected_fields.py index ce98f88..dcdb5b1 100644 --- a/tests/testapp/tests/test_protected_fields.py +++ b/tests/testapp/tests/test_protected_fields.py @@ -26,11 +26,8 @@ def test_no_direct_access(self): instance = RefreshableProtectedAccessModel() assert instance.status == "new" - def try_change(): - instance.status = "change" - with pytest.raises(AttributeError): - try_change() + instance.status = "change" instance.publish() instance.save() @@ -38,6 +35,24 @@ def try_change(): def test_refresh_from_db(self): instance = RefreshableModel() + assert instance.status == "new" + instance.save() + + instance.refresh_from_db() + assert instance.status == "new" + + def test_concurrent_refresh_from_db(self): + instance = RefreshableModel() + assert instance.status == "new" instance.save() + # NOTE: This simulates a concurrent update scenario + concurrent_instance = RefreshableModel.objects.get(pk=instance.pk) + assert concurrent_instance.status == instance.status == "new" + concurrent_instance.publish() + assert concurrent_instance.status == "published" + concurrent_instance.save() + + assert instance.status == "new" instance.refresh_from_db() + assert instance.status == "published"