From 075ae7a5e044a59848d4508afcc01d46c989f642 Mon Sep 17 00:00:00 2001 From: zanderjiang Date: Fri, 24 Oct 2025 17:04:57 -0400 Subject: [PATCH] update kernel generator example to be more clear --- .../kernel_generator_example.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/kernel_generator/kernel_generator_example.py b/examples/kernel_generator/kernel_generator_example.py index d3371803..d495330e 100644 --- a/examples/kernel_generator/kernel_generator_example.py +++ b/examples/kernel_generator/kernel_generator_example.py @@ -19,21 +19,29 @@ def main(): """ Generate optimized solutions for all definitions in the traceset. """ - model_name = "gpt-5-2025-08-07" # Choose model here - language = "triton" - target_gpu = "B200" + # TODO: select model, language, target gpu, definition + model_name = "gpt-5-2025-08-07" # Choose author-model + language = "triton" # Target solution language + target_gpu = "B200" # Choose solution target GPU + definition = "" # Leave empty to generate solutions for all definitions # TODO: adjust local path to traceset - traceset_path = "/home/akj2/flashinfer-trace" + traceset_path = "/path/to/flashinfer-trace" print(f"Loading TraceSet from: {traceset_path}") traceset = TraceSet.from_path(traceset_path) - # all_definitions = list(traceset.definitions.keys()) - # Filter for rmsnorm definitions only - all_definitions = [name for name in traceset.definitions.keys() if "rmsnorm" in name.lower()] + all_definitions = list(traceset.definitions.keys()) - print(f"All definitions found: {len(all_definitions)}") + if definition: + if definition in all_definitions: + all_definitions = [definition] + print(f"Generating solution {definition}") + else: + print(f"Definition '{definition}' not found in traceset") + return + + print(f"Found {len(all_definitions)} definitions to generate solutions") api_key = os.getenv("LLM_API_KEY") base_url = os.getenv("BASE_URL")