|
8 | 8 | #include "common.h" |
9 | 9 | #include "log.h" |
10 | 10 | #include "llama.h" |
| 11 | +#include "chat.h" |
| 12 | +#include <nlohmann/json.hpp> |
11 | 13 |
|
12 | 14 | #include <algorithm> |
13 | 15 | #include <cinttypes> |
@@ -1616,3 +1618,290 @@ float lr_opt::get_lr(float epoch) const { |
1616 | 1618 | LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); |
1617 | 1619 | return r; |
1618 | 1620 | } |
| 1621 | + |
| 1622 | +ggml_opt_dataset_t common_opt_sft_dataset_init( |
| 1623 | + struct llama_context * ctx, |
| 1624 | + const std::string & json_content, |
| 1625 | + int64_t stride, |
| 1626 | + const std::string & chat_template_path) { |
| 1627 | + using json = nlohmann::json; |
| 1628 | + |
| 1629 | + const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(ctx)); |
| 1630 | + common_chat_templates_ptr chat_templates; |
| 1631 | + std::string chat_template_source; |
| 1632 | + if (!chat_template_path.empty()) { |
| 1633 | + std::ifstream tmpl_file(chat_template_path); |
| 1634 | + if (!tmpl_file.is_open()) { |
| 1635 | + LOG_ERR("Warning: Failed to open chat template file: %s\n", chat_template_path.c_str()); |
| 1636 | + } else { |
| 1637 | + chat_template_source.assign(std::istreambuf_iterator<char>(tmpl_file), std::istreambuf_iterator<char>()); |
| 1638 | + tmpl_file.close(); |
| 1639 | + try { |
| 1640 | + chat_templates = common_chat_templates_init(llama_get_model(ctx), chat_template_source); |
| 1641 | + } catch (const std::exception & e) { |
| 1642 | + LOG_ERR("Warning: Failed to parse chat template '%s': %s\n", chat_template_path.c_str(), e.what()); |
| 1643 | + } |
| 1644 | + } |
| 1645 | + } |
| 1646 | + |
| 1647 | + std::vector<json> conversations; |
| 1648 | + std::istringstream content_stream(json_content); |
| 1649 | + |
| 1650 | + std::string line; |
| 1651 | + while (std::getline(content_stream, line)) { |
| 1652 | + if (line.empty() || line[0] == '#') continue; |
| 1653 | + try { |
| 1654 | + json conv = json::parse(line); |
| 1655 | + if (conv.contains("messages") && conv["messages"].is_array()) { |
| 1656 | + conversations.push_back(conv); |
| 1657 | + } |
| 1658 | + } catch (const json::exception & e) { |
| 1659 | + LOG_DBG("Warning: Failed to parse JSON line: %s\n", e.what()); |
| 1660 | + } |
| 1661 | + } |
| 1662 | + |
| 1663 | + if (conversations.empty()) { |
| 1664 | + LOG_ERR("Error: No valid conversations found\n"); |
| 1665 | + return nullptr; |
| 1666 | + } |
| 1667 | + LOG_INF("Loaded %zu conversations\n", conversations.size()); |
| 1668 | + |
| 1669 | + const int64_t ne_datapoint = llama_n_ctx(ctx); |
| 1670 | + if (stride <= 0) stride = ne_datapoint; |
| 1671 | + if (stride > ne_datapoint) stride = ne_datapoint; |
| 1672 | + |
| 1673 | + std::vector<std::vector<llama_token>> all_tokenized_data; |
| 1674 | + std::vector<std::vector<int32_t>> all_assistant_masks; |
| 1675 | + |
| 1676 | + auto token_count_prefix = [&](const std::string & render, size_t char_count) -> size_t { |
| 1677 | + std::string prefix = render.substr(0, char_count); |
| 1678 | + auto t = common_tokenize(ctx, prefix, /*add_special=*/false, /*parse_special=*/true); |
| 1679 | + return t.size(); |
| 1680 | + }; |
| 1681 | + |
| 1682 | + const std::string START_TAG = "<|im_start|>"; |
| 1683 | + const std::string START_SYS = "<|im_start|>system\n"; |
| 1684 | + const std::string START_USR = "<|im_start|>user\n"; |
| 1685 | + const std::string START_AST = "<|im_start|>assistant\n"; |
| 1686 | + const std::string END_TAG = "<|im_end|>"; |
| 1687 | + const std::string NL = "\n"; |
| 1688 | + |
| 1689 | + for (size_t i = 0; i < conversations.size(); ++i) { |
| 1690 | + const auto & messages = conversations[i]["messages"]; |
| 1691 | + if (!messages.is_array() || messages.empty()) continue; |
| 1692 | + |
| 1693 | + std::string render; |
| 1694 | + |
| 1695 | + if (chat_templates) { |
| 1696 | + std::vector<common_chat_msg> chat_msgs; |
| 1697 | + chat_msgs.reserve(messages.size()); |
| 1698 | + for (const auto & msg : messages) { |
| 1699 | + if (!msg.contains("role") || !msg.contains("content")) { |
| 1700 | + continue; |
| 1701 | + } |
| 1702 | + common_chat_msg chat_msg; |
| 1703 | + chat_msg.role = msg["role"].get<std::string>(); |
| 1704 | + chat_msg.content = msg["content"].get<std::string>(); |
| 1705 | + chat_msgs.push_back(std::move(chat_msg)); |
| 1706 | + } |
| 1707 | + |
| 1708 | + if (!chat_msgs.empty()) { |
| 1709 | + common_chat_templates_inputs inputs; |
| 1710 | + inputs.messages = std::move(chat_msgs); |
| 1711 | + inputs.add_generation_prompt = false; |
| 1712 | + inputs.use_jinja = true; |
| 1713 | + try { |
| 1714 | + render = common_chat_templates_apply(chat_templates.get(), inputs).prompt; |
| 1715 | + |
| 1716 | + size_t last_im_end = render.rfind("<|im_end|>"); |
| 1717 | + if (last_im_end != std::string::npos) { |
| 1718 | + size_t end_pos = last_im_end + 10; // length of "<|im_end|>" |
| 1719 | + // Remove any trailing whitespace/newlines after the final <|im_end|> |
| 1720 | + while (end_pos < render.size() && (render[end_pos] == '\n' || render[end_pos] == '\r' || render[end_pos] == ' ')) { |
| 1721 | + end_pos++; |
| 1722 | + } |
| 1723 | + if (end_pos < render.size()) { |
| 1724 | + render = render.substr(0, last_im_end + 10); // Keep only up to </im_end> |
| 1725 | + } |
| 1726 | + } |
| 1727 | + } catch (const std::exception & e) { |
| 1728 | + LOG_WRN("Warning: chat template rendering failed for conversation %zu: %s. Falling back to default ChatML rendering.\n", |
| 1729 | + i, e.what()); |
| 1730 | + } |
| 1731 | + } |
| 1732 | + } |
| 1733 | + |
| 1734 | + if (render.empty()) { |
| 1735 | + render.reserve(4096); |
| 1736 | + for (const auto & msg : messages) { |
| 1737 | + if (!msg.contains("role") || !msg.contains("content")) continue; |
| 1738 | + const std::string role = msg["role"].get<std::string>(); |
| 1739 | + const std::string content = msg["content"].get<std::string>(); |
| 1740 | + |
| 1741 | + if (role == "system") { |
| 1742 | + render += START_SYS; render += content; render += END_TAG + NL; |
| 1743 | + } else if (role == "user") { |
| 1744 | + render += START_USR; render += content; render += END_TAG + NL; |
| 1745 | + } else if (role == "assistant") { |
| 1746 | + render += START_AST; render += content; render += END_TAG + NL; |
| 1747 | + } |
| 1748 | + } |
| 1749 | + } |
| 1750 | + |
| 1751 | + if (render.empty()) { |
| 1752 | + continue; |
| 1753 | + } |
| 1754 | + |
| 1755 | + struct Span { size_t lo, hi; }; |
| 1756 | + std::vector<Span> assistant_spans; |
| 1757 | + |
| 1758 | + { |
| 1759 | + size_t from = 0; |
| 1760 | + while (true) { |
| 1761 | + size_t open = render.find(START_AST, from); |
| 1762 | + if (open == std::string::npos) break; |
| 1763 | + |
| 1764 | + // Include the role token ("assistant") and everything through the closing tag/newlines |
| 1765 | + size_t lo = open + START_TAG.size(); |
| 1766 | + if (lo > render.size()) { |
| 1767 | + lo = render.size(); |
| 1768 | + } |
| 1769 | + |
| 1770 | + size_t close = render.find(END_TAG, open + START_AST.size()); |
| 1771 | + if (close == std::string::npos) { |
| 1772 | + assistant_spans.push_back({lo, render.size()}); |
| 1773 | + break; |
| 1774 | + } |
| 1775 | + |
| 1776 | + size_t hi = close + END_TAG.size(); |
| 1777 | + if (hi <= lo) { |
| 1778 | + lo = open; |
| 1779 | + hi = close + END_TAG.size(); |
| 1780 | + } |
| 1781 | + |
| 1782 | + assistant_spans.push_back({lo, std::min(hi, render.size())}); |
| 1783 | + |
| 1784 | + size_t next_from = hi; |
| 1785 | + from = next_from; |
| 1786 | + } |
| 1787 | + } |
| 1788 | + |
| 1789 | + if (assistant_spans.empty()) { |
| 1790 | + LOG_WRN("Conversation %zu has no assistant spans\n", i); |
| 1791 | + continue; |
| 1792 | + } |
| 1793 | + |
| 1794 | + auto tokens_full = common_tokenize(ctx, render, /*add_special=*/false, /*parse_special=*/true); |
| 1795 | + if (tokens_full.empty()) continue; |
| 1796 | + |
| 1797 | + std::vector<int32_t> assistant_mask(tokens_full.size(), 0); |
| 1798 | + size_t assistant_token_count = 0; |
| 1799 | + |
| 1800 | + for (const auto & sp : assistant_spans) { |
| 1801 | + size_t t_lo = token_count_prefix(render, sp.lo); |
| 1802 | + size_t t_hi = token_count_prefix(render, sp.hi); |
| 1803 | + if (t_lo > tokens_full.size()) t_lo = tokens_full.size(); |
| 1804 | + if (t_hi > tokens_full.size()) t_hi = tokens_full.size(); |
| 1805 | + |
| 1806 | + |
| 1807 | + for (size_t t = t_lo; t < t_hi; ++t) { |
| 1808 | + assistant_mask[t] = 1; |
| 1809 | + ++assistant_token_count; |
| 1810 | + } |
| 1811 | + } |
| 1812 | + |
| 1813 | + if (assistant_token_count == 0) { |
| 1814 | + LOG_WRN("Warning: Conversation %zu has zero assistant tokens after masking\n", i); |
| 1815 | + continue; |
| 1816 | + } |
| 1817 | + |
| 1818 | + all_tokenized_data.push_back(tokens_full); |
| 1819 | + all_assistant_masks.push_back(assistant_mask); |
| 1820 | + } |
| 1821 | + |
| 1822 | + if (all_tokenized_data.empty()) { |
| 1823 | + LOG_ERR("ERROR: No valid training samples generated after processing %zu conversations\n", conversations.size()); |
| 1824 | + return nullptr; |
| 1825 | + } |
| 1826 | + |
| 1827 | + std::vector<std::vector<llama_token>> final_samples; |
| 1828 | + std::vector<std::vector<int32_t>> final_masks; |
| 1829 | + |
| 1830 | + llama_token pad_token = llama_vocab_pad(vocab); |
| 1831 | + if (pad_token == LLAMA_TOKEN_NULL) { |
| 1832 | + pad_token = llama_vocab_eos(vocab); |
| 1833 | + } |
| 1834 | + |
| 1835 | + for (size_t i = 0; i < all_tokenized_data.size(); ++i) { |
| 1836 | + const auto& conv_tokens = all_tokenized_data[i]; |
| 1837 | + const auto& conv_mask = all_assistant_masks[i]; |
| 1838 | + |
| 1839 | + if ((int64_t)conv_tokens.size() > ne_datapoint) { |
| 1840 | + LOG_WRN("Skipping conversation %zu: too long (%zu tokens > %lld)\n", i, conv_tokens.size(), (long long)ne_datapoint); |
| 1841 | + continue; |
| 1842 | + } |
| 1843 | + |
| 1844 | + size_t conv_assistant_tokens = 0; |
| 1845 | + for (int32_t mask_val : conv_mask) { |
| 1846 | + if (mask_val == 1) conv_assistant_tokens++; |
| 1847 | + } |
| 1848 | + |
| 1849 | + if (conv_assistant_tokens == 0) { |
| 1850 | + LOG_WRN("Skipping conversation %zu: no assistant tokens\n", i); |
| 1851 | + continue; |
| 1852 | + } |
| 1853 | + |
| 1854 | + std::vector<llama_token> sample_tokens = conv_tokens; |
| 1855 | + std::vector<int32_t> sample_mask = conv_mask; |
| 1856 | + |
| 1857 | + sample_tokens.resize(ne_datapoint, pad_token); |
| 1858 | + sample_mask.resize(ne_datapoint, 0); // Padding tokens are not trained on |
| 1859 | + |
| 1860 | + final_samples.push_back(sample_tokens); |
| 1861 | + final_masks.push_back(sample_mask); |
| 1862 | + } |
| 1863 | + |
| 1864 | + all_tokenized_data = std::move(final_samples); |
| 1865 | + all_assistant_masks = std::move(final_masks); |
| 1866 | + |
| 1867 | + const int64_t ndata = all_tokenized_data.size(); |
| 1868 | + |
| 1869 | + ggml_opt_dataset_t result = ggml_opt_dataset_init_with_masks( |
| 1870 | + GGML_TYPE_I32, GGML_TYPE_I32, GGML_TYPE_I32, |
| 1871 | + /*ne_datapoint=*/ne_datapoint, /*ne_label=*/ne_datapoint, /*ne_mask=*/ne_datapoint, |
| 1872 | + /*ndata=*/ndata, /*ndata_shard=*/1); |
| 1873 | + |
| 1874 | + if (result == nullptr) { |
| 1875 | + return nullptr; |
| 1876 | + } |
| 1877 | + |
| 1878 | + int32_t * data = (int32_t *) ggml_opt_dataset_data(result)->data; |
| 1879 | + int32_t * labels = (int32_t *) ggml_opt_dataset_labels(result)->data; |
| 1880 | + int32_t * masks = (int32_t *) ggml_opt_dataset_masks(result)->data; |
| 1881 | + |
| 1882 | + for (int64_t idata = 0; idata < ndata; ++idata) { |
| 1883 | + const auto & sample_tokens = all_tokenized_data[idata]; |
| 1884 | + const auto & sample_mask = all_assistant_masks[idata]; |
| 1885 | + |
| 1886 | + // inputs |
| 1887 | + for (int64_t i = 0; i < ne_datapoint; ++i) { |
| 1888 | + data[idata * ne_datapoint + i] = sample_tokens[i]; |
| 1889 | + } |
| 1890 | + |
| 1891 | + // labels: Set actual next tokens for ALL positions (masked cross-entropy needs real tokens) |
| 1892 | + for (int64_t i = 0; i < ne_datapoint - 1; ++i) { |
| 1893 | + // Always set the actual next token - masking is handled separately |
| 1894 | + labels[idata * ne_datapoint + i] = sample_tokens[i + 1]; |
| 1895 | + } |
| 1896 | + labels[idata * ne_datapoint + (ne_datapoint - 1)] = sample_tokens[ne_datapoint - 1]; // last token predicts itself (will be masked) |
| 1897 | + |
| 1898 | + // masks: indicate which preds should be trained on (shifted by 1 from sample_mask) |
| 1899 | + // Since we predict token[i+1] from token[i], we train when token[i+1] is assistant |
| 1900 | + for (int64_t i = 0; i < ne_datapoint - 1; ++i) { |
| 1901 | + masks[idata * ne_datapoint + i] = (i + 1 < ne_datapoint && sample_mask[i + 1] == 1) ? 1 : 0; |
| 1902 | + } |
| 1903 | + masks[idata * ne_datapoint + (ne_datapoint - 1)] = 0; |
| 1904 | + } |
| 1905 | + |
| 1906 | + return result; |
| 1907 | +} |
0 commit comments