diff --git a/generate.py b/generate.py index bb20d6c6..8b291a78 100644 --- a/generate.py +++ b/generate.py @@ -300,7 +300,7 @@ def main( decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) # Uncomment to squeeze more perf out of prefill - if args.compile_prefill: + if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True)