Skip to content

Commit 22415f7

Browse files
committed
Be more careful with conversions for fp8 ops.
Summary: There were some issues when passing keras symbolic inputs to some of the fp8 functions. This diff fixes these issues. Test Plan: CI Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee Maniphest Tasks: T65638 Differential Revision: https://phabricator.sourcevertex.net/D78367
1 parent e82030f commit 22415f7

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

tensorflow/python/ipu/ops/f8_ops.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow.python.ops.math_ops import cast
2727
from tensorflow.python.ops.nn_ops import _get_sequence
2828
import tensorflow as tf
29+
import numpy as np
2930

3031

3132
class Format(IntEnum):
@@ -60,19 +61,18 @@ def __init__(self, data, metadata):
6061
metadata: The metadata for this quarter tensor, should be the output
6162
of a call to `create_metadata`.
6263
"""
63-
self.data = self.maybe_get_tf_variable(data)
64+
self.data = self._maybe_get_tf_variable(data)
6465
self.metadata = metadata
6566

66-
def maybe_get_tf_variable(self, data):
67-
result = data
68-
if not isinstance(data, tf.Variable) and not isinstance(data, tf.Tensor):
69-
result = tf.Variable(data)
70-
if result.dtype != "uint8":
67+
def _maybe_get_tf_variable(self, data):
68+
if isinstance(data, np.ndarray):
69+
data = tf.Variable(data)
70+
if data.dtype != "uint8":
7171
raise TypeError(
7272
"Trying to set/update QuarterTensor data with a tensor of type "
73-
f"{result.dtype}, but only uint8 are supported. Check that data "
73+
f"{data.dtype}, but only uint8 are supported. Check that data "
7474
"is a value returned by `convert_to_f8`")
75-
return result
75+
return data
7676

7777
def numpy(self):
7878
"""Returns a numpy representation of the tensor
@@ -86,7 +86,7 @@ def assign(self, new_values, **kwargs):
8686
new_values: An array of format [data, metadata] that
8787
should be the output of `QuarterTensor.numpy`.
8888
"""
89-
self.data = self.maybe_get_tf_variable(new_values[0])
89+
self.data = self._maybe_get_tf_variable(new_values[0])
9090
self.metadata = new_values[1]
9191

9292
def __eq__(self, other):
@@ -127,7 +127,8 @@ def convert_to_f8(values, metadata, name=None):
127127
if not (values.dtype is dtypes.float16):
128128
values = cast(values, dtypes.float16)
129129

130-
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
130+
if isinstance(metadata, np.ndarray) or isinstance(metadata, int):
131+
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
131132
v, m = gen_popops_ops.ipu_convert_to_f8(values, metadata, name=name)
132133
return QuarterTensor(v, m)
133134

@@ -146,7 +147,8 @@ def convert_from_f8(packed_input, dtype=dtypes.half, name=None):
146147
Tensor with type dtype with unpacked f8 values.
147148
"""
148149
values, metadata = packed_input
149-
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
150+
if isinstance(metadata, np.ndarray) or isinstance(metadata, int):
151+
metadata = ops.convert_to_tensor(metadata, dtype=dtypes.uint8)
150152
values = gen_popops_ops.ipu_convert_from_f8(values, metadata, name=name)
151153
if values.dtype != dtype:
152154
values = cast(values, dtype=dtype)

0 commit comments

Comments
 (0)