1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/mem_reuse/mem_reuse.h"
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include "utils/context/graph_kernel_flags.h"
22 #include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
23 #include "backend/optimizer/common/helper.h"
24
25 namespace mindspore {
26 namespace memreuse {
InitDynamicOutputKernelRef()27 bool MemReuseUtil::InitDynamicOutputKernelRef() {
28 int index = util_index_;
29 auto kernel_cnodes = graph_->execution_order();
30 if (kernel_cnodes.empty()) {
31 return true;
32 }
33 int kernel_out_ref_num = 0;
34 for (auto &kernel_cnode : kernel_cnodes) {
35 #ifdef MEM_REUSE_DEBUG
36 MemReuseChecker::GetInstance().CheckSignalOps(kernel_cnode);
37 #endif
38 if (kernel_cnode == nullptr) {
39 return false;
40 }
41 auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode);
42 if (kernel_mod == nullptr) {
43 return false;
44 }
45 auto key = kernel_cnode.get();
46 // for every apply_kernel to set new output
47 auto iter = kernel_output_refs_.find(key);
48 if (iter == kernel_output_refs_.end()) {
49 auto output_sizes = kernel_mod->GetOutputSizeList();
50 KernelRefCountPtrList kernel_refs;
51 bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel_cnode);
52 size_t output_index = 0;
53 for (auto size : output_sizes) {
54 total_dy_size_ += size;
55 // do not MallocDynamicMem just record this
56 KernelRefCountPtr kernel_ref = std::make_shared<KernelRefCount>();
57 index++;
58 auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode);
59 kernel_ref->stream_id_ = curr_stream_id;
60 kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount);
61 if (is_comm_op) {
62 kernel_ref->type_ = kCommReuse;
63 } else {
64 session::AnfWithOutIndex out_pair(kernel_cnode, output_index);
65 if (graph_->IsInRefOutputMap(out_pair)) {
66 kernel_ref->type_ = kRefNodeOutput;
67 auto origin_pair = graph_->GetRefCorrespondOutput(out_pair);
68 MS_EXCEPTION_IF_NULL(origin_pair.first);
69 MS_LOG(INFO) << "REF origin op is " << origin_pair.first->fullname_with_scope() << ", output index is "
70 << origin_pair.second << ", cur op is " << kernel_cnode->fullname_with_scope()
71 << ", out index is " << output_index;
72 if (origin_pair.first->isa<CNode>()) {
73 auto cnode = origin_pair.first->cast<CNodePtr>();
74 auto ref_ptr = GetRef(cnode, origin_pair.second);
75 if (ref_ptr != nullptr) {
76 ref_ptr->type_ = kRefNodeInput;
77 }
78 }
79 } else {
80 kernel_ref->type_ = kCommon;
81 }
82 }
83 kernel_refs.push_back(kernel_ref);
84 kernel_out_ref_num++;
85 total_refs_list_.push_back(kernel_ref);
86 output_index++;
87 }
88 if (!kernel_refs.empty()) {
89 kernel_output_refs_[key] = kernel_refs;
90 }
91 }
92 }
93 return true;
94 }
95
InitDynamicWorkspaceKernelRef()96 bool MemReuseUtil::InitDynamicWorkspaceKernelRef() {
97 int WkIndex = util_index_;
98 auto kernel_cnodes = graph_->execution_order();
99 if (kernel_cnodes.empty()) {
100 return true;
101 }
102 for (auto &kernel_cnode : kernel_cnodes) {
103 if (kernel_cnode == nullptr) {
104 return false;
105 }
106 auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode);
107 if (kernel_mod == nullptr) {
108 return false;
109 }
110 auto key = kernel_cnode.get();
111 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
112 KernelRefCountPtrList workspace_kernel_refs;
113 for (auto size : workspace_sizes) {
114 total_workspace_size_ += size;
115 ++WkIndex;
116 KernelRefCountPtr workspace_ref = std::make_shared<KernelRefCount>();
117 workspace_ref->SetKernelRefCountInfo(WkIndex, size, kDynamicRefCount);
118 workspace_kernel_refs.push_back(workspace_ref);
119 // total wk ref
120 total_wk_ref_list_.push_back(workspace_ref);
121 }
122 if (!workspace_kernel_refs.empty()) {
123 // every key index wk_refs
124 kernel_workspace_refs_[key] = workspace_kernel_refs;
125 }
126 }
127 return true;
128 }
129
InitDynamicKernelRef(const KernelGraph * graph)130 bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) {
131 MS_EXCEPTION_IF_NULL(graph);
132 graph_ = graph;
133 is_all_nop_node_ = opt::IsAllNopNode(graph);
134 if (!InitDynamicOutputKernelRef()) {
135 MS_LOG(INFO) << "InitDynamicOutputKernelRef fail";
136 return false;
137 }
138 if (!InitDynamicWorkspaceKernelRef()) {
139 MS_LOG(INFO) << "InitDynamicWorkspaceKernelRef fail";
140 return false;
141 }
142 return true;
143 }
144
145 // set longest worspace list && largest workspace sizes
SetWorkSpaceList()146 void MemReuseUtil::SetWorkSpaceList() {
147 int max_list_size = 0;
148 std::vector<size_t> total_sizes;
149 std::vector<size_t> max_list;
150 auto kernel_cnodes = graph_->execution_order();
151 for (auto &kernel_cnode : kernel_cnodes) {
152 MS_EXCEPTION_IF_NULL(kernel_cnode);
153 auto cnode_key = kernel_cnode.get();
154 auto cnode_iter = kernel_workspace_refs_.find(cnode_key);
155 if (cnode_iter != kernel_workspace_refs_.end()) {
156 auto kernel_refs = cnode_iter->second;
157 std::vector<size_t> current_list;
158 for (size_t i = 0; i < kernel_refs.size(); ++i) {
159 auto size = kernel_refs[i]->size_;
160 current_list.push_back(size);
161 }
162 if (max_list_size < SizeToInt(current_list.size())) {
163 max_list_size = SizeToInt(current_list.size());
164 }
165 (void)std::copy(current_list.begin(), current_list.end(), std::back_inserter(total_sizes));
166 }
167 }
168 sort(total_sizes.rbegin(), total_sizes.rend());
169 max_list.resize(IntToSize(max_list_size));
170 if (SizeToInt(total_sizes.size()) < max_list_size) {
171 MS_LOG(EXCEPTION) << "total workspace size is less than required max list size";
172 }
173 max_list.assign(total_sizes.begin(), total_sizes.begin() + max_list_size);
174 for (auto &ma : max_list) {
175 total_reuseworkspace_size_ += ma;
176 }
177 max_workspace_size_ = max_list_size;
178 max_workspace_list_ = max_list;
179 }
180
SetInputMap(const CNodePtr & kernel,KernelDef * kernel_def_ptr)181 void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) {
182 MS_EXCEPTION_IF_NULL(kernel);
183 MS_EXCEPTION_IF_NULL(kernel_def_ptr);
184 auto key = kernel.get();
185 bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel);
186 size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
187 for (size_t i = 0; i < input_tensor_num; ++i) {
188 auto ref_ptr = GetKernelInputRef(kernel, i);
189 if (ref_ptr != nullptr) {
190 if (is_comm_op) {
191 if (input_tensor_num == 1) {
192 ref_ptr->type_ = kCommReuse;
193 } else {
194 ref_ptr->type_ = kCommNotReuse;
195 }
196 }
197
198 if (ref_ptr->reftype() == kStaticRefCount) {
199 continue;
200 } else if (ref_ptr->reftype() == kDynamicRefCount) {
201 auto iter = kernel_def_ptr->inputs_.find(key);
202 if (iter == kernel_def_ptr->inputs_.end()) {
203 kernel_def_ptr->inputs_[key].push_back(ref_ptr);
204 } else {
205 iter->second.push_back(ref_ptr);
206 }
207 }
208 }
209 }
210 }
211
SetOutputMap(const CNodePtr & kernel,KernelDef * kernel_def_ptr)212 void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) {
213 MS_EXCEPTION_IF_NULL(kernel);
214 MS_EXCEPTION_IF_NULL(kernel_def_ptr);
215 auto key = kernel.get();
216 auto iter = kernel_def_ptr->outputs_.find(key);
217 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
218 MS_EXCEPTION_IF_NULL(kernel_mod);
219 for (size_t k = 0; k < kernel_mod->GetOutputSizeList().size(); ++k) {
220 KernelRefCountPtr kernel_ref = kernel_output_refs_[key][k];
221 if (iter == kernel_def_ptr->outputs_.end()) {
222 kernel_def_ptr->outputs_[key].push_back(kernel_ref);
223 } else {
224 iter->second.push_back(kernel_ref);
225 }
226 }
227 }
228
SetWkMap(const CNodePtr & kernel,KernelDef * kernel_def_ptr)229 void MemReuseUtil::SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) {
230 MS_EXCEPTION_IF_NULL(kernel);
231 MS_EXCEPTION_IF_NULL(kernel_def_ptr);
232 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
233 MS_EXCEPTION_IF_NULL(kernel_mod);
234 auto key = kernel.get();
235 for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
236 if (kernel_workspace_refs_.find(key) != kernel_workspace_refs_.end()) {
237 auto wk_refs = kernel_workspace_refs_[key];
238 if (i < wk_refs.size()) {
239 auto wk_ref = wk_refs[i];
240 kernel_def_ptr->wk_space_[key].push_back(wk_ref);
241 } else {
242 MS_LOG(EXCEPTION) << "current index: " << i << " larger than wk_refs size " << wk_refs.size();
243 }
244 } else {
245 MS_LOG(EXCEPTION) << "kernel_workspace_refs_ init error";
246 }
247 }
248 }
249
GetRef(const AnfNodePtr & node,size_t output_idx)250 KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, size_t output_idx) {
251 if (node == nullptr) {
252 MS_LOG(EXCEPTION) << "The node pointer is a nullptr.";
253 }
254 // Get ref count for cnode, except monad cnode.
255 if (node->isa<CNode>() && !HasAbstractMonad(node)) {
256 auto ak_node = node->cast<CNodePtr>();
257 auto key = ak_node.get();
258 MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, output_idx);
259 return kernel_output_refs_[key][output_idx];
260 }
261 return nullptr;
262 }
263
GetKernelInputRef(const CNodePtr & kernel,size_t input_idx)264 KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) {
265 if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) {
266 MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number "
267 << AnfAlgo::GetInputTensorNum(kernel);
268 }
269 auto input_node = kernel->input(input_idx + 1);
270 // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
271 session::KernelWithIndex kernel_input;
272 if (is_all_nop_node_) {
273 // The graph does not remove the nop node.
274 kernel_input = VisitKernelWithReturnType(input_node, 0, false);
275 } else {
276 // The graph removes the nop node.
277 kernel_input = VisitKernelWithReturnType(input_node, 0, true);
278 }
279 if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
280 MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
281 }
282 auto result = GetRef(kernel_input.first, kernel_input.second);
283 return result;
284 }
285
SetKernelDefMap()286 void MemReuseUtil::SetKernelDefMap() {
287 auto kernel_cnodes = graph_->execution_order();
288 for (auto &kernel : kernel_cnodes) {
289 KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>();
290 kernel_def_ptr->set_kernel_name(AnfAlgo::GetCNodeName(kernel));
291 kernel_def_ptr->set_scope_full_name(kernel->fullname_with_scope());
292 kernel_def_ptr->set_stream_id(AnfAlgo::GetStreamId(kernel));
293 SetInputMap(kernel, kernel_def_ptr.get());
294 SetOutputMap(kernel, kernel_def_ptr.get());
295 SetWkMap(kernel, kernel_def_ptr.get());
296 auto key = kernel.get();
297 kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
298 kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
299 if (AnfAlgo::IsCommunicationOp(kernel)) {
300 kernel_def_ptr->type_ = kCommunicationNode;
301 } else {
302 kernel_def_ptr->type_ = kCommonNode;
303 }
304 kernel_def_ptr_list_.push_back(kernel_def_ptr);
305 kernel_map_[key] = kernel_def_ptr;
306 }
307 SetKernelDefInputs();
308 }
309
SetKernelDefInputs()310 void MemReuseUtil::SetKernelDefInputs() {
311 for (const auto &kernel : graph_->execution_order()) {
312 MS_EXCEPTION_IF_NULL(kernel);
313 auto key = kernel.get();
314 // find kernel_def according to cnode addr
315 auto iter = kernel_map_.find(key);
316 if (iter == kernel_map_.end()) {
317 MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init.";
318 }
319 auto kernel_def = iter->second;
320 size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
321 for (size_t i = 0; i < input_num; ++i) {
322 auto ref_ptr = GetKernelInputRef(kernel, i);
323 if (ref_ptr != nullptr) {
324 // set the inputs of this kernel_def
325 auto input_node = AnfAlgo::GetInputNode(kernel, i);
326 // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
327 session::KernelWithIndex input;
328 if (is_all_nop_node_) {
329 // The graph does not remove the nop node.
330 input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
331 } else {
332 // The graph removes the nop node.
333 input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
334 }
335 if (IsPrimitive(input.first, prim::kPrimMakeTuple)) {
336 MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
337 }
338 auto input_key = (input.first).get();
339 auto input_iter = kernel_map_.find(input_key);
340 if (input_iter == kernel_map_.end()) {
341 MS_LOG(EXCEPTION) << "kernel [" << (input.first)->fullname_with_scope() << "] is not init.";
342 }
343 kernel_def->InsertInputKernel(input_iter->second);
344 }
345 }
346 }
347 }
348
SetReuseRefCount()349 void MemReuseUtil::SetReuseRefCount() {
350 auto kernels = graph_->execution_order();
351 for (auto &kernel : kernels) {
352 auto key = kernel.get();
353 for (auto &def : kernel_def_ptr_list_) {
354 auto iter = def->inputs_.find(key);
355 if (iter != def->inputs_.end()) {
356 for (auto &input : iter->second) {
357 input->ref_count_++;
358 input->ref_count_dynamic_use_++;
359 }
360 }
361 }
362 }
363 }
364
365 #ifndef ENABLE_SECURITY
SetSummaryNodesRefCount()366 void MemReuseUtil::SetSummaryNodesRefCount() {
367 bool summary_exist = graph_->summary_node_exist();
368 if (!summary_exist) {
369 return;
370 }
371
372 auto summary_nodes = graph_->summary_nodes();
373 if (summary_nodes.empty()) {
374 return;
375 }
376
377 size_t total_summary_size = 0;
378 for (auto &node_item : summary_nodes) {
379 auto node = node_item.second.first;
380 size_t index = IntToSize(node_item.second.second);
381 if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) {
382 KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
383 kernel_ref->ref_count_ = kMaxRefCount;
384 kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
385 kernel_ref->type_ = kSummary;
386 total_summary_size += kernel_ref->size_;
387 MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
388 } else {
389 MS_LOG(INFO) << "Can't find summary node's kernel_def " << node->fullname_with_scope() << " index: " << index;
390 }
391 }
392 #ifdef MEM_REUSE_DEBUG
393 MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_);
394 #endif
395 MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
396 }
397 #endif
398
SetRefNodesInputRefCount()399 void MemReuseUtil::SetRefNodesInputRefCount() {
400 size_t total_size = 0;
401 for (auto iter : kernel_output_refs_) {
402 for (auto &ref_count : iter.second) {
403 MS_EXCEPTION_IF_NULL(ref_count);
404 if (ref_count->type_ == kRefNodeInput) {
405 ref_count->ref_count_ = kMaxRefCount;
406 total_size += ref_count->size_;
407 }
408 }
409 }
410
411 MS_LOG(INFO) << "Special Tensor total size: RefNodeInput: " << total_size;
412 #ifdef MEM_REUSE_DEBUG
413 MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_);
414 #endif
415 }
416
SetGraphOutputRefCount()417 void MemReuseUtil::SetGraphOutputRefCount() {
418 auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
419 for (const auto &node : nodes) {
420 session::KernelWithIndex kernel_input;
421 if (is_all_nop_node_) {
422 // The graph does not remove the nop node.
423 kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false);
424 } else {
425 // The graph removes the nop node.
426 kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
427 }
428 MS_EXCEPTION_IF_NULL(kernel_input.first);
429 if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
430 continue;
431 }
432 auto ak_node = kernel_input.first->cast<CNodePtr>();
433 auto key = ak_node.get();
434 auto iter = kernel_output_refs_.find(key);
435 if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
436 auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
437 MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
438 kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
439 kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
440 }
441 }
442 #ifdef MEM_REUSE_DEBUG
443 MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_);
444 #endif
445 }
446
ResetDynamicUsedRefCount()447 void MemReuseUtil::ResetDynamicUsedRefCount() {
448 for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) {
449 for (auto &ref_count : iter->second) {
450 MS_EXCEPTION_IF_NULL(ref_count);
451 ref_count->ref_count_dynamic_use_ = ref_count->ref_count_;
452 }
453 }
454 }
455
SetAllInfo(const KernelGraph * graph)456 void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
457 if (!InitDynamicKernelRef(graph)) {
458 MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault";
459 }
460 SetKernelDefMap();
461 SetReuseRefCount();
462 #ifndef ENABLE_SECURITY
463 SetSummaryNodesRefCount();
464 #endif
465 SetRefNodesInputRefCount();
466 SetWorkSpaceList();
467 #ifdef MEM_REUSE_DEBUG
468 MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
469 #endif
470
471 enable_visit_kernel_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel();
472 }
473
GetNodeOutputPtr(const AnfNodePtr & node,size_t index) const474 uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
475 auto key = node.get();
476 auto iter = kernel_output_refs_.find(key);
477 uint8_t *ptr = nullptr;
478 if (iter != kernel_output_refs_.end()) {
479 if (index >= iter->second.size()) {
480 MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
481 }
482 auto output_ref = iter->second[index];
483 ptr = mem_base_ + output_ref->offset_;
484 } else {
485 MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
486 }
487 return ptr;
488 }
489
GetNodeWorkSpacePtr(const AnfNodePtr & node,size_t index) const490 uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const {
491 auto key = node.get();
492 auto iter = kernel_workspace_refs_.find(key);
493 uint8_t *ptr = nullptr;
494 if (iter != kernel_workspace_refs_.end()) {
495 if (index >= iter->second.size()) {
496 MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
497 }
498 auto wk_ref = iter->second[index];
499 ptr = mem_base_ + wk_ref->offset_;
500 }
501 return ptr;
502 }
503
VisitKernelWithReturnType(const AnfNodePtr & node,size_t i,bool visit_nop_node)504 session::KernelWithIndex MemReuseUtil::VisitKernelWithReturnType(const AnfNodePtr &node, size_t i,
505 bool visit_nop_node) {
506 if (!enable_visit_kernel_cache_ || i != 0) {
507 return AnfAlgo::VisitKernelWithReturnType(node, i, visit_nop_node);
508 }
509
510 auto &cache =
511 visit_nop_node ? visit_kernel_with_return_type_in0pos_cache_ : visit_kernel_with_return_type_in0pos_skip_nop_cache_;
512 std::unordered_map<AnfNodePtr, session::KernelWithIndex>::iterator tag_iter;
513 if (auto iter = cache.find(node); iter == cache.end()) {
514 auto tmp_item = std::pair<AnfNodePtr, session::KernelWithIndex>{
515 node, AnfAlgo::VisitKernelWithReturnType(node, i, visit_nop_node)};
516 tag_iter = cache.emplace(tmp_item).first;
517 } else {
518 tag_iter = iter;
519 }
520 return tag_iter->second;
521 }
522 } // namespace memreuse
523 } // namespace mindspore
524