• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/ge_runtime/model_runner.h"
18 #include "runtime/device/ascend/ge_runtime/runtime_model.h"
19 #include "runtime/device/ascend/ge_runtime/davinci_model.h"
20 #include "mindspore/core/utils/log_adapter.h"
21 
22 namespace mindspore::ge::model_runner {
Instance()23 ModelRunner &ModelRunner::Instance() {
24   static ModelRunner instance{};  // Guaranteed to be destroyed.
25   return instance;
26 }
27 
LoadDavinciModel(uint32_t device_id,uint64_t session_id,uint32_t model_id,const std::shared_ptr<DavinciModel> & davinci_model)28 void ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
29                                    const std::shared_ptr<DavinciModel> &davinci_model) {
30   std::shared_ptr<RuntimeModel> model = std::make_shared<RuntimeModel>();
31   model->Load(device_id, session_id, davinci_model);
32   runtime_models_[model_id] = model;
33 }
34 
DistributeTask(uint32_t model_id)35 void ModelRunner::DistributeTask(uint32_t model_id) {
36   auto model_iter = runtime_models_.find(model_id);
37   if (model_iter == runtime_models_.end()) {
38     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
39   }
40   MS_EXCEPTION_IF_NULL(model_iter->second);
41   model_iter->second->DistributeTask();
42 }
43 
LoadModelComplete(uint32_t model_id)44 void ModelRunner::LoadModelComplete(uint32_t model_id) {
45   auto model_iter = runtime_models_.find(model_id);
46   if (model_iter == runtime_models_.end()) {
47     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
48   }
49   MS_EXCEPTION_IF_NULL(model_iter->second);
50   model_iter->second->LoadComplete();
51 }
52 
GetTaskIdList(uint32_t model_id) const53 const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
54   auto model_iter = runtime_models_.find(model_id);
55   if (model_iter == runtime_models_.end()) {
56     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
57   }
58   MS_EXCEPTION_IF_NULL(model_iter->second);
59   return model_iter->second->GetTaskIdList();
60 }
61 
GetStreamIdList(uint32_t model_id) const62 const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
63   auto model_iter = runtime_models_.find(model_id);
64   if (model_iter == runtime_models_.end()) {
65     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
66   }
67   MS_EXCEPTION_IF_NULL(model_iter->second);
68   return model_iter->second->GetStreamIdList();
69 }
70 
GetRuntimeInfoMap(uint32_t model_id) const71 const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
72   auto model_iter = runtime_models_.find(model_id);
73   if (model_iter == runtime_models_.end()) {
74     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
75   }
76   MS_EXCEPTION_IF_NULL(model_iter->second);
77   return model_iter->second->GetRuntimeInfoMap();
78 }
79 
GetModelHandle(uint32_t model_id) const80 void *ModelRunner::GetModelHandle(uint32_t model_id) const {
81   auto model_iter = runtime_models_.find(model_id);
82   if (model_iter == runtime_models_.end()) {
83     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
84   }
85   MS_EXCEPTION_IF_NULL(model_iter->second);
86   return model_iter->second->GetModelHandle();
87 }
88 
GetModelStream(uint32_t model_id) const89 void *ModelRunner::GetModelStream(uint32_t model_id) const {
90   auto model_iter = runtime_models_.find(model_id);
91   if (model_iter == runtime_models_.end()) {
92     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
93   }
94   MS_EXCEPTION_IF_NULL(model_iter->second);
95   return model_iter->second->GetModelStream();
96 }
97 
UnloadModel(uint32_t model_id)98 void ModelRunner::UnloadModel(uint32_t model_id) {
99   auto iter = runtime_models_.find(model_id);
100   if (iter != runtime_models_.end()) {
101     (void)runtime_models_.erase(iter);
102   }
103 }
104 
RunModel(uint32_t model_id)105 void ModelRunner::RunModel(uint32_t model_id) {
106   auto model_iter = runtime_models_.find(model_id);
107   if (model_iter == runtime_models_.end()) {
108     MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
109   }
110   MS_EXCEPTION_IF_NULL(model_iter->second);
111   model_iter->second->Run();
112 }
113 }  // namespace mindspore::ge::model_runner
114