11import builtins
2+ import math
23from copy import copy as builtin_copy
34
45import mlx .core as mx
@@ -950,7 +951,6 @@ def quantile(x, q, axis=None, method="linear", keepdims=False):
950951 else :
951952 dtype = dtypes .result_type (x .dtype , float )
952953 mlx_dtype = to_mlx_dtype (dtype )
953- print ("mlx_dtype" , mlx_dtype )
954954
955955 # problem casting mlx bfloat16 array to numpy
956956 if ori_dtype == "bfloat16" :
@@ -1374,8 +1374,43 @@ def wrapped(*args):
13741374 return wrapped
13751375
13761376
1377- def histogram (x , bins , range ):
1378- raise NotImplementedError ("histogram not yet implemented in mlx." )
1377+ def histogram_bin_edges (a , bins = 10 , range = None ):
1378+ # Ref: jax.numpy.histogram
1379+ # infer range if None
1380+ if range is None :
1381+ range = (mx .min (a ).item (), mx .max (a ).item ())
1382+
1383+ if range [0 ] == range [1 ]:
1384+ range = (range [0 ] - 0.5 , range [1 ] + 0.5 )
1385+
1386+ bin_edges = mx .linspace (range [0 ], range [1 ], bins + 1 , dtype = mx .float32 )
1387+ # due to the way mlx currently handles linspace
1388+ # with fp32 precision it is not always right edge inclusive
1389+ # manually set the right edge for now
1390+ bin_edges [- 1 ] = range [- 1 ]
1391+ return bin_edges
1392+
1393+
1394+ def histogram (x , bins = 10 , range = None ):
1395+ # Ref: jax.numpy.histogram
1396+ x = convert_to_tensor (x )
1397+ if range is not None :
1398+ if not isinstance (range , tuple ) or len (range ) != 2 :
1399+ raise ValueError (
1400+ "Invalid value for argument `range`. Only `None` or "
1401+ "a tuple of the lower and upper range of bins is supported. "
1402+ f"Received: range={ range } "
1403+ )
1404+
1405+ bin_edges = histogram_bin_edges (x , bins , range )
1406+
1407+ bin_idx = searchsorted (bin_edges , x , side = "right" )
1408+ bin_idx = mx .where (x == bin_edges [- 1 ], len (bin_edges ) - 1 , bin_idx )
1409+
1410+ counts = mx .zeros (len (bin_edges ))
1411+ counts = counts .at [bin_idx ].add (mx .ones_like (x ))
1412+
1413+ return counts [1 :], bin_edges
13791414
13801415
13811416def unravel_index (x , shape ):
@@ -1384,7 +1419,7 @@ def unravel_index(x, shape):
13841419
13851420 if None in shape :
13861421 raise ValueError (
1387- "`shape` argument cannot contain `None`. Received: shape={shape}"
1422+ f "`shape` argument cannot contain `None`. Received: shape={ shape } "
13881423 )
13891424
13901425 if x .ndim == 1 :
@@ -1403,8 +1438,73 @@ def unravel_index(x, shape):
14031438 return tuple (reversed (coords ))
14041439
14051440
1441+ def searchsorted_binary (a , b , side = "left" ):
1442+ original_shape = b .shape
1443+ b_flat = b .reshape (- 1 )
1444+
1445+ size = a .shape [0 ]
1446+ steps = math .ceil (math .log2 (size ))
1447+ indices = mx .full (b_flat .shape , vals = size // 2 , dtype = mx .uint32 )
1448+
1449+ comparison = lambda x , y : x <= y if side == "left" else lambda x , y : x < y
1450+
1451+ upper = size
1452+ lower = 0
1453+ for _ in range (steps ):
1454+ comp = comparison (b_flat , a [indices ])
1455+ new_indices = mx .where (
1456+ comp , (lower + indices ) // 2 , (indices + upper ) // 2
1457+ )
1458+ lower = mx .where (comp , lower , indices )
1459+ upper = mx .where (comp , indices , upper )
1460+ indices = new_indices
1461+
1462+ result = mx .where (comparison (b_flat , a [indices ]), indices , indices + 1 )
1463+ return result .reshape (original_shape )
1464+
1465+
1466+ def searchsorted_linear (a , b , side = "left" ):
1467+ original_shape = b .shape
1468+ b_flat = b .reshape (- 1 )
1469+ b_flat_broadcast = b_flat .reshape (- 1 , 1 )
1470+ if side == "left" :
1471+ result = (a [None , :] < b_flat_broadcast ).sum (axis = 1 )
1472+ else :
1473+ result = (a [None , :] <= b_flat_broadcast ).sum (axis = 1 )
1474+
1475+ return result .reshape (original_shape )
1476+
1477+
14061478def searchsorted (sorted_sequence , values , side = "left" ):
1407- raise NotImplementedError ("searchsorted not yet implemented in mlx." )
1479+ if side not in ("left" , "right" ):
1480+ raise ValueError (f"Invalid side `{ side } `, must be `left` or `right`." )
1481+ sorted_sequence = convert_to_tensor (sorted_sequence )
1482+ values = convert_to_tensor (values )
1483+ if sorted_sequence .ndim != 1 :
1484+ raise ValueError (
1485+ "Invalid sorted_sequence, should be 1-dimensional. "
1486+ f"Recieved sorted_sequence.shape={ sorted_sequence .shape } "
1487+ )
1488+ if values .ndim == 0 :
1489+ raise ValueError (
1490+ "Invalid values, should be N-dimensional. Recieved "
1491+ f"scalar array values.shape={ values .shape } "
1492+ )
1493+
1494+ sorted_size = sorted_sequence .size
1495+ search_size = values .size
1496+
1497+ # TODO: swap to mlx implementation if exists in the future
1498+ # current implementation and search choice based on discussion:
1499+ # https://github.com/ml-explore/mlx/issues/1255
1500+ use_linear = sorted_size <= 1024 or (
1501+ sorted_size <= 16384 and search_size <= 256
1502+ )
1503+
1504+ if use_linear :
1505+ return searchsorted_linear (sorted_sequence , values , side = side )
1506+ else :
1507+ return searchsorted_binary (sorted_sequence , values , side = side )
14081508
14091509
14101510def diagflat (x , k = 0 ):
0 commit comments