• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
19 
20 #include <functional>
21 #include <unordered_map>
22 #include <vector>
23 #include <utility>
24 #include <string>
25 #include <memory>
26 #include <NvInfer.h>
27 #include "base/base.h"
28 #include "ir/anf.h"
29 
30 namespace mindspore {
31 namespace opt {
32 class LayerInput;
33 class TrtConverterContext;
34 using ConvertResult = std::pair<bool, std::vector<nvinfer1::ITensor *>>;
35 using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>;
36 
37 class TrtOpFactory {
38  public:
GetInstance()39   static TrtOpFactory &GetInstance() {
40     static TrtOpFactory instance;
41     return instance;
42   }
43 
Register(const std::string & op_name,const ConvertFunc & func)44   void Register(const std::string &op_name, const ConvertFunc &func) {
45     if (op_convert_map_.count(op_name)) {
46       MS_LOG(EXCEPTION) << "Operator: " << op_name << " re-registered.";
47     }
48     op_convert_map_.insert(std::make_pair(op_name, func));
49   }
50 
GetConvertFunc(const std::string & op_name)51   ConvertFunc GetConvertFunc(const std::string &op_name) const {
52     auto iter = op_convert_map_.find(op_name);
53     if (iter == op_convert_map_.end()) {
54       MS_LOG(WARNING) << "Operator: " << op_name << " not support.";
55       return nullptr;
56     }
57     return iter->second;
58   }
59 
60  private:
61   TrtOpFactory() = default;
62   ~TrtOpFactory() = default;
63   DISABLE_COPY_AND_ASSIGN(TrtOpFactory)
64 
65   std::unordered_map<std::string, ConvertFunc> op_convert_map_;
66 };
67 
68 class TrtOpRegister {
69  public:
TrtOpRegister(const std::string & op_name,ConvertFunc func)70   TrtOpRegister(const std::string &op_name, ConvertFunc func) { TrtOpFactory::GetInstance().Register(op_name, func); }
71 };
72 }  // namespace opt
73 }  // namespace mindspore
74 #endif  // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
75