1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_ 19 20 #include <functional> 21 #include <unordered_map> 22 #include <vector> 23 #include <utility> 24 #include <string> 25 #include <memory> 26 #include <NvInfer.h> 27 #include "base/base.h" 28 #include "ir/anf.h" 29 30 namespace mindspore { 31 namespace opt { 32 class LayerInput; 33 class TrtConverterContext; 34 using ConvertResult = std::pair<bool, std::vector<nvinfer1::ITensor *>>; 35 using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>; 36 37 class TrtOpFactory { 38 public: GetInstance()39 static TrtOpFactory &GetInstance() { 40 static TrtOpFactory instance; 41 return instance; 42 } 43 Register(const std::string & op_name,const ConvertFunc & func)44 void Register(const std::string &op_name, const ConvertFunc &func) { 45 if (op_convert_map_.count(op_name)) { 46 MS_LOG(EXCEPTION) << "Operator: " << op_name << " re-registered."; 47 } 48 op_convert_map_.insert(std::make_pair(op_name, func)); 49 } 50 GetConvertFunc(const std::string & op_name)51 ConvertFunc GetConvertFunc(const std::string &op_name) const { 52 auto iter = op_convert_map_.find(op_name); 53 if (iter == op_convert_map_.end()) { 54 MS_LOG(WARNING) << "Operator: " << op_name << " not support."; 55 return nullptr; 56 } 57 return iter->second; 58 } 59 60 private: 61 TrtOpFactory() = default; 62 ~TrtOpFactory() = default; 63 DISABLE_COPY_AND_ASSIGN(TrtOpFactory) 64 65 std::unordered_map<std::string, ConvertFunc> op_convert_map_; 66 }; 67 68 class TrtOpRegister { 69 public: TrtOpRegister(const std::string & op_name,ConvertFunc func)70 TrtOpRegister(const std::string &op_name, ConvertFunc func) { TrtOpFactory::GetInstance().Register(op_name, func); } 71 }; 72 } // namespace opt 73 } // namespace mindspore 74 #endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_ 75