@@ -998,6 +998,49 @@ def array(x, dtype=None):
998998 return convert_to_tensor (x , dtype = dtype )
999999
10001000
1001+ def view (x , dtype = None ):
1002+ from keras .src import backend
1003+
1004+ x = convert_to_tensor (x )
1005+ old_dtype = tf .as_dtype (backend .standardize_dtype (x .dtype ))
1006+ new_dtype = tf .as_dtype (
1007+ backend .standardize_dtype (dtype if dtype else x .dtype )
1008+ )
1009+
1010+ old_itemsize = old_dtype .size
1011+ new_itemsize = new_dtype .size
1012+
1013+ if list (x .shape )[- 1 ] * old_itemsize % new_itemsize != 0 :
1014+ raise ValueError (
1015+ f"Cannot view array of shape { x .shape } and dtype { old_dtype } "
1016+ f"as dtype { new_dtype } because the total number of bytes "
1017+ f"is not divisible by the new itemsize."
1018+ )
1019+
1020+ if old_itemsize == new_itemsize :
1021+ return tf .bitcast (x , type = new_dtype )
1022+ elif old_itemsize > new_itemsize :
1023+ ratio = old_itemsize // new_itemsize
1024+ new_shape = list (shape_op (x ))
1025+ new_shape [- 1 ] *= ratio
1026+ flat_tensor = tf .reshape (x , [- 1 ])
1027+ cast_tensor = tf .bitcast (flat_tensor , type = new_dtype )
1028+ return tf .reshape (cast_tensor , new_shape )
1029+ else :
1030+ old_shape = list (shape_op (x ))
1031+ last_dim_size = old_shape [- 1 ]
1032+ ratio = new_itemsize // old_itemsize
1033+ if isinstance (last_dim_size , int ) and last_dim_size % ratio != 0 :
1034+ raise ValueError (
1035+ f"Cannot view dtype. Last dimension size ({ last_dim_size } ) "
1036+ f"must be divisible by the ratio of new/old item sizes "
1037+ f"({ ratio } )."
1038+ )
1039+ intermediate_shape = old_shape [:- 1 ] + [last_dim_size // ratio , ratio ]
1040+ reshaped_tensor = tf .reshape (x , intermediate_shape )
1041+ return tf .bitcast (reshaped_tensor , new_dtype )
1042+
1043+
10011044def average (x , axis = None , weights = None ):
10021045 x = convert_to_tensor (x )
10031046
@@ -2258,7 +2301,7 @@ def _get_indices(method):
22582301 return gathered_y
22592302 perm = collections .deque (range (ndims ))
22602303 perm .rotate (shift_value_static )
2261- return tf .transpose (a = gathered_y , perm = perm )
2304+ return tf .transpose (a = gathered_y , perm = list ( perm ) )
22622305
22632306
22642307def quantile (x , q , axis = None , method = "linear" , keepdims = False ):
0 commit comments