diff --git a/configs/ui/fuser_ui.yaml b/configs/ui/fuser_ui.yaml new file mode 100644 index 0000000..1ef85f9 --- /dev/null +++ b/configs/ui/fuser_ui.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - kernel_agent + +port: 8086 diff --git a/configs/ui/kernel_agent.yaml b/configs/ui/kernel_agent.yaml new file mode 100644 index 0000000..07fe782 --- /dev/null +++ b/configs/ui/kernel_agent.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +port: 8085 +host: localhost diff --git a/configs/ui/pipeline_ui.yaml b/configs/ui/pipeline_ui.yaml new file mode 100644 index 0000000..dd373dc --- /dev/null +++ b/configs/ui/pipeline_ui.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - kernel_agent + +port: 8087 diff --git a/pyproject.toml b/pyproject.toml index 37462ee..c402edd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ classifiers = [ ] dependencies = [ + "hydra-core", + "omegaconf", "openai", "anthropic", "jinja2", diff --git a/scripts/fuser_ui.py b/scripts/fuser_ui.py index c93c2dd..992a443 100644 --- a/scripts/fuser_ui.py +++ b/scripts/fuser_ui.py @@ -16,8 +16,8 @@ from __future__ import annotations -import argparse import ast +import logging import os import tarfile import time @@ -27,6 +27,9 @@ from pathlib import Path from typing import List, Optional, Tuple +import hydra +from omegaconf import DictConfig + import gradio as gr from dotenv import load_dotenv @@ -41,6 +44,10 @@ from Fuser.orchestrator import Orchestrator from Fuser.paths import ensure_abs_regular_file, make_run_dirs, PathSafetyError +# Turn off noisy HTTPX logging +httpx_logger = logging.getLogger("httpx") +httpx_logger.setLevel(logging.WARNING) + @dataclass class RunArtifacts: @@ -865,14 +872,17 @@ def generate( return app -def main() -> None: - parser = argparse.ArgumentParser(description="FuserAgent UI") - parser.add_argument("--port", type=int, default=8086) - parser.add_argument("--host", type=str, default="localhost") - args = parser.parse_args() - +@hydra.main( + version_base=None, + config_path=str(Path(__file__).resolve().parent.parent / "configs/ui"), + config_name="fuser_ui", +) +def main(cfg: DictConfig) -> None: app = build_interface() + port = cfg.port + host = cfg.host + print("🚀 Starting FuserAgent UI...") meta_keyfile = Path("/var/facebook/x509_identities/server.pem") @@ -880,13 +890,13 @@ def main() -> None: if is_meta_devserver: server_name = os.uname()[1] - print(f"🌐 Meta devserver detected. Visit https://{server_name}:{args.port}/") + print(f"🌐 Meta devserver detected. Visit https://{server_name}:{port}/") print("💡 Ensure you're on the Meta VPN.") app.launch( share=False, show_error=True, server_name=server_name, - server_port=args.port, + server_port=port, ssl_keyfile=str(meta_keyfile), ssl_certfile=str(meta_keyfile), ssl_verify=False, @@ -894,12 +904,12 @@ def main() -> None: inbrowser=False, ) else: - print(f"🌐 Visit http://{args.host}:{args.port}/") + print(f"🌐 Visit http://{host}:{port}/") app.launch( share=False, show_error=True, - server_name=args.host, - server_port=args.port, + server_name=host, + server_port=port, show_api=False, inbrowser=True, ) diff --git a/scripts/pipeline_ui.py b/scripts/pipeline_ui.py index 92ff9b8..2b00d9e 100644 --- a/scripts/pipeline_ui.py +++ b/scripts/pipeline_ui.py @@ -16,7 +16,7 @@ from __future__ import annotations -import argparse +import logging import os import sys import time @@ -26,6 +26,9 @@ from pathlib import Path from typing import List, Optional, Tuple +import hydra +from omegaconf import DictConfig + import gradio as gr from dotenv import load_dotenv @@ -37,6 +40,10 @@ MODEL_NAME_TO_CONFIG, ) +# Turn off noisy HTTPX logging +httpx_logger = logging.getLogger("httpx") +httpx_logger.setLevel(logging.WARNING) + def _list_kernelbench_problems(base: Path) -> List[Tuple[str, str]]: """Return list of (label, absolute_path) pairs for KernelBench problems.""" @@ -685,15 +692,18 @@ def on_run( return app -def main() -> None: - parser = argparse.ArgumentParser(description="Pipeline UI") - parser.add_argument("--port", type=int, default=8087) - parser.add_argument("--host", type=str, default="localhost") - args = parser.parse_args() - +@hydra.main( + version_base=None, + config_path=str(Path(__file__).resolve().parent.parent / "configs/ui"), + config_name="pipeline_ui", +) +def main(cfg: DictConfig) -> None: load_dotenv() app = build_interface() + port = cfg.port + host = cfg.host + print("🚀 Starting Pipeline UI...") # Mirror fuser_ui devserver behavior for Meta VPN environments @@ -702,13 +712,13 @@ def main() -> None: if is_meta_devserver: server_name = os.uname()[1] - print(f"🌐 Meta devserver detected. Visit https://{server_name}:{args.port}/") + print(f"🌐 Meta devserver detected. Visit https://{server_name}:{port}/") print("💡 Ensure you're on the Meta VPN.") app.launch( share=False, show_error=True, server_name=server_name, - server_port=args.port, + server_port=port, ssl_keyfile=str(meta_keyfile), ssl_certfile=str(meta_keyfile), ssl_verify=False, @@ -716,12 +726,12 @@ def main() -> None: inbrowser=False, ) else: - print(f"🌐 Visit http://{args.host}:{args.port}/") + print(f"🌐 Visit http://{host}:{port}/") app.launch( share=False, show_error=True, - server_name=args.host, - server_port=args.port, + server_name=host, + server_port=port, show_api=False, inbrowser=True, ) diff --git a/scripts/triton_ui.py b/scripts/triton_ui.py index 48dd804..0dee1c0 100644 --- a/scripts/triton_ui.py +++ b/scripts/triton_ui.py @@ -15,22 +15,28 @@ """Gradio UI for Triton Kernel Agent.""" -import argparse +import logging import os import time import traceback from pathlib import Path from typing import Any, Dict, Optional, Tuple +import hydra +from omegaconf import DictConfig + import gradio as gr from dotenv import load_dotenv - from triton_kernel_agent import TritonKernelAgent from triton_kernel_agent.providers.models import AVAILABLE_MODELS from triton_kernel_agent.providers.openai_provider import OpenAIProvider from triton_kernel_agent.providers.anthropic_provider import AnthropicProvider +# Turn off noisy HTTPX logging +httpx_logger = logging.getLogger("httpx") +httpx_logger.setLevel(logging.WARNING) + KERNELBENCH_BASE_PATH = ( Path(__file__).resolve().parent / "external" / "KernelBench" / "KernelBench" @@ -247,7 +253,7 @@ def _format_error_logs(self, result: Dict[str, Any], generation_time: float) -> """Format error logs for display""" logs = f"""## Generation Failed -**⏱️ Time:** {generation_time:.2f} seconds +**⏱️ Time:** {generation_time:.2f} seconds **❌ Error:** {result["message"]} **📁 Session:** `{os.path.basename(result["session_dir"])}` @@ -384,9 +390,9 @@ def _create_app() -> gr.Blocks: gr.Markdown( """ # 🚀 Triton Kernel Agent - + **AI-Powered GPU Kernel Generation** - + Generate optimized OpenAI Triton kernels from high-level descriptions. """ ) @@ -630,13 +636,13 @@ def handle_problem_select(evt: gr.SelectData): gr.Markdown( """ --- - + **💡 Tips:** - Be specific about input/output shapes and data types - - Include PyTorch equivalent code for reference + - Include PyTorch equivalent code for reference - Check the logs for detailed generation information - - **🔧 Configuration:** + + **🔧 Configuration:** - Provide your OpenAI or Anthropic API key above (not saved; session-only) - Or set the appropriate env var in `.env` (OPENAI_API_KEY or ANTHROPIC_API_KEY) - The key is only used for this session and automatically cleared @@ -646,14 +652,15 @@ def handle_problem_select(evt: gr.SelectData): return app -def main(): +@hydra.main( + version_base=None, + config_path=str(Path(__file__).resolve().parent.parent / "configs/ui"), + config_name="kernel_agent", +) +def main(cfg: DictConfig): """Create and launch the Gradio interface""" - parser = argparse.ArgumentParser(description="Triton Kernel Agent UI") - parser.add_argument("--port", type=int, default=8085, help="Port to run the UI on") - parser.add_argument("--host", type=str, default="localhost", help="Host to bind to") - args = parser.parse_args() - app = _create_app() + port = cfg.port # Check if running on Meta devserver (has Meta SSL certs) meta_keyfile = "/var/facebook/x509_identities/server.pem" @@ -665,14 +672,14 @@ def main(): if is_meta_devserver: # Meta devserver configuration server_name = os.uname()[1] # Get devserver hostname - print(f"🌐 Opening on Meta devserver: https://{server_name}:{args.port}/") + print(f"🌐 Opening on Meta devserver: https://{server_name}:{port}/") print("💡 Make sure you're connected to Meta VPN to access the demo") app.launch( share=False, show_error=True, server_name=server_name, - server_port=args.port, + server_port=port, ssl_keyfile=meta_keyfile, ssl_certfile=meta_keyfile, ssl_verify=False, @@ -681,16 +688,18 @@ def main(): ) else: # Local development configuration - print(f"🌐 Opening locally: http://{args.host}:{args.port}/") + host = cfg.host + + print(f"🌐 Opening locally: http://{host}:{port}/") print( - f"🚨 IMPORTANT: If Chrome shows blank page, try Safari: open -a Safari http://{args.host}:{args.port}/ 🚨" + f"🚨 IMPORTANT: If Chrome shows blank page, try Safari: open -a Safari http://{host}:{port}/ 🚨" ) app.launch( share=False, show_error=True, - server_name=args.host, - server_port=args.port, + server_name=host, + server_port=port, show_api=False, inbrowser=True, # Auto-open browser for local development )