@@ -406,46 +406,34 @@ def acc_ops_pad_with_slice_layer(
406406 )
407407
408408 input_shape = input_val .shape
409- pre_start = tuple (i - 1 for i in input_shape )
410409 prefix_len = len (input_shape ) - len (pad ) // 2
411- pre_shape = tuple (
412- input_shape [ i ] + ( pad [- (i - prefix_len ) * 2 - 2 ] if i >= prefix_len else 0 )
410+ start = tuple (
411+ - pad [- (i - prefix_len ) * 2 - 2 ] if i >= prefix_len else 0
413412 for i in range (0 , len (input_shape ))
414413 )
415- pre_stride = [- 1 ] * len (input_shape )
414+
415+ shape = tuple (
416+ input_shape [i ]
417+ + (
418+ pad [- (i - prefix_len ) * 2 - 1 ] + pad [- (i - prefix_len ) * 2 - 2 ]
419+ if i >= prefix_len
420+ else 0
421+ )
422+ for i in range (0 , len (input_shape ))
423+ )
424+ stride = tuple ([1 ] * len (shape ))
416425
417426 layer = network .add_slice (
418427 input_val ,
419- pre_start ,
420- pre_shape ,
421- pre_stride ,
428+ start ,
429+ shape ,
430+ stride ,
422431 )
423- layer .set_input (4 , value_const )
424- layer .mode = trt .SliceMode .FILL
425- set_layer_name (layer , target , f"pre_{ name } " )
426- half_pad_output = layer .get_output (0 )
427432
428- shape = half_pad_output .shape
429- mid_start = tuple (i - 1 for i in shape )
430- mid_stride = [- 1 ] * len (shape )
431- layer = network .add_slice (half_pad_output , mid_start , shape , mid_stride )
432433 layer .set_input (4 , value_const )
433434 layer .mode = trt .SliceMode .FILL
434- set_layer_name (layer , target , f"transpose_{ name } " )
435- transpose_output = layer .get_output (0 )
436-
437- shape = transpose_output .shape
438- post_start = tuple ([0 ] * len (shape ))
439- post_shape = tuple (
440- shape [i ] + (pad [- (i - prefix_len ) * 2 - 1 ] if i >= prefix_len else 0 )
441- for i in range (0 , len (shape ))
442- )
443- post_stride = tuple ([1 ] * len (shape ))
435+ set_layer_name (layer , target , name )
444436
445- layer = network .add_slice (transpose_output , post_start , post_shape , post_stride )
446- layer .set_input (4 , value_const )
447- layer .mode = trt .SliceMode .FILL
448- set_layer_name (layer , target , f"post_{ name } " )
449437 return layer .get_output (0 )
450438
451439
0 commit comments