1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ 19 20 #include <utility> 21 #include <string> 22 #include <vector> 23 #include <memory> 24 #include <map> 25 #include "ir/anf.h" 26 #include "ir/dtype.h" 27 #include "utils/utils.h" 28 #include "backend/kernel_compiler/kernel.h" 29 #include "backend/session/kernel_graph.h" 30 31 namespace mindspore { 32 namespace device { 33 namespace gpu { 34 const size_t kAllPositions = SIZE_MAX; 35 const size_t kFormatTransformDimension = 4; 36 37 // Map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the inserted position of format transform. 38 // If the inserted position is kAllPositions, then insert all the positions, because the input or output numbers of 39 // this op are variable. 40 static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { 41 // Format sensitive. 42 {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, 43 {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, 44 {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, 45 {prim::kPrimMaxPool->name(), {{0}, {0}}}, 46 {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, 47 {kAvgPoolOpName, {{0}, {0}}}, 48 {kAvgPoolGradOpName, {{0, 1, 2}, {0}}}, 49 {kBatchNorm, {{0}, {0}}}, 50 {kBatchNormWithActivation, {{0}, {0}}}, 51 {kBatchNormWithAddAndActivation, {{0, 5}, {0}}}, 52 {kBatchNormGradOpName, {{0, 1}, {0}}}, 53 {kBatchNormGradWithActivation, {{0, 1, 7}, {0}}}, 54 {kBatchNormGradWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, 55 {kBiasAddOpName, {{0}, {0}}}, 56 {prim::kPrimBiasAddGrad->name(), {{0}, {}}}, 57 // Format insensitive. 58 {prim::kPrimRelu->name(), {{0}, {0}}}, 59 {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, 60 {prim::kPrimRelu6->name(), {{0}, {0}}}, 61 {prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}}, 62 {kSliceOpName, {{0}, {0}}}, 63 {kSliceGradOpName, {{0, 1}, {0}}}, 64 {kTensorAddOpName, {{0, 1}, {0}}}, 65 {prim::kPrimConcat->name(), {{kAllPositions}, {0}}}, 66 {prim::kPrimAddN->name(), {{kAllPositions}, {0}}}, 67 {prim::kPrimSplit->name(), {{0}, {kAllPositions}}}, 68 }; 69 70 void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); 71 72 class FormatTransformChecker { 73 public: 74 void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph); format_transform()75 bool format_transform() const { return format_transform_; } 76 GetInstance()77 static FormatTransformChecker &GetInstance() { 78 static FormatTransformChecker instance; 79 return instance; 80 } 81 82 private: 83 FormatTransformChecker() = default; 84 ~FormatTransformChecker() = default; 85 FormatTransformChecker(const FormatTransformChecker &); 86 FormatTransformChecker &operator=(const FormatTransformChecker &); 87 88 bool format_transform_{true}; 89 }; 90 91 class KernelAttr { 92 public: 93 using DataType = std::pair<TypeId, std::string>; KernelAttr()94 KernelAttr() : all_same_(false) {} 95 ~KernelAttr() = default; 96 97 KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { 98 input_type_.emplace_back(ms_type, format); 99 return *this; 100 } 101 102 KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { 103 output_type_.emplace_back(ms_type, format); 104 return *this; 105 } 106 AddAllSameAttr(const bool & all_same)107 KernelAttr &AddAllSameAttr(const bool &all_same) { 108 all_same_ = all_same; 109 return *this; 110 } 111 GetInputAttr(const size_t index)112 const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } GetOutputAttr(const size_t index)113 const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } GetAllSame()114 const bool &GetAllSame() const { return all_same_; } 115 GetInputSize()116 size_t GetInputSize() const { return input_type_.size(); } GetOutputSize()117 size_t GetOutputSize() const { return output_type_.size(); } 118 119 private: 120 std::vector<DataType> input_type_; 121 std::vector<DataType> output_type_; 122 bool all_same_; 123 }; 124 } // namespace gpu 125 } // namespace device 126 } // namespace mindspore 127 128 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ 129