1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H 17 #define MINDSPORE_CCSRC_CXX_API_FACTORY_H 18 #include <functional> 19 #include <map> 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include <utility> 24 #include "utils/utils.h" 25 26 namespace mindspore { 27 inline std::string g_device_target = "Default"; 28 29 template <class T> 30 class Factory { 31 using U = std::function<std::shared_ptr<T>()>; 32 33 public: 34 Factory(const Factory &) = delete; 35 void operator=(const Factory &) = delete; 36 Instance()37 static Factory &Instance() { 38 static Factory instance; 39 return instance; 40 } 41 Register(const std::string & device_name,U && creator)42 void Register(const std::string &device_name, U &&creator) { 43 if (creators_.find(device_name) == creators_.end()) { 44 (void)creators_.emplace(device_name, creator); 45 } 46 } 47 CheckModelSupport(const std::string & device_name)48 bool CheckModelSupport(const std::string &device_name) { 49 return std::any_of(creators_.begin(), creators_.end(), 50 [&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; }); 51 } 52 Create(const std::string & device_name)53 std::shared_ptr<T> Create(const std::string &device_name) { 54 auto iter = creators_.find(device_name); 55 if (creators_.end() != iter) { 56 MS_EXCEPTION_IF_NULL(iter->second); 57 return (iter->second)(); 58 } 59 60 MS_LOG(ERROR) << "Unsupported device target " << device_name; 61 return nullptr; 62 } 63 64 private: 65 Factory() = default; 66 ~Factory() = default; 67 std::map<std::string, U> creators_; 68 }; 69 70 template <class T> 71 class Registrar { 72 using U = std::function<std::shared_ptr<T>()>; 73 74 public: Registrar(const std::string & device_name,U creator)75 Registrar(const std::string &device_name, U creator) { 76 Factory<T>::Instance().Register(device_name, std::move(creator)); 77 } 78 ~Registrar() = default; 79 }; 80 81 #define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ 82 static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ 83 #DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); }); 84 } // namespace mindspore 85 #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H 86