@@ -106,6 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
106106 raise NotImplementedError (f"Unexpected dtype={ sdtype } ." )
107107 else :
108108 sdtype = f"np.{ sdtype } "
109+
109110 return [
110111 "initializers.append(" ,
111112 f" { fra } (" ,
@@ -209,3 +210,57 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
209210 ")" ,
210211 ]
211212 return lines
213+
214+
215+ class InnerEmitterShortInitializer (InnerEmitter ):
216+ """
217+ Converts event into proper code.
218+ Initializer are replaced by random values if too big.
219+ """
220+
221+ def _emit_initializer (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
222+ name = kwargs ["name" ]
223+ value = kwargs ["value" ]
224+ repl = {"bool" : "bool_" , "object" : "object_" , "str" : "str_" }
225+ fra = "from_array"
226+ sdtype = repl .get (str (value .dtype ), str (value .dtype ))
227+ if sdtype .startswith ("(" ):
228+ from onnx .reference .custom_element_types import float8e4m3fn
229+
230+ if sdtype == str (float8e4m3fn ):
231+ sdtype = "float8e4m3fn"
232+ fra = "from_array_extended"
233+ else :
234+ raise NotImplementedError (f"Unexpected dtype={ sdtype } ." )
235+ else :
236+ sdtype = f"np.{ sdtype } "
237+ if value .size <= 16 :
238+ return [
239+ "initializers.append(" ,
240+ f" { fra } (" ,
241+ f" np.array({ value .tolist ()} , dtype={ sdtype } )," ,
242+ f" name={ name !r} " ,
243+ " )" ,
244+ ")" ,
245+ ]
246+ if "int" in sdtype :
247+ return [
248+ f"value = np.random.randint(0, 10, size={ value .shape } )"
249+ f".astype({ sdtype } )" ,
250+ "initializers.append(" ,
251+ f" { fra } (" ,
252+ f" np.array(value, dtype={ sdtype } )," ,
253+ f" name={ name !r} " ,
254+ " )" ,
255+ ")" ,
256+ ]
257+ return [
258+ f"value = np.random.randn({ ', ' .join (map (str ,value .shape ))} )"
259+ f".astype({ sdtype } )" ,
260+ "initializers.append(" ,
261+ f" { fra } (" ,
262+ f" np.array(value, dtype={ sdtype } )," ,
263+ f" name={ name !r} " ,
264+ " )" ,
265+ ")" ,
266+ ]
0 commit comments