• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/parallel/operator_info.h"
18 #include <algorithm>
19 #include "tools/converter/ops/ops_def.h"
20 #include "tools/optimizer/parallel/split_strategy.h"
21 #include "ops/concat.h"
22 #include "ops/addn.h"
23 #include "utils/utils.h"
24 #include "base/core_ops.h"
25 #include "include/errorcode.h"
26 #include "nnacl/op_base.h"
27 
28 namespace mindspore {
29 namespace opt {
is_any_none(const std::vector<int64_t> & split)30 bool is_any_none(const std::vector<int64_t> &split) {
31   return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); });
32 }
33 
is_any_not_none(const std::vector<int64_t> & split)34 bool is_any_not_none(const std::vector<int64_t> &split) {
35   return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
36 }
37 
CreateFakeAbstractTensor() const38 std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() const {
39   auto type_ptr = TypeIdToType(operator_type_id_);
40   std::vector<int64_t> shape_vector;
41   return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
42 }
43 
CheckSplitResult(const AnfNodePtr & result_anf_node,const std::vector<AnfNodePtr> & split_results,int target_output_num)44 int OperatorInfo::CheckSplitResult(const AnfNodePtr &result_anf_node, const std::vector<AnfNodePtr> &split_results,
45                                    int target_output_num) {
46   if ((result_anf_node == nullptr) || (split_results.size() != IntToSize(target_output_num))) {
47     MS_LOG(ERROR) << name_ << " : Make split cnode failed.";
48     return lite::RET_ERROR;
49   }
50   return lite::RET_OK;
51 }
52 
Init(const FuncGraphPtr & func_graph,const CNodePtr & cnode,int32_t fmk_type)53 void OperatorInfo::Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type) {
54   func_graph_ = func_graph;
55   cnode_ = cnode;
56   fmk_type_ = fmk_type;
57   parallel_output_nodes_.clear();
58 }
59 
SetCNodeBackend()60 int OperatorInfo::SetCNodeBackend() {
61   for (size_t i = 0; i < strategy_.dev_num; ++i) {
62     lite::DeviceType dt_type;
63     MS_CHECK_LT(i, strategy_.dev_types.size(), lite::RET_ERROR);
64     std::string type = strategy_.dev_types[i];
65     MS_CHECK_LT(i, parallel_output_nodes_.size(), lite::RET_ERROR);
66     auto post_node = parallel_output_nodes_[i];
67     MS_CHECK_TRUE_RET(post_node != nullptr, lite::RET_ERROR);
68     auto post_cnode = post_node->cast<CNodePtr>();
69     MS_CHECK_TRUE_RET(post_cnode != nullptr, lite::RET_ERROR);
70     auto cnode = post_cnode->input(1)->cast<CNodePtr>();
71     MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_ERROR);
72     auto type_iter = kSupportSplitedDevices.find(type);
73     if (type_iter == kSupportSplitedDevices.end()) {
74       MS_LOG(ERROR) << "SetCnodeBackend: unknown device type.";
75       return lite::RET_ERROR;
76     }
77     if (type_iter->second == lite::DeviceType::DT_NPU) {
78       MS_LOG(ERROR) << "SetCnodeBackend: unsupported device type npu.";
79       return lite::RET_ERROR;
80     }
81     dt_type = type_iter->second;
82     cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type)));
83   }
84   return lite::RET_OK;
85 }
86 
CheckStrategyValue()87 int OperatorInfo::CheckStrategyValue() {
88   auto strategy_size = strategy_.strategys.size();
89   for (size_t index = 0; index < strategy_size; ++index) {
90     auto strategy = strategy_.strategys[index];
91     for (const auto &s : strategy) {
92       if (s.size() != IntToSize(strategy_.dev_num)) {
93         MS_LOG(ERROR) << "Strategy split number:" << s.size()
94                       << " is not equal to device number: " << strategy_.dev_num;
95         return lite::RET_ERROR;
96       }
97       if (is_any_not_none(s) && is_any_none(s)) {
98         MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s;
99         return lite::RET_ERROR;
100       }
101     }
102   }
103   return lite::RET_OK;
104 }
105 
CreateMultipleOutputsOfAnfNode(const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)106 int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
107                                                  std::vector<AnfNodePtr> *outputs) {
108   MS_EXCEPTION_IF_NULL(node);
109   MS_EXCEPTION_IF_NULL(outputs);
110   AbstractBasePtrList ptr_list;
111   auto cnode = node->cast<CNodePtr>();
112   if (cnode == nullptr) {
113     MS_LOG(ERROR) << name_ << " : Failed to get CNode.";
114     return lite::RET_ERROR;
115   }
116   for (size_t i = 0; i < output_num; ++i) {
117     auto idx = NewValueNode(SizeToInt(i));
118     auto index = std::make_shared<Int32Imm>(SizeToInt(i));
119     MS_CHECK_TRUE_RET(index != nullptr, lite::RET_ERROR);
120     auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
121     MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
122     idx->set_abstract(abstract_scalar);
123     auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx});
124     if (tuple_getitem == nullptr) {
125       MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
126       return lite::RET_ERROR;
127     }
128     tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
129     outputs->push_back(tuple_getitem);
130     auto abstract_tensor = CreateFakeAbstractTensor();
131     ptr_list.push_back(abstract_tensor);
132   }
133   node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
134   return lite::RET_OK;
135 }
136 
CreateConcateNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,int32_t concat_dim,size_t input_nodes_num)137 AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
138                                            int32_t concat_dim, size_t input_nodes_num) {
139   MS_EXCEPTION_IF_NULL(orig_node);
140   if (input_nodes.size() != input_nodes_num) {
141     MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
142     return nullptr;
143   }
144   auto concat_prim = std::make_shared<ops::Concat>();
145   MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
146   concat_prim->set_axis(concat_dim);
147   auto value_node = NewValueNode(concat_prim);
148   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
149   std::vector<AnfNodePtr> concat_inputs = {value_node};
150   (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs),
151                        [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
152   auto concat_cnode = func_graph_->NewCNode(concat_inputs);
153   if (concat_cnode == nullptr) {
154     MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
155     return nullptr;
156   }
157   concat_cnode->set_fullname_with_scope("Concat_" + name_);
158   concat_cnode->set_scope(orig_node->scope());
159   std::vector<AnfNodePtr> outputs;
160   (void)CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs);
161   return concat_cnode;
162 }
163 
CreateReduceNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,size_t input_nodes_num)164 AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
165                                           size_t input_nodes_num) {
166   MS_EXCEPTION_IF_NULL(orig_node);
167   if (input_nodes.size() != input_nodes_num) {
168     MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
169     return nullptr;
170   }
171   // addup inputs element-wise
172   auto addn_prim = std::make_shared<ops::AddN>();
173   MS_CHECK_TRUE_RET(addn_prim != nullptr, nullptr);
174   auto value_node = NewValueNode(addn_prim);
175   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
176   std::vector<AnfNodePtr> addn_inputs = {value_node};
177   (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs),
178                        [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
179   auto addn_cnode = func_graph_->NewCNode(addn_inputs);
180   if (addn_cnode == nullptr) {
181     MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
182     return nullptr;
183   }
184   addn_cnode->set_fullname_with_scope("AddN_" + name_);
185   addn_cnode->set_scope(orig_node->scope());
186   return addn_cnode;
187 }
188 
DoSplit()189 int OperatorInfo::DoSplit() {
190   if (CheckStrategyValue() != lite::RET_OK) {
191     MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
192     return lite::RET_ERROR;
193   }
194   if (CheckStrategy(strategy_) != lite::RET_OK) {
195     MS_LOG(ERROR) << name_ << ": Check strategys failed.";
196     return lite::RET_ERROR;
197   }
198   if (InferParallelCNodes() != lite::RET_OK) {
199     MS_LOG(ERROR) << name_ << ": InferParallelCNodes failed.";
200     return lite::RET_ERROR;
201   }
202   if (SetCNodeBackend() != lite::RET_OK) {
203     MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed.";
204     return lite::RET_ERROR;
205   }
206   if (InferReplaceOp() != lite::RET_OK) {
207     MS_LOG(ERROR) << name_ << ": InferForwardOps failed.";
208     return lite::RET_ERROR;
209   }
210   return lite::RET_OK;
211 }
212 }  // namespace opt
213 }  // namespace mindspore
214