1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ 19 20 #include <functional> 21 #include <map> 22 #include <memory> 23 #include <unordered_map> 24 #include "runtime/device/ascend/ge_runtime/task_info.h" 25 #include "mindspore/core/utils/log_adapter.h" 26 27 namespace mindspore::ge::model_runner { 28 class Task; 29 class ModelContext; 30 using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>; 31 32 class TaskFactory { 33 private: TaskFactory()34 TaskFactory() {} ~TaskFactory()35 ~TaskFactory() {} RegisterCreator(const TaskInfoType & type,const TASK_CREATOR_FUN & func)36 void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { 37 if (creator_map_.find(type) != creator_map_.end()) { 38 MS_LOG(WARNING) << "Creator type " << type << " already exist."; 39 } 40 creator_map_[type] = func; 41 } 42 43 std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_; 44 45 public: GetInstance()46 static TaskFactory &GetInstance() { 47 static TaskFactory instance; 48 return instance; 49 } 50 Create(const ModelContext & model_context,const std::shared_ptr<TaskInfo> & task_info)51 std::shared_ptr<Task> Create(const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) const { 52 if (task_info == nullptr) { 53 MS_LOG(ERROR) << "task_info is null."; 54 return nullptr; 55 } 56 57 auto iter = creator_map_.find(task_info->type()); 58 if (iter == creator_map_.end()) { 59 MS_LOG(ERROR) << "Unknown task type " << task_info->type(); 60 return nullptr; 61 } 62 return iter->second(model_context, task_info); 63 } 64 65 class Register { 66 public: Register(const TaskInfoType & type,const TASK_CREATOR_FUN & func)67 Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { 68 MS_LOG(DEBUG) << "register type " << type; 69 TaskFactory::GetInstance().RegisterCreator(type, func); 70 } 71 ~Register()72 ~Register() {} 73 }; 74 }; 75 76 #define REGISTER_TASK(type, task_clazz, task_info_clazz) \ 77 TaskFactory::Register g_##task_clazz##_register( \ 78 type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \ 79 std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ 80 return std::make_shared<task_clazz>(model_context, concrete_task_info); \ 81 }); 82 83 } // namespace mindspore::ge::model_runner 84 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ 85