@@ -26,6 +26,8 @@ class ActivationOpBuilder : public BaseOpBuilder {
2626 const logging::Logger& logger) const override ;
2727
2828 int GetMinSupportedOpSet (const Node& node) const override ;
29+
30+ bool SupportsMLProgram () const override { return true ; }
2931};
3032
3133void ActivationOpBuilder::AddInitializersToSkip (ModelBuilder& model_builder, const Node& node) const {
@@ -74,33 +76,61 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node,
7476Status ActivationOpBuilder::AddToModelBuilderImpl (ModelBuilder& model_builder,
7577 const Node& node,
7678 const logging::Logger& logger) const {
77- std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer (node);
78-
7979 const auto & op_type (node.OpType ());
80- if (op_type == " Sigmoid" ) {
81- layer->mutable_activation ()->mutable_sigmoid ();
82- } else if (op_type == " Tanh" ) {
83- layer->mutable_activation ()->mutable_tanh ();
84- } else if (op_type == " Relu" ) {
85- layer->mutable_activation ()->mutable_relu ();
86- } else if (op_type == " PRelu" ) {
87- auto * prelu = layer->mutable_activation ()->mutable_prelu ();
88- ORT_RETURN_IF_ERROR (AddPReluWeight (model_builder, node, logger, *prelu));
89- } else if (op_type == " LeakyRelu" ) {
90- NodeAttrHelper helper (node);
91- const auto alpha = helper.Get (" alpha" , 0 .01f );
92-
93- auto * leaky_relu = layer->mutable_activation ()->mutable_leakyrelu ();
94- leaky_relu->set_alpha (alpha);
95- } else {
96- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
97- " ActivationOpBuilder::AddToModelBuilderImpl, unknown op: " , op_type);
98- }
9980
100- *layer->mutable_input ()->Add () = node.InputDefs ()[0 ]->Name ();
101- *layer->mutable_output ()->Add () = node.OutputDefs ()[0 ]->Name ();
81+ #if defined(COREML_ENABLE_MLPROGRAM)
82+ if (model_builder.CreateMLProgram ()) {
83+ using namespace CoreML ::Specification::MILSpec;
84+ // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation
85+ std::string_view coreml_op_type;
86+ if (op_type == " Sigmoid" ) {
87+ coreml_op_type = " sigmoid" ;
88+ } else if (op_type == " Tanh" ) {
89+ coreml_op_type = " tanh" ;
90+ } else if (op_type == " Relu" ) {
91+ coreml_op_type = " relu" ;
92+ } else {
93+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
94+ " ActivationOpBuilder::AddToModelBuilderImpl, unknown op: " , op_type);
95+ }
96+
97+ std::unique_ptr<Operation> op = model_builder.CreateOperation (node, coreml_op_type);
98+ AddOperationInput (*op, " x" , node.InputDefs ()[0 ]->Name ());
99+ AddOperationOutput (*op, *node.OutputDefs ()[0 ]);
100+
101+ model_builder.AddOperation (std::move (op));
102+
103+ } else
104+ #endif // (COREML_ENABLE_MLPROGRAM)
105+ {
106+ std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer (node);
107+
108+ if (op_type == " Sigmoid" ) {
109+ layer->mutable_activation ()->mutable_sigmoid ();
110+ } else if (op_type == " Tanh" ) {
111+ layer->mutable_activation ()->mutable_tanh ();
112+ } else if (op_type == " Relu" ) {
113+ layer->mutable_activation ()->mutable_relu ();
114+ } else if (op_type == " PRelu" ) {
115+ auto * prelu = layer->mutable_activation ()->mutable_prelu ();
116+ ORT_RETURN_IF_ERROR (AddPReluWeight (model_builder, node, logger, *prelu));
117+ } else if (op_type == " LeakyRelu" ) {
118+ NodeAttrHelper helper (node);
119+ const auto alpha = helper.Get (" alpha" , 0 .01f );
120+
121+ auto * leaky_relu = layer->mutable_activation ()->mutable_leakyrelu ();
122+ leaky_relu->set_alpha (alpha);
123+ } else {
124+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
125+ " ActivationOpBuilder::AddToModelBuilderImpl, unknown op: " , op_type);
126+ }
127+
128+ *layer->mutable_input ()->Add () = node.InputDefs ()[0 ]->Name ();
129+ *layer->mutable_output ()->Add () = node.OutputDefs ()[0 ]->Name ();
130+
131+ model_builder.AddLayer (std::move (layer));
132+ }
102133
103- model_builder.AddLayer (std::move (layer));
104134 return Status::OK ();
105135}
106136
@@ -165,9 +195,20 @@ bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_para
165195bool ActivationOpBuilder::IsOpSupportedImpl (const Node& node, const OpBuilderInputParams& input_params,
166196 const logging::Logger& logger) const {
167197 const auto & op_type = node.OpType ();
168- if (op_type == " PRelu" ) {
169- return IsPReluOpSupported (node, input_params, logger);
198+
199+ #if defined(COREML_ENABLE_MLPROGRAM)
200+ if (input_params.create_mlprogram ) {
201+ if (op_type == " PRelu" || op_type == " LeakyRelu" ) {
202+ return false ;
203+ }
204+ } else
205+ #endif // (COREML_ENABLE_MLPROGRAM)
206+ {
207+ if (op_type == " PRelu" ) {
208+ return IsPReluOpSupported (node, input_params, logger);
209+ }
170210 }
211+
171212 return true ;
172213}
173214
0 commit comments