From c822e73c4d5484a03ea0c367ded40ab3f832a56d Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 9 Nov 2025 20:07:12 -0600 Subject: [PATCH 01/16] common : implement parser combinators to simplify chat parsing --- common/CMakeLists.txt | 2 + common/chat-parser-combinator.cpp | 819 ++++++++++++++++++++++++++ common/chat-parser-combinator.h | 158 +++++ tests/CMakeLists.txt | 1 + tests/test-chat-parser-combinator.cpp | 472 +++++++++++++++ 5 files changed, 1452 insertions(+) create mode 100644 common/chat-parser-combinator.cpp create mode 100644 common/chat-parser-combinator.h create mode 100644 tests/test-chat-parser-combinator.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 7086d08e5e5e9..7bdc9aab5995f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -48,6 +48,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-parser-combinator.cpp + chat-parser-combinator.h chat-parser.cpp chat-parser.h chat.cpp diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp new file mode 100644 index 0000000000000..f2182980b3bac --- /dev/null +++ b/common/chat-parser-combinator.cpp @@ -0,0 +1,819 @@ +#include "chat-parser-combinator.h" +#include "common.h" +#include "log.h" + +#include +#include + +class parser_base { + protected: + int id_; + + void set_id(int id) { id_ = id; } + + public: + parser_base(int id) : id_(id) {} + + virtual parser_type type() const = 0; + virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; + virtual std::string dump() const = 0; + virtual void assign_ids_internal(int& next_id) { + if (id_ == -1) { + id_ = next_id++; + } + } +}; + +class literal_parser : public parser_base { + std::string literal_; + + public: + literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} + + parser_type type() const override { return PARSER_LITERAL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto pos = start; + for (auto i = 0u; i < literal_.size(); ++i) { + if (pos >= ctx.input.size()) { + if (ctx.input_is_complete) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + if (i > 0) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + if (ctx.input[pos] != literal_[i]) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + ++pos; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + } + + std::string dump() const override { + return "Literal(" + literal_ + ")"; + } +}; + +class sequence_parser : public parser_base { + std::vector parsers_; + + public: + sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { + for (const auto & p : parsers) { + if (p.is_sequence()) { + // Flatten sequences + for (const auto & embedded : p.to_sequence()->parsers()) { + parsers_.push_back(embedded); + } + } else { + parsers_.push_back(p); + } + } + } + + parser_type type() const override { return PARSER_SEQUENCE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + std::unordered_map groups; + + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + // Copy groups + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_fail()) { + if (result.end >= ctx.input.size() && !ctx.input_is_complete) { + // If we fail because we don't have enough input, then return success + return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + } + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start, result.end, groups)); + } + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); + } + + pos = result.end; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + } + + std::string dump() const override { + std::vector parts; + parts.reserve(parsers_.size()); + for (const auto & p : parsers_) { + parts.push_back(p->dump()); + } + return "Sequence(" + string_join(parts, ", ") + ")"; + } + + const std::vector & parsers() const { return parsers_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + for (auto & p : parsers_) { + p->assign_ids_internal(next_id); + } + } +}; + +class choice_parser : public parser_base { + std::vector parsers_; + + public: + choice_parser(std::initializer_list parsers, int id) : parser_base(id) { + for (const auto & p : parsers) { + if (p.is_choice()) { + // Flatten choices + for (const auto & embedded : p.to_choice()->parsers()) { + parsers_.push_back(embedded); + } + } else { + parsers_.push_back(p); + } + } + } + + parser_type type() const override { return PARSER_CHOICE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + if (result.is_success()) { + return ctx.memo.set(id_, start, result); + } + + if (result.is_need_more_input()) { + return result; + } + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + std::string dump() const override { + std::vector parts; + parts.reserve(parsers_.size()); + for (const auto & p : parsers_) { + parts.push_back(p->dump()); + } + return "Choice(" + string_join(parts, ", ") + ")"; + } + + const std::vector & parsers() const { return parsers_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + for (auto & p : parsers_) { + p->assign_ids_internal(next_id); + } + } +}; + +class one_or_more_parser : public parser_base { + parser parser_; + + public: + one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_ONE_OR_MORE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + std::unordered_map groups; + + // We can't return back the cached result, since there may be more + // repetitions since the last parsing attempt. Instead, resume parsing from + // the last successful repetition found. + auto pos = start; + if (cached != std::nullopt) { + pos = cached->end; + groups.insert(cached->groups.begin(), cached->groups.end()); + } + + if (pos == start) { + auto first_result = parser_->parse(ctx, pos); + if (!first_result.is_success()) { + return first_result; + } + + pos = first_result.end; + groups.insert(first_result.groups.begin(), first_result.groups.end()); + } + + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } + + if (result.is_fail()) { + // Done with repetitions + break; + } + + if (result.end == pos) { + break; // Prevent an infinite loop + } + + pos = result.end; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + } + + std::string dump() const override { + return "OneOrMore(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class zero_or_more_parser : public parser_base { + parser parser_; + + public: + zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_ZERO_OR_MORE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + std::unordered_map groups; + + // We can't return back the cached result, since there may be more + // repetitions since the last parsing attempt. Instead, resume parsing from + // the last successful repetition found. + auto pos = start; + if (cached != std::nullopt) { + pos = cached->end; + groups.insert(cached->groups.begin(), cached->groups.end()); + } + + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } + + if (result.is_fail()) { + // Done with repetitions (zero or more is always valid) + break; + } + + if (result.end == pos) { + break; // Prevent an infinite loop + } + + pos = result.end; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + } + + std::string dump() const override { + return "ZeroOrMore(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class optional_parser : public parser_base { + parser parser_; + + public: + optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_OPTIONAL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + + if (result.is_success()) { + // Matched successfully + return ctx.memo.set(id_, start, result); + } + + if (result.is_need_more_input()) { + // Propagate - need more input to determine if optional matches + return result; + } + + // No match, but optional always succeeds with zero matches + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start)); + } + + std::string dump() const override { + return "Optional(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class not_parser : public parser_base { + parser parser_; + + public: + not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_NOT; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + + if (result.is_success()) { + // Fail if the underlying parser matches + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + if (result.is_need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } + + // Child failed, so negation succeeds + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start)); + } + + std::string dump() const override { + return "Not(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class any_parser : public parser_base { + public: + any_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_ANY; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + } + + std::string dump() const override { + return "Any"; + } +}; + +class char_class_parser : public parser_base { + struct char_range { + int start; + int end; + + bool contains(char c) const { return (int)c >= start && int(c) <= end; } + }; + + std::string pattern_; + std::vector ranges_; + + public: + char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes) { + std::string content = classes; + if (content.front() == '[') { + content = content.substr(1); + } + + if (content.back() == ']') { + content.pop_back(); + } + + auto parse_char = [&](size_t pos) -> std::pair { + if (content[pos] == '\\' && pos + 1 < content.length()) { + char next = content[pos + 1]; + switch (next) { + case 'n': return {'\n', 2}; + case 't': return {'\t', 2}; + case 'r': return {'\r', 2}; + case '\\': return {'\\', 2}; + case ']': return {']', 2}; + case '-': return {'-', 2}; + case '[': return {'[', 2}; + default: return {next, 2}; // Treat as literal escaped character + } + } + return {content[pos], 1}; + }; + + size_t i = 0; + while (i < content.length()) { + auto [start, start_len] = parse_char(i); + i += start_len; + + if (i + 1 < content.length() && content[i] == '-') { + // Range detected + auto [end, end_len] = parse_char(i + 1); + ranges_.push_back(char_range{start, end}); + i += 1 + end_len; + } else { + ranges_.push_back(char_range{start, start}); + } + } + } + + parser_type type() const override { return PARSER_CHAR_CLASS; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + + for (const auto & range : ranges_) { + if (range.contains(ctx.input[start])) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + } + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + std::string dump() const override { + return "Char(" + pattern_ + ")"; + } +}; + +class group_parser : public parser_base { + std::string name_; + parser parser_; + + public: + group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} + + parser_type type() const override { return PARSER_GROUP; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); + + // Store result + result.groups[name_] = parser_match_location{result.start, result.end}; + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Group(" + name_ + ", " + parser_->dump() + ")"; + } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class rule_parser : public parser_base { + std::string rule_name_; + std::shared_ptr> rules_; + + public: + rule_parser(const std::string & name, std::shared_ptr> rules, int id) + : parser_base(id), rule_name_(name), rules_(std::move(rules)) {} + + parser_type type() const override { return PARSER_RULE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + if (!rules_) { + LOG_ERR("rule_parser::parse called without rule registry\n"); + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + auto it = rules_->find(rule_name_); + if (it == rules_->end()) { + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", rule_name_.c_str()); + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + auto result = it->second->parse(ctx, start); + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Rule(" + rule_name_ + ")"; + } +}; + +std::optional parser_result::group(const std::string & name, std::string_view input) const { + auto it = groups.find(name); + if (it == groups.end()) { + return std::nullopt; + } + + return std::string(it->second.view(input)); +} + +parser_result parse_cache::set(int id, size_t start, parser_result result) { + if (id == -1) { + // Don't cache parsers with ID -1 (from operators and global factory functions) + return result; + } + results[parse_cache_key{id, start}] = result; + return result; +} + +std::optional parse_cache::get(int id, size_t start) { + if (id == -1) { + // Don't cache parsers with ID -1 (from operators and global factory functions) + return std::nullopt; + } + auto it = results.find(parse_cache_key{id, start}); + if (it != results.end()) { + return it->second; + } + return std::nullopt; +} + +void parse_cache::clear() { + results.clear(); +} + +parser::parser() {} + +parser::parser(std::shared_ptr parser) : ptr(std::move(parser)) {} + +parser parser::operator~() const { + return parser(std::make_shared(*this, -1)); +} + +parser parser::operator+(const parser & other) const { + return parser(std::shared_ptr(new sequence_parser({*this, other}, -1))); +} + +parser parser::operator|(const parser & other) const { + return parser(std::shared_ptr(new choice_parser({*this, other}, -1))); +} + +parser_base & parser::operator*() const { + return *ptr; +} + +parser_base * parser::operator->() const { + return ptr.get(); +} + +bool parser::is_sequence() const { + return ptr->type() == PARSER_SEQUENCE; +} + +std::shared_ptr parser::to_sequence() const { + return std::dynamic_pointer_cast(ptr); +} + +bool parser::is_choice() const { + return ptr->type() == PARSER_CHOICE; +} + +std::shared_ptr parser::to_choice() const { + return std::dynamic_pointer_cast(ptr); +} + +parser_type parser::type() const { + return ptr->type(); +} + +parser_result parser::parse(parser_context & ctx, size_t start) const { + return ptr->parse(ctx, start); +} + +std::string parser::dump() const { + return ptr->dump(); +} + +parser_builder::parser_builder() + : rules_(std::make_shared>()) + , next_id_(0) {} + +parser parser_builder::literal(const std::string & literal) { + return parser(std::make_shared(literal, next_id_++)); +} + +parser parser_builder::sequence(std::initializer_list parsers) { + return parser(std::shared_ptr(new sequence_parser(parsers, next_id_++))); +} + +parser parser_builder::choice(std::initializer_list parsers) { + return parser(std::shared_ptr(new choice_parser(parsers, next_id_++))); +} + +parser parser_builder::one_or_more(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::zero_or_more(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::optional(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::negate(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::any() { + return parser(std::make_shared(next_id_++)); +} + +parser parser_builder::char_class(const std::string & classes) { + return parser(std::make_shared(classes, next_id_++)); +} + +parser parser_builder::group(const std::string & name, const parser & p) { + return parser(std::make_shared(name, p, next_id_++)); +} + +parser parser_builder::rule(const std::string & name) { + return parser(std::make_shared(name, rules_, next_id_++)); +} + +parser parser_builder::space() { + return zero_or_more(char_class("[ \\t\\n\\r]")); +} + +parser parser_builder::add_rule(const std::string & name, const parser & p) { + (*rules_)[name] = p; + return rule(name); +} + +void parser_builder::assign_ids(parser & p) { + if (p.ptr) { + p.ptr->assign_ids_internal(next_id_); + } +} + +parser parser_builder::add_json_rule(const std::string & name) { + // Whitespace: space, tab, newline, carriage return + auto ws = zero_or_more(char_class("[ \\t\\n\\r]")); + + // Number components + auto digit = char_class("[0-9]"); + auto digit1_9 = char_class("[1-9]"); + auto digits = one_or_more(digit); + + // Integer part: 0 or non-zero digit followed by more digits + auto int_part = literal("0") | (digit1_9 + zero_or_more(digit)); + + // Optional fractional part + auto frac = literal(".") + digits; + + // Optional exponent part + auto exp = (literal("e") | literal("E")) + optional(char_class("[+\\-]")) + digits; + + // Complete number + auto number = optional(literal("-")) + int_part + optional(frac) + optional(exp); + + add_rule("json_number", number); + + // String components + auto hex = char_class("[0-9a-fA-F]"); + auto unicode_escape = literal("\\u") + hex + hex + hex + hex; + auto simple_escape = literal("\\") + char_class("[\"\\\\bfnrt/]"); + auto escape = simple_escape | unicode_escape; + + // String character: escape sequence or any char except quote and backslash + auto string_char = escape | (~char_class("[\"\\\\]") + any()); + auto string = literal("\"") + zero_or_more(string_char) + literal("\""); + + add_rule("json_string", string); + + // Literals + auto true_lit = literal("true"); + auto false_lit = literal("false"); + auto null_lit = literal("null"); + + // Value - uses forward references for recursive structures + add_rule("json_value", + rule("json_object") | + rule("json_array") | + rule("json_string") | + rule("json_number") | + true_lit | + false_lit | + null_lit + ); + + // Object: { "key": value, ... } + auto member = rule("json_string") + ws + literal(":") + ws + rule("json_value"); + auto members = member + zero_or_more(ws + literal(",") + ws + member); + + // Empty object or object with members + auto object = (literal("{") + ws + literal("}")) | + (literal("{") + ws + members + ws + literal("}")); + + add_rule("json_object", object); + + // Array: [ value, ... ] + auto elements = rule("json_value") + zero_or_more(ws + literal(",") + ws + rule("json_value")); + + // Empty array or array with elements + auto array = (literal("[") + ws + literal("]")) | + (literal("[") + ws + elements + ws + literal("]")); + + add_rule("json_array", array); + + // Register the main rule with the provided name + return add_rule(name, rule("json_value")); +} + +parser build_parser(const std::function & fn) { + parser_builder builder; + auto root = fn(builder); + builder.assign_ids(root); // Assign IDs to rules that were created with operators + return root; +} diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h new file mode 100644 index 0000000000000..72adf523c489a --- /dev/null +++ b/common/chat-parser-combinator.h @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +enum parser_type { + PARSER_LITERAL = 0, + PARSER_SEQUENCE = 1, + PARSER_CHOICE = 2, + PARSER_ZERO_OR_MORE = 3, + PARSER_ONE_OR_MORE = 4, + PARSER_NOT = 5, + PARSER_ANY = 6, + PARSER_CHAR_CLASS = 7, + PARSER_GROUP = 8, + PARSER_RULE = 9, + PARSER_OPTIONAL = 10, +}; + +enum parser_result_type { + PARSER_RESULT_FAIL = 0, + PARSER_RESULT_NEED_MORE_INPUT = 1, + PARSER_RESULT_SUCCESS = 2, +}; + +struct parse_cache_key { + int id; + size_t start; + + bool operator==(const parse_cache_key & other) const { + return id == other.id && start == other.start; + } +}; + +template <> +struct std::hash { + std::size_t operator()(const parse_cache_key & k) const { + return std::hash{}(((size_t)k.id << 32) | k.start); + } +}; + +struct parser_match_location { + size_t start; + size_t end; + + size_t length() const { return end - start; } + + std::string_view view(std::string_view sv) const { + return sv.substr(start, length()); + } +}; + +struct parser_result { + parser_result_type type = PARSER_RESULT_FAIL; + size_t start = 0; + size_t end = 0; + + std::unordered_map groups; + + parser_result() : type(PARSER_RESULT_FAIL) {} + parser_result(parser_result_type type, size_t start) : type(type), start(start), end(start) {} + parser_result(parser_result_type type, size_t start, size_t end) : type(type), start(start), end(end) {} + parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) : type(type), start(start), end(end), groups(groups) {} + + bool is_fail() const { return type == PARSER_RESULT_FAIL; } + bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } + bool is_success() const { return type == PARSER_RESULT_SUCCESS; } + + std::optional group(const std::string & name, std::string_view input) const; +}; + +class parse_cache { + std::unordered_map results; + + public: + parser_result set(int id, size_t start, parser_result result); + std::optional get(int id, size_t start); + void clear(); +}; + +class parser; + +struct parser_context { + std::string_view input; + parse_cache memo; + bool input_is_complete = true; +}; + +class parser_base; +class sequence_parser; +class choice_parser; +class parser_builder; + +class parser { + std::shared_ptr ptr; + + friend class parser_builder; + + public: + parser(); + parser(std::shared_ptr parser); + parser(const parser & other) = default; + parser & operator=(const parser & other) { + if (this != &other) { + ptr = other.ptr; + } + return *this; + } + + parser operator~() const; + parser operator+(const parser & other) const; + parser operator|(const parser & other) const; + + parser_base & operator*() const; + parser_base * operator->() const; + + bool is_sequence() const; + std::shared_ptr to_sequence() const; + + bool is_choice() const; + std::shared_ptr to_choice() const; + + parser_type type() const; + parser_result parse(parser_context & ctx, size_t start = 0) const; + std::string dump() const; +}; + +class parser_builder { + std::shared_ptr> rules_; + int next_id_; + + public: + parser_builder(); + + parser literal(const std::string & literal); + parser sequence(std::initializer_list parsers); + parser choice(std::initializer_list parsers); + parser one_or_more(const parser & p); + parser zero_or_more(const parser & p); + parser optional(const parser & p); + parser negate(const parser & p); + parser any(); + parser char_class(const std::string & classes); + parser group(const std::string & name, const parser & p); + parser rule(const std::string & name); + parser space(); + + parser add_rule(const std::string & name, const parser & p); + parser add_json_rule(const std::string & name); + + void assign_ids(parser & p); +}; + +parser build_parser(const std::function & fn); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..90badf62af667 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -180,6 +180,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) endif() llama_build_and_test(test-chat-parser.cpp) +llama_build_and_test(test-chat-parser-combinator.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp new file mode 100644 index 0000000000000..55a443aed3e38 --- /dev/null +++ b/tests/test-chat-parser-combinator.cpp @@ -0,0 +1,472 @@ +#include +#include + +#include "chat-parser-combinator.h" + +template +static void assert_equals(const std::string_view label, const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << label << "\n"; + std::cerr << "Expected: " << expected << "\n"; + std::cerr << "Actual: " << actual << "\n"; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +template +static void assert_equals(const T & expected, const T & actual) { + assert_equals("", expected, actual); +} + +static void assert_equals(const char * expected, const std::string & actual) { + assert_equals(expected, actual); +} + +static void test_partial_parsing() { + { + // Test literal + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"hello", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + } + { + // Test char class + auto parser = build_parser([](parser_builder& p) { + return p.char_class("a-z"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"a", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"A", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + + parser = build_parser([](parser_builder& p) { + return p.char_class("a-z-"); + }); + + ctx = parser_context{"f", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"-", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"A", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test sequences and literals + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + p.literal(""); + }); + + // Partial matches + auto ctx = parser_context{"", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match, since it does not adhere to the grammar + ctx = parser_context{"I am parser", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test choices + auto parser = build_parser([](parser_builder& p) { + return p.literal("") | p.literal(""); + }); + + // Partial matches + auto ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match + ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test zero_or_more + auto parser = build_parser([](parser_builder& p) { + return p.zero_or_more(p.literal("ab")); + }); + + // Partial matches + auto ctx = parser_context{"a", parse_cache(), false}; + auto result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + ctx = parser_context{"aba", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + // Full match + ctx = parser_context{"ab", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + } + { + // Test one_or_more + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.literal("ab")); + }); + + // Partial matches + auto ctx = parser_context{"a", parse_cache(), false}; + auto result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + ctx = parser_context{"aba", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + // Full match + ctx = parser_context{"ab", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match + ctx = parser_context{"cd", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } +} + +static void test_capture_groups() { + { + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + + p.group("reasoning_content", + p.zero_or_more(~p.literal("") + p.any()) + ) + + p.literal(""); + }); + + std::string input = "I have a thought"; + auto ctx = parser_context{input, parse_cache()}; + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + + auto it = result.groups.find("reasoning_content"); + assert_equals(true, it != result.groups.end()); + assert_equals("I have a thought", std::string(it->second.view(input))); + } + { + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + + p.group("reasoning_content", + p.zero_or_more(~p.literal("") + p.any()) + ) + + p.literal(""); + }); + + std::string input = "I have a "; + auto ctx = parser_context{input, parse_cache(), false}; + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + + auto it = result.groups.find("reasoning_content"); + assert_equals(true, it != result.groups.end()); + assert_equals("I have a ", std::string(it->second.view(input))); + } + { + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + + p.group("reasoning_content", + p.zero_or_more(~p.literal("") + p.any()) + ) + + p.literal("") + + p.group("content", p.zero_or_more(p.any())); + }); + + std::string input = "The user said hello.Hello!"; + auto ctx = parser_context{input, parse_cache(), true}; + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + + auto it = result.groups.find("reasoning_content"); + assert_equals(true, it != result.groups.end()); + assert_equals("The user said hello.", std::string(it->second.view(input))); + + it = result.groups.find("content"); + assert_equals(true, it != result.groups.end()); + assert_equals("Hello!", std::string(it->second.view(input))); + } +} + +static void test_char_class() { + { + // Test common escape sequences + auto parser = build_parser([](parser_builder& p) { + return p.char_class("[\\n\\t\\\\]"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"\n", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"\t", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"\\", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{" ", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test escaped dash (literal dash, not a range) + auto parser = build_parser([](parser_builder& p) { + return p.char_class("[a\\-z]"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"a", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"-", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"z", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Should NOT match 'b' since \- is a literal dash, not a range + ctx = parser_context{"b", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } +} + +static void test_recursive_references() { + auto value_parser = build_parser([](parser_builder& p) { + p.add_rule("number", p.one_or_more(p.char_class("0-9"))); + p.add_rule("list", p.sequence({ + p.literal("["), + p.rule("value"), + p.literal("]") + })); + return p.add_rule("value", p.rule("number") | p.rule("list")); + }); + + parser_context ctx; + parser_result result; + + // Test simple number + ctx = parser_context{"1", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test simple list + ctx = parser_context{"[1]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test nested list + ctx = parser_context{"[[2]]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test deeply nested list + ctx = parser_context{"[[[3]]]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test partial match + ctx = parser_context{"[[", parse_cache(), false}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test no match + ctx = parser_context{"[a]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_fail()); +} + +static void test_optional() { + // Test optional with a match + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + // Full match with optional part present + auto ctx = parser_context{"hello world", parse_cache()}; + auto result = parser.parse(ctx); + assert_equals(true, result.is_success()); + assert_equals((size_t)11, result.end); + + // Full match with optional part absent + ctx = parser_context{"hello", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + assert_equals((size_t)5, result.end); + + // Partial match - waiting for more input to determine if optional matches + ctx = parser_context{"hello ", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); +} + +static void test_json_parser() { + auto json = build_parser([](parser_builder & p) { + return p.add_json_rule("json"); + }); + + // Test parsing a simple JSON object + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + parser_context ctx{input, parse_cache()}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); +} + +static void test_complete_example() { + auto parser = build_parser([](parser_builder & p) { + auto space = p.add_rule("space", p.space()); + + auto reasoning = p.add_rule("reasoning", + p.literal("") + space + + p.group("reasoning-content", + p.zero_or_more(~(space + p.literal("")) + p.any())) + + space + p.literal("")); + + auto content = p.add_rule("content", + p.group("content", + p.zero_or_more(~(space + p.literal("")) + p.any()))); + + auto ident_chars = p.add_rule("ident-chars", p.char_class("[a-zA-Z\\-_]")); + auto json = p.add_json_rule("json"); + + auto tool_call_name = p.add_rule("tool-call-name", + p.literal("") + space + + p.group("tool-name", p.one_or_more(~p.literal("") + ident_chars)) + + space + p.literal("")); + + auto tool_call_args = p.add_rule("tool-call-args", + p.literal("") + space + + p.group("tool-args", json) + + space + p.literal("")); + + auto tool_call = p.add_rule("tool-call", + p.literal("") + space + + tool_call_name + space + + tool_call_args + space + + p.literal("")); + + return p.add_rule("root", reasoning + p.optional(content) + p.optional(tool_call)); + }); + + // Test complete input + std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; + parser_context ctx{input, parse_cache()}; + + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + assert_equals(std::string("I need to call get_weather with city = New York"), *result.group("reasoning-content", ctx.input)); + assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); + assert_equals(std::string(R"({"city": "New York"})"), *result.group("tool-args", ctx.input)); + + // Test partial input + input = R"(I need to call get_weather )"; + ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + + input = R"(I need to call get_weatherget_weather)"; + ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + + input = R"(I need to call get_weatherget_weatherI need to call get_weatherget_weather{"cit)"; + ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); + assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); +} + +int main() { + test_partial_parsing(); + test_char_class(); + test_capture_groups(); + test_recursive_references(); + test_optional(); + test_json_parser(); + test_complete_example(); + std::cout << "All tests passed!\n"; + return 0; +} From e6153bb14a0728ed7fcce7dadf54c3c7f670dca0 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 9 Nov 2025 22:34:24 -0600 Subject: [PATCH 02/16] add virtual destructor to parser_base --- common/chat-parser-combinator.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index f2182980b3bac..a8d8b10fe4fcd 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -13,6 +13,7 @@ class parser_base { public: parser_base(int id) : id_(id) {} + virtual ~parser_base() = default; virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; From 4ced9996e65817c0da27acca05d64b3acb6a6a08 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 9 Nov 2025 22:52:20 -0600 Subject: [PATCH 03/16] fix memory leak from circular references of rules --- common/chat-parser-combinator.cpp | 53 +++++++++++++++++++++++++------ common/chat-parser-combinator.h | 2 ++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index a8d8b10fe4fcd..cd915afcd7b88 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -555,11 +555,11 @@ class group_parser : public parser_base { class rule_parser : public parser_base { std::string rule_name_; - std::shared_ptr> rules_; + std::weak_ptr> rules_; public: rule_parser(const std::string & name, std::shared_ptr> rules, int id) - : parser_base(id), rule_name_(name), rules_(std::move(rules)) {} + : parser_base(id), rule_name_(name), rules_(rules) {} parser_type type() const override { return PARSER_RULE; } @@ -569,13 +569,14 @@ class rule_parser : public parser_base { return *cached; } - if (!rules_) { - LOG_ERR("rule_parser::parse called without rule registry\n"); + auto rules = rules_.lock(); + if (!rules) { + LOG_ERR("rule_parser::parse called with expired rule registry\n"); return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } - auto it = rules_->find(rule_name_); - if (it == rules_->end()) { + auto it = rules->find(rule_name_); + if (it == rules->end()) { LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", rule_name_.c_str()); return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } @@ -589,6 +590,32 @@ class rule_parser : public parser_base { } }; +class root_parser : public parser_base { + parser root_; + std::shared_ptr> rules_; + + public: + root_parser(const parser & root, std::shared_ptr> rules, int id) + : parser_base(id), root_(root), rules_(std::move(rules)) {} + + parser_type type() const override { return root_->type(); } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + return root_->parse(ctx, start); + } + + std::string dump() const override { + return root_->dump(); + } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + root_->assign_ids_internal(next_id); + } +}; + std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); if (it == groups.end()) { @@ -632,11 +659,11 @@ parser parser::operator~() const { } parser parser::operator+(const parser & other) const { - return parser(std::shared_ptr(new sequence_parser({*this, other}, -1))); + return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } parser parser::operator|(const parser & other) const { - return parser(std::shared_ptr(new choice_parser({*this, other}, -1))); + return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } parser_base & parser::operator*() const { @@ -684,11 +711,11 @@ parser parser_builder::literal(const std::string & literal) { } parser parser_builder::sequence(std::initializer_list parsers) { - return parser(std::shared_ptr(new sequence_parser(parsers, next_id_++))); + return parser(std::make_shared(parsers, next_id_++)); } parser parser_builder::choice(std::initializer_list parsers) { - return parser(std::shared_ptr(new choice_parser(parsers, next_id_++))); + return parser(std::make_shared(parsers, next_id_++)); } parser parser_builder::one_or_more(const parser & p) { @@ -816,5 +843,11 @@ parser build_parser(const std::function & fn) { parser_builder builder; auto root = fn(builder); builder.assign_ids(root); // Assign IDs to rules that were created with operators + + // Wrap the root parser in a root_parser to own the rules and break circular references + auto rules = builder.rules(); + if (rules && !rules->empty()) { + return parser(std::make_shared(root, rules, -1)); + } return root; } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 72adf523c489a..edebd0bef75db 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -153,6 +153,8 @@ class parser_builder { parser add_json_rule(const std::string & name); void assign_ids(parser & p); + + std::shared_ptr> rules() const { return rules_; } }; parser build_parser(const std::function & fn); From 2a9a13de753dafa7e3caa64b9cd5dafcdf8bda49 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 03:44:21 -0600 Subject: [PATCH 04/16] implement gbnf grammar building --- common/chat-parser-combinator.cpp | 570 +++++++++++++++++++++++--- common/chat-parser-combinator.h | 14 +- tests/test-chat-parser-combinator.cpp | 276 +++++++++++-- 3 files changed, 782 insertions(+), 78 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index cd915afcd7b88..897c4f6f75b4b 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1,10 +1,17 @@ #include "chat-parser-combinator.h" +#include "json-schema-to-grammar.h" #include "common.h" #include "log.h" +#include + #include #include +class gbnf_visitor; + +static parser json_parser(); + class parser_base { protected: int id_; @@ -18,6 +25,7 @@ class parser_base { virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; virtual std::string dump() const = 0; + virtual std::string accept(gbnf_visitor & visitor) const = 0; virtual void assign_ids_internal(int& next_id) { if (id_ == -1) { id_ = next_id++; @@ -28,6 +36,8 @@ class parser_base { class literal_parser : public parser_base { std::string literal_; + friend class gbnf_visitor; + public: literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} @@ -62,11 +72,15 @@ class literal_parser : public parser_base { std::string dump() const override { return "Literal(" + literal_ + ")"; } + + std::string accept(gbnf_visitor & visitor) const override; }; class sequence_parser : public parser_base { std::vector parsers_; + friend class gbnf_visitor; + public: sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -125,6 +139,8 @@ class sequence_parser : public parser_base { return "Sequence(" + string_join(parts, ", ") + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const std::vector & parsers() const { return parsers_; } void assign_ids_internal(int& next_id) override { @@ -140,6 +156,8 @@ class sequence_parser : public parser_base { class choice_parser : public parser_base { std::vector parsers_; + friend class gbnf_visitor; + public: choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -187,6 +205,8 @@ class choice_parser : public parser_base { return "Choice(" + string_join(parts, ", ") + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const std::vector & parsers() const { return parsers_; } void assign_ids_internal(int& next_id) override { @@ -202,6 +222,8 @@ class choice_parser : public parser_base { class one_or_more_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -257,6 +279,8 @@ class one_or_more_parser : public parser_base { return "OneOrMore(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -270,6 +294,8 @@ class one_or_more_parser : public parser_base { class zero_or_more_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -315,6 +341,8 @@ class zero_or_more_parser : public parser_base { return "ZeroOrMore(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -328,6 +356,8 @@ class zero_or_more_parser : public parser_base { class optional_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -359,6 +389,8 @@ class optional_parser : public parser_base { return "Optional(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -369,9 +401,55 @@ class optional_parser : public parser_base { } }; +class until_parser : public parser_base { + std::string delimiter_; + bool include_spaces_; + parser parser_; + + friend class gbnf_visitor; + + public: + until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) + : parser_base(id), delimiter_(delimiter), include_spaces_(include_spaces) { + if (include_spaces) { + auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); + parser_ = builder.zero_or_more(builder.negate(ws + builder.literal(delimiter)) + builder.any()); + } else { + parser_ = builder.zero_or_more(builder.negate(builder.literal(delimiter)) + builder.any()); + } + } + + parser_type type() const override { return PARSER_UNTIL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Until(" + delimiter_ + ")"; + } + + std::string accept(gbnf_visitor & visitor) const override; + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + class not_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -403,6 +481,8 @@ class not_parser : public parser_base { return "Not(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -414,6 +494,8 @@ class not_parser : public parser_base { }; class any_parser : public parser_base { + friend class gbnf_visitor; + public: any_parser(int id) : parser_base(id) {} @@ -438,6 +520,42 @@ class any_parser : public parser_base { std::string dump() const override { return "Any"; } + + std::string accept(gbnf_visitor & visitor) const override; +}; + +class space_parser : public parser_base { + friend class gbnf_visitor; + + public: + space_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_SPACE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto pos = start; + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + ++pos; + } else { + break; + } + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + } + + std::string dump() const override { + return "Space"; + } + + std::string accept(gbnf_visitor & visitor) const override; }; class char_class_parser : public parser_base { @@ -450,9 +568,12 @@ class char_class_parser : public parser_base { std::string pattern_; std::vector ranges_; + bool negated_; + + friend class gbnf_visitor; public: - char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes) { + char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes), negated_(false) { std::string content = classes; if (content.front() == '[') { content = content.substr(1); @@ -462,6 +583,12 @@ class char_class_parser : public parser_base { content.pop_back(); } + // Check for negation + if (!content.empty() && content.front() == '^') { + negated_ = true; + content = content.substr(1); + } + auto parse_char = [&](size_t pos) -> std::pair { if (content[pos] == '\\' && pos + 1 < content.length()) { char next = content[pos + 1]; @@ -510,24 +637,39 @@ class char_class_parser : public parser_base { return parser_result(PARSER_RESULT_FAIL, start); } + bool matches = false; for (const auto & range : ranges_) { if (range.contains(ctx.input[start])) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + matches = true; + break; } } + // If negated, invert the match result + if (negated_) { + matches = !matches; + } + + if (matches) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + } + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } std::string dump() const override { return "Char(" + pattern_ + ")"; } + + std::string accept(gbnf_visitor & visitor) const override; }; class group_parser : public parser_base { std::string name_; parser parser_; + friend class gbnf_visitor; + public: group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} @@ -545,6 +687,8 @@ class group_parser : public parser_base { return "Group(" + name_ + ", " + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + void assign_ids_internal(int& next_id) override { if (id_ == -1) { id_ = next_id++; @@ -553,10 +697,36 @@ class group_parser : public parser_base { } }; +class schema_parser : public parser_base { + parser parser_; + std::string name_; + nlohmann::ordered_json schema_; + + friend class gbnf_visitor; + + public: + schema_parser(const parser & parser, const std::string & name, const nlohmann::ordered_json & schema, int id) + : parser_base(id), parser_(parser), name_(name), schema_(schema) {} + + parser_type type() const override { return PARSER_SCHEMA; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + return parser_->parse(ctx, start); + } + + std::string dump() const override { + return "Schema(" + parser_->dump() + ", " + schema_.dump() + ")"; + } + + std::string accept(gbnf_visitor & visitor) const override; +}; + class rule_parser : public parser_base { std::string rule_name_; std::weak_ptr> rules_; + friend class gbnf_visitor; + public: rule_parser(const std::string & name, std::shared_ptr> rules, int id) : parser_base(id), rule_name_(name), rules_(rules) {} @@ -588,12 +758,16 @@ class rule_parser : public parser_base { std::string dump() const override { return "Rule(" + rule_name_ + ")"; } + + std::string accept(gbnf_visitor & visitor) const override; }; class root_parser : public parser_base { parser root_; std::shared_ptr> rules_; + friend class gbnf_visitor; + public: root_parser(const parser & root, std::shared_ptr> rules, int id) : parser_base(id), root_(root), rules_(std::move(rules)) {} @@ -608,6 +782,8 @@ class root_parser : public parser_base { return root_->dump(); } + std::string accept(gbnf_visitor & visitor) const override; + void assign_ids_internal(int& next_id) override { if (id_ == -1) { id_ = next_id++; @@ -616,6 +792,269 @@ class root_parser : public parser_base { } }; +class gbnf_visitor { + common_grammar_builder& builder_; + std::unordered_map rule_name_mapping_; + + public: + gbnf_visitor(common_grammar_builder& builder) : builder_(builder) {} + + private: + // Escape special characters for GBNF literals + static std::string escape_literal(const std::string & s) { + std::string escaped; + for (char c : s) { + switch (c) { + case '\n': escaped += "\\n"; break; + case '\t': escaped += "\\t"; break; + case '\r': escaped += "\\r"; break; + case '\\': escaped += "\\\\"; break; + case '"': escaped += "\\\""; break; + default: escaped += c; break; + } + } + return escaped; + } + + // Escape a single character for use in character classes + static std::string escape_char_class(char c) { + switch (c) { + case '\n': return "\\n"; + case '\t': return "\\t"; + case '\r': return "\\r"; + case '\\': return "\\\\"; + case ']': return "\\]"; + case '-': return "\\-"; + case '^': return "\\^"; + default: return std::string(1, c); + } + } + + // Generate pattern for until() that matches prefixes but prevents full delimiter match + // For "" generates: ( [^<] | "<" [^/] | " alternatives; + + // First alternative: match any character that's not the start of the delimiter + alternatives.push_back("[^" + escape_char_class(delimiter[0]) + "]"); + + // For each prefix, match the prefix followed by a char that's not the next delimiter char + for (size_t i = 1; i < delimiter.length(); ++i) { + std::string prefix = "\"" + escape_literal(delimiter.substr(0, i)) + "\""; + std::string next_char_negated = "[^" + escape_char_class(delimiter[i]) + "]"; + alternatives.push_back(prefix + " " + next_char_negated); + } + + // Combine alternatives with | + std::string result = "("; + for (size_t i = 0; i < alternatives.size(); ++i) { + if (i > 0) { + result += " | "; + } + result += alternatives[i]; + } + result += ")"; + + return result; + } + + // Check if expression needs parentheses + static bool needs_parens(parser_type type) { + return type == PARSER_CHOICE || type == PARSER_SEQUENCE; + } + + public: + std::string visit(const literal_parser & p) { + return "\"" + escape_literal(p.literal_) + "\""; + } + + std::string visit(const sequence_parser & p) { + std::string s; + for (size_t i = 0; i < p.parsers_.size(); ++i) { + if (i > 0) s += " "; + auto child_result = p.parsers_[i]->accept(*this); + s += child_result; + } + return s; + } + + std::string visit(const choice_parser & p) { + std::string s; + for (size_t i = 0; i < p.parsers_.size(); ++i) { + if (i > 0) { + s += " | "; + } + + auto child_type = p.parsers_[i]->type(); + auto child_result = p.parsers_[i]->accept(*this); + + // Parenthesize sequences in choices + if (child_type == PARSER_SEQUENCE) { + s += "(" + child_result + ")"; + } else { + s += child_result; + } + } + return s; + } + + std::string visit(const one_or_more_parser & p) { + auto child_type = p.parser_->type(); + auto child_result = p.parser_->accept(*this); + if (needs_parens(child_type)) { + return "(" + child_result + ")+"; + } + return child_result + "+"; + } + + std::string visit(const zero_or_more_parser & p) { + auto child_type = p.parser_->type(); + auto child_result = p.parser_->accept(*this); + if (needs_parens(child_type)) { + return "(" + child_result + ")*"; + } + return child_result + "*"; + } + + std::string visit(const optional_parser & p) { + auto child_type = p.parser_->type(); + auto child_result = p.parser_->accept(*this); + if (needs_parens(child_type)) { + return "(" + child_result + ")?"; + } + return child_result + "?"; + } + + std::string visit(const until_parser & p) { + // Generate pattern that matches prefixes but prevents full delimiter match + return generate_until_pattern(p.delimiter_) + "*"; + } + + std::string visit(const not_parser &) { + // NOT is tricky in GBNF - for now, emit error + LOG_ERR("NOT operator not directly supported in GBNF generation\n"); + return ""; // This will cause compilation errors, which is intended + } + + std::string visit(const any_parser &) { + // Match any single character + return "[\\x00-\\x{10FFFF}]"; + } + + std::string visit(const space_parser &) { + // Reference the built-in space rule + return "space"; + } + + std::string visit(const char_class_parser & p) { + // Return pattern as-is (already in GBNF format) + return p.pattern_; + } + + std::string visit(const group_parser & p) { + // Groups are transparent - just visit child + return p.parser_->accept(*this); + } + + std::string visit(const schema_parser & p) { + return builder_.add_schema(p.name_, p.schema_); + } + + std::string visit(const rule_parser & p) { + // Return canonical rule reference + auto it = rule_name_mapping_.find(p.rule_name_); + if (it != rule_name_mapping_.end()) { + return it->second; + } + // Fallback to original name if not in mapping (shouldn't happen in valid usage) + return p.rule_name_; + } + + std::string visit(const root_parser & p) { + // Generate named rules first + if (p.rules_) { + for (const auto & [name, rule] : *p.rules_) { + auto rule_body = rule->accept(*this); + auto canonical_name = builder_.add_rule(name, rule_body); + rule_name_mapping_[name] = canonical_name; + } + } + + // Return root body for composition + return p.root_->accept(*this); + } +}; + +// Implement accept() methods for all parser classes +std::string literal_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string sequence_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string choice_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string one_or_more_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string zero_or_more_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string optional_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string until_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string not_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string any_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string space_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string char_class_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string group_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string schema_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string rule_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string root_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); if (it == groups.end()) { @@ -666,6 +1105,11 @@ parser parser::operator|(const parser & other) const { return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } +parser parser::operator<<(const parser & other) const { + auto ws = parser(std::make_shared(-1)); + return parser(std::make_shared(std::initializer_list{*this, ws, other}, -1)); +} + parser_base & parser::operator*() const { return *ptr; } @@ -702,6 +1146,16 @@ std::string parser::dump() const { return ptr->dump(); } +void parser::build_grammar(common_grammar_builder& builder) const { + gbnf_visitor visitor(builder); + auto result = ptr->accept(visitor); + // The visitor returns the GBNF string for this parser + // root_parser registers its named rules and returns its root body + if (!result.empty()) { + builder.add_rule("root", result); + } +} + parser_builder::parser_builder() : rules_(std::make_shared>()) , next_id_(0) {} @@ -751,7 +1205,15 @@ parser parser_builder::rule(const std::string & name) { } parser parser_builder::space() { - return zero_or_more(char_class("[ \\t\\n\\r]")); + return parser(std::make_shared(next_id_++)); +} + +parser parser_builder::until(const std::string & delimiter, bool include_spaces) { + return parser(std::make_shared(delimiter, include_spaces, next_id_++, *this)); +} + +parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { + return parser(std::make_shared(p, name, schema, next_id_++)); } parser parser_builder::add_rule(const std::string & name, const parser & p) { @@ -765,89 +1227,99 @@ void parser_builder::assign_ids(parser & p) { } } -parser parser_builder::add_json_rule(const std::string & name) { +parser build_parser(const std::function & fn) { + parser_builder builder; + auto root = fn(builder); + builder.assign_ids(root); // Assign IDs to rules that were created with operators + + // Wrap the root parser in a root_parser to own the rules and break circular references + auto rules = builder.rules(); + if (rules && !rules->empty()) { + return parser(std::make_shared(root, rules, -1)); + } + return root; +} + +static parser json_parser() { + parser_builder builder; + // Whitespace: space, tab, newline, carriage return - auto ws = zero_or_more(char_class("[ \\t\\n\\r]")); + auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); // Number components - auto digit = char_class("[0-9]"); - auto digit1_9 = char_class("[1-9]"); - auto digits = one_or_more(digit); + auto digit = builder.char_class("[0-9]"); + auto digit1_9 = builder.char_class("[1-9]"); + auto digits = builder.one_or_more(digit); // Integer part: 0 or non-zero digit followed by more digits - auto int_part = literal("0") | (digit1_9 + zero_or_more(digit)); + auto int_part = builder.literal("0") | (digit1_9 + builder.zero_or_more(digit)); // Optional fractional part - auto frac = literal(".") + digits; + auto frac = builder.literal(".") + digits; // Optional exponent part - auto exp = (literal("e") | literal("E")) + optional(char_class("[+\\-]")) + digits; + auto exp = (builder.literal("e") | builder.literal("E")) + builder.optional(builder.char_class("[+\\-]")) + digits; // Complete number - auto number = optional(literal("-")) + int_part + optional(frac) + optional(exp); + auto number = builder.optional(builder.literal("-")) + int_part + builder.optional(frac) + builder.optional(exp); - add_rule("json_number", number); + builder.add_rule("json_number", number); // String components - auto hex = char_class("[0-9a-fA-F]"); - auto unicode_escape = literal("\\u") + hex + hex + hex + hex; - auto simple_escape = literal("\\") + char_class("[\"\\\\bfnrt/]"); + auto hex = builder.char_class("[0-9a-fA-F]"); + auto unicode_escape = builder.literal("\\u") + hex + hex + hex + hex; + auto simple_escape = builder.literal("\\") + builder.char_class("[\"\\\\bfnrt/]"); auto escape = simple_escape | unicode_escape; // String character: escape sequence or any char except quote and backslash - auto string_char = escape | (~char_class("[\"\\\\]") + any()); - auto string = literal("\"") + zero_or_more(string_char) + literal("\""); + auto string_char = escape | builder.char_class("[^\"\\\\]"); + auto string = builder.literal("\"") + builder.zero_or_more(string_char) + builder.literal("\""); - add_rule("json_string", string); + builder.add_rule("json_string", string); // Literals - auto true_lit = literal("true"); - auto false_lit = literal("false"); - auto null_lit = literal("null"); + auto true_lit = builder.literal("true"); + auto false_lit = builder.literal("false"); + auto null_lit = builder.literal("null"); // Value - uses forward references for recursive structures - add_rule("json_value", - rule("json_object") | - rule("json_array") | - rule("json_string") | - rule("json_number") | + builder.add_rule("json_value", + builder.rule("json_object") | + builder.rule("json_array") | + builder.rule("json_string") | + builder.rule("json_number") | true_lit | false_lit | null_lit ); // Object: { "key": value, ... } - auto member = rule("json_string") + ws + literal(":") + ws + rule("json_value"); - auto members = member + zero_or_more(ws + literal(",") + ws + member); + auto member = builder.rule("json_string") + ws + builder.literal(":") + ws + builder.rule("json_value"); + auto members = member + builder.zero_or_more(ws + builder.literal(",") + ws + member); // Empty object or object with members - auto object = (literal("{") + ws + literal("}")) | - (literal("{") + ws + members + ws + literal("}")); + auto object = (builder.literal("{") + ws + builder.literal("}")) | + (builder.literal("{") + ws + members + ws + builder.literal("}")); - add_rule("json_object", object); + builder.add_rule("json_object", object); // Array: [ value, ... ] - auto elements = rule("json_value") + zero_or_more(ws + literal(",") + ws + rule("json_value")); + auto elements = builder.rule("json_value") + builder.zero_or_more(ws + builder.literal(",") + ws + builder.rule("json_value")); // Empty array or array with elements - auto array = (literal("[") + ws + literal("]")) | - (literal("[") + ws + elements + ws + literal("]")); + auto array = (builder.literal("[") + ws + builder.literal("]")) | + (builder.literal("[") + ws + elements + ws + builder.literal("]")); - add_rule("json_array", array); + builder.add_rule("json_array", array); - // Register the main rule with the provided name - return add_rule(name, rule("json_value")); -} + // Get the json_value rule as the root + auto root = builder.rule("json_value"); + builder.assign_ids(root); -parser build_parser(const std::function & fn) { - parser_builder builder; - auto root = fn(builder); - builder.assign_ids(root); // Assign IDs to rules that were created with operators + // Wrap in root_parser to own the rules + return parser(std::make_shared(root, builder.rules(), -1)); +} - // Wrap the root parser in a root_parser to own the rules and break circular references - auto rules = builder.rules(); - if (rules && !rules->empty()) { - return parser(std::make_shared(root, rules, -1)); - } - return root; +parser parser_builder::json() { + return json_parser(); } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index edebd0bef75db..1ef3996dc862a 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -7,6 +9,8 @@ #include #include +struct common_grammar_builder; + enum parser_type { PARSER_LITERAL = 0, PARSER_SEQUENCE = 1, @@ -19,6 +23,9 @@ enum parser_type { PARSER_GROUP = 8, PARSER_RULE = 9, PARSER_OPTIONAL = 10, + PARSER_UNTIL = 11, + PARSER_SPACE = 12, + PARSER_SCHEMA = 13, }; enum parser_result_type { @@ -94,6 +101,7 @@ class parser_base; class sequence_parser; class choice_parser; class parser_builder; +class gbnf_visitor; class parser { std::shared_ptr ptr; @@ -114,6 +122,7 @@ class parser { parser operator~() const; parser operator+(const parser & other) const; parser operator|(const parser & other) const; + parser operator<<(const parser & other) const; parser_base & operator*() const; parser_base * operator->() const; @@ -127,6 +136,7 @@ class parser { parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; std::string dump() const; + void build_grammar(common_grammar_builder& builder) const; }; class parser_builder { @@ -148,9 +158,11 @@ class parser_builder { parser group(const std::string & name, const parser & p); parser rule(const std::string & name); parser space(); + parser until(const std::string & delimiter, bool include_spaces = true); + parser json(); + parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); parser add_rule(const std::string & name, const parser & p); - parser add_json_rule(const std::string & name); void assign_ids(parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 55a443aed3e38..1fffbfbf040a0 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -2,6 +2,9 @@ #include #include "chat-parser-combinator.h" +#include "json-schema-to-grammar.h" +#include "nlohmann/json.hpp" +#include "nlohmann/json_fwd.hpp" template static void assert_equals(const std::string_view label, const T & expected, const T & actual) { @@ -365,53 +368,110 @@ static void test_optional() { static void test_json_parser() { auto json = build_parser([](parser_builder & p) { - return p.add_json_rule("json"); + return p.json(); }); - // Test parsing a simple JSON object - std::string input = R"({"name": "test", "value": 42, "flag": true})"; - parser_context ctx{input, parse_cache()}; + { + // Test parsing a simple JSON object + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + parser_context ctx{input, parse_cache()}; - auto result = json.parse(ctx); + auto result = json.parse(ctx); - assert_equals(true, result.is_success()); - assert_equals(input.size(), result.end); + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test parsing a JSON array with mixed types + std::string input = R"([1, "hello", true, null, 3.14])"; + parser_context ctx{input, parse_cache()}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test parsing nested JSON with objects and arrays + std::string input = R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; + parser_context ctx{input, parse_cache()}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test partial parsing - incomplete object + std::string input = R"({"name": "test", "value": )"; + parser_context ctx{input, parse_cache(), false}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } + { + // Test partial parsing - incomplete array + std::string input = R"([1, 2, 3, )"; + parser_context ctx{input, parse_cache(), false}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } + { + // Test partial parsing - incomplete nested structure + std::string input = R"({"data": {"nested": )"; + parser_context ctx{input, parse_cache(), false}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } } static void test_complete_example() { + // Parser for a fictitious model that outputs: + // + // + // ... reasoning content ... + // + // ... content ... + // + // tool_name + // { ... json args ... } + // + // auto parser = build_parser([](parser_builder & p) { - auto space = p.add_rule("space", p.space()); - auto reasoning = p.add_rule("reasoning", - p.literal("") + space + - p.group("reasoning-content", - p.zero_or_more(~(space + p.literal("")) + p.any())) + - space + p.literal("")); + p.literal("") + << p.group("reasoning-content", p.until("")) + << p.literal("")); auto content = p.add_rule("content", - p.group("content", - p.zero_or_more(~(space + p.literal("")) + p.any()))); + p.group("content", p.until(""))); - auto ident_chars = p.add_rule("ident-chars", p.char_class("[a-zA-Z\\-_]")); - auto json = p.add_json_rule("json"); + auto json = p.json(); auto tool_call_name = p.add_rule("tool-call-name", - p.literal("") + space + - p.group("tool-name", p.one_or_more(~p.literal("") + ident_chars)) + - space + p.literal("")); + p.literal("") + << p.group("tool-name", p.one_or_more(p.char_class("[a-zA-Z\\-_]"))) + << p.literal("")); + + auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); auto tool_call_args = p.add_rule("tool-call-args", - p.literal("") + space + - p.group("tool-args", json) + - space + p.literal("")); + p.literal("") + << p.group("tool-args", p.schema(json, "get_weather", schema)) + << p.literal("")); auto tool_call = p.add_rule("tool-call", - p.literal("") + space + - tool_call_name + space + - tool_call_args + space + - p.literal("")); + p.literal("") + << tool_call_name + << tool_call_args + << p.literal("")); - return p.add_rule("root", reasoning + p.optional(content) + p.optional(tool_call)); + return reasoning << p.optional(content) << p.optional(tool_call); }); // Test complete input @@ -457,6 +517,165 @@ static void test_complete_example() { assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + + std::cout << "Grammar:\n" << gbnf << "\n"; +} + +static void test_gbnf_generation() { + { + // Test literal + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); + assert_equals(true, gbnf.find("space ::=") != std::string::npos); + } + { + // Test char class + auto parser = build_parser([](parser_builder& p) { + return p.char_class("[a-z]"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= [a-z]") != std::string::npos); + } + { + // Test sequence + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.literal(" ") + p.literal("world"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" \" \"world\"") != std::string::npos); + } + { + // Test choice + auto parser = build_parser([](parser_builder& p) { + return p.literal("cat") | p.literal("dog"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"cat\" | \"dog\"") != std::string::npos); + } + { + // Test one_or_more + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.char_class("[0-9]")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= [0-9]+") != std::string::npos); + } + { + // Test zero_or_more + auto parser = build_parser([](parser_builder& p) { + return p.zero_or_more(p.char_class("[a-z]")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= [a-z]*") != std::string::npos); + } + { + // Test optional + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" world\"?") != std::string::npos); + } + { + // Test until + auto parser = build_parser([](parser_builder& p) { + return p.until(""); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + // Should generate pattern that prevents matching the full delimiter + assert_equals(true, gbnf.find("root ::= ([^<] | \"<\" [^/] | \"])*") != std::string::npos); + } + { + // Test groups are transparent + auto parser = build_parser([](parser_builder& p) { + return p.group("test", p.literal("hello")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); + } + { + // Test complex expression with parentheses + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.literal("a") | p.literal("b")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= (\"a\" | \"b\")+") != std::string::npos); + } + { + // Test rule references + auto parser = build_parser([](parser_builder& p) { + auto digit = p.add_rule("digit", p.char_class("[0-9]")); + return p.one_or_more(digit); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + // Should have digit rule defined and referenced + assert_equals(true, gbnf.find("digit ::= [0-9]") != std::string::npos); + assert_equals(true, gbnf.find("root ::= digit+") != std::string::npos); + } + { + // Test escaping in literals + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello\nworld\t!"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\\nworld\\t!\"") != std::string::npos); + } + { + // Test operator<< (whitespace insertion) + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") << p.literal("world"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + // Should inline the whitespace pattern + assert_equals(true, gbnf.find("\"hello\"") != std::string::npos); + assert_equals(true, gbnf.find("\"world\"") != std::string::npos); + } } int main() { @@ -467,6 +686,7 @@ int main() { test_optional(); test_json_parser(); test_complete_example(); + test_gbnf_generation(); std::cout << "All tests passed!\n"; return 0; } From 228653248e1994bbe82f99c90f7b8a607a34d461 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 04:06:59 -0600 Subject: [PATCH 05/16] remove unused private variable --- common/chat-parser-combinator.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 897c4f6f75b4b..1a340f6162656 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -403,14 +403,13 @@ class optional_parser : public parser_base { class until_parser : public parser_base { std::string delimiter_; - bool include_spaces_; parser parser_; friend class gbnf_visitor; public: until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) - : parser_base(id), delimiter_(delimiter), include_spaces_(include_spaces) { + : parser_base(id), delimiter_(delimiter) { if (include_spaces) { auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); parser_ = builder.zero_or_more(builder.negate(ws + builder.literal(delimiter)) + builder.any()); From 3e6662f66c030d2804736e515f554b4e6ed4ac11 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:17:09 -0600 Subject: [PATCH 06/16] create a base visitor and implement id assignment as a visitor --- common/chat-parser-combinator.cpp | 461 +++++++++++++++--------------- common/chat-parser-combinator.h | 4 +- 2 files changed, 230 insertions(+), 235 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 1a340f6162656..598fd74a93657 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -8,7 +8,7 @@ #include #include -class gbnf_visitor; +class id_assignment_visitor; static parser json_parser(); @@ -16,7 +16,7 @@ class parser_base { protected: int id_; - void set_id(int id) { id_ = id; } + friend class id_assignment_visitor; public: parser_base(int id) : id_(id) {} @@ -25,19 +25,12 @@ class parser_base { virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; virtual std::string dump() const = 0; - virtual std::string accept(gbnf_visitor & visitor) const = 0; - virtual void assign_ids_internal(int& next_id) { - if (id_ == -1) { - id_ = next_id++; - } - } + virtual void accept(parser_visitor & visitor) = 0; }; class literal_parser : public parser_base { std::string literal_; - friend class gbnf_visitor; - public: literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} @@ -73,14 +66,14 @@ class literal_parser : public parser_base { return "Literal(" + literal_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const std::string & literal() const { return literal_; } }; class sequence_parser : public parser_base { std::vector parsers_; - friend class gbnf_visitor; - public: sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -139,25 +132,14 @@ class sequence_parser : public parser_base { return "Sequence(" + string_join(parts, ", ") + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const std::vector & parsers() const { return parsers_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - for (auto & p : parsers_) { - p->assign_ids_internal(next_id); - } - } }; class choice_parser : public parser_base { std::vector parsers_; - friend class gbnf_visitor; - public: choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -205,25 +187,14 @@ class choice_parser : public parser_base { return "Choice(" + string_join(parts, ", ") + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const std::vector & parsers() const { return parsers_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - for (auto & p : parsers_) { - p->assign_ids_internal(next_id); - } - } }; class one_or_more_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -279,23 +250,14 @@ class one_or_more_parser : public parser_base { return "OneOrMore(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class zero_or_more_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -341,23 +303,14 @@ class zero_or_more_parser : public parser_base { return "ZeroOrMore(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class optional_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -389,24 +342,15 @@ class optional_parser : public parser_base { return "Optional(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class until_parser : public parser_base { std::string delimiter_; parser parser_; - friend class gbnf_visitor; - public: until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) : parser_base(id), delimiter_(delimiter) { @@ -434,21 +378,16 @@ class until_parser : public parser_base { return "Until(" + delimiter_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } + const std::string & delimiter() const { return delimiter_; } + + const parser & child() const { return parser_; } }; class not_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -480,21 +419,12 @@ class not_parser : public parser_base { return "Not(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class any_parser : public parser_base { - friend class gbnf_visitor; - public: any_parser(int id) : parser_base(id) {} @@ -520,12 +450,10 @@ class any_parser : public parser_base { return "Any"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; }; class space_parser : public parser_base { - friend class gbnf_visitor; - public: space_parser(int id) : parser_base(id) {} @@ -554,7 +482,7 @@ class space_parser : public parser_base { return "Space"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; }; class char_class_parser : public parser_base { @@ -569,8 +497,6 @@ class char_class_parser : public parser_base { std::vector ranges_; bool negated_; - friend class gbnf_visitor; - public: char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes), negated_(false) { std::string content = classes; @@ -660,15 +586,15 @@ class char_class_parser : public parser_base { return "Char(" + pattern_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const std::string & pattern() const { return pattern_; } }; class group_parser : public parser_base { std::string name_; parser parser_; - friend class gbnf_visitor; - public: group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} @@ -686,14 +612,9 @@ class group_parser : public parser_base { return "Group(" + name_ + ", " + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } + const parser & child() const { return parser_; } }; class schema_parser : public parser_base { @@ -701,8 +622,6 @@ class schema_parser : public parser_base { std::string name_; nlohmann::ordered_json schema_; - friend class gbnf_visitor; - public: schema_parser(const parser & parser, const std::string & name, const nlohmann::ordered_json & schema, int id) : parser_base(id), parser_(parser), name_(name), schema_(schema) {} @@ -717,18 +636,22 @@ class schema_parser : public parser_base { return "Schema(" + parser_->dump() + ", " + schema_.dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } + + const std::string & name() const { return name_; } + + const nlohmann::ordered_json & schema() const { return schema_; } }; class rule_parser : public parser_base { - std::string rule_name_; + std::string name_; std::weak_ptr> rules_; - friend class gbnf_visitor; - public: - rule_parser(const std::string & name, std::shared_ptr> rules, int id) - : parser_base(id), rule_name_(name), rules_(rules) {} + rule_parser(const std::string & name, const std::shared_ptr> & rules, int id) + : parser_base(id), name_(name), rules_(rules) {} parser_type type() const override { return PARSER_RULE; } @@ -744,9 +667,9 @@ class rule_parser : public parser_base { return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } - auto it = rules->find(rule_name_); + auto it = rules->find(name_); if (it == rules->end()) { - LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", rule_name_.c_str()); + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } @@ -755,17 +678,19 @@ class rule_parser : public parser_base { } std::string dump() const override { - return "Rule(" + rule_name_ + ")"; + return "Rule(" + name_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const std::string & name() const { return name_; } }; class root_parser : public parser_base { parser root_; std::shared_ptr> rules_; - friend class gbnf_visitor; + friend class parser_visitor; public: root_parser(const parser & root, std::shared_ptr> rules, int id) @@ -781,23 +706,45 @@ class root_parser : public parser_base { return root_->dump(); } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - root_->assign_ids_internal(next_id); - } + const parser & root() const { return root_; } + + std::shared_ptr> rules() const { return rules_; } +}; + +// Base visitor class for parser tree traversal +class parser_visitor { + public: + virtual ~parser_visitor() = default; + + virtual void visit(literal_parser & p) = 0; + virtual void visit(sequence_parser & p) = 0; + virtual void visit(choice_parser & p) = 0; + virtual void visit(one_or_more_parser & p) = 0; + virtual void visit(zero_or_more_parser & p) = 0; + virtual void visit(optional_parser & p) = 0; + virtual void visit(until_parser & p) = 0; + virtual void visit(not_parser & p) = 0; + virtual void visit(any_parser & p) = 0; + virtual void visit(space_parser & p) = 0; + virtual void visit(char_class_parser & p) = 0; + virtual void visit(group_parser & p) = 0; + virtual void visit(schema_parser & p) = 0; + virtual void visit(rule_parser & p) = 0; + virtual void visit(root_parser & p) = 0; }; -class gbnf_visitor { +class gbnf_visitor : public parser_visitor { common_grammar_builder& builder_; std::unordered_map rule_name_mapping_; + std::string current_result_; public: gbnf_visitor(common_grammar_builder& builder) : builder_(builder) {} + const std::string& result() const { return current_result_; } + private: // Escape special characters for GBNF literals static std::string escape_literal(const std::string & s) { @@ -872,187 +819,235 @@ class gbnf_visitor { } public: - std::string visit(const literal_parser & p) { - return "\"" + escape_literal(p.literal_) + "\""; + void visit(literal_parser & p) override { + current_result_ = "\"" + escape_literal(p.literal()) + "\""; } - std::string visit(const sequence_parser & p) { + void visit(sequence_parser & p) override { std::string s; - for (size_t i = 0; i < p.parsers_.size(); ++i) { - if (i > 0) s += " "; - auto child_result = p.parsers_[i]->accept(*this); - s += child_result; + for (const auto & child : p.parsers()) { + if (!s.empty()) { + s += " "; + } + child->accept(*this); + s += current_result_; } - return s; + current_result_ = s; } - std::string visit(const choice_parser & p) { + void visit(choice_parser & p) override { std::string s; - for (size_t i = 0; i < p.parsers_.size(); ++i) { - if (i > 0) { + for (const auto & child : p.parsers()) { + if (!s.empty()) { s += " | "; } - auto child_type = p.parsers_[i]->type(); - auto child_result = p.parsers_[i]->accept(*this); + child->accept(*this); // Parenthesize sequences in choices - if (child_type == PARSER_SEQUENCE) { - s += "(" + child_result + ")"; + if (child->type() == PARSER_SEQUENCE) { + s += "(" + current_result_ + ")"; } else { - s += child_result; + s += current_result_; } } - return s; + current_result_ = s; } - std::string visit(const one_or_more_parser & p) { - auto child_type = p.parser_->type(); - auto child_result = p.parser_->accept(*this); - if (needs_parens(child_type)) { - return "(" + child_result + ")+"; + void visit(one_or_more_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")+"; + } else { + current_result_ = current_result_ + "+"; } - return child_result + "+"; } - std::string visit(const zero_or_more_parser & p) { - auto child_type = p.parser_->type(); - auto child_result = p.parser_->accept(*this); - if (needs_parens(child_type)) { - return "(" + child_result + ")*"; + void visit(zero_or_more_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")*"; + } else { + current_result_ = current_result_ + "*"; } - return child_result + "*"; } - std::string visit(const optional_parser & p) { - auto child_type = p.parser_->type(); - auto child_result = p.parser_->accept(*this); - if (needs_parens(child_type)) { - return "(" + child_result + ")?"; + void visit(optional_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")?"; + } else { + current_result_ = current_result_ + "?"; } - return child_result + "?"; } - std::string visit(const until_parser & p) { + void visit(until_parser & p) override { // Generate pattern that matches prefixes but prevents full delimiter match - return generate_until_pattern(p.delimiter_) + "*"; + current_result_ = generate_until_pattern(p.delimiter()) + "*"; } - std::string visit(const not_parser &) { + void visit(not_parser &) override { // NOT is tricky in GBNF - for now, emit error LOG_ERR("NOT operator not directly supported in GBNF generation\n"); - return ""; // This will cause compilation errors, which is intended + current_result_ = ""; } - std::string visit(const any_parser &) { + void visit(any_parser &) override { // Match any single character - return "[\\x00-\\x{10FFFF}]"; + current_result_ = "[\\x00-\\x{10FFFF}]"; } - std::string visit(const space_parser &) { + void visit(space_parser &) override { // Reference the built-in space rule - return "space"; + current_result_ = "space"; } - std::string visit(const char_class_parser & p) { + void visit(char_class_parser & p) override { // Return pattern as-is (already in GBNF format) - return p.pattern_; + current_result_ = p.pattern(); } - std::string visit(const group_parser & p) { + void visit(group_parser & p) override { // Groups are transparent - just visit child - return p.parser_->accept(*this); + p.child()->accept(*this); } - std::string visit(const schema_parser & p) { - return builder_.add_schema(p.name_, p.schema_); + void visit(schema_parser & p) override { + current_result_ = builder_.add_schema(p.name(), p.schema()); } - std::string visit(const rule_parser & p) { + void visit(rule_parser & p) override { // Return canonical rule reference - auto it = rule_name_mapping_.find(p.rule_name_); + auto it = rule_name_mapping_.find(p.name()); if (it != rule_name_mapping_.end()) { - return it->second; + current_result_ = it->second; + } else { + // Fallback to original name if not in mapping (shouldn't happen in valid usage) + current_result_ = p.name(); } - // Fallback to original name if not in mapping (shouldn't happen in valid usage) - return p.rule_name_; } - std::string visit(const root_parser & p) { + void visit(root_parser & p) override { // Generate named rules first - if (p.rules_) { - for (const auto & [name, rule] : *p.rules_) { - auto rule_body = rule->accept(*this); + auto rules = p.rules(); + if (rules) { + for (const auto & [name, rule] : *rules) { + rule->accept(*this); + auto rule_body = current_result_; auto canonical_name = builder_.add_rule(name, rule_body); rule_name_mapping_[name] = canonical_name; } } // Return root body for composition - return p.root_->accept(*this); + p.root()->accept(*this); } }; -// Implement accept() methods for all parser classes -std::string literal_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} +// ID assignment visitor for assigning unique IDs to parsers +class id_assignment_visitor : public parser_visitor { + int & next_id_; -std::string sequence_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + public: + id_assignment_visitor(int & next_id) : next_id_(next_id) {} -std::string choice_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void assign_id(parser_base & p) { + if (p.id_ == -1) { + p.id_ = next_id_++; + } + } -std::string one_or_more_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(literal_parser & p) override { + assign_id(p); + } -std::string zero_or_more_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(any_parser & p) override { + assign_id(p); + } -std::string optional_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(space_parser & p) override { + assign_id(p); + } -std::string until_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(char_class_parser & p) override { + assign_id(p); + } -std::string not_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(schema_parser & p) override { + assign_id(p); + } -std::string any_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(rule_parser & p) override { + assign_id(p); + } -std::string space_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + // Composite parsers - assign ID and traverse children + void visit(sequence_parser & p) override { + assign_id(p); + for (const auto & child : p.parsers()) { + child->accept(*this); + } + } -std::string char_class_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(choice_parser & p) override { + assign_id(p); + for (const auto & child : p.parsers()) { + child->accept(*this); + } + } -std::string group_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(one_or_more_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } -std::string schema_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(zero_or_more_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } -std::string rule_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(optional_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } -std::string root_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(until_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + + void visit(not_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + + void visit(group_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + + void visit(root_parser & p) override { + assign_id(p); + p.root()->accept(*this); + } +}; + +// Implement accept() methods for all parser classes +void literal_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void sequence_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void choice_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void one_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void zero_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void optional_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void until_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void not_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void space_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void char_class_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void group_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void schema_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void rule_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void root_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); @@ -1145,11 +1140,10 @@ std::string parser::dump() const { return ptr->dump(); } -void parser::build_grammar(common_grammar_builder& builder) const { +void parser::build_grammar(common_grammar_builder& builder) { gbnf_visitor visitor(builder); - auto result = ptr->accept(visitor); - // The visitor returns the GBNF string for this parser - // root_parser registers its named rules and returns its root body + ptr->accept(visitor); + auto result = visitor.result(); if (!result.empty()) { builder.add_rule("root", result); } @@ -1222,7 +1216,8 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { void parser_builder::assign_ids(parser & p) { if (p.ptr) { - p.ptr->assign_ids_internal(next_id_); + id_assignment_visitor visitor(next_id_); + p.ptr->accept(visitor); } } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 1ef3996dc862a..f0cb1d24ff7bb 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -101,7 +101,7 @@ class parser_base; class sequence_parser; class choice_parser; class parser_builder; -class gbnf_visitor; +class parser_visitor; class parser { std::shared_ptr ptr; @@ -136,7 +136,7 @@ class parser { parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; std::string dump() const; - void build_grammar(common_grammar_builder& builder) const; + void build_grammar(common_grammar_builder& builder); }; class parser_builder { From 76cf0b5b6197d427e3c48aa4d24f549a3d3a4167 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:21:04 -0600 Subject: [PATCH 07/16] fix const ref for grammar builder --- common/chat-parser-combinator.cpp | 6 +-- common/chat-parser-combinator.h | 3 +- tests/test-chat-parser-combinator.cpp | 69 ++++++++++++++++----------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 598fd74a93657..aff72b67fd68d 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -736,12 +736,12 @@ class parser_visitor { }; class gbnf_visitor : public parser_visitor { - common_grammar_builder& builder_; + const common_grammar_builder & builder_; std::unordered_map rule_name_mapping_; std::string current_result_; public: - gbnf_visitor(common_grammar_builder& builder) : builder_(builder) {} + gbnf_visitor(const common_grammar_builder & builder) : builder_(builder) {} const std::string& result() const { return current_result_; } @@ -1140,7 +1140,7 @@ std::string parser::dump() const { return ptr->dump(); } -void parser::build_grammar(common_grammar_builder& builder) { +void parser::build_grammar(const common_grammar_builder & builder) { gbnf_visitor visitor(builder); ptr->accept(visitor); auto result = visitor.result(); diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index f0cb1d24ff7bb..e56a6adf24c17 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -136,7 +136,8 @@ class parser { parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; std::string dump() const; - void build_grammar(common_grammar_builder& builder); + + void build_grammar(const common_grammar_builder & builder); }; class parser_builder { diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 1fffbfbf040a0..e4f637af9f797 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -518,8 +518,8 @@ static void test_complete_example() { assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); std::cout << "Grammar:\n" << gbnf << "\n"; @@ -532,9 +532,10 @@ static void test_gbnf_generation() { return p.literal("hello"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); assert_equals(true, gbnf.find("space ::=") != std::string::npos); } @@ -544,9 +545,10 @@ static void test_gbnf_generation() { return p.char_class("[a-z]"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= [a-z]") != std::string::npos); } { @@ -555,9 +557,10 @@ static void test_gbnf_generation() { return p.literal("hello") + p.literal(" ") + p.literal("world"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" \" \"world\"") != std::string::npos); } { @@ -566,9 +569,10 @@ static void test_gbnf_generation() { return p.literal("cat") | p.literal("dog"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"cat\" | \"dog\"") != std::string::npos); } { @@ -577,9 +581,10 @@ static void test_gbnf_generation() { return p.one_or_more(p.char_class("[0-9]")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= [0-9]+") != std::string::npos); } { @@ -588,9 +593,10 @@ static void test_gbnf_generation() { return p.zero_or_more(p.char_class("[a-z]")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= [a-z]*") != std::string::npos); } { @@ -599,9 +605,10 @@ static void test_gbnf_generation() { return p.literal("hello") + p.optional(p.literal(" world")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" world\"?") != std::string::npos); } { @@ -610,9 +617,10 @@ static void test_gbnf_generation() { return p.until(""); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + // Should generate pattern that prevents matching the full delimiter assert_equals(true, gbnf.find("root ::= ([^<] | \"<\" [^/] | \"])*") != std::string::npos); } @@ -622,9 +630,10 @@ static void test_gbnf_generation() { return p.group("test", p.literal("hello")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); } { @@ -633,9 +642,10 @@ static void test_gbnf_generation() { return p.one_or_more(p.literal("a") | p.literal("b")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= (\"a\" | \"b\")+") != std::string::npos); } { @@ -645,9 +655,10 @@ static void test_gbnf_generation() { return p.one_or_more(digit); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + // Should have digit rule defined and referenced assert_equals(true, gbnf.find("digit ::= [0-9]") != std::string::npos); assert_equals(true, gbnf.find("root ::= digit+") != std::string::npos); @@ -658,9 +669,10 @@ static void test_gbnf_generation() { return p.literal("hello\nworld\t!"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\\nworld\\t!\"") != std::string::npos); } { @@ -669,9 +681,10 @@ static void test_gbnf_generation() { return p.literal("hello") << p.literal("world"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + // Should inline the whitespace pattern assert_equals(true, gbnf.find("\"hello\"") != std::string::npos); assert_equals(true, gbnf.find("\"world\"") != std::string::npos); From 9c7b3e8bcf57ea416d21a90d219c39d06e16b426 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:33:26 -0600 Subject: [PATCH 08/16] clean up types, friend classes, and class declarations --- common/chat-parser-combinator.cpp | 76 +++++++++++++++---------------- common/chat-parser-combinator.h | 39 ++-------------- 2 files changed, 43 insertions(+), 72 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index aff72b67fd68d..56215302df99e 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -8,7 +8,24 @@ #include #include -class id_assignment_visitor; +enum parser_type { + PARSER_LITERAL = 0, + PARSER_SEQUENCE = 1, + PARSER_CHOICE = 2, + PARSER_ZERO_OR_MORE = 3, + PARSER_ONE_OR_MORE = 4, + PARSER_NOT = 5, + PARSER_ANY = 6, + PARSER_CHAR_CLASS = 7, + PARSER_GROUP = 8, + PARSER_RULE = 9, + PARSER_OPTIONAL = 10, + PARSER_UNTIL = 11, + PARSER_SPACE = 12, + PARSER_SCHEMA = 13, +}; + +class parser_visitor; static parser json_parser(); @@ -16,12 +33,13 @@ class parser_base { protected: int id_; - friend class id_assignment_visitor; - public: parser_base(int id) : id_(id) {} virtual ~parser_base() = default; + int id() const { return id_; } + void set_id(int id) { id_ = id; } + virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; virtual std::string dump() const = 0; @@ -77,9 +95,10 @@ class sequence_parser : public parser_base { public: sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { - if (p.is_sequence()) { + if (p->type() == PARSER_SEQUENCE) { // Flatten sequences - for (const auto & embedded : p.to_sequence()->parsers()) { + auto seq = std::static_pointer_cast(p.ptr()); + for (const auto & embedded : seq->parsers()) { parsers_.push_back(embedded); } } else { @@ -143,9 +162,10 @@ class choice_parser : public parser_base { public: choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { - if (p.is_choice()) { + if (p->type() == PARSER_CHOICE) { // Flatten choices - for (const auto & embedded : p.to_choice()->parsers()) { + auto choice = std::static_pointer_cast(p.ptr()); + for (const auto & embedded : choice->parsers()) { parsers_.push_back(embedded); } } else { @@ -952,8 +972,8 @@ class id_assignment_visitor : public parser_visitor { id_assignment_visitor(int & next_id) : next_id_(next_id) {} void assign_id(parser_base & p) { - if (p.id_ == -1) { - p.id_ = next_id_++; + if (p.id() == -1) { + p.set_id(next_id_++); } } @@ -1085,7 +1105,7 @@ void parse_cache::clear() { parser::parser() {} -parser::parser(std::shared_ptr parser) : ptr(std::move(parser)) {} +parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} parser parser::operator~() const { return parser(std::make_shared(*this, -1)); @@ -1105,44 +1125,24 @@ parser parser::operator<<(const parser & other) const { } parser_base & parser::operator*() const { - return *ptr; + return *ptr_; } parser_base * parser::operator->() const { - return ptr.get(); -} - -bool parser::is_sequence() const { - return ptr->type() == PARSER_SEQUENCE; -} - -std::shared_ptr parser::to_sequence() const { - return std::dynamic_pointer_cast(ptr); -} - -bool parser::is_choice() const { - return ptr->type() == PARSER_CHOICE; -} - -std::shared_ptr parser::to_choice() const { - return std::dynamic_pointer_cast(ptr); -} - -parser_type parser::type() const { - return ptr->type(); + return ptr_.get(); } parser_result parser::parse(parser_context & ctx, size_t start) const { - return ptr->parse(ctx, start); + return ptr_->parse(ctx, start); } std::string parser::dump() const { - return ptr->dump(); + return ptr_->dump(); } -void parser::build_grammar(const common_grammar_builder & builder) { +void parser::build_grammar(const common_grammar_builder & builder) const { gbnf_visitor visitor(builder); - ptr->accept(visitor); + ptr_->accept(visitor); auto result = visitor.result(); if (!result.empty()) { builder.add_rule("root", result); @@ -1215,9 +1215,9 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { } void parser_builder::assign_ids(parser & p) { - if (p.ptr) { + if (p.ptr()) { id_assignment_visitor visitor(next_id_); - p.ptr->accept(visitor); + p.ptr()->accept(visitor); } } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index e56a6adf24c17..6c7b86d4e04e9 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -11,23 +11,6 @@ struct common_grammar_builder; -enum parser_type { - PARSER_LITERAL = 0, - PARSER_SEQUENCE = 1, - PARSER_CHOICE = 2, - PARSER_ZERO_OR_MORE = 3, - PARSER_ONE_OR_MORE = 4, - PARSER_NOT = 5, - PARSER_ANY = 6, - PARSER_CHAR_CLASS = 7, - PARSER_GROUP = 8, - PARSER_RULE = 9, - PARSER_OPTIONAL = 10, - PARSER_UNTIL = 11, - PARSER_SPACE = 12, - PARSER_SCHEMA = 13, -}; - enum parser_result_type { PARSER_RESULT_FAIL = 0, PARSER_RESULT_NEED_MORE_INPUT = 1, @@ -89,8 +72,6 @@ class parse_cache { void clear(); }; -class parser; - struct parser_context { std::string_view input; parse_cache memo; @@ -98,15 +79,9 @@ struct parser_context { }; class parser_base; -class sequence_parser; -class choice_parser; -class parser_builder; -class parser_visitor; class parser { - std::shared_ptr ptr; - - friend class parser_builder; + std::shared_ptr ptr_; public: parser(); @@ -114,7 +89,7 @@ class parser { parser(const parser & other) = default; parser & operator=(const parser & other) { if (this != &other) { - ptr = other.ptr; + ptr_ = other.ptr_; } return *this; } @@ -127,17 +102,13 @@ class parser { parser_base & operator*() const; parser_base * operator->() const; - bool is_sequence() const; - std::shared_ptr to_sequence() const; - - bool is_choice() const; - std::shared_ptr to_choice() const; + std::shared_ptr ptr() const { return ptr_; } - parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; + std::string dump() const; - void build_grammar(const common_grammar_builder & builder); + void build_grammar(const common_grammar_builder & builder) const; }; class parser_builder { From f02e2b06fa0ef29aa647060ec74f7f0c6224606b Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:47:43 -0600 Subject: [PATCH 09/16] remove builder usage from until_parser --- common/chat-parser-combinator.cpp | 84 ++++++++++++++++--------------- common/chat-parser-combinator.h | 2 +- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 56215302df99e..0081606516d40 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -367,44 +367,6 @@ class optional_parser : public parser_base { const parser & child() const { return parser_; } }; -class until_parser : public parser_base { - std::string delimiter_; - parser parser_; - - public: - until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) - : parser_base(id), delimiter_(delimiter) { - if (include_spaces) { - auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); - parser_ = builder.zero_or_more(builder.negate(ws + builder.literal(delimiter)) + builder.any()); - } else { - parser_ = builder.zero_or_more(builder.negate(builder.literal(delimiter)) + builder.any()); - } - } - - parser_type type() const override { return PARSER_UNTIL; } - - parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); - return ctx.memo.set(id_, start, result); - } - - std::string dump() const override { - return "Until(" + delimiter_ + ")"; - } - - void accept(parser_visitor & visitor) override; - - const std::string & delimiter() const { return delimiter_; } - - const parser & child() const { return parser_; } -}; - class not_parser : public parser_base { parser parser_; @@ -637,6 +599,48 @@ class group_parser : public parser_base { const parser & child() const { return parser_; } }; +class until_parser : public parser_base { + std::string delimiter_; + parser parser_; + + public: + until_parser(const std::string & delimiter, bool consume_spaces, int id) + : parser_base(id), delimiter_(delimiter) { + + auto delim = parser(std::make_shared(delimiter, -1)); + auto any = parser(std::make_shared(-1)); + + if (consume_spaces) { + auto ws = parser(std::make_shared(-1)); + parser_ = parser(std::make_shared(~(ws + delim) + any, -1)); + } else { + parser_ = parser(std::make_shared(~delim + any, -1)); + } + } + + parser_type type() const override { return PARSER_UNTIL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Until(" + delimiter_ + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::string & delimiter() const { return delimiter_; } + + const parser & child() const { return parser_; } +}; + class schema_parser : public parser_base { parser parser_; std::string name_; @@ -1201,8 +1205,8 @@ parser parser_builder::space() { return parser(std::make_shared(next_id_++)); } -parser parser_builder::until(const std::string & delimiter, bool include_spaces) { - return parser(std::make_shared(delimiter, include_spaces, next_id_++, *this)); +parser parser_builder::until(const std::string & delimiter, bool consume_spaces) { + return parser(std::make_shared(delimiter, consume_spaces, next_id_++)); } parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 6c7b86d4e04e9..e5f508d963011 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -130,7 +130,7 @@ class parser_builder { parser group(const std::string & name, const parser & p); parser rule(const std::string & name); parser space(); - parser until(const std::string & delimiter, bool include_spaces = true); + parser until(const std::string & delimiter, bool consume_spaces = true); parser json(); parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); From 66cf038a37596bee9771142d9900c7aea35c0b32 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:08:17 -0600 Subject: [PATCH 10/16] Use a counter class to help assign rule ids --- common/chat-parser-combinator.cpp | 74 +++++++++++++++---------------- common/chat-parser-combinator.h | 10 ++++- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 0081606516d40..998e5511d857a 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -27,8 +27,6 @@ enum parser_type { class parser_visitor; -static parser json_parser(); - class parser_base { protected: int id_; @@ -970,14 +968,14 @@ class gbnf_visitor : public parser_visitor { // ID assignment visitor for assigning unique IDs to parsers class id_assignment_visitor : public parser_visitor { - int & next_id_; + std::shared_ptr counter_; public: - id_assignment_visitor(int & next_id) : next_id_(next_id) {} + id_assignment_visitor(const std::shared_ptr & counter) : counter_(counter) {} void assign_id(parser_base & p) { if (p.id() == -1) { - p.set_id(next_id_++); + p.set_id(counter_->next()); } } @@ -1155,62 +1153,66 @@ void parser::build_grammar(const common_grammar_builder & builder) const { parser_builder::parser_builder() : rules_(std::make_shared>()) - , next_id_(0) {} + , counter_(std::make_shared(0)) {} + +parser_builder::parser_builder(std::shared_ptr counter) + : rules_(std::make_shared>()) + , counter_(std::move(counter)) {} parser parser_builder::literal(const std::string & literal) { - return parser(std::make_shared(literal, next_id_++)); + return parser(std::make_shared(literal, counter_->next())); } parser parser_builder::sequence(std::initializer_list parsers) { - return parser(std::make_shared(parsers, next_id_++)); + return parser(std::make_shared(parsers, counter_->next())); } parser parser_builder::choice(std::initializer_list parsers) { - return parser(std::make_shared(parsers, next_id_++)); + return parser(std::make_shared(parsers, counter_->next())); } parser parser_builder::one_or_more(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::zero_or_more(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::optional(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::negate(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::any() { - return parser(std::make_shared(next_id_++)); + return parser(std::make_shared(counter_->next())); } parser parser_builder::char_class(const std::string & classes) { - return parser(std::make_shared(classes, next_id_++)); + return parser(std::make_shared(classes, counter_->next())); } parser parser_builder::group(const std::string & name, const parser & p) { - return parser(std::make_shared(name, p, next_id_++)); + return parser(std::make_shared(name, p, counter_->next())); } parser parser_builder::rule(const std::string & name) { - return parser(std::make_shared(name, rules_, next_id_++)); + return parser(std::make_shared(name, rules_, counter_->next())); } parser parser_builder::space() { - return parser(std::make_shared(next_id_++)); + return parser(std::make_shared(counter_->next())); } parser parser_builder::until(const std::string & delimiter, bool consume_spaces) { - return parser(std::make_shared(delimiter, consume_spaces, next_id_++)); + return parser(std::make_shared(delimiter, consume_spaces, counter_->next())); } parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { - return parser(std::make_shared(p, name, schema, next_id_++)); + return parser(std::make_shared(p, name, schema, counter_->next())); } parser parser_builder::add_rule(const std::string & name, const parser & p) { @@ -1220,7 +1222,7 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { void parser_builder::assign_ids(parser & p) { if (p.ptr()) { - id_assignment_visitor visitor(next_id_); + id_assignment_visitor visitor(counter_); p.ptr()->accept(visitor); } } @@ -1238,8 +1240,8 @@ parser build_parser(const std::function & fn) { return root; } -static parser json_parser() { - parser_builder builder; +static parser json_parser(std::shared_ptr counter) { + parser_builder builder(std::move(counter)); // Whitespace: space, tab, newline, carriage return auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); @@ -1280,17 +1282,6 @@ static parser json_parser() { auto false_lit = builder.literal("false"); auto null_lit = builder.literal("null"); - // Value - uses forward references for recursive structures - builder.add_rule("json_value", - builder.rule("json_object") | - builder.rule("json_array") | - builder.rule("json_string") | - builder.rule("json_number") | - true_lit | - false_lit | - null_lit - ); - // Object: { "key": value, ... } auto member = builder.rule("json_string") + ws + builder.literal(":") + ws + builder.rule("json_value"); auto members = member + builder.zero_or_more(ws + builder.literal(",") + ws + member); @@ -1310,14 +1301,21 @@ static parser json_parser() { builder.add_rule("json_array", array); - // Get the json_value rule as the root - auto root = builder.rule("json_value"); - builder.assign_ids(root); + // Value - uses forward references for recursive structures + auto root = builder.add_rule("json_value", + builder.rule("json_object") | + builder.rule("json_array") | + builder.rule("json_string") | + builder.rule("json_number") | + true_lit | + false_lit | + null_lit + ); // Wrap in root_parser to own the rules return parser(std::make_shared(root, builder.rules(), -1)); } parser parser_builder::json() { - return json_parser(); + return json_parser(counter_); } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index e5f508d963011..56a522394bf75 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -111,12 +111,20 @@ class parser { void build_grammar(const common_grammar_builder & builder) const; }; +class parser_id_counter { + int next_id_; + public: + parser_id_counter(int start) : next_id_(start) {} + int next() { return next_id_++; } +}; + class parser_builder { std::shared_ptr> rules_; - int next_id_; + std::shared_ptr counter_; public: parser_builder(); + parser_builder(std::shared_ptr counter); parser literal(const std::string & literal); parser sequence(std::initializer_list parsers); From 2b3caefde82282c5933c9ad53014c336184cedfb Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:24:05 -0600 Subject: [PATCH 11/16] cache everything --- common/chat-parser-combinator.cpp | 419 ++++++++++++++---------------- common/chat-parser-combinator.h | 2 + 2 files changed, 194 insertions(+), 227 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 998e5511d857a..fb3b0fb083a0a 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -53,29 +53,26 @@ class literal_parser : public parser_base { parser_type type() const override { return PARSER_LITERAL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto pos = start; - for (auto i = 0u; i < literal_.size(); ++i) { - if (pos >= ctx.input.size()) { - if (ctx.input_is_complete) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return ctx.memo.cached(id_, start, [&]() { + auto pos = start; + for (auto i = 0u; i < literal_.size(); ++i) { + if (pos >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + if (i > 0) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + return parser_result(PARSER_RESULT_FAIL, start); } - if (i > 0) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + if (ctx.input[pos] != literal_[i]) { + return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); - } - if (ctx.input[pos] != literal_[i]) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + ++pos; } - ++pos; - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + }); } std::string dump() const override { @@ -108,36 +105,33 @@ class sequence_parser : public parser_base { parser_type type() const override { return PARSER_SEQUENCE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - std::unordered_map groups; - - auto pos = start; - for (const auto & p : parsers_) { - auto result = p->parse(ctx, pos); - - // Copy groups - groups.insert(result.groups.begin(), result.groups.end()); + return ctx.memo.cached(id_, start, [&]() { + std::unordered_map groups; + + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + // Copy groups + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_fail()) { + if (result.end >= ctx.input.size() && !ctx.input_is_complete) { + // If we fail because we don't have enough input, then return success + return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + } + return parser_result(PARSER_RESULT_FAIL, start, result.end, groups); + } - if (result.is_fail()) { - if (result.end >= ctx.input.size() && !ctx.input_is_complete) { - // If we fail because we don't have enough input, then return success - return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start, result.end, groups)); - } - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); + pos = result.end; } - pos = result.end; - } - - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + }); } std::string dump() const override { @@ -175,25 +169,22 @@ class choice_parser : public parser_base { parser_type type() const override { return PARSER_CHOICE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } + return ctx.memo.cached(id_, start, [&]() { + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); - auto pos = start; - for (const auto & p : parsers_) { - auto result = p->parse(ctx, pos); - - if (result.is_success()) { - return ctx.memo.set(id_, start, result); - } + if (result.is_success()) { + return result; + } - if (result.is_need_more_input()) { - return result; + if (result.is_need_more_input()) { + return result; + } } - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return parser_result(PARSER_RESULT_FAIL, start); + }); } std::string dump() const override { @@ -219,49 +210,41 @@ class one_or_more_parser : public parser_base { parser_type type() const override { return PARSER_ONE_OR_MORE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - std::unordered_map groups; - - // We can't return back the cached result, since there may be more - // repetitions since the last parsing attempt. Instead, resume parsing from - // the last successful repetition found. - auto pos = start; - if (cached != std::nullopt) { - pos = cached->end; - groups.insert(cached->groups.begin(), cached->groups.end()); - } + return ctx.memo.cached(id_, start, [&]() { + std::unordered_map groups; - if (pos == start) { - auto first_result = parser_->parse(ctx, pos); + // Parse at least once + auto first_result = parser_->parse(ctx, start); if (!first_result.is_success()) { return first_result; } - pos = first_result.end; + auto pos = first_result.end; groups.insert(first_result.groups.begin(), first_result.groups.end()); - } - for (;;) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); + // Parse zero or more additional times + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } - if (result.is_fail()) { - // Done with repetitions - break; - } + if (result.is_fail()) { + // Done with repetitions + break; + } - if (result.end == pos) { - break; // Prevent an infinite loop - } + if (result.end == pos) { + break; // Prevent an infinite loop + } - pos = result.end; - } + pos = result.end; + } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + }); } std::string dump() const override { @@ -282,39 +265,32 @@ class zero_or_more_parser : public parser_base { parser_type type() const override { return PARSER_ZERO_OR_MORE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - std::unordered_map groups; - - // We can't return back the cached result, since there may be more - // repetitions since the last parsing attempt. Instead, resume parsing from - // the last successful repetition found. - auto pos = start; - if (cached != std::nullopt) { - pos = cached->end; - groups.insert(cached->groups.begin(), cached->groups.end()); - } + return ctx.memo.cached(id_, start, [&]() { + std::unordered_map groups; + auto pos = start; - for (;;) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } - if (result.is_fail()) { - // Done with repetitions (zero or more is always valid) - break; - } + if (result.is_fail()) { + // Done with repetitions (zero or more is always valid) + break; + } - if (result.end == pos) { - break; // Prevent an infinite loop - } + if (result.end == pos) { + break; // Prevent an infinite loop + } - pos = result.end; - } + pos = result.end; + } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + }); } std::string dump() const override { @@ -335,25 +311,22 @@ class optional_parser : public parser_base { parser_type type() const override { return PARSER_OPTIONAL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + auto result = parser_->parse(ctx, start); - if (result.is_success()) { - // Matched successfully - return ctx.memo.set(id_, start, result); - } + if (result.is_success()) { + // Matched successfully + return result; + } - if (result.is_need_more_input()) { - // Propagate - need more input to determine if optional matches - return result; - } + if (result.is_need_more_input()) { + // Propagate - need more input to determine if optional matches + return result; + } - // No match, but optional always succeeds with zero matches - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start)); + // No match, but optional always succeeds with zero matches + return parser_result(PARSER_RESULT_SUCCESS, start, start); + }); } std::string dump() const override { @@ -374,25 +347,22 @@ class not_parser : public parser_base { parser_type type() const override { return PARSER_NOT; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + auto result = parser_->parse(ctx, start); - if (result.is_success()) { - // Fail if the underlying parser matches - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); - } + if (result.is_success()) { + // Fail if the underlying parser matches + return parser_result(PARSER_RESULT_FAIL, start); + } - if (result.is_need_more_input()) { - // Propagate - need to know what child would match before negating - return result; - } + if (result.is_need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } - // Child failed, so negation succeeds - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start)); + // Child failed, so negation succeeds + return parser_result(PARSER_RESULT_SUCCESS, start); + }); } std::string dump() const override { @@ -411,19 +381,16 @@ class any_parser : public parser_base { parser_type type() const override { return PARSER_ANY; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return ctx.memo.cached(id_, start, [&]() { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + }); } std::string dump() const override { @@ -440,22 +407,19 @@ class space_parser : public parser_base { parser_type type() const override { return PARSER_SPACE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto pos = start; - while (pos < ctx.input.size()) { - char c = ctx.input[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - ++pos; - } else { - break; + return ctx.memo.cached(id_, start, [&]() { + auto pos = start; + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + ++pos; + } else { + break; + } } - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + }); } std::string dump() const override { @@ -530,36 +494,33 @@ class char_class_parser : public parser_base { parser_type type() const override { return PARSER_CHAR_CLASS; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return ctx.memo.cached(id_, start, [&]() { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); - } - bool matches = false; - for (const auto & range : ranges_) { - if (range.contains(ctx.input[start])) { - matches = true; - break; + bool matches = false; + for (const auto & range : ranges_) { + if (range.contains(ctx.input[start])) { + matches = true; + break; + } } - } - // If negated, invert the match result - if (negated_) { - matches = !matches; - } + // If negated, invert the match result + if (negated_) { + matches = !matches; + } - if (matches) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); - } + if (matches) { + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return parser_result(PARSER_RESULT_FAIL, start); + }); } std::string dump() const override { @@ -581,11 +542,13 @@ class group_parser : public parser_base { parser_type type() const override { return PARSER_GROUP; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto result = parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + auto result = parser_->parse(ctx, start); - // Store result - result.groups[name_] = parser_match_location{result.start, result.end}; - return ctx.memo.set(id_, start, result); + // Store result + result.groups[name_] = parser_match_location{result.start, result.end}; + return result; + }); } std::string dump() const override { @@ -619,13 +582,9 @@ class until_parser : public parser_base { parser_type type() const override { return PARSER_UNTIL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); - return ctx.memo.set(id_, start, result); + return ctx.memo.cached(id_, start, [&]() { + return parser_->parse(ctx, start); + }); } std::string dump() const override { @@ -651,7 +610,9 @@ class schema_parser : public parser_base { parser_type type() const override { return PARSER_SCHEMA; } parser_result parse(parser_context & ctx, size_t start = 0) override { - return parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + return parser_->parse(ctx, start); + }); } std::string dump() const override { @@ -678,25 +639,21 @@ class rule_parser : public parser_base { parser_type type() const override { return PARSER_RULE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto rules = rules_.lock(); - if (!rules) { - LOG_ERR("rule_parser::parse called with expired rule registry\n"); - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); - } + return ctx.memo.cached(id_, start, [&]() { + auto rules = rules_.lock(); + if (!rules) { + LOG_ERR("rule_parser::parse called with expired rule registry\n"); + return parser_result(PARSER_RESULT_FAIL, start); + } - auto it = rules->find(name_); - if (it == rules->end()) { - LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); - } + auto it = rules->find(name_); + if (it == rules->end()) { + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); + return parser_result(PARSER_RESULT_FAIL, start); + } - auto result = it->second->parse(ctx, start); - return ctx.memo.set(id_, start, result); + return it->second->parse(ctx, start); + }); } std::string dump() const override { @@ -1105,6 +1062,14 @@ void parse_cache::clear() { results.clear(); } +parser_result parse_cache::cached(int id, size_t start, const std::function & fn) { + auto result = get(id, start); + if (result) { + return *result; + } + return set(id, start, fn()); +} + parser::parser() {} parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 56a522394bf75..4dbbca3f6d1c5 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -70,6 +70,8 @@ class parse_cache { parser_result set(int id, size_t start, parser_result result); std::optional get(int id, size_t start); void clear(); + + parser_result cached(int id, size_t start, const std::function & fn); }; struct parser_context { From adac6bae7f8a53493a192f72cdb7302ef1cf7f62 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:32:29 -0600 Subject: [PATCH 12/16] add short description for each parser --- common/chat-parser-combinator.cpp | 30 +++++++++++++++++++++ common/chat-parser-combinator.h | 44 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index fb3b0fb083a0a..2b46a097d7a23 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -44,6 +44,8 @@ class parser_base { virtual void accept(parser_visitor & visitor) = 0; }; +// Matches an exact literal string. +// S -> "hello" class literal_parser : public parser_base { std::string literal_; @@ -84,6 +86,8 @@ class literal_parser : public parser_base { const std::string & literal() const { return literal_; } }; +// Matches a sequence of parsers in order, all must succeed. +// S -> A B C class sequence_parser : public parser_base { std::vector parsers_; @@ -148,6 +152,8 @@ class sequence_parser : public parser_base { const std::vector & parsers() const { return parsers_; } }; +// Matches the first parser that succeeds from a list of alternatives. +// S -> A | B | C class choice_parser : public parser_base { std::vector parsers_; @@ -201,6 +207,8 @@ class choice_parser : public parser_base { const std::vector & parsers() const { return parsers_; } }; +// Matches one or more repetitions of a parser. +// S -> A+ class one_or_more_parser : public parser_base { parser parser_; @@ -256,6 +264,8 @@ class one_or_more_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches zero or more repetitions of a parser, always succeeds. +// S -> A* class zero_or_more_parser : public parser_base { parser parser_; @@ -302,6 +312,8 @@ class zero_or_more_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches zero or one occurrence of a parser, always succeeds. +// S -> A? class optional_parser : public parser_base { parser parser_; @@ -338,6 +350,8 @@ class optional_parser : public parser_base { const parser & child() const { return parser_; } }; +// Negative lookahead: succeeds if child parser fails, consumes no input. +// S -> !A class not_parser : public parser_base { parser parser_; @@ -374,6 +388,8 @@ class not_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches any single character. +// S -> . class any_parser : public parser_base { public: any_parser(int id) : parser_base(id) {} @@ -400,6 +416,8 @@ class any_parser : public parser_base { void accept(parser_visitor & visitor) override; }; +// Matches zero or more whitespace characters (space, tab, newline). +// S -> [ \t\n]* class space_parser : public parser_base { public: space_parser(int id) : parser_base(id) {} @@ -429,6 +447,8 @@ class space_parser : public parser_base { void accept(parser_visitor & visitor) override; }; +// Matches a single character from a character class or range. +// S -> [a-z] or S -> [^0-9] class char_class_parser : public parser_base { struct char_range { int start; @@ -532,6 +552,8 @@ class char_class_parser : public parser_base { const std::string & pattern() const { return pattern_; } }; +// Captures the matched text from a parser and stores it with a name. +// S -> class group_parser : public parser_base { std::string name_; parser parser_; @@ -560,6 +582,8 @@ class group_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches all characters until a delimiter is found (delimiter not consumed). +// S -> (!delim .)* class until_parser : public parser_base { std::string delimiter_; parser parser_; @@ -598,6 +622,8 @@ class until_parser : public parser_base { const parser & child() const { return parser_; } }; +// Wraps a parser with JSON schema metadata for grammar generation. +// Used internally to convert JSON schemas to GBNF grammar rules. class schema_parser : public parser_base { parser parser_; std::string name_; @@ -628,6 +654,8 @@ class schema_parser : public parser_base { const nlohmann::ordered_json & schema() const { return schema_; } }; +// References a named rule for recursive or reusable grammar definitions. +// expr -> term | expr "+" term class rule_parser : public parser_base { std::string name_; std::weak_ptr> rules_; @@ -665,6 +693,8 @@ class rule_parser : public parser_base { const std::string & name() const { return name_; } }; +// Container for the root parser and all named rules in the grammar. +// Manages ownership of rule registry to enable recursive grammar definitions. class root_parser : public parser_base { parser root_; std::shared_ptr> rules_; diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 4dbbca3f6d1c5..25ce7f7c11cb0 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -128,20 +128,64 @@ class parser_builder { parser_builder(); parser_builder(std::shared_ptr counter); + // Matches an exact literal string. + // S -> "hello" parser literal(const std::string & literal); + + // Matches a sequence of parsers in order, all must succeed. + // S -> A B C parser sequence(std::initializer_list parsers); + + // Matches the first parser that succeeds from a list of alternatives. + // S -> A | B | C parser choice(std::initializer_list parsers); + + // Matches one or more repetitions of a parser. + // S -> A+ parser one_or_more(const parser & p); + + // Matches zero or more repetitions of a parser, always succeeds. + // S -> A* parser zero_or_more(const parser & p); + + // Matches zero or one occurrence of a parser, always succeeds. + // S -> A? parser optional(const parser & p); + + // Negative lookahead: succeeds if child parser fails, consumes no input. + // S -> !A parser negate(const parser & p); + + // Matches any single character. + // S -> . parser any(); + + // Matches a single character from a character class or range. + // S -> [a-z] or S -> [^0-9] parser char_class(const std::string & classes); + + // Captures the matched text from a parser and stores it with a name. + // S -> parser group(const std::string & name, const parser & p); + + // References a named rule for recursive or reusable grammar definitions. + // expr -> term | expr "+" term parser rule(const std::string & name); + + // Matches zero or more whitespace characters (space, tab, newline). + // S -> [ \t\n]* parser space(); + + // Matches all characters until a delimiter is found (delimiter not consumed). + // S -> (!delim .)* parser until(const std::string & delimiter, bool consume_spaces = true); + + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. + // value -> object | array | string | number | true | false | null parser json(); + + // Wraps a parser with JSON schema metadata for grammar generation. + // Used internally to convert JSON schemas to GBNF grammar rules. parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); parser add_rule(const std::string & name, const parser & p); From 0be2a93eb7ea86d3836d986007e8ba0e451dd834 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:49:24 -0600 Subject: [PATCH 13/16] create a type for the root parser --- common/chat-parser-combinator.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 2b46a097d7a23..d6aee652d2ea0 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -23,6 +23,7 @@ enum parser_type { PARSER_UNTIL = 11, PARSER_SPACE = 12, PARSER_SCHEMA = 13, + PARSER_ROOT = 14, }; class parser_visitor; @@ -705,7 +706,7 @@ class root_parser : public parser_base { root_parser(const parser & root, std::shared_ptr> rules, int id) : parser_base(id), root_(root), rules_(std::move(rules)) {} - parser_type type() const override { return root_->type(); } + parser_type type() const override { return PARSER_ROOT; } parser_result parse(parser_context & ctx, size_t start = 0) override { return root_->parse(ctx, start); From 31b386f6ef431220840e869da5b54c34804f058b Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 22:16:30 -0600 Subject: [PATCH 14/16] implement repetition parser --- common/chat-parser-combinator.cpp | 194 ++++++++++++++++++------------ common/chat-parser-combinator.h | 9 ++ 2 files changed, 129 insertions(+), 74 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index d6aee652d2ea0..4121bfcde1c2e 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -24,6 +24,7 @@ enum parser_type { PARSER_SPACE = 12, PARSER_SCHEMA = 13, PARSER_ROOT = 14, + PARSER_REPETITION = 15, }; class parser_visitor; @@ -208,48 +209,52 @@ class choice_parser : public parser_base { const std::vector & parsers() const { return parsers_; } }; -// Matches one or more repetitions of a parser. -// S -> A+ -class one_or_more_parser : public parser_base { +// Matches between min and max repetitions of a parser (inclusive). +// S -> A{m,n} +// Use -1 for max_count to represent unbounded repetition (equivalent to {m,}) +class repetition_parser : public parser_base { parser parser_; + int min_count_; + int max_count_; public: - one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + repetition_parser(const parser & parser, int min_count, int max_count, int id) + : parser_base(id), parser_(parser), min_count_(min_count), max_count_(max_count) {} - parser_type type() const override { return PARSER_ONE_OR_MORE; } + parser_type type() const override { return PARSER_REPETITION; } parser_result parse(parser_context & ctx, size_t start = 0) override { return ctx.memo.cached(id_, start, [&]() { std::unordered_map groups; + auto pos = start; + int match_count = 0; - // Parse at least once - auto first_result = parser_->parse(ctx, start); - if (!first_result.is_success()) { - return first_result; - } - - auto pos = first_result.end; - groups.insert(first_result.groups.begin(), first_result.groups.end()); - - // Parse zero or more additional times - for (;;) { + // Try to match up to max_count times (or unlimited if max_count is -1) + while (max_count_ == -1 || match_count < max_count_) { auto result = parser_->parse(ctx, pos); groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + if (result.is_success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + pos = result.end; + match_count++; + continue; } - if (result.is_fail()) { - // Done with repetitions - break; + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); } - if (result.end == pos) { - break; // Prevent an infinite loop - } + // Child failed - stop trying + break; + } - pos = result.end; + // Check if we got enough matches + if (match_count < min_count_) { + return parser_result(PARSER_RESULT_FAIL, start, pos, groups); } return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); @@ -257,98 +262,106 @@ class one_or_more_parser : public parser_base { } std::string dump() const override { - return "OneOrMore(" + parser_->dump() + ")"; + if (max_count_ == -1) { + return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", unbounded)"; + } + return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", " + std::to_string(max_count_) + ")"; } void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } + + int min_count() const { return min_count_; } + + int max_count() const { return max_count_; } }; -// Matches zero or more repetitions of a parser, always succeeds. -// S -> A* -class zero_or_more_parser : public parser_base { - parser parser_; +// Matches one or more repetitions of a parser. +// S -> A+ +class one_or_more_parser : public parser_base { + parser delegate_; public: - zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + one_or_more_parser(const parser & p, int id) : parser_base(id) { + delegate_ = parser(std::make_shared(p, 1, -1, id)); + } - parser_type type() const override { return PARSER_ZERO_OR_MORE; } + parser_type type() const override { return PARSER_ONE_OR_MORE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - std::unordered_map groups; - auto pos = start; + return delegate_->parse(ctx, start); + } - for (;;) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); + std::string dump() const override { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return "OneOrMore(" + rep->child()->dump() + ")"; + } - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + void accept(parser_visitor & visitor) override; - if (result.is_fail()) { - // Done with repetitions (zero or more is always valid) - break; - } + const parser & child() const { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return rep->child(); + } +}; - if (result.end == pos) { - break; // Prevent an infinite loop - } +// Matches zero or more repetitions of a parser, always succeeds. +// S -> A* +class zero_or_more_parser : public parser_base { + parser delegate_; - pos = result.end; - } + public: + zero_or_more_parser(const parser & p, int id) : parser_base(id) { + delegate_ = parser(std::make_shared(p, 0, -1, id)); + } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); - }); + parser_type type() const override { return PARSER_ZERO_OR_MORE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + return delegate_->parse(ctx, start); } std::string dump() const override { - return "ZeroOrMore(" + parser_->dump() + ")"; + auto rep = std::static_pointer_cast(delegate_.ptr()); + return "ZeroOrMore(" + rep->child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - const parser & child() const { return parser_; } + const parser & child() const { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return rep->child(); + } }; // Matches zero or one occurrence of a parser, always succeeds. // S -> A? class optional_parser : public parser_base { - parser parser_; + parser delegate_; public: - optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + optional_parser(const parser & p, int id) : parser_base(id) { + delegate_ = parser(std::make_shared(p, 0, 1, id)); + } parser_type type() const override { return PARSER_OPTIONAL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto result = parser_->parse(ctx, start); - - if (result.is_success()) { - // Matched successfully - return result; - } - - if (result.is_need_more_input()) { - // Propagate - need more input to determine if optional matches - return result; - } - - // No match, but optional always succeeds with zero matches - return parser_result(PARSER_RESULT_SUCCESS, start, start); - }); + return delegate_->parse(ctx, start); } std::string dump() const override { - return "Optional(" + parser_->dump() + ")"; + auto rep = std::static_pointer_cast(delegate_.ptr()); + return "Optional(" + rep->child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - const parser & child() const { return parser_; } + const parser & child() const { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return rep->child(); + } }; // Negative lookahead: succeeds if child parser fails, consumes no input. @@ -734,6 +747,7 @@ class parser_visitor { virtual void visit(one_or_more_parser & p) = 0; virtual void visit(zero_or_more_parser & p) = 0; virtual void visit(optional_parser & p) = 0; + virtual void visit(repetition_parser & p) = 0; virtual void visit(until_parser & p) = 0; virtual void visit(not_parser & p) = 0; virtual void visit(any_parser & p) = 0; @@ -891,6 +905,24 @@ class gbnf_visitor : public parser_visitor { } } + void visit(repetition_parser & p) override { + p.child()->accept(*this); + std::string child_result = current_result_; + + if (needs_parens(p.child()->type())) { + child_result = "(" + child_result + ")"; + } + + if (p.max_count() == -1) { + // Unbounded: {n,} + current_result_ = child_result + "{" + std::to_string(p.min_count()) + ",}"; + } else { + // Bounded: {n,m} + current_result_ = child_result + "{" + std::to_string(p.min_count()) + "," + + std::to_string(p.max_count()) + "}"; + } + } + void visit(until_parser & p) override { // Generate pattern that matches prefixes but prevents full delimiter match current_result_ = generate_until_pattern(p.delimiter()) + "*"; @@ -1021,6 +1053,11 @@ class id_assignment_visitor : public parser_visitor { p.child()->accept(*this); } + void visit(repetition_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + void visit(until_parser & p) override { assign_id(p); p.child()->accept(*this); @@ -1049,6 +1086,7 @@ void choice_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void one_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void zero_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void optional_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void repetition_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void until_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void not_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } @@ -1207,6 +1245,14 @@ parser parser_builder::until(const std::string & delimiter, bool consume_spaces) return parser(std::make_shared(delimiter, consume_spaces, counter_->next())); } +parser parser_builder::repeat(const parser & p, int min, int max) { + return parser(std::make_shared(p, min, max, counter_->next())); +} + +parser parser_builder::repeat(const parser & p, int n) { + return repeat(p, n, n); +} + parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { return parser(std::make_shared(p, name, schema, counter_->next())); } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 25ce7f7c11cb0..b295b4b520498 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -180,6 +180,15 @@ class parser_builder { // S -> (!delim .)* parser until(const std::string & delimiter, bool consume_spaces = true); + // Matches between min and max repetitions of a parser (inclusive). + // S -> A{m,n} + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + parser repeat(const parser & p, int min, int max); + + // Matches exactly n repetitions of a parser. + // S -> A{n} + parser repeat(const parser & p, int n); + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. // value -> object | array | string | number | true | false | null parser json(); From ffb7a6f77db113c16ecf3cd2de4db9fc2f5344fb Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 22:27:54 -0600 Subject: [PATCH 15/16] Make optional, one_or_more, and zero_or_more subclasses of repetition --- common/chat-parser-combinator.cpp | 86 ++++++++----------------------- 1 file changed, 22 insertions(+), 64 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 4121bfcde1c2e..23bce3b3a50d6 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -12,19 +12,19 @@ enum parser_type { PARSER_LITERAL = 0, PARSER_SEQUENCE = 1, PARSER_CHOICE = 2, - PARSER_ZERO_OR_MORE = 3, - PARSER_ONE_OR_MORE = 4, - PARSER_NOT = 5, - PARSER_ANY = 6, - PARSER_CHAR_CLASS = 7, - PARSER_GROUP = 8, - PARSER_RULE = 9, - PARSER_OPTIONAL = 10, - PARSER_UNTIL = 11, - PARSER_SPACE = 12, - PARSER_SCHEMA = 13, - PARSER_ROOT = 14, - PARSER_REPETITION = 15, + PARSER_REPETITION = 3, + PARSER_OPTIONAL = 4, + PARSER_ZERO_OR_MORE = 5, + PARSER_ONE_OR_MORE = 6, + PARSER_NOT = 7, + PARSER_ANY = 8, + PARSER_CHAR_CLASS = 9, + PARSER_GROUP = 10, + PARSER_RULE = 11, + PARSER_UNTIL = 12, + PARSER_SPACE = 13, + PARSER_SCHEMA = 14, + PARSER_ROOT = 15, }; class parser_visitor; @@ -279,89 +279,47 @@ class repetition_parser : public parser_base { // Matches one or more repetitions of a parser. // S -> A+ -class one_or_more_parser : public parser_base { - parser delegate_; - +class one_or_more_parser : public repetition_parser { public: - one_or_more_parser(const parser & p, int id) : parser_base(id) { - delegate_ = parser(std::make_shared(p, 1, -1, id)); - } + one_or_more_parser(const parser & p, int id) : repetition_parser(p, 1, -1, id) {} parser_type type() const override { return PARSER_ONE_OR_MORE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return delegate_->parse(ctx, start); - } - std::string dump() const override { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return "OneOrMore(" + rep->child()->dump() + ")"; + return "OneOrMore(" + child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - - const parser & child() const { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return rep->child(); - } }; // Matches zero or more repetitions of a parser, always succeeds. // S -> A* -class zero_or_more_parser : public parser_base { - parser delegate_; - +class zero_or_more_parser : public repetition_parser { public: - zero_or_more_parser(const parser & p, int id) : parser_base(id) { - delegate_ = parser(std::make_shared(p, 0, -1, id)); - } + zero_or_more_parser(const parser & p, int id) : repetition_parser(p, 0, -1, id) {} parser_type type() const override { return PARSER_ZERO_OR_MORE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return delegate_->parse(ctx, start); - } - std::string dump() const override { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return "ZeroOrMore(" + rep->child()->dump() + ")"; + return "ZeroOrMore(" + child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - - const parser & child() const { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return rep->child(); - } }; // Matches zero or one occurrence of a parser, always succeeds. // S -> A? -class optional_parser : public parser_base { - parser delegate_; - +class optional_parser : public repetition_parser { public: - optional_parser(const parser & p, int id) : parser_base(id) { - delegate_ = parser(std::make_shared(p, 0, 1, id)); - } + optional_parser(const parser & p, int id) : repetition_parser(p, 0, 1, id) {} parser_type type() const override { return PARSER_OPTIONAL; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return delegate_->parse(ctx, start); - } - std::string dump() const override { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return "Optional(" + rep->child()->dump() + ")"; + return "Optional(" + child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - - const parser & child() const { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return rep->child(); - } }; // Negative lookahead: succeeds if child parser fails, consumes no input. From 085404a326000a195efb2ca550cec2dccf684273 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 22:44:55 -0600 Subject: [PATCH 16/16] improve context constructor --- common/chat-parser-combinator.h | 26 ++++++- tests/test-chat-parser-combinator.cpp | 108 +++++++++++++------------- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index b295b4b520498..ab839971b725c 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -52,9 +52,15 @@ struct parser_result { std::unordered_map groups; parser_result() : type(PARSER_RESULT_FAIL) {} - parser_result(parser_result_type type, size_t start) : type(type), start(start), end(start) {} - parser_result(parser_result_type type, size_t start, size_t end) : type(type), start(start), end(end) {} - parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) : type(type), start(start), end(end), groups(groups) {} + + parser_result(parser_result_type type, size_t start) + : type(type), start(start), end(start) {} + + parser_result(parser_result_type type, size_t start, size_t end) + : type(type), start(start), end(end) {} + + parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) + : type(type), start(start), end(end), groups(groups) {} bool is_fail() const { return type == PARSER_RESULT_FAIL; } bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } @@ -77,7 +83,19 @@ class parse_cache { struct parser_context { std::string_view input; parse_cache memo; - bool input_is_complete = true; + bool input_is_complete; + + parser_context() + : memo(), input_is_complete(true) {} + + parser_context(std::string_view input) + : input(input), memo(), input_is_complete(true) {} + + parser_context(std::string_view input, bool complete) + : input(input), memo(), input_is_complete(complete) {} + + parser_context(std::string_view input, parse_cache memo, bool complete = true) + : input(input), memo(std::move(memo)), input_is_complete(complete) {} }; class parser_base; diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index e4f637af9f797..83ff2ba4a67fd 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -36,7 +36,7 @@ static void test_partial_parsing() { parser_context ctx; parser_result result; - ctx = parser_context{"hello", parse_cache()}; + ctx = parser_context("hello"); result = parser.parse(ctx); assert_equals(true, result.is_success()); } @@ -49,11 +49,11 @@ static void test_partial_parsing() { parser_context ctx; parser_result result; - ctx = parser_context{"a", parse_cache()}; + ctx = parser_context("a"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"A", parse_cache()}; + ctx = parser_context("A"); result = parser.parse(ctx); assert_equals(true, result.is_fail()); @@ -61,15 +61,15 @@ static void test_partial_parsing() { return p.char_class("a-z-"); }); - ctx = parser_context{"f", parse_cache()}; + ctx = parser_context("f"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"-", parse_cache()}; + ctx = parser_context("-"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"A", parse_cache()}; + ctx = parser_context("A"); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -80,25 +80,25 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"", parse_cache(), false}; + ctx = parser_context("", false); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); // No match, since it does not adhere to the grammar - ctx = parser_context{"I am parser", parse_cache(), false}; + ctx = parser_context("I am parser", false); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -109,25 +109,25 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); // No match - ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -138,16 +138,16 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"a", parse_cache(), false}; + auto ctx = parser_context("a", false); auto result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); - ctx = parser_context{"aba", parse_cache(), false}; + ctx = parser_context("aba", false); result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); // Full match - ctx = parser_context{"ab", parse_cache(), true}; + ctx = parser_context("ab", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); } @@ -158,21 +158,21 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"a", parse_cache(), false}; + auto ctx = parser_context("a", false); auto result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); - ctx = parser_context{"aba", parse_cache(), false}; + ctx = parser_context("aba", false); result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); // Full match - ctx = parser_context{"ab", parse_cache(), true}; + ctx = parser_context("ab", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); // No match - ctx = parser_context{"cd", parse_cache(), true}; + ctx = parser_context("cd", true); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -189,7 +189,7 @@ static void test_capture_groups() { }); std::string input = "I have a thought"; - auto ctx = parser_context{input, parse_cache()}; + auto ctx = parser_context(input); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); @@ -208,7 +208,7 @@ static void test_capture_groups() { }); std::string input = "I have a "; - auto ctx = parser_context{input, parse_cache(), false}; + auto ctx = parser_context(input, false); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); @@ -228,7 +228,7 @@ static void test_capture_groups() { }); std::string input = "The user said hello.Hello!"; - auto ctx = parser_context{input, parse_cache(), true}; + auto ctx = parser_context(input, true); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); @@ -253,19 +253,19 @@ static void test_char_class() { parser_context ctx; parser_result result; - ctx = parser_context{"\n", parse_cache()}; + ctx = parser_context("\n"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"\t", parse_cache()}; + ctx = parser_context("\t"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"\\", parse_cache()}; + ctx = parser_context("\\"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{" ", parse_cache()}; + ctx = parser_context(" "); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -278,20 +278,20 @@ static void test_char_class() { parser_context ctx; parser_result result; - ctx = parser_context{"a", parse_cache()}; + ctx = parser_context("a"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"-", parse_cache()}; + ctx = parser_context("-"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"z", parse_cache()}; + ctx = parser_context("z"); result = parser.parse(ctx); assert_equals(true, result.is_success()); // Should NOT match 'b' since \- is a literal dash, not a range - ctx = parser_context{"b", parse_cache()}; + ctx = parser_context("b"); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -312,32 +312,32 @@ static void test_recursive_references() { parser_result result; // Test simple number - ctx = parser_context{"1", parse_cache(), true}; + ctx = parser_context("1", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test simple list - ctx = parser_context{"[1]", parse_cache(), true}; + ctx = parser_context("[1]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test nested list - ctx = parser_context{"[[2]]", parse_cache(), true}; + ctx = parser_context("[[2]]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test deeply nested list - ctx = parser_context{"[[[3]]]", parse_cache(), true}; + ctx = parser_context("[[[3]]]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test partial match - ctx = parser_context{"[[", parse_cache(), false}; + ctx = parser_context("[[", false); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test no match - ctx = parser_context{"[a]", parse_cache(), true}; + ctx = parser_context("[a]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -349,19 +349,19 @@ static void test_optional() { }); // Full match with optional part present - auto ctx = parser_context{"hello world", parse_cache()}; + auto ctx = parser_context("hello world"); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals((size_t)11, result.end); // Full match with optional part absent - ctx = parser_context{"hello", parse_cache(), true}; + ctx = parser_context("hello", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals((size_t)5, result.end); // Partial match - waiting for more input to determine if optional matches - ctx = parser_context{"hello ", parse_cache(), false}; + ctx = parser_context("hello ", false); result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); } @@ -374,7 +374,7 @@ static void test_json_parser() { { // Test parsing a simple JSON object std::string input = R"({"name": "test", "value": 42, "flag": true})"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = json.parse(ctx); @@ -384,7 +384,7 @@ static void test_json_parser() { { // Test parsing a JSON array with mixed types std::string input = R"([1, "hello", true, null, 3.14])"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = json.parse(ctx); @@ -394,7 +394,7 @@ static void test_json_parser() { { // Test parsing nested JSON with objects and arrays std::string input = R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = json.parse(ctx); @@ -404,7 +404,7 @@ static void test_json_parser() { { // Test partial parsing - incomplete object std::string input = R"({"name": "test", "value": )"; - parser_context ctx{input, parse_cache(), false}; + parser_context ctx(input, false); auto result = json.parse(ctx); @@ -413,7 +413,7 @@ static void test_json_parser() { { // Test partial parsing - incomplete array std::string input = R"([1, 2, 3, )"; - parser_context ctx{input, parse_cache(), false}; + parser_context ctx(input, false); auto result = json.parse(ctx); @@ -422,7 +422,7 @@ static void test_json_parser() { { // Test partial parsing - incomplete nested structure std::string input = R"({"data": {"nested": )"; - parser_context ctx{input, parse_cache(), false}; + parser_context ctx(input, false); auto result = json.parse(ctx); @@ -476,7 +476,7 @@ static void test_complete_example() { // Test complete input std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = parser.parse(ctx); @@ -488,21 +488,21 @@ static void test_complete_example() { // Test partial input input = R"(I need to call get_weather )"; - ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); input = R"(I need to call get_weatherget_weather)"; - ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); input = R"(I need to call get_weatherget_weatherI need to call get_weatherget_weather{"cit)"; - ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); assert_equals(true, result.is_success());