1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/parallel/operator_info.h"
19 #include <algorithm>
20 #include "tools/optimizer/parallel/split_strategy.h"
21 #include "ops/auto_generate/gen_lite_ops.h"
22 #include "ops/tuple_get_item.h"
23 #include "include/common/utils/utils.h"
24 #include "include/errorcode.h"
25 #include "nnacl/op_base.h"
26 #include "ops/op_utils.h"
27 #include "src/common/log_util.h"
28
29 namespace mindspore {
30 namespace opt {
is_any_none(const std::vector<int64_t> & split)31 bool is_any_none(const std::vector<int64_t> &split) {
32 return std::any_of(split.begin(), split.end(), [](int64_t v) { return v == static_cast<int64_t>(NoSplit); });
33 }
34
is_any_not_none(const std::vector<int64_t> & split)35 bool is_any_not_none(const std::vector<int64_t> &split) {
36 return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
37 }
38
CreateFakeAbstractTensor() const39 std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() const {
40 auto type_ptr = TypeIdToType(operator_type_id_);
41 std::vector<int64_t> shape_vector;
42 return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
43 }
44
CheckSplitResult(const AnfNodePtr & result_anf_node,const std::vector<AnfNodePtr> & split_results,int target_output_num)45 int OperatorInfo::CheckSplitResult(const AnfNodePtr &result_anf_node, const std::vector<AnfNodePtr> &split_results,
46 int target_output_num) {
47 if ((result_anf_node == nullptr) || (split_results.size() != IntToSize(target_output_num))) {
48 MS_LOG(ERROR) << name_ << " : Make split cnode failed.";
49 return lite::RET_ERROR;
50 }
51 return lite::RET_OK;
52 }
53
Init(const FuncGraphPtr & func_graph,const CNodePtr & cnode,int32_t fmk_type)54 void OperatorInfo::Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type) {
55 func_graph_ = func_graph;
56 cnode_ = cnode;
57 fmk_type_ = fmk_type;
58 parallel_output_nodes_.clear();
59 }
60
SetCNodeBackend()61 int OperatorInfo::SetCNodeBackend() {
62 for (size_t i = 0; i < strategy_.dev_num; ++i) {
63 lite::DeviceType dt_type;
64 MS_CHECK_LT(i, strategy_.dev_types.size(), lite::RET_ERROR);
65 std::string type = strategy_.dev_types[i];
66 MS_CHECK_LT(i, parallel_output_nodes_.size(), lite::RET_ERROR);
67 auto post_node = parallel_output_nodes_[i];
68 MS_CHECK_TRUE_RET(post_node != nullptr, lite::RET_ERROR);
69 auto post_cnode = post_node->cast<CNodePtr>();
70 MS_CHECK_TRUE_RET(post_cnode != nullptr, lite::RET_ERROR);
71 auto cnode = post_cnode->input(1)->cast<CNodePtr>();
72 MS_CHECK_TRUE_RET(cnode != nullptr, lite::RET_ERROR);
73 auto type_iter = kSupportSplitedDevices.find(type);
74 if (type_iter == kSupportSplitedDevices.end()) {
75 MS_LOG(ERROR) << "SetCnodeBackend: unknown device type.";
76 return lite::RET_ERROR;
77 }
78 if (type_iter->second == lite::DeviceType::DT_NPU) {
79 MS_LOG(ERROR) << "SetCnodeBackend: unsupported device type npu.";
80 return lite::RET_ERROR;
81 }
82 dt_type = type_iter->second;
83 cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(dt_type)));
84 }
85 return lite::RET_OK;
86 }
87
CheckStrategyValue()88 int OperatorInfo::CheckStrategyValue() {
89 auto strategy_size = strategy_.strategys.size();
90 for (size_t index = 0; index < strategy_size; ++index) {
91 auto strategy = strategy_.strategys[index];
92 for (const auto &s : strategy) {
93 if (s.size() != IntToSize(strategy_.dev_num)) {
94 MS_LOG(ERROR) << "Strategy split number:" << s.size()
95 << " is not equal to device number: " << strategy_.dev_num;
96 return lite::RET_ERROR;
97 }
98 if (is_any_not_none(s) && is_any_none(s)) {
99 MS_LOG(ERROR) << "Strategy split number must be all zero or all non-zero: " << s;
100 return lite::RET_ERROR;
101 }
102 }
103 }
104 return lite::RET_OK;
105 }
106
CreateMultipleOutputsOfAnfNode(const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)107 int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num,
108 std::vector<AnfNodePtr> *outputs) {
109 MS_ASSERT(node != nullptr && outputs != nullptr);
110 AbstractBasePtrList ptr_list;
111 auto cnode = node->cast<CNodePtr>();
112 if (cnode == nullptr) {
113 MS_LOG(ERROR) << name_ << " : Failed to get CNode.";
114 return lite::RET_ERROR;
115 }
116 for (size_t i = 0; i < output_num; ++i) {
117 auto idx = NewValueNode(SizeToLong(i));
118 auto index = std::make_shared<Int64Imm>(SizeToLong(i));
119 MS_CHECK_TRUE_RET(index != nullptr, lite::RET_ERROR);
120 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
121 MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
122 idx->set_abstract(abstract_scalar);
123 auto tuple_node = std::make_shared<ops::TupleGetItem>();
124 if (tuple_node == nullptr) {
125 MS_LOG(ERROR) << name_ << " : Failed to create tuple node.";
126 return lite::RET_ERROR;
127 }
128 auto tuple_prim_c = tuple_node->GetPrim();
129 if (tuple_prim_c == nullptr) {
130 MS_LOG(ERROR) << name_ << " : Failed to create tuple node primitive.";
131 return lite::RET_ERROR;
132 }
133 auto tuple_getitem = func_graph_->NewCNode({NewValueNode(tuple_prim_c), node, idx});
134 if (tuple_getitem == nullptr) {
135 MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
136 return lite::RET_ERROR;
137 }
138 tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem" + std::to_string(i));
139 outputs->push_back(tuple_getitem);
140 auto abstract_tensor = CreateFakeAbstractTensor();
141 ptr_list.push_back(abstract_tensor);
142 }
143 node->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
144 return lite::RET_OK;
145 }
146
CreateConcateNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,int32_t concat_dim,size_t input_nodes_num)147 AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
148 int32_t concat_dim, size_t input_nodes_num) {
149 MS_ASSERT(orig_node != nullptr);
150 if (input_nodes.size() != input_nodes_num) {
151 MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
152 return nullptr;
153 }
154 auto concat_prim = std::make_shared<ops::Concat>();
155 MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
156 auto concat_prim_c = concat_prim->GetPrim();
157 MS_CHECK_TRUE_RET(concat_prim_c != nullptr, nullptr);
158 concat_prim->set_axis(concat_dim);
159 auto value_node = NewValueNode(concat_prim_c);
160 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
161 std::vector<AnfNodePtr> concat_inputs = {value_node};
162 (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(concat_inputs),
163 [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
164 auto concat_cnode = func_graph_->NewCNode(concat_inputs);
165 if (concat_cnode == nullptr) {
166 MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
167 return nullptr;
168 }
169 concat_cnode->set_fullname_with_scope("Concat_" + name_);
170 concat_cnode->set_scope(orig_node->scope());
171 std::vector<AnfNodePtr> outputs;
172 (void)CreateMultipleOutputsOfAnfNode(concat_cnode, 1, &outputs);
173 return concat_cnode;
174 }
175
CreateReduceNode(const CNodePtr & orig_node,const std::vector<AnfNodePtr> & input_nodes,size_t input_nodes_num)176 AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
177 size_t input_nodes_num) {
178 MS_ASSERT(orig_node != nullptr);
179 if (input_nodes.size() != input_nodes_num) {
180 MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
181 return nullptr;
182 }
183 // addup inputs element-wise
184 auto addn_prim = std::make_shared<ops::AddN>();
185 MS_CHECK_TRUE_RET(addn_prim != nullptr, nullptr);
186 auto addn_prim_c = addn_prim->GetPrim();
187 MS_CHECK_TRUE_RET(addn_prim_c != nullptr, nullptr);
188 auto value_node = NewValueNode(addn_prim_c);
189 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
190 std::vector<AnfNodePtr> addn_inputs = {value_node};
191 (void)std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(addn_inputs),
192 [](const AnfNodePtr &p) { return p->cast<CNodePtr>()->input(1); });
193 auto addn_cnode = func_graph_->NewCNode(addn_inputs);
194 if (addn_cnode == nullptr) {
195 MS_LOG(ERROR) << name_ << " : Failed to create concat node.";
196 return nullptr;
197 }
198 addn_cnode->set_fullname_with_scope("AddN_" + name_);
199 addn_cnode->set_scope(orig_node->scope());
200 return addn_cnode;
201 }
202
DoSplit()203 int OperatorInfo::DoSplit() {
204 if (CheckStrategyValue() != lite::RET_OK) {
205 MS_LOG(ERROR) << name_ << ": Invalid strategy values.";
206 return lite::RET_ERROR;
207 }
208 if (CheckStrategy(strategy_) != lite::RET_OK) {
209 MS_LOG(ERROR) << name_ << ": Check strategys failed.";
210 return lite::RET_ERROR;
211 }
212 if (InferParallelCNodes() != lite::RET_OK) {
213 MS_LOG(ERROR) << name_ << ": InferParallelCNodes failed.";
214 return lite::RET_ERROR;
215 }
216 if (SetCNodeBackend() != lite::RET_OK) {
217 MS_LOG(ERROR) << name_ << ": SetCnodeBackend failed.";
218 return lite::RET_ERROR;
219 }
220 if (InferReplaceOp() != lite::RET_OK) {
221 MS_LOG(ERROR) << name_ << ": InferForwardOps failed.";
222 return lite::RET_ERROR;
223 }
224 return lite::RET_OK;
225 }
226 } // namespace opt
227 } // namespace mindspore
228