Skip to content

Commit 253e685

Browse files
committed
Step 3
1 parent 0b001e6 commit 253e685

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

.mypy.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ check_untyped_defs = True
1515
# get passing if you use a lot of untyped libraries
1616
disallow_subclassing_any = True
1717
disallow_untyped_decorators = True
18-
; disallow_any_generics = True
18+
disallow_any_generics = True
1919

2020
# These next few are various gradations of forcing use of type annotations
21-
; disallow_untyped_calls = True
21+
disallow_untyped_calls = True
2222
; disallow_incomplete_defs = True
2323
; disallow_untyped_defs = True
2424

django_fsm/__init__.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,30 @@
3434

3535
if TYPE_CHECKING:
3636
from collections.abc import Callable
37+
from collections.abc import Generator
3738
from collections.abc import Sequence
3839
from typing import Any
3940

40-
from django.contrib.auth.models import AbstractBaseUser
41+
from django.contrib.auth.models import PermissionsMixin as UserWithPermissions
4142
from django.utils.functional import _StrOrPromise
4243

4344
_Model = models.Model
45+
_Field = models.Field[Any, Any]
46+
CharField = models.CharField[str, str]
47+
IntegerField = models.IntegerField[int, int]
48+
ForeignKey = models.ForeignKey[Any, Any]
4449
else:
4550
_Model = object
51+
_Field = object
52+
CharField = models.CharField
53+
IntegerField = models.IntegerField
54+
ForeignKey = models.ForeignKey
4655

4756

4857
class TransitionNotAllowed(Exception):
4958
"""Raised when a transition is not allowed"""
5059

51-
def __init__(self, *args, **kwargs) -> None:
60+
def __init__(self, *args: Any, **kwargs: Any) -> None:
5261
self.object = kwargs.pop("object", None)
5362
self.method = kwargs.pop("method", None)
5463
super().__init__(*args, **kwargs)
@@ -69,12 +78,12 @@ class ConcurrentTransition(Exception):
6978
class Transition:
7079
def __init__(
7180
self,
72-
method: Callable,
81+
method: Callable[..., Any],
7382
source: str | int | Sequence[str | int] | State,
7483
target: str | int | State | None,
7584
on_error: str | int | None,
7685
conditions: list[Callable[[Any], bool]],
77-
permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None,
86+
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None,
7887
custom: dict[str, _StrOrPromise],
7988
) -> None:
8089
self.method = method
@@ -89,7 +98,7 @@ def __init__(
8998
def name(self) -> str:
9099
return self.method.__name__
91100

92-
def has_perm(self, instance, user) -> bool:
101+
def has_perm(self, instance, user: UserWithPermissions) -> bool:
93102
if not self.permission:
94103
return True
95104
elif callable(self.permission):
@@ -102,7 +111,7 @@ def has_perm(self, instance, user) -> bool:
102111
return False
103112

104113

105-
def get_available_FIELD_transitions(instance, field):
114+
def get_available_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]:
106115
"""
107116
List of transitions available in current model state
108117
with all conditions met
@@ -116,14 +125,16 @@ def get_available_FIELD_transitions(instance, field):
116125
yield meta.get_transition(curr_state)
117126

118127

119-
def get_all_FIELD_transitions(instance, field):
128+
def get_all_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]:
120129
"""
121130
List of all transitions available in current model state
122131
"""
123132
return field.get_all_transitions(instance.__class__)
124133

125134

126-
def get_available_user_FIELD_transitions(instance, user, field):
135+
def get_available_user_FIELD_transitions(
136+
instance, user: UserWithPermissions, field: FSMFieldMixin
137+
) -> Generator[Transition, None, None]:
127138
"""
128139
List of transitions available in current model state
129140
with all conditions met and user have rights on it
@@ -142,15 +153,24 @@ def __init__(self, field, method) -> None:
142153
self.field = field
143154
self.transitions: dict[str, Any] = {} # source -> Transition
144155

145-
def get_transition(self, source):
156+
def get_transition(self, source: str):
146157
transition = self.transitions.get(source, None)
147158
if transition is None:
148159
transition = self.transitions.get("*", None)
149160
if transition is None:
150161
transition = self.transitions.get("+", None)
151162
return transition
152163

153-
def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}) -> None:
164+
def add_transition(
165+
self,
166+
method: Callable[..., Any],
167+
source: str,
168+
target: str | int,
169+
on_error: str | int | None = None,
170+
conditions: list[Callable[[Any], bool]] = [],
171+
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
172+
custom: dict[str, _StrOrPromise] = {},
173+
) -> None:
154174
if source in self.transitions:
155175
raise AssertionError(f"Duplicate transition for {source} state")
156176

@@ -192,7 +212,7 @@ def conditions_met(self, instance, state) -> bool:
192212
else:
193213
return all(map(lambda condition: condition(instance), transition.conditions))
194214

195-
def has_transition_perm(self, instance, state, user) -> bool:
215+
def has_transition_perm(self, instance, state, user: UserWithPermissions) -> bool:
196216
transition = self.get_transition(state)
197217

198218
if not transition:
@@ -235,10 +255,10 @@ def __set__(self, instance, value) -> None:
235255
self.field.set_state(instance, value)
236256

237257

238-
class FSMFieldMixin(Field):
258+
class FSMFieldMixin(_Field):
239259
descriptor_class = FSMFieldDescriptor
240260

241-
def __init__(self, *args, **kwargs) -> None:
261+
def __init__(self, *args: Any, **kwargs: Any) -> None:
242262
self.protected = kwargs.pop("protected", False)
243263
self.transitions: dict[Any, dict[str, Any]] = {} # cls -> (transitions name -> method)
244264
self.state_proxy = {} # state -> ProxyClsRef
@@ -263,15 +283,15 @@ def deconstruct(self):
263283
kwargs["protected"] = self.protected
264284
return name, path, args, kwargs
265285

266-
def get_state(self, instance):
286+
def get_state(self, instance) -> Any:
267287
# The state field may be deferred. We delegate the logic of figuring this out
268288
# and loading the deferred field on-demand to Django's built-in DeferredAttribute class.
269289
return DeferredAttribute(self).__get__(instance) # type: ignore[attr-defined]
270290

271-
def set_state(self, instance, state):
291+
def set_state(self, instance, state: str) -> None:
272292
instance.__dict__[self.name] = state
273293

274-
def set_proxy(self, instance, state):
294+
def set_proxy(self, instance, state: str) -> None:
275295
"""
276296
Change class
277297
"""
@@ -292,7 +312,7 @@ def set_proxy(self, instance, state):
292312

293313
instance.__class__ = model
294314

295-
def change_state(self, instance, method, *args, **kwargs):
315+
def change_state(self, instance, method, *args: Any, **kwargs: Any):
296316
meta = method._django_fsm
297317
method_name = method.__name__
298318
current_state = self.get_state(instance)
@@ -345,7 +365,7 @@ def change_state(self, instance, method, *args, **kwargs):
345365

346366
return result
347367

348-
def get_all_transitions(self, instance_cls):
368+
def get_all_transitions(self, instance_cls) -> Generator[Transition, None, None]:
349369
"""
350370
Returns [(source, target, name, method)] for all field transitions
351371
"""
@@ -372,7 +392,7 @@ def contribute_to_class(self, cls, name, private_only=False, **kwargs):
372392

373393
class_prepared.connect(self._collect_transitions)
374394

375-
def _collect_transitions(self, *args, **kwargs):
395+
def _collect_transitions(self, *args: Any, **kwargs: Any):
376396
sender = kwargs["sender"]
377397

378398
if not issubclass(sender, self.base_cls):
@@ -401,25 +421,25 @@ def is_field_transition_method(attr):
401421
self.transitions[sender] = sender_transitions
402422

403423

404-
class FSMField(FSMFieldMixin, models.CharField):
424+
class FSMField(FSMFieldMixin, CharField):
405425
"""
406426
State Machine support for Django model as CharField
407427
"""
408428

409-
def __init__(self, *args, **kwargs) -> None:
429+
def __init__(self, *args: Any, **kwargs: Any) -> None:
410430
kwargs.setdefault("max_length", 50)
411431
super().__init__(*args, **kwargs)
412432

413433

414-
class FSMIntegerField(FSMFieldMixin, models.IntegerField):
434+
class FSMIntegerField(FSMFieldMixin, IntegerField):
415435
"""
416436
Same as FSMField, but stores the state value in an IntegerField.
417437
"""
418438

419439
pass
420440

421441

422-
class FSMKeyField(FSMFieldMixin, models.ForeignKey):
442+
class FSMKeyField(FSMFieldMixin, ForeignKey):
423443
"""
424444
State Machine support for Django model
425445
"""
@@ -457,7 +477,7 @@ class ConcurrentTransitionMixin(_Model):
457477
state, thus practically negating their effect.
458478
"""
459479

460-
def __init__(self, *args, **kwargs) -> None:
480+
def __init__(self, *args: Any, **kwargs: Any) -> None:
461481
super().__init__(*args, **kwargs)
462482
self._update_initial_state()
463483

@@ -495,14 +515,14 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat
495515

496516
return updated
497517

498-
def _update_initial_state(self):
518+
def _update_initial_state(self) -> None:
499519
self.__initial_states = {field.attname: field.value_from_object(self) for field in self.state_fields}
500520

501-
def refresh_from_db(self, *args, **kwargs):
521+
def refresh_from_db(self, *args: Any, **kwargs: Any) -> None:
502522
super().refresh_from_db(*args, **kwargs)
503523
self._update_initial_state()
504524

505-
def save(self, *args, **kwargs):
525+
def save(self, *args: Any, **kwargs: Any) -> None:
506526
super().save(*args, **kwargs)
507527
self._update_initial_state()
508528

@@ -513,7 +533,7 @@ def transition(
513533
target: str | int | State | None = None,
514534
on_error: str | int | None = None,
515535
conditions: list[Callable[[Any], bool]] = [],
516-
permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None = None,
536+
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
517537
custom: dict[str, _StrOrPromise] = {},
518538
):
519539
"""
@@ -537,7 +557,7 @@ def inner_transition(func):
537557
func._django_fsm.add_transition(func, source, target, on_error, conditions, permission, custom)
538558

539559
@wraps(func)
540-
def _change_state(instance, *args, **kwargs):
560+
def _change_state(instance, *args: Any, **kwargs: Any):
541561
return fsm_meta.field.change_state(instance, func, *args, **kwargs)
542562

543563
if not wrapper_installed:
@@ -548,7 +568,7 @@ def _change_state(instance, *args, **kwargs):
548568
return inner_transition
549569

550570

551-
def can_proceed(bound_method, check_conditions=True) -> bool:
571+
def can_proceed(bound_method, check_conditions: bool = True) -> bool:
552572
"""
553573
Returns True if model in state allows to call bound_method
554574
@@ -565,7 +585,7 @@ def can_proceed(bound_method, check_conditions=True) -> bool:
565585
return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state))
566586

567587

568-
def has_transition_perm(bound_method, user) -> bool:
588+
def has_transition_perm(bound_method, user: UserWithPermissions) -> bool:
569589
"""
570590
Returns True if model in state allows to call bound_method and user have rights on it
571591
"""
@@ -589,7 +609,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):
589609

590610

591611
class RETURN_VALUE(State):
592-
def __init__(self, *allowed_states) -> None:
612+
def __init__(self, *allowed_states: Sequence[str | int]) -> None:
593613
self.allowed_states = allowed_states if allowed_states else None
594614

595615
def get_state(self, model, transition, result, args=[], kwargs={}):
@@ -600,7 +620,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):
600620

601621

602622
class GET_STATE(State):
603-
def __init__(self, func, states=None) -> None:
623+
def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] | None = None) -> None:
604624
self.func = func
605625
self.allowed_states = states
606626

0 commit comments

Comments
 (0)