• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
18 
19 #include <iostream>
20 #include <string>
21 
22 #include "utils/ms_utils.h"
23 #include "runtime/device/kernel_info.h"
24 #include "runtime/device/gpu/cuda_common.h"
25 #include "backend/kernel_compiler/common_utils.h"
26 
27 namespace mindspore {
28 namespace kernel {
GetInstance()29 GpuKernelFactory &GpuKernelFactory::GetInstance() {
30   static GpuKernelFactory instance;
31   return instance;
32 }
33 
Register(const std::string & kernel_name,const KernelAttr & kernel_attr,GpuKernelCreater && creator)34 void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
35                                 GpuKernelCreater &&creator) {
36   map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creator);
37 }
38 
CheckIOParam(const std::string & kernel_name,const KernelBuildInfo * kernel_info,std::vector<std::pair<KernelAttr,GpuKernelCreater>> * iter_second,size_t attr_index)39 bool GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
40                                     std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second,
41                                     size_t attr_index) {
42   if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) {
43     if (!iter_second->at(attr_index).first.GetAllSame()) {
44       return false;
45     }
46   }
47   if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) {
48     if (!iter_second->at(attr_index).first.GetAllSame()) {
49       return false;
50     }
51   }
52   return true;
53 }
54 
SupportedTypeList(const std::string & kernel_name)55 std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) {
56   std::string type_lists = "";
57   auto iter = map_kernel_name_to_creater_.find(kernel_name);
58   if (map_kernel_name_to_creater_.end() == iter) {
59     return type_lists;
60   }
61   for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
62     std::string type_list = "in[";
63     auto attr = (iter->second)[attr_index].first;
64     for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) {
65       type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) +
66                   ((input_index == (attr.GetInputSize() - 1)) ? "" : " ");
67     }
68     type_list = type_list + "], out[";
69     for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) {
70       type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) +
71                   ((input_index == (attr.GetOutputSize() - 1)) ? "" : " ");
72     }
73     type_lists = type_lists + type_list + "]; ";
74   }
75   return type_lists;
76 }
77 
ReducePrecision(const std::string & kernel_name,std::shared_ptr<mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder)78 bool GpuKernelFactory::ReducePrecision(
79   const std::string &kernel_name, std::shared_ptr<mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder) {
80   MS_EXCEPTION_IF_NULL(builder);
81   auto kernel_info = builder->Build();
82   MS_EXCEPTION_IF_NULL(kernel_info);
83   auto iter = map_kernel_name_to_creater_.find(kernel_name);
84   if (map_kernel_name_to_creater_.end() == iter) {
85     MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!";
86     return false;
87   }
88   reduce_flag_.first.clear();
89   for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
90     auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize();
91     for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) {
92       if (kernel_info->GetInputDeviceType(input_index) == kNumberTypeInt64 &&
93           (iter->second)[attr_index].first.GetInputAttr(input_index % attr_size).first == kNumberTypeInt32) {
94         builder->SetInputDeviceType(kNumberTypeInt32, input_index);
95         reduce_flag_.first.push_back(input_index);
96         MS_LOG(WARNING) << "Kernel [" << kernel_name << "] does not support int64, cast input " << input_index
97                         << " to int32.";
98       }
99     }
100     for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) {
101       if (kernel_info->GetOutputDeviceType(output_index) == kNumberTypeInt64 &&
102           (iter->second)[attr_index].first.GetOutputAttr(output_index % attr_size).first == kNumberTypeInt32) {
103         builder->SetOutputDeviceType(kNumberTypeInt32, output_index);
104         MS_LOG(WARNING) << "Kernel [" << kernel_name << "] does not support int64, cast output " << output_index
105                         << " to int32.";
106       }
107     }
108   }
109   return GpuKernelFactory::SearchRegistered(kernel_name, builder->Build());
110 }
111 
CheckSM(const KernelBuildInfo * kernel_info,const size_t & input_index)112 void GpuKernelFactory::CheckSM(const KernelBuildInfo *kernel_info, const size_t &input_index) {
113   const int major_sm = GET_MAJOR_SM;
114   const bool check_sm = mindspore::device::gpu::CudaCommon::GetInstance().check_sm();
115   if (check_sm && major_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) {
116     if (major_sm < MINIUM_SM) {
117       MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM
118                         << ", but the current device's computing capacity is " << major_sm;
119     }
120     MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM
121                     << ", but the current device's computing capacity is " << major_sm;
122     mindspore::device::gpu::CudaCommon::GetInstance().set_check_sm(false);
123   }
124 }
125 
GpuKernelAttrCheck(const std::string & kernel_name,const KernelBuildInfo * kernel_info)126 std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &kernel_name,
127                                                              const KernelBuildInfo *kernel_info) {
128   auto iter = map_kernel_name_to_creater_.find(kernel_name);
129   if (map_kernel_name_to_creater_.end() == iter) {
130     MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!";
131     return std::make_pair(false, 0);
132   }
133   if ((iter->second).size() == 1 && (iter->second)[0].first.GetInputSize() == 0) {
134     return std::make_pair(true, 0);
135   }
136 
137   for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
138     if (!CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index)) {
139       continue;
140     }
141     bool flag = true;
142     auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize();
143     if (kernel_info->GetInputNum() > 0) {
144       MS_EXCEPTION_IF_ZERO("attr size", attr_size);
145     }
146     // data type matching check of all input parameters of kernel
147     for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) {
148       GpuKernelFactory::CheckSM(kernel_info, input_index);
149       if (kernel_info->GetInputDeviceType(input_index) !=
150           (iter->second)[attr_index].first.GetInputAttr(input_index % attr_size).first) {
151         flag = false;
152         break;
153       }
154     }
155     if (!flag) {
156       continue;
157     }
158     attr_size = (&(iter->second))->at(attr_index).first.GetOutputSize();
159     if (kernel_info->GetOutputNum() > 0) {
160       MS_EXCEPTION_IF_ZERO("attr size", attr_size);
161     }
162     // data type matching check of all output parameters of kernel
163     for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) {
164       if (kernel_info->GetOutputDeviceType(output_index) !=
165           (iter->second)[attr_index].first.GetOutputAttr(output_index % attr_size).first) {
166         flag = false;
167         break;
168       }
169     }
170     // finish data type matching check and return a pair maintain the whether matching is success,
171     // if first is true, second is index of matching KernelAttr and creator pair in vector;
172     if (flag) {
173       size_t match_index = attr_index;
174       return std::make_pair(true, match_index);
175     }
176   }
177   return std::make_pair(false, 0);
178 }
179 
Create(const std::string & kernel_name,const CNodePtr & apply_kernel)180 GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) {
181   auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
182   MS_EXCEPTION_IF_NULL(kernel_info);
183   const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
184   MS_EXCEPTION_IF_NULL(kernel_build_Info);
185   std::pair<bool, size_t> ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_Info);
186   if (ret_pair.first) {
187     return (map_kernel_name_to_creater_.find(kernel_name)->second)[ret_pair.second].second();
188   }
189   return nullptr;
190 }
191 
SearchRegistered(const std::string & kernel_name,const KernelBuildInfoPtr & kernel_build_info)192 bool GpuKernelFactory::SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_build_info) {
193   std::pair<bool, size_t> ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_info.get());
194   return ret_pair.first;
195 }
196 }  // namespace kernel
197 }  // namespace mindspore
198