|
1 | 1 | #include "mlxsharp/api.h" |
| 2 | +#include "mlxsharp/llm_model_runner.h" |
2 | 3 |
|
3 | 4 | #include <algorithm> |
4 | 5 | #include <atomic> |
@@ -43,6 +44,7 @@ struct mlxsharp_session { |
43 | 44 | std::string chat_model; |
44 | 45 | std::string embedding_model; |
45 | 46 | std::string image_model; |
| 47 | + std::unique_ptr<mlxsharp::llm::ModelRunner> model_runner; |
46 | 48 |
|
47 | 49 | mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image) |
48 | 50 | : context(ctx), |
@@ -563,6 +565,104 @@ void mlxsharp_free_buffer(unsigned char* data) { |
563 | 565 | std::free(data); |
564 | 566 | } |
565 | 567 |
|
| 568 | +int mlxsharp_session_load_model( |
| 569 | + void* session_ptr, |
| 570 | + const char* model_directory, |
| 571 | + const char* tokenizer_path) { |
| 572 | + if (session_ptr == nullptr) { |
| 573 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session pointer is null."); |
| 574 | + } |
| 575 | + |
| 576 | + if (model_directory == nullptr || tokenizer_path == nullptr) { |
| 577 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Model directory or tokenizer path is null."); |
| 578 | + } |
| 579 | + |
| 580 | + auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
| 581 | + |
| 582 | + return invoke([&]() -> int { |
| 583 | + auto model = mlxsharp::llm::ModelRunner::Create(model_directory, tokenizer_path); |
| 584 | + session->model_runner = std::move(model); |
| 585 | + return MLXSHARP_STATUS_SUCCESS; |
| 586 | + }); |
| 587 | +} |
| 588 | + |
| 589 | +int mlxsharp_session_generate_tokens( |
| 590 | + void* session_ptr, |
| 591 | + const int32_t* prompt_tokens, |
| 592 | + size_t prompt_token_count, |
| 593 | + const mlxsharp_generation_options* options, |
| 594 | + mlxsharp_token_buffer* output_tokens, |
| 595 | + mlx_usage* usage) { |
| 596 | + if (session_ptr == nullptr) { |
| 597 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session pointer is null."); |
| 598 | + } |
| 599 | + |
| 600 | + if (output_tokens == nullptr) { |
| 601 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, kNullOutParameter); |
| 602 | + } |
| 603 | + |
| 604 | + output_tokens->tokens = nullptr; |
| 605 | + output_tokens->length = 0; |
| 606 | + |
| 607 | + auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
| 608 | + |
| 609 | + if (session->model_runner == nullptr) { |
| 610 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Model is not loaded. Call mlxsharp_session_load_model first."); |
| 611 | + } |
| 612 | + |
| 613 | + if (prompt_token_count > 0 && prompt_tokens == nullptr) { |
| 614 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Prompt tokens pointer is null."); |
| 615 | + } |
| 616 | + |
| 617 | + if (options == nullptr) { |
| 618 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Generation options pointer is null."); |
| 619 | + } |
| 620 | + |
| 621 | + return invoke([&]() -> int { |
| 622 | + std::vector<int32_t> prompt; |
| 623 | + prompt.reserve(prompt_token_count); |
| 624 | + for (size_t i = 0; i < prompt_token_count; ++i) { |
| 625 | + prompt.push_back(prompt_tokens[i]); |
| 626 | + } |
| 627 | + |
| 628 | + mlxsharp::llm::GenerationOptions native_options{ |
| 629 | + options->max_tokens, |
| 630 | + options->temperature, |
| 631 | + options->top_p, |
| 632 | + options->top_k, |
| 633 | + }; |
| 634 | + |
| 635 | + auto generated = session->model_runner->Generate(prompt, native_options); |
| 636 | + output_tokens->length = generated.size(); |
| 637 | + |
| 638 | + if (generated.empty()) { |
| 639 | + assign_usage(usage, static_cast<int>(prompt_token_count), 0); |
| 640 | + return MLXSHARP_STATUS_SUCCESS; |
| 641 | + } |
| 642 | + |
| 643 | + auto* buffer = static_cast<int32_t*>(std::malloc(generated.size() * sizeof(int32_t))); |
| 644 | + if (buffer == nullptr) { |
| 645 | + return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Failed to allocate output token buffer."); |
| 646 | + } |
| 647 | + |
| 648 | + std::memcpy(buffer, generated.data(), generated.size() * sizeof(int32_t)); |
| 649 | + output_tokens->tokens = buffer; |
| 650 | + |
| 651 | + assign_usage(usage, static_cast<int>(prompt_token_count), static_cast<int>(generated.size())); |
| 652 | + return MLXSHARP_STATUS_SUCCESS; |
| 653 | + }); |
| 654 | +} |
| 655 | + |
| 656 | +void mlxsharp_release_tokens(mlxsharp_token_buffer* buffer) { |
| 657 | + if (buffer == nullptr || buffer->tokens == nullptr) { |
| 658 | + return; |
| 659 | + } |
| 660 | + |
| 661 | + std::free(buffer->tokens); |
| 662 | + buffer->tokens = nullptr; |
| 663 | + buffer->length = 0; |
| 664 | +} |
| 665 | + |
566 | 666 | void mlxsharp_release_session(void* session_ptr) { |
567 | 667 | if (session_ptr == nullptr) { |
568 | 668 | return; |
|
0 commit comments