• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "model_utils.h"
17 #include <securec.h>
18 #include "gtest/gtest.h"
19 #include "include/c_api/context_c.h"
20 #include "include/c_api/model_c.h"
21 #include "include/c_api/types_c.h"
22 #include "include/c_api/status_c.h"
23 #include "include/c_api/data_type_c.h"
24 #include "include/c_api/tensor_c.h"
25 #include "include/c_api/format_c.h"
26 #include "common.h"
27 
28 std::string g_testResourcesDir = "/data/test/resource/";
29 
30 // function before callback
PrintBeforeCallback(const OH_AI_TensorHandleArray inputs,const OH_AI_TensorHandleArray outputs,const OH_AI_CallBackParam kernelInfo)31 bool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
32                          const OH_AI_CallBackParam kernelInfo) {
33     std::cout << "Before forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl;
34     return true;
35 }
36 
37 // function after callback
PrintAfterCallback(const OH_AI_TensorHandleArray inputs,const OH_AI_TensorHandleArray outputs,const OH_AI_CallBackParam kernelInfo)38 bool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
39                         const OH_AI_CallBackParam kernelInfo) {
40     std::cout << "After forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl;
41     return true;
42 }
43 
44 // add cpu device info
AddContextDeviceCPU(OH_AI_ContextHandle context)45 void AddContextDeviceCPU(OH_AI_ContextHandle context) {
46     OH_AI_DeviceInfoHandle cpuDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
47     ASSERT_NE(cpuDeviceInfo, nullptr);
48     OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(cpuDeviceInfo);
49     printf("==========deviceType:%d\n", deviceType);
50     ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_CPU);
51     OH_AI_ContextAddDeviceInfo(context, cpuDeviceInfo);
52 }
53 
IsNPU()54 bool IsNPU() {
55     size_t num = 0;
56     auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
57     if (desc == nullptr) {
58         return false;
59     }
60     auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
61     const std::string npuNamePrefix = "NPU_";
62     if (strncmp(npuNamePrefix.c_str(), name, npuNamePrefix.size()) != 0) {
63         return false;
64     }
65     return true;
66 }
67 
68 // add nnrt device info
AddContextDeviceNNRT(OH_AI_ContextHandle context)69 void AddContextDeviceNNRT(OH_AI_ContextHandle context) {
70     size_t num = 0;
71     auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
72     if (desc == nullptr) {
73         return;
74     }
75 
76     std::cout << "found " << num << " nnrt devices" << std::endl;
77     auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
78     auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
79     auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc);
80     std::cout << "NNRT device: id = " << id << ", name: " << name << ", type:" << type << std::endl;
81 
82     OH_AI_DeviceInfoHandle nnrtDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
83     ASSERT_NE(nnrtDeviceInfo, nullptr);
84     OH_AI_DeviceInfoSetDeviceId(nnrtDeviceInfo, id);
85     OH_AI_DestroyAllNNRTDeviceDescs(&desc);
86 
87     OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo);
88     printf("==========deviceType:%d\n", deviceType);
89     ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT);
90 
91     OH_AI_DeviceInfoSetPerformanceMode(nnrtDeviceInfo, OH_AI_PERFORMANCE_MEDIUM);
92     ASSERT_EQ(OH_AI_DeviceInfoGetPerformanceMode(nnrtDeviceInfo), OH_AI_PERFORMANCE_MEDIUM);
93     OH_AI_DeviceInfoSetPriority(nnrtDeviceInfo, OH_AI_PRIORITY_MEDIUM);
94     ASSERT_EQ(OH_AI_DeviceInfoGetPriority(nnrtDeviceInfo), OH_AI_PRIORITY_MEDIUM);
95 
96     OH_AI_ContextAddDeviceInfo(context, nnrtDeviceInfo);
97 }
98 
99 // fill data to inputs tensor
FillInputsData(OH_AI_TensorHandleArray inputs,std::string modelName,bool isTranspose)100 void FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose) {
101     for (size_t i = 0; i < inputs.handle_num; ++i) {
102         printf("==========ReadFile==========\n");
103         size_t size1;
104         size_t *ptrSize1 = &size1;
105         std::string inputDataPath = g_testResourcesDir + modelName + "_" + std::to_string(i) + ".input";
106         const char *imagePath = inputDataPath.c_str();
107         char *imageBuf = ReadFile(imagePath, ptrSize1);
108         ASSERT_NE(imageBuf, nullptr);
109         OH_AI_TensorHandle tensor = inputs.handle_list[i];
110         int64_t elementNum = OH_AI_TensorGetElementNum(tensor);
111         printf("Tensor name: %s. \n", OH_AI_TensorGetName(tensor));
112         float *inputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(inputs.handle_list[i]));
113         ASSERT_NE(inputData, nullptr);
114         if (isTranspose) {
115             printf("==========Transpose==========\n");
116             size_t shapeNum;
117             const int64_t *shape = OH_AI_TensorGetShape(tensor, &shapeNum);
118             auto imageBufNhwc = new char[size1];
119             PackNCHWToNHWCFp32(imageBuf, imageBufNhwc, shape[0], shape[1] * shape[2], shape[3]);
120             errno_t ret = memcpy_s(inputData, size1, imageBufNhwc, size1);
121             if (ret != EOK) {
122                 printf("memcpy_s failed, ret: %d\n", ret);
123             }
124             delete[] imageBufNhwc;
125         } else {
126             errno_t ret = memcpy_s(inputData, size1, imageBuf, size1);
127             if (ret != EOK) {
128                 printf("memcpy_s failed, ret: %d\n", ret);
129             }
130         }
131         printf("input data after filling is: ");
132         for (int j = 0; j < elementNum && j <= 20; ++j) {
133             printf("%f ", inputData[j]);
134         }
135         printf("\n");
136         delete[] imageBuf;
137     }
138 }
139 
140 // compare result after predict
CompareResult(OH_AI_TensorHandleArray outputs,std::string modelName,float atol,float rtol)141 void CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol, float rtol) {
142     printf("==========GetOutput==========\n");
143     for (size_t i = 0; i < outputs.handle_num; ++i) {
144         OH_AI_TensorHandle tensor = outputs.handle_list[i];
145         int64_t elementNum = OH_AI_TensorGetElementNum(tensor);
146         printf("Tensor name: %s .\n", OH_AI_TensorGetName(tensor));
147         float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor));
148         printf("output data is:");
149         for (int j = 0; j < elementNum && j <= 20; ++j) {
150             printf("%f ", outputData[j]);
151         }
152         printf("\n");
153         printf("==========compFp32WithTData==========\n");
154         std::string outputFile = g_testResourcesDir + modelName + std::to_string(i) + ".output";
155         bool result = compFp32WithTData(outputData, outputFile, atol, rtol, false);
156         EXPECT_EQ(result, true);
157     }
158 }
159 
160 // model build and predict
ModelPredict(OH_AI_ModelHandle model,OH_AI_ContextHandle context,std::string modelName,OH_AI_ShapeInfo shapeInfos,bool buildByGraph,bool isTranspose,bool isCallback)161 void ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
162                   OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) {
163     std::string modelPath = g_testResourcesDir + modelName + ".ms";
164     const char *graphPath = modelPath.c_str();
165     OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
166     if (buildByGraph) {
167         printf("==========Build model by graphBuf==========\n");
168         size_t size;
169         size_t *ptrSize = &size;
170         char *graphBuf = ReadFile(graphPath, ptrSize);
171         ASSERT_NE(graphBuf, nullptr);
172         ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context);
173         delete[] graphBuf;
174     } else {
175         printf("==========Build model==========\n");
176         ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context);
177     }
178     printf("==========build model return code:%d\n", ret);
179     ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
180     printf("==========GetInputs==========\n");
181     OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
182     ASSERT_NE(inputs.handle_list, nullptr);
183     if (shapeInfos.shape_num != 0) {
184         printf("==========Resizes==========\n");
185         OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num);
186         printf("==========Resizes return code:%d\n", resize_ret);
187         ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS);
188     }
189 
190     FillInputsData(inputs, modelName, isTranspose);
191     OH_AI_TensorHandleArray outputs;
192     OH_AI_Status predictRet = OH_AI_STATUS_SUCCESS;
193     if (isCallback) {
194         printf("==========Model Predict Callback==========\n");
195         OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback;
196         OH_AI_KernelCallBack afterCallBack = PrintAfterCallback;
197         predictRet = OH_AI_ModelPredict(model, inputs, &outputs, beforeCallBack, afterCallBack);
198     } else {
199         printf("==========Model Predict==========\n");
200         predictRet = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
201     }
202     printf("==========Model Predict End==========\n");
203     ASSERT_EQ(predictRet, OH_AI_STATUS_SUCCESS);
204     printf("=========CompareResult===========\n");
205     CompareResult(outputs, modelName);
206     printf("=========OH_AI_ModelDestroy===========\n");
207     OH_AI_ModelDestroy(&model);
208     printf("=========OH_AI_ModelDestroy End===========\n");
209 }
210 
211 // model train build and predict
ModelTrain(OH_AI_ModelHandle model,OH_AI_ContextHandle context,std::string modelName,OH_AI_ShapeInfo shapeInfos,bool buildByGraph,bool isTranspose,bool isCallback)212 void ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
213                 OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) {
214     std::string modelPath = g_testResourcesDir + modelName + ".ms";
215     const char *graphPath = modelPath.c_str();
216     OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
217     OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
218     if (buildByGraph) {
219         printf("==========Build model by graphBuf==========\n");
220         size_t size;
221         size_t *ptrSize = &size;
222         char *graphBuf = ReadFile(graphPath, ptrSize);
223         ASSERT_NE(graphBuf, nullptr);
224         ret = OH_AI_TrainModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
225         delete[] graphBuf;
226     } else {
227         printf("==========Build model==========\n");
228         ret = OH_AI_TrainModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
229     }
230     printf("==========build model return code:%d\n", ret);
231     ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
232     printf("==========GetInputs==========\n");
233     OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
234     ASSERT_NE(inputs.handle_list, nullptr);
235     if (shapeInfos.shape_num != 0) {
236         printf("==========Resizes==========\n");
237         OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num);
238         printf("==========Resizes return code:%d\n", resize_ret);
239         ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS);
240     }
241     FillInputsData(inputs, modelName, isTranspose);
242     ret = OH_AI_ModelSetTrainMode(model, true);
243     ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
244     if (isCallback) {
245         printf("==========Model RunStep Callback==========\n");
246         OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback;
247         OH_AI_KernelCallBack afterCallBack = PrintAfterCallback;
248         ret = OH_AI_RunStep(model, beforeCallBack, afterCallBack);
249     } else {
250         printf("==========Model RunStep==========\n");
251         ret = OH_AI_RunStep(model, nullptr, nullptr);
252     }
253     printf("==========Model RunStep End==========\n");
254     ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
255 }
256