• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "AudioFile.h"
17 #include "base64.h"
18 #include "napi/native_api.h"
19 #include "utils.h"
20 #include <algorithm>
21 #include <cstdlib>
22 #include <fstream>
23 #include <hilog/log.h>
24 #include <iostream>
25 #include <librosa/librosa.h>
26 #include <mindspore/context.h>
27 #include <mindspore/model.h>
28 #include <mindspore/status.h>
29 #include <mindspore/tensor.h>
30 #include <mindspore/types.h>
31 #include <numeric>
32 #include <rawfile/raw_file_manager.h>
33 #include <sstream>
34 #include <vector>
35 
36 #define LOGI(...) ((void)OH_LOG_Print(LOG_APP, LOG_INFO, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
37 #define LOGD(...) ((void)OH_LOG_Print(LOG_APP, LOG_DEBUG, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
38 #define LOGW(...) ((void)OH_LOG_Print(LOG_APP, LOG_WARN, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
39 #define LOGE(...) ((void)OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
40 
41 const float NEG_INF = -std::numeric_limits<float>::infinity();
42 const int WHISPER_SOT = 50258;
43 const int WHISPER_TRANSCRIBE = 50359;
44 const int WHISPER_TRANSLATE = 50358;
45 const int WHISPER_NO_TIMESTAMPS = 50363;
46 const int WHISPER_EOT = 50257;
47 const int WHISPER_BLANK = 220;
48 const int WHISPER_NO_SPEECH = 50362;
49 const int WHISPER_N_TEXT_CTX = 448;
50 const int WHISPER_N_TEXT_STATE = 384; // for tiny
51 constexpr int WHISPER_SAMPLE_RATE = 16000;
52 constexpr int K_NUM_PRINT_OF_OUT_DATA = 20;
53 
54 using BinBuffer = std::pair<void *, size_t>;
55 
FillInputTensor(OH_AI_TensorHandle input,const BinBuffer & bin)56 int FillInputTensor(OH_AI_TensorHandle input, const BinBuffer &bin)
57 {
58     if (OH_AI_TensorGetDataSize(input) != bin.second) {
59         return OH_AI_STATUS_LITE_INPUT_PARAM_INVALID;
60     }
61     char *data = (char *)OH_AI_TensorGetMutableData(input);
62     memcpy(data, (const char *)bin.first, OH_AI_TensorGetDataSize(input));
63     return OH_AI_STATUS_SUCCESS;
64 }
65 
ReadTokens(NativeResourceManager * nativeResourceManager,const std::string & modelName)66 BinBuffer ReadTokens(NativeResourceManager *nativeResourceManager, const std::string &modelName) {
67     auto rawFile = OH_ResourceManager_OpenRawFile(nativeResourceManager, modelName.c_str());
68     if (rawFile == nullptr) {
69         LOGE("MS_LITE_ERR: Open model file failed");
70     }
71     long fileSize = OH_ResourceManager_GetRawFileSize(rawFile);
72     if (fileSize <= 0) {
73         LOGE("MS_LITE_ERR: FileSize not correct");
74     }
75     void *buffer = malloc(fileSize);
76     if (buffer == nullptr) {
77         LOGE("MS_LITE_ERR: OH_ResourceManager_ReadRawFile failed");
78     }
79     int ret = OH_ResourceManager_ReadRawFile(rawFile, buffer, fileSize);
80     if (ret == 0) {
81         LOGE("MS_LITE_LOG: OH_ResourceManager_ReadRawFile failed");
82         OH_ResourceManager_CloseRawFile(rawFile);
83     }
84     OH_ResourceManager_CloseRawFile(rawFile);
85     BinBuffer res(buffer, fileSize);
86     return res;
87 }
88 
ReadBinFile(NativeResourceManager * nativeResourceManager,const std::string & modelName)89 BinBuffer ReadBinFile(NativeResourceManager *nativeResourceManager, const std::string &modelName)
90 {
91     auto rawFile = OH_ResourceManager_OpenRawFile(nativeResourceManager, modelName.c_str());
92     if (rawFile == nullptr) {
93         LOGE("MS_LITE_ERR: Open model file failed");
94         return BinBuffer(nullptr, 0);
95     }
96     long fileSize = OH_ResourceManager_GetRawFileSize(rawFile);
97     if (fileSize <= 0) {
98         LOGE("MS_LITE_ERR: FileSize not correct");
99         return BinBuffer(nullptr, 0);
100     }
101     void *buffer = malloc(fileSize);
102     if (buffer == nullptr) {
103         LOGE("MS_LITE_ERR: OH_ResourceManager_ReadRawFile failed");
104         return BinBuffer(nullptr, 0);
105     }
106     int ret = OH_ResourceManager_ReadRawFile(rawFile, buffer, fileSize);
107     if (ret == 0) {
108         LOGE("MS_LITE_LOG: OH_ResourceManager_ReadRawFile failed");
109         OH_ResourceManager_CloseRawFile(rawFile);
110         return BinBuffer(nullptr, 0);
111     }
112     OH_ResourceManager_CloseRawFile(rawFile);
113     return BinBuffer(buffer, fileSize);
114 }
115 
DestroyModelBuffer(void ** buffer)116 void DestroyModelBuffer(void **buffer)
117 {
118     if (buffer == nullptr) {
119         return;
120     }
121     free(*buffer);
122     *buffer = nullptr;
123 }
124 
CreateMSLiteModel(BinBuffer & bin)125 OH_AI_ModelHandle CreateMSLiteModel(BinBuffer &bin)
126 {
127     // Set executing context for model.
128     auto context = OH_AI_ContextCreate();
129     if (context == nullptr) {
130         DestroyModelBuffer(&bin.first);
131         LOGE("MS_LITE_ERR: Create MSLite context failed.\n");
132         return nullptr;
133     }
134     auto cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
135     OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, false);
136     OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
137 
138     // Create model
139     auto model = OH_AI_ModelCreate();
140     if (model == nullptr) {
141         DestroyModelBuffer(&bin.first);
142         LOGE("MS_LITE_ERR: Allocate MSLite Model failed.\n");
143         return nullptr;
144     }
145 
146     // Build model object
147     auto build_ret = OH_AI_ModelBuild(model, bin.first, bin.second, OH_AI_MODELTYPE_MINDIR, context);
148     DestroyModelBuffer(&bin.first);
149     if (build_ret != OH_AI_STATUS_SUCCESS) {
150         OH_AI_ModelDestroy(&model);
151         LOGE("MS_LITE_ERR: Build MSLite model failed.\n");
152         return nullptr;
153     }
154     LOGI("MS_LITE_LOG: Build MSLite model success.\n");
155     return model;
156 }
157 
158 template <class T>
PrintBinAs(const BinBuffer & bin,const std::string & name="Vector",const size_t n_visible=10)159 void PrintBinAs(const BinBuffer &bin, const std::string &name = "Vector", const size_t n_visible = 10) {
160     size_t n_elem = bin.second / sizeof(T);
161     std::stringstream ss;
162     const T *data = reinterpret_cast<T *>(bin.first);
163     for (size_t i = 0; i < bin.second / sizeof(T) && i < n_visible; i++) {
164         ss << data[i] << " ";
165     }
166     LOGD("MS_LITE_LOG: bin name: %{public}s, n_elem: %{public}zu, data: [%{public}s]", name.c_str(), n_elem,
167          ss.str().c_str());
168     }
169 
SaveToBinaryFile(const std::vector<float> & data,const std::string & filename)170 void SaveToBinaryFile(const std::vector<float>& data, const std::string& filename) {
171     // 打开文件,以二进制格式写入
172     std::ofstream outFile(filename, std::ios::binary);
173     if (!outFile) {
174         throw std::runtime_error("无法打开文件进行写入!");
175     }
176     // 写入数据到文件
177     outFile.write(reinterpret_cast<const char*>(data.data()), data.size() * sizeof(float));
178     // 关闭文件
179     outFile.close();
180     std::cout << "数据已成功保存为二进制文件: " << filename << std::endl;
181 }
182 
RunMSLiteModel(OH_AI_ModelHandle model,std::vector<BinBuffer> inputBins)183 int RunMSLiteModel(OH_AI_ModelHandle model, std::vector<BinBuffer> inputBins)
184 {
185     // Set input data for model.
186     auto inputs = OH_AI_ModelGetInputs(model);
187     for(int i = 0; i < inputBins.size(); i++)
188     {
189         auto ret = FillInputTensor(inputs.handle_list[i], inputBins[i]);
190         if (ret != OH_AI_STATUS_SUCCESS) {
191             LOGE("MS_LITE_ERR: set input %{public}d error.\n", i);
192             return OH_AI_STATUS_LITE_ERROR;
193         }
194     }
195 
196     // Get model output.
197     auto outputs = OH_AI_ModelGetOutputs(model);
198 
199     // Predict model.
200     auto predict_ret = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
201     if (predict_ret != OH_AI_STATUS_SUCCESS) {
202         OH_AI_ModelDestroy(&model);
203         LOGE("MS_LITE_ERR: MSLite Predict error.\n");
204         return OH_AI_STATUS_LITE_ERROR;
205     }
206     LOGD("MS_LITE_LOG: Run MSLite model Predict success.\n");
207 
208     // Print output tensor data.
209     LOGD("MS_LITE_LOG: Get model outputs:\n");
210     for (size_t i = 0; i < outputs.handle_num; i++) {
211         auto tensor = outputs.handle_list[i];
212         LOGD("MS_LITE_LOG: - Tensor %{public}d name is: %{public}s.\n", static_cast<int>(i),
213              OH_AI_TensorGetName(tensor));
214         LOGD("MS_LITE_LOG: - Tensor %{public}d size is: %{public}d.\n", static_cast<int>(i),
215              (int)OH_AI_TensorGetDataSize(tensor));
216         LOGD("MS_LITE_LOG: - Tensor data is:\n");
217         auto out_data = reinterpret_cast<const float *>(OH_AI_TensorGetData(tensor));
218         std::stringstream outStr;
219         for (int i = 0; (i < OH_AI_TensorGetElementNum(tensor)) && (i <= K_NUM_PRINT_OF_OUT_DATA); i++) {
220             outStr << out_data[i] << " ";
221         }
222         LOGD("MS_LITE_LOG: %{public}s", outStr.str().c_str());
223     }
224     return OH_AI_STATUS_SUCCESS;
225 }
226 
ConvertIntVectorToFloat(const std::vector<int> & vec)227 std::vector<float> ConvertIntVectorToFloat(const std::vector<int>& vec) {
228     std::vector<float> floatVec(vec.size());
229     float* floatPtr = reinterpret_cast<float*>(const_cast<int*>(vec.data()));
230     for (size_t i = 0; i < vec.size(); ++i) {
231         floatVec[i] = *(floatPtr + i);
232     }
233     return floatVec;
234 }
235 
GetMSOutput(OH_AI_TensorHandle output)236 BinBuffer GetMSOutput(OH_AI_TensorHandle output) {
237     float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(output));
238     size_t size = OH_AI_TensorGetDataSize(output);
239     return {outputData, size};
240 }
241 
GetVecOutput(OH_AI_TensorHandle output)242 std::vector<float> GetVecOutput(OH_AI_TensorHandle output){
243     float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(output));
244     size_t len = OH_AI_TensorGetElementNum(output);
245     std::vector<float> res(outputData, outputData + len);
246     return res;
247 }
248 
SupressTokens(BinBuffer & logits,bool is_initial)249 void SupressTokens(BinBuffer &logits, bool is_initial) {
250     auto logits_data = static_cast<float *>(logits.first);
251     if (is_initial) {
252         // 假设这两个值在 logits 中的索引位置
253         logits_data[WHISPER_EOT] = NEG_INF;
254         logits_data[WHISPER_BLANK] = NEG_INF;
255     }
256 
257     // 其他令牌的抑制
258     logits_data[WHISPER_NO_TIMESTAMPS] = NEG_INF;
259     logits_data[WHISPER_SOT] = NEG_INF;
260     logits_data[WHISPER_NO_SPEECH] = NEG_INF;
261     logits_data[WHISPER_TRANSLATE] = NEG_INF;
262 }
263 
264 template <class T>
CompareVectorHelper(const T * data_a,const T * data_b,const std::string & label,size_t n,float rtol=1e-3,float atol=5e-3)265 void CompareVectorHelper(const T *data_a, const T *data_b, const std::string &label, size_t n, float rtol = 1e-3,
266                          float atol = 5e-3) {
267     LOGD("MS_LITE_LOG: ==== 精度校验 ====");
268     LOGD("MS_LITE_LOG: 比较 %{public}s", label.c_str());
269 
270     bool all_close = true;
271     float max_diff = 0.0f;
272 
273     for (size_t i = 0; i < n; i++) {
274         const float diff = std::abs(data_a[i] - data_b[i]);
275         max_diff = std::max(max_diff, diff);
276 
277         // 精度容差校验
278         if (diff > (atol + rtol * std::abs(data_b[i]))) {
279             all_close = false;
280         }
281     }
282 
283     LOGD("MS_LITE_LOG: 最大差值: %{public}.6f", max_diff);
284     LOGD("MS_LITE_LOG: all_close = %{public}d", all_close);
285 
286     if (!all_close) {
287         LOGD("MS_LITE_LOG: --- 数据不匹配详情 ---");
288 
289         // 仅输出前5个差异元素
290         constexpr int MAX_SHOW = 30;
291         int show_count = 0;
292         for (size_t i = 0; i < n && show_count < MAX_SHOW; ++i) {
293             float diff = data_a[i] - data_b[i];
294             if (std::abs(diff) > (atol + rtol * std::abs(data_b[i])) && data_a[i] != 0.0f) {
295                 LOGD("MS_LITE_LOG: 索引[%{public}zu]: %{public}.6f vs %{public}.6f (Δ=%{public}.6f)", i, data_a[i],
296                      data_b[i], diff);
297                 ++show_count;
298             }
299         }
300         LOGD("MS_LITE_LOG: === 数据不匹配,校验结束 ===");
301     }
302     return;
303 }
304 
CompareFloatVector(const BinBuffer & a,const BinBuffer & b,const std::string & label,float rtol=1e-3,float atol=5e-3)305 void CompareFloatVector(const BinBuffer &a, const BinBuffer &b, const std::string &label, float rtol = 1e-3,
306                         float atol = 5e-3) {
307     // 检查数据尺寸
308     assert(a.second == b.second);
309     const float *data_a = (const float *)a.first;
310     const float *data_b = (const float *)b.first;
311     CompareVectorHelper<float>(data_a, data_b, label, b.second / sizeof(float), rtol, atol);
312 }
313 
CompareFloatVector(const std::vector<float> & fp_a,const BinBuffer & b,const std::string & label,float rtol=1e-3,float atol=5e-3)314 void CompareFloatVector(const std::vector<float> &fp_a, const BinBuffer &b, const std::string &label, float rtol = 1e-3,
315                         float atol = 5e-3) {
316     // 检查数据尺寸
317     assert(fp_a.size() * sizeof(float) == b.second);
318 
319     const float *data_a = (const float *)fp_a.data();
320     const float *data_b = (const float *)b.first;
321 
322     CompareVectorHelper<float>(data_a, data_b, label, b.second / sizeof(float), rtol, atol);
323 }
324 
CompareIntVector(const BinBuffer & a,const BinBuffer & b,const std::string & label,float rtol=1e-3,float atol=5e-3)325 void CompareIntVector(const BinBuffer &a, const BinBuffer &b, const std::string &label, float rtol = 1e-3,
326                       float atol = 5e-3) {
327     // 检查数据尺寸
328     assert(a.second == b.second);
329 
330     const int *data_a = (const int *)a.first;
331     const int *data_b = (const int *)b.first;
332 
333     CompareVectorHelper<int>(data_a, data_b, label, b.second / sizeof(float), rtol, atol);
334 }
335 
LoopPredict(const OH_AI_ModelHandle model,const BinBuffer & n_layer_cross_k,const BinBuffer & n_layer_cross_v,const BinBuffer & logits_init,BinBuffer & out_n_layer_self_k_cache,BinBuffer & out_n_layer_self_v_cache,const BinBuffer & data_embedding,const int loop,const int offset_init)336 std::vector<int> LoopPredict(const OH_AI_ModelHandle model, const BinBuffer &n_layer_cross_k,
337                              const BinBuffer &n_layer_cross_v, const BinBuffer &logits_init,
338                              BinBuffer &out_n_layer_self_k_cache, BinBuffer &out_n_layer_self_v_cache,
339                              const BinBuffer &data_embedding, const int loop, const int offset_init) {
340     // logits
341     BinBuffer logits{nullptr, 51865 * sizeof(float)};
342     logits.first = malloc(logits.second);
343     if (!logits.first) {
344         LOGE("MS_LITE_LOG: Fail to malloc!\n");
345     }
346     void *logits_init_src = static_cast<char *>(logits_init.first) + 51865 * 3 * sizeof(float);
347     memcpy(logits.first, logits_init_src, logits.second);
348     SupressTokens(logits, true);
349 
350     std::vector<int> output_token;
351     float *logits_data = static_cast<float *>(logits.first);
352     int max_token_id = 0;
353     float max_token = logits_data[0];
354     for (int i = 0; i < logits.second / sizeof(float); i++) {
355         if (logits_data[i] > max_token) {
356             max_token_id = i;
357             max_token = logits_data[i];
358         }
359     }
360 
361     int offset = offset_init;
362     BinBuffer slice{nullptr, 0};
363     slice.second = WHISPER_N_TEXT_STATE * sizeof(float);
364     slice.first = malloc(slice.second);
365     if (!slice.first) {
366         LOGE("MS_LITE_LOG: Fail to malloc!\n");
367     }
368 
369     auto out_n_layer_self_k_cache_new = out_n_layer_self_k_cache;
370     auto out_n_layer_self_v_cache_new = out_n_layer_self_v_cache;
371 
372     for (size_t i = 0; i < loop; i++) {
373         if (max_token_id == WHISPER_EOT) {
374             break;
375         }
376         output_token.push_back(max_token_id);
377         std::vector<float> mask(WHISPER_N_TEXT_CTX, 0.0f);
378         for (size_t i = 0; i < WHISPER_N_TEXT_CTX - offset - 1; ++i) {
379             mask[i] = NEG_INF;
380         }
381         BinBuffer tokens{&max_token_id, sizeof(int)};
382 
383         void *data_embedding_src =
384             static_cast<char *>(data_embedding.first) + offset * WHISPER_N_TEXT_STATE * sizeof(float);
385         memcpy(slice.first, data_embedding_src, slice.second);
386         // out_n_layer_self_k_cache
387         // out_n_layer_self_v_cache
388         // n_layer_cross_k
389         // n_layer_cross_v
390         // slice
391         // token
392         BinBuffer mask_bin(mask.data(), mask.size() * sizeof(float));
393         int ret = RunMSLiteModel(model, {tokens, out_n_layer_self_k_cache_new, out_n_layer_self_v_cache_new,
394                                          n_layer_cross_k, n_layer_cross_v, slice, mask_bin});
395 
396         auto outputs = OH_AI_ModelGetOutputs(model);
397         logits = GetMSOutput(outputs.handle_list[0]);
398         out_n_layer_self_k_cache_new = GetMSOutput(outputs.handle_list[1]);
399         out_n_layer_self_v_cache_new = GetMSOutput(outputs.handle_list[2]);
400         offset++;
401         SupressTokens(logits, false);
402         logits_data = static_cast<float *>(logits.first);
403         max_token = logits_data[0];
404 
405         for (int j = 0; j < logits.second / sizeof(float); j++) {
406             if (logits_data[j] > max_token) {
407                 max_token_id = j;
408                 max_token = logits_data[j];
409             }
410         }
411         LOGI("MS_LITE_LOG: run decoder loop %{public}d ok!\n token = %{public}d", i, max_token_id);
412     }
413 
414     return output_token;
415 }
416 
ProcessDataLines(const BinBuffer token_txt)417 std::vector<std::string> ProcessDataLines(const BinBuffer token_txt) {
418     void *data_ptr = token_txt.first;
419     size_t data_size = token_txt.second;
420     std::vector<std::string> tokens;
421 
422     const char *char_data = static_cast<const char *>(data_ptr);
423     std::stringstream ss(std::string(char_data, char_data + data_size));
424     std::string line;
425     while (std::getline(ss, line)) {
426         size_t space_pos = line.find(' ');
427         tokens.push_back(line.substr(0, space_pos));
428     }
429     return tokens;
430 }
431 
RunDemo(napi_env env,napi_callback_info info)432 static napi_value RunDemo(napi_env env, napi_callback_info info)
433 {
434     // run demo
435     napi_value error_ret;
436     napi_create_int32(env, -1, &error_ret);
437     size_t argc = 1;
438     napi_value argv[1] = {nullptr};
439     napi_get_cb_info(env, info, &argc, argv, nullptr, nullptr);
440     auto resourcesManager = OH_ResourceManager_InitNativeResourceManager(env, argv[0]);
441 
442     // preprocess
443     AudioFile<float> audioFile;
444     std::string filePath = "zh.wav";
445     auto audioBin = ReadBinFile(resourcesManager, filePath);
446     if (audioBin.first == nullptr) {
447         LOGI("MS_LITE_LOG: Fail to read  %{public}s!", filePath.c_str());
448     }
449     size_t dataSize = audioBin.second;
450     uint8_t *dataBuffer = (uint8_t *)audioBin.first;
451     bool ok = audioFile.loadFromMemory(std::vector<uint8_t>(dataBuffer, dataBuffer + dataSize));
452     if (!ok) {
453         LOGI("MS_LITE_LOG: Fail to read  %{public}s!", filePath.c_str());
454     }
455     std::vector<float> data(audioFile.samples[0]);
456     ResampleAudio(data, audioFile.getSampleRate(), WHISPER_SAMPLE_RATE, 1, SRC_SINC_BEST_QUALITY);
457     std::vector<float> audio(data);
458 
459     int padding = 480000;
460     int sr = 16000;
461     int n_fft = 480;
462     int n_hop = 160;
463     int n_mel = 80;
464     int fmin = 0; // Minimum frequency, default value is 0.0 Hz
465     int fmax =
466         sr /
467         2.0; // Maximum frequency, default value is half of the sampling rate (sr / 2.0), i.e., the Nyquist frequency.
468     audio.insert(audio.end(), padding, 0.0f);
469     std::vector<std::vector<float>> mels_T =
470         librosa::Feature::melspectrogram(audio, sr, n_fft, n_hop, "hann", true, "reflect", 2.f, n_mel, fmin, fmax);
471     std::cout << "mels:   " << std::endl;
472 
473     std::vector<std::vector<float>> mels = TransposeMel(mels_T);
474     ProcessMelSpectrogram(mels);
475 
476     std::vector<float> inputMels(mels.size() * mels[0].size(), 0);
477     for (int i = 0; i < mels.size(); i++) {
478         std::copy(mels[i].begin(), mels[i].end(), inputMels.begin() + i * mels[0].size());
479     }
480 
481     BinBuffer inputMelsBin(inputMels.data(), inputMels.size() * sizeof(float));
482 
483     // --- encoder ---
484     auto encoderBin = ReadBinFile(resourcesManager, "tiny-encoder.ms");
485     if (encoderBin.first == nullptr) {
486         return error_ret;
487     }
488 
489     auto encoder = CreateMSLiteModel(encoderBin);
490 
491     int ret = RunMSLiteModel(encoder, {inputMelsBin});
492     if (ret != OH_AI_STATUS_SUCCESS) {
493         OH_AI_ModelDestroy(&encoder);
494         return error_ret;
495     }
496     LOGI("run encoder ok!\n");
497 
498     auto outputs = OH_AI_ModelGetOutputs(encoder);
499     auto n_layer_cross_k = GetMSOutput(outputs.handle_list[0]);
500     auto n_layer_cross_v = GetMSOutput(outputs.handle_list[1]);
501 
502     // --- decoder_main ---
503     std::vector<int> SOT_SEQUENCE = {WHISPER_SOT,
504                                      WHISPER_SOT + 1 + 1, // wait to modify
505                                      WHISPER_TRANSCRIBE, WHISPER_NO_TIMESTAMPS};
506     BinBuffer sotSequence(SOT_SEQUENCE.data(), SOT_SEQUENCE.size() * sizeof(int));
507 
508     const std::string decoder_main_path = "tiny-decoder-main.ms";
509     auto decoderMainBin = ReadBinFile(resourcesManager, decoder_main_path);
510     if (decoderMainBin.first == nullptr) {
511         return error_ret;
512     }
513     auto decoder_main = CreateMSLiteModel(decoderMainBin);
514     int ret2 = RunMSLiteModel(decoder_main, {sotSequence, n_layer_cross_k, n_layer_cross_v});
515 
516     if (ret2 != OH_AI_STATUS_SUCCESS) {
517         OH_AI_ModelDestroy(&decoder_main);
518         return error_ret;
519     }
520     LOGI("run decoder_main ok!\n");
521 
522     auto decoderMainOut = OH_AI_ModelGetOutputs(decoder_main);
523     auto logitsBin = GetMSOutput(decoderMainOut.handle_list[0]);
524     auto out_n_layer_self_k_cache_Bin = GetMSOutput(decoderMainOut.handle_list[1]);
525     auto out_n_layer_self_v_cache_Bin = GetMSOutput(decoderMainOut.handle_list[2]);
526 
527     // --- decoder_loop ---
528     const std::string modelName3 = "tiny-decoder-loop.ms";
529     auto modelBuffer3 = ReadBinFile(resourcesManager, modelName3);
530     auto decoder_loop = CreateMSLiteModel(modelBuffer3);
531 
532     const std::string dataName_embedding = "tiny-positional_embedding.bin"; // read input data
533     auto data_embedding = ReadBinFile(resourcesManager, dataName_embedding);
534     if (data_embedding.first == nullptr) {
535         return error_ret;
536     }
537 
538     int loop_times = WHISPER_N_TEXT_CTX - SOT_SEQUENCE.size();
539     int offset_init = SOT_SEQUENCE.size();
540     auto output_tokens =
541         LoopPredict(decoder_loop, n_layer_cross_k, n_layer_cross_v, logitsBin, out_n_layer_self_k_cache_Bin,
542                     out_n_layer_self_v_cache_Bin, data_embedding, loop_times, offset_init);
543 
544     std::vector<std::string> token_tables = ProcessDataLines(ReadTokens(resourcesManager, "tiny-tokens.txt"));
545     std::string result;
546     for (const auto i : output_tokens) {
547         char str[1024];
548         base64_decode((const uint8 *)token_tables[i].c_str(), (uint32)token_tables[i].size(), str);
549         result += str;
550     }
551     LOGI("MS_LITE_LOG: result is -> %{public}s", result.c_str());
552 
553     OH_AI_ModelDestroy(&encoder);
554     OH_AI_ModelDestroy(&decoder_main);
555     OH_AI_ModelDestroy(&decoder_loop);
556 
557     napi_value out_data;
558     napi_create_string_utf8(env, result.c_str(), result.length(), &out_data);
559     return out_data;
560 }
561 
562 EXTERN_C_START
Init(napi_env env,napi_value exports)563 static napi_value Init(napi_env env, napi_value exports)
564 {
565     napi_property_descriptor desc[] = {{"runDemo", nullptr, RunDemo, nullptr, nullptr, nullptr, napi_default, nullptr}};
566     napi_define_properties(env, exports, sizeof(desc) / sizeof(desc[0]), desc);
567     return exports;
568 }
569 EXTERN_C_END
570 
571 static napi_module demoModule = {
572     .nm_version = 1,
573     .nm_flags = 0,
574     .nm_filename = nullptr,
575     .nm_register_func = Init,
576     .nm_modname = "entry",
577     .nm_priv = ((void *)0),
578     .reserved = {0},
579 };
580 
RegisterEntryModule(void)581 extern "C" __attribute__((constructor)) void RegisterEntryModule(void) { napi_module_register(&demoModule); }
582