• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/allreduce_fusion/allreduce_node.h"
18 #include <queue>
19 #include "frontend/parallel/tensor_layout/tensor_layout.h"
20 #include "utils/log_adapter.h"
21 
22 namespace mindspore {
23 namespace parallel {
AddNext(const AllreduceNodePtr & next_node)24 Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) {
25   if (next_node == nullptr) {
26     MS_LOG(ERROR) << "next_node is nullptr!";
27     return FAILED;
28   }
29   next_.emplace_back(next_node);
30   return SUCCESS;
31 }
32 
AddPrev(const AllreduceNodePtr & prev_node,double dist,double * max)33 Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) {
34   if (prev_node == nullptr) {
35     MS_LOG(ERROR) << "next_node is nullptr!";
36     return FAILED;
37   }
38   if (dist <= 0) {
39     MS_LOG(ERROR) << "dist must be positive! dist: " << dist;
40     return FAILED;
41   }
42   prev_.emplace_back(prev_node);
43   double add_dist = prev_node->depend_feat_size() + dist;
44   depend_feat_size_ += add_dist;
45   if (depend_feat_size_ > *max) {
46     *max = depend_feat_size_;
47   }
48   std::queue<AllreduceNodePtr> next_queue;
49   for (auto &next : next_) {
50     next_queue.push(next);
51   }
52   while (!next_queue.empty()) {
53     auto ele = next_queue.front();
54     ele->AddDependFeatSize(add_dist);
55     if (ele->depend_feat_size() > *max) {
56       *max = ele->depend_feat_size();
57     }
58     for (auto &next : ele->next()) {
59       next_queue.push(next);
60     }
61     next_queue.pop();
62   }
63   return SUCCESS;
64 }
65 
Init(const CNodePtr & cnode_ptr)66 Status AllreduceNode::Init(const CNodePtr &cnode_ptr) {
67   if (cnode_ptr == nullptr) {
68     MS_LOG(ERROR) << "cnode_ptr is nullptr!";
69     return FAILED;
70   }
71   cnode_ptr_ = cnode_ptr;
72   return SUCCESS;
73 }
74 
AddPara(const AnfNodePtr & node_ptr)75 Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
76   if (node_ptr == nullptr) {
77     MS_LOG(ERROR) << "node_ptr is nullptr!";
78     return FAILED;
79   }
80   if (!node_ptr->isa<Parameter>()) {
81     MS_LOG(ERROR) << "node_ptr is not a ParameterPtr!";
82     return FAILED;
83   }
84   auto para_ptr = node_ptr->cast<ParameterPtr>();
85   MS_EXCEPTION_IF_NULL(para_ptr);
86   auto layout_ptr = para_ptr->user_data<TensorLayout>();
87   if (layout_ptr == nullptr) {
88     MS_LOG(ERROR) << "layout_ptr is nullptr!";
89     return FAILED;
90   }
91   auto emplace_return = paras_.emplace(node_ptr);
92   if (emplace_return.second) {
93     double para_size = static_cast<double>(layout_ptr->slice_shape().size());
94     curr_para_size_ += para_size;
95     para_size_map_[node_ptr] = para_size;
96   } else {
97     MS_LOG(INFO) << "node already exist!";
98   }
99   return SUCCESS;
100 }
101 
RemovePara(const AnfNodePtr & node_ptr)102 Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) {
103   if (node_ptr == nullptr) {
104     MS_LOG(ERROR) << "node_ptr is nullptr!";
105     return FAILED;
106   }
107   auto erase_num = paras_.erase(node_ptr);
108   if (erase_num == 0) {
109     MS_LOG(ERROR) << "para not find!";
110     return FAILED;
111   }
112   curr_para_size_ -= para_size_map_[node_ptr];
113   return SUCCESS;
114 }
115 
ToString() const116 void AllreduceNode::ToString() const {
117   MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size();
118   for (auto &para : paras_) {
119     MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para);
120   }
121   MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_;
122 }
123 }  // namespace parallel
124 }  // namespace mindspore
125