• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/fisson/fisson_util.h"
17 #include <unordered_set>
18 #include <memory>
19 #include "base/core_ops.h"
20 #include "src/common/utils.h"
21 #include "ops/split_with_overlap.h"
22 #include "tools/common/node_util.h"
23 #include "ops/concat.h"
24 #include "tools/optimizer/parallel/spliter.h"
25 #include "tools/optimizer/parallel/split_strategy.h"
26 #include "nnacl/op_base.h"
27 #include "src/common/log_util.h"
28 
29 using mindspore::converter::FmkType;
30 namespace mindspore {
31 namespace opt {
GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> & ori_conv_prim,int64_t input_h,int64_t input_w)32 std::vector<int64_t> GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t input_h,
33                                      int64_t input_w) {
34   if (ori_conv_prim == nullptr) {
35     MS_LOG(DEBUG) << "input Conv2DFusion is nullptr";
36     return {};
37   }
38   if (ori_conv_prim->get_pad_mode() != SAME) {
39     return ori_conv_prim->get_pad_list();
40   }
41   if (ori_conv_prim->get_stride().size() < kIndexW || ori_conv_prim->get_kernel_size().size() < kIndexW ||
42       ori_conv_prim->get_dilation().size() < kIndexW) {
43     MS_LOG(ERROR) << "Index out of range";
44     return {};
45   }
46   int64_t output_h = static_cast<int64_t>(
47     std::ceil(static_cast<float>(input_h) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexH))));
48   int64_t output_w = static_cast<int64_t>(
49     std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
50 
51   auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
52   auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
53   auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
54   auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
55   if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
56       INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
57     MS_LOG(ERROR) << "int mul overflow";
58     return {};
59   }
60   std::vector<int64_t> new_pad_list;
61   int64_t pad_up = 0, pad_down = 0, pad_left = 0, pad_right = 0;
62   int64_t pad_h_all =
63     (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
64   int64_t pad_w_all =
65     (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
66   // only check pad_up and pad_down is positive
67   // if compute overflowed, we will get abnormal it in infer_shape
68   if (pad_h_all >= 0) {
69     pad_up = pad_h_all / 2;
70     pad_down = pad_h_all - pad_up;
71   }
72   new_pad_list.push_back(pad_up);
73   new_pad_list.push_back(pad_down);
74   if (pad_w_all >= 0) {
75     pad_left = pad_w_all / 2;
76     pad_right = pad_w_all - pad_left;
77   }
78   new_pad_list.push_back(pad_left);
79   new_pad_list.push_back(pad_right);
80   return new_pad_list;
81 }
82 
83 namespace {
CalSplitOutputShape(int64_t splited_axis_value,const SplitInfo * split_info,std::vector<int64_t> * split_axis_out_shape,std::vector<int64_t> * split_axis_reduce_out_shape)84 bool CalSplitOutputShape(int64_t splited_axis_value, const SplitInfo *split_info,
85                          std::vector<int64_t> *split_axis_out_shape,
86                          std::vector<int64_t> *split_axis_reduce_out_shape) {
87   MS_ASSERT(split_info != nullptr && split_axis_out_shape != nullptr && split_axis_reduce_out_shape != nullptr);
88   // ori ratio
89   int64_t split_num = split_info->out_num;
90   int64_t split_len = 0;
91   for (int64_t i = 0; i < split_num; i++) {
92     split_len += split_info->size_splits[i];
93   }
94   if (split_len > splited_axis_value) {
95     return false;
96   }
97   // out-shape after splited
98   int64_t tmp_value = 0;
99   MS_CHECK_TRUE_MSG(split_num > 0, false, "out_num of split_info should greater than zero");
100   MS_CHECK_TRUE_MSG(split_len > 0, false, "split_len should greater than zero");
101   for (int64_t i = 0; i < split_num - 1; i++) {
102     if (INT_MUL_OVERFLOW_THRESHOLD(split_info->size_splits[i], splited_axis_value, INT64_MAX)) {
103       MS_LOG(ERROR) << "int mul overflow";
104       return false;
105     }
106     int64_t tmp = UP_DIV(split_info->size_splits[i] * splited_axis_value, split_len);
107     tmp_value += tmp;
108     split_axis_out_shape->push_back(tmp);
109     split_axis_reduce_out_shape->push_back(tmp_value);
110   }
111   split_axis_out_shape->push_back(splited_axis_value - tmp_value);
112   split_axis_reduce_out_shape->push_back(splited_axis_value);
113   return true;
114 }
115 
CalSplitInShape(const std::vector<std::vector<ShapeVector>> & node_in_out_shapes,const SplitInfo * split_info,const std::shared_ptr<ops::Conv2DFusion> & ori_conv_prim,int64_t index_node,std::vector<std::vector<int64_t>> * split_axis_inputs_shape,std::vector<std::vector<int64_t>> * split_axis_reduce_inputs_shape)116 bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_shapes, const SplitInfo *split_info,
117                      const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t index_node,
118                      std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
119                      std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
120   MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
121             split_axis_reduce_inputs_shape != nullptr);
122   MS_ASSERT(node_in_out_shapes.size() > index_node);
123   auto in_out_shape = node_in_out_shapes.at(index_node);
124   MS_ASSERT(!in_out_shape.empty());
125   auto in_shape = in_out_shape.front();
126   if (in_shape.size() < kAxisW) {
127     MS_LOG(DEBUG) << "out of in_shape range";
128     return false;
129   }
130   int64_t input_h = in_shape.at(kAxisH);
131   int64_t input_w = in_shape.at(kAxisW);
132   auto new_pad_list = GetSplitPadList(ori_conv_prim, input_h, input_w);
133   ori_conv_prim->set_pad_list(new_pad_list);
134   int64_t split_num = split_info->out_num;
135   int64_t tmp = 0;
136   std::vector<int64_t> split_axis_shape;
137   std::vector<int64_t> split_axis_reduce_shape;
138   // iter splited_num
139   for (int64_t index = 0; index < split_num; index++) {
140     // shape
141     auto stride_h = ori_conv_prim->get_stride()[kIndexH];
142     auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
143     if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
144       MS_LOG(ERROR) << "int mul overflow";
145       return false;
146     }
147     if (split_info->axis == CuttingStragedy::CUT_H) {  // H
148       if (index == 0) {
149         tmp =
150           stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
151       } else if (index == split_num - 1) {
152         tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
153               ori_conv_prim->get_kernel_size()[kIndexH];
154       } else {
155         tmp = stride_h * split_axis_dim - 0 + ori_conv_prim->get_kernel_size()[kIndexH];
156       }
157     }
158     split_axis_shape.push_back(tmp);
159 
160     // reduce shape
161     auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
162     if (split_info->axis == CuttingStragedy::CUT_H) {  // H
163       if (index == split_num - 1) {
164         tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
165               ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
166       } else {
167         tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
168               ori_conv_prim->get_kernel_size()[kIndexH];
169       }
170     }
171     split_axis_reduce_shape.push_back(tmp);
172   }
173   split_axis_inputs_shape->push_back(split_axis_shape);
174   split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
175   return true;
176 }
177 }  // namespace
178 
IsConv2D(const AnfNodePtr & node)179 bool IsConv2D(const AnfNodePtr &node) {
180   return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
181 }
182 
CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> & ori_conv_prim)183 std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim) {
184   MS_CHECK_TRUE_MSG(ori_conv_prim != nullptr, nullptr, "input Conv2DFusion is nullptr");
185   auto new_prim = std::make_shared<ops::Conv2DFusion>();
186   MS_CHECK_TRUE_MSG(new_prim != nullptr, nullptr, "create Conv2DFusion return nullptr");
187   new_prim->set_pad(ori_conv_prim->get_pad());
188   new_prim->set_in_channel(ori_conv_prim->get_in_channel());
189   new_prim->set_out_channel(ori_conv_prim->get_out_channel());
190   new_prim->set_dilation(ori_conv_prim->get_dilation());
191   new_prim->set_format(ori_conv_prim->get_format());
192   new_prim->set_group(ori_conv_prim->get_group());
193   new_prim->set_kernel_size(ori_conv_prim->get_kernel_size());
194   if (ori_conv_prim->get_pad_mode() == SAME) {
195     new_prim->set_pad_mode(PAD);
196   } else {
197     new_prim->set_pad_mode(ori_conv_prim->get_pad_mode());
198   }
199 
200   new_prim->set_stride(ori_conv_prim->get_stride());
201   new_prim->set_activation_type(ori_conv_prim->get_activation_type());
202   new_prim->set_pad_list(ori_conv_prim->get_pad_list());
203   auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
204   if (is_depth_value != nullptr) {
205     bool is_depth_wise = GetValue<bool>(is_depth_value);
206     new_prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
207   }
208   return new_prim;
209 }
210 
UpdateSplitInfo(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & conv_nodes,SplitInfo * split_info)211 bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
212   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
213   MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input SplitInfo is nullptr");
214   if (split_info->axis != CuttingStragedy::CUT_H) {
215     return false;
216   }
217   auto splited_axis = split_info->axis;
218   // need to check
219   if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
220       split_info->fmk_type == FmkType::kFmkTypeOnnx) {  // NHWC -> NCHW
221     splited_axis += 1;
222   }
223 
224   size_t node_size = conv_nodes.size();
225   size_t index_node = 0;
226   std::vector<std::vector<ShapeVector>> node_in_out_shapes;
227   while (index_node < node_size) {
228     // [conv3, conv2, conv1] conv1->conv2->conv3
229     auto out_node_name = conv_nodes[index_node]->fullname_with_scope();
230     auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
231     auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
232     // 0-> in-shape 1->out-shape
233     // only one in and one output
234     MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
235     node_in_out_shapes.push_back({input_shapes.front(), output_shapes.front()});
236     index_node++;
237   }
238   if (node_in_out_shapes.empty() || node_in_out_shapes.size() < (node_size - 1) || node_in_out_shapes[0].size() <= 1 ||
239       node_in_out_shapes[0][1].size() <= static_cast<size_t>(splited_axis) ||
240       node_in_out_shapes[node_size - 1].empty() ||
241       node_in_out_shapes[node_size - 1][0].size() <= static_cast<size_t>(splited_axis)) {
242     MS_LOG(ERROR) << "out of node_in_out_shapes range";
243     return false;
244   }
245   int64_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
246   int64_t final_split_axis_value = node_in_out_shapes[node_size - 1][0][splited_axis];
247   split_info->ori_split_axis_value = final_split_axis_value;
248   size_t split_num = split_info->size_splits.size();
249   std::vector<int64_t> split_axis_out_shape;
250   std::vector<int64_t> split_axis_reduce_out_shape;
251   if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
252     return false;
253   }
254   // infer in-shape after splited
255   std::vector<std::vector<int64_t>> split_axis_inputs_shape{split_axis_out_shape};
256   std::vector<std::vector<int64_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
257   index_node = 0;
258   // iter node
259   while (index_node < node_size) {
260     auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
261     MS_ASSERT(conv_cnode != nullptr);
262     auto ori_conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
263     MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
264     if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
265                          &split_axis_reduce_inputs_shape)) {
266       MS_LOG(ERROR) << "CalSplitInShape failed";
267       return false;
268     }
269     index_node++;
270   }
271 
272   // update ratio
273   split_info->size_splits.clear();
274   split_info->extend_top.clear();
275   split_info->extend_bottom.clear();
276 
277   int32_t top = 0;
278   int32_t bottom = 0;
279   split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
280   split_info->extend_top.push_back(top);
281   split_info->extend_bottom.push_back(bottom);
282 
283   for (size_t i = 1; i < split_num; i++) {
284     auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
285     top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
286     auto value = split_axis_inputs_shape[node_size][i] - top;
287     split_info->size_splits.push_back(value);
288     split_info->extend_top.push_back(top);
289     split_info->extend_bottom.push_back(bottom);
290   }
291   return true;
292 }
293 
GetMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)294 bool GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
295                                  std::vector<AnfNodePtr> *outputs) {
296   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
297   MS_CHECK_TRUE_MSG(node != nullptr, false, "input AnfNodePtr is nullptr");
298   MS_CHECK_TRUE_MSG(outputs != nullptr, false, "input std::vector<AnfNodePtr> is nullptr");
299   auto cnode = node->cast<CNodePtr>();
300   MS_CHECK_TRUE_MSG(cnode != nullptr, false, "create CNode return nullptr");
301   for (size_t i = 0; i < output_num; i++) {
302     auto index = NewValueNode(SizeToInt(i));
303     MS_CHECK_TRUE_MSG(index != nullptr, false, "create ValueNode return nullptr");
304     size_t temp = SizeToInt(i);
305     auto imm = std::make_shared<Int32Imm>(temp);
306     MS_CHECK_TRUE_MSG(imm != nullptr, false, "create Int32Imm return nullptr");
307     auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
308     MS_CHECK_TRUE_MSG(abstract_scalar != nullptr, false, "create AbstractScalar return nullptr");
309     index->set_abstract(abstract_scalar);
310     auto tuple_getitem_primitive = NewValueNode(prim::kPrimTupleGetItem);
311     MS_CHECK_TRUE_MSG(tuple_getitem_primitive != nullptr, false, "create PrimTupleGetItem return nullptr");
312     auto tuple_getitem = func_graph->NewCNode({tuple_getitem_primitive, node, index});
313     MS_CHECK_TRUE_MSG(tuple_getitem != nullptr, false, "create CNode return nullptr");
314     tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
315     outputs->push_back(tuple_getitem);
316   }
317   return true;
318 }
319 
CreateOutputsOfConcat(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_cnode,const std::vector<AnfNodePtr> & conv_outputs,SplitInfo * split_info,const std::string & node_name)320 AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
321                                  const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
322                                  const std::string &node_name) {
323   MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "input FuncGraphPtr is nullptr");
324   MS_CHECK_TRUE_MSG(conv_cnode != nullptr, nullptr, "input AnfNodePtr is nullptr");
325   MS_CHECK_TRUE_MSG(split_info != nullptr, nullptr, "input SplitInfo is nullptr");
326 
327   int32_t nodes_num = conv_outputs.size();
328   if (nodes_num != static_cast<int32_t>(split_info->out_num)) {
329     MS_LOG(ERROR) << "Conv outputs has wrong input size";
330     return nullptr;
331   }
332 
333   auto concat_prim = std::make_shared<ops::Concat>();
334   MS_CHECK_TRUE_MSG(concat_prim != nullptr, nullptr, "create ops::Concat return nullptr");
335   concat_prim->set_axis(split_info->axis);
336 
337   // the inputs of concate are from the outputs of conv
338   auto concate_primitive = NewValueNode(concat_prim);
339   MS_CHECK_TRUE_MSG(concate_primitive != nullptr, nullptr, "create concate_primitive return nullptr");
340   std::vector<AnfNodePtr> concate_inputs = {concate_primitive};
341   for (int32_t i = 0; i < nodes_num; i++) {
342     concate_inputs.push_back(conv_outputs[i]);
343   }
344 
345   auto concate_cnode = func_graph->NewCNode(concate_inputs);
346   MS_CHECK_TRUE_MSG(concate_cnode != nullptr, nullptr, "create concate_cnode return nullptr");
347 
348   concate_cnode->set_fullname_with_scope(node_name + "_Concat");
349   concate_cnode->set_scope(conv_cnode->scope());
350   std::vector<AnfNodePtr> outputs;
351   if (!GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs)) {
352     MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
353     return nullptr;
354   }
355   return concate_cnode;
356 }
357 
CreateOutputsOfSplitWithOverlap(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_node,std::vector<AnfNodePtr> * split_outputs,SplitInfo * split_info,const std::string & node_name)358 bool CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
359                                      std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
360                                      const std::string &node_name) {
361   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
362   MS_CHECK_TRUE_MSG(conv_node != nullptr, false, "input conv_node is nullptr");
363   MS_CHECK_TRUE_MSG(split_outputs != nullptr, false, "input split_outputs is nullptr");
364   MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input split_info is nullptr");
365   // attr of split
366   auto split_prim = std::make_shared<ops::SplitWithOverlap>();
367   MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::SplitWithOverlap return nullptr");
368   split_prim->set_split_dim(split_info->axis);
369   split_prim->set_number_split(split_info->out_num);
370   split_prim->set_ratio(split_info->size_splits);
371   split_prim->set_extend_top(split_info->extend_top);
372   split_prim->set_extend_bottom(split_info->extend_bottom);
373   auto conv_cnode = conv_node->cast<CNodePtr>();
374 
375   // the inputs of split is from the inputs of conv
376   auto split_primitive = NewValueNode(split_prim);
377   MS_CHECK_TRUE_MSG(split_primitive != nullptr, false, "create split_primitive return nullptr");
378   std::vector<AnfNodePtr> split_inputs = {split_primitive};
379 
380   // this conv only has one input, which has been ensured before
381   split_inputs.push_back(conv_cnode->input(1));
382 
383   auto split_cnode = func_graph->NewCNode(split_inputs);
384   MS_CHECK_TRUE_MSG(split_cnode != nullptr, false, "create split_cnode return nullptr");
385 
386   split_cnode->set_fullname_with_scope(node_name + "_Split");
387   // create outputs op split
388   if (!GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs)) {
389     MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
390     return false;
391   }
392 
393   AbstractBasePtrList ptr_list;
394   for (int64_t i = 0; i < split_info->out_num; i++) {
395     auto node = (*split_outputs)[i];
396     // set date_type same with weight
397     auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
398     auto type_ptr = TypeIdToType(type_id);
399     std::vector<int64_t> shape_vector;
400     auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
401     MS_CHECK_TRUE_MSG(value_node != nullptr, false, "create abstract::AbstractTensor return nullptr");
402     ptr_list.push_back(value_node);
403   }
404   split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
405   return true;
406 }
407 
UpdateRatioWithPadStride(int64_t * ratio,size_t ratio_len,size_t split_size,int split_dim_size,int pad,int stride)408 bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_size, int split_dim_size, int pad,
409                               int stride) {
410   MS_CHECK_TRUE_MSG(ratio != nullptr, false, "input ratio is nullptr");
411   int total_block_count = 0;
412   for (size_t i = 0; i < split_size; i++) {
413     total_block_count += ratio[i];
414   }
415   if (ratio_len < split_size) {
416     MS_LOG(ERROR) << "out of ratio range";
417     return false;
418   }
419 
420   std::vector<int64_t> new_ratio(split_size);
421   int visited_block = 0;
422   for (size_t i = 0; i < split_size - 1; i++) {
423     visited_block += ratio[i];
424     if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
425       MS_LOG(ERROR) << "int mul overflow";
426       return false;
427     }
428     int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
429     new_ratio[i + 1] = cur_border;
430   }
431 
432   for (size_t i = 0; i < split_size; i++) {
433     ratio[i] = new_ratio[i];
434   }
435   return true;
436 }
437 }  // namespace opt
438 }  // namespace mindspore
439