• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/model_parallel_runner.h"
17 #include "src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.h"
18 #include "src/extendrt/cxx_api/model_pool/runner_config.h"
19 #include "src/common/log_adapter.h"
20 #ifdef CAPTURE_SIGNALS
21 #include "src/extendrt/signal_handler.h"
22 #endif
23 namespace mindspore {
24 namespace {
25 constexpr size_t kMaxSectionNum = 100;
26 constexpr size_t kMaxConfigNumPerSection = 1000;
27 }  // namespace
28 #ifdef USE_GLOG
29 extern "C" {
30 extern void mindspore_log_init();
31 }
32 #endif
33 
34 std::mutex g_model_parallel_runner_mutex;
35 
RunnerConfig()36 RunnerConfig::RunnerConfig() : data_(std::make_shared<Data>()) {}
37 
~RunnerConfig()38 RunnerConfig::~RunnerConfig() {}
39 
SetWorkersNum(int32_t workers_num)40 void RunnerConfig::SetWorkersNum(int32_t workers_num) {
41   if (data_ == nullptr) {
42     MS_LOG(ERROR) << "Runner config data is nullptr.";
43     return;
44   }
45   data_->workers_num = workers_num;
46 }
47 
SetContext(const std::shared_ptr<Context> & context)48 void RunnerConfig::SetContext(const std::shared_ptr<Context> &context) {
49   if (data_ == nullptr) {
50     MS_LOG(ERROR) << "Runner config data is nullptr.";
51     return;
52   }
53   MS_CHECK_TRUE_RET_VOID(context != nullptr);
54   data_->context = context;
55 }
56 
GetWorkersNum() const57 int32_t RunnerConfig::GetWorkersNum() const {
58   if (data_ == nullptr) {
59     MS_LOG(ERROR) << "Runner config data is nullptr.";
60     return -1;
61   }
62   return data_->workers_num;
63 }
64 
GetContext() const65 std::shared_ptr<Context> RunnerConfig::GetContext() const {
66   if (data_ == nullptr) {
67     MS_LOG(ERROR) << "Runner config data is nullptr.";
68     return nullptr;
69   }
70   return data_->context;
71 }
72 
SetConfigInfo(const std::vector<char> & section,const std::map<std::vector<char>,std::vector<char>> & config)73 void RunnerConfig::SetConfigInfo(const std::vector<char> &section,
74                                  const std::map<std::vector<char>, std::vector<char>> &config) {
75   if (data_ == nullptr) {
76     MS_LOG(ERROR) << "Runner config data is nullptr.";
77     return;
78   }
79   if (data_->config_info.size() > kMaxSectionNum) {
80     MS_LOG(ERROR) << "The number of added sessions exceeds the maximum[" << kMaxSectionNum << "] limit.";
81     return;
82   }
83   if (config.size() > kMaxConfigNumPerSection) {
84     MS_LOG(ERROR) << "The number of added config exceeds the maximum[" << kMaxConfigNumPerSection << "] limit.";
85     return;
86   }
87   data_->config_info[CharToString(section)] = MapVectorCharToString(config);
88   return;
89 }
90 
SetConfigPath(const std::vector<char> & config_path)91 void RunnerConfig::SetConfigPath(const std::vector<char> &config_path) {
92   if (data_ == nullptr) {
93     MS_LOG(ERROR) << "Runner config data is nullptr.";
94     return;
95   }
96   data_->config_path = CharToString(config_path);
97   return;
98 }
99 
GetConfigPathChar() const100 std::vector<char> RunnerConfig::GetConfigPathChar() const {
101   if (data_ == nullptr) {
102     MS_LOG(ERROR) << "Runner config data is nullptr.";
103     std::vector<char> empty;
104     return empty;
105   }
106   return StringToChar(data_->config_path);
107 }
108 
GetConfigInfoChar() const109 std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> RunnerConfig::GetConfigInfoChar() const {
110   if (data_ == nullptr) {
111     MS_LOG(ERROR) << "Runner config data is nullptr.";
112     std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> empty;
113     return empty;
114   }
115   return MapMapStringToChar(data_->config_info);
116 }
117 
SetDeviceIds(const std::vector<uint32_t> & device_ids)118 void RunnerConfig::SetDeviceIds(const std::vector<uint32_t> &device_ids) {
119   if (data_ == nullptr) {
120     MS_LOG(ERROR) << "Runner config data is nullptr.";
121     return;
122   }
123   data_->device_ids = device_ids;
124 }
125 
GetDeviceIds() const126 std::vector<uint32_t> RunnerConfig::GetDeviceIds() const {
127   if (data_ == nullptr) {
128     MS_LOG(ERROR) << "Runner config data is nullptr.";
129     return {};
130   }
131   return data_->device_ids;
132 }
133 
ModelParallelRunner()134 ModelParallelRunner::ModelParallelRunner() {}
135 
~ModelParallelRunner()136 ModelParallelRunner::~ModelParallelRunner() {}
137 
Init(const std::vector<char> & model_path,const std::shared_ptr<RunnerConfig> & runner_config)138 Status ModelParallelRunner::Init(const std::vector<char> &model_path,
139                                  const std::shared_ptr<RunnerConfig> &runner_config) {
140   {
141     std::lock_guard<std::mutex> l(g_model_parallel_runner_mutex);
142 #ifdef USE_GLOG
143     mindspore::mindspore_log_init();
144 #endif
145     if (model_parallel_runner_impl_ == nullptr) {
146       model_parallel_runner_impl_ = std::make_shared<ModelParallelRunnerImpl>();
147       if (model_parallel_runner_impl_ == nullptr) {
148         MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
149         return kLiteNullptr;
150       }
151     }
152   }
153   return model_parallel_runner_impl_->Init(CharToString(model_path), runner_config);
154 }
155 
Init(const void * model_data,size_t data_size,const std::shared_ptr<RunnerConfig> & runner_config)156 Status ModelParallelRunner::Init(const void *model_data, size_t data_size,
157                                  const std::shared_ptr<RunnerConfig> &runner_config) {
158   {
159     std::lock_guard<std::mutex> l(g_model_parallel_runner_mutex);
160 #ifdef USE_GLOG
161     mindspore::mindspore_log_init();
162 #endif
163     if (model_parallel_runner_impl_ == nullptr) {
164       model_parallel_runner_impl_ = std::make_shared<ModelParallelRunnerImpl>();
165       if (model_parallel_runner_impl_ == nullptr) {
166         MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
167         return kLiteNullptr;
168       }
169     }
170   }
171   return model_parallel_runner_impl_->Init(model_data, data_size, runner_config);
172 }
173 
GetInputs()174 std::vector<MSTensor> ModelParallelRunner::GetInputs() {
175   if (model_parallel_runner_impl_ == nullptr) {
176     std::vector<MSTensor> empty;
177     MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetInput API.";
178     return empty;
179   }
180   return model_parallel_runner_impl_->GetInputs();
181 }
182 
GetOutputs()183 std::vector<MSTensor> ModelParallelRunner::GetOutputs() {
184   if (model_parallel_runner_impl_ == nullptr) {
185     std::vector<MSTensor> empty;
186     MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetOutputs API.";
187     return empty;
188   }
189   return model_parallel_runner_impl_->GetOutputs();
190 }
191 
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)192 Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
193                                     const MSKernelCallBack &before, const MSKernelCallBack &after) {
194   if (model_parallel_runner_impl_ == nullptr) {
195     MS_LOG(ERROR) << "ModelParallelRunner Not Initialize.";
196     return kLiteNullptr;
197   }
198   return model_parallel_runner_impl_->Predict(inputs, outputs, before, after);
199 }
200 }  // namespace mindspore
201