• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_
19 
20 #include <utility>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include <map>
25 #include "ir/anf.h"
26 #include "ir/dtype.h"
27 #include "utils/utils.h"
28 #include "backend/kernel_compiler/kernel.h"
29 #include "backend/session/kernel_graph.h"
30 
31 namespace mindspore {
32 namespace device {
33 namespace gpu {
34 const size_t kAllPositions = SIZE_MAX;
35 const size_t kFormatTransformDimension = 4;
36 
37 // Map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the inserted position of format transform.
38 // If the inserted position is kAllPositions, then insert all the positions, because the input or output numbers of
39 // this op are variable.
40 static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
41   // Format sensitive.
42   {prim::kPrimConv2D->name(), {{0, 1}, {0}}},
43   {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}},
44   {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}},
45   {prim::kPrimMaxPool->name(), {{0}, {0}}},
46   {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}},
47   {kAvgPoolOpName, {{0}, {0}}},
48   {kAvgPoolGradOpName, {{0, 1, 2}, {0}}},
49   {kBatchNorm, {{0}, {0}}},
50   {kBatchNormWithActivation, {{0}, {0}}},
51   {kBatchNormWithAddAndActivation, {{0, 5}, {0}}},
52   {kBatchNormGradOpName, {{0, 1}, {0}}},
53   {kBatchNormGradWithActivation, {{0, 1, 7}, {0}}},
54   {kBatchNormGradWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
55   {kBiasAddOpName, {{0}, {0}}},
56   {prim::kPrimBiasAddGrad->name(), {{0}, {}}},
57   // Format insensitive.
58   {prim::kPrimRelu->name(), {{0}, {0}}},
59   {prim::kPrimReluGrad->name(), {{0, 1}, {0}}},
60   {prim::kPrimRelu6->name(), {{0}, {0}}},
61   {prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}},
62   {kSliceOpName, {{0}, {0}}},
63   {kSliceGradOpName, {{0, 1}, {0}}},
64   {kTensorAddOpName, {{0, 1}, {0}}},
65   {prim::kPrimConcat->name(), {{kAllPositions}, {0}}},
66   {prim::kPrimAddN->name(), {{kAllPositions}, {0}}},
67   {prim::kPrimSplit->name(), {{0}, {kAllPositions}}},
68 };
69 
70 void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
71 
72 class FormatTransformChecker {
73  public:
74   void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph);
format_transform()75   bool format_transform() const { return format_transform_; }
76 
GetInstance()77   static FormatTransformChecker &GetInstance() {
78     static FormatTransformChecker instance;
79     return instance;
80   }
81 
82  private:
83   FormatTransformChecker() = default;
84   ~FormatTransformChecker() = default;
85   FormatTransformChecker(const FormatTransformChecker &);
86   FormatTransformChecker &operator=(const FormatTransformChecker &);
87 
88   bool format_transform_{true};
89 };
90 
91 class KernelAttr {
92  public:
93   using DataType = std::pair<TypeId, std::string>;
KernelAttr()94   KernelAttr() : all_same_(false) {}
95   ~KernelAttr() = default;
96 
97   KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
98     input_type_.emplace_back(ms_type, format);
99     return *this;
100   }
101 
102   KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
103     output_type_.emplace_back(ms_type, format);
104     return *this;
105   }
106 
AddAllSameAttr(const bool & all_same)107   KernelAttr &AddAllSameAttr(const bool &all_same) {
108     all_same_ = all_same;
109     return *this;
110   }
111 
GetInputAttr(const size_t index)112   const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
GetOutputAttr(const size_t index)113   const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
GetAllSame()114   const bool &GetAllSame() const { return all_same_; }
115 
GetInputSize()116   size_t GetInputSize() const { return input_type_.size(); }
GetOutputSize()117   size_t GetOutputSize() const { return output_type_.size(); }
118 
119  private:
120   std::vector<DataType> input_type_;
121   std::vector<DataType> output_type_;
122   bool all_same_;
123 };
124 }  // namespace gpu
125 }  // namespace device
126 }  // namespace mindspore
127 
128 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_KERNEL_INFO_SETTER_H_
129