1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/fisson/fisson_util.h"
17 #include <unordered_set>
18 #include <memory>
19 #include "mindspore/core/ops/sequence_ops.h"
20 #include "mindspore/core/ops/conv_pool_ops.h"
21 #include "mindspore/core/ops/lite_ops.h"
22 #include "src/common/utils.h"
23 #include "ops/split_with_overlap.h"
24 #include "tools/common/node_util.h"
25 #include "ops/auto_generate/gen_lite_ops.h"
26 #include "ops/make_tuple.h"
27 #include "tools/optimizer/parallel/spliter.h"
28 #include "tools/optimizer/parallel/split_strategy.h"
29 #include "nnacl/op_base.h"
30 #include "src/common/log_util.h"
31 #include "ops/op_utils.h"
32 #include "include/registry/converter_context.h"
33
34 using mindspore::converter::FmkType;
35 namespace mindspore {
36 namespace opt {
GetSplitPadList(const api::SharedPtr<ops::Conv2DFusion> & ori_conv_prim,int64_t input_h,int64_t input_w)37 std::vector<int64_t> GetSplitPadList(const api::SharedPtr<ops::Conv2DFusion> &ori_conv_prim, int64_t input_h,
38 int64_t input_w) {
39 if (ori_conv_prim == nullptr) {
40 MS_LOG(DEBUG) << "input Conv2DFusion is nullptr";
41 return {};
42 }
43 if (ori_conv_prim->get_pad_mode() != SAME) {
44 return ori_conv_prim->get_pad_list();
45 }
46 if (ori_conv_prim->get_stride().size() < kIndexW || ori_conv_prim->get_kernel_size().size() < kIndexW ||
47 ori_conv_prim->get_dilation().size() < kIndexW) {
48 MS_LOG(ERROR) << "Index out of range";
49 return {};
50 }
51 int64_t output_h = static_cast<int64_t>(
52 std::ceil(static_cast<float>(input_h) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexH))));
53 int64_t output_w = static_cast<int64_t>(
54 std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
55
56 auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
57 auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
58 auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
59 auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
60 if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
61 INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
62 MS_LOG(ERROR) << "int mul overflow";
63 return {};
64 }
65 std::vector<int64_t> new_pad_list;
66 int64_t pad_up = 0;
67 int64_t pad_down = 0;
68 int64_t pad_left = 0;
69 int64_t pad_right = 0;
70 int64_t pad_h_all =
71 (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
72 int64_t pad_w_all =
73 (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
74 // only check pad_up and pad_down is positive
75 // if compute overflowed, we will get abnormal it in infer_shape
76 if (pad_h_all >= 0) {
77 pad_up = pad_h_all / 2;
78 pad_down = pad_h_all - pad_up;
79 }
80 new_pad_list.push_back(pad_up);
81 new_pad_list.push_back(pad_down);
82 if (pad_w_all >= 0) {
83 pad_left = pad_w_all / 2;
84 pad_right = pad_w_all - pad_left;
85 }
86 new_pad_list.push_back(pad_left);
87 new_pad_list.push_back(pad_right);
88 return new_pad_list;
89 }
90
91 namespace {
CalSplitOutputShape(int64_t splited_axis_value,const SplitInfo * split_info,std::vector<int64_t> * split_axis_out_shape,std::vector<int64_t> * split_axis_reduce_out_shape)92 bool CalSplitOutputShape(int64_t splited_axis_value, const SplitInfo *split_info,
93 std::vector<int64_t> *split_axis_out_shape,
94 std::vector<int64_t> *split_axis_reduce_out_shape) {
95 MS_ASSERT(split_info != nullptr && split_axis_out_shape != nullptr && split_axis_reduce_out_shape != nullptr);
96 // ori ratio
97 int64_t split_num = split_info->out_num;
98 int64_t split_len = 0;
99 for (int64_t i = 0; i < split_num; i++) {
100 split_len += split_info->size_splits[i];
101 }
102 if (split_len > splited_axis_value) {
103 return false;
104 }
105 // out-shape after splited
106 int64_t tmp_value = 0;
107 MS_CHECK_TRUE_MSG(split_num > 0, false, "out_num of split_info should be greater than zero");
108 MS_CHECK_TRUE_MSG(split_len > 0, false, "split_len should be greater than zero");
109 for (int64_t i = 0; i < split_num - 1; i++) {
110 if (INT_MUL_OVERFLOW_THRESHOLD(split_info->size_splits[i], splited_axis_value, INT64_MAX)) {
111 MS_LOG(ERROR) << "int mul overflow";
112 return false;
113 }
114 int64_t tmp = UP_DIV(split_info->size_splits[i] * splited_axis_value, split_len);
115 tmp_value += tmp;
116 split_axis_out_shape->push_back(tmp);
117 split_axis_reduce_out_shape->push_back(tmp_value);
118 }
119 split_axis_out_shape->push_back(splited_axis_value - tmp_value);
120 split_axis_reduce_out_shape->push_back(splited_axis_value);
121 return true;
122 }
123
CalSplitInShape(const std::vector<std::vector<ShapeVector>> & node_in_out_shapes,const SplitInfo * split_info,const api::SharedPtr<ops::Conv2DFusion> & ori_conv_prim,size_t index_node,std::vector<std::vector<int64_t>> * split_axis_inputs_shape,std::vector<std::vector<int64_t>> * split_axis_reduce_inputs_shape)124 bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_shapes, const SplitInfo *split_info,
125 const api::SharedPtr<ops::Conv2DFusion> &ori_conv_prim, size_t index_node,
126 std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
127 std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
128 MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
129 split_axis_reduce_inputs_shape != nullptr);
130 MS_ASSERT(node_in_out_shapes.size() > index_node);
131 auto in_out_shape = node_in_out_shapes.at(index_node);
132 MS_ASSERT(!in_out_shape.empty());
133 auto in_shape = in_out_shape.front();
134 if (in_shape.size() < kAxisW) {
135 MS_LOG(DEBUG) << "out of in_shape range";
136 return false;
137 }
138 int64_t input_h = in_shape.at(kAxisH);
139 int64_t input_w = in_shape.at(kAxisW);
140 auto new_pad_list = GetSplitPadList(ori_conv_prim, input_h, input_w);
141 ori_conv_prim->set_pad_list(new_pad_list);
142 int64_t split_num = split_info->out_num;
143 int64_t tmp = 0;
144 std::vector<int64_t> split_axis_shape;
145 std::vector<int64_t> split_axis_reduce_shape;
146 // iter splited_num
147 for (int64_t index = 0; index < split_num; index++) {
148 // shape
149 auto stride_h = ori_conv_prim->get_stride()[kIndexH];
150 auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
151 if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
152 MS_LOG(ERROR) << "int mul overflow";
153 return false;
154 }
155 if (split_info->axis == CuttingStragedy::CUT_H) { // H
156 if (index == 0) {
157 tmp =
158 stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
159 } else if (index == split_num - 1) {
160 tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
161 ori_conv_prim->get_kernel_size()[kIndexH];
162 } else {
163 tmp = stride_h * split_axis_dim + ori_conv_prim->get_kernel_size()[kIndexH];
164 }
165 }
166 split_axis_shape.push_back(tmp);
167
168 // reduce shape
169 auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
170 if (split_info->axis == CuttingStragedy::CUT_H) { // H
171 if (index == split_num - 1) {
172 tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
173 ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
174 } else {
175 tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
176 ori_conv_prim->get_kernel_size()[kIndexH];
177 }
178 }
179 split_axis_reduce_shape.push_back(tmp);
180 }
181 split_axis_inputs_shape->push_back(split_axis_shape);
182 split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
183 return true;
184 }
185 } // namespace
186
IsConv2D(const AnfNodePtr & node)187 bool IsConv2D(const AnfNodePtr &node) {
188 return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
189 }
190
CopyConvPrim(const api::SharedPtr<ops::Conv2DFusion> & ori_conv_prim)191 api::SharedPtr<ops::Conv2DFusion> CopyConvPrim(const api::SharedPtr<ops::Conv2DFusion> &ori_conv_prim) {
192 MS_CHECK_TRUE_MSG(ori_conv_prim != nullptr, nullptr, "input Conv2DFusion is nullptr");
193 auto new_prim = api::MakeShared<ops::Conv2DFusion>();
194 MS_CHECK_TRUE_MSG(new_prim != nullptr, nullptr, "create Conv2DFusion return nullptr");
195 auto new_prim_c = new_prim->GetPrim();
196 MS_CHECK_TRUE_MSG(new_prim_c != nullptr, nullptr, "create primic return nullptr");
197 new_prim->set_pad(ori_conv_prim->get_pad());
198 new_prim->set_in_channel(ori_conv_prim->get_in_channel());
199 new_prim->set_out_channel(ori_conv_prim->get_out_channel());
200 new_prim->set_dilation(ori_conv_prim->get_dilation());
201 new_prim->set_format(ori_conv_prim->get_format());
202 new_prim->set_group(ori_conv_prim->get_group());
203 new_prim->set_kernel_size(ori_conv_prim->get_kernel_size());
204 if (ori_conv_prim->get_pad_mode() == SAME) {
205 new_prim->set_pad_mode(PAD);
206 } else {
207 new_prim->set_pad_mode(ori_conv_prim->get_pad_mode());
208 }
209
210 new_prim->set_stride(ori_conv_prim->get_stride());
211 new_prim->set_activation_type(ori_conv_prim->get_activation_type());
212 new_prim->set_pad_list(ori_conv_prim->get_pad_list());
213 auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
214 if (is_depth_value != nullptr) {
215 bool is_depth_wise = GetValue<bool>(is_depth_value);
216 (void)new_prim_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
217 }
218 return new_prim;
219 }
220
UpdateSplitInfo(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & conv_nodes,SplitInfo * split_info)221 bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
222 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
223 MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input SplitInfo is nullptr");
224 MS_CHECK_TRUE_MSG(conv_nodes.size() >= 1, false, "conv_nodes is empty");
225 if (split_info->axis != CuttingStragedy::CUT_H) {
226 return false;
227 }
228 auto splited_axis = split_info->axis;
229 // need to check
230 if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
231 split_info->fmk_type == FmkType::kFmkTypeOnnx) { // NHWC -> NCHW
232 splited_axis += 1;
233 }
234
235 size_t node_size = conv_nodes.size();
236 size_t index_node = 0;
237 std::vector<std::vector<ShapeVector>> node_in_out_shapes;
238 while (index_node < node_size) {
239 // [conv3, conv2, conv1] conv1->conv2->conv3
240 auto out_node_name = conv_nodes[index_node]->fullname_with_scope();
241 auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
242 auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
243 // 0-> in-shape 1->out-shape
244 // only one in and one output
245 MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
246 std::vector<ShapeVector> shape_vec = {input_shapes.front(), output_shapes.front()};
247 (void)node_in_out_shapes.emplace_back(shape_vec);
248 index_node++;
249 }
250 if (node_in_out_shapes.empty() || node_in_out_shapes.size() < (node_size - 1) || node_in_out_shapes[0].size() <= 1 ||
251 node_in_out_shapes[0][1].size() <= static_cast<size_t>(splited_axis) ||
252 node_in_out_shapes[node_size - 1].empty() ||
253 node_in_out_shapes[node_size - 1][0].size() <= static_cast<size_t>(splited_axis)) {
254 MS_LOG(ERROR) << "out of node_in_out_shapes range";
255 return false;
256 }
257 int64_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
258 int64_t final_split_axis_value = node_in_out_shapes[node_size - 1][0][splited_axis];
259 split_info->ori_split_axis_value = final_split_axis_value;
260 size_t split_num = split_info->size_splits.size();
261 std::vector<int64_t> split_axis_out_shape;
262 std::vector<int64_t> split_axis_reduce_out_shape;
263 if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
264 return false;
265 }
266 // infer in-shape after splited
267 std::vector<std::vector<int64_t>> split_axis_inputs_shape{split_axis_out_shape};
268 std::vector<std::vector<int64_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
269 index_node = 0;
270 // iter node
271 while (index_node < node_size) {
272 auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
273 MS_ASSERT(conv_cnode != nullptr);
274 auto ori_conv_prim = ops::GetOperator<ops::Conv2DFusion>(conv_cnode->input(kAnfPrimitiveIndex));
275 MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
276 if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
277 &split_axis_reduce_inputs_shape)) {
278 MS_LOG(ERROR) << "CalSplitInShape failed";
279 return false;
280 }
281 index_node++;
282 }
283
284 // update ratio
285 split_info->size_splits.clear();
286 split_info->extend_top.clear();
287 split_info->extend_bottom.clear();
288
289 int64_t top = 0;
290 int32_t bottom = 0;
291 split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
292 split_info->extend_top.push_back(top);
293 split_info->extend_bottom.push_back(bottom);
294
295 for (size_t i = 1; i < split_num; i++) {
296 auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
297 top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
298 auto value = split_axis_inputs_shape[node_size][i] - top;
299 split_info->size_splits.push_back(value);
300 split_info->extend_top.push_back(top);
301 split_info->extend_bottom.push_back(bottom);
302 }
303 return true;
304 }
305
GetMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)306 bool GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
307 std::vector<AnfNodePtr> *outputs) {
308 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
309 MS_CHECK_TRUE_MSG(node != nullptr, false, "input AnfNodePtr is nullptr");
310 MS_CHECK_TRUE_MSG(outputs != nullptr, false, "input std::vector<AnfNodePtr> is nullptr");
311 auto cnode = node->cast<CNodePtr>();
312 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "create CNode return nullptr");
313 for (size_t i = 0; i < output_num; i++) {
314 auto index = NewValueNode(SizeToLong(i));
315 MS_CHECK_TRUE_MSG(index != nullptr, false, "create ValueNode return nullptr");
316 auto temp = SizeToLong(i);
317 auto imm = std::make_shared<Int64Imm>(temp);
318 MS_CHECK_TRUE_MSG(imm != nullptr, false, "create Int64Imm return nullptr");
319 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
320 MS_CHECK_TRUE_MSG(abstract_scalar != nullptr, false, "create AbstractScalar return nullptr");
321 index->set_abstract(abstract_scalar);
322 auto tuple_getitem_primitive = NewValueNode(prim::kPrimTupleGetItem);
323 MS_CHECK_TRUE_MSG(tuple_getitem_primitive != nullptr, false, "create PrimTupleGetItem return nullptr");
324 auto tuple_getitem = func_graph->NewCNode({tuple_getitem_primitive, node, index});
325 MS_CHECK_TRUE_MSG(tuple_getitem != nullptr, false, "create CNode return nullptr");
326 tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
327 outputs->push_back(tuple_getitem);
328 }
329 return true;
330 }
331
CreateOutputsOfConcat(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_cnode,const std::vector<AnfNodePtr> & conv_outputs,const SplitInfo & split_info,const std::string & node_name)332 AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
333 const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
334 const std::string &node_name) {
335 MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "input FuncGraphPtr is nullptr");
336 MS_CHECK_TRUE_MSG(conv_cnode != nullptr, nullptr, "input AnfNodePtr is nullptr");
337
338 auto nodes_num = static_cast<int64_t>(conv_outputs.size());
339 if (nodes_num != split_info.out_num) {
340 MS_LOG(ERROR) << "Conv outputs has wrong input size";
341 return nullptr;
342 }
343
344 auto make_tuple_prim = std::make_shared<ops::MakeTuple>();
345 MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, nullptr, "create ops::MakeTuple return nullptr");
346 auto make_tuple_prim_c = make_tuple_prim->GetPrim();
347 MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, nullptr, "create ops::make_tuple_prim_c return nullptr");
348
349 // the inputs of make_tuple are from the outputs of conv
350 auto make_tuple_primitive = NewValueNode(make_tuple_prim_c);
351 MS_CHECK_TRUE_MSG(make_tuple_primitive != nullptr, nullptr, "create make_tuple_primitive return nullptr");
352 std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitive};
353 for (size_t i = 0; i < static_cast<size_t>(nodes_num); i++) {
354 make_tuple_inputs.push_back(conv_outputs[i]);
355 }
356
357 auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
358 MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "create make_tuple_cnode return nullptr");
359 make_tuple_cnode->set_fullname_with_scope(node_name + "_MakeTuple");
360 make_tuple_cnode->set_scope(conv_cnode->scope());
361
362 auto concat_prim = std::make_shared<ops::Concat>();
363 MS_CHECK_TRUE_MSG(concat_prim != nullptr, nullptr, "create ops::Concat return nullptr");
364 auto concat_prim_c = concat_prim->GetPrim();
365 MS_CHECK_TRUE_MSG(concat_prim_c != nullptr, nullptr, "create ops::concat_prim_c return nullptr");
366 concat_prim->set_axis(split_info.axis);
367
368 // the inputs of concat are from the outputs of conv
369 auto concat_primitive = NewValueNode(concat_prim_c);
370 MS_CHECK_TRUE_MSG(concat_primitive != nullptr, nullptr, "create concat_primitive return nullptr");
371 std::vector<AnfNodePtr> concat_inputs = {concat_primitive, make_tuple_cnode};
372
373 auto concat_cnode = func_graph->NewCNode(concat_inputs);
374 MS_CHECK_TRUE_MSG(concat_cnode != nullptr, nullptr, "create concat_cnode return nullptr");
375 concat_cnode->set_fullname_with_scope(node_name + "_Concat");
376 concat_cnode->set_scope(conv_cnode->scope());
377 std::vector<AnfNodePtr> outputs;
378 if (!GetMultipleOutputsOfAnfNode(func_graph, concat_cnode, 1, &outputs)) {
379 MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
380 return nullptr;
381 }
382 return concat_cnode;
383 }
384
CreateOutputsOfSplitWithOverlap(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_node,std::vector<AnfNodePtr> * split_outputs,const SplitInfo & split_info,const std::string & node_name)385 bool CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
386 std::vector<AnfNodePtr> *split_outputs, const SplitInfo &split_info,
387 const std::string &node_name) {
388 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
389 MS_CHECK_TRUE_MSG(conv_node != nullptr, false, "input conv_node is nullptr");
390 MS_CHECK_TRUE_MSG(split_outputs != nullptr, false, "input split_outputs is nullptr");
391 // attr of split
392 auto split_prim = std::make_shared<ops::SplitWithOverlap>();
393 MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::SplitWithOverlap return nullptr");
394 auto split_prim_c = split_prim->GetPrim();
395 MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::split_prim_c return nullptr");
396 split_prim->set_split_dim(split_info.axis);
397 split_prim->set_number_split(split_info.out_num);
398 split_prim->set_ratio(split_info.size_splits);
399 split_prim->set_extend_top(split_info.extend_top);
400 split_prim->set_extend_bottom(split_info.extend_bottom);
401 auto conv_cnode = conv_node->cast<CNodePtr>();
402
403 // the inputs of split is from the inputs of conv
404 auto split_primitive = NewValueNode(split_prim_c);
405 MS_CHECK_TRUE_MSG(split_primitive != nullptr, false, "create split_primitive return nullptr");
406 std::vector<AnfNodePtr> split_inputs = {split_primitive};
407
408 // this conv only has one input, which has been ensured before
409 split_inputs.push_back(conv_cnode->input(1));
410
411 auto split_cnode = func_graph->NewCNode(split_inputs);
412 MS_CHECK_TRUE_MSG(split_cnode != nullptr, false, "create split_cnode return nullptr");
413
414 split_cnode->set_fullname_with_scope(node_name + "_Split");
415 if (split_info.out_num < 0) {
416 MS_LOG(ERROR) << "out_num should greater then zero";
417 return false;
418 }
419 // create outputs op split
420 if (!GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info.out_num, split_outputs)) {
421 MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
422 return false;
423 }
424
425 AbstractBasePtrList ptr_list;
426 for (int64_t i = 0; i < split_info.out_num; i++) {
427 // set date_type same with weight
428 auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
429 auto type_ptr = TypeIdToType(type_id);
430 std::vector<int64_t> shape_vector;
431 auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
432 MS_CHECK_TRUE_MSG(value_node != nullptr, false, "create abstract::AbstractTensor return nullptr");
433 ptr_list.push_back(value_node);
434 }
435 split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
436 return true;
437 }
438
UpdateRatioWithPadStride(int64_t * ratio,size_t ratio_len,size_t split_size,int split_dim_size)439 bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_size, int split_dim_size) {
440 MS_CHECK_TRUE_MSG(ratio != nullptr, false, "input ratio is nullptr");
441 MS_CHECK_TRUE_MSG(split_size > 0, false, "split_size is zero");
442 int64_t total_block_count = 0;
443 for (size_t i = 0; i < split_size; i++) {
444 total_block_count += ratio[i];
445 }
446 if (ratio_len < split_size) {
447 MS_LOG(ERROR) << "out of ratio range";
448 return false;
449 }
450 if (total_block_count < 0) {
451 MS_LOG(ERROR) << "divide by zero";
452 return false;
453 }
454
455 std::vector<int64_t> new_ratio(split_size);
456 int64_t visited_block = 0;
457 for (size_t i = 0; i < split_size - 1; i++) {
458 visited_block += ratio[i];
459 if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
460 MS_LOG(ERROR) << "int mul overflow";
461 return false;
462 }
463 int64_t cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
464 new_ratio[i + 1] = cur_border;
465 }
466
467 for (size_t i = 0; i < split_size; i++) {
468 ratio[i] = new_ratio[i];
469 }
470 return true;
471 }
472 } // namespace opt
473 } // namespace mindspore
474