1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_OPERATOR_INFO_REGISTER_H 18 #define MINDSPORE_LITE_TOOLS_OPERATOR_INFO_REGISTER_H 19 20 #include <map> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include "tools/optimizer/parallel/operator_info.h" 25 26 namespace mindspore { 27 namespace opt { 28 using OperatorInfoCreatorFunc = 29 std::function<std::unique_ptr<opt::OperatorInfo>(const std::string &name, const SplitStrategy &strategy)>; 30 31 class SplitOpKey { 32 public: 33 SplitOpKey() = delete; 34 SplitOpKey(int op_type,TypeId data_type,bool is_depth_wise)35 SplitOpKey(int op_type, TypeId data_type, bool is_depth_wise) 36 : op_type_(op_type), data_type_(data_type), is_depth_wise_(is_depth_wise) {} 37 38 bool operator<(const SplitOpKey &key) const; 39 40 std::string ToString() const; 41 42 ~SplitOpKey() = default; 43 44 private: 45 int op_type_{schema::PrimitiveType_NONE}; 46 TypeId data_type_{kTypeUnknown}; 47 // Conv && DepthwiseCon has same schema_id, so need this flags 48 bool is_depth_wise_{false}; 49 }; 50 51 class OperatorInfoFactory { 52 public: 53 static OperatorInfoFactory *GeInstance(); 54 55 OperatorInfoFactory(const OperatorInfoFactory &) = delete; 56 57 OperatorInfoFactory &operator=(const OperatorInfoFactory &) = delete; 58 59 void RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise, 60 const OperatorInfoCreatorFunc &creator_func); 61 62 OperatorInfoCreatorFunc FindOperatorInfo(const SplitOpKey &split_op_key); 63 64 private: 65 OperatorInfoFactory() = default; 66 67 virtual ~OperatorInfoFactory() = default; 68 69 private: 70 // key: op_type -->data_type-->-->is_depth_wise-->name 71 std::map<SplitOpKey, OperatorInfoCreatorFunc> operator_info_map_; 72 }; 73 74 class OperatorInfoRegister { 75 public: 76 OperatorInfoRegister() = delete; 77 78 OperatorInfoRegister(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise, 79 const OperatorInfoCreatorFunc &creator_func); 80 81 ~OperatorInfoRegister() = default; 82 }; 83 84 #define OPERATOR_INFO_REGISTER(operator_type, type_id, is_depth_wise, creator_func) \ 85 static OperatorInfoRegister g_name##operator_type##type_id##is_depth_wise##Creator(operator_type, type_id, \ 86 is_depth_wise, creator_func); 87 } // namespace opt 88 } // namespace mindspore 89 90 #endif // MINDSPORE_LITE_TOOLS_OPERATOR_INFO_REGISTER_H 91