Skip to content

Commit 3e242ce

Browse files
quic-xiyushisanising
authored andcommitted
Update example with new random sampling logic
Signed-off-by: quic-sanising <sanising@qti.qualcomm.com> Signed-off-by: sanising <sanising@qti.qualcomm.com>
1 parent e06e175 commit 3e242ce

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

examples/on_device_sampling.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def main(args, **kwargs):
2828
if include_sampler is not None:
2929
return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true"
3030
max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512))
31+
np.random.seed(int(args.random_number))
3132
sampling_params = {
3233
"repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
3334
"presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
@@ -36,7 +37,9 @@ def main(args, **kwargs):
3637
"top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1),
3738
"top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
3839
"min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
39-
"random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1),
40+
"random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype(
41+
np.float32
42+
),
4043
}
4144
qaic_config = {
4245
k: v
@@ -110,10 +113,10 @@ def main(args, **kwargs):
110113
--repetition-penalty 1.9 \
111114
--presence-penalty 0.8 \
112115
--temperature 0.67 \
113-
--top-k 54720 \
116+
--top-k 54 \
114117
--top-p 0.89 \
115118
--min-p 0.6 \
116-
--random-number 0.26
119+
--random-number 26
117120
118121
2. For non-continuous batching:
119122
python3.10 examples/on_device_sampling.py \
@@ -130,10 +133,10 @@ def main(args, **kwargs):
130133
--repetition-penalty 1.9 \
131134
--presence-penalty 0.8 \
132135
--temperature 0.67 \
133-
--top-k 54720 \
136+
--top-k 54 \
134137
--top-p 0.89 \
135138
--min-p 0.6 \
136-
--random-number 0.26
139+
--random-number 26
137140
"""
138141

139142
parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling")

0 commit comments

Comments
 (0)