From e534cce350ef68d1e1068c1dfa14861374107302 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 3 Oct 2025 14:38:43 +0200 Subject: [PATCH 1/3] feat: remove stft calculation within the encoder --- .../common/rnexecutorch/models/BaseModel.cpp | 2 +- .../common/rnexecutorch/models/BaseModel.h | 2 +- .../models/speech_to_text/asr/ASR.cpp | 36 +++++++++---------- .../models/speech_to_text/asr/ASR.h | 9 +++-- 4 files changed, 23 insertions(+), 26 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index 098bd487f..26fe781bd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -55,7 +55,7 @@ std::vector BaseModel::getInputShape(std::string method_name, } std::vector> -BaseModel::getAllInputShapes(std::string methodName) { +BaseModel::getAllInputShapes(std::string methodName) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get all input shapes"); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index 983dc9b74..4720d646c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -23,7 +23,7 @@ class BaseModel { void unload() noexcept; std::vector getInputShape(std::string method_name, int32_t index); std::vector> - getAllInputShapes(std::string methodName = "forward"); + getAllInputShapes(std::string methodName = "forward") const; std::vector forwardJS(std::vector tensorViewVec); Result> forward(const EValue &input_value) const; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp index d0f965cb3..7c39e4020 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp @@ -4,8 +4,8 @@ #include "ASR.h" #include "executorch/extension/tensor/tensor_ptr.h" #include "rnexecutorch/data_processing/Numerical.h" -#include "rnexecutorch/data_processing/dsp.h" #include "rnexecutorch/data_processing/gzip.h" +#include namespace rnexecutorch::models::speech_to_text::asr { @@ -37,8 +37,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const { return seq; } -GenerationResult ASR::generate(std::span waveform, - float temperature, +GenerationResult ASR::generate(std::span waveform, float temperature, const DecodingOptions &options) const { std::vector encoderOutput = this->encode(waveform); @@ -94,7 +93,7 @@ float ASR::getCompressionRatio(const std::string &text) const { } std::vector -ASR::generateWithFallback(std::span waveform, +ASR::generateWithFallback(std::span waveform, const DecodingOptions &options) const { std::vector temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}; std::vector bestTokens; @@ -209,7 +208,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, return wordObjs; } -std::vector ASR::transcribe(std::span waveform, +std::vector ASR::transcribe(std::span waveform, const DecodingOptions &options) const { int32_t seek = 0; std::vector results; @@ -218,7 +217,7 @@ std::vector ASR::transcribe(std::span waveform, int32_t start = seek * ASR::kSamplingRate; const auto end = std::min( (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size()); - std::span chunk = waveform.subspan(start, end - start); + auto chunk = waveform.subspan(start, end - start); if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) { break; @@ -246,19 +245,12 @@ std::vector ASR::transcribe(std::span waveform, return results; } -std::vector ASR::encode(std::span waveform) const { - constexpr int32_t fftWindowSize = 512; - constexpr int32_t stftHopLength = 160; - constexpr int32_t innerDim = 256; - - std::vector preprocessedData = - dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength); - const auto numFrames = - static_cast(preprocessedData.size()) / innerDim; - std::vector inputShape = {numFrames, innerDim}; +std::vector ASR::encode(std::span waveform) const { + auto inputShape = {static_cast(waveform.size())}; const auto modelInputTensor = executorch::extension::make_tensor_ptr( - std::move(inputShape), std::move(preprocessedData)); + std::move(inputShape), waveform.data(), + executorch::runtime::etensor::ScalarType::Float); const auto encoderResult = this->encoder->forward(modelInputTensor); if (!encoderResult.ok()) { @@ -268,7 +260,7 @@ std::vector ASR::encode(std::span waveform) const { } const auto decoderOutputTensor = encoderResult.get().at(0).toTensor(); - const int32_t outputNumel = decoderOutputTensor.numel(); + const auto outputNumel = decoderOutputTensor.numel(); const float *const dataPtr = decoderOutputTensor.const_data_ptr(); return {dataPtr, dataPtr + outputNumel}; @@ -277,12 +269,18 @@ std::vector ASR::encode(std::span waveform) const { std::vector ASR::decode(std::span tokens, std::span encoderOutput) const { std::vector tokenShape = {1, static_cast(tokens.size())}; + auto tokensLong = std::vector(tokens.begin(), tokens.end()); + auto tokenTensor = executorch::extension::make_tensor_ptr( - std::move(tokenShape), tokens.data(), ScalarType::Int); + tokenShape, tokensLong.data(), ScalarType::Long); const auto encoderOutputSize = static_cast(encoderOutput.size()); std::vector encShape = {1, ASR::kNumFrames, encoderOutputSize / ASR::kNumFrames}; + log(LOG_LEVEL::Debug, encShape); + log(LOG_LEVEL::Debug, tokenShape); + log(LOG_LEVEL::Debug, this->encoder->getAllInputShapes()); + log(LOG_LEVEL::Debug, this->decoder->getAllInputShapes()); auto encoderTensor = executorch::extension::make_tensor_ptr( std::move(encShape), encoderOutput.data(), ScalarType::Float); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h index 20180ebe4..a0ea7e181 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h @@ -14,9 +14,9 @@ class ASR { const models::BaseModel *decoder, const TokenizerModule *tokenizer); std::vector - transcribe(std::span waveform, + transcribe(std::span waveform, const types::DecodingOptions &options) const; - std::vector encode(std::span waveform) const; + std::vector encode(std::span waveform) const; std::vector decode(std::span tokens, std::span encoderOutput) const; @@ -44,11 +44,10 @@ class ASR { std::vector getInitialSequence(const types::DecodingOptions &options) const; - types::GenerationResult generate(std::span waveform, - float temperature, + types::GenerationResult generate(std::span waveform, float temperature, const types::DecodingOptions &options) const; std::vector - generateWithFallback(std::span waveform, + generateWithFallback(std::span waveform, const types::DecodingOptions &options) const; std::vector calculateWordLevelTimestamps(std::span tokens, From 5106782ffe70c86dbe25f8a0c0f76cfa73040829 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 3 Oct 2025 14:39:55 +0200 Subject: [PATCH 2/3] chore: remove logs --- .../common/rnexecutorch/models/speech_to_text/asr/ASR.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp index 7c39e4020..bf8f9fb86 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp @@ -5,7 +5,6 @@ #include "executorch/extension/tensor/tensor_ptr.h" #include "rnexecutorch/data_processing/Numerical.h" #include "rnexecutorch/data_processing/gzip.h" -#include namespace rnexecutorch::models::speech_to_text::asr { @@ -277,10 +276,6 @@ std::vector ASR::decode(std::span tokens, const auto encoderOutputSize = static_cast(encoderOutput.size()); std::vector encShape = {1, ASR::kNumFrames, encoderOutputSize / ASR::kNumFrames}; - log(LOG_LEVEL::Debug, encShape); - log(LOG_LEVEL::Debug, tokenShape); - log(LOG_LEVEL::Debug, this->encoder->getAllInputShapes()); - log(LOG_LEVEL::Debug, this->decoder->getAllInputShapes()); auto encoderTensor = executorch::extension::make_tensor_ptr( std::move(encShape), encoderOutput.data(), ScalarType::Float); From 747f6278de10b397bb0ef20f160b0df8970b6444 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 3 Oct 2025 14:44:57 +0200 Subject: [PATCH 3/3] fix: mark some methods const in the BaseModel --- .../common/rnexecutorch/models/BaseModel.cpp | 10 +++++----- .../common/rnexecutorch/models/BaseModel.h | 15 +++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index 26fe781bd..79b109387 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -29,7 +29,7 @@ BaseModel::BaseModel(const std::string &modelSource, } std::vector BaseModel::getInputShape(std::string method_name, - int32_t index) { + int32_t index) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get input shape"); } @@ -87,7 +87,7 @@ BaseModel::getAllInputShapes(std::string methodName) const { /// to JS. It is not meant to be used within C++. If you want to call forward /// from C++ on a BaseModel, please use BaseModel::forward. std::vector -BaseModel::forwardJS(std::vector tensorViewVec) { +BaseModel::forwardJS(std::vector tensorViewVec) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot perform forward pass"); } @@ -135,7 +135,7 @@ BaseModel::forwardJS(std::vector tensorViewVec) { } Result -BaseModel::getMethodMeta(const std::string &methodName) { +BaseModel::getMethodMeta(const std::string &methodName) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get method meta!"); } @@ -160,7 +160,7 @@ BaseModel::forward(const std::vector &input_evalues) const { Result> BaseModel::execute(const std::string &methodName, - const std::vector &input_value) { + const std::vector &input_value) const { if (!module_) { throw std::runtime_error("Model not loaded, cannot run execute."); } @@ -174,7 +174,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept { void BaseModel::unload() noexcept { module_.reset(nullptr); } std::vector -BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) { +BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const { auto sizes = tensor.sizes(); return std::vector(sizes.begin(), sizes.end()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index 4720d646c..b944c590a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -21,18 +21,20 @@ class BaseModel { std::shared_ptr callInvoker); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; - std::vector getInputShape(std::string method_name, int32_t index); + std::vector getInputShape(std::string method_name, + int32_t index) const; std::vector> getAllInputShapes(std::string methodName = "forward") const; std::vector - forwardJS(std::vector tensorViewVec); + forwardJS(std::vector tensorViewVec) const; Result> forward(const EValue &input_value) const; Result> forward(const std::vector &input_value) const; - Result> execute(const std::string &methodName, - const std::vector &input_value); + Result> + execute(const std::string &methodName, + const std::vector &input_value) const; Result - getMethodMeta(const std::string &methodName); + getMethodMeta(const std::string &methodName) const; protected: // If possible, models should not use the JS runtime to keep JSI internals @@ -44,7 +46,8 @@ class BaseModel { private: std::size_t memorySizeLowerBound{0}; - std::vector getTensorShape(const executorch::aten::Tensor &tensor); + std::vector + getTensorShape(const executorch::aten::Tensor &tensor) const; }; } // namespace models