Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/setup-nip/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ runs:
shell: bash
- name: Install dependencies with uv
run: |
uv sync --locked
uv sync --locked --all-extras --dev
shell: bash
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ to communicate reliably.
- Ability to set vLLM max LoRA rank automatically.
- Agent-level hyper-parameter to enable quantisation for self-hosted models.
- Using Liger kernel in DPO training for increased speed and lower memory usage.
- Enabled Flash Attention 2 and padding-free batching when doing DPO training on
supported models.


### Changed
Expand All @@ -48,6 +50,7 @@ to communicate reliably.
name itself.
- Launching language model server with `uvicorn` directly (rather than using `fastapi
run`), which has allowed displaying more log messages.
- The Dockerfile now uses Ubuntu 22.04.


### Fixed
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1
FROM nvidia/cuda:12.0.1-devel-ubuntu20.04 AS base
FROM nvidia/cuda:12.0.1-devel-ubuntu22.04 AS base

# Ports for the language model server and vLLM server
ARG LM_SERVER_PORT=5000
Expand Down Expand Up @@ -66,6 +66,7 @@ RUN grep timm== pyproject.toml \
| tar -xzC /root/neural-interactive-proofs/vendor

# Install all the required packages
RUN uv sync --locked
RUN uv sync --locked --extra lm-server

# The default target doesn't do much else
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ for a more detailed installation guide.
you can do `uv sync --no-dev`.

If you want to host open-weight language models on your machine, install the
`lm-server` optional dependencies: `uv sync --extra lm-server`.
`lm-server` optional dependencies after the first sync: `uv sync --extra lm-server`.

If using `pip`, create a virtual environment, activate it, then run:

Expand Down
128 changes: 64 additions & 64 deletions databases/language_models.csv
Original file line number Diff line number Diff line change
@@ -1,64 +1,64 @@
Model Series,Model Name,Developer,API URI,Self Hosted URI,Parameters (10E+9),FLOPs (10E+23),OpenRouter Input Cost,OpenRouter Output Cost,MMLU-Pro
Jamba 1.6,Mini,AI21,OpenRouter/ai21/jamba-1.6-mini,SelfHosted/ai21labs/AI21-Jamba-Mini-1.6,52,,0.2,0.4,0.37
Jamba 1.6,Large,AI21,OpenRouter/ai21/jamba-1.6-large,SelfHosted/ai21labs/AI21-Jamba-Large-1.6,398,,2,8,0.56
Qwen 2.5 Instruct,0.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-0.5B-Instruct,0.5,,,,0.15
Qwen 2.5 Instruct,1.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-1.5B-Instruct,1.5,1.6632,,,0.32
Qwen 2.5 Instruct,3B,Alibaba,,SelfHosted/Qwen/Qwen2.5-3B-Instruct,3,3.3372,,,0.44
Qwen 2.5 Instruct,7B,Alibaba,OpenRouter/qwen/qwen-2.5-7b-instruct,SelfHosted/Qwen/Qwen2.5-7B-Instruct,7,8.2188,0.3,0.3,0.45
Qwen 2.5 Instruct,14B,Alibaba,,SelfHosted/Qwen/Qwen2.5-14B-Instruct,14,15.876,,,0.64
Qwen 2.5 Instruct,32B,Alibaba,OpenRouter/qwen/qwen2.5-32b-instruct,SelfHosted/Qwen/Qwen2.5-32B-Instruct,32,35.1,0.79,0.79,0.7
Qwen 2.5 Instruct,72B,Alibaba,OpenRouter/qwen/qwen-2.5-72b-instruct,SelfHosted/Qwen/Qwen2.5-72B-Instruct,72,78,0.7,0.7,0.72
Claude 3,Haiku,Anthropic,OpenRouter/anthropic/claude-3-haiku,,,,0.25,1.25,
Claude 3,Sonnet,Anthropic,OpenRouter/anthropic/claude-3-sonnet,,,,3,15,0.58
Claude 3,Opus,Anthropic,OpenRouter/anthropic/claude-3-opus,,,164.0001,15,75,0.7
Claude 3.5,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.5-sonnet,,,365.0001,0.8,4,0.77
Claude 3.7,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.7-sonnet,,,335,3,15,0.8
Claude 4,Sonnet,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,3,15,0.87
Claude 4,Opus,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,15,75,0.89
Command,R7B,Cohere,OpenRouter/cohere/command-r7b-12-2024,SelfHosted/CohereLabs/c4ai-command-r7b-12-2024,7,,0.0375,0.15,0.29
Command,R,Cohere,OpenRouter/cohere/command-r,SelfHosted/CohereLabs/c4ai-command-r-v01,35,,0.15,0.6,0.34
Command,R+,Cohere,OpenRouter/cohere/command-r-plus,SelfHosted/CohereLabs/c4ai-command-r-plus,104,,3,15,0.43
DeepSeek Coder Instruct,1.3B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-1.3b-instruct,1.3,,,,
DeepSeek Coder Instruct,6.7B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-6.7b-instruct,6.7,,,,
DeepSeek Coder Instruct,33B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-33b-instruct,33,,,,
DeepSeek Coder v2,Lite,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,16,,,,0.42
DeepSeek Coder v2,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Instruct,236,,,,0.64
DeepSeek V2.5,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-V2.5,236,17.892,,,
DeepSeek V3,,DeepSeek,OpenRouter/deepseek/deepseek-chat-v3-0324,SelfHosted/deepseek-ai/DeepSeek-V3-0324,671,34.078,0.27,1.1,0.76
Gemma 2,2B,Google DeepMind,,SelfHosted/google/gemma-2-2b-it,2,,,,0.51
Gemma 2,9B,Google DeepMind,OpenRouter/google/gemma-2-9b-it,SelfHosted/google/gemma-2-9b-it,9,4.32,0.2,0.2,0.5
Gemma 2,27B,Google DeepMind,OpenRouter/google/gemma-2-27b-it,SelfHosted/google/gemma-2-27b-it,27,21.06,0.8,0.8,0.57
Gemma 3,1B,Google DeepMind,,SelfHosted/google/gemma-3-1b-it,1,,,,0.1
Gemma 3,4B,Google DeepMind,OpenRouter/google/gemma-3-4b-it,SelfHosted/google/gemma-3-4b-it,4,,0.02,0.04,0.42
Gemma 3,12B,Google DeepMind,OpenRouter/google/gemma-3-12b-it,SelfHosted/google/gemma-3-12b-it,12,,0.05,0.1,0.6
Gemma 3,27B,Google DeepMind,OpenRouter/google/gemma-3-27b-it,SelfHosted/google/gemma-3-27b-it,27,,0.2,0.4,0.67
Llama 3.1 Instruct,8B,Meta,OpenRouter/meta-llama/llama-3.1-8b-instruct,SelfHosted/meta-llama/Llama-3.1-8B-Instruct,8,12.24,0.02,0.05,0.44
Llama 3.1 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.1-70b-instruct,SelfHosted/meta-llama/Llama-3.1-70B-Instruct,70,79.29,0.12,0.3,0.56
Llama 3.1 Instruct,405B,Meta,OpenRouter/meta-llama/llama-3.1-405b-instruct,SelfHosted/meta-llama/Llama-3.1-405B-Instruct,405,380,0.8,0.8,0.73
Llama 3.2 Instruct,1B,Meta,OpenRouter/meta-llama/llama-3.2-1b-instruct,SelfHosted/meta-llama/Llama-3.2-1B-Instruct,1,,0.01,0.01,0.12
Llama 3.2 Instruct,3B,Meta,OpenRouter/meta-llama/llama-3.2-3b-instruct,SelfHosted/meta-llama/Llama-3.2-3B-Instruct,3,1.7334,0.015,0.025,0.22
Llama 3.3 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.3-70b-instruct,SelfHosted/meta-llama/Llama-3.3-70B-Instruct,70,68.649768,0.12,0.3,0.66
Llama 4,Scout,Meta,OpenRouter/meta-llama/llama-4-scout,SelfHosted/meta-llama/Llama-4-Scout-17B-16E-Instruct,109,,0.1,0.4,0.74
Llama 4,Maverick,Meta,OpenRouter/meta-llama/llama-4-maverick,SelfHosted/meta-llama/Llama-4-Maverick-17B-128E-Instruct,400,,0.2,0.8,0.81
Phi 3,Mini,Microsoft,OpenRouter/microsoft/phi-3-mini-128k-instruct,SelfHosted/microsoft/Phi-3-mini-128k-instruct,3.8,,0.1,0.1,0.44
Phi 3,Small,Microsoft,,SelfHosted/microsoft/Phi-3-small-128k-instruct,7,2.1312,,,
Phi 3,Medium,Microsoft,OpenRouter/microsoft/phi-3-medium-128k-instruct,SelfHosted/microsoft/Phi-3-medium-128k-instruct,14,4.032,1,1,0.52
Phi 4,Mini,Microsoft,,SelfHosted/microsoft/Phi-4-mini-instruct,3.8,,,,0.53
Phi 4,,Microsoft,OpenRouter/microsoft/phi-4,SelfHosted/microsoft/phi-4,14,9.3202015,0.07,0.14,0.7
Codestral,,Mistral,OpenRouter/mistralai/codestral-2501,SelfHosted/mistralai/Codestral-22B-v0.1,22,,0.3,0.9,0.45
Ministral,3B,Mistral,OpenRouter/mistralai/ministral-3b,SelfHosted/ministral/Ministral-3b-instruct,3,,0.04,0.04,0.34
Ministral,8B,Mistral,OpenRouter/mistralai/ministral-8b,SelfHosted/mistralai/Ministral-8B-Instruct-2410,8,,0.1,0.1,0.39
Mistral Instruct,7B,Mistral,OpenRouter/mistralai/mistral-7b-instruct,SelfHosted/mistralai/Mistral-7B-Instruct-v0.3,7,,0.03,0.055,
Mixtral Instruct (8 Experts),8x7B,Mistral,OpenRouter/mistralai/mixtral-8x7b-instruct,SelfHosted/mistralai/Mixtral-8x7B-Instruct-v0.1,56,,0.24,0.24,
Mixtral Instruct (8 Experts),8x22B,Mistral,OpenRouter/mistralai/mixtral-8x22b-instruct,SelfHosted/mistralai/Mixtral-8x22B-Instruct-v0.1,176,,0.9,0.9,
Mistral Large 2,,Mistral,OpenRouter/mistralai/mistral-large-2411,SelfHosted/mistralai/Pixtral-Large-Instruct-2411,123,213,2,6,
Mistral Small 3,,Mistral,OpenRouter/mistralai/mistral-small-24b-instruct-2501,SelfHosted/mistralai/Mistral-Small-24B-Instruct-2501,24,11.52,0.1,0.3,
GPT 4o,mini,OpenAI,OpenAI/gpt-4o-mini-2024-07-18,,,381.0001,2.5,10,0.63
GPT 4o,,OpenAI,OpenAI/gpt-4o-2024-08-06,,,73.6001,0.150,0.6,0.75
GPT 4.1,nano,OpenAI,OpenAI/gpt-4.1-nano-2025-04-14,,,,0.1,0.4,0.8
GPT 4.1,mini,OpenAI,OpenAI/gpt-4.1-mini-2025-04-14,,,,0.4,1.6,0.88
GPT 4.1,,OpenAI,OpenAI/gpt-4.1-2025-04-14,,,,2,8,0.9
GPT o1,,OpenAI,OpenAI/o1-2024-12-17,,,,15,60,0.89
GPT o3,,OpenAI,OpenAI/o3-2025-04-16,,,,2,8,0.85
GPT o3,mini,OpenAI,OpenAI/o3-mini-2025-01-31,,,,1.1,4.4,0.87
GPT o4,mini,OpenAI,OpenAI/o4-mini-2025-04-16,,,,1.1,4.4,0.83
Model Series,Model Name,Developer,API URI,Self Hosted URI,Parameters (10E+9),FLOPs (10E+23),OpenRouter Input Cost,OpenRouter Output Cost,MMLU-Pro,Flash Attention 2
Jamba 1.6,Mini,AI21,OpenRouter/ai21/jamba-1.6-mini,SelfHosted/ai21labs/AI21-Jamba-Mini-1.6,52,,0.2,0.4,0.37,
Jamba 1.6,Large,AI21,OpenRouter/ai21/jamba-1.6-large,SelfHosted/ai21labs/AI21-Jamba-Large-1.6,398,,2,8,0.56,
Qwen 2.5 Instruct,0.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-0.5B-Instruct,0.5,,,,0.15,yes
Qwen 2.5 Instruct,1.5B,Alibaba,,SelfHosted/Qwen/Qwen2.5-1.5B-Instruct,1.5,1.6632,,,0.32,yes
Qwen 2.5 Instruct,3B,Alibaba,,SelfHosted/Qwen/Qwen2.5-3B-Instruct,3,3.3372,,,0.44,yes
Qwen 2.5 Instruct,7B,Alibaba,OpenRouter/qwen/qwen-2.5-7b-instruct,SelfHosted/Qwen/Qwen2.5-7B-Instruct,7,8.2188,0.3,0.3,0.45,yes
Qwen 2.5 Instruct,14B,Alibaba,,SelfHosted/Qwen/Qwen2.5-14B-Instruct,14,15.876,,,0.64,yes
Qwen 2.5 Instruct,32B,Alibaba,OpenRouter/qwen/qwen2.5-32b-instruct,SelfHosted/Qwen/Qwen2.5-32B-Instruct,32,35.1,0.79,0.79,0.7,yes
Qwen 2.5 Instruct,72B,Alibaba,OpenRouter/qwen/qwen-2.5-72b-instruct,SelfHosted/Qwen/Qwen2.5-72B-Instruct,72,78,0.7,0.7,0.72,yes
Claude 3,Haiku,Anthropic,OpenRouter/anthropic/claude-3-haiku,,,,0.25,1.25,,
Claude 3,Sonnet,Anthropic,OpenRouter/anthropic/claude-3-sonnet,,,,3,15,0.58,
Claude 3,Opus,Anthropic,OpenRouter/anthropic/claude-3-opus,,,164.0001,15,75,0.7,
Claude 3.5,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.5-sonnet,,,365.0001,0.8,4,0.77,
Claude 3.7,Sonnet,Anthropic,OpenRouter/anthropic/claude-3.7-sonnet,,,335,3,15,0.8,
Claude 4,Sonnet,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,3,15,0.87,
Claude 4,Opus,Anthropic,OpenRouter/anthropic/claude-sonnet-4,,,,15,75,0.89,
Command,R7B,Cohere,OpenRouter/cohere/command-r7b-12-2024,SelfHosted/CohereLabs/c4ai-command-r7b-12-2024,7,,0.0375,0.15,0.29,
Command,R,Cohere,OpenRouter/cohere/command-r,SelfHosted/CohereLabs/c4ai-command-r-v01,35,,0.15,0.6,0.34,
Command,R+,Cohere,OpenRouter/cohere/command-r-plus,SelfHosted/CohereLabs/c4ai-command-r-plus,104,,3,15,0.43,
DeepSeek Coder Instruct,1.3B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-1.3b-instruct,1.3,,,,,
DeepSeek Coder Instruct,6.7B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-6.7b-instruct,6.7,,,,,
DeepSeek Coder Instruct,33B,DeepSeek,,SelfHosted/deepseek-ai/deepseek-coder-33b-instruct,33,,,,,
DeepSeek Coder v2,Lite,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,16,,,,0.42,
DeepSeek Coder v2,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-Coder-V2-Instruct,236,,,,0.64,
DeepSeek V2.5,,DeepSeek,,SelfHosted/deepseek-ai/DeepSeek-V2.5,236,17.892,,,,
DeepSeek V3,,DeepSeek,OpenRouter/deepseek/deepseek-chat-v3-0324,SelfHosted/deepseek-ai/DeepSeek-V3-0324,671,34.078,0.27,1.1,0.76,
Gemma 2,2B,Google DeepMind,,SelfHosted/google/gemma-2-2b-it,2,,,,0.51,
Gemma 2,9B,Google DeepMind,OpenRouter/google/gemma-2-9b-it,SelfHosted/google/gemma-2-9b-it,9,4.32,0.2,0.2,0.5,
Gemma 2,27B,Google DeepMind,OpenRouter/google/gemma-2-27b-it,SelfHosted/google/gemma-2-27b-it,27,21.06,0.8,0.8,0.57,
Gemma 3,1B,Google DeepMind,,SelfHosted/google/gemma-3-1b-it,1,,,,0.1,
Gemma 3,4B,Google DeepMind,OpenRouter/google/gemma-3-4b-it,SelfHosted/google/gemma-3-4b-it,4,,0.02,0.04,0.42,
Gemma 3,12B,Google DeepMind,OpenRouter/google/gemma-3-12b-it,SelfHosted/google/gemma-3-12b-it,12,,0.05,0.1,0.6,
Gemma 3,27B,Google DeepMind,OpenRouter/google/gemma-3-27b-it,SelfHosted/google/gemma-3-27b-it,27,,0.2,0.4,0.67,
Llama 3.1 Instruct,8B,Meta,OpenRouter/meta-llama/llama-3.1-8b-instruct,SelfHosted/meta-llama/Llama-3.1-8B-Instruct,8,12.24,0.02,0.05,0.44,yes
Llama 3.1 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.1-70b-instruct,SelfHosted/meta-llama/Llama-3.1-70B-Instruct,70,79.29,0.12,0.3,0.56,yes
Llama 3.1 Instruct,405B,Meta,OpenRouter/meta-llama/llama-3.1-405b-instruct,SelfHosted/meta-llama/Llama-3.1-405B-Instruct,405,380,0.8,0.8,0.73,yes
Llama 3.2 Instruct,1B,Meta,OpenRouter/meta-llama/llama-3.2-1b-instruct,SelfHosted/meta-llama/Llama-3.2-1B-Instruct,1,,0.01,0.01,0.12,yes
Llama 3.2 Instruct,3B,Meta,OpenRouter/meta-llama/llama-3.2-3b-instruct,SelfHosted/meta-llama/Llama-3.2-3B-Instruct,3,1.7334,0.015,0.025,0.22,yes
Llama 3.3 Instruct,70B,Meta,OpenRouter/meta-llama/llama-3.3-70b-instruct,SelfHosted/meta-llama/Llama-3.3-70B-Instruct,70,68.649768,0.12,0.3,0.66,yes
Llama 4,Scout,Meta,OpenRouter/meta-llama/llama-4-scout,SelfHosted/meta-llama/Llama-4-Scout-17B-16E-Instruct,109,,0.1,0.4,0.74,yes
Llama 4,Maverick,Meta,OpenRouter/meta-llama/llama-4-maverick,SelfHosted/meta-llama/Llama-4-Maverick-17B-128E-Instruct,400,,0.2,0.8,0.81,yes
Phi 3,Mini,Microsoft,OpenRouter/microsoft/phi-3-mini-128k-instruct,SelfHosted/microsoft/Phi-3-mini-128k-instruct,3.8,,0.1,0.1,0.44,
Phi 3,Small,Microsoft,,SelfHosted/microsoft/Phi-3-small-128k-instruct,7,2.1312,,,,
Phi 3,Medium,Microsoft,OpenRouter/microsoft/phi-3-medium-128k-instruct,SelfHosted/microsoft/Phi-3-medium-128k-instruct,14,4.032,1,1,0.52,
Phi 4,Mini,Microsoft,,SelfHosted/microsoft/Phi-4-mini-instruct,3.8,,,,0.53,
Phi 4,,Microsoft,OpenRouter/microsoft/phi-4,SelfHosted/microsoft/phi-4,14,9.3202015,0.07,0.14,0.7,
Codestral,,Mistral,OpenRouter/mistralai/codestral-2501,SelfHosted/mistralai/Codestral-22B-v0.1,22,,0.3,0.9,0.45,
Ministral,3B,Mistral,OpenRouter/mistralai/ministral-3b,SelfHosted/ministral/Ministral-3b-instruct,3,,0.04,0.04,0.34,
Ministral,8B,Mistral,OpenRouter/mistralai/ministral-8b,SelfHosted/mistralai/Ministral-8B-Instruct-2410,8,,0.1,0.1,0.39,
Mistral Instruct,7B,Mistral,OpenRouter/mistralai/mistral-7b-instruct,SelfHosted/mistralai/Mistral-7B-Instruct-v0.3,7,,0.03,0.055,,yes
Mixtral Instruct (8 Experts),8x7B,Mistral,OpenRouter/mistralai/mixtral-8x7b-instruct,SelfHosted/mistralai/Mixtral-8x7B-Instruct-v0.1,56,,0.24,0.24,,
Mixtral Instruct (8 Experts),8x22B,Mistral,OpenRouter/mistralai/mixtral-8x22b-instruct,SelfHosted/mistralai/Mixtral-8x22B-Instruct-v0.1,176,,0.9,0.9,,
Mistral Large 2,,Mistral,OpenRouter/mistralai/mistral-large-2411,SelfHosted/mistralai/Pixtral-Large-Instruct-2411,123,213,2,6,,
Mistral Small 3,,Mistral,OpenRouter/mistralai/mistral-small-24b-instruct-2501,SelfHosted/mistralai/Mistral-Small-24B-Instruct-2501,24,11.52,0.1,0.3,,
GPT 4o,mini,OpenAI,OpenAI/gpt-4o-mini-2024-07-18,,,381.0001,2.5,10,0.63,
GPT 4o,,OpenAI,OpenAI/gpt-4o-2024-08-06,,,73.6001,0.150,0.6,0.75,
GPT 4.1,nano,OpenAI,OpenAI/gpt-4.1-nano-2025-04-14,,,,0.1,0.4,0.8,
GPT 4.1,mini,OpenAI,OpenAI/gpt-4.1-mini-2025-04-14,,,,0.4,1.6,0.88,
GPT 4.1,,OpenAI,OpenAI/gpt-4.1-2025-04-14,,,,2,8,0.9,
GPT o1,,OpenAI,OpenAI/o1-2024-12-17,,,,15,60,0.89,
GPT o3,,OpenAI,OpenAI/o3-2025-04-16,,,,2,8,0.85,
GPT o3,mini,OpenAI,OpenAI/o3-mini-2025-01-31,,,,1.1,4.4,0.87,
GPT o4,mini,OpenAI,OpenAI/o4-mini-2025-04-16,,,,1.1,4.4,0.83,
1 change: 1 addition & 0 deletions doc/docs/guides/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Installation Steps

.. code-tab:: bash Hosting the LM Server

uv sync
uv sync --extra lm-server

.. code-tab:: bash With ``pip``
Expand Down
48 changes: 42 additions & 6 deletions nip/language_model_server/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from datasets import Dataset

import torch

from trl import DPOConfig, DPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

Expand All @@ -22,6 +24,7 @@
from nip.utils.env import get_env_var
from nip.utils.types import HuggingFaceDpoDatasetItem
from nip.utils.hugging_face import is_model_peft
from nip.utils.language_model_database import LanguageModelDatabase
from nip.language_model_server.types import (
LmTrainingConfig,
LmLoraAdapterConfig,
Expand Down Expand Up @@ -209,6 +212,36 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam
# The maximum length for a W&B job name is 128 characters.
job_name = job_id[:127]

is_peft = is_model_peft(config.model_name)

if is_peft:
model_lora_config = LoraConfig.from_pretrained(config.model_name)
base_model_name = model_lora_config.base_model_name_or_path
else:
base_model_name = config.model_name

language_model_db = LanguageModelDatabase()
lm_db_entry = language_model_db.get_by_model_provider_and_name(
"SelfHosted", base_model_name
)
use_flash_attention_2 = (
lm_db_entry.has_flash_attention_2 and config.mixed_precision in ("fp16", "bf16")
)

extra_model_kwargs = {}

if use_flash_attention_2:

logger.info(f"Using Flash Attention 2 for {base_model_name!r}.")

extra_model_kwargs["attn_implementation"] = "flash_attention_2"
if config.mixed_precision == "fp16":
extra_model_kwargs["torch_dtype"] = torch.float16
elif config.mixed_precision == "bf16":
extra_model_kwargs["torch_dtype"] = torch.bfloat16

# Only use padding-free batching if Flash Attention 2 is available, to avoid batch
# contamination issues.
dpo_config = DPOConfig(
**config.dpo_config.model_dump(),
hub_model_id=new_model_name,
Expand All @@ -220,15 +253,12 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam
per_device_train_batch_size=config.per_device_train_batch_size,
use_liger_kernel=config.use_liger_kernel,
seed=config.seed,
padding_free=use_flash_attention_2,
)

ignore_training_lora_config = False

if not is_model_peft(config.model_name):
model = AutoModelForCausalLM.from_pretrained(config.model_name)

else:
model_lora_config = LoraConfig.from_pretrained(config.model_name)
if is_peft:

# When reusing the LoRA adapter, make sure the model's LoRA configuration is
# compatible with the training configuration.
Expand All @@ -251,7 +281,7 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam
)

model = AutoPeftModelForCausalLM.from_pretrained(
config.model_name, is_trainable=True
config.model_name, is_trainable=True, **extra_model_kwargs
)

# Sanity check: ensure that exactly the LoRA layers are trainable.
Expand All @@ -275,6 +305,12 @@ def train(config: LmTrainingConfig, dataset: Dataset, job_id: str, new_model_nam
# already LoRA-adapted and the trainer will train the existing adapter.
ignore_training_lora_config = True

else:

model = AutoModelForCausalLM.from_pretrained(
config.model_name, **extra_model_kwargs
)

if ignore_training_lora_config or config.training_lora_config is None:
training_lora_config = None
else:
Expand Down
Loading
Loading