@@ -308,3 +308,57 @@ def stop_gradient(variable):
308308def unstack (x , num = None , axis = 0 ):
309309 y = x .split (num or x .shape [axis ], axis = axis )
310310 return [yi .squeeze (axis ) for yi in y ]
311+
312+
313+ def reverse_sequence (xs ):
314+ indices = mx .arange (xs .shape [0 ] - 1 , - 1 , - 1 )
315+ return mx .take (xs , indices , axis = 0 )
316+
317+
318+ def scan (f , init , xs , reverse = False , mask = None ):
319+ states = init
320+ outputs_list = []
321+
322+ if mask is not None :
323+ x , mask = xs
324+ if reverse :
325+ x = reverse_sequence (x )
326+ mask = reverse_sequence (mask )
327+ iterator = zip (x , mask )
328+ else :
329+ if reverse :
330+ if isinstance (xs , tuple ):
331+ xs = tuple (reverse_sequence (x ) for x in xs )
332+ else :
333+ xs = reverse_sequence (xs )
334+ iterator = zip (* xs ) if isinstance (xs , tuple ) else xs
335+
336+ for x in iterator :
337+ result = f (states , x )
338+ if isinstance (result , tuple ):
339+ states , outputs = result
340+ if outputs is not None :
341+ outputs_list .append (outputs )
342+ else :
343+ states = result
344+
345+ if outputs_list :
346+ if isinstance (outputs_list [0 ], tuple ):
347+ # Multiple outputs case
348+ outputs = tuple (
349+ mx .stack ([out [i ] for out in outputs_list ])
350+ for i in range (len (outputs_list [0 ]))
351+ )
352+ else :
353+ # Single output case
354+ outputs = mx .stack (outputs_list )
355+
356+ if reverse :
357+ if isinstance (outputs , tuple ):
358+ outputs = tuple (reverse_sequence (out ) for out in outputs )
359+ else :
360+ outputs = reverse_sequence (outputs )
361+
362+ return states , outputs
363+
364+ return states , None
0 commit comments