• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/parallel/operator_info.h"
19 #include <algorithm>
20 #include "tools/optimizer/parallel/split_strategy.h"
21 #include "ops/auto_generate/gen_lite_ops.h"
22 #include "ops/tuple_get_item.h"
23 #include "include/common/utils/utils.h"
24 #include "include/errorcode.h"
25 #include "nnacl/op_base.h"
26 #include "ops/op_utils.h"
27 #include "src/common/log_util.h"
28 
29 namespace mindspore {
30 namespace opt {
is_any_none(const std::vector<int64_t> & split)31 bool is_any_none(const std::vector<int64_t> &split) {
32   return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); });
33 }
34 
is_any_not_none(const std::vector<int64_t> & split)35 bool is_any_not_none(const std::vector<int64_t> &split) {
36   return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
37 }
38 
CreateFakeAbstractTensor() const39 std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() const {
40   auto type_ptr = TypeIdToType(operator_type_id_);
41   std::vector<int64_t> shape_vector;
42   return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
43 }
44 
CheckSplitResult(const AnfNodePtr & result_anf_node,const std::vector<AnfNodePtr> & split_results,int target_output_num)45 int OperatorInfo::CheckSplitResult(const AnfNodePtr &result_anf_node, const std::vector<AnfNodePtr> &split_results,
46                                    int target_output_num) {
47   if ((result_anf_node == nullptr) || (split_results.size() != IntToSize(target_output_num))) {
48     MS_LOG(ERROR) << name_ << " : Make split cnode failed.";
49     return lite::RET_ERROR;
50   }
51   return lite::RET_OK;
52 }
53 
Init(const FuncGraphPtr & func_graph,const CNodePtr & cnode,int32_t fmk_type)54 void OperatorInfo::Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type) {
55   func_graph_ = func_graph;
56   cnode_ = cnode;
57   fmk_type_ = fmk_type;
58   parallel_output_nodes_.clear();
59 }
60 
SetCNodeBackend()61 int OperatorInfo::SetCNodeBackend() {
62   for (size_t i = 0; i < strategy_.dev_num; ++i) {
63     lite::DeviceType dt_type;
64     MS_CHECK_LT(i, strategy_.dev_types.size(), lite::RET_ERROR);
65     std::string type = strategy_.dev_types[i];
66     MS_CHECK_LT(i, parallel_output_nodes_.size(), lite::RET_ERROR);
67     auto post_node = parallel_output_nodes_[i];
68     MS_CHECK_TRUE_RET(post_node != nullptr, lite::RET_ERROR);
69     auto post_cnode = post_node->cast<CNodePtr>();
70     MS_CHECK_TRUE_RET(post_cnode != nullptr, lite::RET_ERROR);
71     auto cnode = post_cnode->input(1)->cast<CNodePtr>();
72     MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_ERROR);
73     auto type_iter = kSupportSplitedDevices.find(type);
74     if (type_iter == kSupportSplitedDevices.end()) {
75       MS_LOG(ERROR) << "SetCnodeBackend: unknown device type.";
76       return lite::RET_ERROR;
77     }
78     if (type_iter->second == lite::DeviceType::DT_NPU) {
79       MS_LOG(ERROR) << "SetCnodeBackend: unsupported device type npu.";
80       return lite::RET_ERROR;
81     }
82     dt_type = type_iter->second;
83     cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type)));
84   }
85   return lite::RET_OK;
86 }
87 
CheckStrategyValue()88 int OperatorInfo::CheckStrategyValue() {
89   auto strategy_size = strategy_.strategys.size();
90   for (size_t index = 0; index < strategy_size; ++index) {
91     auto strategy = strategy_.strategys[index];
92     for (const auto &s : strategy) {
93       if (s.size() != IntToSize(strategy_.dev_num)) {
94         MS_LOG(ERROR) << "Strategy split number:" << s.size()
95                       << " is not equal to device number: " << strategy_.dev_num;
96         return lite::RET_ERROR;
97       }
98       if (is_any_not_none(s) && is_any_none(s)) {
99         MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s;
100         return lite::RET_ERROR;
101       }
102     }
103   }
104   return lite::RET_OK;
105 }
106 
CreateMultipleOutputsOfAnfNode(const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)107 int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
108                                                  std::vector<AnfNodePtr> *outputs) {
109   MS_ASSERT(node != nullptr && outputs != nullptr);
110   AbstractBasePtrList ptr_list;
111   auto cnode = node->cast<CNodePtr>();
112   if (cnode == nullptr) {
113     MS_LOG(ERROR) << name_ << " : Failed to get CNode.";
114     return lite::RET_ERROR;
115   }
116   for (size_t i = 0; i < output_num; ++i) {
117     auto idx = NewValueNode(SizeToLong(i));
118     auto index = std::make_shared<Int64Imm>(SizeToLong(i));
119     MS_CHECK_TRUE_RET(index != nullptr, lite::RET_ERROR);
120     auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
121     MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
122     idx->set_abstract(abstract_scalar);
123     auto tuple_node = std::make_shared<ops::TupleGetItem>();
124     if (tuple_node == nullptr) {
125       MS_LOG(ERROR) << name_ << " : Failed to create tuple node.";
126       return lite::RET_ERROR;
127     }
128     auto tuple_prim_c = tuple_node->GetPrim();
129     if (tuple_prim_c == nullptr) {
130       MS_LOG(ERROR) << name_ << " : Failed to create tuple node primitive.";
131       return lite::RET_ERROR;
132     }
133     auto tuple_getitem = func_graph_->NewCNode({NewValueNode(tuple_prim_c), node, idx});
134     if (tuple_getitem == nullptr) {
135       MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
136       return lite::RET_ERROR;
137     }
138     tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
139     outputs->push_back(tuple_getitem);
140     auto abstract_tensor = CreateFakeAbstractTensor();
141     ptr_list.push_back(abstract_tensor);
142   }
143   node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
144   return lite::RET_OK;
145 }
146 
CreateConcateNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,int32_t concat_dim,size_t input_nodes_num)147 AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
148                                            int32_t concat_dim, size_t input_nodes_num) {
149   MS_ASSERT(orig_node != nullptr);
150   if (input_nodes.size() != input_nodes_num) {
151     MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
152     return nullptr;
153   }
154   auto concat_prim = std::make_shared<ops::Concat>();
155   MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
156   auto concat_prim_c = concat_prim->GetPrim();
157   MS_CHECK_TRUE_RET(concat_prim_c != nullptr, nullptr);
158   concat_prim->set_axis(concat_dim);
159   auto value_node = NewValueNode(concat_prim_c);
160   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
161   std::vector<AnfNodePtr> concat_inputs = {value_node};
162   (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs),
163                        [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
164   auto concat_cnode = func_graph_->NewCNode(concat_inputs);
165   if (concat_cnode == nullptr) {
166     MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
167     return nullptr;
168   }
169   concat_cnode->set_fullname_with_scope("Concat_" + name_);
170   concat_cnode->set_scope(orig_node->scope());
171   std::vector<AnfNodePtr> outputs;
172   (void)CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs);
173   return concat_cnode;
174 }
175 
CreateReduceNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,size_t input_nodes_num)176 AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
177                                           size_t input_nodes_num) {
178   MS_ASSERT(orig_node != nullptr);
179   if (input_nodes.size() != input_nodes_num) {
180     MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
181     return nullptr;
182   }
183   // addup inputs element-wise
184   auto addn_prim = std::make_shared<ops::AddN>();
185   MS_CHECK_TRUE_RET(addn_prim != nullptr, nullptr);
186   auto addn_prim_c = addn_prim->GetPrim();
187   MS_CHECK_TRUE_RET(addn_prim_c != nullptr, nullptr);
188   auto value_node = NewValueNode(addn_prim_c);
189   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
190   std::vector<AnfNodePtr> addn_inputs = {value_node};
191   (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs),
192                        [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
193   auto addn_cnode = func_graph_->NewCNode(addn_inputs);
194   if (addn_cnode == nullptr) {
195     MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
196     return nullptr;
197   }
198   addn_cnode->set_fullname_with_scope("AddN_" + name_);
199   addn_cnode->set_scope(orig_node->scope());
200   return addn_cnode;
201 }
202 
DoSplit()203 int OperatorInfo::DoSplit() {
204   if (CheckStrategyValue() != lite::RET_OK) {
205     MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
206     return lite::RET_ERROR;
207   }
208   if (CheckStrategy(strategy_) != lite::RET_OK) {
209     MS_LOG(ERROR) << name_ << ": Check strategys failed.";
210     return lite::RET_ERROR;
211   }
212   if (InferParallelCNodes() != lite::RET_OK) {
213     MS_LOG(ERROR) << name_ << ": InferParallelCNodes failed.";
214     return lite::RET_ERROR;
215   }
216   if (SetCNodeBackend() != lite::RET_OK) {
217     MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed.";
218     return lite::RET_ERROR;
219   }
220   if (InferReplaceOp() != lite::RET_OK) {
221     MS_LOG(ERROR) << name_ << ": InferForwardOps failed.";
222     return lite::RET_ERROR;
223   }
224   return lite::RET_OK;
225 }
226 }  // namespace opt
227 }  // namespace mindspore
228