@@ -1379,7 +1379,28 @@ def histogram(x, bins, range):
13791379
13801380
13811381def unravel_index (x , shape ):
1382- raise NotImplementedError ("unravel_index not yet implemented in mlx." )
1382+ x = convert_to_tensor (x )
1383+ input_dtype = x .dtype
1384+
1385+ if None in shape :
1386+ raise ValueError (
1387+ "`shape` argument cannot contain `None`. Received: shape={shape}"
1388+ )
1389+
1390+ if x .ndim == 1 :
1391+ coords = []
1392+ for dim in reversed (shape ):
1393+ coords .append ((x % dim ).astype (input_dtype ))
1394+ x = x // dim
1395+ return tuple (reversed (coords ))
1396+
1397+ x_shape = x .shape
1398+ coords = []
1399+ for dim in shape :
1400+ coords .append (mx .reshape ((x % dim ).astype (input_dtype ), x_shape ))
1401+ x = x // dim
1402+
1403+ return tuple (reversed (coords ))
13831404
13841405
13851406def searchsorted (sorted_sequence , values , side = "left" ):
@@ -1391,4 +1412,46 @@ def diagflat(x, k=0):
13911412
13921413
13931414def rot90 (array , k = 1 , axes = (0 , 1 )):
1394- raise NotImplementedError ("rot90 not yet implemented in mlx." )
1415+ array = convert_to_tensor (array )
1416+
1417+ if array .ndim < 2 :
1418+ raise ValueError (
1419+ f"Input array must have at least 2 dimensions. "
1420+ f"Received: array.ndim={ array .ndim } "
1421+ )
1422+ if len (axes ) != 2 or axes [0 ] == axes [1 ]:
1423+ raise ValueError (
1424+ f"Invalid axes: { axes } . Axes must be a tuple of "
1425+ "two different dimensions."
1426+ )
1427+
1428+ array_axes = list (range (array .ndim ))
1429+ # Swap axes
1430+ array_axes [axes [0 ]], array_axes [axes [1 ]] = (
1431+ array_axes [axes [1 ]],
1432+ array_axes [axes [0 ]],
1433+ )
1434+
1435+ if k < 0 :
1436+ axes = (axes [1 ], axes [0 ])
1437+ k *= - 1
1438+
1439+ k = k % 4
1440+
1441+ if k > 0 :
1442+ slices = [builtins .slice (None ) for _ in range (array .ndim )]
1443+ if k == 2 :
1444+ # 180 deg rotation => reverse elements along both axes
1445+ slices [axes [0 ]] = builtins .slice (None , None , - 1 )
1446+ slices [axes [1 ]] = builtins .slice (None , None , - 1 )
1447+ else :
1448+ # 90 or 270 deg rotation => transpose and reverse along one axis
1449+ array = mx .transpose (array , axes = array_axes )
1450+ if k == 1 :
1451+ slices [axes [0 ]] = builtins .slice (None , None , - 1 )
1452+ else :
1453+ slices [axes [1 ]] = builtins .slice (None , None , - 1 )
1454+
1455+ array = array [tuple (slices )]
1456+
1457+ return array
0 commit comments