• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_
18 #include <string>
19 #include <memory>
20 #include <vector>
21 #include <functional>
22 
23 #include "ir/anf.h"
24 #include "ir/dtype/type_id.h"
25 #include "utils/shape_utils.h"
26 #include "backend/common/graph_kernel/model/node.h"
27 #include "include/backend/visible.h"
28 
29 namespace mindspore::graphkernel {
30 class Callback;
31 using CallbackPtr = std::shared_ptr<Callback>;
32 class BACKEND_EXPORT Callback {
33  public:
34   virtual ~Callback() = default;
Instance()35   static CallbackPtr Instance() { return instance_; }
36 
37   /**
38    * @brief Get the real input shape of the `node`.
39    *
40    * @param node the AnfNodePtr
41    * @param i    the input index, start from 0
42    */
43   virtual ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) = 0;
44 
45   /**
46    * @brief Get the real output shape of the `node`.
47    *
48    * @param node the AnfNodePtr
49    * @param i    the output index, start from 0
50    */
51   virtual ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) = 0;
52 
53   /**
54    * @brief Get the inferred input shape of the `node`.
55    *
56    * @param node the AnfNodePtr
57    * @param i    the input index, start from 0
58    */
59   virtual ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) = 0;
60 
61   /**
62    * @brief Get the inferred output shape of the `node`.
63    *
64    * @param node the AnfNodePtr
65    * @param i    the output index, start from 0
66    */
67   virtual ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) = 0;
68 
69   /**
70    * @brief Get the real input data type of the `node`.
71    *
72    * @param node the AnfNodePtr
73    * @param i    the input index, start from 0
74    */
75   virtual TypeId GetInputType(const AnfNodePtr &node, size_t i) = 0;
76 
77   /**
78    * @brief Get the real output data type of the `node`.
79    *
80    * @param node the AnfNodePtr
81    * @param i    the output index, start from 0
82    */
83   virtual TypeId GetOutputType(const AnfNodePtr &node, size_t i) = 0;
84 
85   /**
86    * @brief Get the inferred input data type of the `node`.
87    *
88    * @param node the AnfNodePtr
89    * @param i    the input index, start from 0
90    */
91   virtual TypeId GetInputInferType(const AnfNodePtr &node, size_t i) = 0;
92 
93   /**
94    * @brief Get the inferred output data type of the `node`.
95    *
96    * @param node the AnfNodePtr
97    * @param i    the output index, start from 0
98    */
99   virtual TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) = 0;
100 
101   /**
102    * @brief Get the input data format of the `node`.
103    *
104    * @param node the AnfNodePtr
105    * @param i    the input index, start from 0
106    */
107   virtual std::string GetInputFormat(const AnfNodePtr &node, size_t i) = 0;
108 
109   /**
110    * @brief Get the output data format of the `node`.
111    *
112    * @param node the AnfNodePtr
113    * @param i    the output index, start from 0
114    */
115   virtual std::string GetOutputFormat(const AnfNodePtr &node, size_t i) = 0;
116 
117   /**
118    * @brief Get the processor of the `node`.
119    *
120    * @param node the AnfNodePtr
121    */
122   virtual std::string GetProcessor(const AnfNodePtr &node) = 0;
123 
124   /**
125    * @brief Get the backend target from context.
126    *
127    * @param detail if false(default), only "Ascend/GPU/CPU" is returned. otherwise target like "Ascend910" is returned.
128    */
129   std::string GetTargetFromContext(bool detail = false) { return GetTargetFromContextImpl(detail); }
130 
131   /**
132    * @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs.
133    *
134    * @param[in] node the GraphKernel CNode.
135    */
136   virtual void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) = 0;
137 
138   /**
139    * @brief Set KernelInfo for a basic node.
140    *
141    * @param node         the AnfNodePtr
142    * @param outputs_info the output info list
143    */
144   virtual void SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector<inner::NodeBase> &outputs_info) = 0;
145 
146   /**
147    * @brief Set empty KernelInfo.
148    *
149    * @param node the AnfNodePtr
150    */
151   virtual void SetEmptyKernelInfo(const AnfNodePtr &node) = 0;
152 
153   /**
154    * @brief Reset KernelInfo on different platforms.
155    *
156    * @param node the AnfNodePtr
157    */
158   virtual void ResetKernelInfo(const AnfNodePtr &node) = 0;
159 
160   /**
161    * @brief Reset KernelInfo input msg for convert attr and input.
162    *
163    * @param node the AnfNodePtr
164    * @param overwrite if true, override all inputs kernel info, if false, use the original kernel info saved in node
165    */
166   virtual void ResetKernelInfoInputs(const AnfNodePtr &node, const std::vector<size_t> &indices) = 0;
167 
168   /**
169    * @brief The Callback implementation use nodes' device info.
170    */
IsUseDeviceInfo()171   virtual bool IsUseDeviceInfo() { return true; }
172 
RegImpl(const CallbackPtr & cb)173   static void RegImpl(const CallbackPtr &cb) { instance_ = cb; }
174 
175  private:
176   // to avoid the default argument in virtual function.
177   virtual std::string GetTargetFromContextImpl(bool detail) = 0;
178 
179   friend class CallbackImplRegister;
180 #ifndef _MSC_VER
181   BACKEND_EXPORT inline static CallbackPtr instance_{nullptr};
182 #else
183   inline static CallbackPtr instance_{nullptr};
184 #endif
185 };
186 
187 class CallbackImplRegister {
188  public:
CallbackImplRegister(const std::function<CallbackPtr ()> & fn)189   explicit CallbackImplRegister(const std::function<CallbackPtr()> &fn) noexcept { Callback::RegImpl(fn()); }
190   ~CallbackImplRegister() = default;
191 
192  protected:
193   // for pclint-plus
194   bool rev_{false};
195 };
196 
197 #define GRAPH_KERNEL_CALLBACK_REGISTER(cls) \
198   const CallbackImplRegister callback(      \
199     []() noexcept { return std::static_pointer_cast<Callback>(std::make_shared<cls>()); })
200 }  // namespace mindspore::graphkernel
201 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_
202