4949 IntegerField = models .IntegerField [int , int ]
5050 ForeignKey = models .ForeignKey [Any , Any ]
5151
52+ _StateValue = str | int
5253 _Instance = models .Model # TODO: use real type
5354 _ToDo = Any # TODO: use real type
5455else :
@@ -83,10 +84,10 @@ class ConcurrentTransition(Exception):
8384class Transition :
8485 def __init__ (
8586 self ,
86- method : Callable [..., str | int | None ],
87- source : str | int | Sequence [str | int ] | State ,
88- target : str | int ,
89- on_error : str | int | None ,
87+ method : Callable [..., _StateValue | Any ],
88+ source : _StateValue | Sequence [_StateValue ] | State ,
89+ target : _StateValue ,
90+ on_error : _StateValue | None ,
9091 conditions : list [Callable [[_Instance ], bool ]],
9192 permission : str | Callable [[_Instance , UserWithPermissions ], bool ] | None ,
9293 custom : dict [str , _StrOrPromise ],
@@ -414,7 +415,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
414415 if not issubclass (sender , self .base_cls ):
415416 return
416417
417- def is_field_transition_method (attr ) :
418+ def is_field_transition_method (attr : _ToDo ) -> bool :
418419 return (
419420 (inspect .ismethod (attr ) or inspect .isfunction (attr ))
420421 and hasattr (attr , "_django_fsm" )
@@ -528,7 +529,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
528529 def state_fields (self ) -> Iterable [Any ]:
529530 return filter (lambda field : isinstance (field , FSMFieldMixin ), self ._meta .fields )
530531
531- def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ):
532+ def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ): # type: ignore[no-untyped-def]
532533 # _do_update is called once for each model class in the inheritance hierarchy.
533534 # We can only filter the base_qs on state fields (can be more than one!) present in this particular model.
534535
@@ -572,21 +573,21 @@ def save(self, *args: Any, **kwargs: Any) -> None:
572573
573574def transition (
574575 field : FSMFieldMixin ,
575- source : str | int | Sequence [str | int ] | State = "*" ,
576+ source : str | int | Sequence [str | int ] = "*" ,
576577 target : str | int | State | None = None ,
577578 on_error : str | int | None = None ,
578579 conditions : list [Callable [[Any ], bool ]] = [],
579580 permission : str | Callable [[models .Model , UserWithPermissions ], bool ] | None = None ,
580581 custom : dict [str , _StrOrPromise ] = {},
581- ) -> _ToDo :
582+ ) -> Callable [[ Any ], Any ] :
582583 """
583584 Method decorator to mark allowed transitions.
584585
585586 Set target to None if current state needs to be validated and
586587 has not changed after the function call.
587588 """
588589
589- def inner_transition (func ) :
590+ def inner_transition (func : _ToDo ) -> _ToDo :
590591 wrapper_installed , fsm_meta = True , getattr (func , "_django_fsm" , None )
591592 if not fsm_meta :
592593 wrapper_installed = False
@@ -647,15 +648,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
647648
648649
649650class State :
650- def get_state (self , model , transition , result , args = [], kwargs = {}):
651+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
651652 raise NotImplementedError
652653
653654
654655class RETURN_VALUE (State ):
655656 def __init__ (self , * allowed_states : Sequence [str | int ]) -> None :
656657 self .allowed_states = allowed_states if allowed_states else None
657658
658- def get_state (self , model , transition , result , args = [], kwargs = {}):
659+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
659660 if self .allowed_states is not None :
660661 if result not in self .allowed_states :
661662 raise InvalidResultState (f"{ result } is not in list of allowed states\n { self .allowed_states } " )
@@ -667,7 +668,9 @@ def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] |
667668 self .func = func
668669 self .allowed_states = states
669670
670- def get_state (self , model , transition , result , args = [], kwargs = {}):
671+ def get_state (
672+ self , model : _Model , transition : Transition , result : _StateValue | Any , args : Any = [], kwargs : Any = {}
673+ ) -> _ToDo :
671674 result_state = self .func (model , * args , ** kwargs )
672675 if self .allowed_states is not None :
673676 if result_state not in self .allowed_states :
0 commit comments