• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_FACTORY_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_FACTORY_H_
18 #include <string>
19 #include <utility>
20 #include <vector>
21 #include <memory>
22 #include <map>
23 #include <set>
24 
25 #include "ir/anf.h"
26 #include "utils/hash_map.h"
27 #include "utils/hash_set.h"
28 #include "utils/ms_utils.h"
29 #include "utils/ms_context.h"
30 #include "include/backend/visible.h"
31 
32 namespace mindspore::opt {
33 class BACKEND_EXPORT OpAdaptationInfo {
34  public:
OpAdaptationInfo(const std::string & me_op_name,std::string device_name,bool flag)35   explicit OpAdaptationInfo(const std::string &me_op_name, std::string device_name, bool flag)
36       : me_op_name_(me_op_name),
37         backend_op_name_(me_op_name),
38         target_op_name_(me_op_name),
39         device_name_(std::move(device_name)),
40         flag_(flag) {}
41 
42   OpAdaptationInfo &operator=(const OpAdaptationInfo &op_adaptation_info);
43   virtual ~OpAdaptationInfo() = default;
44 
45   OpAdaptationInfo &set_backend_op_name(const std::string &default_op_name);
46   OpAdaptationInfo &set_target_op_name(const std::string &target_op_name);
47   OpAdaptationInfo &set_pre_check_func(std::function<bool(CNodePtr)> pre_check_func);
48   OpAdaptationInfo &set_need_tbe_check_supported(bool need_tbe_check_supported);
49   OpAdaptationInfo &set_input_attr_info(size_t input_index, const std::string &attr_data_type = "");
50   OpAdaptationInfo &set_is_ascend_mindir();
51 
me_op_name()52   const std::string &me_op_name() const { return me_op_name_; }
backend_op_name()53   const std::string &backend_op_name() const { return backend_op_name_; }
target_op_name()54   const std::string &target_op_name() const { return target_op_name_; }
pre_check_func()55   const std::function<bool(CNodePtr)> &pre_check_func() const { return pre_check_func_; }
need_tbe_check_supported()56   bool need_tbe_check_supported() const { return need_tbe_check_supported_; }
input_attr_map()57   const std::map<size_t, std::string> &input_attr_map() const { return input_attr_map_; }
device_name()58   const std::string &device_name() const { return device_name_; }
flag()59   bool flag() const { return flag_; }
is_ascend_mindir()60   bool is_ascend_mindir() const { return is_ascend_mindir_; }
61 
62  private:
63   std::string me_op_name_;
64   std::string backend_op_name_;
65   std::string target_op_name_;
66   std::function<bool(CNodePtr)> pre_check_func_{nullptr};
67   bool need_tbe_check_supported_{false};
68   std::map<size_t, std::string> input_attr_map_;
69   std::string device_name_;
70   bool flag_{false};
71   bool is_ascend_mindir_{false};
72 };
73 
74 class BACKEND_EXPORT OpAdaptationInfoRegister {
75  public:
76   static OpAdaptationInfoRegister &GetInstance();
77   static void RegOpAdaptationInfo(OpAdaptationInfo *reg_info);
78   [[nodiscard]] static OpAdaptationInfo *GetOpAdaptationInfo(const std::string &me_op_name,
79                                                              const std::string &device_name, bool flag);
80   static CNodePtr CreateTargetOp(const CNodePtr &origin_op, const OpAdaptationInfo &op_adaptation_info);
81   static bool ConvertInputToAttr(const CNodePtr &origin_op, size_t i, const std::shared_ptr<AnfNode> &input_node,
82                                  const std::string &attr_data_type, const std::shared_ptr<Primitive> &target_primitive);
83   static void RenamePrimitiveName(const CNodePtr &origin_op, const std::string &me_op_name,
84                                   const std::string &backend_op_name);
85 
86  private:
87   OpAdaptationInfoRegister() = default;
88   ~OpAdaptationInfoRegister() = default;
89   DISABLE_COPY_AND_ASSIGN(OpAdaptationInfoRegister)
90 
91   static std::string GenerateKey(const std::string &me_op_name, const std::string &device_name, bool flag);
92   // key: (op_name + device_name + flag), value: <OpAdaptationInfo *>
93   static std::map<std::string, OpAdaptationInfo *> &GetOpInfoMap();
94   // For improving performance, no need generate key for every op
95   static std::set<std::string> &GetOpName();
96 };
97 
98 class BACKEND_EXPORT RegisterHelper {
99  public:
100   RegisterHelper(const std::string &me_op_name, const std::string &device_name, bool flag, int len, ...);
101   RegisterHelper(const OpAdaptationInfo &op_adaptation_info);
102   ~RegisterHelper() = default;
103 
104  private:
105   std::shared_ptr<OpAdaptationInfo> op_adaptation_info_{nullptr};
106 };
107 
108 #define REG_OP_ADAPTATION_INFO(me_op_name, device_name, flag)                                      \
109   static opt::RegisterHelper g_reg_##device_name##_##flag##_##me_op_name __attribute__((unused)) = \
110     opt::OpAdaptationInfo(me_op_name, device_name, flag)
111 
112 #define RER_CONST_TO_ATTR_LIST(me_op_name, device_name, flag, ...)        \
113   static opt::RegisterHelper g_reg_##device_name##_##flag##_##me_op_name( \
114     me_op_name, device_name, flag, std::tuple_size<decltype(std::make_tuple(__VA_ARGS__))>::value, __VA_ARGS__)
115 }  // namespace mindspore::opt
116 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_FACTORY_H_
117