5656CHUNK_SIZES = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 ]
5757LAZY_MODE = int (os .environ .get ("PT_HPU_LAZY_MODE" , 1 ))
5858BATCH_SIZE_EXPONENT_BASE = int (os .environ .get ("BATCH_SIZE_EXPONENT_BASE" , 2 ))
59+ SEQ_LEN_EXPONENT_BASE = int (os .environ .get ("SEQ_LEN_EXPONENT_BASE" , 2 ))
5960MAX_BATCH_SIZE = (
6061 int (os .environ .get ("MAX_BATCH_SIZE" ))
6162 if os .environ .get ("MAX_BATCH_SIZE" ) is not None
@@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
7172 )
7273
7374
74- def round_up_seq (number , k ):
75- return (number + k - 1 ) // k * k
75+ def round_up_seq (number , k , base ):
76+ exponent = math .ceil (math .log (number / k , base ))
77+ return int (k * (base ** exponent ))
78+
79+
80+ def iterate_powers_of_base (max_value , start , base ):
81+ current = start
82+ result = []
83+ assert (
84+ max_value >= start
85+ ), f"max_value { max_value } must be greater than start { start } "
86+ while current < max_value :
87+ result .append (current )
88+ current *= base
89+ return result
7690
7791
7892def round_up_batch (number ):
@@ -575,7 +589,9 @@ def from_pb(
575589 assert (
576590 PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
577591 ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
578- rounded_seq_len = round_up_seq (input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF )
592+ rounded_seq_len = round_up_seq (
593+ input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF , SEQ_LEN_EXPONENT_BASE
594+ )
579595 if rounded_seq_len <= max_input_length :
580596 bucket_size = rounded_seq_len - 1
581597 else :
@@ -1345,14 +1361,9 @@ def warmup(
13451361 max_exp + 1 ,
13461362 )
13471363 ]
1348- prefill_seqlen_list = [
1349- seq
1350- for seq in range (
1351- PAD_SEQUENCE_TO_MULTIPLE_OF ,
1352- max_input_tokens ,
1353- PAD_SEQUENCE_TO_MULTIPLE_OF ,
1354- )
1355- ]
1364+ prefill_seqlen_list = iterate_powers_of_base (
1365+ max_input_tokens , PAD_SEQUENCE_TO_MULTIPLE_OF , SEQ_LEN_EXPONENT_BASE
1366+ )
13561367 prefill_seqlen_list .append (max_input_tokens )
13571368 prefill_batch_size_list .sort (reverse = True )
13581369 prefill_seqlen_list .sort (reverse = True )
0 commit comments