From c8e0337b65e55d23aafd608c6b5e6d6de0e27009 Mon Sep 17 00:00:00 2001 From: Panchajanya1999 Date: Sat, 20 Jul 2024 11:56:24 +0530 Subject: [PATCH 1/2] create_completion: Add support for Groq API Signed-off-by: Panchajanya1999 --- create_completion.py | 50 ++++++++++++++++++++++++++++++++++++++------ zsh_codex.plugin.zsh | 2 +- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/create_completion.py b/create_completion.py index 4e9ba2f..70cedce 100755 --- a/create_completion.py +++ b/create_completion.py @@ -5,7 +5,7 @@ import configparser import argparse -# Conditionally import OpenAI and Google Generative AI +# Conditionally import OpenAI, Google Generative AI, and Groq try: from openai import OpenAI except ImportError: @@ -16,10 +16,16 @@ except ImportError: genai = None +try: + from groq import Groq +except ImportError: + Groq = None + # Get config dir from environment or default to ~/.config CONFIG_DIR = os.getenv('XDG_CONFIG_HOME', os.path.expanduser('~/.config')) OPENAI_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'openaiapirc') GEMINI_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'geminiapirc') +GROQ_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'groqapirc') def create_template_ini_file(api_type): """ @@ -29,10 +35,14 @@ def create_template_ini_file(api_type): file_path = OPENAI_API_KEYS_LOCATION content = '[openai]\nsecret_key=\n' url = 'https://platform.openai.com/api-keys' - else: # gemini + elif api_type == 'gemini': file_path = GEMINI_API_KEYS_LOCATION content = '[gemini]\napi_key=\n' url = 'Google AI Studio' + else: # groq + file_path = GROQ_API_KEYS_LOCATION + content = '[groq]\napi_key=\n' + url = 'Groq API dashboard' if not os.path.isfile(file_path): with open(file_path, 'w') as f: @@ -60,12 +70,18 @@ def initialize_api(api_type): ) api_config.setdefault("model", "gpt-3.5-turbo-0613") return client, api_config - else: # gemini + elif api_type == 'gemini': config.read(GEMINI_API_KEYS_LOCATION) api_config = {k: v.strip("\"'") for k, v in config["gemini"].items()} genai.configure(api_key=api_config["api_key"]) api_config.setdefault("model", "gemini-1.5-pro-latest") return genai, api_config + else: # groq + config.read(GROQ_API_KEYS_LOCATION) + api_config = {k: v.strip("\"'") for k, v in config["groq"].items()} + client = Groq(api_key=api_config["api_key"]) + api_config.setdefault("model", "llama3-8b-8192") + return client, api_config def get_completion(api_type, client, config, full_command): if api_type == 'openai': @@ -84,16 +100,35 @@ def get_completion(api_type, client, config, full_command): temperature=float(config.get("temperature", 1.0)) ) return response.choices[0].message.content - else: # gemini + elif api_type == 'gemini': model = client.GenerativeModel(config["model"]) chat = model.start_chat(history=[]) prompt = "You are a zsh shell expert, please help me complete the following command. Only output the completed command, no need for any other explanation. Do not put the completed command in a code block.\n\n" + full_command response = chat.send_message(prompt) return response.text + else: # groq + response = client.chat.completions.create( + model=config["model"], + messages=[ + { + "role": 'system', + "content": "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block.", + }, + { + "role": 'user', + "content": full_command, + } + ], + temperature=float(config.get("temperature", 0.5)), + max_tokens=int(config.get("max_tokens", 1024)), + top_p=float(config.get("top_p", 0.65)), + stream=False, + ) + return response.choices[0].message.content def main(): parser = argparse.ArgumentParser(description="Generate command completions using AI.") - parser.add_argument('--api', choices=['openai', 'gemini'], default='openai', help="Choose the API to use (default: openai)") + parser.add_argument('--api', choices=['openai', 'gemini', 'groq'], default='openai', help="Choose the API to use (default: openai)") parser.add_argument('cursor_position', type=int, help="Cursor position in the input buffer") args = parser.parse_args() @@ -103,6 +138,9 @@ def main(): elif args.api == 'gemini' and genai is None: print("Google Generative AI library is not installed. Please install it using 'pip install google-generativeai'") sys.exit(1) + elif args.api == 'groq' and Groq is None: + print("Groq library is not installed. Please install it using 'pip install groq'") + sys.exit(1) client, config = initialize_api(args.api) @@ -135,4 +173,4 @@ def main(): sys.stdout.write(completion) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/zsh_codex.plugin.zsh b/zsh_codex.plugin.zsh index 3699067..18a7500 100644 --- a/zsh_codex.plugin.zsh +++ b/zsh_codex.plugin.zsh @@ -2,7 +2,7 @@ # This ZSH plugin reads the text from the current buffer # and uses a Python script to complete the text. -api="openai" +api="${ZSH_CODEX_AI_SERVICE:-groq}" # Default to OpenAI if not set create_completion() { # Get the text typed until now. From e6fc6d400393f71f79c5cf548aaf37d45008a6cd Mon Sep 17 00:00:00 2001 From: Panchajanya1999 Date: Sun, 8 Sep 2024 11:07:08 +0530 Subject: [PATCH 2/2] create_completion: Optimize code structure This is a structural change, no features are added or deleted. Signed-off-by: Panchajanya1999 --- create_completion.py | 100 ++++++++++++++----------------------------- 1 file changed, 32 insertions(+), 68 deletions(-) diff --git a/create_completion.py b/create_completion.py index 70cedce..de207e5 100755 --- a/create_completion.py +++ b/create_completion.py @@ -5,7 +5,7 @@ import configparser import argparse -# Conditionally import OpenAI, Google Generative AI, and Groq +# Check for required libraries try: from openai import OpenAI except ImportError: @@ -21,81 +21,61 @@ except ImportError: Groq = None -# Get config dir from environment or default to ~/.config CONFIG_DIR = os.getenv('XDG_CONFIG_HOME', os.path.expanduser('~/.config')) OPENAI_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'openaiapirc') GEMINI_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'geminiapirc') GROQ_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'groqapirc') def create_template_ini_file(api_type): - """ - If the ini file does not exist create it and add the api_key placeholder - """ - if api_type == 'openai': - file_path = OPENAI_API_KEYS_LOCATION - content = '[openai]\nsecret_key=\n' - url = 'https://platform.openai.com/api-keys' - elif api_type == 'gemini': - file_path = GEMINI_API_KEYS_LOCATION - content = '[gemini]\napi_key=\n' - url = 'Google AI Studio' - else: # groq - file_path = GROQ_API_KEYS_LOCATION - content = '[groq]\napi_key=\n' - url = 'Groq API dashboard' - + file_info = { + 'openai': (OPENAI_API_KEYS_LOCATION, '[openai]\nsecret_key=\n', 'https://platform.openai.com/api-keys'), + 'gemini': (GEMINI_API_KEYS_LOCATION, '[gemini]\napi_key=\n', 'Google AI Studio'), + 'groq': (GROQ_API_KEYS_LOCATION, '[groq]\napi_key=\n', 'Groq API dashboard') + } + + file_path, content, url = file_info[api_type] + if not os.path.isfile(file_path): with open(file_path, 'w') as f: f.write(content) - + print(f'{api_type.capitalize()} API config file created at {file_path}') print('Please edit it and add your API key') print(f'If you do not yet have an API key, you can get it from: {url}') sys.exit(1) def initialize_api(api_type): - """ - Initialize the specified API - """ create_template_ini_file(api_type) config = configparser.ConfigParser() + config.read(os.path.join(CONFIG_DIR, f'{api_type}apirc')) + api_config = {k: v.strip("\"'") for k, v in config[api_type].items()} if api_type == 'openai': - config.read(OPENAI_API_KEYS_LOCATION) - api_config = {k: v.strip("\"'") for k, v in config["openai"].items()} client = OpenAI( api_key=api_config["secret_key"], base_url=api_config.get("base_url", "https://api.openai.com/v1"), organization=api_config.get("organization") ) - api_config.setdefault("model", "gpt-3.5-turbo-0613") - return client, api_config + api_config["model"] = api_config.get("model", "gpt-3.5-turbo-0613") elif api_type == 'gemini': - config.read(GEMINI_API_KEYS_LOCATION) - api_config = {k: v.strip("\"'") for k, v in config["gemini"].items()} genai.configure(api_key=api_config["api_key"]) - api_config.setdefault("model", "gemini-1.5-pro-latest") - return genai, api_config + client = genai + api_config["model"] = api_config.get("model", "gemini-1.5-pro-latest") else: # groq - config.read(GROQ_API_KEYS_LOCATION) - api_config = {k: v.strip("\"'") for k, v in config["groq"].items()} client = Groq(api_key=api_config["api_key"]) - api_config.setdefault("model", "llama3-8b-8192") - return client, api_config + api_config["model"] = api_config.get("model", "llama3-8b-8192") + + return client, api_config def get_completion(api_type, client, config, full_command): + system_message = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block." + if api_type == 'openai': response = client.chat.completions.create( model=config["model"], messages=[ - { - "role": 'system', - "content": "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block.", - }, - { - "role": 'user', - "content": full_command, - } + {"role": 'system', "content": system_message}, + {"role": 'user', "content": full_command}, ], temperature=float(config.get("temperature", 1.0)) ) @@ -103,21 +83,15 @@ def get_completion(api_type, client, config, full_command): elif api_type == 'gemini': model = client.GenerativeModel(config["model"]) chat = model.start_chat(history=[]) - prompt = "You are a zsh shell expert, please help me complete the following command. Only output the completed command, no need for any other explanation. Do not put the completed command in a code block.\n\n" + full_command + prompt = f"{system_message}\n\n{full_command}" response = chat.send_message(prompt) return response.text else: # groq response = client.chat.completions.create( model=config["model"], messages=[ - { - "role": 'system', - "content": "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block.", - }, - { - "role": 'user', - "content": full_command, - } + {"role": 'system', "content": system_message}, + {"role": 'user', "content": full_command}, ], temperature=float(config.get("temperature", 0.5)), max_tokens=int(config.get("max_tokens", 1024)), @@ -132,24 +106,18 @@ def main(): parser.add_argument('cursor_position', type=int, help="Cursor position in the input buffer") args = parser.parse_args() - if args.api == 'openai' and OpenAI is None: - print("OpenAI library is not installed. Please install it using 'pip install openai'") - sys.exit(1) - elif args.api == 'gemini' and genai is None: - print("Google Generative AI library is not installed. Please install it using 'pip install google-generativeai'") - sys.exit(1) - elif args.api == 'groq' and Groq is None: - print("Groq library is not installed. Please install it using 'pip install groq'") + api_libs = {'openai': OpenAI, 'gemini': genai, 'groq': Groq} + if api_libs[args.api] is None: + print(f"{args.api.capitalize()} library is not installed. Please install it using 'pip install {args.api}'") sys.exit(1) client, config = initialize_api(args.api) - # Read the input prompt from stdin. buffer = sys.stdin.read() zsh_prefix = '#!/bin/zsh\n\n' buffer_prefix = buffer[:args.cursor_position] buffer_suffix = buffer[args.cursor_position:] - full_command = zsh_prefix + buffer_prefix + buffer_suffix + full_command = f"{zsh_prefix}{buffer_prefix}{buffer_suffix}" completion = get_completion(args.api, client, config, full_command) @@ -157,20 +125,16 @@ def main(): completion = completion[len(zsh_prefix):] line_prefix = buffer_prefix.rsplit("\n", 1)[-1] - # Handle all the different ways the command can be returned for prefix in [buffer_prefix, line_prefix]: if completion.startswith(prefix): completion = completion[len(prefix):] break - if buffer_suffix and completion.endswith(buffer_suffix): - completion = completion[:-len(buffer_suffix)] - - completion = completion.strip("\n") + completion = completion.rstrip(buffer_suffix).strip("\n") if line_prefix.strip().startswith("#"): - completion = "\n" + completion + completion = f"\n{completion}" sys.stdout.write(completion) if __name__ == "__main__": - main() \ No newline at end of file + main()