Skip to content

Commit 54d05df

Browse files
committed
fix type annotations for hooks
1 parent 3c9c062 commit 54d05df

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

idom/core/events.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def _get_handler_index(self, function: Callable[..., Any]) -> Optional[int]:
222222
return None
223223
else:
224224
for i, h in enumerate(self._handlers):
225-
if h.__wrapped__ == function:
225+
# The `coroutine()` decorator adds a `__wrapped__` attribute
226+
if h.__wrapped__ == function: # type: ignore
226227
return i
227228
else:
228229
return None

idom/core/hooks.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
List,
1414
overload,
1515
)
16+
from typing_extensions import Protocol
1617

1718
from loguru import logger
1819

@@ -89,7 +90,7 @@ def dispatch(
8990

9091
@overload
9192
def use_effect(
92-
function: None, args: Optional[Sequence[Any]]
93+
function: None = None, args: Optional[Sequence[Any]] = None
9394
) -> Callable[[_EffectApplyFunc], None]:
9495
...
9596

@@ -118,7 +119,7 @@ def use_effect(
118119
hook = current_hook()
119120
memoize = use_memo(args=args)
120121

121-
def setup(function: _EffectApplyFunc) -> None:
122+
def add_effect(function: _EffectApplyFunc) -> None:
122123
def effect() -> None:
123124
clean = function()
124125
if clean is not None:
@@ -127,9 +128,10 @@ def effect() -> None:
127128
return memoize(lambda: hook.add_effect("did_render", effect))
128129

129130
if function is not None:
130-
return setup(function)
131+
add_effect(function)
132+
return None
131133
else:
132-
return setup
134+
return add_effect
133135

134136

135137
_ActionType = TypeVar("_ActionType")
@@ -157,7 +159,7 @@ def use_reducer(
157159

158160
def _create_dispatcher(
159161
reducer: Callable[[_StateType, _ActionType], _StateType],
160-
set_state: Callable[[_StateType], None],
162+
set_state: Callable[[Callable[[_StateType], _StateType]], None],
161163
) -> Callable[[_ActionType], None]:
162164
def dispatch(action: _ActionType) -> None:
163165
set_state(lambda last_state: reducer(last_state, action))
@@ -170,20 +172,22 @@ def dispatch(action: _ActionType) -> None:
170172

171173
@overload
172174
def use_callback(
173-
function: None, args: Optional[Sequence[Any]]
174-
) -> Callable[[_CallbackFunc], None]:
175+
function: None = None, args: Optional[Sequence[Any]] = None
176+
) -> Callable[[_CallbackFunc], _CallbackFunc]:
175177
...
176178

177179

178180
@overload
179-
def use_callback(function: _CallbackFunc, args: Optional[Sequence[Any]]) -> None:
181+
def use_callback(
182+
function: _CallbackFunc, args: Optional[Sequence[Any]]
183+
) -> _CallbackFunc:
180184
...
181185

182186

183187
def use_callback(
184188
function: Optional[_CallbackFunc] = None,
185189
args: Optional[Sequence[Any]] = None,
186-
) -> Optional[Callable[[_CallbackFunc], None]]:
190+
) -> Union[_CallbackFunc, Callable[[_CallbackFunc], _CallbackFunc]]:
187191
"""See the full :ref:`use_callback` docs for details
188192
189193
Parameters:
@@ -204,10 +208,17 @@ def setup(function: _CallbackFunc) -> _CallbackFunc:
204208
return setup
205209

206210

211+
class _LambdaCaller(Protocol):
212+
"""MyPy doesn't know how to deal with TypeVars only used in function return"""
213+
214+
def __call__(self, func: Callable[[], _StateType]) -> _StateType:
215+
...
216+
217+
207218
@overload
208219
def use_memo(
209-
function: None, args: Optional[Sequence[Any]]
210-
) -> Callable[[Callable[[], _StateType]], _StateType]:
220+
function: None = None, args: Optional[Sequence[Any]] = None
221+
) -> _LambdaCaller:
211222
...
212223

213224

@@ -231,22 +242,26 @@ def use_memo(
231242
Returns:
232243
The current state
233244
"""
234-
memo = _use_const(_Memo)
245+
memo: _Memo[_StateType] = _use_const(_Memo)
235246

236247
if memo.empty():
237248
# we need to initialize on the first run
238249
changed = True
239250
memo.args = () if args is None else args
251+
elif args is None:
252+
changed = True
240253
elif (
241-
args is None
242-
or len(memo.args) != len(args)
254+
len(memo.args) != len(args)
255+
# if args are same length check identity for each item
243256
or any(current is not new for current, new in zip(memo.args, args))
244257
):
245258
memo.args = args
246259
changed = True
247260
else:
248261
changed = False
249262

263+
setup: Callable[[Callable[[], _StateType]], _StateType]
264+
250265
if changed:
251266

252267
def setup(function: Callable[[], _StateType]) -> _StateType:
@@ -270,7 +285,7 @@ class _Memo(Generic[_StateType]):
270285
__slots__ = "value", "args"
271286

272287
value: _StateType
273-
args: Tuple[Any, ...]
288+
args: Sequence[Any]
274289

275290
def empty(self) -> bool:
276291
try:

0 commit comments

Comments
 (0)