1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/benchmark/run_benchmark.h"
18 #include <string>
19 #include <memory>
20 #include "tools/benchmark/benchmark.h"
21 #include "tools/benchmark/benchmark_unified_api.h"
22 #include "tools/benchmark/benchmark_c_api.h"
23
24 namespace mindspore {
25 namespace lite {
RunBenchmark(int argc,const char ** argv)26 int RunBenchmark(int argc, const char **argv) {
27 BenchmarkFlags flags;
28 Option<std::string> err = flags.ParseFlags(argc, argv);
29 if (err.IsSome()) {
30 std::cerr << err.Get() << std::endl;
31 std::cerr << flags.Usage() << std::endl;
32 return RET_ERROR;
33 }
34
35 if (flags.help) {
36 std::cerr << flags.Usage() << std::endl;
37 return RET_OK;
38 }
39
40 auto api_type = std::getenv("MSLITE_API_TYPE");
41 if (api_type != nullptr) {
42 MS_LOG(INFO) << "MSLITE_API_TYPE = " << api_type;
43 std::cout << "MSLITE_API_TYPE = " << api_type << std::endl;
44 }
45
46 std::unique_ptr<BenchmarkBase> benchmark;
47 if (flags.config_file_ != "" || (api_type != nullptr && std::string(api_type) == "NEW")) {
48 benchmark = std::make_unique<BenchmarkUnifiedApi>(&flags);
49 } else if (api_type == nullptr || std::string(api_type) == "OLD") {
50 benchmark = std::make_unique<Benchmark>(&flags);
51 } else if (std::string(api_type) == "C") {
52 benchmark = std::make_unique<tools::BenchmarkCApi>(&flags);
53 } else {
54 BENCHMARK_LOG_ERROR("Invalid MSLITE_API_TYPE, (OLD/NEW/C, default:OLD)");
55 return RET_ERROR;
56 }
57 if (benchmark == nullptr) {
58 BENCHMARK_LOG_ERROR("new benchmark failed ");
59 return RET_ERROR;
60 }
61
62 auto status = benchmark->Init();
63 if (status != 0) {
64 BENCHMARK_LOG_ERROR("Benchmark init Error : " << status);
65 return RET_ERROR;
66 }
67 auto model_name = flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1);
68
69 status = benchmark->RunBenchmark();
70 if (status != 0) {
71 BENCHMARK_LOG_ERROR("Run Benchmark " << model_name << " Failed : " << status);
72 return RET_ERROR;
73 }
74
75 MS_LOG(INFO) << "Run Benchmark " << model_name << " Success.";
76 std::cout << "Run Benchmark " << model_name << " Success." << std::endl;
77 return RET_OK;
78 }
79 } // namespace lite
80 } // namespace mindspore
81