• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# 使用MindSpore Lite进行端侧训练 (C/C++)
2
3<!--Kit: MindSpore Lite Kit-->
4<!--Subsystem: AI-->
5<!--Owner: @zhuguodong8-->
6<!--Designer: @zhuguodong8; @jjfeing-->
7<!--Tester: @principal87-->
8<!--Adviser: @ge-yafang-->
9
10## 场景介绍
11
12MindSpore Lite是一款AI引擎,它提供了面向不同硬件设备AI模型推理的功能,目前已经在图像分类、目标识别、人脸识别、文字识别等应用中广泛使用,同时支持在端侧设备上进行部署训练,让模型在实际业务场景中自适应用户的行为。
13
14本文介绍使用MindSpore Lite端侧AI引擎进行模型训练的通用开发流程。
15
16
17## 接口说明
18此处给出使用MindSpore Lite进行模型训练相关的部分接口,具体请见下方表格。更多接口及详细内容,请见[MindSpore](../../reference/apis-mindspore-lite-kit/capi-mindspore.md)。
19
20| 接口名称        | 描述        |
21| ------------------ | ----------------- |
22|OH_AI_ContextHandle OH_AI_ContextCreate()|创建一个上下文的对象。注意:此接口需跟OH_AI_ContextDestroy配套使用。|
23|OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type)|创建一个运行时设备信息对象。|
24|void OH_AI_ContextDestroy(OH_AI_ContextHandle *context)|释放上下文对象。|
25|void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info)|添加运行时设备信息。|
26|OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate()|创建训练配置对象指针。|
27|void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg)|销毁训练配置对象指针。|
28|OH_AI_ModelHandle OH_AI_ModelCreate()|创建一个模型对象。|
29|OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg)|通过模型文件加载并编译MindSpore训练模型。|
30|OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after)|单步训练模型。|
31|OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train)|设置训练模式。|
32|OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name, size_t num)|导出训练后的ms模型。|
33|void OH_AI_ModelDestroy(OH_AI_ModelHandle *model)|释放一个模型对象。|
34
35
36## 开发步骤
37使用MindSpore Lite进行模型训练的开发流程如下图所示。
38
39**图 1** 使用MindSpore Lite进行模型训练的开发流程
40![how-to-use-train](figures/train_sequence_unify_api.png)
41
42进入主要流程之前需要先引用相关的头文件,并编写函数生成随机的输入,具体如下:
43
44```c
45#include <stdlib.h>
46#include <stdio.h>
47#include <string.h>
48#include "mindspore/model.h"
49
50int GenerateInputDataWithRandom(OH_AI_TensorHandleArray inputs) {
51  for (size_t i = 0; i < inputs.handle_num; ++i) {
52    float *input_data = (float *)OH_AI_TensorGetMutableData(inputs.handle_list[i]);
53    if (input_data == NULL) {
54      printf("OH_AI_TensorGetMutableData failed.\n");
55      return  OH_AI_STATUS_LITE_ERROR;
56    }
57    int64_t num = OH_AI_TensorGetElementNum(inputs.handle_list[i]);
58    const int divisor = 10;
59    for (size_t j = 0; j < num; j++) {
60      input_data[j] = (float)(rand() % divisor) / divisor;  // 0--0.9f
61    }
62  }
63  return OH_AI_STATUS_SUCCESS;
64}
65```
66
67然后进入主要的开发步骤,包括模型的准备、读取、编译、训练、模型导出和释放,具体开发过程及细节请见下文的开发步骤及示例。
68
691. 模型准备。
70
71    准备的模型格式为`.ms`,本文以lenet_train.ms为例(此模型是提前准备的`ms`模型,本文相关效果仅以此模型文件为例)。开发者请自行准备所需的模型,可以按如下步骤操作:
72
73    - 首先基于MindSpore架构使用Python创建网络模型,并导出为`.mindir`文件,详细指南参考[这里](https://www.mindspore.cn/tutorials/zh-CN/r2.1/beginner/quick_start.html)74    - 然后将`.mindir`模型文件转换成`.ms`文件,转换操作步骤可以参考[训练模型转换](https://www.mindspore.cn/lite/docs/zh-CN/r2.1/use/converter_train.html),`.ms`文件可以导入端侧设备并基于MindSpore端侧框架进行训练。
75
762. 创建上下文,设置设备类型、训练配置等参数。
77
78    ```c
79    // Create and init context, add CPU device info
80    OH_AI_ContextHandle context = OH_AI_ContextCreate();
81    if (context == NULL) {
82        printf("OH_AI_ContextCreate failed.\n");
83        return OH_AI_STATUS_LITE_ERROR;
84    }
85
86    OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
87    if (cpu_device_info == NULL) {
88        printf("OH_AI_DeviceInfoCreate failed.\n");
89        OH_AI_ContextDestroy(&context);
90        return OH_AI_STATUS_LITE_ERROR;
91    }
92    OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
93
94    // Create trainCfg
95    OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
96    if (trainCfg == NULL) {
97        printf("OH_AI_TrainCfgCreate failed.\n");
98        OH_AI_ContextDestroy(&context);
99        return OH_AI_STATUS_LITE_ERROR;
100    }
101    ```
102
1033. 创建、加载与编译模型。
104
105    调用OH_AI_TrainModelBuildFromFile加载并编译模型。
106
107    ```c
108    // Create model
109    OH_AI_ModelHandle model = OH_AI_ModelCreate();
110    if (model == NULL) {
111        printf("OH_AI_ModelCreate failed.\n");
112        OH_AI_TrainCfgDestroy(&trainCfg);
113        OH_AI_ContextDestroy(&context);
114        return OH_AI_STATUS_LITE_ERROR;
115    }
116
117    // Build model
118    int ret = OH_AI_TrainModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
119    if (ret != OH_AI_STATUS_SUCCESS) {
120        printf("OH_AI_TrainModelBuildFromFile failed, ret: %d.\n", ret);
121        OH_AI_ModelDestroy(&model);
122        OH_AI_ContextDestroy(&context);
123        return ret;
124    }
125    ```
126
1274. 输入数据。
128
129    模型执行之前需要向输入的张量中填充数据。本例使用随机的数据对模型进行填充。
130
131    ```c
132    // Get Inputs
133    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
134    if (inputs.handle_list == NULL) {
135        printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
136        OH_AI_ModelDestroy(&model);
137        OH_AI_ContextDestroy(&context);
138        return ret;
139    }
140
141    // Generate random data as input data.
142    ret = GenerateInputDataWithRandom(inputs);
143    if (ret != OH_AI_STATUS_SUCCESS) {
144        printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
145        OH_AI_ModelDestroy(&model);
146        OH_AI_ContextDestroy(&context);
147        return ret;
148    }
149    ```
150
1515. 执行训练。
152
153    使用OH_AI_ModelSetTrainMode接口设置训练模式,使用OH_AI_RunStep接口进行模型训练。
154
155    ```c
156    // Set Train Mode
157    ret = OH_AI_ModelSetTrainMode(model, true);
158    if (ret != OH_AI_STATUS_SUCCESS) {
159        printf("OH_AI_ModelSetTrainMode failed, ret: %d.\n", ret);
160        OH_AI_ModelDestroy(&model);
161        OH_AI_ContextDestroy(&context);
162        return ret;
163    }
164
165    // Model Train Step
166    ret = OH_AI_RunStep(model, NULL, NULL);
167    if (ret != OH_AI_STATUS_SUCCESS) {
168        printf("OH_AI_RunStep failed, ret: %d.\n", ret);
169        OH_AI_ModelDestroy(&model);
170        OH_AI_ContextDestroy(&context);
171        return ret;
172    }
173    printf("Train Step Success.\n");
174    ```
175
1766. 导出训练后模型。
177
178    使用OH_AI_ExportModel接口导出训练后模型。
179
180    ```c
181    // Export Train Model
182    ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_train_model, OH_AI_NO_QUANT, false, NULL, 0);
183    if (ret != OH_AI_STATUS_SUCCESS) {
184        printf("OH_AI_ExportModel train failed, ret: %d.\n", ret);
185        OH_AI_ModelDestroy(&model);
186        OH_AI_ContextDestroy(&context);
187        return ret;
188    }
189    printf("Export Train Model Success.\n");
190
191    // Export Inference Model
192    ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_infer_model, OH_AI_NO_QUANT, true, NULL, 0);
193    if (ret != OH_AI_STATUS_SUCCESS) {
194        printf("OH_AI_ExportModel inference failed, ret: %d.\n", ret);
195        OH_AI_ModelDestroy(&model);
196        OH_AI_ContextDestroy(&context);
197        return ret;
198    }
199    printf("Export Inference Model Success.\n");
200    ```
201
2027. 释放模型。
203
204    不再使用MindSpore Lite推理框架时,需要释放已经创建的模型。
205
206    ```c
207    // Delete model and context.
208    OH_AI_ModelDestroy(&model);
209    OH_AI_ContextDestroy(&context);
210    ```
211
212
213## 调测验证
214
2151. 编写CMakeLists.txt216    ```c
217    cmake_minimum_required(VERSION 3.14)
218    project(TrainDemo)
219
220    add_executable(train_demo main.c)
221
222    target_link_libraries(
223            train_demo
224            mindspore_lite_ndk
225    )
226    ```
227
228   - 使用ohos-sdk交叉编译,需要对CMake设置native工具链路径,即:`-DCMAKE_TOOLCHAIN_FILE="/xxx/native/build/cmake/ohos.toolchain.cmake"`。
229
230   - 编译命令如下,其中OHOS_NDK需要设置为native工具链路径:
231      ```shell
232        mkdir -p build
233
234        cd ./build || exit
235        OHOS_NDK=""
236        cmake -G "Unix Makefiles" \
237              -S ../ \
238              -DCMAKE_TOOLCHAIN_FILE="$OHOS_NDK/build/cmake/ohos.toolchain.cmake" \
239              -DOHOS_ARCH=arm64-v8a \
240              -DCMAKE_BUILD_TYPE=Release
241
242        make
243      ```
244
2452. 运行编译的可执行程序。
246
247    - 使用hdc连接设备,并将train_demo和lenet_train.ms推送到设备中的相同目录。
248    - 使用hdc shell进入设备,并进入train_demo所在的目录执行如下命令,即可得到结果。
249
250    ```shell
251    ./train_demo ./lenet_train.ms export_train_model export_infer_model
252    ```
253
254    得到如下输出:
255
256    ```shell
257    Train Step Success.
258    Export Train Model Success.
259    Export Inference Model Success.
260    Tensor name: Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op121, tensor size is 80, elements num: 20.
261    output data is:
262    0.000265 0.000231 0.000254 0.000269 0.000238 0.000228
263    ```
264
265    在train_demo所在目录可以看到导出的两个模型文件:export_train_model.msexport_infer_model.ms266
267
268## 完整示例
269
270```c
271#include <stdlib.h>
272#include <stdio.h>
273#include <string.h>
274#include "mindspore/model.h"
275
276int GenerateInputDataWithRandom(OH_AI_TensorHandleArray inputs) {
277  for (size_t i = 0; i < inputs.handle_num; ++i) {
278    float *input_data = (float *)OH_AI_TensorGetMutableData(inputs.handle_list[i]);
279    if (input_data == NULL) {
280      printf("OH_AI_TensorGetMutableData failed.\n");
281      return  OH_AI_STATUS_LITE_ERROR;
282    }
283    int64_t num = OH_AI_TensorGetElementNum(inputs.handle_list[i]);
284    const int divisor = 10;
285    for (size_t j = 0; j < num; j++) {
286      input_data[j] = (float)(rand() % divisor) / divisor;  // 0--0.9f
287    }
288  }
289  return OH_AI_STATUS_SUCCESS;
290}
291
292int ModelPredict(char* model_file) {
293  // Create and init context, add CPU device info
294  OH_AI_ContextHandle context = OH_AI_ContextCreate();
295  if (context == NULL) {
296    printf("OH_AI_ContextCreate failed.\n");
297    return OH_AI_STATUS_LITE_ERROR;
298  }
299
300  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
301  if (cpu_device_info == NULL) {
302    printf("OH_AI_DeviceInfoCreate failed.\n");
303    OH_AI_ContextDestroy(&context);
304    return OH_AI_STATUS_LITE_ERROR;
305  }
306  OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
307
308  // Create model
309  OH_AI_ModelHandle model = OH_AI_ModelCreate();
310  if (model == NULL) {
311    printf("OH_AI_ModelCreate failed.\n");
312    OH_AI_ContextDestroy(&context);
313    return OH_AI_STATUS_LITE_ERROR;
314  }
315
316  // Build model
317  int ret = OH_AI_ModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context);
318  if (ret != OH_AI_STATUS_SUCCESS) {
319    printf("OH_AI_ModelBuildFromFile failed, ret: %d.\n", ret);
320    OH_AI_ModelDestroy(&model);
321    OH_AI_ContextDestroy(&context);
322    return ret;
323  }
324
325  // Get Inputs
326  OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
327  if (inputs.handle_list == NULL) {
328    printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
329    OH_AI_ModelDestroy(&model);
330    OH_AI_ContextDestroy(&context);
331    return ret;
332  }
333
334  // Generate random data as input data.
335  ret = GenerateInputDataWithRandom(inputs);
336  if (ret != OH_AI_STATUS_SUCCESS) {
337    printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
338    OH_AI_ModelDestroy(&model);
339    OH_AI_ContextDestroy(&context);
340    return ret;
341  }
342
343  // Model Predict
344  OH_AI_TensorHandleArray outputs;
345  ret = OH_AI_ModelPredict(model, inputs, &outputs, NULL, NULL);
346  if (ret != OH_AI_STATUS_SUCCESS) {
347    printf("MSModelPredict failed, ret: %d.\n", ret);
348    OH_AI_ModelDestroy(&model);
349    OH_AI_ContextDestroy(&context);
350    return ret;
351  }
352
353  // Print Output Tensor Data.
354  for (size_t i = 0; i < outputs.handle_num; ++i) {
355    OH_AI_TensorHandle tensor = outputs.handle_list[i];
356    int64_t element_num = OH_AI_TensorGetElementNum(tensor);
357    printf("Tensor name: %s, tensor size is %ld ,elements num: %ld.\n", OH_AI_TensorGetName(tensor),
358           OH_AI_TensorGetDataSize(tensor), element_num);
359    const float *data = (const float *)OH_AI_TensorGetData(tensor);
360    printf("output data is:\n");
361    const int max_print_num = 50;
362    for (int j = 0; j < element_num && j <= max_print_num; ++j) {
363      printf("%f ", data[j]);
364    }
365    printf("\n");
366  }
367
368  OH_AI_ModelDestroy(&model);
369  OH_AI_ContextDestroy(&context);
370  return OH_AI_STATUS_SUCCESS;
371}
372
373int TrainDemo(int argc, const char **argv) {
374  if (argc < 4) {
375    printf("Model file must be provided.\n");
376    printf("Export Train Model path must be provided.\n");
377    printf("Export Inference Model path must be provided.\n");
378    return OH_AI_STATUS_LITE_ERROR;
379  }
380  const char *model_file = argv[1];
381  const char *export_train_model = argv[2];
382  const char *export_infer_model = argv[3];
383
384  // Create and init context, add CPU device info
385  OH_AI_ContextHandle context = OH_AI_ContextCreate();
386  if (context == NULL) {
387    printf("OH_AI_ContextCreate failed.\n");
388    return OH_AI_STATUS_LITE_ERROR;
389  }
390
391  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
392  if (cpu_device_info == NULL) {
393    printf("OH_AI_DeviceInfoCreate failed.\n");
394    OH_AI_ContextDestroy(&context);
395    return OH_AI_STATUS_LITE_ERROR;
396  }
397  OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
398
399  // Create trainCfg
400  OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
401  if (trainCfg == NULL) {
402    printf("OH_AI_TrainCfgCreate failed.\n");
403    OH_AI_ContextDestroy(&context);
404    return OH_AI_STATUS_LITE_ERROR;
405  }
406
407  // Create model
408  OH_AI_ModelHandle model = OH_AI_ModelCreate();
409  if (model == NULL) {
410    printf("OH_AI_ModelCreate failed.\n");
411    OH_AI_TrainCfgDestroy(&trainCfg);
412    OH_AI_ContextDestroy(&context);
413    return OH_AI_STATUS_LITE_ERROR;
414  }
415
416  // Build model
417  int ret = OH_AI_TrainModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
418  if (ret != OH_AI_STATUS_SUCCESS) {
419    printf("OH_AI_TrainModelBuildFromFile failed, ret: %d.\n", ret);
420    OH_AI_ModelDestroy(&model);
421    OH_AI_ContextDestroy(&context);
422    return ret;
423  }
424
425  // Get Inputs
426  OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
427  if (inputs.handle_list == NULL) {
428    printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
429    OH_AI_ModelDestroy(&model);
430    OH_AI_ContextDestroy(&context);
431    return ret;
432  }
433
434  // Generate random data as input data.
435  ret = GenerateInputDataWithRandom(inputs);
436  if (ret != OH_AI_STATUS_SUCCESS) {
437    printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
438    OH_AI_ModelDestroy(&model);
439    OH_AI_ContextDestroy(&context);
440    return ret;
441  }
442
443  // Set Train Mode
444  ret = OH_AI_ModelSetTrainMode(model, true);
445  if (ret != OH_AI_STATUS_SUCCESS) {
446    printf("OH_AI_ModelSetTrainMode failed, ret: %d.\n", ret);
447    OH_AI_ModelDestroy(&model);
448    OH_AI_ContextDestroy(&context);
449    return ret;
450  }
451
452  // Model Train Step
453  ret = OH_AI_RunStep(model, NULL, NULL);
454  if (ret != OH_AI_STATUS_SUCCESS) {
455    printf("OH_AI_RunStep failed, ret: %d.\n", ret);
456    OH_AI_ModelDestroy(&model);
457    OH_AI_ContextDestroy(&context);
458    return ret;
459  }
460  printf("Train Step Success.\n");
461
462  // Export Train Model
463  ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_train_model, OH_AI_NO_QUANT, false, NULL, 0);
464  if (ret != OH_AI_STATUS_SUCCESS) {
465    printf("OH_AI_ExportModel train failed, ret: %d.\n", ret);
466    OH_AI_ModelDestroy(&model);
467    OH_AI_ContextDestroy(&context);
468    return ret;
469  }
470  printf("Export Train Model Success.\n");
471
472  // Export Inference Model
473  ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_infer_model, OH_AI_NO_QUANT, true, NULL, 0);
474  if (ret != OH_AI_STATUS_SUCCESS) {
475    printf("OH_AI_ExportModel inference failed, ret: %d.\n", ret);
476    OH_AI_ModelDestroy(&model);
477    OH_AI_ContextDestroy(&context);
478    return ret;
479  }
480  printf("Export Inference Model Success.\n");
481
482  // Delete model and context.
483  OH_AI_ModelDestroy(&model);
484  OH_AI_ContextDestroy(&context);
485
486  // Use The Exported Model to predict
487  char *exported_model = strcat(export_infer_model, ".ms");
488  ret = ModelPredict(exported_model);
489  if (ret != OH_AI_STATUS_SUCCESS) {
490    printf("Exported Model to predict failed, ret: %d.\n", ret);
491    return ret;
492  }
493  return OH_AI_STATUS_SUCCESS;
494}
495
496int main(int argc, const char **argv) { return TrainDemo(argc, argv); }
497
498```