@@ -2361,17 +2361,17 @@ def acc_ops_slice_tensor(
23612361
23622362 ranks = len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
23632363 dim = get_positive_dim (cast (int , kwargs ["dim" ]), ranks )
2364-
2364+ dynamic_shape = has_dynamic_shape ( input_val . shape )
23652365 if network .has_implicit_batch_dimension :
23662366 if dim == 0 :
23672367 raise RuntimeError (
23682368 f"We do not support slice_tensor at batch dim when it's implicit, got { dim } !"
23692369 )
23702370 dim = dim - 1
23712371 else :
2372- raise RuntimeError (
2373- "We don't support slice_tensor with explicit batch dimension yet!"
2374- )
2372+ if dynamic_shape :
2373+ # Check whether slice target dim is dynamic shape dim
2374+ assert input_val . shape [ dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
23752375
23762376 start_int = cast (int , kwargs ["start" ])
23772377 stop_int = cast (int , kwargs ["stop" ])
@@ -2383,7 +2383,18 @@ def acc_ops_slice_tensor(
23832383 output_shape = list (input_val .shape )
23842384 output_shape [dim ] = (stop_int - start_int ) // step_int
23852385
2386- layer = network .add_slice (input_val , start = start , shape = output_shape , stride = stride )
2386+ if dynamic_shape > 0 :
2387+ output_shape = get_shape_with_dynamic_shape (
2388+ network , output_shape , input_val , target , name
2389+ )
2390+ layer = network .add_slice (
2391+ input_val ,
2392+ start = start ,
2393+ shape = [] if dynamic_shape else output_shape ,
2394+ stride = stride ,
2395+ )
2396+ if dynamic_shape :
2397+ layer .set_input (2 , output_shape )
23872398 set_layer_name (layer , target , name )
23882399 return layer .get_output (0 )
23892400
@@ -2584,11 +2595,14 @@ def acc_ops_split(
25842595 )
25852596
25862597 dim = cast (int , kwargs ["dim" ])
2598+ dynamic_shape = has_dynamic_shape (input_val .shape )
25872599 if network .has_implicit_batch_dimension :
25882600 assert dim != 0 , "Can't split on batch dim when it's implicit!"
25892601 dim -= 1
25902602 else :
2591- raise RuntimeError ("We don't support split with explicit batch dimension yet!" )
2603+ if dynamic_shape > 0 :
2604+ # Check whether slice target dim is dynamic shape dim
2605+ assert input_val .shape [dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
25922606
25932607 split_size = cast (int , kwargs ["split_size" ])
25942608 start = [0 ] * len (input_val .shape )
@@ -2607,7 +2621,15 @@ def acc_ops_split(
26072621 shape = list (input_val .shape )
26082622 shape [dim ] = min (split_size , cast (int , max_offset - offset ))
26092623 start [dim ] = offset
2610- layer = network .add_slice (input_val , start = start , shape = shape , stride = stride )
2624+ if dynamic_shape :
2625+ shape = get_shape_with_dynamic_shape (
2626+ network , shape , input_val , target , f"{ name } _shape_{ i } "
2627+ )
2628+ layer = network .add_slice (
2629+ input_val , start = start , shape = [] if dynamic_shape else shape , stride = stride
2630+ )
2631+ if dynamic_shape :
2632+ layer .set_input (2 , shape )
26112633 offset += split_size
26122634 set_layer_name (layer , target , f"{ name } _{ i } " )
26132635 output .append (layer .get_output (0 ))
@@ -2761,7 +2783,7 @@ def acc_ops_getitem(
27612783 slices = (slices ,)
27622784
27632785 dynamic_shape = get_dynamic_dims (input_val .shape )
2764- if dynamic_shape :
2786+ if len ( dynamic_shape ) > 0 :
27652787 for i , s in zip (input_val .shape , slices ):
27662788 assert i > 0 or (
27672789 s in [slice (None , None , None ), slice (0 , None , None ), Ellipsis ]
0 commit comments