• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 
17 #include "backend/optimizer/somas/somas.h"
18 #include <algorithm>
19 #include <cstdio>
20 #include <fstream>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <set>
25 
26 #include "backend/optimizer/somas/somas_node.h"
27 #include "backend/optimizer/somas/somas_solver_pre.h"
28 #include "backend/optimizer/somas/somas_stream.h"
29 #include "backend/optimizer/somas/somas_tensor.h"
30 #ifdef ENABLE_D
31 #include "runtime/device/ascend/ascend_stream_assign.h"
32 #endif
33 #include "backend/optimizer/common/helper.h"
34 #include "utils/ms_context.h"
35 #include "debug/common.h"
36 #ifdef ENABLE_DUMP_IR
37 #include "debug/rdr/running_data_recorder.h"
38 #endif
39 #include "common/thread_pool.h"
40 #ifndef ENABLE_SECURITY
41 #include "profiler/device/ascend/memory_profiling.h"
42 
43 using mindspore::profiler::ascend::MemoryProfiling;
44 using mindspore::profiler::ascend::NodeMemory;
45 using mindspore::profiler::ascend::TensorMemory;
46 #endif
47 namespace mindspore {
48 namespace somas {
49 constexpr auto kGapSize = 512;
50 constexpr auto kRetryIntervalSeconds = 500;
51 constexpr size_t kRefNodeTensorNum = 2;
52 
53 constexpr auto kGraphId = "graph_id";
54 constexpr auto kHashId = "hash_id";
55 constexpr auto kMemOffset = "mem_offset";
56 constexpr auto kNodeSize = "node_size";
57 constexpr auto kTensorSize = "tensor_size";
58 constexpr auto kContiguousSize = "contiguous_size";
59 constexpr auto kRefNodeSize = "ref_node_size";
60 constexpr auto kStreamSize = "stream_size";
61 constexpr auto kStreamGroupSize = "stream_group_size";
62 constexpr auto kTensors = "tensors";
63 
64 constexpr auto kTensorId = "tensor_id";
65 constexpr auto kSize = "size";
66 constexpr auto kOriSize = "ori_size";
67 constexpr auto kLifelongValue = "lifelong_value";
68 constexpr auto kLifeStart = "life_start";
69 constexpr auto kLifeEnd = "life_end";
70 constexpr auto kOffset = "offset";
71 constexpr auto kCachedResultThreshold = 2000;
72 
73 std::map<TensorType, std::string> tensor_type_name_map = {{kCommon, "Common"},
74                                                           {kOutputOnly, "OutputOnly"},
75                                                           {kWorkspace, "Workspace"},
76                                                           {kGetNextOutput, "GetNextOutput"},
77                                                           {kSummaryInput, "SummaryInput"},
78                                                           {kRefNodeInput, "RefNodeInput"},
79                                                           {kRefNodeOutput, "RefNodeOutput"},
80                                                           {kUnknown, "Unknown"}};
81 
82 std::map<LifeLongType, std::string> life_long_name_map = {{kLifeLongNone, "LifeLongNone"},
83                                                           {kLifeLongGraphAll, "LifeLongGraphAll"},
84                                                           {kLifeLongGraphStart, "LifeLongGraphStart"},
85                                                           {kLifeLongGraphEnd, "LifeLongGraphEnd"}};
86 
Allocate(const session::KernelGraph * graph)87 bool Somas::Allocate(const session::KernelGraph *graph) {
88   auto ret = InitSomasTensors(graph);
89   if (!ret) {
90     MS_LOG(EXCEPTION) << "Somas Initialize Failed.";
91   }
92 
93   if (tensors_list_.empty()) {
94     MS_LOG(INFO) << "No Tensor for Somas";
95     return true;
96   }
97 
98   ret = LoadSomasCache(graph);
99   if (ret) {
100     GenGraphStatisticInfo();
101     return ret;
102   }
103 
104   // Computing Conflict pairs
105   MS_LOG(INFO) << "Start Computing Conflict Pairs";
106   ComputeConflictPairs();
107   MS_LOG(INFO) << "End Computing Conflict Pairs";
108 
109   ret = Assign(graph);
110   if (!ret) {
111     MS_LOG(EXCEPTION) << "Somas Assign Failed.";
112   }
113   SaveSomasResult(graph);
114   GenGraphStatisticInfo();
115   return ret;
116 }
117 
LoadSomasCache(const session::KernelGraph * graph)118 bool Somas::LoadSomasCache(const session::KernelGraph *graph) {
119   MS_EXCEPTION_IF_NULL(graph);
120   if (tensors_list_.size() < kCachedResultThreshold) {
121     MS_LOG(DEBUG) << "Tensors size (" << tensors_list_.size() << ") less than " << kCachedResultThreshold
122                   << ", no need to load cached";
123     return false;
124   }
125 
126   bool ret = CalcSomasModelHash(graph);
127   if (ret) {
128     std::string filename = GetSaveGraphsPathName(
129       "/somas_meta/somas_graph" + std::to_string(graph->graph_id()) + "_" + hash_id_ + ".json", save_graphs_path_);
130     ret = LoadSomasResult(graph, filename);
131     if (ret) {
132       MS_LOG(INFO) << "Load Somas Cache file " << filename << " Successfully.";
133     }
134   } else {
135     MS_LOG(ERROR) << "Calculate somas's model hash id failed.";
136   }
137   return ret;
138 }
139 
CalcSomasModelHash(const session::KernelGraph * graph)140 bool Somas::CalcSomasModelHash(const session::KernelGraph *graph) {
141   MS_EXCEPTION_IF_NULL(graph);
142   auto model_str = SomasInfo(true);
143   hash_id_ = std::to_string(std::hash<std::string>()(model_str));
144   MS_LOG(INFO) << "Graph " << graph->graph_id() << "'s SOMAS Model hash id is " << hash_id_;
145   std::string filename = GetSaveGraphsPathName(
146     "/somas_meta/somas_graph" + std::to_string(graph->graph_id()) + "_" + hash_id_ + ".info", save_graphs_path_);
147   return Common::SaveStringToFile(filename, model_str);
148 }
149 
SaveSomasResult(const session::KernelGraph * graph)150 bool Somas::SaveSomasResult(const session::KernelGraph *graph) {
151   MS_EXCEPTION_IF_NULL(graph);
152   if (tensors_list_.size() < kCachedResultThreshold) {
153     MS_LOG(DEBUG) << "Tensors size (" << tensors_list_.size() << ") less than " << kCachedResultThreshold
154                   << ", no need to save result";
155     return false;
156   }
157   nlohmann::json somas_json;
158   somas_json[kGraphId] = graph->graph_id();
159   somas_json[kHashId] = hash_id_;
160   somas_json[kMemOffset] = mem_offset_;
161   somas_json[kNodeSize] = nodes_list_.size();
162   somas_json[kTensorSize] = tensors_list_.size();
163   somas_json[kContiguousSize] = contiguous_tensors_list_.size();
164   somas_json[kRefNodeSize] = ref_node_constraints_.size();
165   somas_json[kStreamSize] = streams_list_.size();
166   somas_json[kStreamGroupSize] = streams_groups_.size();
167   std::vector<nlohmann::json> tensors_json;
168   for (auto &tensor : tensors_list_) {
169     MS_EXCEPTION_IF_NULL(tensor);
170     nlohmann::json tensor_json;
171     tensor_json[kTensorId] = tensor->GetId();
172     tensor_json[kSize] = tensor->GetAlignedSize();
173     tensor_json[kOriSize] = tensor->GetOriginalSize();
174     tensor_json[kLifelongValue] = tensor->lifelong_value_;
175     tensor_json[kLifeStart] = tensor->lifetime_.start_;
176     tensor_json[kLifeEnd] = tensor->lifetime_.end_;
177     tensor_json[kOffset] = tensor->GetOffset();
178     tensors_json.emplace_back(tensor_json);
179   }
180   somas_json[kTensors] = tensors_json;
181 
182   std::string filename = GetSaveGraphsPathName(
183     "/somas_meta/somas_graph" + std::to_string(graph->graph_id()) + "_" + hash_id_ + ".json", save_graphs_path_);
184   (void)Common::SaveStringToFile(filename, somas_json.dump());
185   return true;
186 }
187 
LoadSomasResult(const session::KernelGraph * graph,const string & filename)188 bool Somas::LoadSomasResult(const session::KernelGraph *graph, const string &filename) {
189   if (filename.length() <= strlen(".json")) {
190     MS_LOG(WARNING) << "please check somas cache file path.";
191     return false;
192   }
193   std::ifstream somas_json_fs(filename);
194   if (!somas_json_fs.is_open()) {
195     MS_LOG(INFO) << "Open json file: " << filename << " error, Somas Cache Missed.";
196     return false;
197   }
198   nlohmann::json somas_json;
199   try {
200     somas_json_fs >> somas_json;
201     somas_json_fs.close();
202   } catch (std::exception &e) {
203     MS_LOG(WARNING) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
204     somas_json_fs.close();
205     std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalSeconds));
206     std::ifstream retry_tmp(filename);
207     if (!retry_tmp.is_open()) {
208       MS_LOG(INFO) << "Open json file: " << filename << " error, please check kernel_meta.";
209       return false;
210     }
211     retry_tmp >> somas_json;
212     retry_tmp.close();
213   }
214 
215   auto ret = VerifySomasResult(graph, somas_json);
216   if (!ret) {
217     MS_LOG(WARNING) << "Verify Somas Result Failed.";
218     return false;
219   }
220   auto mem_offset = somas_json[kMemOffset];
221   mem_offset_ = mem_offset;
222   ret = UpdateTensorsOffset(somas_json[kTensors]);
223   return ret;
224 }
225 
VerifySomasResult(const session::KernelGraph * graph,const nlohmann::json & somas_json) const226 bool Somas::VerifySomasResult(const session::KernelGraph *graph, const nlohmann::json &somas_json) const {
227   MS_EXCEPTION_IF_NULL(graph);
228   auto graph_id = somas_json[kGraphId];
229   auto hash_id = somas_json[kHashId];
230   auto node_size = somas_json[kNodeSize];
231   auto tensor_size = somas_json[kTensorSize];
232   auto contiguous_size = somas_json[kContiguousSize];
233   auto ref_node_size = somas_json[kRefNodeSize];
234   auto stream_size = somas_json[kStreamSize];
235   auto stream_group_size = somas_json[kStreamGroupSize];
236 
237   if (graph_id != graph->graph_id()) {
238     MS_LOG(WARNING) << "Mismatch graph id " << graph_id << " vs " << graph->graph_id();
239     return false;
240   }
241 
242   if (hash_id != hash_id_) {
243     MS_LOG(WARNING) << "Mismatch hash id " << hash_id << " vs " << hash_id_;
244     return false;
245   }
246 
247   if (node_size != nodes_list_.size()) {
248     MS_LOG(WARNING) << "Mismatch node size " << node_size << " vs " << nodes_list_.size();
249     return false;
250   }
251 
252   if (tensor_size != tensors_list_.size()) {
253     MS_LOG(WARNING) << "Mismatch tensor size " << tensor_size << " vs " << tensors_list_.size();
254     return false;
255   }
256 
257   if (contiguous_size != contiguous_tensors_list_.size()) {
258     MS_LOG(WARNING) << "Mismatch contiguous size " << contiguous_size << " vs " << contiguous_tensors_list_.size();
259     return false;
260   }
261 
262   if (ref_node_size != ref_node_constraints_.size()) {
263     MS_LOG(WARNING) << "Mismatch ref node size " << ref_node_size << " vs " << ref_node_constraints_.size();
264     return false;
265   }
266 
267   if (stream_size != streams_list_.size()) {
268     MS_LOG(WARNING) << "Mismatch stream size " << stream_size << " vs " << streams_list_.size();
269     return false;
270   }
271 
272   if (stream_group_size != streams_groups_.size()) {
273     MS_LOG(WARNING) << "Mismatch stream group size " << stream_group_size << " vs " << streams_groups_.size();
274     return false;
275   }
276 
277   return true;
278 }
279 
UpdateTensorsOffset(const std::vector<nlohmann::json> & tensors_json)280 bool Somas::UpdateTensorsOffset(const std::vector<nlohmann::json> &tensors_json) {
281   bool ret = true;
282   for (auto &tensor_json : tensors_json) {
283     auto tensor_id = tensor_json[kTensorId];
284     auto size = tensor_json[kSize];
285     auto ori_size = tensor_json[kOriSize];
286     auto lifelong_value = tensor_json[kLifelongValue];
287     auto life_start = tensor_json[kLifeStart];
288     auto life_end = tensor_json[kLifeEnd];
289     auto offset = tensor_json[kOffset];
290     auto iter = tensors_map_.find(tensor_id);
291     if (iter != tensors_map_.end()) {
292       MS_EXCEPTION_IF_NULL(iter->second);
293       if (size != iter->second->aligned_size_) {
294         MS_LOG(WARNING) << "Mismatch size of tensor " << tensor_id << " " << size << " vs "
295                         << iter->second->aligned_size_;
296         ret = false;
297         break;
298       }
299 
300       if (ori_size != iter->second->GetOriginalSize()) {
301         MS_LOG(WARNING) << "Mismatch original size of tensor " << tensor_id << " " << ori_size << " vs "
302                         << iter->second->GetOriginalSize();
303         ret = false;
304         break;
305       }
306 
307       if (lifelong_value != iter->second->lifelong_value_) {
308         MS_LOG(WARNING) << "Mismatch lifelong value of tensor " << tensor_id << " " << lifelong_value << " vs "
309                         << iter->second->lifelong_value_;
310         ret = false;
311         break;
312       }
313 
314       if (life_start != iter->second->lifetime_.start_) {
315         MS_LOG(WARNING) << "Mismatch life start of tensor " << tensor_id << " " << life_start << " vs "
316                         << iter->second->lifetime_.start_;
317         ret = false;
318         break;
319       }
320 
321       if (life_end != iter->second->lifetime_.end_) {
322         MS_LOG(WARNING) << "Mismatch life start of tensor " << tensor_id << " " << life_end << " vs "
323                         << iter->second->lifetime_.end_;
324         ret = false;
325         break;
326       }
327 
328       // verify pass, update memory offset
329       iter->second->offset_ = offset;
330     } else {
331       MS_LOG(WARNING) << "Can't find tensor " << tensor_id;
332       ret = false;
333       break;
334     }
335   }
336   return ret;
337 }
338 
InitSomasTensors(const session::KernelGraph * graph)339 bool Somas::InitSomasTensors(const session::KernelGraph *graph) {
340   MS_EXCEPTION_IF_NULL(graph);
341   InitBasicInfo(graph);
342   IndependentNodeOutputProcess(graph);
343 #ifndef ENABLE_SECURITY
344   SummaryInputProcess(graph);
345 #endif
346   RefNodeProcess(graph);
347   NonTaskSplitProcess(graph);
348   UnReuseNodeProcess(graph);
349   GenContiguousList(graph);
350   GetNextOutputProcess(graph);
351 
352   if (tensors_list_.empty()) {
353     MS_LOG(INFO) << "No Tensor from graph " << graph->graph_id();
354     return true;
355   }
356 
357   MS_LOG(INFO) << "Created " << streams_list_.size() << " streams (" << streams_groups_.size() << " groups), "
358                << nodes_list_.size() << " nodes, " << tensors_list_.size() << " tensors, and "
359                << contiguous_tensors_list_.size() << " contiguous lists";
360 
361 #ifdef ENABLE_DUMP_IR
362   SubModuleId module = SubModuleId::SM_OPTIMIZER;
363   std::string name = "somas_pre_processed_info." + std::to_string(graph->graph_id());
364   (void)mindspore::RDR::RecordString(module, name, SomasInfo());
365   name = "somas_offline_log." + std::to_string(graph->graph_id());
366   (void)mindspore::RDR::RecordString(module, name, Offline());
367 #endif
368 
369   if (save_graphs_) {
370     std::string file_path = GetSaveGraphsPathName(
371       "/somas_pre_processed_info_" + std::to_string(graph->graph_id()) + ".ir", save_graphs_path_);
372     DumpSomasInfoIR(file_path);
373 
374     std::string offline_file_path =
375       GetSaveGraphsPathName("/somas_offline_log_" + std::to_string(graph->graph_id()) + ".ir", save_graphs_path_);
376     DumpOfflineIR(offline_file_path);
377   }
378 
379   return true;
380 }
381 
InitSomasStreamAndNode(const session::KernelGraph * graph)382 void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
383   MS_EXCEPTION_IF_NULL(graph);
384   std::vector<CNodePtr> kernel_cnodes;
385   streams_list_ = {};
386   nodes_list_ = {};
387   size_t node_index = 0;
388   if (graph->subgraph_multi_call()) {
389     kernel_cnodes = graph->mem_reuse_exec_order();
390   } else {
391     kernel_cnodes = graph->execution_order();
392   }
393   for (size_t i = 0; i < kernel_cnodes.size(); i++) {
394     auto kernel = kernel_cnodes[i];
395     MS_EXCEPTION_IF_NULL(kernel);
396     SomasStreamPtr stream;
397     auto stream_id = AnfAlgo::GetStreamId(kernel);
398     auto it = find_if(streams_list_.begin(), streams_list_.end(),
399                       [stream_id](const SomasStreamPtr &s) { return s->GetId() == stream_id; });
400     if (it == streams_list_.end()) {
401       stream = std::make_shared<SomasStream>(stream_id);
402       streams_list_.push_back(stream);
403     } else {
404       stream = *it;
405     }
406 
407     // Node
408     NodeType type = kCommonNode;
409     if (AnfAlgo::IsCommunicationOp(kernel)) {
410       type = kCommunicationNode;
411     }
412     auto node = std::make_shared<SomasNode>(node_index, type, stream);
413     MS_EXCEPTION_IF_NULL(node);
414     node->scope_full_name_ = kernel->fullname_with_scope();
415     nodes_list_.push_back(node);
416     stream->nodes_.push_back(node);
417     auto key = kernel.get();
418     auto &nodes = nodes_map_[key];
419     nodes.push_back(node);
420     node_index++;
421   }
422 }
423 
InitSomasOutputAndWorkspaceTensors(const session::KernelGraph * graph)424 void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph) {
425   MS_EXCEPTION_IF_NULL(graph);
426   tensors_list_ = {};
427   size_t tensor_index = 0;
428   auto kernel_cnodes = graph->execution_order();
429   for (const auto &kernel : kernel_cnodes) {
430     auto nodes = nodes_map_[kernel.get()];
431     auto node = nodes[0];
432     MS_EXCEPTION_IF_NULL(node);
433     auto stream = node->GetStream();
434     MS_EXCEPTION_IF_NULL(stream);
435 
436     // Output Tensor
437     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
438     MS_EXCEPTION_IF_NULL(kernel_mod);
439     auto output_sizes = kernel_mod->GetOutputSizeList();
440     auto index = 0;
441     for (const auto &size : output_sizes) {
442       auto output_tensor_index = tensor_index;
443       tensor_index++;
444       // Set all output tensor lifelong to true.
445       auto tensor = std::make_shared<SomasTensor>(output_tensor_index, node, stream, size, kLifeLongNone);
446       MS_EXCEPTION_IF_NULL(tensor);
447       tensor->lifetime_.start_ = node->GetId();
448       tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId();
449       tensor->type_ = kOutputOnly;
450       if (AnfAlgo::OutputAddrExist(kernel, IntToSize(index))) {
451         tensor->aligned_size_ = 0;
452       }
453 
454       tensors_list_.push_back(tensor);
455       tensors_map_[output_tensor_index] = tensor;
456       stream->tensors_.push_back(tensor);
457       std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) {
458         MS_EXCEPTION_IF_NULL(node);
459         node->tensors_.insert(tensor);
460         node->output_tensors_.push_back(tensor);
461       });
462       index++;
463     }
464 
465     // WorkSpace Tensor
466     auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
467     index = 0;
468     for (const auto &size : workspace_sizes) {
469       auto workspace_tensor_index = tensor_index;
470       tensor_index++;
471       SomasTensorPtr tensor = std::make_shared<SomasTensor>(workspace_tensor_index, node, stream, size, kLifeLongNone);
472       MS_EXCEPTION_IF_NULL(tensor);
473       tensor->type_ = kWorkspace;
474       tensor->lifetime_.start_ = node->GetId();
475       tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId();
476       if (AnfAlgo::WorkspaceAddrExist(kernel, IntToSize(index))) {
477         tensor->aligned_size_ = 0;
478       }
479       tensors_list_.push_back(tensor);
480       tensors_map_[workspace_tensor_index] = tensor;
481       stream->tensors_.push_back(tensor);
482       std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) {
483         MS_EXCEPTION_IF_NULL(node);
484         node->tensors_.insert(tensor);
485         node->workspace_tensors_.push_back(tensor);
486       });
487       index++;
488     }
489   }
490 }
491 
InitSomasInputTensors(const session::KernelGraph * graph)492 void Somas::InitSomasInputTensors(const session::KernelGraph *graph) {
493   MS_EXCEPTION_IF_NULL(graph);
494   bool is_all_nop_node = opt::IsAllNopNode(graph);
495   static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
496   auto kernel_cnodes = graph->execution_order();
497   for (const auto &kernel : kernel_cnodes) {
498     if (AnfAlgo::GetCNodeName(kernel) != kAtomicAddrCleanOpName) {
499       InitCommonNodeInputs(is_all_nop_node, kernel);
500     } else {
501       InitAtomicCleanInputs(enable_fusion_clear, kernel);
502     }
503   }
504 }
505 
InitCommonNodeInputs(bool is_all_nop_node,const CNodePtr & kernel)506 void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
507   auto nodes = nodes_map_[kernel.get()];
508   auto node = nodes[0];
509   MS_EXCEPTION_IF_NULL(node);
510   auto stream = node->GetStream();
511   MS_EXCEPTION_IF_NULL(stream);
512 
513   // Input Tensor
514   auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
515   size_t real_input_index = 0;
516   for (size_t i = 0; i < input_tensor_num; i++) {
517     auto input_node = kernel->input(i + 1);
518     MS_EXCEPTION_IF_NULL(input_node);
519     session::KernelWithIndex prenode_index;
520     if (is_all_nop_node) {
521       prenode_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
522     } else {
523       prenode_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
524     }
525     if (AnfAlgo::CheckPrimitiveType(prenode_index.first, prim::kPrimMakeTuple)) {
526       MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
527     }
528     MS_EXCEPTION_IF_NULL(prenode_index.first);
529     if (!AnfAlgo::IsRealCNodeKernel(prenode_index.first)) {
530       auto op_name = AnfAlgo::GetCNodeName(kernel);
531       TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel, i);
532       if ((op_name == kDynamicRNNOpName || op_name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
533         continue;
534       }
535       auto parameter = GetSomasParameter(prenode_index.first, prenode_index.second);
536       node->input_parameters_map_[real_input_index] = parameter;
537       real_input_index++;
538       MS_LOG(DEBUG) << "Input  [" << prenode_index.first->fullname_with_scope() << "] is not a real cnode kernel.";
539       continue;
540     }
541 
542     auto iter = nodes_map_.find(prenode_index.first.get());
543     if (iter == nodes_map_.end()) {
544       MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " ["
545                         << prenode_index.first->fullname_with_scope() << "] is not init.";
546     }
547     auto pre_somas_node = iter->second.at(0);
548     if (prenode_index.second > pre_somas_node->output_tensors_.size()) {
549       MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node ["
550                         << prenode_index.first->fullname_with_scope() << "]'s outputs size "
551                         << pre_somas_node->output_tensors_.size();
552     }
553     auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second];
554     MS_EXCEPTION_IF_NULL(input_somas_tensor);
555     std::for_each(nodes.begin(), nodes.end(),
556                   [input_somas_tensor](auto &node) { node->input_tensors_.push_back(input_somas_tensor); });
557     real_input_index++;
558     if (input_somas_tensor->type_ == kOutputOnly) {
559       input_somas_tensor->type_ = kCommon;
560     }
561     input_somas_tensor->destinationStreams_.insert(stream);
562     for (auto &repeat_node : nodes) {
563       input_somas_tensor->destinations_.insert(repeat_node);
564       if (input_somas_tensor->lifetime_.end_ < repeat_node->GetId()) {
565         input_somas_tensor->lifetime_.end_ = repeat_node->GetId();
566       }
567     }
568 
569     if (node != pre_somas_node) {
570       node->ancestor_nodes_.insert(pre_somas_node);
571     }
572     auto input_tensor_stream = input_somas_tensor->GetSourceStream();
573     if (input_tensor_stream != stream) {
574       stream->ancestor_streams_.insert(input_tensor_stream);
575       input_somas_tensor->between_streams_ = true;
576     }
577   }
578 }
579 
InitAtomicCleanInputs(bool enable_fusion_clear,const CNodePtr & kernel)580 void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) {
581   auto node = nodes_map_[kernel.get()].at(0);
582   MS_EXCEPTION_IF_NULL(node);
583   auto stream = node->GetStream();
584   MS_EXCEPTION_IF_NULL(stream);
585 
586   auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
587   for (size_t i = 0; i < input_tensor_num; i++) {
588     MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]);
589     auto pre_node = kernel->input(i + 1)->cast<CNodePtr>();
590     auto iter = nodes_map_.find(pre_node.get());
591     if (iter == nodes_map_.end()) {
592       MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input ["
593                         << pre_node->fullname_with_scope() << "] is not init.";
594     }
595     auto pre_somas_node = iter->second.at(0);
596     MS_EXCEPTION_IF_NULL(pre_somas_node);
597     // set clean output tensors
598     if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
599       auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
600       for (auto index : clean_output_indexs) {
601         if (index > pre_somas_node->output_tensors_.size()) {
602           MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope()
603                             << "]'s outputs size " << pre_somas_node->output_tensors_.size();
604         }
605         auto input_somas_tensor = pre_somas_node->output_tensors_[index];
606         MS_EXCEPTION_IF_NULL(input_somas_tensor);
607         node->input_tensors_.push_back(input_somas_tensor);
608         if (enable_fusion_clear) {
609           input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
610           MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
611                        << " 's output" << index << " to lifelong";
612         }
613       }
614     }
615     // set clean workspace tensors
616     if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
617       auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
618       for (const auto &index : clean_workspace_indexs) {
619         if (index > pre_somas_node->output_tensors_.size()) {
620           MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope()
621                             << "]'s Workspace size " << pre_somas_node->workspace_tensors_.size();
622         }
623         auto input_somas_tensor = pre_somas_node->workspace_tensors_[index];
624         MS_EXCEPTION_IF_NULL(input_somas_tensor);
625         node->input_tensors_.push_back(input_somas_tensor);
626         if (enable_fusion_clear) {
627           input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
628           MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
629                        << " 's workspace" << index << " to lifelong";
630         }
631       }
632     }
633   }
634 }
635 
CreateSomasParameter(const AnfNodePtr & node,size_t index)636 SomasParameterPtr Somas::CreateSomasParameter(const AnfNodePtr &node, size_t index) {
637   MS_EXCEPTION_IF_NULL(node);
638   auto id = parameters_list_.size();
639   auto device_addr = AnfAlgo::GetOutputAddr(node, index);
640   if (device_addr == nullptr) {
641     MS_LOG(EXCEPTION) << "Node " << node->fullname_with_scope() << " has no device address before Somas.";
642   }
643   auto param = std::make_shared<SomasParameter>(id, node->fullname_with_scope(), index, device_addr->GetPtr(),
644                                                 device_addr->GetSize());
645   parameters_list_.push_back(param);
646   return param;
647 }
648 
GetSomasParameter(const AnfNodePtr & node,size_t index)649 SomasParameterPtr Somas::GetSomasParameter(const AnfNodePtr &node, size_t index) {
650   auto key = node.get();
651   auto iter = parameters_map_.find(key);
652   if (iter != parameters_map_.end()) {
653     auto it = std::find_if(iter->second.begin(), iter->second.end(),
654                            [index](const SomasParameterPtr &param) -> bool { return index == param->output_index_; });
655     if (it != iter->second.end()) {
656       return *it;
657     } else {
658       auto new_param = CreateSomasParameter(node, index);
659       iter->second.push_back(new_param);
660       return new_param;
661     }
662   } else {
663     auto param = CreateSomasParameter(node, index);
664     parameters_map_[key].push_back(param);
665     return param;
666   }
667 }
668 
InitBasicInfo(const session::KernelGraph * graph)669 void Somas::InitBasicInfo(const session::KernelGraph *graph) {
670   MS_EXCEPTION_IF_NULL(graph);
671 #ifdef ENABLE_D
672   streams_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group();
673 #endif
674   InitSomasStreamAndNode(graph);
675   InitSomasOutputAndWorkspaceTensors(graph);
676   InitSomasInputTensors(graph);
677 
678   auto context_ptr = MsContext::GetInstance();
679   MS_EXCEPTION_IF_NULL(context_ptr);
680 
681 #ifdef ENABLE_DUMP_IR
682   SubModuleId module = SubModuleId::SM_OPTIMIZER;
683   std::string name = "somas_initial_info." + std::to_string(graph->graph_id());
684   (void)mindspore::RDR::RecordString(module, name, SomasInfo());
685 #endif
686 
687   save_graphs_ = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
688   save_graphs_path_ = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
689   if (save_graphs_path_.empty()) {
690     save_graphs_path_ = ".";
691   }
692   if (save_graphs_) {
693     std::string file_path =
694       GetSaveGraphsPathName("/somas_initial_info_" + std::to_string(graph->graph_id()) + ".ir", save_graphs_path_);
695     DumpSomasInfoIR(file_path);
696   }
697 }
698 
GetNextOutputProcess(const session::KernelGraph * graph)699 void Somas::GetNextOutputProcess(const session::KernelGraph *graph) {
700   MS_EXCEPTION_IF_NULL(graph);
701   auto kernel_cnodes = graph->execution_order();
702   size_t total_size = 0;
703   for (const auto &kernel : kernel_cnodes) {
704     if (AnfAlgo::GetCNodeName(kernel) != kGetNextOpName) {
705       continue;
706     }
707     auto iter = nodes_map_.find(kernel.get());
708     if (iter != nodes_map_.end()) {
709       auto &node = iter->second.at(0);
710       MS_EXCEPTION_IF_NULL(node);
711       auto getnext_output_tensors = node->output_tensors_;
712       for (auto &tensor : getnext_output_tensors) {
713         MS_EXCEPTION_IF_NULL(tensor);
714         total_size += tensor->GetAlignedSize();
715         tensor->lifelong_value_ = kLifeLongGraphAll;
716         tensor->type_ = kGetNextOutput;
717       }
718     }
719   }
720   MS_LOG(INFO) << "Special Tensor total size: GetNext Output " << total_size;
721 }
722 
IndependentNodeOutputProcess(const session::KernelGraph * graph)723 void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) {
724   MS_EXCEPTION_IF_NULL(graph);
725   auto kernel_cnodes = graph->execution_order();
726   size_t total_size = 0;
727   for (const auto &kernel : kernel_cnodes) {
728     bool independent = AnfAlgo::IsIndependentNode(kernel);
729     if (!independent) {
730       continue;
731     }
732     auto iter = nodes_map_.find(kernel.get());
733     if (iter != nodes_map_.end()) {
734       auto &node = iter->second.at(0);
735       MS_EXCEPTION_IF_NULL(node);
736       auto semi_reuse_output_tensors = node->output_tensors_;
737       for (auto &tensor : semi_reuse_output_tensors) {
738         MS_EXCEPTION_IF_NULL(tensor);
739         total_size += tensor->GetAlignedSize();
740         tensor->lifelong_value_ = kLifeLongGraphAll;
741       }
742     }
743   }
744 
745   MS_LOG(INFO) << "Special Tensor total size: Independent Node output " << total_size;
746 }
747 
748 #ifndef ENABLE_SECURITY
SummaryInputProcess(const session::KernelGraph * graph)749 void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
750   MS_EXCEPTION_IF_NULL(graph);
751   bool summary_exist = graph->summary_node_exist();
752   if (!summary_exist) {
753     return;
754   }
755 
756   auto summary_nodes = graph->summary_nodes();
757   if (summary_nodes.empty()) {
758     return;
759   }
760 
761   size_t total_summary_size = 0;
762   for (auto &node_item : summary_nodes) {
763     auto node = node_item.second.first;
764     size_t index = IntToSize(node_item.second.second);
765     auto iter = nodes_map_.find(node.get());
766     if (iter != nodes_map_.end()) {
767       auto input_node = iter->second.at(0);
768       MS_EXCEPTION_IF_NULL(input_node);
769       if (index < input_node->output_tensors_.size()) {
770         auto tensor = input_node->output_tensors_[index];
771         MS_EXCEPTION_IF_NULL(tensor);
772         tensor->lifelong_value_ = kLifeLongGraphAll;
773         tensor->type_ = kSummaryInput;
774         total_summary_size += tensor->GetAlignedSize();
775         MS_LOG(INFO) << "Set summary node input tensor's lifelong, node: " << node->fullname_with_scope()
776                      << " index: " << index;
777       } else {
778         MS_LOG(WARNING) << "Index exceed size, node " << node->fullname_with_scope() << " index: " << index
779                         << " size: " << input_node->output_tensors_.size();
780       }
781     } else {
782       MS_LOG(WARNING) << "Can't find summary input node " << node->fullname_with_scope() << " index: " << index;
783     }
784   }
785 
786   MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
787 }
788 #endif
789 
RefNodeProcess(const session::KernelGraph * graph)790 void Somas::RefNodeProcess(const session::KernelGraph *graph) {
791   MS_EXCEPTION_IF_NULL(graph);
792   auto kernel_cnodes = graph->execution_order();
793   size_t total_output_size = 0;
794   size_t total_input_size = 0;
795   for (const auto &kernel : kernel_cnodes) {
796     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
797     if (kernel_mod == nullptr) {
798       MS_LOG(WARNING) << "Kernel mode is NULL Of " << kernel->fullname_with_scope();
799       continue;
800     }
801     auto output_sizes = kernel_mod->GetOutputSizeList();
802     size_t output_index = 0;
803     for (const auto &size : output_sizes) {
804       auto out_index = output_index;
805       output_index++;
806       session::AnfWithOutIndex out_pair(kernel, out_index);
807       if (graph->IsInRefOutputMap(out_pair)) {
808         auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
809         MS_EXCEPTION_IF_NULL(origin_pair.first);
810         auto &node = nodes_map_[kernel.get()].at(0);
811         MS_EXCEPTION_IF_NULL(node);
812         auto output_tensor = node->output_tensors_[out_index];
813         MS_EXCEPTION_IF_NULL(output_tensor);
814         output_tensor->type_ = kRefNodeOutput;
815         total_output_size += size;
816 
817         if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) {
818           auto ori_node = origin_pair.first->cast<CNodePtr>();
819           auto ori_index = origin_pair.second;
820           if (nodes_map_.find(ori_node.get()) == nodes_map_.end()) {
821             MS_LOG(EXCEPTION)
822               << "The ori_node is not included in nodes_map_ constructed from exec_order of graph. Info ori_node: "
823               << ori_node->DebugString();
824           }
825           auto &repeat_node = nodes_map_[ori_node.get()].at(0);
826           MS_EXCEPTION_IF_NULL(repeat_node);
827           auto input_tensor = repeat_node->output_tensors_[ori_index];
828           MS_EXCEPTION_IF_NULL(input_tensor);
829           input_tensor->type_ = kRefNodeInput;
830           total_input_size += input_tensor->aligned_size_;
831           std::vector<size_t> refnode_input_output;
832           refnode_input_output.push_back(input_tensor->GetId());
833           refnode_input_output.push_back(output_tensor->GetId());
834           ref_node_constraints_.push_back(refnode_input_output);
835           MS_LOG(INFO) << "RefNode: input " << input_tensor->GetId() << " output " << output_tensor->GetId();
836         }
837       }
838     }
839   }
840 
841   MS_LOG(INFO) << "Special Tensor total size: RefNode: input " << total_input_size << " output " << total_output_size;
842 }
843 
NonTaskSplitProcess(const session::KernelGraph * graph)844 void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) {
845   MS_EXCEPTION_IF_NULL(graph);
846   auto kernel_cnodes = graph->execution_order();
847   for (const auto &kernel : kernel_cnodes) {
848     auto op_name = AnfAlgo::GetCNodeName(kernel);
849     if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) {
850       std::vector<size_t> refnode_input_output;
851       auto node = nodes_map_[kernel.get()].at(0);
852       MS_EXCEPTION_IF_NULL(node);
853       if (node->input_tensors_.size() == 0) {
854         MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process.";
855       }
856       auto input_tensor = node->input_tensors_[0];
857       MS_EXCEPTION_IF_NULL(input_tensor);
858       input_tensor->type_ = kRefNodeInput;
859       refnode_input_output.push_back(input_tensor->GetId());
860 
861       for (auto &output_tensor : node->output_tensors_) {
862         MS_EXCEPTION_IF_NULL(output_tensor);
863         output_tensor->type_ = kRefNodeOutput;
864         refnode_input_output.push_back(output_tensor->GetId());
865       }
866       ref_node_constraints_.push_back(refnode_input_output);
867     }
868   }
869 }
870 
UnReuseNodeProcess(const session::KernelGraph * graph)871 void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) {
872   MS_EXCEPTION_IF_NULL(graph);
873   vector<string> full_name_list = {};
874   if (full_name_list.size() == 0) {
875     return;
876   }
877 
878   auto kernel_cnodes = graph->execution_order();
879   for (const auto &kernel : kernel_cnodes) {
880     MS_EXCEPTION_IF_NULL(kernel);
881     auto full_name = kernel->fullname_with_scope();
882     auto iter = std::find(full_name_list.begin(), full_name_list.end(), full_name);
883     if (iter != full_name_list.end()) {
884       MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name;
885       auto key = kernel.get();
886       auto somas_node = nodes_map_[key].at(0);
887       MS_EXCEPTION_IF_NULL(somas_node);
888       // input
889       auto inputs = somas_node->input_tensors_;
890       for (auto &input : inputs) {
891         MS_EXCEPTION_IF_NULL(input);
892         input->lifelong_value_ = kLifeLongGraphAll;
893       }
894 
895       // output
896       auto outputs = somas_node->output_tensors_;
897       MS_LOG(INFO) << "Output size of " << kernel->fullname_with_scope() << " is  " << outputs.size();
898       for (auto &output : outputs) {
899         MS_EXCEPTION_IF_NULL(output);
900         output->lifelong_value_ = kLifeLongGraphAll;
901       }
902 
903       // workspace
904       auto workspaces = somas_node->workspace_tensors_;
905       for (auto &workspace : workspaces) {
906         MS_EXCEPTION_IF_NULL(workspace);
907         workspace->lifelong_value_ = kLifeLongGraphAll;
908       }
909     }
910   }
911 }
912 
GenContiguousList(const session::KernelGraph * graph)913 void Somas::GenContiguousList(const session::KernelGraph *graph) {
914   MS_EXCEPTION_IF_NULL(graph);
915   for (const auto &node : nodes_list_) {
916     MS_EXCEPTION_IF_NULL(node);
917     if (node->GetType() != kCommunicationNode) {
918       continue;
919     }
920 
921     // Contiguous input
922     if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) {
923       if (node->input_tensors_[0]->aligned_size_) {
924         node->input_tensors_[0]->aligned_size_ += kGapSize;
925       }
926       if (node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_) {
927         node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += kGapSize;
928       }
929       std::vector<size_t> inputs;
930       for (const auto &input_tensor : node->input_tensors_) {
931         MS_EXCEPTION_IF_NULL(input_tensor);
932         comm_input_total_size_ += input_tensor->aligned_size_;
933         input_tensor->contiguous_ = true;
934         inputs.push_back(input_tensor->GetId());
935       }
936       contiguous_tensors_list_.push_back(inputs);
937     }
938 
939     // Contiguous output
940     if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) {
941       if (node->output_tensors_[0]->aligned_size_) {
942         node->output_tensors_[0]->aligned_size_ += kGapSize;
943       }
944       if (node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_) {
945         node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += kGapSize;
946       }
947       std::vector<size_t> outputs;
948       for (const auto &output_tensor : node->output_tensors_) {
949         MS_EXCEPTION_IF_NULL(output_tensor);
950         comm_output_total_size_ += output_tensor->aligned_size_;
951         output_tensor->contiguous_ = true;
952         outputs.push_back(output_tensor->GetId());
953       }
954       contiguous_tensors_list_.push_back(outputs);
955     }
956   }
957 }
958 
ComputeConflictPairs()959 void Somas::ComputeConflictPairs() {
960   if (tensors_list_.empty()) {
961     MS_LOG(INFO) << "No Tensor for Conflict computing";
962     return;
963   }
964 
965   MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
966   auto start_conflict = std::chrono::system_clock::now();
967   std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
968   UpdateTensorDestinations();
969 
970   MS_LOG(INFO) << "Start Bitset";
971   std::vector<DynamicBitSet> nodes_dependency;
972 
973   size_t count = nodes_list_.back()->GetId() + 1;
974   for (size_t i = 0; i < count; i++) {
975     nodes_dependency.emplace_back(count);
976   }
977 
978   MS_LOG(INFO) << "Start Path Computing";
979   // Loop to compute ancestor paths via bitset for time dependence
980   for (const auto &node : nodes_list_) {
981     for (const auto &ancestor : node->ancestor_nodes_) {
982       nodes_dependency[node->GetId()].SetBitTrue(ancestor->GetId());
983       Union(&nodes_dependency[node->GetId()], &nodes_dependency[ancestor->GetId()]);
984     }
985   }
986   MS_LOG(INFO) << "End Path Computing";
987 
988   MS_LOG(INFO) << "Start Tensor Relation Computing";
989   count = tensors_list_.back()->GetId() + 1;
990   for (size_t i = 0; i < count; i++) {
991     reuse_matrix_.emplace_back(count);
992   }
993 
994   if (tensors_list_.size() < kParallelComputeSizeThreshold) {
995     ComputeMultiTensorConflicts(tensors_list_, tensors_list_, nodes_dependency, &reuse_matrix_);
996   } else {
997     MS_LOG(INFO) << "Tensor Num " << tensors_list_.size() << " is larger than " << kParallelComputeSizeThreshold;
998     MS_LOG(INFO) << "Enter Multi-Thread Mode...";
999     size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
1000     MS_LOG(INFO) << "Threads Num is " << process_num;
1001 
1002     int64_t start_index = 0;
1003     int64_t total_size = SizeToLong(tensors_list_.size());
1004     int64_t job_size = total_size / SizeToLong(process_num);
1005     if (job_size == 0) {
1006       job_size = total_size;
1007     }
1008     std::vector<common::Task> tasks;
1009     while (start_index < total_size) {
1010       int64_t end_index = (start_index + job_size) > total_size ? total_size : start_index + job_size;
1011       auto jobs = std::vector<SomasTensorPtr>(tensors_list_.begin() + start_index, tensors_list_.begin() + end_index);
1012       auto task = [this, jobs, &nodes_dependency]() {
1013         this->ComputeMultiTensorConflicts(jobs, tensors_list_, nodes_dependency, &reuse_matrix_);
1014         return common::SUCCESS;
1015       };
1016       tasks.emplace_back(task);
1017       start_index += job_size;
1018     }
1019 
1020     common::ThreadPool::GetInstance().SyncRun(tasks);
1021   }
1022   MS_LOG(INFO) << "End Tensor Relation Computing";
1023   auto end_conflict = std::chrono::system_clock::now();
1024   MS_LOG(INFO) << "End Conflict Computing (Bitset Model)(time taken "
1025                << std::chrono::duration_cast<std::chrono::milliseconds>(end_conflict - start_conflict).count() << "ms)";
1026 }
1027 
UpdateTensorDestinations()1028 void Somas::UpdateTensorDestinations() {
1029   // Loop to add edges within each stream (node order within stream)
1030   for (const auto &stream : streams_list_) {
1031     MS_EXCEPTION_IF_NULL(stream);
1032     auto &nodes = stream->nodes_;
1033     std::sort(nodes.begin(), nodes.end(), NodeSort);
1034     for (size_t i = 1; i < nodes.size(); i++) {
1035       const auto &previous_node = nodes[i - 1];
1036       const auto &current_node = nodes[i];
1037       MS_EXCEPTION_IF_NULL(current_node);
1038       current_node->ancestor_nodes_.insert(previous_node);
1039     }
1040   }
1041 
1042   // Loop to add edges from end to beginning of next group
1043   for (const auto &group : streams_groups_) {
1044     for (size_t i = 1; i < group.size(); i++) {
1045       int64_t previous_stream = group[i - 1];
1046       int64_t current_stream = group[i];
1047 
1048       auto it =
1049         std::find_if(streams_list_.begin(), streams_list_.end(),
1050                      [previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; });
1051       if (it == streams_list_.end()) {
1052         continue;
1053       }
1054       auto &last_node_in_prev_stream = (*it)->nodes_.back();
1055 
1056       it = std::find_if(streams_list_.begin(), streams_list_.end(),
1057                         [current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; });
1058       if (it == streams_list_.end()) {
1059         continue;
1060       }
1061       auto &first_node_in_cur_stream = (*it)->nodes_.front();
1062 
1063       first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream);
1064     }
1065   }
1066 
1067   // Loop to avoid tensors with empty destinations (add itself)
1068   for (const auto &tensor : tensors_list_) {
1069     MS_EXCEPTION_IF_NULL(tensor);
1070     if (tensor->destinations_.size() == 0) {
1071       tensor->destinations_.insert(tensor->GetSourceNode());
1072     }
1073   }
1074 
1075   // Loop to compute max destinations in each stream
1076   for (const auto &tensor : tensors_list_) {
1077     MS_EXCEPTION_IF_NULL(tensor);
1078     tensor->ComputeMaxDestinationId();
1079   }
1080 }
1081 
ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> & calc_tensors_list,const std::vector<SomasTensorPtr> & all_tensors_list,const vector<DynamicBitSet> & nodes_dependency,std::vector<DynamicBitSet> * tensor_relation) const1082 void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
1083                                         const std::vector<SomasTensorPtr> &all_tensors_list,
1084                                         const vector<DynamicBitSet> &nodes_dependency,
1085                                         std::vector<DynamicBitSet> *tensor_relation) const {
1086   auto start = std::chrono::system_clock::now();
1087   MS_LOG(INFO) << "Start Computing Conflicts Pairs, tensors list size is " << calc_tensors_list.size();
1088   for (size_t i = 0; i < calc_tensors_list.size(); i++) {
1089     auto calc_tensor = calc_tensors_list[i];
1090     MS_EXCEPTION_IF_NULL(calc_tensor);
1091     if (calc_tensor->IsLifelong() || calc_tensor->IsSemiLifelongEnd() || calc_tensor->IsRefOverlap() ||
1092         calc_tensor->GetAlignedSize() == 0) {
1093       continue;
1094     }
1095 
1096     ComputeOneTensorConflicts(calc_tensor, all_tensors_list, nodes_dependency, tensor_relation);
1097   }
1098   auto end = std::chrono::system_clock::now();
1099   MS_LOG(INFO) << "End Computing Conflicts Pairs (time taken "
1100                << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms)";
1101 }
1102 
ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> & calc_tensor,const std::vector<SomasTensorPtr> & all_tensors_list,const vector<DynamicBitSet> & nodes_dependency,std::vector<DynamicBitSet> * tensor_relation) const1103 void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor,
1104                                       const std::vector<SomasTensorPtr> &all_tensors_list,
1105                                       const vector<DynamicBitSet> &nodes_dependency,
1106                                       std::vector<DynamicBitSet> *tensor_relation) const {
1107   MS_EXCEPTION_IF_NULL(calc_tensor);
1108   for (size_t j = 0; j < all_tensors_list.size(); j++) {
1109     auto target_tensor = all_tensors_list[j];
1110     MS_EXCEPTION_IF_NULL(target_tensor);
1111     if (calc_tensor == target_tensor || target_tensor->IsLifelong() || target_tensor->IsSemiLifelongStart() ||
1112         target_tensor->IsRefOverlap() || target_tensor->GetAlignedSize() == 0) {
1113       continue;
1114     }
1115     size_t calc_src_node = calc_tensor->GetSourceNode()->GetId();
1116     size_t target_src_node = target_tensor->GetSourceNode()->GetId();
1117     if (calc_src_node == target_src_node) {
1118       continue;
1119     }
1120     if ((*tensor_relation)[calc_tensor->GetId()].IsBitTrue(target_tensor->GetId()) ||
1121         (*tensor_relation)[target_tensor->GetId()].IsBitTrue(calc_tensor->GetId())) {
1122       continue;
1123     }
1124 
1125     bool reuse = true;
1126     // check calc_tensor's all consumers is target_tensor's source node's dependency or not
1127     for (const auto &dst_map : calc_tensor->max_destinations_) {
1128       const auto &dst_node = dst_map.second;
1129       MS_EXCEPTION_IF_NULL(dst_node);
1130       if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) {
1131         // calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or
1132         // not when target_tensor produced
1133         reuse = false;
1134         break;
1135       } else if (target_src_node == dst_node->GetId()) {
1136         // calc_tensor is target_tensor's source node's input, can't reuse
1137         reuse = false;
1138         break;
1139       } else {
1140         // calc_tensor's consumer is in target_tensor's source node's dependency, this consumer is done when
1141         // target_tensor produced
1142         reuse = true;
1143       }
1144     }
1145 
1146     if (reuse) {
1147       // calc_tensor and target_tensor have dependencies so they can reuse each other
1148       (*tensor_relation)[calc_tensor->GetId()].SetBitTrue(target_tensor->GetId());
1149       (*tensor_relation)[target_tensor->GetId()].SetBitTrue(calc_tensor->GetId());
1150     }
1151   }
1152 }
1153 
NodeSort(const SomasNodePtr & node1,const SomasNodePtr & node2)1154 bool Somas::NodeSort(const SomasNodePtr &node1, const SomasNodePtr &node2) { return node1->GetId() < node2->GetId(); }
1155 
Assign(const session::KernelGraph * graph)1156 bool Somas::Assign(const session::KernelGraph *graph) {
1157   if (tensors_list_.empty()) {
1158     MS_LOG(INFO) << "No Tensor for Assigner";
1159     return true;
1160   }
1161 
1162   // Ref Node Preprocessing
1163   UpdateRefTensorsConflict();
1164   std::map<size_t, size_t> contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor();
1165   vector<vector<size_t>> contiguous_tensors_list_removed = contiguous_tensors_list_;
1166   std::set<vector<size_t>> contiguous_tensors_list_to_remove;
1167   for (auto ref_list_pair : contiguous_list_with_ref_index_map) {
1168     contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]);
1169   }
1170 
1171   // remove the contiguous list which all tensors' align size is 0
1172   for (auto contiguous_list : contiguous_tensors_list_) {
1173     bool all_outputs = true;
1174     for (auto tensor_id : contiguous_list) {
1175       auto tensor = tensors_list_[tensor_id];
1176       MS_EXCEPTION_IF_NULL(tensor);
1177       if (tensor->aligned_size_ != 0) {
1178         all_outputs = false;
1179         break;
1180       }
1181     }
1182 
1183     if (all_outputs) {
1184       contiguous_tensors_list_to_remove.insert(contiguous_list);
1185     }
1186   }
1187 
1188   for (auto contiguous_list : contiguous_tensors_list_to_remove) {
1189     auto iterator =
1190       std::find(contiguous_tensors_list_removed.begin(), contiguous_tensors_list_removed.end(), contiguous_list);
1191     if (iterator != contiguous_tensors_list_removed.end()) {
1192       contiguous_tensors_list_removed.erase(iterator);
1193     } else {
1194       MS_LOG(WARNING) << "Could not find contiguous list to remove for ref";
1195     }
1196   }
1197   MS_LOG(INFO) << "End Solving Preprocessing for Ref Node";
1198   UpdateRefOverlapTensorsConflicts();
1199 
1200 #ifdef SOMAS_DEBUG
1201   // Compute number of constraints for each tensor
1202   auto tensors_num = tensors_list_.size();
1203   for (auto tensor1 : tensors_list_) {
1204     auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum();
1205     tensor1->num_constraints_ = tensors_num - ones_num;
1206   }
1207 #endif
1208 
1209   // Prepare solver info
1210   MS_LOG(INFO) << "Start Loop to create solver info";
1211   for (auto tensor : tensors_list_) {
1212     MS_EXCEPTION_IF_NULL(tensor);
1213     if (tensor->GetSolverTensorDesc() != nullptr) {
1214       SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc();
1215       solver_tensor_desc_map_.insert(std::pair<size_t, SomasSolverTensorDescPtr>(pSolverTensor->index_, pSolverTensor));
1216     }
1217   }
1218   MS_LOG(INFO) << "End Loop to create solver info";
1219 
1220   MS_LOG(INFO) << "Start Solving";
1221   if (solver_tensor_desc_map_.empty()) {
1222     MS_LOG(INFO) << "solver_tensor_desc_list is empty.";
1223     return true;
1224   }
1225 
1226   somas_solver_ = std::make_shared<SomasSolverPre>();
1227   auto status =
1228     somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed, false);
1229   MS_LOG(INFO) << "End Solving";
1230   if (status != SUCCESS) {
1231     GenGraphStatisticInfo();
1232     MS_LOG(EXCEPTION) << "SOMAS Solving Failed.";
1233   }
1234 
1235   // Update solver_tensor_desc offset to tensors list
1236   for (const auto &tensor : tensors_list_) {
1237     MS_EXCEPTION_IF_NULL(tensor);
1238     tensor->SetOffset();
1239   }
1240 
1241   UpdateRefTensorsOffset();
1242   UpdateContiguousTensorsOffset(contiguous_list_with_ref_index_map);
1243 
1244   // Set mem_offset_ value by solver result
1245   mem_offset_ = static_cast<size_t>(somas_solver_->GetMaxOffset());
1246 
1247   return true;
1248 }
1249 
GetContiguousListContainRefTensor()1250 std::map<size_t, size_t> Somas::GetContiguousListContainRefTensor() {
1251   // key: contiguous list index with ref node input; value: contiguous list index with ref node output
1252   std::map<size_t, size_t> contiguous_list_with_ref_index_map;
1253   std::map<size_t, size_t> ref_tensors_in_contiguous_map = GetRefTensorsInContiguousList();
1254   std::map<size_t, std::map<size_t, std::set<size_t>>> contiguous_ref_list_error_check_map;
1255   for (auto ref_pair : ref_tensors_in_contiguous_map) {
1256     size_t ref_first = ref_pair.first;
1257     size_t ref_second = ref_pair.second;
1258     bool found_first = false;
1259     bool found_second = false;
1260     size_t index_first = 0;
1261     size_t index_second = 0;
1262     size_t index_in_list_first = 0;
1263     size_t index_in_list_second = 0;
1264     for (size_t index = 0; index < contiguous_tensors_list_.size() && (!found_first || !found_second); index++) {
1265       if (!found_first) {
1266         auto iterator_first =
1267           std::find(contiguous_tensors_list_[index].begin(), contiguous_tensors_list_[index].end(), ref_first);
1268         if (iterator_first != contiguous_tensors_list_[index].end()) {
1269           index_first = index;
1270           index_in_list_first = iterator_first - contiguous_tensors_list_[index].begin();
1271           found_first = true;
1272         }
1273       }
1274       if (!found_second) {
1275         auto iterator_second =
1276           std::find(contiguous_tensors_list_[index].begin(), contiguous_tensors_list_[index].end(), ref_second);
1277         if (iterator_second != contiguous_tensors_list_[index].end()) {
1278           index_second = index;
1279           index_in_list_second = iterator_second - contiguous_tensors_list_[index].begin();
1280           found_second = true;
1281         }
1282       }
1283     }
1284 
1285     if (!found_first) {
1286       MS_LOG(WARNING) << "Contiguous ref tensor " << ref_first << " not found in any contiguous list";
1287     }
1288     if (!found_second) {
1289       MS_LOG(WARNING) << "Contiguous ref tensor " << ref_second << " not found in any contiguous list";
1290     }
1291     if (contiguous_list_with_ref_index_map.find(index_first) == contiguous_list_with_ref_index_map.end() ||
1292         contiguous_list_with_ref_index_map[index_first] == index_second) {
1293       contiguous_list_with_ref_index_map[index_first] = index_second;
1294       // Checking for error cases
1295       if (index_in_list_first != index_in_list_second) {
1296         MS_LOG(WARNING) << "Inconsistency in contiguous ref: tensor " << ref_first << " in position "
1297                         << index_in_list_first << " of contiguous list " << index_first << " and tensor " << ref_second
1298                         << " in position " << index_in_list_second << " of contiguous list " << index_second;
1299       }
1300       contiguous_ref_list_error_check_map[index_first][index_second].insert(index_in_list_first);
1301     } else {
1302       MS_LOG(WARNING) << "Contiguous list " << index_first << " associated (ref node) with two other contiguous lists: "
1303                       << contiguous_list_with_ref_index_map[index_first] << " and " << index_second;
1304     }
1305   }
1306 
1307   for (auto check_list_pair : contiguous_ref_list_error_check_map) {
1308     auto first_list = check_list_pair.first;
1309     auto index_set_map = check_list_pair.second;
1310     for (auto index_set : index_set_map) {
1311       auto second_list = index_set.first;
1312       if (contiguous_tensors_list_[first_list].size() != contiguous_tensors_list_[second_list].size()) {
1313         MS_LOG(WARNING) << "Contiguous lists " << first_list << " and " << second_list
1314                         << " considered in ref do not have the same size";
1315       }
1316       for (size_t x = 0; x < contiguous_tensors_list_[second_list].size(); x++) {
1317         if (contiguous_ref_list_error_check_map[first_list][second_list].count(x) == 0) {
1318           MS_LOG(WARNING) << "Contiguous lists " << first_list << " and " << second_list
1319                           << " considered in ref: ref pair at in-lists index " << x << " has not been considered";
1320         }
1321       }
1322     }
1323   }
1324   return contiguous_list_with_ref_index_map;
1325 }
1326 
GetRefTensorsInContiguousList()1327 std::map<size_t, size_t> Somas::GetRefTensorsInContiguousList() {
1328   // key: refnode input value: refnode output
1329   std::map<size_t, size_t> ref_tensors_in_contiguous_map;
1330   for (auto ref_node_list : ref_node_constraints_) {
1331     // Count contiguous tensors in ref list
1332     auto contiguous_in_ref_list = std::count_if(ref_node_list.begin(), ref_node_list.end(),
1333                                                 [this](size_t tid) { return tensors_map_[tid]->contiguous_; });
1334     // Keep info about contiguous and check for errors
1335     if (ref_node_list.size() > kRefNodeTensorNum && contiguous_in_ref_list > 0) {
1336       MS_LOG(WARNING) << "Ref node of size greater than two with at least one contiguous tensor in";
1337     }
1338     if (ref_node_list.size() == kRefNodeTensorNum && contiguous_in_ref_list == 1) {
1339       MS_LOG(WARNING) << "Ref node of size two with only one contiguous tensor" << ref_node_list[0] << ":"
1340                       << tensors_map_[ref_node_list[0]]->contiguous_ << ", " << ref_node_list[1] << ":"
1341                       << tensors_map_[ref_node_list[1]]->contiguous_;
1342     }
1343     if (ref_node_list.size() == kRefNodeTensorNum && contiguous_in_ref_list == SizeToLong(kRefNodeTensorNum)) {
1344       ref_tensors_in_contiguous_map[ref_node_list[0]] = ref_node_list[1];
1345     }
1346   }
1347   return ref_tensors_in_contiguous_map;
1348 }
1349 
UpdateContiguousTensorsOffset(const std::map<size_t,size_t> & contiguous_ref_list_map)1350 void Somas::UpdateContiguousTensorsOffset(const std::map<size_t, size_t> &contiguous_ref_list_map) {
1351   // Handle contiguous ref node
1352   for (auto ref_list_pair : contiguous_ref_list_map) {
1353     size_t index_first = ref_list_pair.first;
1354     size_t index_second = ref_list_pair.second;
1355     for (size_t x = 0; x < contiguous_tensors_list_[index_second].size(); x++) {
1356       tensors_map_[contiguous_tensors_list_[index_second][x]]->offset_ =
1357         tensors_map_[contiguous_tensors_list_[index_first][x]]->offset_;
1358     }
1359   }
1360 
1361   // Contiguous gaps postprocessing
1362   for (auto list : contiguous_tensors_list_) {
1363     tensors_map_[list[0]]->offset_ += kGapSize;
1364   }
1365 }
1366 
UpdateRefTensorsOffset()1367 void Somas::UpdateRefTensorsOffset() {
1368   // Ref Node Postprocessing
1369   MS_LOG(INFO) << "\nStart Solving Postprocessing for Ref Node";
1370   // Set offset for rest of ref node list (ignored by solver due to ref node preprocessing)
1371   for (auto ref_node_list : ref_node_constraints_) {
1372     for (size_t i = 1; i < ref_node_list.size(); ++i) {
1373       tensors_map_[ref_node_list[i]]->offset_ = tensors_map_[ref_node_list[0]]->offset_;
1374     }
1375   }
1376 }
1377 
UpdateRefOverlapTensorsConflicts()1378 void Somas::UpdateRefOverlapTensorsConflicts() {
1379   // Ref Overlap Preprocessing
1380   MS_LOG(INFO) << "Start Solving Preprocessing for Ref Overlap";
1381   // In ConflictComputing(), by use of ref_overlap_ flag, each tensor in a ref_overlap_list has all entries 1 in
1382   // cannot_reuse_ array Here, we allow reuse only among tensors in same list
1383   for (auto ref_overlap_list : ref_overlap_constraints_) {
1384     for (size_t tid_1 : ref_overlap_list) {
1385       for (size_t tid_2 : ref_overlap_list) {
1386         reuse_matrix_[tid_1].SetBitTrue(tid_2);
1387         reuse_matrix_[tid_2].SetBitTrue(tid_1);
1388       }
1389     }
1390   }
1391   MS_LOG(INFO) << "End Solving Preprocessing for Ref Overlap";
1392 }
1393 
UpdateRefTensorsConflict()1394 void Somas::UpdateRefTensorsConflict() {
1395   // Keep all constraints for first tensor in list
1396   for (auto ref_node_list : ref_node_constraints_) {
1397     size_t tid_0 = ref_node_list[0];
1398     for (SomasTensorPtr tensor : tensors_list_) {
1399       if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) {
1400         continue;
1401       }
1402       for (size_t tid : ref_node_list) {
1403         if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) {
1404           reuse_matrix_[tid_0].SetBitFalse(tensor->GetId());
1405           reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0);
1406           break;
1407         }
1408       }
1409     }
1410     // Set rest to size 0, so that solver ignores them (if not contiguous)
1411     for (size_t i = 1; i < ref_node_list.size(); ++i) {
1412       if (!tensors_map_[ref_node_list[i]]->contiguous_) {
1413         tensors_map_[ref_node_list[i]]->aligned_size_ = 0;
1414       }
1415     }
1416   }
1417 }
1418 
GetSplitName(const std::string & scope_name) const1419 std::string Somas::GetSplitName(const std::string &scope_name) const {
1420   auto index = scope_name.rfind('/');
1421   if (index == std::string::npos) {
1422     return scope_name;
1423   } else {
1424     if (index < scope_name.size() - 1) {
1425       auto split_name = scope_name.substr(index + 1);
1426       return split_name;
1427     }
1428     return scope_name;
1429   }
1430 }
1431 
SomasInfo(bool calc_hash) const1432 std::string Somas::SomasInfo(bool calc_hash) const {
1433   std::ostringstream oss;
1434   if (!calc_hash) {
1435     DumpParameters(oss);
1436   }
1437   DumpTensors(oss);
1438   DumpNodes(oss);
1439 
1440   oss << "\n\nAll Stream Groups:\n\n";
1441   for (const auto &stream_group : streams_groups_) {
1442     for (const auto &stream : stream_group) {
1443       oss << "stm" << stream << " ";
1444     }
1445     oss << "\n";
1446   }
1447 
1448   if (!ref_node_constraints_.empty()) {
1449     oss << "\n\nAll Ref Node Info:\n\n";
1450     for (const auto &ref_in_out : ref_node_constraints_) {
1451       oss << "refnode input-output:";
1452       for (const auto &item : ref_in_out) {
1453         oss << "%" << item << "T ";
1454       }
1455       oss << "\n";
1456     }
1457   }
1458   return oss.str();
1459 }
1460 
DumpNodes(std::ostringstream & oss) const1461 void Somas::DumpNodes(std::ostringstream &oss) const {
1462   oss << "\n\nAll Nodes:\n\n";
1463   for (const auto &node : nodes_list_) {
1464     MS_EXCEPTION_IF_NULL(node);
1465     auto scope_name = node->scope_full_name_;
1466     std::string split_name = GetSplitName(scope_name);
1467     oss << "$" << node->GetId() << "\t" << split_name << "\t" << static_cast<int>(node->GetType()) << "\t";
1468     auto input_num = node->input_tensors_.size() + node->input_parameters_map_.size();
1469     oss << "inputs[";
1470     size_t tensor_index = 0;
1471     for (size_t input_index = 0; input_index < input_num; input_index++) {
1472       auto iter = node->input_parameters_map_.find(input_index);
1473       if (iter != node->input_parameters_map_.end()) {
1474         oss << "%" << iter->second->id_ << "P"
1475             << ", ";
1476       } else {
1477         oss << "%" << node->input_tensors_[tensor_index]->GetId() << "T"
1478             << ", ";
1479         tensor_index++;
1480       }
1481     }
1482 
1483     oss << "]";
1484     oss << "\toutputs[";
1485     for (const auto &out : node->output_tensors_) {
1486       MS_EXCEPTION_IF_NULL(out);
1487       oss << "%" << out->GetId() << "T"
1488           << ", ";
1489     }
1490     oss << "]";
1491     oss << "\tworkspace[";
1492     for (const auto &wk : node->workspace_tensors_) {
1493       MS_EXCEPTION_IF_NULL(wk);
1494       oss << "%" << wk->GetId() << "T"
1495           << ", ";
1496     }
1497     oss << "]";
1498     oss << "\tstreamID["
1499         << "@" << node->GetStream()->GetId() << "]\n";
1500   }
1501 }
1502 
DumpTensors(std::ostringstream & oss) const1503 void Somas::DumpTensors(std::ostringstream &oss) const {
1504   oss << "\n\nAll Tensors:\n\n";
1505   oss << "index:"
1506       << "\tsize:"
1507       << "\treal_size:"
1508       << "\toffset:"
1509       << "\taddr:"
1510       << "\ttype:"
1511       << "\tlifelong:"
1512       << "\tlife_start:"
1513       << "\tlife_end:"
1514       << "\tsource node name:\n";
1515 
1516   for (const auto &tensor : tensors_list_) {
1517     MS_EXCEPTION_IF_NULL(tensor);
1518     auto scope_name = tensor->GetSourceNode()->scope_full_name_;
1519     std::string split_name = GetSplitName(scope_name);
1520     oss << "%" << tensor->GetId() << "T"
1521         << "\t"
1522         << "#" << tensor->GetAlignedSize() << "S"
1523         << "\t"
1524         << "#" << tensor->GetOriginalSize() << "S"
1525         << "\t"
1526         << "&" << tensor->GetOffset() << ""
1527         << "\t"
1528         << "&" << static_cast<void *>(tensor->GetOffset() + mem_base_addr_) << "\t"
1529         << tensor_type_name_map[tensor->type_] << "\t" << tensor->IsLifelong() << "\t" << tensor->lifetime_.start_
1530         << "\t" << tensor->lifetime_.end_ << "\t" << split_name << "\n";
1531   }
1532 }
1533 
DumpParameters(std::ostringstream & oss) const1534 void Somas::DumpParameters(std::ostringstream &oss) const {
1535   oss << "All Parameters:\n\n";
1536   oss << "index:"
1537       << "\tsize:"
1538       << "\tstart_addr:"
1539       << "\tsource node name:"
1540       << "\tnode out index:\n";
1541 
1542   for (const auto &param : parameters_list_) {
1543     MS_EXCEPTION_IF_NULL(param);
1544     oss << "%" << param->id_ << "P"
1545         << "\t"
1546         << "#" << param->size_ << "S"
1547         << "\t"
1548         << "&" << param->addr_ << "\t" << param->source_node_name_ << "\t" << param->output_index_ << "\n";
1549   }
1550 }
1551 
DumpSomasInfoIR(const string filename) const1552 void Somas::DumpSomasInfoIR(const string filename) const { (void)Common::SaveStringToFile(filename, SomasInfo()); }
1553 
Offline() const1554 std::string Somas::Offline() const {
1555   std::ostringstream oss;
1556 
1557   for (auto tensor : tensors_list_) {
1558     MS_EXCEPTION_IF_NULL(tensor);
1559     if (tensor->IsOutputOnly() || tensor->type_ == TensorType::kRefNodeOutput) {
1560       oss << "Somas EDGE ERROR src=n" << tensor->GetSourceNode()->GetId()
1561           << ", srcstm=" << tensor->GetSourceStream()->GetId() << ", dst=nc"
1562           << ", dststm=nc"
1563           << ", workspace=0, size=" << tensor->GetOriginalSize()
1564           << ", lifelong=" << static_cast<int>(tensor->lifelong_value_) << ", tid=" << tensor->GetId()
1565           << ", start=" << tensor->lifetime_.start_ << ", end=" << tensor->lifetime_.end_ << std::endl;
1566     } else {
1567       std::map<size_t, size_t> dest_infos;
1568       for (SomasNodePtr dest_node : tensor->destinations_) {
1569         dest_infos.insert(std::make_pair(dest_node->GetId(), dest_node->GetStream()->GetId()));
1570       }
1571 
1572       for (auto dest_info : dest_infos) {
1573         oss << "Somas EDGE src=n" << tensor->GetSourceNode()->GetId()
1574             << ", srcstm=" << tensor->GetSourceStream()->GetId() << ", dst=n" << dest_info.first
1575             << ", dststm=" << dest_info.second << ", workspace=" << static_cast<int>(tensor->type_ == kWorkspace)
1576             << ", size=" << tensor->GetOriginalSize() << ", lifelong=" << static_cast<int>(tensor->lifelong_value_)
1577             << ", tid=" << tensor->GetId() << ", start=" << tensor->lifetime_.start_
1578             << ", end=" << tensor->lifetime_.end_ << std::endl;
1579       }
1580     }
1581   }
1582   for (vector<size_t> tList : contiguous_tensors_list_) {
1583     oss << "Somas CONTIGUOUS";
1584     for (size_t tid : tList) {
1585       oss << " " << tid;
1586     }
1587     oss << std::endl;
1588   }
1589   for (const auto &group : streams_groups_) {
1590     oss << "Somas GROUP";
1591     for (int64_t sid : group) {
1592       oss << " " << sid;
1593     }
1594     oss << std::endl;
1595   }
1596   return oss.str();
1597 }
1598 
DumpOfflineIR(const string filename) const1599 void Somas::DumpOfflineIR(const string filename) const {
1600   MS_LOG(INFO) << "Printing somas-log-from-graph log: " << filename;
1601   (void)Common::SaveStringToFile(filename, Offline());
1602 }
1603 
SomasMemory() const1604 std::string Somas::SomasMemory() const {
1605   std::ostringstream oss;
1606 
1607   std::map<size_t, size_t> mem_map;
1608   for (auto tensor : tensors_list_) {
1609     MS_EXCEPTION_IF_NULL(tensor);
1610     mem_map[tensor->GetOffset()] = 0;
1611   }
1612 
1613   size_t num = 0;
1614   for (auto iter = mem_map.begin(); iter != mem_map.end(); ++iter, ++num) {
1615     iter->second = num;
1616   }
1617 
1618   std::map<size_t, std::map<size_t, SomasTensorPtr>> mem_list;
1619 
1620   for (const auto &output_tensor : tensors_list_) {
1621     MS_EXCEPTION_IF_NULL(output_tensor);
1622     size_t key = output_tensor->offset_;
1623     auto iter = mem_list.find(key);
1624     if (iter == mem_list.end()) {
1625       std::map<size_t, SomasTensorPtr> id_tensor_map;
1626       id_tensor_map[output_tensor->GetId()] = output_tensor;
1627       mem_list[key] = id_tensor_map;
1628     } else {
1629       iter->second[output_tensor->GetId()] = output_tensor;
1630     }
1631   }
1632 
1633   oss << "mem_id:"
1634       << "\tstart_offset:"
1635       << "\tend_offset:"
1636       << "\ttensor_id:"
1637       << "\torigin_size:"
1638       << "\talign_size:"
1639       << "\tstart_addr:"
1640       << "\tend_addr:"
1641       << "\ttype:"
1642       << "\tsrc_node:"
1643       << "\tsrc_stm_id:"
1644       << "lifetime_start\t"
1645       << "lifetime_end\n";
1646 
1647   for (const auto &mem : mem_list) {
1648     auto id_tensor_map = mem.second;
1649     for (const auto &id_tensor : id_tensor_map) {
1650       auto place_tensor = id_tensor.second;
1651       MS_EXCEPTION_IF_NULL(place_tensor);
1652       std::string scope_name;
1653       int64_t src_stm_id = 0xffff;
1654       if (place_tensor->GetSourceNode() != nullptr) {
1655         scope_name = place_tensor->GetSourceNode()->scope_full_name_;
1656         src_stm_id = place_tensor->GetSourceNode()->GetStream()->GetId();
1657       } else {
1658         scope_name = "Somas Tensor";
1659       }
1660 
1661       std::string split_name = GetSplitName(scope_name);
1662       oss << "#" << mem_map[place_tensor->GetOffset()] << "\t" << place_tensor->GetOffset() << "\t"
1663           << place_tensor->GetOffset() + place_tensor->GetAlignedSize() << "\t%" << place_tensor->GetId() << "T\t"
1664           << place_tensor->GetOriginalSize() << "\t" << place_tensor->GetAlignedSize() << "\t&"
1665           << static_cast<void *>(place_tensor->GetOffset() + mem_base_addr_) << "\t&"
1666           << static_cast<void *>(place_tensor->GetOffset() + mem_base_addr_ + place_tensor->GetAlignedSize()) << "\t"
1667           << tensor_type_name_map[place_tensor->type_] << "\t" << split_name << "\tstm" << src_stm_id << "\t"
1668           << place_tensor->lifetime_.start_ << "\t" << place_tensor->lifetime_.end_ << "\n";
1669     }
1670   }
1671   return oss.str();
1672 }
1673 
DumpSomasMemoryIR(const string & filename) const1674 void Somas::DumpSomasMemoryIR(const string &filename) const { (void)Common::SaveStringToFile(filename, SomasMemory()); }
1675 
CalcLowerBound() const1676 size_t Somas::CalcLowerBound() const {
1677   size_t max_node_id = std::accumulate(tensors_list_.begin(), tensors_list_.end(), 0, [](size_t max_id, auto tensor) {
1678     return std::max(max_id, tensor->lifetime_.end_);
1679   });
1680 
1681   std::map<size_t, size_t> lifetime_lb;
1682   for (size_t time = 0; time <= max_node_id; time++) {
1683     lifetime_lb[time] = 0;
1684   }
1685 
1686   size_t lower, upper;
1687   for (const auto &tensor : tensors_list_) {
1688     MS_EXCEPTION_IF_NULL(tensor);
1689     if (tensor->lifelong_value_ == kLifeLongGraphAll) {
1690       lower = 0;
1691       upper = max_node_id;
1692     } else {
1693       lower = tensor->lifetime_.start_;
1694       upper = tensor->lifetime_.end_;
1695     }
1696 
1697     for (size_t time = lower; time <= upper; time++) {
1698       lifetime_lb[time] += tensor->GetAlignedSize();
1699     }
1700   }
1701 
1702   size_t max_lifetime = 0;
1703   for (size_t time = 0; time <= max_node_id; time++) {
1704     if (max_lifetime < lifetime_lb[time]) {
1705       max_lifetime = lifetime_lb[time];
1706     }
1707   }
1708   return max_lifetime;
1709 }
1710 
GenGraphStatisticInfo()1711 void Somas::GenGraphStatisticInfo() {
1712   lower_bound_ = CalcLowerBound();
1713   for (const auto &tensor : tensors_list_) {
1714     MS_EXCEPTION_IF_NULL(tensor);
1715     upper_bound_ += tensor->aligned_size_;
1716     if (tensor->type_ == kWorkspace) {
1717       workspace_total_size_ += tensor->aligned_size_;
1718     }
1719     if (tensor->lifelong_value_ == kLifeLongGraphAll) {
1720       lifelong_all_total_size_ += tensor->aligned_size_;
1721     } else if (tensor->lifelong_value_ == kLifeLongGraphStart) {
1722       lifelong_start_total_size_ += tensor->aligned_size_;
1723     } else if (tensor->lifelong_value_ == kLifeLongGraphEnd) {
1724       lifelong_end_total_size_ += tensor->aligned_size_;
1725     }
1726   }
1727 
1728   const double giga = 1024. * 1024. * 1024.;
1729   MS_LOG(INFO) << "Lower Bound: " << lower_bound_ << " (" << lower_bound_ / giga
1730                << " GB), Upper Bound: " << upper_bound_ << " (" << upper_bound_ / giga << " GB)";
1731 
1732   MS_LOG(INFO) << "\nTotal Dynamic Size (Upper Bound):\t" << upper_bound_ << "\n"
1733                << "Theoretical Optimal Size (Lower Bound):\t" << lower_bound_ << "\n"
1734                << "Total Workspace Size:\t" << workspace_total_size_ << "\n"
1735                << "Total Communication Input Tensor Size:\t" << comm_input_total_size_ << "\n"
1736                << "Total Communication Output Tensor Size:\t" << comm_output_total_size_ << "\n"
1737                << "Total LifeLong All Tensor Size:\t" << lifelong_all_total_size_ << "\n"
1738                << "Total LifeLong Start Tensor Size:\t" << lifelong_start_total_size_ << "\n"
1739                << "Total LifeLong End Tensor Size:\t" << lifelong_end_total_size_ << "\n"
1740                << "Reused Size(Allocate Size):\t" << GetTotalMemSize() << "\n\n\n";
1741 }
1742 
GetNodeOutputPtr(const AnfNodePtr & node,size_t index) const1743 uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
1744   MS_EXCEPTION_IF_NULL(node);
1745   auto key = node.get();
1746   auto iter = nodes_map_.find(key);
1747   uint8_t *ptr = nullptr;
1748   if (iter != nodes_map_.end()) {
1749     auto &somas_node = iter->second.at(0);
1750     MS_EXCEPTION_IF_NULL(somas_node);
1751     if (index >= somas_node->output_tensors_.size()) {
1752       MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's output size:["
1753                         << somas_node->output_tensors_.size() << "]";
1754     }
1755     auto output_tensor = somas_node->output_tensors_[index];
1756     ptr = mem_base_addr_ + output_tensor->offset_;
1757   } else {
1758     MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map";
1759   }
1760   return ptr;
1761 }
1762 
GetNodeWorkSpacePtr(const AnfNodePtr & node,size_t index) const1763 uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const {
1764   MS_EXCEPTION_IF_NULL(node);
1765   auto key = node.get();
1766   auto iter = nodes_map_.find(key);
1767   uint8_t *ptr = nullptr;
1768   if (iter != nodes_map_.end()) {
1769     auto &somas_node = iter->second.at(0);
1770     MS_EXCEPTION_IF_NULL(somas_node);
1771     if (index >= somas_node->workspace_tensors_.size()) {
1772       MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:["
1773                         << somas_node->workspace_tensors_.size() << "]";
1774     }
1775     auto workspace_tensor = somas_node->workspace_tensors_[index];
1776     ptr = mem_base_addr_ + workspace_tensor->offset_;
1777   }
1778   return ptr;
1779 }
1780 #ifndef ENABLE_SECURITY
ConvertToProfilingNode(uint32_t graph_id)1781 void Somas::ConvertToProfilingNode(uint32_t graph_id) {
1782 #ifdef ENABLE_D
1783   auto graph_node = MemoryProfiling::GetInstance().GetGraphMemoryNode(graph_id);
1784   if (graph_node == nullptr) {
1785     graph_node = MemoryProfiling::GetInstance().AddGraphMemoryNode(graph_id);
1786     MS_LOG(INFO) << "Add graph memory node for dynamic memory profiling, graph id is " << graph_id;
1787   }
1788 
1789   for (const auto &tensor : tensors_list_) {
1790     TensorMemory tensor_memory;
1791     tensor_memory.SetTensorId(tensor->GetId());
1792     tensor_memory.SetAlignedSize(tensor->GetAlignedSize());
1793     tensor_memory.SetType(tensor_type_name_map[tensor->type_]);
1794     tensor_memory.SetLifeStart(tensor->lifetime_.start_);
1795     tensor_memory.SetLifeEnd(tensor->lifetime_.end_);
1796     tensor_memory.SetLifeLong(life_long_name_map[tensor->lifelong_value_]);
1797     graph_node->AddTensorMemory(tensor_memory);
1798   }
1799 
1800   for (const auto &node : nodes_list_) {
1801     NodeMemory node_memory;
1802     std::string name = GetSplitName(node->scope_full_name_);
1803     node_memory.SetNodeName(name);
1804     node_memory.SetNodeId(node->GetId());
1805     for (const auto &input_tensor : node->input_tensors_) {
1806       node_memory.AddInputTensorId(input_tensor->GetId());
1807     }
1808     for (const auto &output_tensor : node->output_tensors_) {
1809       node_memory.AddOutputTensorId(output_tensor->GetId());
1810     }
1811     for (const auto &workspace_tensor : node->workspace_tensors_) {
1812       node_memory.AddWorkSpaceTensorId(workspace_tensor->GetId());
1813     }
1814     graph_node->AddNodeMemory(node_memory);
1815   }
1816 #endif
1817 }
1818 #endif
1819 }  // namespace somas
1820 }  // namespace mindspore
1821