|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.Linq; |
| 4 | +using System.Text; |
| 5 | + |
| 6 | +namespace Tensorflow.Operations |
| 7 | +{ |
| 8 | + public class _WithSpaceToBatch |
| 9 | + { |
| 10 | + private _NonAtrousConvolution call; |
| 11 | + |
| 12 | + public _WithSpaceToBatch(TensorShape input_shape, |
| 13 | + int[] dilation_rate, |
| 14 | + string padding, |
| 15 | + Func<int, string, _NonAtrousConvolution> build_op, |
| 16 | + TensorShape filter_shape = null, |
| 17 | + int[] spatial_dims = null, |
| 18 | + string data_format = null) |
| 19 | + { |
| 20 | + var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); |
| 21 | + var rate_shape = dilation_rate_tensor.getShape(); |
| 22 | + var num_spatial_dims = rate_shape.Dimensions[0]; |
| 23 | + int starting_spatial_dim = -1; |
| 24 | + if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) |
| 25 | + starting_spatial_dim = 2; |
| 26 | + else |
| 27 | + starting_spatial_dim = 1; |
| 28 | + |
| 29 | + if (spatial_dims == null) |
| 30 | + throw new NotImplementedException("_WithSpaceToBatch spatial_dims"); |
| 31 | + |
| 32 | + var orig_spatial_dims = spatial_dims; |
| 33 | + spatial_dims = spatial_dims.OrderBy(x => x).ToArray(); |
| 34 | + if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1)) |
| 35 | + throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers"); |
| 36 | + |
| 37 | + int expected_input_rank = -1; |
| 38 | + if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) |
| 39 | + expected_input_rank = spatial_dims.Last(); |
| 40 | + else |
| 41 | + expected_input_rank = spatial_dims.Last() + 1; |
| 42 | + |
| 43 | + var const_rate = tensor_util.constant_value(dilation_rate_tensor); |
| 44 | + var rate_or_const_rate = dilation_rate; |
| 45 | + if(!(const_rate is null)) |
| 46 | + { |
| 47 | + if (const_rate.Data<int>().Count(x => x == 1) == const_rate.size) |
| 48 | + { |
| 49 | + call = build_op(num_spatial_dims, padding); |
| 50 | + return; |
| 51 | + } |
| 52 | + } |
| 53 | + } |
| 54 | + } |
| 55 | +} |
0 commit comments