• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/benchmark/run_benchmark.h"
18 #include <string>
19 #include <memory>
20 #include "tools/benchmark/benchmark.h"
21 #include "tools/benchmark/benchmark_unified_api.h"
22 #include "tools/benchmark/benchmark_c_api.h"
23 
24 namespace mindspore {
25 namespace lite {
RunBenchmark(int argc,const char ** argv)26 int RunBenchmark(int argc, const char **argv) {
27   BenchmarkFlags flags;
28   Option<std::string> err = flags.ParseFlags(argc, argv);
29   if (err.IsSome()) {
30     std::cerr << err.Get() << std::endl;
31     std::cerr << flags.Usage() << std::endl;
32     return RET_ERROR;
33   }
34 
35   if (flags.help) {
36     std::cerr << flags.Usage() << std::endl;
37     return RET_OK;
38   }
39 
40   auto api_type = std::getenv("MSLITE_API_TYPE");
41   if (api_type != nullptr) {
42     MS_LOG(INFO) << "MSLITE_API_TYPE = " << api_type;
43     std::cout << "MSLITE_API_TYPE = " << api_type << std::endl;
44   }
45 
46   std::unique_ptr<BenchmarkBase> benchmark;
47   if (flags.config_file_ != "" || (api_type != nullptr && std::string(api_type) == "NEW")) {
48     benchmark = std::make_unique<BenchmarkUnifiedApi>(&flags);
49   } else if (api_type == nullptr || std::string(api_type) == "OLD") {
50     benchmark = std::make_unique<Benchmark>(&flags);
51   } else if (std::string(api_type) == "C") {
52     benchmark = std::make_unique<tools::BenchmarkCApi>(&flags);
53   } else {
54     BENCHMARK_LOG_ERROR("Invalid MSLITE_API_TYPE, (OLD/NEW/C, default:OLD)");
55     return RET_ERROR;
56   }
57   if (benchmark == nullptr) {
58     BENCHMARK_LOG_ERROR("new benchmark failed ");
59     return RET_ERROR;
60   }
61 
62   auto status = benchmark->Init();
63   if (status != 0) {
64     BENCHMARK_LOG_ERROR("Benchmark init Error : " << status);
65     return RET_ERROR;
66   }
67   auto model_name = flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1);
68 
69   status = benchmark->RunBenchmark();
70   if (status != 0) {
71     BENCHMARK_LOG_ERROR("Run Benchmark " << model_name << " Failed : " << status);
72     return RET_ERROR;
73   }
74 
75   MS_LOG(INFO) << "Run Benchmark " << model_name << " Success.";
76   std::cout << "Run Benchmark " << model_name << " Success." << std::endl;
77   return RET_OK;
78 }
79 }  // namespace lite
80 }  // namespace mindspore
81