• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/gpu/kernel_info_setter.h"
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include "backend/kernel_compiler/common_utils.h"
22 #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
23 #include "backend/kernel_compiler/kernel.h"
24 #include "backend/kernel_compiler/kernel_build_info.h"
25 #include "backend/kernel_compiler/oplib/opinfo.h"
26 #include "backend/kernel_compiler/oplib/oplib.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "runtime/device/gpu/cuda_common.h"
29 #include "utils/ms_context.h"
30 #include "utils/ms_utils.h"
31 #include "utils/utils.h"
32 
33 namespace mindspore {
34 namespace device {
35 namespace gpu {
36 using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
37 using mindspore::kernel::KernelBuildInfo;
38 namespace {
CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> & alternative_kernel_info,const std::shared_ptr<KernelBuildInfo> & selected_kernel_info)39 bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
40                      const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
41   MS_EXCEPTION_IF_NULL(selected_kernel_info);
42   MS_EXCEPTION_IF_NULL(alternative_kernel_info);
43   size_t selected_input_num = selected_kernel_info->GetInputNum();
44   size_t alternative_input_num = alternative_kernel_info->GetInputNum();
45   if (selected_input_num != alternative_input_num) {
46     return false;
47   }
48   for (size_t i = 0; i < selected_input_num; i++) {
49     if (selected_kernel_info->GetInputFormat(i) != alternative_kernel_info->GetInputFormat(i)) {
50       return false;
51     }
52     if (selected_kernel_info->GetInputDeviceType(i) != alternative_kernel_info->GetInputDeviceType(i)) {
53       return false;
54     }
55   }
56 
57   size_t selected_output_num = selected_kernel_info->GetOutputNum();
58   size_t alternative_output_num = alternative_kernel_info->GetOutputNum();
59   if (selected_output_num != alternative_output_num) {
60     return false;
61   }
62   for (size_t i = 0; i < selected_output_num; i++) {
63     if (selected_kernel_info->GetOutputFormat(i) != alternative_kernel_info->GetOutputFormat(i)) {
64       return false;
65     }
66     if (selected_kernel_info->GetOutputDeviceType(i) != alternative_kernel_info->GetOutputDeviceType(i)) {
67       return false;
68     }
69   }
70   return true;
71 }
72 
SupportedTypeList(const CNodePtr & kernel_node)73 std::string SupportedTypeList(const CNodePtr &kernel_node) {
74   std::string supported_type_lists =
75     kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
76   if (!supported_type_lists.empty()) {
77     return supported_type_lists;
78   }
79   std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
80   std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
81   auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG);
82   if (op_info_ptr == nullptr) {
83     MS_LOG(EXCEPTION) << "Unsupported op [" << op_name << "] on GPU";
84   }
85   (void)ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list);
86   for (size_t i = 0; i < kernel_info_list.size(); i++) {
87     auto supported_akg_type = kernel_info_list[i]->GetAllInputDeviceTypes();
88     auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes();
89     std::string supported_akg_type_list = "in[";
90     for (auto type : supported_akg_type) {
91       supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
92     }
93     supported_type_lists = supported_type_lists + supported_akg_type_list + "], out[";
94     supported_akg_type_list.clear();
95     for (auto type : supported_akg_type_out) {
96       supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
97     }
98     supported_type_lists = supported_type_lists + supported_akg_type_list + "]; ";
99   }
100   return supported_type_lists;
101 }
102 
SelectAkgKernel(const CNodePtr & kernel_node,const std::shared_ptr<KernelBuildInfo> & selected_kernel_info)103 bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
104   MS_EXCEPTION_IF_NULL(kernel_node);
105   MS_EXCEPTION_IF_NULL(selected_kernel_info);
106   std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
107   auto func_call = kernel_node->input(0);
108   if (auto pre = GetCNodePrimitive(kernel_node)) {
109     if (pre->GetAttr("akg")) {
110       return true;
111     }
112   }
113   if (AnfAlgo::IsNodeInGraphKernel(kernel_node)) {
114     // The op_info in OpLib is only used for basic ops,
115     // we don't care it in GraphKernel.
116     return true;
117   }
118 
119   std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
120 
121   auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG);
122   if (op_info_ptr == nullptr) {
123     MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg";
124     return false;
125   }
126   if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) {
127     MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed.";
128   }
129   if (kernel_info_list.empty()) {
130     MS_LOG(EXCEPTION) << "Akg dose not has metadata of op[" << op_name << "].";
131   }
132 
133   bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
134                            [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
135                              return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
136                            });
137   if (!match) {
138     MS_LOG(ERROR) << "Not find op[" << op_name << "] which both match data type and format in akg";
139     return false;
140   }
141   return true;
142 }
143 
SetTensorDeviceInfo(const kernel::KernelBuildInfo & selected_kernel_info,const CNodePtr & kernel_node)144 void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
145   MS_EXCEPTION_IF_NULL(kernel_node);
146   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
147   for (size_t input_index = 0; input_index < input_num; ++input_index) {
148     auto input_kernel_node = kernel_node->input(input_index + 1);
149     MS_EXCEPTION_IF_NULL(input_kernel_node);
150     auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
151     MS_EXCEPTION_IF_NULL(input_with_index.first);
152     auto real_input_node = input_with_index.first;
153     if (!real_input_node->isa<Parameter>()) {
154       continue;
155     }
156     std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
157       std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
158 
159     auto param = real_input_node->cast<ParameterPtr>();
160     MS_EXCEPTION_IF_NULL(param);
161     if (!AnfAlgo::IsParameterWeight(param)) {
162       std::vector<std::string> output_format = {kOpFormat_DEFAULT};
163       builder->SetOutputsFormat(output_format);
164       std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
165       builder->SetOutputsDeviceType(output_type);
166       AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
167       continue;
168     }
169     if ((AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) ||
170         (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) {
171       std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
172       builder->SetOutputsFormat(output_format);
173       auto reduce_flag = kernel::GpuKernelFactory::GetInstance().reduce_flag_;
174       std::vector<TypeId> output_type;
175       if (std::find(reduce_flag.first.begin(), reduce_flag.first.end(), input_index) != reduce_flag.first.end()) {
176         output_type = {reduce_flag.second};
177       } else {
178         output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
179       }
180       builder->SetOutputsDeviceType(output_type);
181       AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
182     }
183   }
184   kernel::GpuKernelFactory::GetInstance().reduce_flag_.first.clear();
185 }
186 
TransformFormatPosition(std::vector<size_t> * format_position,size_t position_num)187 void TransformFormatPosition(std::vector<size_t> *format_position, size_t position_num) {
188   MS_EXCEPTION_IF_NULL(format_position);
189   if (format_position->size() == 0) {
190     return;
191   }
192 
193   // If the inserted position is kAllPositions, then insert all the positions.
194   if ((*format_position)[0] == kAllPositions) {
195     format_position->clear();
196     for (size_t index = 0; index < position_num; index++) {
197       format_position->push_back(index);
198     }
199   }
200 }
201 
IsNeedProcessFormatInfo(const CNodePtr & kernel_node,const std::vector<TypeId> & inputs_type)202 bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
203   auto ms_context = MsContext::GetInstance();
204   MS_EXCEPTION_IF_NULL(ms_context);
205   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
206     return false;
207   }
208   if (!FormatTransformChecker::GetInstance().format_transform()) {
209     return false;
210   }
211   if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
212     return false;
213   }
214   auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
215   auto iter = kKernelFormatPositionMap.find(kernel_name);
216   if (iter == kKernelFormatPositionMap.end()) {
217     return false;
218   }
219   if (inputs_type.size() == 0) {
220     return false;
221   }
222 
223   auto inputs_format_position = iter->second.first;
224   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
225   TransformFormatPosition(&inputs_format_position, input_num);
226   for (const auto &input_format_position : inputs_format_position) {
227     auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position);
228     // Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension.
229     if (input_shape.size() != kFormatTransformDimension) {
230       return false;
231     }
232   }
233 
234   auto outputs_format_position = iter->second.second;
235   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
236   TransformFormatPosition(&outputs_format_position, output_num);
237   for (const auto &output_format_position : outputs_format_position) {
238     auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_format_position);
239     // Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension.
240     if (output_shape.size() != kFormatTransformDimension) {
241       return false;
242     }
243   }
244   return true;
245 }
246 
UpdateKernelFormatInfo(const CNodePtr & kernel_node,const std::vector<TypeId> & inputs_type,std::vector<std::string> * inputs_format,std::vector<std::string> * outputs_format,std::string * origin_data_format)247 void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type,
248                             std::vector<std::string> *inputs_format, std::vector<std::string> *outputs_format,
249                             std::string *origin_data_format) {
250   MS_EXCEPTION_IF_NULL(kernel_node);
251   MS_EXCEPTION_IF_NULL(inputs_format);
252   MS_EXCEPTION_IF_NULL(outputs_format);
253   auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
254   auto iter = kKernelFormatPositionMap.find(kernel_name);
255   if (iter == kKernelFormatPositionMap.end()) {
256     return;
257   }
258   auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW;
259   MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format;
260   auto inputs_format_position = iter->second.first;
261   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
262   TransformFormatPosition(&inputs_format_position, input_num);
263   for (const auto &input_format_position : inputs_format_position) {
264     if (input_format_position >= inputs_format->size()) {
265       MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size ["
266                         << inputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]";
267     }
268     (*inputs_format)[input_format_position] = cal_format;
269   }
270 
271   auto outputs_format_position = iter->second.second;
272   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
273   TransformFormatPosition(&outputs_format_position, output_num);
274   for (const auto &output_format_position : outputs_format_position) {
275     if (output_format_position >= outputs_format->size()) {
276       MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size ["
277                         << outputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]";
278     }
279     (*outputs_format)[output_format_position] = cal_format;
280   }
281   auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
282   MS_EXCEPTION_IF_NULL(prim);
283   if (prim->HasAttr("format")) {
284     *origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "format");
285   }
286 }
287 
SetGraphKernelInfo(const CNodePtr & kernel_node,const FuncGraphPtr & func_graph)288 void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
289   std::vector<AnfNodePtr> node_list, input_list, output_list;
290   kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
291 
292   std::vector<std::string> graph_input_format;
293   std::vector<TypeId> graph_input_type;
294   // set graph kernel inputs kernel info.
295   for (size_t i = 0; i < input_list.size(); ++i) {
296     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
297     std::vector<std::string> outputs_format = {kOpFormat_DEFAULT};
298     std::vector<TypeId> outputs_device_type = {AnfAlgo::GetOutputInferDataType(input_list[i], 0)};
299     graph_input_format.push_back(kOpFormat_DEFAULT);
300     graph_input_type.push_back(AnfAlgo::GetOutputInferDataType(input_list[i], 0));
301     builder.SetOutputsFormat(outputs_format);
302     builder.SetOutputsDeviceType(outputs_device_type);
303     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
304   }
305 
306   // set graph kernel innner nodes kernel info.
307   for (size_t i = 0; i < node_list.size(); ++i) {
308     const auto &anf_node = node_list[i];
309     MS_EXCEPTION_IF_NULL(anf_node);
310     auto cnode = anf_node->cast<CNodePtr>();
311     cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
312     SetKernelInfo(cnode, KernelType::AKG_KERNEL);
313   }
314 
315   // set graph kernel node kernel info.
316   auto mng = func_graph->manager();
317   if (mng == nullptr) {
318     mng = Manage(func_graph, true);
319     func_graph->set_manager(mng);
320   }
321   auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
322   std::vector<std::string> graph_output_format;
323   std::vector<TypeId> graph_output_type;
324   for (size_t i = 0; i < output_index.size(); ++i) {
325     auto const &output = output_index[i];
326     graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
327     graph_output_type.push_back(AnfAlgo::GetOutputDeviceDataType(output.first, output.second));
328   }
329 
330   kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
331   graph_info_builder.SetInputsFormat(graph_input_format);
332   graph_info_builder.SetInputsDeviceType(graph_input_type);
333   graph_info_builder.SetOutputsFormat(graph_output_format);
334   graph_info_builder.SetOutputsDeviceType(graph_output_type);
335   graph_info_builder.SetProcessor(kernel::Processor::CUDA);
336   graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
337   graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
338   auto graph_selected_info = graph_info_builder.Build();
339   MS_EXCEPTION_IF_NULL(graph_selected_info);
340   AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
341   SetTensorDeviceInfo(*graph_selected_info, kernel_node);
342 }
343 
PrintUnsupportedTypeException(const CNodePtr & kernel_node,const std::vector<TypeId> & inputs_type,const std::vector<TypeId> & outputs_type)344 void PrintUnsupportedTypeException(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type,
345                                    const std::vector<TypeId> &outputs_type) {
346   auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
347   std::string build_type = "in [";
348   std::for_each(std::begin(inputs_type), std::end(inputs_type),
349                 [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
350   build_type += "] out [";
351   std::for_each(std::begin(outputs_type), std::end(outputs_type),
352                 [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
353   build_type += "]";
354   auto supported_type_lists = SupportedTypeList(kernel_node);
355   MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name
356                           << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
357                           << ", but get " << build_type;
358 }
359 }  // namespace
360 
CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> & kernel_graph)361 void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
362   MS_EXCEPTION_IF_NULL(kernel_graph);
363   // TensorCore can be used only in Volta or newer devices.
364   const int marjor_sm = GET_MAJOR_SM;
365   if (marjor_sm < RECOMMEND_SM) {
366     format_transform_ = false;
367     return;
368   }
369   auto kernels = kernel_graph->execution_order();
370   size_t conv_cnt = 0;
371   size_t bn_cnt = 0;
372   for (const auto &kernel : kernels) {
373     auto kernel_name = AnfAlgo::GetCNodeName(kernel);
374     if (kernel_name == prim::kPrimLayerNorm->name()) {
375       format_transform_ = false;
376       return;
377     }
378     auto value = AnfAlgo::GetCNodePrimitive(kernel);
379     if (value != nullptr && value->GetAttr("format") != nullptr &&
380         GetValue<std::string>(value->GetAttr("format")) == kOpFormat_NHWC) {
381       format_transform_ = false;
382       return;
383     }
384     if (kernel_name == prim::kPrimConv2D->name()) {
385       conv_cnt++;
386     }
387     if (kernel_name == prim::kPrimBatchNorm->name()) {
388       bn_cnt++;
389     }
390   }
391   if (conv_cnt + bn_cnt > 1) {
392     format_transform_ = true;
393     return;
394   }
395   format_transform_ = false;
396 }
397 
SetKernelInfo(const CNodePtr & kernel_node,KernelType kernel_type)398 void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
399   MS_EXCEPTION_IF_NULL(kernel_node);
400   if (AnfAlgo::IsGraphKernel(kernel_node)) {
401     auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(kernel_node);
402     MS_EXCEPTION_IF_NULL(func_graph);
403     SetGraphKernelInfo(kernel_node, func_graph);
404     return;
405   }
406   std::vector<std::string> inputs_format;
407   std::vector<TypeId> inputs_type;
408   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
409   for (size_t input_index = 0; input_index < input_num; ++input_index) {
410     inputs_format.emplace_back(kOpFormat_DEFAULT);
411     inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
412   }
413   std::vector<std::string> outputs_format;
414   std::vector<TypeId> outputs_type;
415   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
416   for (size_t output_index = 0; output_index < output_num; ++output_index) {
417     outputs_format.emplace_back(kOpFormat_DEFAULT);
418     outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
419   }
420   std::string origin_data_format = kOpFormat_DEFAULT;
421   if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
422     UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
423   }
424   auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
425   builder->SetOriginDataFormat(origin_data_format);
426   builder->SetInputsFormat(inputs_format);
427   builder->SetInputsDeviceType(inputs_type);
428   builder->SetOutputsFormat(outputs_format);
429   builder->SetOutputsDeviceType(outputs_type);
430   bool result = false;
431   if (kernel_type == UNKNOWN_KERNEL_TYPE) {
432     result =
433       kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
434     if (!result) {
435       result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
436     }
437     if (!result && (!AnfAlgo::IsControlOpExecInBackend(kernel_node))) {
438       result = SelectAkgKernel(kernel_node, builder->Build());
439       kernel_type = AKG_KERNEL;
440     }
441   } else if (kernel_type == AKG_KERNEL) {
442     result = SelectAkgKernel(kernel_node, builder->Build());
443   }
444   if (!result && (!AnfAlgo::IsControlOpExecInBackend(kernel_node))) {
445     PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
446     return;
447   }
448   builder->SetKernelType(kernel_type);
449   builder->SetProcessor(kernel::Processor::CUDA);
450   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
451   SetTensorDeviceInfo(*(builder->Build()), kernel_node);
452 }
453 }  // namespace gpu
454 }  // namespace device
455 }  // namespace mindspore
456