From 14f4ad6ad7f60e199a7ced90a6266fbff8953e1d Mon Sep 17 00:00:00 2001 From: Sam Adam-Day Date: Wed, 25 Jun 2025 11:45:23 +0100 Subject: [PATCH 1/4] Started implementing Flash Attention 2 --- CHANGELOG.md | 2 + Dockerfile | 1 + README.md | 2 +- databases/language_models.csv | 128 +++++++++++----------- doc/docs/guides/installation.rst | 1 + nip/language_model_server/trainers/dpo.py | 50 ++++++++- nip/utils/language_model_database.py | 59 ++++++++-- pyproject.toml | 15 +-- uv.lock | 12 ++ 9 files changed, 176 insertions(+), 94 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d89c834..5ee4c43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ to communicate reliably. - Ability to set vLLM max LoRA rank automatically. - Agent-level hyper-parameter to enable quantisation for self-hosted models. - Using Liger kernel in DPO training for increased speed and lower memory usage. +- Enabled Flash Attention 2 and padding-free batching when doing DPO training on + supported models. ### Changed diff --git a/Dockerfile b/Dockerfile index 67e1073..7001f41 100644 --- a/Dockerfile +++ b/Dockerfile @@ -66,6 +66,7 @@ RUN grep timm== pyproject.toml \ | tar -xzC /root/neural-interactive-proofs/vendor # Install all the required packages +RUN uv sync --locked RUN uv sync --locked --extra lm-server # The default target doesn't do much else diff --git a/README.md b/README.md index 6410240..8034043 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ for a more detailed installation guide. you can do `uv sync --no-dev`. If you want to host open-weight language models on your machine, install the - `lm-server` optional dependencies: `uv sync --extra lm-server`. + `lm-server` optional dependencies after the first sync: `uv sync --extra lm-server`. If using `pip`, create a virtual environment, activate it, then run: diff --git a/databases/language_models.csv b/databases/language_models.csv index 57f0297..f81c56b 100644 --- a/databases/language_models.csv +++ b/databases/language_models.csv @@ -1,64 +1,64 @@ -Model Series,Model Name,Developer,API URI,Self Hosted URI,Parameters (10E+9),FLOPs (10E+23),OpenRouter Input Cost,OpenRouter Output Cost,MMLU-Pro -Jamba 1.6,Mini,AI21,OpenRouter/ai21/jamba-1.6-mini,SelfHosted/ai21labs/AI21-Jamba-Mini-1.6,52,,0.2,0.4,0.37 -Jamba 1.6,Large,AI21,OpenRouter/ai21/jamba-1.6-large,SelfHosted/ai21labs/AI21-Jamba-Large-1.6,398,,2,8,0.56 -Qwen 2.5 Instruct,0.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-0.5B-Instruct,0.5,,,,0.15 -Qwen 2.5 Instruct,1.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-1.5B-Instruct,1.5,1.6632,,,0.32 -Qwen 2.5 Instruct,3B,Alibaba,,SelfHosted/Qwen/Qwen2.5-3B-Instruct,3,3.3372,,,0.44 -Qwen 2.5 Instruct,7B,Alibaba,OpenRouter/qwen/qwen-2.5-7b-instruct,SelfHosted/Qwen/Qwen2.5-7B-Instruct,7,8.2188,0.3,0.3,0.45 -Qwen 2.5 Instruct,14B,Alibaba,,SelfHosted/Qwen/Qwen2.5-14B-Instruct,14,15.876,,,0.64 -Qwen 2.5 Instruct,32B,Alibaba,OpenRouter/qwen/qwen2.5-32b-instruct,SelfHosted/Qwen/Qwen2.5-32B-Instruct,32,35.1,0.79,0.79,0.7 -Qwen 2.5 Instruct,72B,Alibaba,OpenRouter/qwen/qwen-2.5-72b-instruct,SelfHosted/Qwen/Qwen2.5-72B-Instruct,72,78,0.7,0.7,0.72 -Claude 3,Haiku,Anthropic,OpenRouter/anthropic/claude-3-haiku,,,,0.25,1.25, -Claude 3,Sonnet,Anthropic,OpenRouter/anthropic/claude-3-sonnet,,,,3,15,0.58 -Claude 3,Opus,Anthropic,OpenRouter/anthropic/claude-3-opus,,,164.0001,15,75,0.7 -Claude 3.5,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.5-sonnet,,,365.0001,0.8,4,0.77 -Claude 3.7,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.7-sonnet,,,335,3,15,0.8 -Claude 4,Sonnet,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,3,15,0.87 -Claude 4,Opus,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,15,75,0.89 -Command,R7B,Cohere,OpenRouter/cohere/command-r7b-12-2024,SelfHosted/CohereLabs/c4ai-command-r7b-12-2024,7,,0.0375,0.15,0.29 -Command,R,Cohere,OpenRouter/cohere/command-r,SelfHosted/CohereLabs/c4ai-command-r-v01,35,,0.15,0.6,0.34 -Command,R+,Cohere,OpenRouter/cohere/command-r-plus,SelfHosted/CohereLabs/c4ai-command-r-plus,104,,3,15,0.43 -DeepSeek Coder Instruct,1.3B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-1.3b-instruct,1.3,,,, -DeepSeek Coder Instruct,6.7B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-6.7b-instruct,6.7,,,, -DeepSeek Coder Instruct,33B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-33b-instruct,33,,,, -DeepSeek Coder v2,Lite,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,16,,,,0.42 -DeepSeek Coder v2,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Instruct,236,,,,0.64 -DeepSeek V2.5,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-V2.5,236,17.892,,, -DeepSeek V3,,DeepSeek,OpenRouter/deepseek/deepseek-chat-v3-0324,SelfHosted/deepseek-ai/DeepSeek-V3-0324,671,34.078,0.27,1.1,0.76 -Gemma 2,2B,Google DeepMind,,SelfHosted/google/gemma-2-2b-it,2,,,,0.51 -Gemma 2,9B,Google DeepMind,OpenRouter/google/gemma-2-9b-it,SelfHosted/google/gemma-2-9b-it,9,4.32,0.2,0.2,0.5 -Gemma 2,27B,Google DeepMind,OpenRouter/google/gemma-2-27b-it,SelfHosted/google/gemma-2-27b-it,27,21.06,0.8,0.8,0.57 -Gemma 3,1B,Google DeepMind,,SelfHosted/google/gemma-3-1b-it,1,,,,0.1 -Gemma 3,4B,Google DeepMind,OpenRouter/google/gemma-3-4b-it,SelfHosted/google/gemma-3-4b-it,4,,0.02,0.04,0.42 -Gemma 3,12B,Google DeepMind,OpenRouter/google/gemma-3-12b-it,SelfHosted/google/gemma-3-12b-it,12,,0.05,0.1,0.6 -Gemma 3,27B,Google DeepMind,OpenRouter/google/gemma-3-27b-it,SelfHosted/google/gemma-3-27b-it,27,,0.2,0.4,0.67 -Llama 3.1 Instruct,8B,Meta,OpenRouter/meta-llama/llama-3.1-8b-instruct,SelfHosted/meta-llama/Llama-3.1-8B-Instruct,8,12.24,0.02,0.05,0.44 -Llama 3.1 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.1-70b-instruct,SelfHosted/meta-llama/Llama-3.1-70B-Instruct,70,79.29,0.12,0.3,0.56 -Llama 3.1 Instruct,405B,Meta,OpenRouter/meta-llama/llama-3.1-405b-instruct,SelfHosted/meta-llama/Llama-3.1-405B-Instruct,405,380,0.8,0.8,0.73 -Llama 3.2 Instruct,1B,Meta,OpenRouter/meta-llama/llama-3.2-1b-instruct,SelfHosted/meta-llama/Llama-3.2-1B-Instruct,1,,0.01,0.01,0.12 -Llama 3.2 Instruct,3B,Meta,OpenRouter/meta-llama/llama-3.2-3b-instruct,SelfHosted/meta-llama/Llama-3.2-3B-Instruct,3,1.7334,0.015,0.025,0.22 -Llama 3.3 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.3-70b-instruct,SelfHosted/meta-llama/Llama-3.3-70B-Instruct,70,68.649768,0.12,0.3,0.66 -Llama 4,Scout,Meta,OpenRouter/meta-llama/llama-4-scout,SelfHosted/meta-llama/Llama-4-Scout-17B-16E-Instruct,109,,0.1,0.4,0.74 -Llama 4,Maverick,Meta,OpenRouter/meta-llama/llama-4-maverick,SelfHosted/meta-llama/Llama-4-Maverick-17B-128E-Instruct,400,,0.2,0.8,0.81 -Phi 3,Mini,Microsoft,OpenRouter/microsoft/phi-3-mini-128k-instruct,SelfHosted/microsoft/Phi-3-mini-128k-instruct,3.8,,0.1,0.1,0.44 -Phi 3,Small,Microsoft,,SelfHosted/microsoft/Phi-3-small-128k-instruct,7,2.1312,,, -Phi 3,Medium,Microsoft,OpenRouter/microsoft/phi-3-medium-128k-instruct,SelfHosted/microsoft/Phi-3-medium-128k-instruct,14,4.032,1,1,0.52 -Phi 4,Mini,Microsoft,,SelfHosted/microsoft/Phi-4-mini-instruct,3.8,,,,0.53 -Phi 4,,Microsoft,OpenRouter/microsoft/phi-4,SelfHosted/microsoft/phi-4,14,9.3202015,0.07,0.14,0.7 -Codestral,,Mistral,OpenRouter/mistralai/codestral-2501,SelfHosted/mistralai/Codestral-22B-v0.1,22,,0.3,0.9,0.45 -Ministral,3B,Mistral,OpenRouter/mistralai/ministral-3b,SelfHosted/ministral/Ministral-3b-instruct,3,,0.04,0.04,0.34 -Ministral,8B,Mistral,OpenRouter/mistralai/ministral-8b,SelfHosted/mistralai/Ministral-8B-Instruct-2410,8,,0.1,0.1,0.39 -Mistral Instruct,7B,Mistral,OpenRouter/mistralai/mistral-7b-instruct,SelfHosted/mistralai/Mistral-7B-Instruct-v0.3,7,,0.03,0.055, -Mixtral Instruct (8 Experts),8x7B,Mistral,OpenRouter/mistralai/mixtral-8x7b-instruct,SelfHosted/mistralai/Mixtral-8x7B-Instruct-v0.1,56,,0.24,0.24, -Mixtral Instruct (8 Experts),8x22B,Mistral,OpenRouter/mistralai/mixtral-8x22b-instruct,SelfHosted/mistralai/Mixtral-8x22B-Instruct-v0.1,176,,0.9,0.9, -Mistral Large 2,,Mistral,OpenRouter/mistralai/mistral-large-2411,SelfHosted/mistralai/Pixtral-Large-Instruct-2411,123,213,2,6, -Mistral Small 3,,Mistral,OpenRouter/mistralai/mistral-small-24b-instruct-2501,SelfHosted/mistralai/Mistral-Small-24B-Instruct-2501,24,11.52,0.1,0.3, -GPT 4o,mini,OpenAI,OpenAI/gpt-4o-mini-2024-07-18,,,381.0001,2.5,10,0.63 -GPT 4o,,OpenAI,OpenAI/gpt-4o-2024-08-06,,,73.6001,0.150,0.6,0.75 -GPT 4.1,nano,OpenAI,OpenAI/gpt-4.1-nano-2025-04-14,,,,0.1,0.4,0.8 -GPT 4.1,mini,OpenAI,OpenAI/gpt-4.1-mini-2025-04-14,,,,0.4,1.6,0.88 -GPT 4.1,,OpenAI,OpenAI/gpt-4.1-2025-04-14,,,,2,8,0.9 -GPT o1,,OpenAI,OpenAI/o1-2024-12-17,,,,15,60,0.89 -GPT o3,,OpenAI,OpenAI/o3-2025-04-16,,,,2,8,0.85 -GPT o3,mini,OpenAI,OpenAI/o3-mini-2025-01-31,,,,1.1,4.4,0.87 -GPT o4,mini,OpenAI,OpenAI/o4-mini-2025-04-16,,,,1.1,4.4,0.83 \ No newline at end of file +Model Series,Model Name,Developer,API URI,Self Hosted URI,Parameters (10E+9),FLOPs (10E+23),OpenRouter Input Cost,OpenRouter Output Cost,MMLU-Pro,Flash Attention 2 +Jamba 1.6,Mini,AI21,OpenRouter/ai21/jamba-1.6-mini,SelfHosted/ai21labs/AI21-Jamba-Mini-1.6,52,,0.2,0.4,0.37, +Jamba 1.6,Large,AI21,OpenRouter/ai21/jamba-1.6-large,SelfHosted/ai21labs/AI21-Jamba-Large-1.6,398,,2,8,0.56, +Qwen 2.5 Instruct,0.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-0.5B-Instruct,0.5,,,,0.15,yes +Qwen 2.5 Instruct,1.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-1.5B-Instruct,1.5,1.6632,,,0.32,yes +Qwen 2.5 Instruct,3B,Alibaba,,SelfHosted/Qwen/Qwen2.5-3B-Instruct,3,3.3372,,,0.44,yes +Qwen 2.5 Instruct,7B,Alibaba,OpenRouter/qwen/qwen-2.5-7b-instruct,SelfHosted/Qwen/Qwen2.5-7B-Instruct,7,8.2188,0.3,0.3,0.45,yes +Qwen 2.5 Instruct,14B,Alibaba,,SelfHosted/Qwen/Qwen2.5-14B-Instruct,14,15.876,,,0.64,yes +Qwen 2.5 Instruct,32B,Alibaba,OpenRouter/qwen/qwen2.5-32b-instruct,SelfHosted/Qwen/Qwen2.5-32B-Instruct,32,35.1,0.79,0.79,0.7,yes +Qwen 2.5 Instruct,72B,Alibaba,OpenRouter/qwen/qwen-2.5-72b-instruct,SelfHosted/Qwen/Qwen2.5-72B-Instruct,72,78,0.7,0.7,0.72,yes +Claude 3,Haiku,Anthropic,OpenRouter/anthropic/claude-3-haiku,,,,0.25,1.25,, +Claude 3,Sonnet,Anthropic,OpenRouter/anthropic/claude-3-sonnet,,,,3,15,0.58, +Claude 3,Opus,Anthropic,OpenRouter/anthropic/claude-3-opus,,,164.0001,15,75,0.7, +Claude 3.5,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.5-sonnet,,,365.0001,0.8,4,0.77, +Claude 3.7,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.7-sonnet,,,335,3,15,0.8, +Claude 4,Sonnet,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,3,15,0.87, +Claude 4,Opus,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,15,75,0.89, +Command,R7B,Cohere,OpenRouter/cohere/command-r7b-12-2024,SelfHosted/CohereLabs/c4ai-command-r7b-12-2024,7,,0.0375,0.15,0.29, +Command,R,Cohere,OpenRouter/cohere/command-r,SelfHosted/CohereLabs/c4ai-command-r-v01,35,,0.15,0.6,0.34, +Command,R+,Cohere,OpenRouter/cohere/command-r-plus,SelfHosted/CohereLabs/c4ai-command-r-plus,104,,3,15,0.43, +DeepSeek Coder Instruct,1.3B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-1.3b-instruct,1.3,,,,, +DeepSeek Coder Instruct,6.7B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-6.7b-instruct,6.7,,,,, +DeepSeek Coder Instruct,33B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-33b-instruct,33,,,,, +DeepSeek Coder v2,Lite,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,16,,,,0.42, +DeepSeek Coder v2,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Instruct,236,,,,0.64, +DeepSeek V2.5,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-V2.5,236,17.892,,,, +DeepSeek V3,,DeepSeek,OpenRouter/deepseek/deepseek-chat-v3-0324,SelfHosted/deepseek-ai/DeepSeek-V3-0324,671,34.078,0.27,1.1,0.76, +Gemma 2,2B,Google DeepMind,,SelfHosted/google/gemma-2-2b-it,2,,,,0.51, +Gemma 2,9B,Google DeepMind,OpenRouter/google/gemma-2-9b-it,SelfHosted/google/gemma-2-9b-it,9,4.32,0.2,0.2,0.5, +Gemma 2,27B,Google DeepMind,OpenRouter/google/gemma-2-27b-it,SelfHosted/google/gemma-2-27b-it,27,21.06,0.8,0.8,0.57, +Gemma 3,1B,Google DeepMind,,SelfHosted/google/gemma-3-1b-it,1,,,,0.1, +Gemma 3,4B,Google DeepMind,OpenRouter/google/gemma-3-4b-it,SelfHosted/google/gemma-3-4b-it,4,,0.02,0.04,0.42, +Gemma 3,12B,Google DeepMind,OpenRouter/google/gemma-3-12b-it,SelfHosted/google/gemma-3-12b-it,12,,0.05,0.1,0.6, +Gemma 3,27B,Google DeepMind,OpenRouter/google/gemma-3-27b-it,SelfHosted/google/gemma-3-27b-it,27,,0.2,0.4,0.67, +Llama 3.1 Instruct,8B,Meta,OpenRouter/meta-llama/llama-3.1-8b-instruct,SelfHosted/meta-llama/Llama-3.1-8B-Instruct,8,12.24,0.02,0.05,0.44,yes +Llama 3.1 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.1-70b-instruct,SelfHosted/meta-llama/Llama-3.1-70B-Instruct,70,79.29,0.12,0.3,0.56,yes +Llama 3.1 Instruct,405B,Meta,OpenRouter/meta-llama/llama-3.1-405b-instruct,SelfHosted/meta-llama/Llama-3.1-405B-Instruct,405,380,0.8,0.8,0.73,yes +Llama 3.2 Instruct,1B,Meta,OpenRouter/meta-llama/llama-3.2-1b-instruct,SelfHosted/meta-llama/Llama-3.2-1B-Instruct,1,,0.01,0.01,0.12,yes +Llama 3.2 Instruct,3B,Meta,OpenRouter/meta-llama/llama-3.2-3b-instruct,SelfHosted/meta-llama/Llama-3.2-3B-Instruct,3,1.7334,0.015,0.025,0.22,yes +Llama 3.3 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.3-70b-instruct,SelfHosted/meta-llama/Llama-3.3-70B-Instruct,70,68.649768,0.12,0.3,0.66,yes +Llama 4,Scout,Meta,OpenRouter/meta-llama/llama-4-scout,SelfHosted/meta-llama/Llama-4-Scout-17B-16E-Instruct,109,,0.1,0.4,0.74,yes +Llama 4,Maverick,Meta,OpenRouter/meta-llama/llama-4-maverick,SelfHosted/meta-llama/Llama-4-Maverick-17B-128E-Instruct,400,,0.2,0.8,0.81,yes +Phi 3,Mini,Microsoft,OpenRouter/microsoft/phi-3-mini-128k-instruct,SelfHosted/microsoft/Phi-3-mini-128k-instruct,3.8,,0.1,0.1,0.44, +Phi 3,Small,Microsoft,,SelfHosted/microsoft/Phi-3-small-128k-instruct,7,2.1312,,,, +Phi 3,Medium,Microsoft,OpenRouter/microsoft/phi-3-medium-128k-instruct,SelfHosted/microsoft/Phi-3-medium-128k-instruct,14,4.032,1,1,0.52, +Phi 4,Mini,Microsoft,,SelfHosted/microsoft/Phi-4-mini-instruct,3.8,,,,0.53, +Phi 4,,Microsoft,OpenRouter/microsoft/phi-4,SelfHosted/microsoft/phi-4,14,9.3202015,0.07,0.14,0.7, +Codestral,,Mistral,OpenRouter/mistralai/codestral-2501,SelfHosted/mistralai/Codestral-22B-v0.1,22,,0.3,0.9,0.45, +Ministral,3B,Mistral,OpenRouter/mistralai/ministral-3b,SelfHosted/ministral/Ministral-3b-instruct,3,,0.04,0.04,0.34, +Ministral,8B,Mistral,OpenRouter/mistralai/ministral-8b,SelfHosted/mistralai/Ministral-8B-Instruct-2410,8,,0.1,0.1,0.39, +Mistral Instruct,7B,Mistral,OpenRouter/mistralai/mistral-7b-instruct,SelfHosted/mistralai/Mistral-7B-Instruct-v0.3,7,,0.03,0.055,,yes +Mixtral Instruct (8 Experts),8x7B,Mistral,OpenRouter/mistralai/mixtral-8x7b-instruct,SelfHosted/mistralai/Mixtral-8x7B-Instruct-v0.1,56,,0.24,0.24,, +Mixtral Instruct (8 Experts),8x22B,Mistral,OpenRouter/mistralai/mixtral-8x22b-instruct,SelfHosted/mistralai/Mixtral-8x22B-Instruct-v0.1,176,,0.9,0.9,, +Mistral Large 2,,Mistral,OpenRouter/mistralai/mistral-large-2411,SelfHosted/mistralai/Pixtral-Large-Instruct-2411,123,213,2,6,, +Mistral Small 3,,Mistral,OpenRouter/mistralai/mistral-small-24b-instruct-2501,SelfHosted/mistralai/Mistral-Small-24B-Instruct-2501,24,11.52,0.1,0.3,, +GPT 4o,mini,OpenAI,OpenAI/gpt-4o-mini-2024-07-18,,,381.0001,2.5,10,0.63, +GPT 4o,,OpenAI,OpenAI/gpt-4o-2024-08-06,,,73.6001,0.150,0.6,0.75, +GPT 4.1,nano,OpenAI,OpenAI/gpt-4.1-nano-2025-04-14,,,,0.1,0.4,0.8, +GPT 4.1,mini,OpenAI,OpenAI/gpt-4.1-mini-2025-04-14,,,,0.4,1.6,0.88, +GPT 4.1,,OpenAI,OpenAI/gpt-4.1-2025-04-14,,,,2,8,0.9, +GPT o1,,OpenAI,OpenAI/o1-2024-12-17,,,,15,60,0.89, +GPT o3,,OpenAI,OpenAI/o3-2025-04-16,,,,2,8,0.85, +GPT o3,mini,OpenAI,OpenAI/o3-mini-2025-01-31,,,,1.1,4.4,0.87, +GPT o4,mini,OpenAI,OpenAI/o4-mini-2025-04-16,,,,1.1,4.4,0.83, \ No newline at end of file diff --git a/doc/docs/guides/installation.rst b/doc/docs/guides/installation.rst index 79e8f34..7bf7183 100644 --- a/doc/docs/guides/installation.rst +++ b/doc/docs/guides/installation.rst @@ -58,6 +58,7 @@ Installation Steps .. code-tab:: bash Hosting the LM Server + uv sync uv sync --extra lm-server .. code-tab:: bash With ``pip`` diff --git a/nip/language_model_server/trainers/dpo.py b/nip/language_model_server/trainers/dpo.py index 4d84839..e866c30 100644 --- a/nip/language_model_server/trainers/dpo.py +++ b/nip/language_model_server/trainers/dpo.py @@ -9,6 +9,8 @@ from datasets import Dataset +import torch + from trl import DPOConfig, DPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE @@ -22,6 +24,7 @@ from nip.utils.env import get_env_var from nip.utils.types import HuggingFaceDpoDatasetItem from nip.utils.hugging_face import is_model_peft +from nip.utils.language_model_database import LanguageModelDatabase from nip.language_model_server.types import ( LmTrainingConfig, LmLoraAdapterConfig, @@ -209,6 +212,33 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam # The maximum length for a W&B job name is 128 characters. job_name = job_id[:127] + language_model_db = LanguageModelDatabase() + + is_peft = is_model_peft(config.model_name) + + if is_peft: + model_lora_config = LoraConfig.from_pretrained(config.model_name) + base_model_name = model_lora_config.base_model_name_or_path + else: + base_model_name = config.model_name + + lm_db_entry = language_model_db.get_by_model_provider_and_name( + "SelfHosted", base_model_name + ) + + torch_dtype = None + use_flash_attention_2 = ( + lm_db_entry.has_flash_attention_2 and config.mixed_precision in ("fp16", "bf16") + ) + if use_flash_attention_2: + logger.info(f"Using Flash Attention 2 for {base_model_name!r}.") + if config.mixed_precision == "fp16": + torch_dtype = torch.float16 + elif config.mixed_precision == "bf16": + torch_dtype = torch.bfloat16 + + # Only use padding-free batching if Flash Attention 2 is available, to avoid batch + # contamination issues. dpo_config = DPOConfig( **config.dpo_config.model_dump(), hub_model_id=new_model_name, @@ -220,15 +250,12 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam per_device_train_batch_size=config.per_device_train_batch_size, use_liger_kernel=config.use_liger_kernel, seed=config.seed, + padding_free=use_flash_attention_2, ) ignore_training_lora_config = False - if not is_model_peft(config.model_name): - model = AutoModelForCausalLM.from_pretrained(config.model_name) - - else: - model_lora_config = LoraConfig.from_pretrained(config.model_name) + if is_peft: # When reusing the LoRA adapter, make sure the model's LoRA configuration is # compatible with the training configuration. @@ -251,7 +278,10 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam ) model = AutoPeftModelForCausalLM.from_pretrained( - config.model_name, is_trainable=True + config.model_name, + is_trainable=True, + torch_dtype=torch_dtype, + use_flash_attention_2=use_flash_attention_2, ) # Sanity check: ensure that exactly the LoRA layers are trainable. @@ -275,6 +305,14 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam # already LoRA-adapted and the trainer will train the existing adapter. ignore_training_lora_config = True + else: + + model = AutoModelForCausalLM.from_pretrained( + config.model_name, + torch_dtype=torch_dtype, + use_flash_attention_2=use_flash_attention_2, + ) + if ignore_training_lora_config or config.training_lora_config is None: training_lora_config = None else: diff --git a/nip/utils/language_model_database.py b/nip/utils/language_model_database.py index 69c8549..b623ae8 100644 --- a/nip/utils/language_model_database.py +++ b/nip/utils/language_model_database.py @@ -3,7 +3,7 @@ This database holds metadata about each model, along with how to access them. """ -from typing import Annotated, Optional +from typing import Annotated, Optional, Union, get_origin, get_args from dataclasses import dataclass import dataclasses @@ -32,6 +32,7 @@ class LanguageModelDbEntry: mmlu_pro_score: Annotated[Optional[float], "MMLU-Pro"] = None openrouter_input_cost: Annotated[Optional[float], "OpenRouter Input Cost"] = None openrouter_output_cost: Annotated[Optional[float], "OpenRouter Output Cost"] = None + flash_attention_2: Annotated[Optional[bool], "Flash Attention 2"] = None @property def provider(self) -> str: @@ -50,6 +51,11 @@ def display_name(self) -> str: return f"{self.model_series} {self.model_name}" return self.model_series + @property + def has_flash_attention_2(self) -> bool: + """Whether the model supports Flash Attention 2.""" + return bool(self.flash_attention_2) + @classmethod def from_row(cls, row: pd.Series) -> "LanguageModelDbEntry": """Create a LanguageModelDbEntry from a database Pandas Series. @@ -68,12 +74,18 @@ def from_row(cls, row: pd.Series) -> "LanguageModelDbEntry": arguments = {} for field in dataclasses.fields(cls): header = field.type.__metadata__[0] + field_type = field.type.__origin__ value = row[header] if pd.isna(value): - if field.type.__origin__ is str: + if field_type is str: arguments[field.name] = "" else: arguments[field.name] = None + elif get_origin(field_type) is Union and bool in get_args(field_type): + if value == "yes": + arguments[field.name] = True + else: + arguments[field.name] = False elif isinstance(value, np.floating): arguments[field.name] = float(value) else: @@ -101,28 +113,30 @@ class LanguageModelDatabase: def __init__(self): self._db = pd.read_csv(LANGUAGE_MODEL_DB_DIR) - def get_by_agent_params( - self, agent_params: PureTextAgentParameters + def get_by_model_provider_and_name( + self, model_provider: str, model_name: str ) -> LanguageModelDbEntry: - """Find a language model entry for a given set of agent hyper-parameters. + """Find a language model entry by its provider and name. Parameters ---------- - agent_params : PureTextAgentParameters - The agent hyper-parameters to search for in the database + model_provider : str + The provider of the model (e.g., "OpenAI", "OpenRouter", "SelfHosted") + model_name : str + The name of the model (e.g., "qwen/qwen2.5-32b-instruct") Returns ------- LanguageModelDbEntry - The language model entry corresponding to the hyper parameters + The language model entry corresponding to the provider and name Raises ------ LanguageModelNotFound - If no entry is found in the database for the given hyper-parameters + If no entry is found in the database for the given provider and name """ - uri = f"{agent_params.model_provider}/{agent_params.model_name}" + uri = f"{model_provider}/{model_name}" if uri in self._db["API URI"].values: entry = self._db[self._db["API URI"] == uri].iloc[0] elif uri in self._db["Self Hosted URI"].values: @@ -132,6 +146,31 @@ def get_by_agent_params( return LanguageModelDbEntry.from_row(entry) + def get_by_agent_params( + self, agent_params: PureTextAgentParameters + ) -> LanguageModelDbEntry: + """Find a language model entry for a given set of agent hyper-parameters. + + Parameters + ---------- + agent_params : PureTextAgentParameters + The agent hyper-parameters to search for in the database + + Returns + ------- + LanguageModelDbEntry + The language model entry corresponding to the hyper parameters + + Raises + ------ + LanguageModelNotFound + If no entry is found in the database for the given hyper-parameters + """ + + return self.get_by_model_provider_and_name( + agent_params.model_provider, agent_params.model_name + ) + def get_by_hyper_params( self, hyper_params: HyperParameters, agent_name: str ) -> LanguageModelDbEntry: diff --git a/pyproject.toml b/pyproject.toml index c5b3fdc..739e320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ lm-server = [ "bitsandbytes==0.46.0", "deepspeed==0.17.1", "fastapi==0.115.6", + "flash-attn==2.8.0.post2", "liger-kernel==0.5.10", "loralib==0.1.2", "pydantic-settings==2.9.1", @@ -84,19 +85,7 @@ dev = [ [tool.uv] add-bounds = "exact" - -# [[tool.uv.index]] -# name = "pytorch-cu121" -# url = "https://download.pytorch.org/whl/cu121" -# explicit = true - -# [tool.uv.sources] -# torch = [ -# { index = "pytorch-cu121", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, -# ] -# torchvision = [ -# { index = "pytorch-cu121", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, -# ] +no-build-isolation-package = ["flash-attn"] [tool.black] extend-exclude = "nip/utils/runtime_module.py|doc/extensions" diff --git a/uv.lock b/uv.lock index d1acbac..50861fd 100644 --- a/uv.lock +++ b/uv.lock @@ -972,6 +972,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/43/d5147aadaa52558e94e024811f2f9543b4bd7203b3a9659eeb5dff9c61b3/flake8-7.1.0-py2.py3-none-any.whl", hash = "sha256:2e416edcc62471a64cea09353f4e7bdba32aeb079b6e360554c659a122b1bc6a", size = 57569, upload-time = "2024-06-15T21:37:05.342Z" }, ] +[[package]] +name = "flash-attn" +version = "2.8.0.post2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/5c/c7610beeb2fc0e70d0c09a93490bb2d07fb6c8fa1f80ef9617b0cd556d76/flash_attn-2.8.0.post2.tar.gz", hash = "sha256:b51d7015eb78f7ab2c332ec7c36681d7e97834f010b52eb4db1f118351982f58", size = 7857730, upload-time = "2025-06-15T01:39:32.219Z" } + [[package]] name = "fonttools" version = "4.58.1" @@ -2435,6 +2445,7 @@ lm-server = [ { name = "bitsandbytes" }, { name = "deepspeed" }, { name = "fastapi" }, + { name = "flash-attn" }, { name = "liger-kernel" }, { name = "loralib" }, { name = "pydantic-settings" }, @@ -2476,6 +2487,7 @@ requires-dist = [ { name = "deepspeed", marker = "extra == 'lm-server'", specifier = "==0.17.1" }, { name = "einops", specifier = "==0.8.0" }, { name = "fastapi", marker = "extra == 'lm-server'", specifier = "==0.115.6" }, + { name = "flash-attn", marker = "extra == 'lm-server'", specifier = "==2.8.0.post2" }, { name = "gitpython", specifier = "==3.1.44" }, { name = "httpx", specifier = "==0.27.0" }, { name = "huggingface-hub", specifier = "==0.32.4" }, From 7eb434e2fcb5cd4301957d6e6ed47bf92baef544 Mon Sep 17 00:00:00 2001 From: Sam Adam-Day Date: Wed, 25 Jun 2025 12:35:13 +0100 Subject: [PATCH 2/4] Using Ubuntu 22.04 in Dockerfile --- CHANGELOG.md | 1 + Dockerfile | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ee4c43..7e9d334 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ to communicate reliably. name itself. - Launching language model server with `uvicorn` directly (rather than using `fastapi run`), which has allowed displaying more log messages. +- The Dockerfile now uses Ubuntu 22.04. ### Fixed diff --git a/Dockerfile b/Dockerfile index 7001f41..43115b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM nvidia/cuda:12.0.1-devel-ubuntu20.04 AS base +FROM nvidia/cuda:12.0.1-devel-ubuntu22.04 AS base # Ports for the language model server and vLLM server ARG LM_SERVER_PORT=5000 From 49671f926fbcfd21d05881e3ac60bd7c2b95b222 Mon Sep 17 00:00:00 2001 From: Sam Adam-Day Date: Wed, 25 Jun 2025 15:51:59 +0100 Subject: [PATCH 3/4] Using `attn_implementation` over `use_flash_attention_2` --- nip/language_model_server/trainers/dpo.py | 24 +++++++++++------------ 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/nip/language_model_server/trainers/dpo.py b/nip/language_model_server/trainers/dpo.py index e866c30..700ec24 100644 --- a/nip/language_model_server/trainers/dpo.py +++ b/nip/language_model_server/trainers/dpo.py @@ -212,8 +212,6 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam # The maximum length for a W&B job name is 128 characters. job_name = job_id[:127] - language_model_db = LanguageModelDatabase() - is_peft = is_model_peft(config.model_name) if is_peft: @@ -222,20 +220,25 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam else: base_model_name = config.model_name + language_model_db = LanguageModelDatabase() lm_db_entry = language_model_db.get_by_model_provider_and_name( "SelfHosted", base_model_name ) - - torch_dtype = None use_flash_attention_2 = ( lm_db_entry.has_flash_attention_2 and config.mixed_precision in ("fp16", "bf16") ) + + extra_model_kwargs = {} + if use_flash_attention_2: + logger.info(f"Using Flash Attention 2 for {base_model_name!r}.") + + extra_model_kwargs["attn_implementation"] = "flash_attention_2" if config.mixed_precision == "fp16": - torch_dtype = torch.float16 + extra_model_kwargs["torch_dtype"] = torch.float16 elif config.mixed_precision == "bf16": - torch_dtype = torch.bfloat16 + extra_model_kwargs["torch_dtype"] = torch.bfloat16 # Only use padding-free batching if Flash Attention 2 is available, to avoid batch # contamination issues. @@ -278,10 +281,7 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam ) model = AutoPeftModelForCausalLM.from_pretrained( - config.model_name, - is_trainable=True, - torch_dtype=torch_dtype, - use_flash_attention_2=use_flash_attention_2, + config.model_name, is_trainable=True, **extra_model_kwargs ) # Sanity check: ensure that exactly the LoRA layers are trainable. @@ -308,9 +308,7 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam else: model = AutoModelForCausalLM.from_pretrained( - config.model_name, - torch_dtype=torch_dtype, - use_flash_attention_2=use_flash_attention_2, + config.model_name, **extra_model_kwargs ) if ignore_training_lora_config or config.training_lora_config is None: From 4fb1c6540c6dd1ee04b6aa3aa1d7b94951de7c3f Mon Sep 17 00:00:00 2001 From: Sam Adam-Day Date: Wed, 25 Jun 2025 15:58:40 +0100 Subject: [PATCH 4/4] Got GitHub workflow working with flash_attn --- .github/actions/setup-nip/action.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/setup-nip/action.yaml b/.github/actions/setup-nip/action.yaml index 6f8cf9a..3044f7c 100644 --- a/.github/actions/setup-nip/action.yaml +++ b/.github/actions/setup-nip/action.yaml @@ -19,5 +19,6 @@ runs: shell: bash - name: Install dependencies with uv run: | + uv sync --locked uv sync --locked --all-extras --dev shell: bash