• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_BENCHMARK_C_API_H_
17 #define MINDSPORE_LITE_TOOLS_BENCHMARK_BENCHMARK_C_API_H_
18 
19 #include <vector>
20 #include <string>
21 #include "tools/benchmark/benchmark_base.h"
22 #include "include/c_api/model_c.h"
23 #include "include/c_api/context_c.h"
24 
25 #ifdef __cplusplus
26 extern "C" {
27 #endif
28 bool TimeBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
29                         const OH_AI_CallBackParam kernel_Info);
30 bool TimeAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
31                        const OH_AI_CallBackParam kernel_Info);
32 #ifdef __cplusplus
33 }
34 #endif
35 
36 using mindspore::lite::BenchmarkBase;
37 using mindspore::lite::BenchmarkFlags;
38 
39 namespace mindspore::tools {
40 class OH_AI_API BenchmarkCApi : public BenchmarkBase {
41  public:
BenchmarkCApi(BenchmarkFlags * flags)42   explicit BenchmarkCApi(BenchmarkFlags *flags) : BenchmarkBase(flags) {}
43 
~BenchmarkCApi()44   virtual ~BenchmarkCApi() { OH_AI_ModelDestroy(&model_); }
45 
46   int RunBenchmark() override;
47 
48  protected:
49   int CompareDataGetTotalBiasAndSize(const std::string &name, OH_AI_TensorHandle tensor, float *total_bias,
50                                      int *total_size);
51   int InitContext();
52   int GenerateInputData() override;
53   int ReadInputFile() override;
54   int GetDataTypeByTensorName(const std::string &tensor_name) override;
55   int CompareOutput() override;
56 
57   int InitTimeProfilingCallbackParameter() override;
58   int InitPerfProfilingCallbackParameter() override;
59   int InitDumpTensorDataCallbackParameter() override;
60   int InitPrintTensorDataCallbackParameter() override;
61 
62   int PrintInputData();
63   int MarkPerformance();
64   int MarkAccuracy();
65 
66  private:
67   OH_AI_ModelHandle model_ = nullptr;
68   OH_AI_ContextHandle context_ = nullptr;
69   OH_AI_TensorHandleArray inputs_;
70   OH_AI_TensorHandleArray outputs_;
71 
72   OH_AI_KernelCallBack before_call_back_ = nullptr;
73   OH_AI_KernelCallBack after_call_back_ = nullptr;
74 };
75 }  // namespace mindspore::tools
76 #endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_BENCHMARK_C_API_H_
77