Skip to content

Commit 59333c2

Browse files
authored
Merge pull request #89 from foundation-model-stack/jni/dev
skip device initialization warmup for senulator
2 parents 9779f4d + 1a3b997 commit 59333c2

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

scripts/inference.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,11 +803,16 @@ def infer(use_cache, do_sample, warmup):
803803
args.compile_dynamic_sendnn,
804804
**extra_generation_kwargs,
805805
)
806-
aiu_warmup_time = time.time()
807-
for sample, cache in itertools.product(do_sample, use_cache):
808-
infer(cache, sample, True)
809-
aiu_warmup_time = time.time() - aiu_warmup_time
810-
dprint(f"AIU warmup complete, took {aiu_warmup_time:.3f}s")
806+
if (
807+
args.device_type == "aiu"
808+
): # run device initialization warmup for AIU, skip for senulator
809+
aiu_warmup_time = time.time()
810+
for sample, cache in itertools.product(do_sample, use_cache):
811+
infer(cache, sample, True)
812+
aiu_warmup_time = time.time() - aiu_warmup_time
813+
dprint(
814+
f"AIU device initialization warmup complete, took {aiu_warmup_time:.3f}s"
815+
)
811816
else:
812817
for sample, cache in itertools.product(do_sample, use_cache):
813818
infer(cache, sample, True)

0 commit comments

Comments
 (0)