1 /*
2 * Copyright (c) 2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "AudioFile.h"
17 #include "base64.h"
18 #include "napi/native_api.h"
19 #include "utils.h"
20 #include <algorithm>
21 #include <cstdlib>
22 #include <fstream>
23 #include <hilog/log.h>
24 #include <iostream>
25 #include <librosa/librosa.h>
26 #include <mindspore/context.h>
27 #include <mindspore/model.h>
28 #include <mindspore/status.h>
29 #include <mindspore/tensor.h>
30 #include <mindspore/types.h>
31 #include <numeric>
32 #include <rawfile/raw_file_manager.h>
33 #include <sstream>
34 #include <vector>
35
36 #define LOGI(...) ((void)OH_LOG_Print(LOG_APP, LOG_INFO, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
37 #define LOGD(...) ((void)OH_LOG_Print(LOG_APP, LOG_DEBUG, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
38 #define LOGW(...) ((void)OH_LOG_Print(LOG_APP, LOG_WARN, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
39 #define LOGE(...) ((void)OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, "[MSLiteNapi]", __VA_ARGS__))
40
41 const float NEG_INF = -std::numeric_limits<float>::infinity();
42 const int WHISPER_SOT = 50258;
43 const int WHISPER_TRANSCRIBE = 50359;
44 const int WHISPER_TRANSLATE = 50358;
45 const int WHISPER_NO_TIMESTAMPS = 50363;
46 const int WHISPER_EOT = 50257;
47 const int WHISPER_BLANK = 220;
48 const int WHISPER_NO_SPEECH = 50362;
49 const int WHISPER_N_TEXT_CTX = 448;
50 const int WHISPER_N_TEXT_STATE = 384; // for tiny
51 constexpr int WHISPER_SAMPLE_RATE = 16000;
52 constexpr int K_NUM_PRINT_OF_OUT_DATA = 20;
53
54 using BinBuffer = std::pair<void *, size_t>;
55
FillInputTensor(OH_AI_TensorHandle input,const BinBuffer & bin)56 int FillInputTensor(OH_AI_TensorHandle input, const BinBuffer &bin)
57 {
58 if (OH_AI_TensorGetDataSize(input) != bin.second) {
59 return OH_AI_STATUS_LITE_INPUT_PARAM_INVALID;
60 }
61 char *data = (char *)OH_AI_TensorGetMutableData(input);
62 memcpy(data, (const char *)bin.first, OH_AI_TensorGetDataSize(input));
63 return OH_AI_STATUS_SUCCESS;
64 }
65
ReadTokens(NativeResourceManager * nativeResourceManager,const std::string & modelName)66 BinBuffer ReadTokens(NativeResourceManager *nativeResourceManager, const std::string &modelName) {
67 auto rawFile = OH_ResourceManager_OpenRawFile(nativeResourceManager, modelName.c_str());
68 if (rawFile == nullptr) {
69 LOGE("MS_LITE_ERR: Open model file failed");
70 }
71 long fileSize = OH_ResourceManager_GetRawFileSize(rawFile);
72 if (fileSize <= 0) {
73 LOGE("MS_LITE_ERR: FileSize not correct");
74 }
75 void *buffer = malloc(fileSize);
76 if (buffer == nullptr) {
77 LOGE("MS_LITE_ERR: OH_ResourceManager_ReadRawFile failed");
78 }
79 int ret = OH_ResourceManager_ReadRawFile(rawFile, buffer, fileSize);
80 if (ret == 0) {
81 LOGE("MS_LITE_LOG: OH_ResourceManager_ReadRawFile failed");
82 OH_ResourceManager_CloseRawFile(rawFile);
83 }
84 OH_ResourceManager_CloseRawFile(rawFile);
85 BinBuffer res(buffer, fileSize);
86 return res;
87 }
88
ReadBinFile(NativeResourceManager * nativeResourceManager,const std::string & modelName)89 BinBuffer ReadBinFile(NativeResourceManager *nativeResourceManager, const std::string &modelName)
90 {
91 auto rawFile = OH_ResourceManager_OpenRawFile(nativeResourceManager, modelName.c_str());
92 if (rawFile == nullptr) {
93 LOGE("MS_LITE_ERR: Open model file failed");
94 return BinBuffer(nullptr, 0);
95 }
96 long fileSize = OH_ResourceManager_GetRawFileSize(rawFile);
97 if (fileSize <= 0) {
98 LOGE("MS_LITE_ERR: FileSize not correct");
99 return BinBuffer(nullptr, 0);
100 }
101 void *buffer = malloc(fileSize);
102 if (buffer == nullptr) {
103 LOGE("MS_LITE_ERR: OH_ResourceManager_ReadRawFile failed");
104 return BinBuffer(nullptr, 0);
105 }
106 int ret = OH_ResourceManager_ReadRawFile(rawFile, buffer, fileSize);
107 if (ret == 0) {
108 LOGE("MS_LITE_LOG: OH_ResourceManager_ReadRawFile failed");
109 OH_ResourceManager_CloseRawFile(rawFile);
110 return BinBuffer(nullptr, 0);
111 }
112 OH_ResourceManager_CloseRawFile(rawFile);
113 return BinBuffer(buffer, fileSize);
114 }
115
DestroyModelBuffer(void ** buffer)116 void DestroyModelBuffer(void **buffer)
117 {
118 if (buffer == nullptr) {
119 return;
120 }
121 free(*buffer);
122 *buffer = nullptr;
123 }
124
CreateMSLiteModel(BinBuffer & bin)125 OH_AI_ModelHandle CreateMSLiteModel(BinBuffer &bin)
126 {
127 // Set executing context for model.
128 auto context = OH_AI_ContextCreate();
129 if (context == nullptr) {
130 DestroyModelBuffer(&bin.first);
131 LOGE("MS_LITE_ERR: Create MSLite context failed.\n");
132 return nullptr;
133 }
134 auto cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
135 OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, false);
136 OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
137
138 // Create model
139 auto model = OH_AI_ModelCreate();
140 if (model == nullptr) {
141 DestroyModelBuffer(&bin.first);
142 LOGE("MS_LITE_ERR: Allocate MSLite Model failed.\n");
143 return nullptr;
144 }
145
146 // Build model object
147 auto build_ret = OH_AI_ModelBuild(model, bin.first, bin.second, OH_AI_MODELTYPE_MINDIR, context);
148 DestroyModelBuffer(&bin.first);
149 if (build_ret != OH_AI_STATUS_SUCCESS) {
150 OH_AI_ModelDestroy(&model);
151 LOGE("MS_LITE_ERR: Build MSLite model failed.\n");
152 return nullptr;
153 }
154 LOGI("MS_LITE_LOG: Build MSLite model success.\n");
155 return model;
156 }
157
158 template <class T>
PrintBinAs(const BinBuffer & bin,const std::string & name="Vector",const size_t n_visible=10)159 void PrintBinAs(const BinBuffer &bin, const std::string &name = "Vector", const size_t n_visible = 10) {
160 size_t n_elem = bin.second / sizeof(T);
161 std::stringstream ss;
162 const T *data = reinterpret_cast<T *>(bin.first);
163 for (size_t i = 0; i < bin.second / sizeof(T) && i < n_visible; i++) {
164 ss << data[i] << " ";
165 }
166 LOGD("MS_LITE_LOG: bin name: %{public}s, n_elem: %{public}zu, data: [%{public}s]", name.c_str(), n_elem,
167 ss.str().c_str());
168 }
169
SaveToBinaryFile(const std::vector<float> & data,const std::string & filename)170 void SaveToBinaryFile(const std::vector<float>& data, const std::string& filename) {
171 // 打开文件,以二进制格式写入
172 std::ofstream outFile(filename, std::ios::binary);
173 if (!outFile) {
174 throw std::runtime_error("无法打开文件进行写入!");
175 }
176 // 写入数据到文件
177 outFile.write(reinterpret_cast<const char*>(data.data()), data.size() * sizeof(float));
178 // 关闭文件
179 outFile.close();
180 std::cout << "数据已成功保存为二进制文件: " << filename << std::endl;
181 }
182
RunMSLiteModel(OH_AI_ModelHandle model,std::vector<BinBuffer> inputBins)183 int RunMSLiteModel(OH_AI_ModelHandle model, std::vector<BinBuffer> inputBins)
184 {
185 // Set input data for model.
186 auto inputs = OH_AI_ModelGetInputs(model);
187 for(int i = 0; i < inputBins.size(); i++)
188 {
189 auto ret = FillInputTensor(inputs.handle_list[i], inputBins[i]);
190 if (ret != OH_AI_STATUS_SUCCESS) {
191 LOGE("MS_LITE_ERR: set input %{public}d error.\n", i);
192 return OH_AI_STATUS_LITE_ERROR;
193 }
194 }
195
196 // Get model output.
197 auto outputs = OH_AI_ModelGetOutputs(model);
198
199 // Predict model.
200 auto predict_ret = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
201 if (predict_ret != OH_AI_STATUS_SUCCESS) {
202 OH_AI_ModelDestroy(&model);
203 LOGE("MS_LITE_ERR: MSLite Predict error.\n");
204 return OH_AI_STATUS_LITE_ERROR;
205 }
206 LOGD("MS_LITE_LOG: Run MSLite model Predict success.\n");
207
208 // Print output tensor data.
209 LOGD("MS_LITE_LOG: Get model outputs:\n");
210 for (size_t i = 0; i < outputs.handle_num; i++) {
211 auto tensor = outputs.handle_list[i];
212 LOGD("MS_LITE_LOG: - Tensor %{public}d name is: %{public}s.\n", static_cast<int>(i),
213 OH_AI_TensorGetName(tensor));
214 LOGD("MS_LITE_LOG: - Tensor %{public}d size is: %{public}d.\n", static_cast<int>(i),
215 (int)OH_AI_TensorGetDataSize(tensor));
216 LOGD("MS_LITE_LOG: - Tensor data is:\n");
217 auto out_data = reinterpret_cast<const float *>(OH_AI_TensorGetData(tensor));
218 std::stringstream outStr;
219 for (int i = 0; (i < OH_AI_TensorGetElementNum(tensor)) && (i <= K_NUM_PRINT_OF_OUT_DATA); i++) {
220 outStr << out_data[i] << " ";
221 }
222 LOGD("MS_LITE_LOG: %{public}s", outStr.str().c_str());
223 }
224 return OH_AI_STATUS_SUCCESS;
225 }
226
ConvertIntVectorToFloat(const std::vector<int> & vec)227 std::vector<float> ConvertIntVectorToFloat(const std::vector<int>& vec) {
228 std::vector<float> floatVec(vec.size());
229 float* floatPtr = reinterpret_cast<float*>(const_cast<int*>(vec.data()));
230 for (size_t i = 0; i < vec.size(); ++i) {
231 floatVec[i] = *(floatPtr + i);
232 }
233 return floatVec;
234 }
235
GetMSOutput(OH_AI_TensorHandle output)236 BinBuffer GetMSOutput(OH_AI_TensorHandle output) {
237 float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(output));
238 size_t size = OH_AI_TensorGetDataSize(output);
239 return {outputData, size};
240 }
241
GetVecOutput(OH_AI_TensorHandle output)242 std::vector<float> GetVecOutput(OH_AI_TensorHandle output){
243 float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(output));
244 size_t len = OH_AI_TensorGetElementNum(output);
245 std::vector<float> res(outputData, outputData + len);
246 return res;
247 }
248
SupressTokens(BinBuffer & logits,bool is_initial)249 void SupressTokens(BinBuffer &logits, bool is_initial) {
250 auto logits_data = static_cast<float *>(logits.first);
251 if (is_initial) {
252 // 假设这两个值在 logits 中的索引位置
253 logits_data[WHISPER_EOT] = NEG_INF;
254 logits_data[WHISPER_BLANK] = NEG_INF;
255 }
256
257 // 其他令牌的抑制
258 logits_data[WHISPER_NO_TIMESTAMPS] = NEG_INF;
259 logits_data[WHISPER_SOT] = NEG_INF;
260 logits_data[WHISPER_NO_SPEECH] = NEG_INF;
261 logits_data[WHISPER_TRANSLATE] = NEG_INF;
262 }
263
264 template <class T>
CompareVectorHelper(const T * data_a,const T * data_b,const std::string & label,size_t n,float rtol=1e-3,float atol=5e-3)265 void CompareVectorHelper(const T *data_a, const T *data_b, const std::string &label, size_t n, float rtol = 1e-3,
266 float atol = 5e-3) {
267 LOGD("MS_LITE_LOG: ==== 精度校验 ====");
268 LOGD("MS_LITE_LOG: 比较 %{public}s", label.c_str());
269
270 bool all_close = true;
271 float max_diff = 0.0f;
272
273 for (size_t i = 0; i < n; i++) {
274 const float diff = std::abs(data_a[i] - data_b[i]);
275 max_diff = std::max(max_diff, diff);
276
277 // 精度容差校验
278 if (diff > (atol + rtol * std::abs(data_b[i]))) {
279 all_close = false;
280 }
281 }
282
283 LOGD("MS_LITE_LOG: 最大差值: %{public}.6f", max_diff);
284 LOGD("MS_LITE_LOG: all_close = %{public}d", all_close);
285
286 if (!all_close) {
287 LOGD("MS_LITE_LOG: --- 数据不匹配详情 ---");
288
289 // 仅输出前5个差异元素
290 constexpr int MAX_SHOW = 30;
291 int show_count = 0;
292 for (size_t i = 0; i < n && show_count < MAX_SHOW; ++i) {
293 float diff = data_a[i] - data_b[i];
294 if (std::abs(diff) > (atol + rtol * std::abs(data_b[i])) && data_a[i] != 0.0f) {
295 LOGD("MS_LITE_LOG: 索引[%{public}zu]: %{public}.6f vs %{public}.6f (Δ=%{public}.6f)", i, data_a[i],
296 data_b[i], diff);
297 ++show_count;
298 }
299 }
300 LOGD("MS_LITE_LOG: === 数据不匹配,校验结束 ===");
301 }
302 return;
303 }
304
CompareFloatVector(const BinBuffer & a,const BinBuffer & b,const std::string & label,float rtol=1e-3,float atol=5e-3)305 void CompareFloatVector(const BinBuffer &a, const BinBuffer &b, const std::string &label, float rtol = 1e-3,
306 float atol = 5e-3) {
307 // 检查数据尺寸
308 assert(a.second == b.second);
309 const float *data_a = (const float *)a.first;
310 const float *data_b = (const float *)b.first;
311 CompareVectorHelper<float>(data_a, data_b, label, b.second / sizeof(float), rtol, atol);
312 }
313
CompareFloatVector(const std::vector<float> & fp_a,const BinBuffer & b,const std::string & label,float rtol=1e-3,float atol=5e-3)314 void CompareFloatVector(const std::vector<float> &fp_a, const BinBuffer &b, const std::string &label, float rtol = 1e-3,
315 float atol = 5e-3) {
316 // 检查数据尺寸
317 assert(fp_a.size() * sizeof(float) == b.second);
318
319 const float *data_a = (const float *)fp_a.data();
320 const float *data_b = (const float *)b.first;
321
322 CompareVectorHelper<float>(data_a, data_b, label, b.second / sizeof(float), rtol, atol);
323 }
324
CompareIntVector(const BinBuffer & a,const BinBuffer & b,const std::string & label,float rtol=1e-3,float atol=5e-3)325 void CompareIntVector(const BinBuffer &a, const BinBuffer &b, const std::string &label, float rtol = 1e-3,
326 float atol = 5e-3) {
327 // 检查数据尺寸
328 assert(a.second == b.second);
329
330 const int *data_a = (const int *)a.first;
331 const int *data_b = (const int *)b.first;
332
333 CompareVectorHelper<int>(data_a, data_b, label, b.second / sizeof(float), rtol, atol);
334 }
335
LoopPredict(const OH_AI_ModelHandle model,const BinBuffer & n_layer_cross_k,const BinBuffer & n_layer_cross_v,const BinBuffer & logits_init,BinBuffer & out_n_layer_self_k_cache,BinBuffer & out_n_layer_self_v_cache,const BinBuffer & data_embedding,const int loop,const int offset_init)336 std::vector<int> LoopPredict(const OH_AI_ModelHandle model, const BinBuffer &n_layer_cross_k,
337 const BinBuffer &n_layer_cross_v, const BinBuffer &logits_init,
338 BinBuffer &out_n_layer_self_k_cache, BinBuffer &out_n_layer_self_v_cache,
339 const BinBuffer &data_embedding, const int loop, const int offset_init) {
340 // logits
341 BinBuffer logits{nullptr, 51865 * sizeof(float)};
342 logits.first = malloc(logits.second);
343 if (!logits.first) {
344 LOGE("MS_LITE_LOG: Fail to malloc!\n");
345 }
346 void *logits_init_src = static_cast<char *>(logits_init.first) + 51865 * 3 * sizeof(float);
347 memcpy(logits.first, logits_init_src, logits.second);
348 SupressTokens(logits, true);
349
350 std::vector<int> output_token;
351 float *logits_data = static_cast<float *>(logits.first);
352 int max_token_id = 0;
353 float max_token = logits_data[0];
354 for (int i = 0; i < logits.second / sizeof(float); i++) {
355 if (logits_data[i] > max_token) {
356 max_token_id = i;
357 max_token = logits_data[i];
358 }
359 }
360
361 int offset = offset_init;
362 BinBuffer slice{nullptr, 0};
363 slice.second = WHISPER_N_TEXT_STATE * sizeof(float);
364 slice.first = malloc(slice.second);
365 if (!slice.first) {
366 LOGE("MS_LITE_LOG: Fail to malloc!\n");
367 }
368
369 auto out_n_layer_self_k_cache_new = out_n_layer_self_k_cache;
370 auto out_n_layer_self_v_cache_new = out_n_layer_self_v_cache;
371
372 for (size_t i = 0; i < loop; i++) {
373 if (max_token_id == WHISPER_EOT) {
374 break;
375 }
376 output_token.push_back(max_token_id);
377 std::vector<float> mask(WHISPER_N_TEXT_CTX, 0.0f);
378 for (size_t i = 0; i < WHISPER_N_TEXT_CTX - offset - 1; ++i) {
379 mask[i] = NEG_INF;
380 }
381 BinBuffer tokens{&max_token_id, sizeof(int)};
382
383 void *data_embedding_src =
384 static_cast<char *>(data_embedding.first) + offset * WHISPER_N_TEXT_STATE * sizeof(float);
385 memcpy(slice.first, data_embedding_src, slice.second);
386 // out_n_layer_self_k_cache
387 // out_n_layer_self_v_cache
388 // n_layer_cross_k
389 // n_layer_cross_v
390 // slice
391 // token
392 BinBuffer mask_bin(mask.data(), mask.size() * sizeof(float));
393 int ret = RunMSLiteModel(model, {tokens, out_n_layer_self_k_cache_new, out_n_layer_self_v_cache_new,
394 n_layer_cross_k, n_layer_cross_v, slice, mask_bin});
395
396 auto outputs = OH_AI_ModelGetOutputs(model);
397 logits = GetMSOutput(outputs.handle_list[0]);
398 out_n_layer_self_k_cache_new = GetMSOutput(outputs.handle_list[1]);
399 out_n_layer_self_v_cache_new = GetMSOutput(outputs.handle_list[2]);
400 offset++;
401 SupressTokens(logits, false);
402 logits_data = static_cast<float *>(logits.first);
403 max_token = logits_data[0];
404
405 for (int j = 0; j < logits.second / sizeof(float); j++) {
406 if (logits_data[j] > max_token) {
407 max_token_id = j;
408 max_token = logits_data[j];
409 }
410 }
411 LOGI("MS_LITE_LOG: run decoder loop %{public}d ok!\n token = %{public}d", i, max_token_id);
412 }
413
414 return output_token;
415 }
416
ProcessDataLines(const BinBuffer token_txt)417 std::vector<std::string> ProcessDataLines(const BinBuffer token_txt) {
418 void *data_ptr = token_txt.first;
419 size_t data_size = token_txt.second;
420 std::vector<std::string> tokens;
421
422 const char *char_data = static_cast<const char *>(data_ptr);
423 std::stringstream ss(std::string(char_data, char_data + data_size));
424 std::string line;
425 while (std::getline(ss, line)) {
426 size_t space_pos = line.find(' ');
427 tokens.push_back(line.substr(0, space_pos));
428 }
429 return tokens;
430 }
431
RunDemo(napi_env env,napi_callback_info info)432 static napi_value RunDemo(napi_env env, napi_callback_info info)
433 {
434 // run demo
435 napi_value error_ret;
436 napi_create_int32(env, -1, &error_ret);
437 size_t argc = 1;
438 napi_value argv[1] = {nullptr};
439 napi_get_cb_info(env, info, &argc, argv, nullptr, nullptr);
440 auto resourcesManager = OH_ResourceManager_InitNativeResourceManager(env, argv[0]);
441
442 // preprocess
443 AudioFile<float> audioFile;
444 std::string filePath = "zh.wav";
445 auto audioBin = ReadBinFile(resourcesManager, filePath);
446 if (audioBin.first == nullptr) {
447 LOGI("MS_LITE_LOG: Fail to read %{public}s!", filePath.c_str());
448 }
449 size_t dataSize = audioBin.second;
450 uint8_t *dataBuffer = (uint8_t *)audioBin.first;
451 bool ok = audioFile.loadFromMemory(std::vector<uint8_t>(dataBuffer, dataBuffer + dataSize));
452 if (!ok) {
453 LOGI("MS_LITE_LOG: Fail to read %{public}s!", filePath.c_str());
454 }
455 std::vector<float> data(audioFile.samples[0]);
456 ResampleAudio(data, audioFile.getSampleRate(), WHISPER_SAMPLE_RATE, 1, SRC_SINC_BEST_QUALITY);
457 std::vector<float> audio(data);
458
459 int padding = 480000;
460 int sr = 16000;
461 int n_fft = 480;
462 int n_hop = 160;
463 int n_mel = 80;
464 int fmin = 0; // Minimum frequency, default value is 0.0 Hz
465 int fmax =
466 sr /
467 2.0; // Maximum frequency, default value is half of the sampling rate (sr / 2.0), i.e., the Nyquist frequency.
468 audio.insert(audio.end(), padding, 0.0f);
469 std::vector<std::vector<float>> mels_T =
470 librosa::Feature::melspectrogram(audio, sr, n_fft, n_hop, "hann", true, "reflect", 2.f, n_mel, fmin, fmax);
471 std::cout << "mels: " << std::endl;
472
473 std::vector<std::vector<float>> mels = TransposeMel(mels_T);
474 ProcessMelSpectrogram(mels);
475
476 std::vector<float> inputMels(mels.size() * mels[0].size(), 0);
477 for (int i = 0; i < mels.size(); i++) {
478 std::copy(mels[i].begin(), mels[i].end(), inputMels.begin() + i * mels[0].size());
479 }
480
481 BinBuffer inputMelsBin(inputMels.data(), inputMels.size() * sizeof(float));
482
483 // --- encoder ---
484 auto encoderBin = ReadBinFile(resourcesManager, "tiny-encoder.ms");
485 if (encoderBin.first == nullptr) {
486 return error_ret;
487 }
488
489 auto encoder = CreateMSLiteModel(encoderBin);
490
491 int ret = RunMSLiteModel(encoder, {inputMelsBin});
492 if (ret != OH_AI_STATUS_SUCCESS) {
493 OH_AI_ModelDestroy(&encoder);
494 return error_ret;
495 }
496 LOGI("run encoder ok!\n");
497
498 auto outputs = OH_AI_ModelGetOutputs(encoder);
499 auto n_layer_cross_k = GetMSOutput(outputs.handle_list[0]);
500 auto n_layer_cross_v = GetMSOutput(outputs.handle_list[1]);
501
502 // --- decoder_main ---
503 std::vector<int> SOT_SEQUENCE = {WHISPER_SOT,
504 WHISPER_SOT + 1 + 1, // wait to modify
505 WHISPER_TRANSCRIBE, WHISPER_NO_TIMESTAMPS};
506 BinBuffer sotSequence(SOT_SEQUENCE.data(), SOT_SEQUENCE.size() * sizeof(int));
507
508 const std::string decoder_main_path = "tiny-decoder-main.ms";
509 auto decoderMainBin = ReadBinFile(resourcesManager, decoder_main_path);
510 if (decoderMainBin.first == nullptr) {
511 return error_ret;
512 }
513 auto decoder_main = CreateMSLiteModel(decoderMainBin);
514 int ret2 = RunMSLiteModel(decoder_main, {sotSequence, n_layer_cross_k, n_layer_cross_v});
515
516 if (ret2 != OH_AI_STATUS_SUCCESS) {
517 OH_AI_ModelDestroy(&decoder_main);
518 return error_ret;
519 }
520 LOGI("run decoder_main ok!\n");
521
522 auto decoderMainOut = OH_AI_ModelGetOutputs(decoder_main);
523 auto logitsBin = GetMSOutput(decoderMainOut.handle_list[0]);
524 auto out_n_layer_self_k_cache_Bin = GetMSOutput(decoderMainOut.handle_list[1]);
525 auto out_n_layer_self_v_cache_Bin = GetMSOutput(decoderMainOut.handle_list[2]);
526
527 // --- decoder_loop ---
528 const std::string modelName3 = "tiny-decoder-loop.ms";
529 auto modelBuffer3 = ReadBinFile(resourcesManager, modelName3);
530 auto decoder_loop = CreateMSLiteModel(modelBuffer3);
531
532 const std::string dataName_embedding = "tiny-positional_embedding.bin"; // read input data
533 auto data_embedding = ReadBinFile(resourcesManager, dataName_embedding);
534 if (data_embedding.first == nullptr) {
535 return error_ret;
536 }
537
538 int loop_times = WHISPER_N_TEXT_CTX - SOT_SEQUENCE.size();
539 int offset_init = SOT_SEQUENCE.size();
540 auto output_tokens =
541 LoopPredict(decoder_loop, n_layer_cross_k, n_layer_cross_v, logitsBin, out_n_layer_self_k_cache_Bin,
542 out_n_layer_self_v_cache_Bin, data_embedding, loop_times, offset_init);
543
544 std::vector<std::string> token_tables = ProcessDataLines(ReadTokens(resourcesManager, "tiny-tokens.txt"));
545 std::string result;
546 for (const auto i : output_tokens) {
547 char str[1024];
548 base64_decode((const uint8 *)token_tables[i].c_str(), (uint32)token_tables[i].size(), str);
549 result += str;
550 }
551 LOGI("MS_LITE_LOG: result is -> %{public}s", result.c_str());
552
553 OH_AI_ModelDestroy(&encoder);
554 OH_AI_ModelDestroy(&decoder_main);
555 OH_AI_ModelDestroy(&decoder_loop);
556
557 napi_value out_data;
558 napi_create_string_utf8(env, result.c_str(), result.length(), &out_data);
559 return out_data;
560 }
561
562 EXTERN_C_START
Init(napi_env env,napi_value exports)563 static napi_value Init(napi_env env, napi_value exports)
564 {
565 napi_property_descriptor desc[] = {{"runDemo", nullptr, RunDemo, nullptr, nullptr, nullptr, napi_default, nullptr}};
566 napi_define_properties(env, exports, sizeof(desc) / sizeof(desc[0]), desc);
567 return exports;
568 }
569 EXTERN_C_END
570
571 static napi_module demoModule = {
572 .nm_version = 1,
573 .nm_flags = 0,
574 .nm_filename = nullptr,
575 .nm_register_func = Init,
576 .nm_modname = "entry",
577 .nm_priv = ((void *)0),
578 .reserved = {0},
579 };
580
RegisterEntryModule(void)581 extern "C" __attribute__((constructor)) void RegisterEntryModule(void) { napi_module_register(&demoModule); }
582