44
55from collections .abc import Iterable
66from functools import reduce
7- from typing import Callable , TypeVar , Union , cast , overload
7+ from typing import TYPE_CHECKING , Callable , TypeVar , Union , cast , overload
8+
9+ if TYPE_CHECKING :
10+ import torch
11+
12+ from vllm .multimodal .inputs import BatchedTensorInputs
813
914_T = TypeVar ("_T" )
1015_U = TypeVar ("_U" )
1722]
1823"""A nested JSON structure where the leaves need not be JSON-serializable."""
1924
25+ _JSONTree = Union [
26+ dict [str , "JSONTree[_T]" ],
27+ list ["JSONTree[_T]" ],
28+ tuple ["JSONTree[_T]" , ...],
29+ dict [str , _T ],
30+ list [_T ],
31+ tuple [_T , ...],
32+ _T ,
33+ ]
34+ """
35+ Same as `JSONTree` but with additional `Union` members to satisfy overloads.
36+ """
37+
2038
2139def json_iter_leaves (value : JSONTree [_T ]) -> Iterable [_T ]:
2240 """Iterate through each leaf in a nested JSON structure."""
@@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
3048 yield value
3149
3250
51+ @overload
52+ def json_map_leaves (
53+ func : Callable [["torch.Tensor" ], "torch.Tensor" ],
54+ value : "BatchedTensorInputs" ,
55+ ) -> "BatchedTensorInputs" :
56+ ...
57+
58+
3359@overload
3460def json_map_leaves (
3561 func : Callable [[_T ], _U ],
@@ -64,11 +90,14 @@ def json_map_leaves(
6490
6591def json_map_leaves (
6692 func : Callable [[_T ], _U ],
67- value : Union [dict [ str , _T ], list [ _T ], tuple [ _T , ...], JSONTree [_T ]],
68- ) -> Union [dict [ str , _U ], list [ _U ], tuple [ _U , ...], JSONTree [_U ]]:
93+ value : Union ["BatchedTensorInputs" , _JSONTree [_T ]],
94+ ) -> Union ["BatchedTensorInputs" , _JSONTree [_U ]]:
6995 """Apply a function to each leaf in a nested JSON structure."""
7096 if isinstance (value , dict ):
71- return {k : json_map_leaves (func , v ) for k , v in value .items ()}
97+ return {
98+ k : json_map_leaves (func , v ) # type: ignore[arg-type]
99+ for k , v in value .items ()
100+ }
72101 elif isinstance (value , list ):
73102 return [json_map_leaves (func , v ) for v in value ]
74103 elif isinstance (value , tuple ):
@@ -125,7 +154,7 @@ def json_reduce_leaves(
125154
126155def json_reduce_leaves (
127156 func : Callable [..., Union [_T , _U ]],
128- value : Union [ dict [ str , _T ], list [ _T ], tuple [ _T , ...], JSONTree [ _T ] ],
157+ value : _JSONTree [ _T ],
129158 initial : _U = cast (_U , ...), # noqa: B008
130159 / ,
131160) -> Union [_T , _U ]:
0 commit comments