1+ import contextlib
2+
3+ import mlx .core as mx
4+
5+ from keras .src import tree
6+ from keras .src .backend .common import stateless_scope
7+
8+
19def rnn (
210 step_function ,
311 inputs ,
@@ -11,7 +19,228 @@ def rnn(
1119 zero_output_for_mask = False ,
1220 return_all_outputs = True ,
1321):
14- raise NotImplementedError ("rnn not yet implemented in mlx" )
22+ def swap_batch_timestep (input_t ):
23+ # Swap the batch and timestep dim for the incoming tensor.
24+ axes = list (range (len (input_t .shape )))
25+ axes [0 ], axes [1 ] = 1 , 0
26+ return mx .transpose (input_t , axes )
27+
28+ if not time_major :
29+ inputs = tree .map_structure (swap_batch_timestep , inputs )
30+
31+ flattened_inputs = tree .flatten (inputs )
32+ time_steps = flattened_inputs [0 ].shape [0 ]
33+
34+ if mask is not None :
35+ if mask .dtype != mx .bool_ :
36+ mask = mask .astype (mx .bool_ )
37+ if len (mask .shape ) == 2 :
38+ mask = mx .expand_dims (mask , axis = - 1 )
39+ if not time_major :
40+ mask = swap_batch_timestep (mask )
41+
42+ if constants is None :
43+ constants = []
44+
45+ def _expand_mask (mask_t , input_t , fixed_dim = 1 ):
46+ if tree .is_nested (mask_t ):
47+ raise ValueError (
48+ f"mask_t is expected to be tensor, but got { mask_t } "
49+ )
50+ if tree .is_nested (input_t ):
51+ raise ValueError (
52+ f"input_t is expected to be tensor, but got { input_t } "
53+ )
54+ rank_diff = len (input_t .shape ) - len (mask_t .shape )
55+ for _ in range (rank_diff ):
56+ mask_t = mx .expand_dims (mask_t , axis = - 1 )
57+ multiples = [1 ] * fixed_dim + list (input_t .shape [fixed_dim :])
58+ return mx .tile (mask_t , multiples )
59+
60+ if unroll :
61+ if not time_steps :
62+ raise ValueError ("Unrolling requires a fixed number of timesteps." )
63+ states = tuple (initial_states )
64+ successive_states = []
65+ successive_outputs = []
66+
67+ # Process the input tensors. The input tensor need to be split on the
68+ # time_step dim, and reverse if go_backwards is True. In the case of
69+ # nested input, the input is flattened and then transformed
70+ # individually. The result of this will be a tuple of lists, each of
71+ # the item in tuple is list of the tensor with shape (batch, feature)
72+ def _process_single_input_t (input_t ):
73+ input_t = unstack (input_t ) # unstack for time_step dim
74+ if go_backwards :
75+ input_t .reverse ()
76+ return input_t
77+
78+ if tree .is_nested (inputs ):
79+ processed_input = tree .map_structure (
80+ _process_single_input_t , inputs
81+ )
82+ else :
83+ processed_input = (_process_single_input_t (inputs ),)
84+
85+ def _get_input_tensor (time ):
86+ inp = [t_ [time ] for t_ in processed_input ]
87+ return tree .pack_sequence_as (inputs , inp )
88+
89+ if mask is not None :
90+ mask_list = unstack (mask )
91+ if go_backwards :
92+ mask_list .reverse ()
93+
94+ for i in range (time_steps ):
95+ inp = _get_input_tensor (i )
96+ mask_t = mask_list [i ]
97+ output , new_states = step_function (
98+ inp , tuple (states ) + tuple (constants )
99+ )
100+ tiled_mask_t = _expand_mask (mask_t , output )
101+
102+ if not successive_outputs :
103+ prev_output = mx .zeros_like (output )
104+ else :
105+ prev_output = successive_outputs [- 1 ]
106+
107+ output = mx .where (tiled_mask_t , output , prev_output )
108+
109+ flat_states = tree .flatten (states )
110+ flat_new_states = tree .flatten (new_states )
111+ tiled_mask_t = tuple (
112+ _expand_mask (mask_t , s ) for s in flat_states
113+ )
114+ flat_final_states = tuple (
115+ mx .where (m , s , ps )
116+ for m , s , ps in zip (
117+ tiled_mask_t , flat_new_states , flat_states
118+ )
119+ )
120+ states = tree .pack_sequence_as (states , flat_final_states )
121+
122+ if return_all_outputs :
123+ successive_outputs .append (output )
124+ successive_states .append (states )
125+ else :
126+ successive_outputs = [output ]
127+ successive_states = [states ]
128+ last_output = successive_outputs [- 1 ]
129+ new_states = successive_states [- 1 ]
130+ outputs = mx .stack (successive_outputs )
131+
132+ else : # mask is None
133+ for i in range (time_steps ):
134+ inp = _get_input_tensor (i )
135+ output , states = step_function (
136+ inp , tuple (states ) + tuple (constants )
137+ )
138+ if return_all_outputs :
139+ successive_outputs .append (output )
140+ successive_states .append (states )
141+ else :
142+ successive_outputs = [output ]
143+ successive_states = [states ]
144+ last_output = successive_outputs [- 1 ]
145+ new_states = successive_states [- 1 ]
146+ outputs = mx .stack (successive_outputs )
147+
148+ else : # Unroll == False
149+ if mask is not None :
150+
151+ def _step (states , current_input ):
152+ current_input , current_mask = current_input
153+ is_masked = mx .all (
154+ mx .logical_not (current_mask ), axis = - 1 , keepdims = True
155+ )
156+
157+ output_t , new_states = step_function (current_input , states )
158+
159+ if zero_output_for_mask :
160+ masked_outs = mx .where (
161+ is_masked , mx .zeros_like (output_t ), output_t
162+ )
163+ else :
164+ # Assume the first state is the previous output.
165+ output_tm1 = states [0 ]
166+ masked_outs = mx .where (is_masked , output_tm1 , output_t )
167+
168+ new_states = [
169+ mx .where (is_masked , s , ns )
170+ for s , ns in zip (states , new_states )
171+ ]
172+ return (new_states , masked_outs )
173+
174+ scan_xs = (inputs , mask )
175+
176+ else :
177+
178+ def _step (states , current_input ):
179+ output_t , new_states = step_function (current_input , states )
180+ return new_states , output_t
181+
182+ scan_xs = inputs
183+ if stateless_scope .in_stateless_scope ():
184+ # Reuse the existing parent stateless scope.
185+ scope = contextlib .nullcontext ()
186+ else :
187+ scope = stateless_scope .StatelessScope ()
188+ with scope :
189+ new_states , outputs = mlx_scan (
190+ f = _step ,
191+ init = initial_states ,
192+ xs = scan_xs ,
193+ reverse = go_backwards ,
194+ mask = mask ,
195+ )
196+
197+ if go_backwards :
198+ outputs = reverse_sequence (outputs )
199+
200+ last_output = outputs [- 1 ]
201+
202+ if not time_major :
203+ outputs = tree .map_structure (swap_batch_timestep , outputs )
204+
205+ return last_output , outputs , new_states
206+
207+
208+ def reverse_sequence (xs ):
209+ indices = mx .arange (xs .shape [0 ] - 1 , - 1 , - 1 )
210+ return mx .take (xs , indices , axis = 0 )
211+
212+
213+ def unstack (x , axis = 0 ):
214+ return [mx .take (x , i , axis = axis ) for i in range (x .shape [axis ])]
215+
216+
217+ def mlx_scan (f , init , xs , reverse = False , mask = None ):
218+ states = init
219+ outputs = []
220+
221+ if mask is not None :
222+ x , mask = xs
223+ if reverse :
224+ x = reverse_sequence (x )
225+ mask = reverse_sequence (mask )
226+
227+ for each_x , each_mask in zip (x , mask ):
228+ states , output = f (states , (each_x , each_mask ))
229+ outputs .append (output )
230+ else :
231+ if reverse :
232+ xs = reverse_sequence (xs )
233+
234+ for x in xs :
235+ states , output = f (states , x )
236+ outputs .append (output )
237+
238+ outputs = mx .array (outputs )
239+
240+ if reverse :
241+ outputs = reverse_sequence (outputs )
242+
243+ return states , outputs
15244
16245
17246def cudnn_ok (* args , ** kwargs ):
0 commit comments