1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/parallel/operator_info.h"
18 #include <algorithm>
19 #include "tools/converter/ops/ops_def.h"
20 #include "tools/optimizer/parallel/split_strategy.h"
21 #include "ops/concat.h"
22 #include "ops/addn.h"
23 #include "utils/utils.h"
24 #include "base/core_ops.h"
25 #include "include/errorcode.h"
26 #include "nnacl/op_base.h"
27
28 namespace mindspore {
29 namespace opt {
is_any_none(const std::vector<int64_t> & split)30 bool is_any_none(const std::vector<int64_t> &split) {
31 return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); });
32 }
33
is_any_not_none(const std::vector<int64_t> & split)34 bool is_any_not_none(const std::vector<int64_t> &split) {
35 return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
36 }
37
CreateFakeAbstractTensor() const38 std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() const {
39 auto type_ptr = TypeIdToType(operator_type_id_);
40 std::vector<int64_t> shape_vector;
41 return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
42 }
43
CheckSplitResult(const AnfNodePtr & result_anf_node,const std::vector<AnfNodePtr> & split_results,int target_output_num)44 int OperatorInfo::CheckSplitResult(const AnfNodePtr &result_anf_node, const std::vector<AnfNodePtr> &split_results,
45 int target_output_num) {
46 if ((result_anf_node == nullptr) || (split_results.size() != IntToSize(target_output_num))) {
47 MS_LOG(ERROR) << name_ << " : Make split cnode failed.";
48 return lite::RET_ERROR;
49 }
50 return lite::RET_OK;
51 }
52
Init(const FuncGraphPtr & func_graph,const CNodePtr & cnode,int32_t fmk_type)53 void OperatorInfo::Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type) {
54 func_graph_ = func_graph;
55 cnode_ = cnode;
56 fmk_type_ = fmk_type;
57 parallel_output_nodes_.clear();
58 }
59
SetCNodeBackend()60 int OperatorInfo::SetCNodeBackend() {
61 for (size_t i = 0; i < strategy_.dev_num; ++i) {
62 lite::DeviceType dt_type;
63 MS_CHECK_LT(i, strategy_.dev_types.size(), lite::RET_ERROR);
64 std::string type = strategy_.dev_types[i];
65 MS_CHECK_LT(i, parallel_output_nodes_.size(), lite::RET_ERROR);
66 auto post_node = parallel_output_nodes_[i];
67 MS_CHECK_TRUE_RET(post_node != nullptr, lite::RET_ERROR);
68 auto post_cnode = post_node->cast<CNodePtr>();
69 MS_CHECK_TRUE_RET(post_cnode != nullptr, lite::RET_ERROR);
70 auto cnode = post_cnode->input(1)->cast<CNodePtr>();
71 MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_ERROR);
72 auto type_iter = kSupportSplitedDevices.find(type);
73 if (type_iter == kSupportSplitedDevices.end()) {
74 MS_LOG(ERROR) << "SetCnodeBackend: unknown device type.";
75 return lite::RET_ERROR;
76 }
77 if (type_iter->second == lite::DeviceType::DT_NPU) {
78 MS_LOG(ERROR) << "SetCnodeBackend: unsupported device type npu.";
79 return lite::RET_ERROR;
80 }
81 dt_type = type_iter->second;
82 cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type)));
83 }
84 return lite::RET_OK;
85 }
86
CheckStrategyValue()87 int OperatorInfo::CheckStrategyValue() {
88 auto strategy_size = strategy_.strategys.size();
89 for (size_t index = 0; index < strategy_size; ++index) {
90 auto strategy = strategy_.strategys[index];
91 for (const auto &s : strategy) {
92 if (s.size() != IntToSize(strategy_.dev_num)) {
93 MS_LOG(ERROR) << "Strategy split number:" << s.size()
94 << " is not equal to device number: " << strategy_.dev_num;
95 return lite::RET_ERROR;
96 }
97 if (is_any_not_none(s) && is_any_none(s)) {
98 MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s;
99 return lite::RET_ERROR;
100 }
101 }
102 }
103 return lite::RET_OK;
104 }
105
CreateMultipleOutputsOfAnfNode(const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)106 int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
107 std::vector<AnfNodePtr> *outputs) {
108 MS_EXCEPTION_IF_NULL(node);
109 MS_EXCEPTION_IF_NULL(outputs);
110 AbstractBasePtrList ptr_list;
111 auto cnode = node->cast<CNodePtr>();
112 if (cnode == nullptr) {
113 MS_LOG(ERROR) << name_ << " : Failed to get CNode.";
114 return lite::RET_ERROR;
115 }
116 for (size_t i = 0; i < output_num; ++i) {
117 auto idx = NewValueNode(SizeToInt(i));
118 auto index = std::make_shared<Int32Imm>(SizeToInt(i));
119 MS_CHECK_TRUE_RET(index != nullptr, lite::RET_ERROR);
120 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
121 MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
122 idx->set_abstract(abstract_scalar);
123 auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx});
124 if (tuple_getitem == nullptr) {
125 MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
126 return lite::RET_ERROR;
127 }
128 tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
129 outputs->push_back(tuple_getitem);
130 auto abstract_tensor = CreateFakeAbstractTensor();
131 ptr_list.push_back(abstract_tensor);
132 }
133 node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
134 return lite::RET_OK;
135 }
136
CreateConcateNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,int32_t concat_dim,size_t input_nodes_num)137 AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
138 int32_t concat_dim, size_t input_nodes_num) {
139 MS_EXCEPTION_IF_NULL(orig_node);
140 if (input_nodes.size() != input_nodes_num) {
141 MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
142 return nullptr;
143 }
144 auto concat_prim = std::make_shared<ops::Concat>();
145 MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
146 concat_prim->set_axis(concat_dim);
147 auto value_node = NewValueNode(concat_prim);
148 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
149 std::vector<AnfNodePtr> concat_inputs = {value_node};
150 (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs),
151 [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
152 auto concat_cnode = func_graph_->NewCNode(concat_inputs);
153 if (concat_cnode == nullptr) {
154 MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
155 return nullptr;
156 }
157 concat_cnode->set_fullname_with_scope("Concat_" + name_);
158 concat_cnode->set_scope(orig_node->scope());
159 std::vector<AnfNodePtr> outputs;
160 (void)CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs);
161 return concat_cnode;
162 }
163
CreateReduceNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,size_t input_nodes_num)164 AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
165 size_t input_nodes_num) {
166 MS_EXCEPTION_IF_NULL(orig_node);
167 if (input_nodes.size() != input_nodes_num) {
168 MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
169 return nullptr;
170 }
171 // addup inputs element-wise
172 auto addn_prim = std::make_shared<ops::AddN>();
173 MS_CHECK_TRUE_RET(addn_prim != nullptr, nullptr);
174 auto value_node = NewValueNode(addn_prim);
175 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
176 std::vector<AnfNodePtr> addn_inputs = {value_node};
177 (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs),
178 [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
179 auto addn_cnode = func_graph_->NewCNode(addn_inputs);
180 if (addn_cnode == nullptr) {
181 MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
182 return nullptr;
183 }
184 addn_cnode->set_fullname_with_scope("AddN_" + name_);
185 addn_cnode->set_scope(orig_node->scope());
186 return addn_cnode;
187 }
188
DoSplit()189 int OperatorInfo::DoSplit() {
190 if (CheckStrategyValue() != lite::RET_OK) {
191 MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
192 return lite::RET_ERROR;
193 }
194 if (CheckStrategy(strategy_) != lite::RET_OK) {
195 MS_LOG(ERROR) << name_ << ": Check strategys failed.";
196 return lite::RET_ERROR;
197 }
198 if (InferParallelCNodes() != lite::RET_OK) {
199 MS_LOG(ERROR) << name_ << ": InferParallelCNodes failed.";
200 return lite::RET_ERROR;
201 }
202 if (SetCNodeBackend() != lite::RET_OK) {
203 MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed.";
204 return lite::RET_ERROR;
205 }
206 if (InferReplaceOp() != lite::RET_OK) {
207 MS_LOG(ERROR) << name_ << ": InferForwardOps failed.";
208 return lite::RET_ERROR;
209 }
210 return lite::RET_OK;
211 }
212 } // namespace opt
213 } // namespace mindspore
214