1+ using System ;
2+ using System . Collections . Generic ;
3+ using System . Text ;
4+ using System ;
5+ using Tensorflow . Keras . ArgsDefinition ;
6+ using Tensorflow . Keras . Saving ;
7+ using Tensorflow . Common . Types ;
8+ using Tensorflow . Keras . Utils ;
9+ using Tensorflow . Operations ;
10+ using Newtonsoft . Json ;
11+ using System . Security . Cryptography ;
12+
13+ namespace Tensorflow . Keras . Layers
14+ {
15+ public class DepthwiseConv2DArgs : Conv2DArgs
16+ {
17+ /// <summary>
18+ /// depth_multiplier: The number of depthwise convolution output channels for
19+ /// each input channel.The total number of depthwise convolution output
20+ /// channels will be equal to `filters_in* depth_multiplier`.
21+ /// </summary>
22+ [ JsonProperty ( "depth_multiplier" ) ]
23+ public int DepthMultiplier { get ; set ; } = 1 ;
24+
25+ [ JsonProperty ( "depthwise_initializer" ) ]
26+ public IInitializer DepthwiseInitializer { get ; set ; }
27+ }
28+
29+ public class DepthwiseConv2D : Conv2D
30+ {
31+ /// <summary>
32+ /// depth_multiplier: The number of depthwise convolution output channels for
33+ /// each input channel.The total number of depthwise convolution output
34+ /// channels will be equal to `filters_in* depth_multiplier`.
35+ /// </summary>
36+ int DepthMultiplier = 1 ;
37+
38+ IInitializer DepthwiseInitializer ;
39+
40+ int [ ] strides ;
41+
42+ int [ ] dilation_rate ;
43+
44+ string getDataFormat ( )
45+ {
46+ return data_format == "channels_first" ? "NCHW" : "NHWC" ;
47+ }
48+
49+ static int _id = 1 ;
50+
51+ public DepthwiseConv2D ( DepthwiseConv2DArgs args ) : base ( args )
52+ {
53+ args . Padding = args . Padding . ToUpper ( ) ;
54+
55+ if ( string . IsNullOrEmpty ( args . Name ) )
56+ name = "DepthwiseConv2D_" + _id ;
57+
58+ this . DepthMultiplier = args . DepthMultiplier ;
59+ this . DepthwiseInitializer = args . DepthwiseInitializer ;
60+
61+ }
62+
63+ public override void build ( KerasShapesWrapper input_shape )
64+ {
65+ //base.build(input_shape);
66+
67+ var shape = input_shape . ToSingleShape ( ) ;
68+
69+ int channel_axis = data_format == "channels_first" ? 1 : - 1 ;
70+ var input_channel = channel_axis < 0 ?
71+ shape . dims [ shape . ndim + channel_axis ] :
72+ shape . dims [ channel_axis ] ;
73+
74+ var arg = args as DepthwiseConv2DArgs ;
75+
76+ if ( arg . Strides . ndim != shape . ndim )
77+ {
78+ if ( arg . Strides . ndim == 2 )
79+ {
80+ this . strides = new int [ ] { 1 , ( int ) arg . Strides [ 0 ] , ( int ) arg . Strides [ 1 ] , 1 } ;
81+ }
82+ else
83+ {
84+ this . strides = conv_utils . normalize_tuple ( new int [ ] { ( int ) arg . Strides [ 0 ] } , shape . ndim , "strides" ) ;
85+ }
86+ }
87+ else
88+ {
89+ this . strides = arg . Strides . dims . Select ( o=> ( int ) ( o ) ) . ToArray ( ) ;
90+ }
91+
92+ if ( arg . DilationRate . ndim != shape . ndim )
93+ {
94+ this . dilation_rate = conv_utils . normalize_tuple ( new int [ ] { ( int ) arg . DilationRate [ 0 ] } , shape . ndim , "dilation_rate" ) ;
95+ }
96+
97+ long channel_data = data_format == "channels_first" ? shape [ 0 ] : shape [ shape . Length - 1 ] ;
98+
99+ var depthwise_kernel_shape = this . kernel_size . dims . concat ( new long [ ] {
100+ channel_data ,
101+ this . DepthMultiplier
102+ } ) ;
103+
104+ this . kernel = this . add_weight (
105+ shape : depthwise_kernel_shape ,
106+ initializer : this . DepthwiseInitializer != null ? this . DepthwiseInitializer : this . kernel_initializer ,
107+ name : "depthwise_kernel" ,
108+ trainable : true ,
109+ dtype : DType ,
110+ regularizer : this . kernel_regularizer
111+ ) ;
112+
113+ var axes = new Dictionary < int , int > ( ) ;
114+ axes . Add ( - 1 , ( int ) input_channel ) ;
115+ inputSpec = new InputSpec ( min_ndim : rank + 2 , axes : axes ) ;
116+
117+
118+ if ( use_bias )
119+ {
120+ bias = add_weight ( name : "bias" ,
121+ shape : ( ( int ) channel_data ) ,
122+ initializer : bias_initializer ,
123+ trainable : true ,
124+ dtype : DType ) ;
125+ }
126+
127+ built = true ;
128+ _buildInputShape = input_shape ;
129+ }
130+
131+ protected override Tensors Call ( Tensors inputs , Tensors state = null ,
132+ bool ? training = false , IOptionalArgs ? optional_args = null )
133+ {
134+ Tensor outputs = null ;
135+
136+ outputs = gen_nn_ops . depthwise_conv2d_native (
137+ inputs ,
138+ filter : this . kernel . AsTensor ( ) ,
139+ strides : this . strides ,
140+ padding : this . padding ,
141+ dilations : this . dilation_rate ,
142+ data_format : this . getDataFormat ( ) ,
143+ name : name
144+ ) ;
145+
146+ if ( use_bias )
147+ {
148+ if ( data_format == "channels_first" )
149+ {
150+ throw new NotImplementedException ( "call channels_first" ) ;
151+ }
152+ else
153+ {
154+ outputs = gen_nn_ops . bias_add ( outputs , ops . convert_to_tensor ( bias ) ,
155+ data_format : this . getDataFormat ( ) , name : name ) ;
156+ }
157+ }
158+
159+ if ( activation != null )
160+ outputs = activation . Apply ( outputs ) ;
161+
162+
163+ return outputs ;
164+ }
165+
166+ }
167+ }
0 commit comments