@@ -6,7 +6,7 @@ class OpsVars:
66 Operators taking multiple inputs.
77 """
88
9- def BitShift (self , direction : str = b "" ) -> "Var" :
9+ def BitShift (self , direction : str = "" ) -> "Var" :
1010 return self .make_node ("BitShift" , * self .vars_ , direction = direction )
1111
1212 def CenterCropPad (self , axes : Optional [List [int ]] = None ) -> "Var" :
@@ -42,7 +42,7 @@ def Concat(self, axis: int = 0) -> "Var":
4242
4343 def Conv (
4444 self ,
45- auto_pad : str = b "NOTSET" ,
45+ auto_pad : str = "NOTSET" ,
4646 dilations : Optional [List [int ]] = None ,
4747 group : int = 1 ,
4848 kernel_shape : Optional [List [int ]] = None ,
@@ -66,7 +66,7 @@ def Conv(
6666
6767 def ConvInteger (
6868 self ,
69- auto_pad : str = b "NOTSET" ,
69+ auto_pad : str = "NOTSET" ,
7070 dilations : Optional [List [int ]] = None ,
7171 group : int = 1 ,
7272 kernel_shape : Optional [List [int ]] = None ,
@@ -90,7 +90,7 @@ def ConvInteger(
9090
9191 def ConvTranspose (
9292 self ,
93- auto_pad : str = b "NOTSET" ,
93+ auto_pad : str = "NOTSET" ,
9494 dilations : Optional [List [int ]] = None ,
9595 group : int = 1 ,
9696 kernel_shape : Optional [List [int ]] = None ,
@@ -155,7 +155,7 @@ def DeformConv(
155155 def DequantizeLinear (self , axis : int = 1 ) -> "Var" :
156156 return self .make_node ("DequantizeLinear" , * self .vars_ , axis = axis )
157157
158- def Einsum (self , equation : str = b "" ) -> "Var" :
158+ def Einsum (self , equation : str = "" ) -> "Var" :
159159 return self .make_node ("Einsum" , * self .vars_ , equation = equation )
160160
161161 def Gather (self , axis : int = 0 ) -> "Var" :
@@ -174,8 +174,8 @@ def Gemm(
174174 def GridSample (
175175 self ,
176176 align_corners : int = 0 ,
177- mode : str = b "bilinear" ,
178- padding_mode : str = b "zeros" ,
177+ mode : str = "bilinear" ,
178+ padding_mode : str = "zeros" ,
179179 ) -> "Var" :
180180 return self .make_node (
181181 "GridSample" ,
@@ -240,7 +240,7 @@ def Mod(self, fmod: int = 0) -> "Var":
240240 return self .make_node ("Mod" , * self .vars_ , fmod = fmod )
241241
242242 def NegativeLogLikelihoodLoss (
243- self , ignore_index : int = 0 , reduction : str = b "mean"
243+ self , ignore_index : int = 0 , reduction : str = "mean"
244244 ) -> "Var" :
245245 return self .make_node (
246246 "NegativeLogLikelihoodLoss" ,
@@ -257,12 +257,12 @@ def NonMaxSuppression(self, center_point_box: int = 0) -> "Var":
257257 def OneHot (self , axis : int = - 1 ) -> "Var" :
258258 return self .make_node ("OneHot" , * self .vars_ , axis = axis )
259259
260- def Pad (self , mode : str = b "constant" ) -> "Var" :
260+ def Pad (self , mode : str = "constant" ) -> "Var" :
261261 return self .make_node ("Pad" , * self .vars_ , mode = mode )
262262
263263 def QLinearConv (
264264 self ,
265- auto_pad : str = b "NOTSET" ,
265+ auto_pad : str = "NOTSET" ,
266266 dilations : Optional [List [int ]] = None ,
267267 group : int = 1 ,
268268 kernel_shape : Optional [List [int ]] = None ,
@@ -431,13 +431,13 @@ def Resize(
431431 self ,
432432 antialias : int = 0 ,
433433 axes : Optional [List [int ]] = None ,
434- coordinate_transformation_mode : str = b "half_pixel" ,
434+ coordinate_transformation_mode : str = "half_pixel" ,
435435 cubic_coeff_a : float = - 0.75 ,
436436 exclude_outside : int = 0 ,
437437 extrapolation_value : float = 0.0 ,
438- keep_aspect_ratio_policy : str = b "stretch" ,
439- mode : str = b "nearest" ,
440- nearest_mode : str = b "round_prefer_floor" ,
438+ keep_aspect_ratio_policy : str = "stretch" ,
439+ mode : str = "nearest" ,
440+ nearest_mode : str = "round_prefer_floor" ,
441441 ) -> "Var" :
442442 axes = axes or []
443443 return self .make_node (
@@ -456,8 +456,8 @@ def Resize(
456456
457457 def RoiAlign (
458458 self ,
459- coordinate_transformation_mode : str = b "half_pixel" ,
460- mode : str = b "avg" ,
459+ coordinate_transformation_mode : str = "half_pixel" ,
460+ mode : str = "avg" ,
461461 output_height : int = 1 ,
462462 output_width : int = 1 ,
463463 sampling_ratio : int = 0 ,
@@ -480,12 +480,12 @@ def STFT(self, onesided: int = 1) -> "Var":
480480 def Scatter (self , axis : int = 0 ) -> "Var" :
481481 return self .make_node ("Scatter" , * self .vars_ , axis = axis )
482482
483- def ScatterElements (self , axis : int = 0 , reduction : str = b "none" ) -> "Var" :
483+ def ScatterElements (self , axis : int = 0 , reduction : str = "none" ) -> "Var" :
484484 return self .make_node (
485485 "ScatterElements" , * self .vars_ , axis = axis , reduction = reduction
486486 )
487487
488- def ScatterND (self , reduction : str = b "none" ) -> "Var" :
488+ def ScatterND (self , reduction : str = "none" ) -> "Var" :
489489 return self .make_node ("ScatterND" , * self .vars_ , reduction = reduction )
490490
491491 def Slice (
@@ -498,13 +498,18 @@ def Slice(
498498
499499 def TopK (self , axis : int = - 1 , largest : int = 1 , sorted : int = 1 ) -> "Vars" :
500500 return self .make_node (
501- "TopK" , * self .vars_ , axis = axis , largest = largest , sorted = sorted
501+ "TopK" ,
502+ * self .vars_ ,
503+ axis = axis ,
504+ largest = largest ,
505+ sorted = sorted ,
506+ n_outputs = 2 ,
502507 )
503508
504509 def Trilu (self , upper : int = 1 ) -> "Var" :
505510 return self .make_node ("Trilu" , * self .vars_ , upper = upper )
506511
507- def Upsample (self , mode : str = b "nearest" ) -> "Var" :
512+ def Upsample (self , mode : str = "nearest" ) -> "Var" :
508513 return self .make_node ("Upsample" , * self .vars_ , mode = mode )
509514
510515 def Where (
0 commit comments