Skip to content

Commit 70acefe

Browse files
[CVS-172796] fix bfloat16 conversion when single cast node to bfloat16 (#841)
* disable bfloat16 conversion when single cast node to bfloat16, unit test case * Insert a Cast(To:BFloat16) before output node(bfloat16) to keep user use original bf16 outputs tensor * revert changes to add Cast Node, add statement to disable bfloat16 transform for OV CPU * remove bfloat16 silence conversion * remove bf16 testing and cpu support for openvino --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com>
1 parent fa68db1 commit 70acefe

File tree

5 files changed

+1
-202
lines changed

5 files changed

+1
-202
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -389,18 +389,6 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
389389
return false;
390390
}
391391

392-
static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) {
393-
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
394-
for (std::size_t i = 0; i < node_indices.size(); i++) {
395-
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));
396-
for (auto& output : node->OutputDefs()) {
397-
if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
398-
return true;
399-
}
400-
}
401-
return false;
402-
}
403-
404392
static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) {
405393
const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr;
406394
return type_proto && type_proto->has_tensor_type() &&
@@ -598,16 +586,6 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
598586
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
599587
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
600588
return model_proto;
601-
} else if (IsModelBF16(subgraph)) {
602-
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled";
603-
std::unique_ptr<onnxruntime::Model> model;
604-
Status status = bfloat16_fix::Transform(subgraph, logger, model);
605-
auto model_proto = model->ToProto();
606-
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
607-
print_model_proto_duration();
608-
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
609-
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
610-
return model_proto;
611589
} else {
612590
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled";
613591

onnxruntime/core/providers/openvino/ov_versions/data_ops.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,7 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
561561
}
562562

563563
auto dtype = type_proto->tensor_type().elem_type();
564-
// Enable bfloat16 -> float16 on-the-fly conversion
565-
if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16 ||
566-
dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 ||
564+
if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 ||
567565
dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)
568566
return true;
569567
if (is_initializer) {

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "qdq_scales_fix.h"
55
#include "core/providers/openvino/ov_protobuf_utils.h"
66
#include "core/framework/ort_value.h"
7-
#include "core/common/float16.h"
87

98
#include <fstream>
109
#include <list>
@@ -955,60 +954,5 @@ Status Transform(const GraphViewer& src_graph_viewer,
955954
return status;
956955
}
957956
} // namespace qdq_scales_fix
958-
959-
namespace bfloat16_fix {
960-
void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) {
961-
for (auto& const_node : gen_graph.original_graph.Nodes()) {
962-
auto node = const_cast<ONNX_NAMESPACE::Node*>(const_node);
963-
if (node->OpType() == "Cast") {
964-
for (auto& [name, const_attribute] : node->GetAttributes()) {
965-
auto& attribute = const_cast<ONNX_NAMESPACE::AttributeProto&>(const_attribute);
966-
if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT)
967-
if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
968-
attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
969-
}
970-
}
971-
for (auto& output : node->OutputDefs()) {
972-
auto& output_proto = const_cast<ONNX_NAMESPACE::TypeProto&>(output->ToProto().type());
973-
if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
974-
output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
975-
}
976-
}
977-
978-
for (auto& node : gen_graph.original_graph.Nodes()) {
979-
for (auto& input_def : node->InputDefs()) {
980-
ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(gen_graph.original_graph, input_def->Name()));
981-
}
982-
}
983-
984-
const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors();
985-
for (auto& [key, const_tensor_proto] : init_set) {
986-
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(const_tensor_proto);
987-
auto dt = tensor_proto->data_type();
988-
if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
989-
auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast<std::uint16_t*>(tensor_proto->mutable_raw_data()->data()) : nullptr;
990-
if (raw_data) {
991-
tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
992-
std::int64_t size = 1;
993-
for (int i = 0; i < tensor_proto->dims_size(); ++i)
994-
size *= tensor_proto->dims()[i];
995-
for (std::int64_t i = 0; i < size; ++i) {
996-
raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val;
997-
}
998-
}
999-
}
1000-
}
1001-
}
1002-
1003-
Status Transform(const GraphViewer& src_graph_viewer,
1004-
const logging::Logger& logger,
1005-
/*out*/ std::unique_ptr<onnxruntime::Model>& model) {
1006-
auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model);
1007-
auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph());
1008-
1009-
replace_bf16_with_fp16(g);
1010-
return status;
1011-
}
1012-
} // namespace bfloat16_fix
1013957
} // namespace openvino_ep
1014958
} // namespace onnxruntime

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,5 @@ Status Transform(const GraphViewer& src_graph,
1515
const logging::Logger& logger,
1616
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
1717
}
18-
namespace bfloat16_fix {
19-
Status Transform(const GraphViewer& src_graph,
20-
const logging::Logger& logger,
21-
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
22-
}
2318
} // namespace openvino_ep
2419
} // namespace onnxruntime

onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc

Lines changed: 0 additions & 116 deletions
This file was deleted.

0 commit comments

Comments
 (0)