1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/kernel_build_ascend.h"
18
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <set>
23 #include <map>
24 #include "runtime/device/ascend/kernel_select_ascend.h"
25 #include "runtime/device/kernel_info.h"
26 #include "backend/kernel_compiler/kernel.h"
27 #include "backend/kernel_compiler/tbe/ascend_kernel_compile.h"
28 #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
29 #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
30 #include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h"
31 #include "backend/kernel_compiler/host/host_kernel_build.h"
32 #include "backend/kernel_compiler/hccl/hccl_kernel_build.h"
33 #include "backend/kernel_compiler/rts/rt_kernel_build.h"
34 #include "backend/kernel_compiler/tbe/tbe_utils.h"
35 #include "backend/kernel_compiler/common_utils.h"
36 #include "frontend/operator/ops.h"
37 #include "backend/session/anf_runtime_algorithm.h"
38
39 namespace mindspore {
40 namespace device {
41 namespace ascend {
42 using mindspore::kernel::tbe::TbeUtils;
43 using std::make_shared;
44 constexpr size_t kMaxAttrMemListSize = 192;
45
SerialCompileImpl(const AnfNodePtr & anf_node)46 static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
47 kernel::KernelModPtr kernel_mod_ptr = nullptr;
48 KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
49 switch (kernel_type) {
50 case KernelType::AICPU_KERNEL: {
51 kernel_mod_ptr = kernel::AicpuOpBuild(anf_node);
52 break;
53 }
54 case KernelType::HOST_KERNEL: {
55 kernel_mod_ptr = kernel::HostOpBuild(anf_node);
56 break;
57 }
58 case KernelType::RT_KERNEL: {
59 kernel_mod_ptr = kernel::RtOpBuild(anf_node);
60 break;
61 }
62 case KernelType::HCCL_KERNEL: {
63 kernel_mod_ptr = kernel::HcclOpBuild(anf_node);
64 break;
65 }
66 default: {
67 MS_EXCEPTION_IF_NULL(anf_node);
68 MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type;
69 }
70 }
71 return kernel_mod_ptr;
72 }
73
KernelBuildParallelCompile(const std::vector<CNodePtr> & kernels)74 static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
75 std::vector<AnfNodePtr> tbe_nodes;
76 std::vector<AnfNodePtr> akg_nodes;
77 std::vector<AnfNodePtr> other_nodes;
78 for (const auto &anf_node : kernels) {
79 MS_EXCEPTION_IF_NULL(anf_node);
80 if (!AnfAlgo::IsRealKernel(anf_node)) {
81 continue;
82 }
83 KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
84 switch (kernel_type) {
85 case KernelType::TBE_KERNEL: {
86 if (AnfAlgo::GetKernelMod(anf_node) == nullptr) {
87 tbe_nodes.push_back(anf_node);
88 }
89 break;
90 }
91 case KernelType::AKG_KERNEL: {
92 akg_nodes.push_back(anf_node);
93 break;
94 }
95 default: {
96 other_nodes.push_back(anf_node);
97 break;
98 }
99 }
100 }
101 bool tbe_ret = true;
102 bool akg_ret = true;
103 auto bin_map = kernel::tbe::KernelMeta::GetInstance();
104 MS_EXCEPTION_IF_NULL(bin_map);
105 if (!tbe_nodes.empty()) {
106 std::string old_build = common::GetEnv("MS_OLD_BUILD_PROCESS");
107 if (!old_build.empty()) {
108 tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes);
109 } else {
110 auto &build_manager = kernel::ascend::AscendKernelCompileManager::GetInstance();
111 tbe_ret = build_manager.AscendSingleOpCompile(tbe_nodes);
112 build_manager.ResetOldTask();
113 }
114 auto config_path = TbeUtils::GetOpDebugPath();
115 std::string dir = config_path + "kernel_meta/";
116 (void)bin_map->ReadIndex(dir);
117 }
118 if (!akg_nodes.empty()) {
119 kernel::AkgAscendKernelBuilder akg_ascend_kernel_builder;
120 akg_ret = akg_ascend_kernel_builder.AkgKernelParallelBuild(akg_nodes);
121 (void)bin_map->ReadIndex(kernel::kCceKernelMeta);
122 }
123 for (const auto &anf_node : other_nodes) {
124 kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
125 MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
126 AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
127 }
128 return tbe_ret && akg_ret;
129 }
130
CalCleanZerosSize(const CNodePtr & pre_node)131 static std::vector<size_t> CalCleanZerosSize(const CNodePtr &pre_node) {
132 MS_EXCEPTION_IF_NULL(pre_node);
133 auto kernel_mod = AnfAlgo::GetKernelMod(pre_node);
134 MS_EXCEPTION_IF_NULL(kernel_mod);
135 std::vector<size_t> clean_size_list;
136 constexpr size_t kAlignBytes = 32 - 1;
137 // clean output
138 if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
139 auto output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
140 auto output_men_size = kernel_mod->GetOutputSizeList();
141 for (auto index : output_indexs) {
142 auto clean_item = (output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize;
143 clean_size_list.emplace_back(clean_item);
144 }
145 }
146 // clean workspace
147 if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
148 auto workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
149 auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList();
150 for (const auto &index : workspace_indexs) {
151 auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize;
152 clean_size_list.emplace_back(clean_item);
153 }
154 }
155 MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope();
156 return clean_size_list;
157 }
158
AddTbeClearZeroNode(mindspore::session::KernelGraph * const kernel_graph,const mindspore::CNodePtr & pre_node,std::vector<mindspore::CNodePtr> * new_nodes)159 static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
160 const mindspore::CNodePtr &pre_node, std::vector<mindspore::CNodePtr> *new_nodes) {
161 MS_EXCEPTION_IF_NULL(kernel_graph);
162 MS_EXCEPTION_IF_NULL(pre_node);
163 MS_EXCEPTION_IF_NULL(new_nodes);
164 auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
165 MS_EXCEPTION_IF_NULL(clear_zero_prim);
166 auto new_value_node = NewValueNode(clear_zero_prim);
167 MS_EXCEPTION_IF_NULL(new_value_node);
168 std::vector<AnfNodePtr> inputs = {new_value_node};
169 inputs.push_back(pre_node);
170 CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
171 MS_EXCEPTION_IF_NULL(clear_zero);
172 AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
173 MS_EXCEPTION_IF_NULL(abstract);
174 clear_zero->set_abstract(abstract);
175 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
176 MS_EXCEPTION_IF_NULL(builder);
177 builder->SetKernelType(KernelType::TBE_KERNEL);
178 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
179 auto clean_size = CalCleanZerosSize(pre_node);
180 AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero);
181 AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get());
182 new_nodes->push_back(clear_zero);
183 }
184
AddFusionTbeClearZeroNode(mindspore::session::KernelGraph * const kernel_graph,const mindspore::CNodePtr & first_clear_node,const std::vector<AnfNodePtr> & fusion_clear_inputs,const std::vector<size_t> & clean_size_list,std::vector<mindspore::CNodePtr> * new_nodes)185 static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
186 const mindspore::CNodePtr &first_clear_node,
187 const std::vector<AnfNodePtr> &fusion_clear_inputs,
188 const std::vector<size_t> &clean_size_list,
189 std::vector<mindspore::CNodePtr> *new_nodes) {
190 MS_EXCEPTION_IF_NULL(first_clear_node);
191 auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
192 MS_EXCEPTION_IF_NULL(clear_zero_prim);
193 auto new_value_node = NewValueNode(clear_zero_prim);
194 MS_EXCEPTION_IF_NULL(new_value_node);
195 std::vector<AnfNodePtr> inputs = {new_value_node};
196 inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end());
197 CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
198 MS_EXCEPTION_IF_NULL(clear_zero);
199 AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
200 MS_EXCEPTION_IF_NULL(abstract);
201 clear_zero->set_abstract(abstract);
202 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
203 MS_EXCEPTION_IF_NULL(builder);
204 builder->SetKernelType(KernelType::TBE_KERNEL);
205 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
206 AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
207 AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(first_clear_node.get()), clear_zero.get());
208 auto it = std::find(new_nodes->begin(), new_nodes->end(), first_clear_node);
209 if (it != new_nodes->end()) {
210 new_nodes->insert(it, clear_zero);
211 } else {
212 new_nodes->insert(new_nodes->begin(), clear_zero);
213 }
214 }
215
IsAtomicNode(const CNodePtr & kernel_node)216 static bool IsAtomicNode(const CNodePtr &kernel_node) {
217 MS_EXCEPTION_IF_NULL(kernel_node);
218 auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
219 MS_EXCEPTION_IF_NULL(kernel_mod);
220 auto parameters_indexs = kernel_mod->GenParameters();
221 if (parameters_indexs.empty()) {
222 return false;
223 }
224 if (AnfAlgo::IsDynamicShape(kernel_node)) {
225 if (parameters_indexs.at(0) == 1) {
226 (void)parameters_indexs.erase(parameters_indexs.begin());
227 } else {
228 parameters_indexs.pop_back();
229 }
230 }
231 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
232 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
233 size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
234 size_t param_num = parameters_indexs.size();
235 size_t total_num = input_num + output_num + workspace_num;
236 size_t pad_index = param_num;
237
238 for (; pad_index < total_num; ++pad_index) {
239 parameters_indexs.emplace_back(0);
240 }
241
242 for (size_t j = 0; j < input_num; ++j) {
243 if (parameters_indexs.at(j) == 1) {
244 MS_LOG(EXCEPTION) << "Atomic addr clean doesn't support clean input address, input index: " << j;
245 }
246 }
247
248 if (parameters_indexs.size() < total_num) {
249 MS_LOG(EXCEPTION) << "Parameters indexes size: " << parameters_indexs.size()
250 << " less than total num: " << total_num;
251 }
252 // process output
253 std::vector<size_t> output_indexs = {};
254 if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, kernel_node)) {
255 output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrAtomicOutputIndexs);
256 }
257
258 for (size_t i = 0; i < output_num; ++i) {
259 auto param_output = parameters_indexs.at(input_num + i);
260 if (param_output == 1) {
261 output_indexs.emplace_back(i);
262 MS_LOG(INFO) << "Atomic clear output index: " << i;
263 }
264 }
265
266 if (!output_indexs.empty()) {
267 std::set<size_t> s(output_indexs.begin(), output_indexs.end());
268 output_indexs.assign(s.begin(), s.end());
269 AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node);
270 }
271 // process workspace
272 std::vector<size_t> workspace_indexs = {};
273 for (size_t k = 0; k < workspace_num; ++k) {
274 auto param_workspace = parameters_indexs.at(input_num + output_num + k);
275 if (param_workspace == 1) {
276 workspace_indexs.emplace_back(k);
277 MS_LOG(INFO) << "Atomic clear workspace index: " << k;
278 }
279 }
280 if (!workspace_indexs.empty()) {
281 AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node);
282 }
283 return !(workspace_indexs.empty() && output_indexs.empty());
284 }
285
KernelBuild(const std::vector<CNodePtr> & kernels)286 bool KernelBuild(const std::vector<CNodePtr> &kernels) {
287 TbeUtils::LoadCache();
288 return device::ascend::KernelBuildParallelCompile(kernels);
289 }
290
GetCommunicationOpInputInfo(const mindspore::session::KernelGraph * kernel_graph)291 std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
292 const mindspore::session::KernelGraph *kernel_graph) {
293 MS_EXCEPTION_IF_NULL(kernel_graph);
294 std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map;
295 for (auto &kernel : kernel_graph->execution_order()) {
296 MS_EXCEPTION_IF_NULL(kernel);
297 auto input_num = AnfAlgo::GetInputTensorNum(kernel);
298 if (mindspore::session::AnfRuntimeAlgorithm::IsCommunicationOp(kernel)) {
299 for (size_t i = 0; i < input_num; i++) {
300 auto input_node = kernel->input(i + 1);
301 auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
302 MS_EXCEPTION_IF_NULL(kernel_input.first);
303 if (!kernel_input.first->isa<CNode>()) {
304 continue;
305 }
306 auto cnode = kernel_input.first->cast<CNodePtr>();
307 MS_EXCEPTION_IF_NULL(cnode);
308 if (AnfAlgo::IsCommunicationOp(cnode) || AnfAlgo::IsIndependentNode(cnode) ||
309 AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
310 // no need to add atomic for communication/independent/getnext op 's output
311 MS_LOG(INFO) << "No need to add atomic clean for op " << kernel_input.first->fullname_with_scope()
312 << "'s output";
313 continue;
314 }
315 MS_LOG(INFO) << "Add atomic clean for single communication op input, comm:" << kernel->fullname_with_scope()
316 << " input_node: " << kernel_input.first->fullname_with_scope()
317 << " index: " << kernel_input.second;
318 auto iter = comm_input_info_map.find(kernel_input.first);
319 if (iter != comm_input_info_map.end()) {
320 iter->second.push_back(kernel_input.second);
321 } else {
322 std::vector<size_t> indexes = {kernel_input.second};
323 comm_input_info_map[kernel_input.first] = indexes;
324 }
325 }
326 }
327 }
328
329 // remove duplicate index
330 for (auto &info : comm_input_info_map) {
331 std::set<size_t> s(info.second.begin(), info.second.end());
332 info.second.assign(s.begin(), s.end());
333 }
334
335 return comm_input_info_map;
336 }
337
IsNeedClearZeroNodeFusion(const size_t clean_total_num,const mindspore::CNodePtr & first_node,const mindspore::CNodePtr & current_node)338 bool IsNeedClearZeroNodeFusion(const size_t clean_total_num, const mindspore::CNodePtr &first_node,
339 const mindspore::CNodePtr ¤t_node) {
340 if (first_node == nullptr || current_node == nullptr) {
341 return false;
342 }
343 auto first_graph_id = AnfAlgo::GetGraphId(first_node.get());
344 auto current_graph_id = AnfAlgo::GetGraphId(current_node.get());
345 if (clean_total_num >= kMaxAttrMemListSize || first_graph_id != current_graph_id) {
346 return true;
347 }
348 return false;
349 }
350
TbeClearZeroNodeFusion(mindspore::session::KernelGraph * const kernel_graph)351 static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
352 std::vector<CNodePtr> new_nodes;
353 std::vector<size_t> clean_size_list;
354 std::vector<AnfNodePtr> fusion_clear_inputs;
355 CNodePtr first_node = nullptr;
356 std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
357 for (const auto &anf_node : kernel_graph->execution_order()) {
358 std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
359 bool is_comm_input = false;
360 // set communication input output index attr
361 if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
362 auto indexes = comm_input_info_map[anf_node];
363 AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
364 is_comm_input = true;
365 }
366
367 if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
368 AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
369 auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
370 MS_EXCEPTION_IF_NULL(clear_zero_prim);
371 auto new_value_node = NewValueNode(clear_zero_prim);
372 MS_EXCEPTION_IF_NULL(new_value_node);
373 std::vector<AnfNodePtr> inputs = {new_value_node};
374 inputs.push_back(anf_node);
375 CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
376 MS_EXCEPTION_IF_NULL(clear_zero);
377 auto kernel_info = std::make_shared<device::KernelInfo>();
378 MS_EXCEPTION_IF_NULL(kernel_info);
379 clear_zero->set_kernel_info(kernel_info);
380 AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
381 MS_EXCEPTION_IF_NULL(abstract);
382 AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
383 SelectKernelInfo(clear_zero);
384 // set the distinction label of clear same with anf
385 AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
386 new_nodes.push_back(clear_zero);
387 } else if (is_comm_input ||
388 (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) {
389 auto clean_sizes = CalCleanZerosSize(anf_node);
390 if (!clean_sizes.empty()) {
391 auto clean_total_num = clean_size_list.size() + clean_sizes.size();
392 if (IsNeedClearZeroNodeFusion(clean_total_num, first_node, anf_node)) {
393 // create clean node
394 AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
395 clean_size_list.clear();
396 fusion_clear_inputs.clear();
397 first_node = nullptr;
398 }
399 if (fusion_clear_inputs.empty()) {
400 first_node = anf_node;
401 }
402 clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
403 fusion_clear_inputs.emplace_back(anf_node);
404 MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size()
405 << ", clean_size_list: " << clean_size_list.size();
406 }
407 }
408 new_nodes.emplace_back(anf_node);
409 }
410
411 if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
412 // create clean node
413 AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
414 }
415 kernel_graph->set_execution_order(new_nodes);
416 }
417
KernelBuildPreprocess(mindspore::session::KernelGraph * kernel_graph)418 void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
419 MS_EXCEPTION_IF_NULL(kernel_graph);
420 static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
421 bool is_dynamic_graph = kernel_graph->is_dynamic_shape();
422 if (!is_dynamic_graph && enable_fusion_clear) {
423 TbeClearZeroNodeFusion(kernel_graph);
424 } else {
425 std::vector<CNodePtr> new_nodes;
426 std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
427 for (const auto &anf_node : kernel_graph->execution_order()) {
428 std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
429 bool is_comm_input = false;
430 if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
431 auto indexes = comm_input_info_map[anf_node];
432 AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
433 is_comm_input = true;
434 }
435
436 if (is_comm_input) {
437 AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
438 } else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
439 AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
440 auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
441 MS_EXCEPTION_IF_NULL(clear_zero_prim);
442 auto new_value_node = NewValueNode(clear_zero_prim);
443 MS_EXCEPTION_IF_NULL(new_value_node);
444 std::vector<AnfNodePtr> inputs = {new_value_node};
445 inputs.push_back(anf_node);
446 CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
447 MS_EXCEPTION_IF_NULL(clear_zero);
448 auto kernel_info = std::make_shared<device::KernelInfo>();
449 MS_EXCEPTION_IF_NULL(kernel_info);
450 clear_zero->set_kernel_info(kernel_info);
451 AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
452 MS_EXCEPTION_IF_NULL(abstract);
453 AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
454 SelectKernelInfo(clear_zero);
455 // set the distinction label of clear same with anf
456 AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
457 new_nodes.push_back(clear_zero);
458 } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
459 if (IsAtomicNode(anf_node)) {
460 AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
461 }
462 }
463 new_nodes.push_back(anf_node);
464 }
465 kernel_graph->set_execution_order(new_nodes);
466 }
467 }
468 } // namespace ascend
469 } // namespace device
470 } // namespace mindspore
471