• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/fisson/fisson_util.h"
17 #include <unordered_set>
18 #include <memory>
19 #include "mindspore/core/ops/sequence_ops.h"
20 #include "mindspore/core/ops/conv_pool_ops.h"
21 #include "mindspore/core/ops/lite_ops.h"
22 #include "src/common/utils.h"
23 #include "ops/split_with_overlap.h"
24 #include "tools/common/node_util.h"
25 #include "ops/auto_generate/gen_lite_ops.h"
26 #include "ops/make_tuple.h"
27 #include "tools/optimizer/parallel/spliter.h"
28 #include "tools/optimizer/parallel/split_strategy.h"
29 #include "nnacl/op_base.h"
30 #include "src/common/log_util.h"
31 #include "ops/op_utils.h"
32 #include "include/registry/converter_context.h"
33 
34 using mindspore::converter::FmkType;
35 namespace mindspore {
36 namespace opt {
GetSplitPadList(const api::SharedPtr<ops::Conv2DFusion> & ori_conv_prim,int64_t input_h,int64_t input_w)37 std::vector<int64_t> GetSplitPadList(const api::SharedPtr<ops::Conv2DFusion> &ori_conv_prim, int64_t input_h,
38                                      int64_t input_w) {
39   if (ori_conv_prim == nullptr) {
40     MS_LOG(DEBUG) << "input Conv2DFusion is nullptr";
41     return {};
42   }
43   if (ori_conv_prim->get_pad_mode() != SAME) {
44     return ori_conv_prim->get_pad_list();
45   }
46   if (ori_conv_prim->get_stride().size() < kIndexW || ori_conv_prim->get_kernel_size().size() < kIndexW ||
47       ori_conv_prim->get_dilation().size() < kIndexW) {
48     MS_LOG(ERROR) << "Index out of range";
49     return {};
50   }
51   int64_t output_h = static_cast<int64_t>(
52     std::ceil(static_cast<float>(input_h) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexH))));
53   int64_t output_w = static_cast<int64_t>(
54     std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
55 
56   auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
57   auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
58   auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
59   auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
60   if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
61       INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
62     MS_LOG(ERROR) << "int mul overflow";
63     return {};
64   }
65   std::vector<int64_t> new_pad_list;
66   int64_t pad_up = 0;
67   int64_t pad_down = 0;
68   int64_t pad_left = 0;
69   int64_t pad_right = 0;
70   int64_t pad_h_all =
71     (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
72   int64_t pad_w_all =
73     (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
74   // only check pad_up and pad_down is positive
75   // if compute overflowed, we will get abnormal it in infer_shape
76   if (pad_h_all >= 0) {
77     pad_up = pad_h_all / 2;
78     pad_down = pad_h_all - pad_up;
79   }
80   new_pad_list.push_back(pad_up);
81   new_pad_list.push_back(pad_down);
82   if (pad_w_all >= 0) {
83     pad_left = pad_w_all / 2;
84     pad_right = pad_w_all - pad_left;
85   }
86   new_pad_list.push_back(pad_left);
87   new_pad_list.push_back(pad_right);
88   return new_pad_list;
89 }
90 
91 namespace {
CalSplitOutputShape(int64_t splited_axis_value,const SplitInfo * split_info,std::vector<int64_t> * split_axis_out_shape,std::vector<int64_t> * split_axis_reduce_out_shape)92 bool CalSplitOutputShape(int64_t splited_axis_value, const SplitInfo *split_info,
93                          std::vector<int64_t> *split_axis_out_shape,
94                          std::vector<int64_t> *split_axis_reduce_out_shape) {
95   MS_ASSERT(split_info != nullptr && split_axis_out_shape != nullptr && split_axis_reduce_out_shape != nullptr);
96   // ori ratio
97   int64_t split_num = split_info->out_num;
98   int64_t split_len = 0;
99   for (int64_t i = 0; i < split_num; i++) {
100     split_len += split_info->size_splits[i];
101   }
102   if (split_len > splited_axis_value) {
103     return false;
104   }
105   // out-shape after splited
106   int64_t tmp_value = 0;
107   MS_CHECK_TRUE_MSG(split_num > 0, false, "out_num of split_info should be greater than zero");
108   MS_CHECK_TRUE_MSG(split_len > 0, false, "split_len should be greater than zero");
109   for (int64_t i = 0; i < split_num - 1; i++) {
110     if (INT_MUL_OVERFLOW_THRESHOLD(split_info->size_splits[i], splited_axis_value, INT64_MAX)) {
111       MS_LOG(ERROR) << "int mul overflow";
112       return false;
113     }
114     int64_t tmp = UP_DIV(split_info->size_splits[i] * splited_axis_value, split_len);
115     tmp_value += tmp;
116     split_axis_out_shape->push_back(tmp);
117     split_axis_reduce_out_shape->push_back(tmp_value);
118   }
119   split_axis_out_shape->push_back(splited_axis_value - tmp_value);
120   split_axis_reduce_out_shape->push_back(splited_axis_value);
121   return true;
122 }
123 
CalSplitInShape(const std::vector<std::vector<ShapeVector>> & node_in_out_shapes,const SplitInfo * split_info,const api::SharedPtr<ops::Conv2DFusion> & ori_conv_prim,size_t index_node,std::vector<std::vector<int64_t>> * split_axis_inputs_shape,std::vector<std::vector<int64_t>> * split_axis_reduce_inputs_shape)124 bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_shapes, const SplitInfo *split_info,
125                      const api::SharedPtr<ops::Conv2DFusion> &ori_conv_prim, size_t index_node,
126                      std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
127                      std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
128   MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
129             split_axis_reduce_inputs_shape != nullptr);
130   MS_ASSERT(node_in_out_shapes.size() > index_node);
131   auto in_out_shape = node_in_out_shapes.at(index_node);
132   MS_ASSERT(!in_out_shape.empty());
133   auto in_shape = in_out_shape.front();
134   if (in_shape.size() < kAxisW) {
135     MS_LOG(DEBUG) << "out of in_shape range";
136     return false;
137   }
138   int64_t input_h = in_shape.at(kAxisH);
139   int64_t input_w = in_shape.at(kAxisW);
140   auto new_pad_list = GetSplitPadList(ori_conv_prim, input_h, input_w);
141   ori_conv_prim->set_pad_list(new_pad_list);
142   int64_t split_num = split_info->out_num;
143   int64_t tmp = 0;
144   std::vector<int64_t> split_axis_shape;
145   std::vector<int64_t> split_axis_reduce_shape;
146   // iter splited_num
147   for (int64_t index = 0; index < split_num; index++) {
148     // shape
149     auto stride_h = ori_conv_prim->get_stride()[kIndexH];
150     auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
151     if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
152       MS_LOG(ERROR) << "int mul overflow";
153       return false;
154     }
155     if (split_info->axis == CuttingStragedy::CUT_H) {  // H
156       if (index == 0) {
157         tmp =
158           stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
159       } else if (index == split_num - 1) {
160         tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
161               ori_conv_prim->get_kernel_size()[kIndexH];
162       } else {
163         tmp = stride_h * split_axis_dim + ori_conv_prim->get_kernel_size()[kIndexH];
164       }
165     }
166     split_axis_shape.push_back(tmp);
167 
168     // reduce shape
169     auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
170     if (split_info->axis == CuttingStragedy::CUT_H) {  // H
171       if (index == split_num - 1) {
172         tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
173               ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
174       } else {
175         tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
176               ori_conv_prim->get_kernel_size()[kIndexH];
177       }
178     }
179     split_axis_reduce_shape.push_back(tmp);
180   }
181   split_axis_inputs_shape->push_back(split_axis_shape);
182   split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
183   return true;
184 }
185 }  // namespace
186 
IsConv2D(const AnfNodePtr & node)187 bool IsConv2D(const AnfNodePtr &node) {
188   return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
189 }
190 
CopyConvPrim(const api::SharedPtr<ops::Conv2DFusion> & ori_conv_prim)191 api::SharedPtr<ops::Conv2DFusion> CopyConvPrim(const api::SharedPtr<ops::Conv2DFusion> &ori_conv_prim) {
192   MS_CHECK_TRUE_MSG(ori_conv_prim != nullptr, nullptr, "input Conv2DFusion is nullptr");
193   auto new_prim = api::MakeShared<ops::Conv2DFusion>();
194   MS_CHECK_TRUE_MSG(new_prim != nullptr, nullptr, "create Conv2DFusion return nullptr");
195   auto new_prim_c = new_prim->GetPrim();
196   MS_CHECK_TRUE_MSG(new_prim_c != nullptr, nullptr, "create primic return nullptr");
197   new_prim->set_pad(ori_conv_prim->get_pad());
198   new_prim->set_in_channel(ori_conv_prim->get_in_channel());
199   new_prim->set_out_channel(ori_conv_prim->get_out_channel());
200   new_prim->set_dilation(ori_conv_prim->get_dilation());
201   new_prim->set_format(ori_conv_prim->get_format());
202   new_prim->set_group(ori_conv_prim->get_group());
203   new_prim->set_kernel_size(ori_conv_prim->get_kernel_size());
204   if (ori_conv_prim->get_pad_mode() == SAME) {
205     new_prim->set_pad_mode(PAD);
206   } else {
207     new_prim->set_pad_mode(ori_conv_prim->get_pad_mode());
208   }
209 
210   new_prim->set_stride(ori_conv_prim->get_stride());
211   new_prim->set_activation_type(ori_conv_prim->get_activation_type());
212   new_prim->set_pad_list(ori_conv_prim->get_pad_list());
213   auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
214   if (is_depth_value != nullptr) {
215     bool is_depth_wise = GetValue<bool>(is_depth_value);
216     (void)new_prim_c->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
217   }
218   return new_prim;
219 }
220 
UpdateSplitInfo(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & conv_nodes,SplitInfo * split_info)221 bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
222   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
223   MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input SplitInfo is nullptr");
224   MS_CHECK_TRUE_MSG(conv_nodes.size() >= 1, false, "conv_nodes is empty");
225   if (split_info->axis != CuttingStragedy::CUT_H) {
226     return false;
227   }
228   auto splited_axis = split_info->axis;
229   // need to check
230   if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
231       split_info->fmk_type == FmkType::kFmkTypeOnnx) {  // NHWC -> NCHW
232     splited_axis += 1;
233   }
234 
235   size_t node_size = conv_nodes.size();
236   size_t index_node = 0;
237   std::vector<std::vector<ShapeVector>> node_in_out_shapes;
238   while (index_node < node_size) {
239     // [conv3, conv2, conv1] conv1->conv2->conv3
240     auto out_node_name = conv_nodes[index_node]->fullname_with_scope();
241     auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
242     auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
243     // 0-> in-shape 1->out-shape
244     // only one in and one output
245     MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
246     std::vector<ShapeVector> shape_vec = {input_shapes.front(), output_shapes.front()};
247     (void)node_in_out_shapes.emplace_back(shape_vec);
248     index_node++;
249   }
250   if (node_in_out_shapes.empty() || node_in_out_shapes.size() < (node_size - 1) || node_in_out_shapes[0].size() <= 1 ||
251       node_in_out_shapes[0][1].size() <= static_cast<size_t>(splited_axis) ||
252       node_in_out_shapes[node_size - 1].empty() ||
253       node_in_out_shapes[node_size - 1][0].size() <= static_cast<size_t>(splited_axis)) {
254     MS_LOG(ERROR) << "out of node_in_out_shapes range";
255     return false;
256   }
257   int64_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
258   int64_t final_split_axis_value = node_in_out_shapes[node_size - 1][0][splited_axis];
259   split_info->ori_split_axis_value = final_split_axis_value;
260   size_t split_num = split_info->size_splits.size();
261   std::vector<int64_t> split_axis_out_shape;
262   std::vector<int64_t> split_axis_reduce_out_shape;
263   if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
264     return false;
265   }
266   // infer in-shape after splited
267   std::vector<std::vector<int64_t>> split_axis_inputs_shape{split_axis_out_shape};
268   std::vector<std::vector<int64_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
269   index_node = 0;
270   // iter node
271   while (index_node < node_size) {
272     auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
273     MS_ASSERT(conv_cnode != nullptr);
274     auto ori_conv_prim = ops::GetOperator<ops::Conv2DFusion>(conv_cnode->input(kAnfPrimitiveIndex));
275     MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
276     if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
277                          &split_axis_reduce_inputs_shape)) {
278       MS_LOG(ERROR) << "CalSplitInShape failed";
279       return false;
280     }
281     index_node++;
282   }
283 
284   // update ratio
285   split_info->size_splits.clear();
286   split_info->extend_top.clear();
287   split_info->extend_bottom.clear();
288 
289   int64_t top = 0;
290   int32_t bottom = 0;
291   split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
292   split_info->extend_top.push_back(top);
293   split_info->extend_bottom.push_back(bottom);
294 
295   for (size_t i = 1; i < split_num; i++) {
296     auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
297     top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
298     auto value = split_axis_inputs_shape[node_size][i] - top;
299     split_info->size_splits.push_back(value);
300     split_info->extend_top.push_back(top);
301     split_info->extend_bottom.push_back(bottom);
302   }
303   return true;
304 }
305 
GetMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)306 bool GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
307                                  std::vector<AnfNodePtr> *outputs) {
308   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
309   MS_CHECK_TRUE_MSG(node != nullptr, false, "input AnfNodePtr is nullptr");
310   MS_CHECK_TRUE_MSG(outputs != nullptr, false, "input std::vector<AnfNodePtr> is nullptr");
311   auto cnode = node->cast<CNodePtr>();
312   MS_CHECK_TRUE_MSG(cnode != nullptr, false, "create CNode return nullptr");
313   for (size_t i = 0; i < output_num; i++) {
314     auto index = NewValueNode(SizeToLong(i));
315     MS_CHECK_TRUE_MSG(index != nullptr, false, "create ValueNode return nullptr");
316     auto temp = SizeToLong(i);
317     auto imm = std::make_shared<Int64Imm>(temp);
318     MS_CHECK_TRUE_MSG(imm != nullptr, false, "create Int64Imm return nullptr");
319     auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
320     MS_CHECK_TRUE_MSG(abstract_scalar != nullptr, false, "create AbstractScalar return nullptr");
321     index->set_abstract(abstract_scalar);
322     auto tuple_getitem_primitive = NewValueNode(prim::kPrimTupleGetItem);
323     MS_CHECK_TRUE_MSG(tuple_getitem_primitive != nullptr, false, "create PrimTupleGetItem return nullptr");
324     auto tuple_getitem = func_graph->NewCNode({tuple_getitem_primitive, node, index});
325     MS_CHECK_TRUE_MSG(tuple_getitem != nullptr, false, "create CNode return nullptr");
326     tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
327     outputs->push_back(tuple_getitem);
328   }
329   return true;
330 }
331 
CreateOutputsOfConcat(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_cnode,const std::vector<AnfNodePtr> & conv_outputs,const SplitInfo & split_info,const std::string & node_name)332 AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
333                                  const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
334                                  const std::string &node_name) {
335   MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "input FuncGraphPtr is nullptr");
336   MS_CHECK_TRUE_MSG(conv_cnode != nullptr, nullptr, "input AnfNodePtr is nullptr");
337 
338   auto nodes_num = static_cast<int64_t>(conv_outputs.size());
339   if (nodes_num != split_info.out_num) {
340     MS_LOG(ERROR) << "Conv outputs has wrong input size";
341     return nullptr;
342   }
343 
344   auto make_tuple_prim = std::make_shared<ops::MakeTuple>();
345   MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, nullptr, "create ops::MakeTuple return nullptr");
346   auto make_tuple_prim_c = make_tuple_prim->GetPrim();
347   MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, nullptr, "create ops::make_tuple_prim_c return nullptr");
348 
349   // the inputs of make_tuple are from the outputs of conv
350   auto make_tuple_primitive = NewValueNode(make_tuple_prim_c);
351   MS_CHECK_TRUE_MSG(make_tuple_primitive != nullptr, nullptr, "create make_tuple_primitive return nullptr");
352   std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitive};
353   for (size_t i = 0; i < static_cast<size_t>(nodes_num); i++) {
354     make_tuple_inputs.push_back(conv_outputs[i]);
355   }
356 
357   auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
358   MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "create make_tuple_cnode return nullptr");
359   make_tuple_cnode->set_fullname_with_scope(node_name + "_MakeTuple");
360   make_tuple_cnode->set_scope(conv_cnode->scope());
361 
362   auto concat_prim = std::make_shared<ops::Concat>();
363   MS_CHECK_TRUE_MSG(concat_prim != nullptr, nullptr, "create ops::Concat return nullptr");
364   auto concat_prim_c = concat_prim->GetPrim();
365   MS_CHECK_TRUE_MSG(concat_prim_c != nullptr, nullptr, "create ops::concat_prim_c return nullptr");
366   concat_prim->set_axis(split_info.axis);
367 
368   // the inputs of concat are from the outputs of conv
369   auto concat_primitive = NewValueNode(concat_prim_c);
370   MS_CHECK_TRUE_MSG(concat_primitive != nullptr, nullptr, "create concat_primitive return nullptr");
371   std::vector<AnfNodePtr> concat_inputs = {concat_primitive, make_tuple_cnode};
372 
373   auto concat_cnode = func_graph->NewCNode(concat_inputs);
374   MS_CHECK_TRUE_MSG(concat_cnode != nullptr, nullptr, "create concat_cnode return nullptr");
375   concat_cnode->set_fullname_with_scope(node_name + "_Concat");
376   concat_cnode->set_scope(conv_cnode->scope());
377   std::vector<AnfNodePtr> outputs;
378   if (!GetMultipleOutputsOfAnfNode(func_graph, concat_cnode, 1, &outputs)) {
379     MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
380     return nullptr;
381   }
382   return concat_cnode;
383 }
384 
CreateOutputsOfSplitWithOverlap(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_node,std::vector<AnfNodePtr> * split_outputs,const SplitInfo & split_info,const std::string & node_name)385 bool CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
386                                      std::vector<AnfNodePtr> *split_outputs, const SplitInfo &split_info,
387                                      const std::string &node_name) {
388   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
389   MS_CHECK_TRUE_MSG(conv_node != nullptr, false, "input conv_node is nullptr");
390   MS_CHECK_TRUE_MSG(split_outputs != nullptr, false, "input split_outputs is nullptr");
391   // attr of split
392   auto split_prim = std::make_shared<ops::SplitWithOverlap>();
393   MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::SplitWithOverlap return nullptr");
394   auto split_prim_c = split_prim->GetPrim();
395   MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::split_prim_c return nullptr");
396   split_prim->set_split_dim(split_info.axis);
397   split_prim->set_number_split(split_info.out_num);
398   split_prim->set_ratio(split_info.size_splits);
399   split_prim->set_extend_top(split_info.extend_top);
400   split_prim->set_extend_bottom(split_info.extend_bottom);
401   auto conv_cnode = conv_node->cast<CNodePtr>();
402 
403   // the inputs of split is from the inputs of conv
404   auto split_primitive = NewValueNode(split_prim_c);
405   MS_CHECK_TRUE_MSG(split_primitive != nullptr, false, "create split_primitive return nullptr");
406   std::vector<AnfNodePtr> split_inputs = {split_primitive};
407 
408   // this conv only has one input, which has been ensured before
409   split_inputs.push_back(conv_cnode->input(1));
410 
411   auto split_cnode = func_graph->NewCNode(split_inputs);
412   MS_CHECK_TRUE_MSG(split_cnode != nullptr, false, "create split_cnode return nullptr");
413 
414   split_cnode->set_fullname_with_scope(node_name + "_Split");
415   if (split_info.out_num < 0) {
416     MS_LOG(ERROR) << "out_num should greater then zero";
417     return false;
418   }
419   // create outputs op split
420   if (!GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info.out_num, split_outputs)) {
421     MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
422     return false;
423   }
424 
425   AbstractBasePtrList ptr_list;
426   for (int64_t i = 0; i < split_info.out_num; i++) {
427     // set date_type same with weight
428     auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
429     auto type_ptr = TypeIdToType(type_id);
430     std::vector<int64_t> shape_vector;
431     auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
432     MS_CHECK_TRUE_MSG(value_node != nullptr, false, "create abstract::AbstractTensor return nullptr");
433     ptr_list.push_back(value_node);
434   }
435   split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
436   return true;
437 }
438 
UpdateRatioWithPadStride(int64_t * ratio,size_t ratio_len,size_t split_size,int split_dim_size)439 bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_size, int split_dim_size) {
440   MS_CHECK_TRUE_MSG(ratio != nullptr, false, "input ratio is nullptr");
441   MS_CHECK_TRUE_MSG(split_size > 0, false, "split_size is zero");
442   int64_t total_block_count = 0;
443   for (size_t i = 0; i < split_size; i++) {
444     total_block_count += ratio[i];
445   }
446   if (ratio_len < split_size) {
447     MS_LOG(ERROR) << "out of ratio range";
448     return false;
449   }
450   if (total_block_count < 0) {
451     MS_LOG(ERROR) << "divide by zero";
452     return false;
453   }
454 
455   std::vector<int64_t> new_ratio(split_size);
456   int64_t visited_block = 0;
457   for (size_t i = 0; i < split_size - 1; i++) {
458     visited_block += ratio[i];
459     if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
460       MS_LOG(ERROR) << "int mul overflow";
461       return false;
462     }
463     int64_t cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
464     new_ratio[i + 1] = cur_border;
465   }
466 
467   for (size_t i = 0; i < split_size; i++) {
468     ratio[i] = new_ratio[i];
469   }
470   return true;
471 }
472 }  // namespace opt
473 }  // namespace mindspore
474