@@ -28,6 +28,7 @@ def main(args, **kwargs):
2828 if include_sampler is not None :
2929 return_pdfs = args .override_qaic_config .get ("aic_return_pdfs" , None ) == "true"
3030 max_top_k_ids = int (args .override_qaic_config .get ("max_top_k_ids" , 512 ))
31+ np .random .seed (int (args .random_number ))
3132 sampling_params = {
3233 "repetition_penalties" : np .array (args .repetition_penalty , dtype = np .float32 ).repeat (bs ).reshape (- 1 , 1 ),
3334 "presence_penalties" : np .array (args .presence_penalty , dtype = np .float32 ).repeat (bs ).reshape (- 1 , 1 ),
@@ -36,7 +37,9 @@ def main(args, **kwargs):
3637 "top_ks" : np .array (args .top_k , dtype = np .int32 ).repeat (bs ).reshape (- 1 , 1 ),
3738 "top_ps" : np .array (args .top_p , dtype = np .float32 ).repeat (bs ).reshape (- 1 , 1 ),
3839 "min_ps" : np .array (args .min_p , dtype = np .float32 ).repeat (bs ).reshape (- 1 , 1 ),
39- "random_numbers" : np .array (args .random_number , dtype = np .float32 ).repeat (bs ).reshape (- 1 , 1 ),
40+ "random_numbers" : np .tile (np .random .uniform (low = 0.0 , high = 1.0 , size = max_top_k_ids ), (bs , 1 )).astype (
41+ np .float32
42+ ),
4043 }
4144 qaic_config = {
4245 k : v
@@ -110,10 +113,10 @@ def main(args, **kwargs):
110113 --repetition-penalty 1.9 \
111114 --presence-penalty 0.8 \
112115 --temperature 0.67 \
113- --top-k 54720 \
116+ --top-k 54 \
114117 --top-p 0.89 \
115118 --min-p 0.6 \
116- --random-number 0. 26
119+ --random-number 26
117120
118121 2. For non-continuous batching:
119122 python3.10 examples/on_device_sampling.py \
@@ -130,10 +133,10 @@ def main(args, **kwargs):
130133 --repetition-penalty 1.9 \
131134 --presence-penalty 0.8 \
132135 --temperature 0.67 \
133- --top-k 54720 \
136+ --top-k 54 \
134137 --top-p 0.89 \
135138 --min-p 0.6 \
136- --random-number 0. 26
139+ --random-number 26
137140 """
138141
139142 parser = argparse .ArgumentParser (description = "Run QEfficient model with On Device Sampling" )
0 commit comments