4848 IntegerField = models .IntegerField [int , int ]
4949 ForeignKey = models .ForeignKey [Any , Any ]
5050
51+ _StateValue = str | int
5152 _Instance = models .Model # TODO: use real type
5253 _ToDo = Any # TODO: use real type
5354else :
@@ -82,10 +83,10 @@ class ConcurrentTransition(Exception):
8283class Transition :
8384 def __init__ (
8485 self ,
85- method : Callable [..., str | int | None ],
86- source : str | int | Sequence [str | int ] | State ,
87- target : str | int ,
88- on_error : str | int | None ,
86+ method : Callable [..., _StateValue | Any ],
87+ source : _StateValue | Sequence [_StateValue ] | State ,
88+ target : _StateValue ,
89+ on_error : _StateValue | None ,
8990 conditions : list [Callable [[_Instance ], bool ]],
9091 permission : str | Callable [[_Instance , UserWithPermissions ], bool ] | None ,
9192 custom : dict [str , _StrOrPromise ],
@@ -402,7 +403,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
402403 if not issubclass (sender , self .base_cls ):
403404 return
404405
405- def is_field_transition_method (attr ) :
406+ def is_field_transition_method (attr : _ToDo ) -> bool :
406407 return (
407408 (inspect .ismethod (attr ) or inspect .isfunction (attr ))
408409 and hasattr (attr , "_django_fsm" )
@@ -489,7 +490,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
489490 def state_fields (self ) -> Iterable [Any ]:
490491 return filter (lambda field : isinstance (field , FSMFieldMixin ), self ._meta .fields )
491492
492- def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ):
493+ def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ): # type: ignore[no-untyped-def]
493494 # _do_update is called once for each model class in the inheritance hierarchy.
494495 # We can only filter the base_qs on state fields (can be more than one!) present in this particular model.
495496
@@ -533,21 +534,21 @@ def save(self, *args: Any, **kwargs: Any) -> None:
533534
534535def transition (
535536 field : FSMFieldMixin ,
536- source : str | int | Sequence [str | int ] | State = "*" ,
537+ source : str | int | Sequence [str | int ] = "*" ,
537538 target : str | int | State | None = None ,
538539 on_error : str | int | None = None ,
539540 conditions : list [Callable [[Any ], bool ]] = [],
540541 permission : str | Callable [[models .Model , UserWithPermissions ], bool ] | None = None ,
541542 custom : dict [str , _StrOrPromise ] = {},
542- ) -> _ToDo :
543+ ) -> Callable [[ Any ], Any ] :
543544 """
544545 Method decorator to mark allowed transitions.
545546
546547 Set target to None if current state needs to be validated and
547548 has not changed after the function call.
548549 """
549550
550- def inner_transition (func ) :
551+ def inner_transition (func : _ToDo ) -> _ToDo :
551552 wrapper_installed , fsm_meta = True , getattr (func , "_django_fsm" , None )
552553 if not fsm_meta :
553554 wrapper_installed = False
@@ -608,15 +609,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
608609
609610
610611class State :
611- def get_state (self , model , transition , result , args = [], kwargs = {}):
612+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
612613 raise NotImplementedError
613614
614615
615616class RETURN_VALUE (State ):
616617 def __init__ (self , * allowed_states : Sequence [str | int ]) -> None :
617618 self .allowed_states = allowed_states if allowed_states else None
618619
619- def get_state (self , model , transition , result , args = [], kwargs = {}):
620+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
620621 if self .allowed_states is not None :
621622 if result not in self .allowed_states :
622623 raise InvalidResultState (f"{ result } is not in list of allowed states\n { self .allowed_states } " )
@@ -628,7 +629,9 @@ def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] |
628629 self .func = func
629630 self .allowed_states = states
630631
631- def get_state (self , model , transition , result , args = [], kwargs = {}):
632+ def get_state (
633+ self , model : _Model , transition : Transition , result : _StateValue | Any , args : Any = [], kwargs : Any = {}
634+ ) -> _ToDo :
632635 result_state = self .func (model , * args , ** kwargs )
633636 if self .allowed_states is not None :
634637 if result_state not in self .allowed_states :
0 commit comments