@@ -20,6 +20,10 @@ class BuilderEmitter(BaseEmitter):
2020 Converts event into proper code.
2121 """
2222
23+ def __init__ (self , make_model_function : str = "" ):
24+ super ().__init__ ()
25+ self .make_model_function = make_model_function
26+
2327 def join (self , rows : List [str ], single_line : bool = False ) -> str :
2428 "Join the rows"
2529 assert (
@@ -29,6 +33,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2933
3034 def _emit_start (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
3135 self .opsets = kwargs .get ("opsets" , {})
36+ self .ir_version = kwargs .get ("ir_version" , None )
3237 return []
3338
3439 def _emit_to_onnx_model (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -43,12 +48,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4348 )
4449 rows = [
4550 "" ,
46- f"g = GraphBuilder({ self .opsets } )" ,
51+ (
52+ f"g = GraphBuilder({ self .opsets } , ir_version={ self .ir_version } )"
53+ if self .ir_version
54+ else f"GraphBuilder({ self .opsets } )"
55+ ),
4756 * inputs ,
4857 f"{ self .name } ({ inps } )" ,
4958 * outputs ,
5059 "model = g.to_onnx()" ,
5160 ]
61+ if self .make_model_function :
62+ rows = [
63+ "" ,
64+ "" ,
65+ f'def { self .make_model_function } () -> "ModelProto":' ,
66+ * [" " + _ for _ in rows [1 :]],
67+ " return model" ,
68+ "" ,
69+ "" ,
70+ f"model = { self .make_model_function } ()" ,
71+ ]
5272 return rows
5373
5474 def _emit_begin_graph (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -79,12 +99,14 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
7999 itype = kwargs .get ("elem_type" , 0 )
80100 shape = kwargs .get ("shape" , None )
81101 if itype == 0 :
82- inp = "X"
102+ inp = name or "X"
83103 else :
84104 if shape is None :
85- inp = f'X : "{ _itype_to_string (itype )} "'
105+ inp = f'{ name } : "{ _itype_to_string (itype )} "'
86106 else :
87- inp = f'X: "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
107+ inp = (
108+ f'{ name } : "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
109+ )
88110 self .inputs_full .append (inp )
89111 self .inputs .append (name )
90112 self .inputs_full_ .append ((name , _itype_to_string (itype ), shape ))
0 commit comments