@@ -1644,6 +1644,27 @@ def version_11(cls, ctx, node, **kwargs):
16441644
16451645@tf_op ("MatrixBandPart" )
16461646class MatrixBandPart :
1647+ @classmethod
1648+ def _apply_mask_and_transform (cls , ctx , node , mask ):
1649+ shapes = node .output_shapes
1650+ dtypes = node .output_dtypes
1651+ dtype = ctx .get_dtype (node .input [0 ])
1652+ data = node .input [0 ]
1653+ if dtype == TensorProto .BOOL :
1654+ # bool is not supported for 'Mul', so convert mask and input supported dtype
1655+ mask = ctx .make_node ("Cast" , inputs = mask .output , attr = {'to' : TensorProto .FLOAT }).output [0 ]
1656+ data = ctx .make_node ("Cast" , [data ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1657+ result = ctx .make_node (op_type = "Mul" , inputs = [mask , data ], shapes = shapes , dtypes = [TensorProto .FLOAT ])
1658+ ctx .remove_node (node .name )
1659+ ctx .make_node ("Cast" , inputs = result .output , attr = {'to' : dtype },
1660+ name = node .name , outputs = node .output , dtypes = dtypes )
1661+ else :
1662+ mask = ctx .make_node (op_type = "Cast" , inputs = mask .output , attr = {"to" : dtype }).output [0 ]
1663+ ctx .remove_node (node .name )
1664+ ctx .make_node (op_type = "Mul" , inputs = [mask , data ],
1665+ name = node .name , outputs = node .output , shapes = shapes ,
1666+ dtypes = dtypes )
1667+
16471668 @classmethod
16481669 def version_7 (cls , ctx , node , ** kwargs ):
16491670 # T output = MatrixBandPart(T input, int num_lower, int num_upper)
@@ -1714,14 +1735,7 @@ def version_7(cls, ctx, node, **kwargs):
17141735 mask_matrix = ctx .make_node (op_type = "Transpose" , inputs = cast1 .output )
17151736 else :
17161737 mask_matrix = squeeze
1717- cast2 = ctx .make_node (op_type = "Cast" , inputs = mask_matrix .output ,
1718- attr = {"to" : ctx .get_dtype (node .input [0 ])})
1719- shapes = node .output_shapes
1720- dtypes = node .output_dtypes
1721- ctx .remove_node (node .name )
1722- ctx .make_node (op_type = "Mul" , inputs = [cast2 .output [0 ], node .input [0 ]],
1723- name = node .name , outputs = node .output , shapes = shapes ,
1724- dtypes = dtypes )
1738+ cls ._apply_mask_and_transform (ctx , node , mask_matrix )
17251739
17261740 @classmethod
17271741 def version_11 (cls , ctx , node , ** kwargs ):
@@ -1739,17 +1753,12 @@ def version_11(cls, ctx, node, **kwargs):
17391753 {'data' : whole_shape , 'starts' : [- 2 ], 'ends' : [int_max_val ], 'axes' : [0 ]})
17401754 if num_lower_const == 0 and num_upper_const == 0 :
17411755 if rank == 2 :
1742- identity_node = ctx .make_node ("EyeLike" , [data ]). output [ 0 ]
1756+ identity_node = ctx .make_node ("EyeLike" , [data ])
17431757 else :
17441758 zero_tensor = helper .make_tensor ("value" , dtype , dims = [1 ], vals = [0 ])
17451759 const_of_shape = ctx .make_node ("ConstantOfShape" , [shape ], attr = {'value' : zero_tensor }).output [0 ]
1746- identity_node = ctx .make_node ("EyeLike" , [const_of_shape ]).output [0 ]
1747- shapes = node .output_shapes
1748- dtypes = node .output_dtypes
1749- ctx .remove_node (node .name )
1750- ctx .make_node (op_type = "Mul" , inputs = [identity_node , data ],
1751- name = node .name , outputs = node .output , shapes = shapes ,
1752- dtypes = dtypes )
1760+ identity_node = ctx .make_node ("EyeLike" , [const_of_shape ])
1761+ cls ._apply_mask_and_transform (ctx , node , identity_node )
17531762 return
17541763 zero_const = ctx .make_const (utils .make_name ("zero" ), np .array (0 , np .int64 )).output [0 ]
17551764 one_const = ctx .make_const (utils .make_name ("one" ), np .array (1 , np .int64 )).output [0 ]
@@ -1771,14 +1780,14 @@ def version_11(cls, ctx, node, **kwargs):
17711780 if ctx .get_dtype (num_upper ) != TensorProto .INT64 :
17721781 num_upper = ctx .make_node ("Cast" , [num_upper ], attr = {'to' : TensorProto .INT64 }).output [0 ]
17731782 greater = ctx .make_node ("Greater" , [idx_diff , num_upper ]).output [0 ]
1774- less_or_equal = ctx .make_node ("Not" , [greater ]). output [ 0 ]
1783+ less_or_equal = ctx .make_node ("Not" , [greater ])
17751784 conditions .append (less_or_equal )
17761785 if num_lower_const is None or num_lower_const >= 0 :
17771786 if ctx .get_dtype (num_lower ) != TensorProto .INT64 :
17781787 num_lower = ctx .make_node ("Cast" , [num_lower ], attr = {'to' : TensorProto .INT64 }).output [0 ]
17791788 num_lower_neg = ctx .make_node ("Neg" , [num_lower ]).output [0 ]
17801789 greater = ctx .make_node ("Greater" , [num_lower_neg , idx_diff ]).output [0 ]
1781- less_or_equal = ctx .make_node ("Not" , [greater ]). output [ 0 ]
1790+ less_or_equal = ctx .make_node ("Not" , [greater ])
17821791 conditions .append (less_or_equal )
17831792 if len (conditions ) == 0 :
17841793 node .type = "Identity"
@@ -1787,14 +1796,8 @@ def version_11(cls, ctx, node, **kwargs):
17871796 if len (conditions ) == 1 :
17881797 cond = conditions [0 ]
17891798 if len (conditions ) == 2 :
1790- cond = ctx .make_node ("And" , conditions ).output [0 ]
1791- mask = ctx .make_node ("Cast" , [cond ], attr = {'to' : ctx .get_dtype (data )}).output [0 ]
1792- shapes = node .output_shapes
1793- dtypes = node .output_dtypes
1794- ctx .remove_node (node .name )
1795- ctx .make_node (op_type = "Mul" , inputs = [mask , data ],
1796- name = node .name , outputs = node .output , shapes = shapes ,
1797- dtypes = dtypes )
1799+ cond = ctx .make_node ("And" , inputs = [c .output [0 ] for c in conditions ])
1800+ cls ._apply_mask_and_transform (ctx , node , cond )
17981801
17991802
18001803def _make_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
0 commit comments