Skip to content

Commit bf1c139

Browse files
authored
[Optimizer] Fix reinterpretation of strings in _get_numpy_value (#2514)
Signed-off-by: Christoph Berganski <christoph.berganski@gmail.com>
1 parent 6bf856e commit bf1c139

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,18 @@ def _get_numpy_value(
278278
if size_limit is not None and const_value.size > size_limit:
279279
return None
280280
try:
281-
# Reinterpret the array with `.view()` because some implementations of
282-
# ir.TensorProtocol (e.g. PyTorch<=2.7) do not use ml_dtypes for bfloat16 etc.
283-
array = const_value.numpy().view(const_value.dtype.numpy())
281+
# Turn the constant value into a numpy array representation with the
282+
# specifics of this conversion handled by the tensor type
283+
array = const_value.numpy()
284+
# Can/should not reinterpret strings via .view, resulting in
285+
# "TypeError: Cannot change data-type for array of references."
286+
# There is also no reason to reinterpret strings, this is only
287+
# relevant for some arithmetic types
288+
if const_value.dtype != ir.DataType.STRING:
289+
# Reinterpret the array with `.view()` because some
290+
# implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do
291+
# not use ml_dtypes for bfloat16 etc.
292+
array = array.view(const_value.dtype.numpy())
284293
except FileNotFoundError:
285294
# External data is not available.
286295
logger.warning(

0 commit comments

Comments
 (0)