1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_FACTORY_H_ 17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_FACTORY_H_ 18 19 #include <functional> 20 #include <string> 21 #include <memory> 22 23 #include "utils/hash_map.h" 24 #include "runtime/hardware/device_context.h" 25 #include "src/extendrt/delegate_graph_executor.h" 26 #include "include/api/context.h" 27 #include "src/common/config_infos.h" 28 29 namespace mindspore { 30 using mindspore::device::GraphExecutor; 31 // (zhaizhiqiang): Wrap graph executor as delegate. 32 // typedef std::shared_ptr<GraphSinkDelegate> (*DelegateCreator)(const std::shared_ptr<Context> &); 33 template <typename T> 34 using DelegateCreator = std::function<T(const std::shared_ptr<Context> &, const ConfigInfos &)>; 35 36 template <typename T> 37 class MS_API DelegateRegistry { 38 public: 39 DelegateRegistry() = default; 40 virtual ~DelegateRegistry() = default; 41 GetInstance()42 static DelegateRegistry<T> &GetInstance() { 43 static DelegateRegistry<T> instance; 44 return instance; 45 } 46 RegDelegate(const mindspore::DeviceType & device_type,const std::string & provider,DelegateCreator<T> * creator)47 void RegDelegate(const mindspore::DeviceType &device_type, const std::string &provider, DelegateCreator<T> *creator) { 48 auto it = creator_map_.find(device_type); 49 if (it == creator_map_.end()) { 50 HashMap<std::string, DelegateCreator<T> *> map; 51 map[provider] = creator; 52 creator_map_[device_type] = map; 53 return; 54 } 55 it->second[provider] = creator; 56 } UnRegDelegate(const mindspore::DeviceType & device_type,const std::string & provider)57 void UnRegDelegate(const mindspore::DeviceType &device_type, const std::string &provider) { 58 auto it = creator_map_.find(device_type); 59 if (it != creator_map_.end()) { 60 creator_map_.erase(it); 61 } 62 } GetDelegate(const mindspore::DeviceType & device_type,const std::string & provider,const std::shared_ptr<Context> & ctx,const ConfigInfos & config_infos)63 T GetDelegate(const mindspore::DeviceType &device_type, const std::string &provider, 64 const std::shared_ptr<Context> &ctx, const ConfigInfos &config_infos) { 65 // find common delegate 66 auto it = creator_map_.find(device_type); 67 if (it == creator_map_.end()) { 68 return nullptr; 69 } 70 auto creator_it = it->second.find(provider); 71 if (creator_it == it->second.end()) { 72 return nullptr; 73 } 74 return (*(creator_it->second))(ctx, config_infos); 75 } 76 77 private: 78 mindspore::HashMap<DeviceType, mindspore::HashMap<std::string, DelegateCreator<T> *>> creator_map_; 79 }; 80 81 template <typename T> 82 class DelegateRegistrar { 83 public: DelegateRegistrar(const mindspore::DeviceType & device_type,const std::string & provider,DelegateCreator<T> * creator)84 DelegateRegistrar(const mindspore::DeviceType &device_type, const std::string &provider, 85 DelegateCreator<T> *creator) { 86 DelegateRegistry<T>::GetInstance().RegDelegate(device_type, provider, creator); 87 } 88 ~DelegateRegistrar() = default; 89 }; 90 91 #define REG_DELEGATE(device_type, provider, creator) \ 92 using t = decltype(creator(std::declval<const std::shared_ptr<Context> &>(), std::declval<const ConfigInfos &>())); \ 93 static DelegateCreator<t> func = [](const std::shared_ptr<Context> &context, const ConfigInfos &config_infos) { \ 94 return creator(context, config_infos); \ 95 }; \ 96 static DelegateRegistrar<t> g_##device_type##provider##Delegate(device_type, provider, &func); 97 } // namespace mindspore 98 99 #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_FACTORY_H_ 100