44from .base_emitter import BaseEmitter
55
66_types = {
7+ TensorProto .DOUBLE : "DOUBLE" ,
78 TensorProto .FLOAT : "FLOAT" ,
89 TensorProto .FLOAT16 : "FLOAT16" ,
910 TensorProto .INT64 : "INT64" ,
1011 TensorProto .INT32 : "INT32" ,
12+ TensorProto .INT16 : "INT16" ,
13+ TensorProto .UINT64 : "UINT64" ,
14+ TensorProto .UINT32 : "UINT32" ,
15+ TensorProto .UINT16 : "UINT16" ,
16+ TensorProto .STRING : "STRING" ,
17+ TensorProto .BOOL : "BOOL" ,
1118}
1219
1320
@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
2027 Converts event into proper code.
2128 """
2229
30+ def __init__ (self , make_model_function : str = "" ):
31+ super ().__init__ ()
32+ self .make_model_function = make_model_function
33+
2334 def join (self , rows : List [str ], single_line : bool = False ) -> str :
2435 "Join the rows"
2536 assert (
@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2940
3041 def _emit_start (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
3142 self .opsets = kwargs .get ("opsets" , {})
43+ self .ir_version = kwargs .get ("ir_version" , None )
3244 return []
3345
3446 def _emit_to_onnx_model (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4355 )
4456 rows = [
4557 "" ,
46- f"g = GraphBuilder({ self .opsets } )" ,
58+ (
59+ f"g = GraphBuilder({ self .opsets } , ir_version={ self .ir_version } )"
60+ if self .ir_version
61+ else f"GraphBuilder({ self .opsets } )"
62+ ),
4763 * inputs ,
4864 f"{ self .name } ({ inps } )" ,
4965 * outputs ,
5066 "model = g.to_onnx()" ,
5167 ]
68+ if self .make_model_function :
69+ rows = [
70+ "" ,
71+ "" ,
72+ f'def { self .make_model_function } () -> "ModelProto":' ,
73+ * [" " + _ for _ in rows [1 :]],
74+ " return model" ,
75+ "" ,
76+ "" ,
77+ f"model = { self .make_model_function } ()" ,
78+ ]
5279 return rows
5380
5481 def _emit_begin_graph (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
78105 name = kwargs ["name" ]
79106 itype = kwargs .get ("elem_type" , 0 )
80107 shape = kwargs .get ("shape" , None )
108+ name = self ._clean_result_name (name )
81109 if itype == 0 :
82- inp = "X"
110+ inp = name or "X"
83111 else :
84112 if shape is None :
85- inp = f'X : "{ _itype_to_string (itype )} "'
113+ inp = f'{ name } : "{ _itype_to_string (itype )} "'
86114 else :
87- inp = f'X: "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
115+ inp = (
116+ f'{ name } : "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
117+ )
88118 self .inputs_full .append (inp )
89119 self .inputs .append (name )
90120 self .inputs_full_ .append ((name , _itype_to_string (itype ), shape ))
@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
113143
114144 def _emit_output (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
115145 name = kwargs ["name" ]
146+ name = self ._clean_result_name (name )
116147 itype = kwargs .get ("elem_type" , 0 )
117148 shape = kwargs .get ("shape" , None )
118149 self .outputs .append (name )
@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
126157 if kwargs .get ("domain" , "" ) != "" :
127158 domain = kwargs ["domain" ]
128159 op_type = f"{ domain } .{ op_type } "
160+ else :
161+ domain = ""
129162 atts = kwargs .get ("atts" , {})
130163 args = []
131164 for k , v in atts .items ():
@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
134167 raise NotImplementedError ("Graph attribute not supported yet." )
135168 args .append (f"{ k } ={ vatt } " )
136169
137- outs = ", " .join (outputs )
138- inps = ", " .join (inputs )
170+ outs = ", " .join (map (self ._clean_result_name , outputs ))
171+ inps = ", " .join (map (self ._clean_result_name , inputs ))
172+ op_type = self ._emit_node_type (op_type , domain )
173+ sdomain = "" if not domain else f", domain={ domain !r} "
139174 if args :
140175 sargs = ", " .join (args )
141- row = f" { outs } = op.{ op_type } ({ inps } , { sargs } )"
176+ if inps :
177+ row = f" { outs } = op.{ op_type } ({ inps } , { sargs } { sdomain } )"
178+ else :
179+ row = f" { outs } = op.{ op_type } ({ sargs } { sdomain } )"
142180 else :
143- row = f" { outs } = op.{ op_type } ({ inps } )"
181+ row = f" { outs } = op.{ op_type } ({ inps } { sdomain } )"
144182 return [row ]
183+
184+ def _clean_result_name (self , name ):
185+ return name
186+
187+ def _emit_node_type (self , op_type , domain ):
188+ return op_type
0 commit comments