Skip to content

Commit e9825e6

Browse files
authored
Merge pull request #46 from makaveli10/lora-instruct-ft
Add Instruction Fine-tuning Support for LoRA with Assistant-Only Loss
2 parents f9e7293 + 661890c commit e9825e6

File tree

19 files changed

+1632
-83
lines changed

19 files changed

+1632
-83
lines changed

common/common.cpp

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "common.h"
99
#include "log.h"
1010
#include "llama.h"
11+
#include "chat.h"
12+
#include <nlohmann/json.hpp>
1113

1214
#include <algorithm>
1315
#include <cinttypes>
@@ -1616,3 +1618,290 @@ float lr_opt::get_lr(float epoch) const {
16161618
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
16171619
return r;
16181620
}
1621+
1622+
ggml_opt_dataset_t common_opt_sft_dataset_init(
1623+
struct llama_context * ctx,
1624+
const std::string & json_content,
1625+
int64_t stride,
1626+
const std::string & chat_template_path) {
1627+
using json = nlohmann::json;
1628+
1629+
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(ctx));
1630+
common_chat_templates_ptr chat_templates;
1631+
std::string chat_template_source;
1632+
if (!chat_template_path.empty()) {
1633+
std::ifstream tmpl_file(chat_template_path);
1634+
if (!tmpl_file.is_open()) {
1635+
LOG_ERR("Warning: Failed to open chat template file: %s\n", chat_template_path.c_str());
1636+
} else {
1637+
chat_template_source.assign(std::istreambuf_iterator<char>(tmpl_file), std::istreambuf_iterator<char>());
1638+
tmpl_file.close();
1639+
try {
1640+
chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source);
1641+
} catch (const std::exception & e) {
1642+
LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what());
1643+
}
1644+
}
1645+
}
1646+
1647+
std::vector<json> conversations;
1648+
std::istringstream content_stream(json_content);
1649+
1650+
std::string line;
1651+
while (std::getline(content_stream, line)) {
1652+
if (line.empty() || line[0] == '#') continue;
1653+
try {
1654+
json conv = json::parse(line);
1655+
if (conv.contains("messages") && conv["messages"].is_array()) {
1656+
conversations.push_back(conv);
1657+
}
1658+
} catch (const json::exception & e) {
1659+
LOG_DBG("Warning: Failed to parse JSON line: %s\n", e.what());
1660+
}
1661+
}
1662+
1663+
if (conversations.empty()) {
1664+
LOG_ERR("Error: No valid conversations found\n");
1665+
return nullptr;
1666+
}
1667+
LOG_INF("Loaded %zu conversations\n", conversations.size());
1668+
1669+
const int64_t ne_datapoint = llama_n_ctx(ctx);
1670+
if (stride <= 0) stride = ne_datapoint;
1671+
if (stride > ne_datapoint) stride = ne_datapoint;
1672+
1673+
std::vector<std::vector<llama_token>> all_tokenized_data;
1674+
std::vector<std::vector<int32_t>> all_assistant_masks;
1675+
1676+
auto token_count_prefix = [&](const std::string & render, size_t char_count) -> size_t {
1677+
std::string prefix = render.substr(0, char_count);
1678+
auto t = common_tokenize(ctx, prefix, /*add_special=*/false, /*parse_special=*/true);
1679+
return t.size();
1680+
};
1681+
1682+
const std::string START_TAG = "<|im_start|>";
1683+
const std::string START_SYS = "<|im_start|>system\n";
1684+
const std::string START_USR = "<|im_start|>user\n";
1685+
const std::string START_AST = "<|im_start|>assistant\n";
1686+
const std::string END_TAG = "<|im_end|>";
1687+
const std::string NL = "\n";
1688+
1689+
for (size_t i = 0; i < conversations.size(); ++i) {
1690+
const auto & messages = conversations[i]["messages"];
1691+
if (!messages.is_array() || messages.empty()) continue;
1692+
1693+
std::string render;
1694+
1695+
if (chat_templates) {
1696+
std::vector<common_chat_msg> chat_msgs;
1697+
chat_msgs.reserve(messages.size());
1698+
for (const auto & msg : messages) {
1699+
if (!msg.contains("role") || !msg.contains("content")) {
1700+
continue;
1701+
}
1702+
common_chat_msg chat_msg;
1703+
chat_msg.role = msg["role"].get<std::string>();
1704+
chat_msg.content = msg["content"].get<std::string>();
1705+
chat_msgs.push_back(std::move(chat_msg));
1706+
}
1707+
1708+
if (!chat_msgs.empty()) {
1709+
common_chat_templates_inputs inputs;
1710+
inputs.messages = std::move(chat_msgs);
1711+
inputs.add_generation_prompt = false;
1712+
inputs.use_jinja = true;
1713+
try {
1714+
render = common_chat_templates_apply(chat_templates.get(), inputs).prompt;
1715+
1716+
size_t last_im_end = render.rfind("<|im_end|>");
1717+
if (last_im_end != std::string::npos) {
1718+
size_t end_pos = last_im_end + 10; // length of "<|im_end|>"
1719+
// Remove any trailing whitespace/newlines after the final <|im_end|>
1720+
while (end_pos < render.size() && (render[end_pos] == '\n' || render[end_pos] == '\r' || render[end_pos] == ' ')) {
1721+
end_pos++;
1722+
}
1723+
if (end_pos < render.size()) {
1724+
render = render.substr(0, last_im_end + 10); // Keep only up to </im_end>
1725+
}
1726+
}
1727+
} catch (const std::exception & e) {
1728+
LOG_WRN("Warning: chat template rendering failed for conversation %zu: %s. Falling back to default ChatML rendering.\n",
1729+
i, e.what());
1730+
}
1731+
}
1732+
}
1733+
1734+
if (render.empty()) {
1735+
render.reserve(4096);
1736+
for (const auto & msg : messages) {
1737+
if (!msg.contains("role") || !msg.contains("content")) continue;
1738+
const std::string role = msg["role"].get<std::string>();
1739+
const std::string content = msg["content"].get<std::string>();
1740+
1741+
if (role == "system") {
1742+
render += START_SYS; render += content; render += END_TAG + NL;
1743+
} else if (role == "user") {
1744+
render += START_USR; render += content; render += END_TAG + NL;
1745+
} else if (role == "assistant") {
1746+
render += START_AST; render += content; render += END_TAG + NL;
1747+
}
1748+
}
1749+
}
1750+
1751+
if (render.empty()) {
1752+
continue;
1753+
}
1754+
1755+
struct Span { size_t lo, hi; };
1756+
std::vector<Span> assistant_spans;
1757+
1758+
{
1759+
size_t from = 0;
1760+
while (true) {
1761+
size_t open = render.find(START_AST, from);
1762+
if (open == std::string::npos) break;
1763+
1764+
// Include the role token ("assistant") and everything through the closing tag/newlines
1765+
size_t lo = open + START_TAG.size();
1766+
if (lo > render.size()) {
1767+
lo = render.size();
1768+
}
1769+
1770+
size_t close = render.find(END_TAG, open + START_AST.size());
1771+
if (close == std::string::npos) {
1772+
assistant_spans.push_back({lo, render.size()});
1773+
break;
1774+
}
1775+
1776+
size_t hi = close + END_TAG.size();
1777+
if (hi <= lo) {
1778+
lo = open;
1779+
hi = close + END_TAG.size();
1780+
}
1781+
1782+
assistant_spans.push_back({lo, std::min(hi, render.size())});
1783+
1784+
size_t next_from = hi;
1785+
from = next_from;
1786+
}
1787+
}
1788+
1789+
if (assistant_spans.empty()) {
1790+
LOG_WRN("Conversation %zu has no assistant spans\n", i);
1791+
continue;
1792+
}
1793+
1794+
auto tokens_full = common_tokenize(ctx, render, /*add_special=*/false, /*parse_special=*/true);
1795+
if (tokens_full.empty()) continue;
1796+
1797+
std::vector<int32_t> assistant_mask(tokens_full.size(), 0);
1798+
size_t assistant_token_count = 0;
1799+
1800+
for (const auto & sp : assistant_spans) {
1801+
size_t t_lo = token_count_prefix(render, sp.lo);
1802+
size_t t_hi = token_count_prefix(render, sp.hi);
1803+
if (t_lo > tokens_full.size()) t_lo = tokens_full.size();
1804+
if (t_hi > tokens_full.size()) t_hi = tokens_full.size();
1805+
1806+
1807+
for (size_t t = t_lo; t < t_hi; ++t) {
1808+
assistant_mask[t] = 1;
1809+
++assistant_token_count;
1810+
}
1811+
}
1812+
1813+
if (assistant_token_count == 0) {
1814+
LOG_WRN("Warning: Conversation %zu has zero assistant tokens after masking\n", i);
1815+
continue;
1816+
}
1817+
1818+
all_tokenized_data.push_back(tokens_full);
1819+
all_assistant_masks.push_back(assistant_mask);
1820+
}
1821+
1822+
if (all_tokenized_data.empty()) {
1823+
LOG_ERR("ERROR: No valid training samples generated after processing %zu conversations\n", conversations.size());
1824+
return nullptr;
1825+
}
1826+
1827+
std::vector<std::vector<llama_token>> final_samples;
1828+
std::vector<std::vector<int32_t>> final_masks;
1829+
1830+
llama_token pad_token = llama_vocab_pad(vocab);
1831+
if (pad_token == LLAMA_TOKEN_NULL) {
1832+
pad_token = llama_vocab_eos(vocab);
1833+
}
1834+
1835+
for (size_t i = 0; i < all_tokenized_data.size(); ++i) {
1836+
const auto& conv_tokens = all_tokenized_data[i];
1837+
const auto& conv_mask = all_assistant_masks[i];
1838+
1839+
if ((int64_t)conv_tokens.size() > ne_datapoint) {
1840+
LOG_WRN("Skipping conversation %zu: too long (%zu tokens > %lld)\n", i, conv_tokens.size(), (long long)ne_datapoint);
1841+
continue;
1842+
}
1843+
1844+
size_t conv_assistant_tokens = 0;
1845+
for (int32_t mask_val : conv_mask) {
1846+
if (mask_val == 1) conv_assistant_tokens++;
1847+
}
1848+
1849+
if (conv_assistant_tokens == 0) {
1850+
LOG_WRN("Skipping conversation %zu: no assistant tokens\n", i);
1851+
continue;
1852+
}
1853+
1854+
std::vector<llama_token> sample_tokens = conv_tokens;
1855+
std::vector<int32_t> sample_mask = conv_mask;
1856+
1857+
sample_tokens.resize(ne_datapoint, pad_token);
1858+
sample_mask.resize(ne_datapoint, 0); // Padding tokens are not trained on
1859+
1860+
final_samples.push_back(sample_tokens);
1861+
final_masks.push_back(sample_mask);
1862+
}
1863+
1864+
all_tokenized_data = std::move(final_samples);
1865+
all_assistant_masks = std::move(final_masks);
1866+
1867+
const int64_t ndata = all_tokenized_data.size();
1868+
1869+
ggml_opt_dataset_t result = ggml_opt_dataset_init_with_masks(
1870+
GGML_TYPE_I32, GGML_TYPE_I32, GGML_TYPE_I32,
1871+
/*ne_datapoint=*/ne_datapoint, /*ne_label=*/ne_datapoint, /*ne_mask=*/ne_datapoint,
1872+
/*ndata=*/ndata, /*ndata_shard=*/1);
1873+
1874+
if (result == nullptr) {
1875+
return nullptr;
1876+
}
1877+
1878+
int32_t * data = (int32_t *) ggml_opt_dataset_data(result)->data;
1879+
int32_t * labels = (int32_t *) ggml_opt_dataset_labels(result)->data;
1880+
int32_t * masks = (int32_t *) ggml_opt_dataset_masks(result)->data;
1881+
1882+
for (int64_t idata = 0; idata < ndata; ++idata) {
1883+
const auto & sample_tokens = all_tokenized_data[idata];
1884+
const auto & sample_mask = all_assistant_masks[idata];
1885+
1886+
// inputs
1887+
for (int64_t i = 0; i < ne_datapoint; ++i) {
1888+
data[idata * ne_datapoint + i] = sample_tokens[i];
1889+
}
1890+
1891+
// labels: Set actual next tokens for ALL positions (masked cross-entropy needs real tokens)
1892+
for (int64_t i = 0; i < ne_datapoint - 1; ++i) {
1893+
// Always set the actual next token - masking is handled separately
1894+
labels[idata * ne_datapoint + i] = sample_tokens[i + 1];
1895+
}
1896+
labels[idata * ne_datapoint + (ne_datapoint - 1)] = sample_tokens[ne_datapoint - 1]; // last token predicts itself (will be masked)
1897+
1898+
// masks: indicate which preds should be trained on (shifted by 1 from sample_mask)
1899+
// Since we predict token[i+1] from token[i], we train when token[i+1] is assistant
1900+
for (int64_t i = 0; i < ne_datapoint - 1; ++i) {
1901+
masks[idata * ne_datapoint + i] = (i + 1 < ne_datapoint && sample_mask[i + 1] == 1) ? 1 : 0;
1902+
}
1903+
masks[idata * ne_datapoint + (ne_datapoint - 1)] = 0;
1904+
}
1905+
1906+
return result;
1907+
}

common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,8 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
745745

746746
// "adamw" or "sgd" (case insensitive)
747747
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
748+
ggml_opt_dataset_t common_opt_sft_dataset_init(
749+
struct llama_context * ctx,
750+
const std::string & json_content,
751+
int64_t stride,
752+
const std::string & chat_template_path = "");

examples/training/README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ the base model frozen, making it memory-efficient.
2727

2828
```sh
2929
# Create new LoRA adapter with default settings (rank=8, alpha=16, attention modules)
30-
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512
30+
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 -fa off
3131

3232
# Custom LoRA parameters(creates new lora adapter and trains it from scratch)
3333
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \
34-
--lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o"
34+
--lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o" -fa off
3535

3636
# Fine-tune existing LoRA adapter
3737
./build/bin/llama-finetune-lora -m base_model.gguf -f dataset.txt --lora existing_adapter.gguf \
@@ -44,8 +44,17 @@ the base model frozen, making it memory-efficient.
4444
# Resume training from checkpoint
4545
./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \
4646
--resume-from "./lora_checkpoints/checkpoint_step_00000150/"
47+
--output-adapter improved_adapter.gguf -ngl 999 -c 512 -b 512 -ub 512 -fa off
48+
49+
# Supervised FineTuning with Assistant only loss
50+
./build/bin/llama-finetune-lora -m model.gguf -f dataset.jsonl -ngl 999 -c 512 -b 512 -ub 512 \
51+
--lora-modules "attn_q,attn_k,attn_v,attn_o" --assistant-loss-only -fa off
4752
```
4853

54+
### SFT(Instruction Fine Tuning) with Assistant Only Loss
55+
- Masks the system and user tokens and only computes loss on assistant tokens
56+
- Requires the dataset to be in json format just like huggingface with `role` and `content` for each role
57+
- Allows users to optionally pass a jinja chat template with `--chat-template chat-ml-template.jinja`
4958

5059
### Parameters
5160

@@ -60,6 +69,8 @@ the base model frozen, making it memory-efficient.
6069
- Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all`
6170
- Default: `attn_q,attn_k,attn_v,attn_o` (attention modules)
6271
- `--output-adapter PATH` - Output adapter filename (default: auto-generated)
72+
- `--assistant-loss-only` - Trains only on assistant tokens
73+
- `--chat-template` - Jinja chat template for chat ML formatting to train on assistant tokens only
6374

6475
#### Checkpointing
6576
- `--checkpoint-save-steps N` - Save checkpoint every N training steps (default: 100)

0 commit comments

Comments
 (0)