1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_ 18 #include <string> 19 #include <memory> 20 #include <vector> 21 #include <functional> 22 23 #include "ir/anf.h" 24 #include "ir/dtype/type_id.h" 25 #include "utils/shape_utils.h" 26 #include "backend/common/graph_kernel/model/node.h" 27 #include "include/backend/visible.h" 28 29 namespace mindspore::graphkernel { 30 class Callback; 31 using CallbackPtr = std::shared_ptr<Callback>; 32 class BACKEND_EXPORT Callback { 33 public: 34 virtual ~Callback() = default; Instance()35 static CallbackPtr Instance() { return instance_; } 36 37 /** 38 * @brief Get the real input shape of the `node`. 39 * 40 * @param node the AnfNodePtr 41 * @param i the input index, start from 0 42 */ 43 virtual ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) = 0; 44 45 /** 46 * @brief Get the real output shape of the `node`. 47 * 48 * @param node the AnfNodePtr 49 * @param i the output index, start from 0 50 */ 51 virtual ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) = 0; 52 53 /** 54 * @brief Get the inferred input shape of the `node`. 55 * 56 * @param node the AnfNodePtr 57 * @param i the input index, start from 0 58 */ 59 virtual ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) = 0; 60 61 /** 62 * @brief Get the inferred output shape of the `node`. 63 * 64 * @param node the AnfNodePtr 65 * @param i the output index, start from 0 66 */ 67 virtual ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) = 0; 68 69 /** 70 * @brief Get the real input data type of the `node`. 71 * 72 * @param node the AnfNodePtr 73 * @param i the input index, start from 0 74 */ 75 virtual TypeId GetInputType(const AnfNodePtr &node, size_t i) = 0; 76 77 /** 78 * @brief Get the real output data type of the `node`. 79 * 80 * @param node the AnfNodePtr 81 * @param i the output index, start from 0 82 */ 83 virtual TypeId GetOutputType(const AnfNodePtr &node, size_t i) = 0; 84 85 /** 86 * @brief Get the inferred input data type of the `node`. 87 * 88 * @param node the AnfNodePtr 89 * @param i the input index, start from 0 90 */ 91 virtual TypeId GetInputInferType(const AnfNodePtr &node, size_t i) = 0; 92 93 /** 94 * @brief Get the inferred output data type of the `node`. 95 * 96 * @param node the AnfNodePtr 97 * @param i the output index, start from 0 98 */ 99 virtual TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) = 0; 100 101 /** 102 * @brief Get the input data format of the `node`. 103 * 104 * @param node the AnfNodePtr 105 * @param i the input index, start from 0 106 */ 107 virtual std::string GetInputFormat(const AnfNodePtr &node, size_t i) = 0; 108 109 /** 110 * @brief Get the output data format of the `node`. 111 * 112 * @param node the AnfNodePtr 113 * @param i the output index, start from 0 114 */ 115 virtual std::string GetOutputFormat(const AnfNodePtr &node, size_t i) = 0; 116 117 /** 118 * @brief Get the processor of the `node`. 119 * 120 * @param node the AnfNodePtr 121 */ 122 virtual std::string GetProcessor(const AnfNodePtr &node) = 0; 123 124 /** 125 * @brief Get the backend target from context. 126 * 127 * @param detail if false(default), only "Ascend/GPU/CPU" is returned. otherwise target like "Ascend910" is returned. 128 */ 129 std::string GetTargetFromContext(bool detail = false) { return GetTargetFromContextImpl(detail); } 130 131 /** 132 * @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs. 133 * 134 * @param[in] node the GraphKernel CNode. 135 */ 136 virtual void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) = 0; 137 138 /** 139 * @brief Set KernelInfo for a basic node. 140 * 141 * @param node the AnfNodePtr 142 * @param outputs_info the output info list 143 */ 144 virtual void SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector<inner::NodeBase> &outputs_info) = 0; 145 146 /** 147 * @brief Set empty KernelInfo. 148 * 149 * @param node the AnfNodePtr 150 */ 151 virtual void SetEmptyKernelInfo(const AnfNodePtr &node) = 0; 152 153 /** 154 * @brief Reset KernelInfo on different platforms. 155 * 156 * @param node the AnfNodePtr 157 */ 158 virtual void ResetKernelInfo(const AnfNodePtr &node) = 0; 159 160 /** 161 * @brief Reset KernelInfo input msg for convert attr and input. 162 * 163 * @param node the AnfNodePtr 164 * @param overwrite if true, override all inputs kernel info, if false, use the original kernel info saved in node 165 */ 166 virtual void ResetKernelInfoInputs(const AnfNodePtr &node, const std::vector<size_t> &indices) = 0; 167 168 /** 169 * @brief The Callback implementation use nodes' device info. 170 */ IsUseDeviceInfo()171 virtual bool IsUseDeviceInfo() { return true; } 172 RegImpl(const CallbackPtr & cb)173 static void RegImpl(const CallbackPtr &cb) { instance_ = cb; } 174 175 private: 176 // to avoid the default argument in virtual function. 177 virtual std::string GetTargetFromContextImpl(bool detail) = 0; 178 179 friend class CallbackImplRegister; 180 #ifndef _MSC_VER 181 BACKEND_EXPORT inline static CallbackPtr instance_{nullptr}; 182 #else 183 inline static CallbackPtr instance_{nullptr}; 184 #endif 185 }; 186 187 class CallbackImplRegister { 188 public: CallbackImplRegister(const std::function<CallbackPtr ()> & fn)189 explicit CallbackImplRegister(const std::function<CallbackPtr()> &fn) noexcept { Callback::RegImpl(fn()); } 190 ~CallbackImplRegister() = default; 191 192 protected: 193 // for pclint-plus 194 bool rev_{false}; 195 }; 196 197 #define GRAPH_KERNEL_CALLBACK_REGISTER(cls) \ 198 const CallbackImplRegister callback( \ 199 []() noexcept { return std::static_pointer_cast<Callback>(std::make_shared<cls>()); }) 200 } // namespace mindspore::graphkernel 201 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_ 202