1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ 19 20 #include <utility> 21 #include <string> 22 #include <vector> 23 #include <memory> 24 #include <map> 25 #include "ir/anf.h" 26 #include "ops/conv_pool_ops.h" 27 #include "ops/nn_optimizer_ops.h" 28 #include "ops/nn_ops.h" 29 #include "ops/array_ops.h" 30 #include "ops/math_op_name.h" 31 #include "ir/dtype.h" 32 #include "include/common/utils/utils.h" 33 #include "kernel/kernel.h" 34 #include "kernel/kernel_build_info.h" 35 #include "kernel/graph_kernel_info.h" 36 #include "include/backend/kernel_graph.h" 37 #include "kernel/common_utils.h" 38 #include "utils/ms_context.h" 39 #include "include/backend/visible.h" 40 41 namespace mindspore { 42 namespace device { 43 namespace gpu { 44 const size_t kAllPositions = SIZE_MAX; 45 const size_t kFormatTransformDimension = 4; 46 47 // Map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the inserted position of format transform. 48 // If the inserted position is kAllPositions, then insert all the positions, because the input or output numbers of 49 // this op are variable. 50 static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { 51 // Format sensitive. 52 {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, 53 {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, 54 {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, 55 {prim::kPrimMaxPool->name(), {{0}, {0}}}, 56 {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, 57 {kAvgPoolOpName, {{0}, {0}}}, 58 {kAvgPoolGradOpName, {{0, 1, 2}, {0}}}, 59 {kBatchNormOpName, {{0}, {0}}}, 60 {kBatchNormWithActivationOpName, {{0}, {0}}}, 61 {kBatchNormWithAddAndActivationOpName, {{0, 5}, {0}}}, 62 {kBatchNormGradOpName, {{0, 1}, {0}}}, 63 {kBatchNormGradWithActivationOpName, {{0, 1, 7}, {0}}}, 64 {kBatchNormGradWithAddAndActivationOpName, {{0, 1, 7}, {0, 3}}}, 65 {kBiasAddOpName, {{0}, {0}}}, 66 {prim::kPrimBiasAddGrad->name(), {{0}, {}}}, 67 // Format insensitive. 68 {prim::kPrimReLU->name(), {{0}, {0}}}, 69 {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, 70 {prim::kPrimReLU6->name(), {{0}, {0}}}, 71 {prim::kPrimReLU6Grad->name(), {{0, 1}, {0}}}, 72 {kSliceOpName, {{0}, {0}}}, 73 {kSliceGradOpName, {{0, 1}, {0}}}, 74 {kTensorAddOpName, {{0, 1}, {0}}}, 75 {prim::kPrimConcat->name(), {{kAllPositions}, {0}}}, 76 {prim::kPrimAddN->name(), {{kAllPositions}, {0}}}, 77 {prim::kPrimSplit->name(), {{0}, {kAllPositions}}}, 78 }; 79 80 std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kernel_node, 81 KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); 82 83 class FormatTransformChecker { 84 public: 85 void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph); format_transform()86 bool format_transform() const { return format_transform_; } 87 GetInstance()88 static FormatTransformChecker &GetInstance() { 89 static FormatTransformChecker instance; 90 return instance; 91 } 92 93 private: 94 FormatTransformChecker() = default; 95 ~FormatTransformChecker() = default; 96 FormatTransformChecker(const FormatTransformChecker &); 97 FormatTransformChecker &operator=(const FormatTransformChecker &); 98 99 bool format_transform_{true}; 100 }; 101 102 class GPU_EXPORT GPUGraphKernelInfo : public GraphKernelInfo { 103 public: 104 GPUGraphKernelInfo() = default; 105 virtual ~GPUGraphKernelInfo() = default; 106 void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) override; 107 }; 108 109 REG_GRAPH_KERNEL_INFO(kGPUDevice, GPUGraphKernelInfo); 110 } // namespace gpu 111 } // namespace device 112 } // namespace mindspore 113 114 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_ 115