@@ -2091,3 +2091,93 @@ def version_11(cls, ctx, node, **kwargs):
20912091 ctx .replace_all_inputs (node .output [3 ], sum_max_neg )
20922092
20932093 ctx .remove_node (node .name )
2094+
2095+
2096+ @tf_op ("ExtractImagePatches" )
2097+ class ExtractImagePatches :
2098+ @classmethod
2099+ def version_9 (cls , ctx , node , ** kwargs ):
2100+ input_shape = ctx .get_shape (node .input [0 ])
2101+ output_shape = node .output_shapes [0 ]
2102+
2103+ sizes = node .get_attr_value ("ksizes" )
2104+ strides = node .get_attr_value ("strides" )
2105+ rates = node .get_attr_value ("rates" )
2106+ padding = node .get_attr_str ("padding" )
2107+
2108+ # This implementation of ExtractImagePatches does not generalize
2109+ # to outputs that are empty. For example:
2110+ #
2111+ # tf.image.extract_patches(
2112+ # np.random.rand(1, 1, 1, 1), sizes=[1, 2, 2, 1], strides=[1, 1, 1, 1],
2113+ # rates=[1, 1, 1, 1], padding="VALID"
2114+ # )
2115+ #
2116+ # succeeds with the output of:
2117+ #
2118+ # <tf.Tensor: shape=(1, 0, 0, 4), dtype=float64, numpy=array([], shape=(1, 0, 0, 4), dtype=float64)>
2119+ #
2120+ # whereas attempting the same here results in an "Invalid input shape" error for the "Conv" node.
2121+ utils .make_sure (0 not in output_shape , "Empty ExtractImagePatches output is unsupported." )
2122+ [_ , size_rows , size_cols , _ ] = sizes
2123+
2124+ # Transform input into [N * C, H, W, 1].
2125+ transformed_input = ctx .make_node ("Reshape" , inputs = [
2126+ ctx .make_node ("Transpose" , inputs = node .input , attr = dict (perm = [0 , 3 , 1 , 2 ])).output [0 ],
2127+ ctx .make_const (utils .make_name ("new_shape" ), np .int64 ([
2128+ input_shape [0 ] * input_shape [3 ],
2129+ input_shape [1 ],
2130+ input_shape [2 ],
2131+ 1 ,
2132+ ])).output [0 ],
2133+ ])
2134+
2135+ # Create identity kernel.
2136+ k = size_rows * size_cols
2137+ identity_kernel = ctx .make_node ("Reshape" , inputs = [
2138+ ctx .make_node ("EyeLike" , inputs = [
2139+ ctx .make_node ("ConstantOfShape" , inputs = [
2140+ ctx .make_const (utils .make_name ("eye_size" ), np .array ([k , k ], dtype = np .int64 )).output [0 ],
2141+ ]).output [0 ],
2142+ ]).output [0 ],
2143+ ctx .make_const (utils .make_name ("new_shape" ), np .array ([
2144+ size_rows ,
2145+ size_cols ,
2146+ 1 ,
2147+ k ,
2148+ ], dtype = np .int64 )).output [0 ],
2149+ ])
2150+
2151+ # Construct placeholder convolution node and transform into [N * C, K, ?H, ?W].
2152+ convolution = ctx .make_node ("Conv" , inputs = [transformed_input .output [0 ], identity_kernel .output [0 ]],
2153+ shapes = [[input_shape [0 ] * input_shape [3 ], output_shape [1 ], output_shape [2 ], k ]],
2154+ attr = dict (strides = strides , dilations = rates , padding = padding , data_format = "NHWC" ),
2155+ dtypes = node .output_dtypes )
2156+
2157+ # Transform into [N, ?H, ?W, C * K].
2158+ output_node = ctx .make_node ("Reshape" , inputs = [
2159+ ctx .make_node ("Transpose" , inputs = [
2160+ ctx .make_node ("Reshape" , inputs = [
2161+ convolution .output [0 ],
2162+ ctx .make_const (utils .make_name ("new_shape" ), np .array ([
2163+ input_shape [0 ],
2164+ input_shape [3 ],
2165+ output_shape [1 ],
2166+ output_shape [2 ],
2167+ k ,
2168+ ], dtype = np .int64 )).output [0 ],
2169+ ]).output [0 ],
2170+ ], attr = dict (perm = [0 , 2 , 3 , 4 , 1 ])).output [0 ],
2171+ ctx .make_const (utils .make_name ("new_shape" ), np .array (output_shape , dtype = np .int64 )).output [0 ],
2172+ ])
2173+
2174+ # Replace original node.
2175+ ctx .replace_all_inputs (node .output [0 ], output_node .output [0 ])
2176+ ctx .remove_node (node .name )
2177+
2178+ # Transform convolution node.
2179+ kernel_shape = conv_kernel_shape (ctx , convolution , 1 )
2180+ strides = conv_dims_attr (convolution , "strides" )
2181+ dilations = conv_dims_attr (convolution , "dilations" )
2182+ add_padding (ctx , convolution , kernel_shape , strides , dilations )
2183+ conv_convert_inputs (ctx , convolution , with_kernel = True )
0 commit comments