@@ -216,6 +216,204 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
216216 return tensor_proto ;
217217 }
218218
219+ public static TensorShape constant_value_as_shape ( Tensor tensor )
220+ {
221+ bool hasattr ( Graph property , string attr )
222+ {
223+ var t = property . GetType ( ) . GetProperties ( ) ;
224+ foreach ( System . Reflection . PropertyInfo pi in t )
225+ {
226+ if ( pi . Name == attr )
227+ return true ;
228+ }
229+ return false ;
230+ }
231+
232+ if ( tensor . GetType ( ) == typeof ( EagerTensor ) )
233+ {
234+ int [ ] dims = { } ;
235+ foreach ( int dim in tensor . numpy ( ) )
236+ if ( dim != 1 )
237+ {
238+ dims [ dims . Length ] = dim ;
239+ } else
240+ {
241+ // -1 == Unknown
242+ dims [ dims . Length ] = - 1 ;
243+ }
244+ return new TensorShape ( dims ) ;
245+ }
246+
247+ if ( tensor . TensorShape . ndim == 0 )
248+ {
249+ var value_ = constant_value ( tensor ) ;
250+ if ( value_ == null )
251+ throw new ValueError (
252+ @"Received a scalar with unknown value as shape; require a statically
253+ known scalar with value '-1' to describe an unknown shape." ) ;
254+ if ( value_ != - 1 )
255+ throw new ValueError (
256+ String . Format ( @"Received a scalar value {0} as shape; require a statically known
257+ scalar with value '-1' to describe an unknown shape." , value_ ) ) ;
258+ return tensor . TensorShape . unknown_shape ( - 1 ) ;
259+ }
260+
261+ var shape = tensor . TensorShape . with_rank ( 1 ) ;
262+ if ( shape == new TensorShape ( new int [ ] { 1 } ) )
263+ {
264+ return new TensorShape ( new int [ ] { } ) ;
265+ } else if ( tensor . op . type == "Cast" )
266+ {
267+ var pre_cast = constant_value_as_shape ( tensor . op . inputs [ 0 ] ) ;
268+ if ( pre_cast . dims == null )
269+ return pre_cast ;
270+ var cast_dtype = dtypes . as_dtype ( ( Type ) tensor . op . get_attr ( "DstT" ) ) ;
271+ if ( ! Array . Exists ( new [ ] { dtypes . int32 , dtypes . int64 } , cast_dtype_ => cast_dtype_ == cast_dtype ) )
272+ return tensor . TensorShape . unknown_shape ( shape . dims [ 0 ] ) ;
273+
274+ int [ ] x_ = { } ;
275+ foreach ( var x in pre_cast . as_list ( ) )
276+ if ( x != - 1 )
277+ x_ [ x_ . Length ] = x ;
278+ else
279+ x_ [ x_ . Length ] = - 1 ;
280+ var dest_dtype_shape_array = np . array ( x_ ) . astype ( cast_dtype . as_numpy_dtype ( ) ) ;
281+
282+ int [ ] y_ = { } ;
283+ foreach ( int y in dest_dtype_shape_array )
284+ if ( y >= 0 )
285+ y_ [ y_ . Length ] = y ;
286+ else
287+ y_ [ y_ . Length ] = - 1 ;
288+ return new TensorShape ( y_ ) ;
289+ } else if ( tensor . op . type == "Shape" )
290+ {
291+ return tensor . op . inputs [ 0 ] . shape ;
292+ } else if ( tensor . op . type == "Pack" )
293+ {
294+ var ret_ = new TensorShape ( new int [ ] { } ) ;
295+ if ( ( int ) tensor . op . get_attr ( "axis" ) != 0 )
296+ throw new ValueError ( String . Format (
297+ @"Since rank 1 inputs are expected, Pack's axis: {0} must be 0, otherwise it
298+ would not be rank 1." , tensor . op . get_attr ( "axis" ) ) ) ;
299+ foreach ( Tensor pack_input in tensor . op . inputs )
300+ {
301+ var pack_input_val = constant_value ( pack_input ) ;
302+ Dimension new_dim ;
303+ if ( pack_input_val < 0 )
304+ {
305+ new_dim = new Dimension ( - 1 ) ;
306+ } else if ( pack_input_val == null )
307+ {
308+ new_dim = new Dimension ( - 1 ) ;
309+ } else
310+ {
311+ new_dim = new Dimension ( pack_input_val ) ;
312+ }
313+ ret_ = ret_ . concatenate ( new int [ ] { new_dim } ) ;
314+ }
315+ return ret_ ;
316+ } else if ( tensor . op . type == "Concat" )
317+ {
318+ var ret_ = new TensorShape ( new int [ ] { } ) ;
319+
320+ var inputlist_ = new ArraySegment < Tensor > ( tensor . op . inputs , 1 ,
321+ tensor . op . inputs . Length - 1 ) ;
322+ foreach ( var concat_input in inputlist_ )
323+ {
324+ ret_ = ret_ . concatenate ( constant_value_as_shape ( concat_input ) ) ;
325+ }
326+ return ret_ ;
327+ } else if ( tensor . op . type == "StridedSlice" )
328+ {
329+ try
330+ {
331+ var begin = constant_value ( tensor . op . inputs [ 1 ] ) ;
332+ var end = constant_value ( tensor . op . inputs [ 2 ] ) ;
333+ var strides = constant_value ( tensor . op . inputs [ 3 ] ) ;
334+ if ( new [ ] { begin , end , strides } . All ( x => x == null ) )
335+ {
336+ begin = begin [ 0 ] ;
337+ end = end [ 0 ] ;
338+ strides = strides [ 0 ] ;
339+ var begin_mask = tensor . op . get_attr ( "begin_mask" ) ;
340+ if ( ( int ) begin_mask == 1 )
341+ {
342+ begin = null ;
343+ }
344+ var end_mask = tensor . op . get_attr ( "end_mask" ) ;
345+ if ( ( int ) end_mask == 1 )
346+ {
347+ end = null ;
348+ }
349+
350+ var ellipsis_mask = tensor . op . get_attr ( "ellipsis_mask" ) ;
351+ var new_axis_mask = tensor . op . get_attr ( "new_axis_mask" ) ;
352+ var shrink_axis_mask = tensor . op . get_attr ( "shrink_axis_mask" ) ;
353+
354+ bool valid_attributes ;
355+ if ( ! ( bool ) ellipsis_mask && ! ( bool ) new_axis_mask &&
356+ ! ( bool ) shrink_axis_mask && ! ( ( bool ) begin_mask || ( int ) begin_mask == 1 ) &&
357+ ! ( ( bool ) end_mask || ( int ) end_mask == 1 ) )
358+ {
359+ valid_attributes = true ;
360+ } else { valid_attributes = false ; }
361+ if ( valid_attributes )
362+ {
363+ // sorry for the mess here, but this hacky solution was the best way
364+ // i could come up with to implement the things done in python in c#
365+ var prev_ = constant_value_as_shape ( tensor . op . inputs [ 0 ] ) . dims ;
366+ var prev = prev_ . Skip ( begin ) . Take ( end - begin ) . ToArray ( ) ;
367+ // 100 being the comparison doesn't really matter here; it's going to break anyway
368+ for ( int iter = 0 ; iter != 100 ; iter = iter + strides )
369+ {
370+ prev [ prev . Length ] = prev_ [ iter ] ;
371+ if ( ( iter + strides ) > prev_ . Length )
372+ break ;
373+ }
374+ var ret_ = new TensorShape ( prev ) ;
375+ return ret_ ;
376+ }
377+ }
378+ } catch ( Exception ex )
379+ {
380+ if ( ex is ValueError || ex is TypeError ) { }
381+ }
382+ } else if ( tensor . op . type == "Placeholder" &&
383+ tensor . op . graph . building_function &&
384+ hasattr ( tensor . op . graph , "internal_captures" ) )
385+ {
386+ int i = 0 ;
387+ foreach ( Tensor capture in tensor . op . graph . internal_captures ( ) )
388+ {
389+ if ( capture . GetType ( ) == typeof ( Tensor ) )
390+ {
391+ var external_capture = tensor . op . graph . external_captures ( ) [ i ] ;
392+ return constant_value_as_shape ( external_capture ) ;
393+ }
394+
395+ i ++ ;
396+ }
397+ }
398+
399+ var ret = tensor . TensorShape . unknown_shape ( shape . dims [ 0 ] ) ;
400+ var value = constant_value ( tensor ) ;
401+ if ( value != null )
402+ {
403+ int [ ] d_ = { } ;
404+ foreach ( int d in value )
405+ {
406+ if ( d >= 0 )
407+ d_ [ d_ . Length ] = d ;
408+ else
409+ d_ [ d_ . Length ] = - 1 ; // None
410+ }
411+ ret = ret . merge_with ( new TensorShape ( d_ ) ) ;
412+
413+ }
414+ return ret ;
415+ }
416+
219417 public static NDArray convert_to_numpy_ndarray ( object values )
220418 {
221419 NDArray nd ;
0 commit comments