Skip to content

Commit 80f4d4e

Browse files
authored
Merge branch 'keras-team:master' into master
2 parents 1fa75a0 + 960133e commit 80f4d4e

File tree

15 files changed

+266
-13
lines changed

15 files changed

+266
-13
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
from keras.src.ops.numpy import var as var
297297
from keras.src.ops.numpy import vdot as vdot
298298
from keras.src.ops.numpy import vectorize as vectorize
299+
from keras.src.ops.numpy import view as view
299300
from keras.src.ops.numpy import vstack as vstack
300301
from keras.src.ops.numpy import where as where
301302
from keras.src.ops.numpy import zeros as zeros

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
from keras.src.ops.numpy import var as var
183183
from keras.src.ops.numpy import vdot as vdot
184184
from keras.src.ops.numpy import vectorize as vectorize
185+
from keras.src.ops.numpy import view as view
185186
from keras.src.ops.numpy import vstack as vstack
186187
from keras.src.ops.numpy import where as where
187188
from keras.src.ops.numpy import zeros as zeros

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
from keras.src.ops.numpy import var as var
297297
from keras.src.ops.numpy import vdot as vdot
298298
from keras.src.ops.numpy import vectorize as vectorize
299+
from keras.src.ops.numpy import view as view
299300
from keras.src.ops.numpy import vstack as vstack
300301
from keras.src.ops.numpy import where as where
301302
from keras.src.ops.numpy import zeros as zeros

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
from keras.src.ops.numpy import var as var
183183
from keras.src.ops.numpy import vdot as vdot
184184
from keras.src.ops.numpy import vectorize as vectorize
185+
from keras.src.ops.numpy import view as view
185186
from keras.src.ops.numpy import vstack as vstack
186187
from keras.src.ops.numpy import where as where
187188
from keras.src.ops.numpy import zeros as zeros

keras/src/backend/common/variables.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def assign(self, value):
282282
"The shape of the target variable and "
283283
"the shape of the target value in "
284284
"`variable.assign(value)` must match. "
285-
f"variable.shape={self.value.shape}, "
285+
f"variable.shape={self.shape}, "
286286
f"Received: value.shape={value.shape}. "
287287
f"Target variable: {self}"
288288
)
@@ -399,7 +399,11 @@ def constraint(self, value):
399399
def __repr__(self):
400400
value = None
401401
if hasattr(self, "_value") and self._value is not None:
402-
value = backend.core.convert_to_numpy(self._value)
402+
try:
403+
value = backend.core.convert_to_numpy(self._value)
404+
except:
405+
# In some cases the conversion to numpy can fail.
406+
pass
403407
value_str = f", value={value}" if value is not None else ""
404408
return (
405409
f"<Variable path={self.path}, shape={self.shape}, "

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,11 @@ def array(x, dtype=None):
446446
return jnp.array(x, dtype=dtype)
447447

448448

449+
def view(x, dtype=None):
450+
x = convert_to_tensor(x)
451+
return x.view(dtype=dtype)
452+
453+
449454
def average(x, axis=None, weights=None):
450455
x = convert_to_tensor(x)
451456
dtypes_to_resolve = [x.dtype, float]

keras/src/backend/numpy/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ def array(x, dtype=None):
294294
return convert_to_tensor(x, dtype=dtype)
295295

296296

297+
def view(x, dtype=None):
298+
x = convert_to_tensor(x)
299+
return x.view(dtype=dtype)
300+
301+
297302
def average(x, axis=None, weights=None):
298303
axis = standardize_axis_for_numpy(axis)
299304
x = convert_to_tensor(x)

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ NumpyDtypeTest::test_trunc
5151
NumpyDtypeTest::test_unravel
5252
NumpyDtypeTest::test_var
5353
NumpyDtypeTest::test_vdot
54+
NumpyDtypeTest::test_view
5455
NumpyDtypeTest::test_vstack
5556
HistogramTest
5657
NumpyOneInputOpsCorrectnessTest::test_angle
@@ -102,6 +103,7 @@ NumpyOneInputOpsCorrectnessTest::test_unravel_index
102103
NumpyOneInputOpsCorrectnessTest::test_var
103104
NumpyOneInputOpsCorrectnessTest::test_vectorize
104105
NumpyOneInputOpsCorrectnessTest::test_vstack
106+
NumpyOneInputOpsCorrectnessTest::test_view
105107
NumpyTwoInputOpsCorrectnessTest::test_bitwise_and
106108
NumpyTwoInputOpsCorrectnessTest::test_bitwise_left_shift
107109
NumpyTwoInputOpsCorrectnessTest::test_bitwise_or
@@ -131,10 +133,12 @@ NumpyOneInputOpsDynamicShapeTest::test_hanning
131133
NumpyOneInputOpsDynamicShapeTest::test_isposinf
132134
NumpyOneInputOpsDynamicShapeTest::test_isreal
133135
NumpyOneInputOpsDynamicShapeTest::test_kaiser
136+
NumpyOneInputOpsDynamicShapeTest::test_view
134137
NumpyOneInputOpsStaticShapeTest::test_angle
135138
NumpyOneInputOpsStaticShapeTest::test_cbrt
136139
NumpyOneInputOpsStaticShapeTest::test_isposinf
137140
NumpyOneInputOpsStaticShapeTest::test_isreal
141+
NumpyOneInputOpsStaticShapeTest::test_view
138142
NumpyTwoInputOpsDynamicShapeTest::test_gcd
139143
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
140144
NumpyTwoInputOpsDynamicShapeTest::test_hypot

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ def array(x, dtype=None):
508508
return np.array(x)
509509

510510

511+
def view(x, dtype=None):
512+
raise NotImplementedError("`view` is not supported with openvino backend")
513+
514+
511515
def average(x, axis=None, weights=None):
512516
x = get_ov_output(x)
513517
if weights is not None:

keras/src/backend/tensorflow/numpy.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,49 @@ def array(x, dtype=None):
998998
return convert_to_tensor(x, dtype=dtype)
999999

10001000

1001+
def view(x, dtype=None):
1002+
from keras.src import backend
1003+
1004+
x = convert_to_tensor(x)
1005+
old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
1006+
new_dtype = tf.as_dtype(
1007+
backend.standardize_dtype(dtype if dtype else x.dtype)
1008+
)
1009+
1010+
old_itemsize = old_dtype.size
1011+
new_itemsize = new_dtype.size
1012+
1013+
if list(x.shape)[-1] * old_itemsize % new_itemsize != 0:
1014+
raise ValueError(
1015+
f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
1016+
f"as dtype {new_dtype} because the total number of bytes "
1017+
f"is not divisible by the new itemsize."
1018+
)
1019+
1020+
if old_itemsize == new_itemsize:
1021+
return tf.bitcast(x, type=new_dtype)
1022+
elif old_itemsize > new_itemsize:
1023+
ratio = old_itemsize // new_itemsize
1024+
new_shape = list(shape_op(x))
1025+
new_shape[-1] *= ratio
1026+
flat_tensor = tf.reshape(x, [-1])
1027+
cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
1028+
return tf.reshape(cast_tensor, new_shape)
1029+
else:
1030+
old_shape = list(shape_op(x))
1031+
last_dim_size = old_shape[-1]
1032+
ratio = new_itemsize // old_itemsize
1033+
if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
1034+
raise ValueError(
1035+
f"Cannot view dtype. Last dimension size ({last_dim_size}) "
1036+
f"must be divisible by the ratio of new/old item sizes "
1037+
f"({ratio})."
1038+
)
1039+
intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
1040+
reshaped_tensor = tf.reshape(x, intermediate_shape)
1041+
return tf.bitcast(reshaped_tensor, new_dtype)
1042+
1043+
10011044
def average(x, axis=None, weights=None):
10021045
x = convert_to_tensor(x)
10031046

@@ -2258,7 +2301,7 @@ def _get_indices(method):
22582301
return gathered_y
22592302
perm = collections.deque(range(ndims))
22602303
perm.rotate(shift_value_static)
2261-
return tf.transpose(a=gathered_y, perm=perm)
2304+
return tf.transpose(a=gathered_y, perm=list(perm))
22622305

22632306

22642307
def quantile(x, q, axis=None, method="linear", keepdims=False):

0 commit comments

Comments
 (0)