1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/fisson/fisson_util.h"
17 #include <unordered_set>
18 #include <memory>
19 #include "base/core_ops.h"
20 #include "src/common/utils.h"
21 #include "ops/split_with_overlap.h"
22 #include "tools/common/node_util.h"
23 #include "ops/concat.h"
24 #include "tools/optimizer/parallel/spliter.h"
25 #include "tools/optimizer/parallel/split_strategy.h"
26 #include "nnacl/op_base.h"
27 #include "src/common/log_util.h"
28
29 using mindspore::converter::FmkType;
30 namespace mindspore {
31 namespace opt {
GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> & ori_conv_prim,int64_t input_h,int64_t input_w)32 std::vector<int64_t> GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t input_h,
33 int64_t input_w) {
34 if (ori_conv_prim == nullptr) {
35 MS_LOG(DEBUG) << "input Conv2DFusion is nullptr";
36 return {};
37 }
38 if (ori_conv_prim->get_pad_mode() != SAME) {
39 return ori_conv_prim->get_pad_list();
40 }
41 if (ori_conv_prim->get_stride().size() < kIndexW || ori_conv_prim->get_kernel_size().size() < kIndexW ||
42 ori_conv_prim->get_dilation().size() < kIndexW) {
43 MS_LOG(ERROR) << "Index out of range";
44 return {};
45 }
46 int64_t output_h = static_cast<int64_t>(
47 std::ceil(static_cast<float>(input_h) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexH))));
48 int64_t output_w = static_cast<int64_t>(
49 std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
50
51 auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
52 auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
53 auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
54 auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
55 if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
56 INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
57 MS_LOG(ERROR) << "int mul overflow";
58 return {};
59 }
60 std::vector<int64_t> new_pad_list;
61 int64_t pad_up = 0, pad_down = 0, pad_left = 0, pad_right = 0;
62 int64_t pad_h_all =
63 (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
64 int64_t pad_w_all =
65 (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
66 // only check pad_up and pad_down is positive
67 // if compute overflowed, we will get abnormal it in infer_shape
68 if (pad_h_all >= 0) {
69 pad_up = pad_h_all / 2;
70 pad_down = pad_h_all - pad_up;
71 }
72 new_pad_list.push_back(pad_up);
73 new_pad_list.push_back(pad_down);
74 if (pad_w_all >= 0) {
75 pad_left = pad_w_all / 2;
76 pad_right = pad_w_all - pad_left;
77 }
78 new_pad_list.push_back(pad_left);
79 new_pad_list.push_back(pad_right);
80 return new_pad_list;
81 }
82
83 namespace {
CalSplitOutputShape(int64_t splited_axis_value,const SplitInfo * split_info,std::vector<int64_t> * split_axis_out_shape,std::vector<int64_t> * split_axis_reduce_out_shape)84 bool CalSplitOutputShape(int64_t splited_axis_value, const SplitInfo *split_info,
85 std::vector<int64_t> *split_axis_out_shape,
86 std::vector<int64_t> *split_axis_reduce_out_shape) {
87 MS_ASSERT(split_info != nullptr && split_axis_out_shape != nullptr && split_axis_reduce_out_shape != nullptr);
88 // ori ratio
89 int64_t split_num = split_info->out_num;
90 int64_t split_len = 0;
91 for (int64_t i = 0; i < split_num; i++) {
92 split_len += split_info->size_splits[i];
93 }
94 if (split_len > splited_axis_value) {
95 return false;
96 }
97 // out-shape after splited
98 int64_t tmp_value = 0;
99 MS_CHECK_TRUE_MSG(split_num > 0, false, "out_num of split_info should greater than zero");
100 MS_CHECK_TRUE_MSG(split_len > 0, false, "split_len should greater than zero");
101 for (int64_t i = 0; i < split_num - 1; i++) {
102 if (INT_MUL_OVERFLOW_THRESHOLD(split_info->size_splits[i], splited_axis_value, INT64_MAX)) {
103 MS_LOG(ERROR) << "int mul overflow";
104 return false;
105 }
106 int64_t tmp = UP_DIV(split_info->size_splits[i] * splited_axis_value, split_len);
107 tmp_value += tmp;
108 split_axis_out_shape->push_back(tmp);
109 split_axis_reduce_out_shape->push_back(tmp_value);
110 }
111 split_axis_out_shape->push_back(splited_axis_value - tmp_value);
112 split_axis_reduce_out_shape->push_back(splited_axis_value);
113 return true;
114 }
115
CalSplitInShape(const std::vector<std::vector<ShapeVector>> & node_in_out_shapes,const SplitInfo * split_info,const std::shared_ptr<ops::Conv2DFusion> & ori_conv_prim,int64_t index_node,std::vector<std::vector<int64_t>> * split_axis_inputs_shape,std::vector<std::vector<int64_t>> * split_axis_reduce_inputs_shape)116 bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_shapes, const SplitInfo *split_info,
117 const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t index_node,
118 std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
119 std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
120 MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
121 split_axis_reduce_inputs_shape != nullptr);
122 MS_ASSERT(node_in_out_shapes.size() > index_node);
123 auto in_out_shape = node_in_out_shapes.at(index_node);
124 MS_ASSERT(!in_out_shape.empty());
125 auto in_shape = in_out_shape.front();
126 if (in_shape.size() < kAxisW) {
127 MS_LOG(DEBUG) << "out of in_shape range";
128 return false;
129 }
130 int64_t input_h = in_shape.at(kAxisH);
131 int64_t input_w = in_shape.at(kAxisW);
132 auto new_pad_list = GetSplitPadList(ori_conv_prim, input_h, input_w);
133 ori_conv_prim->set_pad_list(new_pad_list);
134 int64_t split_num = split_info->out_num;
135 int64_t tmp = 0;
136 std::vector<int64_t> split_axis_shape;
137 std::vector<int64_t> split_axis_reduce_shape;
138 // iter splited_num
139 for (int64_t index = 0; index < split_num; index++) {
140 // shape
141 auto stride_h = ori_conv_prim->get_stride()[kIndexH];
142 auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
143 if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
144 MS_LOG(ERROR) << "int mul overflow";
145 return false;
146 }
147 if (split_info->axis == CuttingStragedy::CUT_H) { // H
148 if (index == 0) {
149 tmp =
150 stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
151 } else if (index == split_num - 1) {
152 tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
153 ori_conv_prim->get_kernel_size()[kIndexH];
154 } else {
155 tmp = stride_h * split_axis_dim - 0 + ori_conv_prim->get_kernel_size()[kIndexH];
156 }
157 }
158 split_axis_shape.push_back(tmp);
159
160 // reduce shape
161 auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
162 if (split_info->axis == CuttingStragedy::CUT_H) { // H
163 if (index == split_num - 1) {
164 tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
165 ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
166 } else {
167 tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
168 ori_conv_prim->get_kernel_size()[kIndexH];
169 }
170 }
171 split_axis_reduce_shape.push_back(tmp);
172 }
173 split_axis_inputs_shape->push_back(split_axis_shape);
174 split_axis_reduce_inputs_shape->push_back(split_axis_reduce_shape);
175 return true;
176 }
177 } // namespace
178
IsConv2D(const AnfNodePtr & node)179 bool IsConv2D(const AnfNodePtr &node) {
180 return (CheckPrimitiveType(node, prim::kPrimConv2D) || CheckPrimitiveType(node, prim::kPrimConv2DFusion));
181 }
182
CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> & ori_conv_prim)183 std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim) {
184 MS_CHECK_TRUE_MSG(ori_conv_prim != nullptr, nullptr, "input Conv2DFusion is nullptr");
185 auto new_prim = std::make_shared<ops::Conv2DFusion>();
186 MS_CHECK_TRUE_MSG(new_prim != nullptr, nullptr, "create Conv2DFusion return nullptr");
187 new_prim->set_pad(ori_conv_prim->get_pad());
188 new_prim->set_in_channel(ori_conv_prim->get_in_channel());
189 new_prim->set_out_channel(ori_conv_prim->get_out_channel());
190 new_prim->set_dilation(ori_conv_prim->get_dilation());
191 new_prim->set_format(ori_conv_prim->get_format());
192 new_prim->set_group(ori_conv_prim->get_group());
193 new_prim->set_kernel_size(ori_conv_prim->get_kernel_size());
194 if (ori_conv_prim->get_pad_mode() == SAME) {
195 new_prim->set_pad_mode(PAD);
196 } else {
197 new_prim->set_pad_mode(ori_conv_prim->get_pad_mode());
198 }
199
200 new_prim->set_stride(ori_conv_prim->get_stride());
201 new_prim->set_activation_type(ori_conv_prim->get_activation_type());
202 new_prim->set_pad_list(ori_conv_prim->get_pad_list());
203 auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
204 if (is_depth_value != nullptr) {
205 bool is_depth_wise = GetValue<bool>(is_depth_value);
206 new_prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
207 }
208 return new_prim;
209 }
210
UpdateSplitInfo(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & conv_nodes,SplitInfo * split_info)211 bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_nodes, SplitInfo *split_info) {
212 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
213 MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input SplitInfo is nullptr");
214 if (split_info->axis != CuttingStragedy::CUT_H) {
215 return false;
216 }
217 auto splited_axis = split_info->axis;
218 // need to check
219 if (split_info->fmk_type == FmkType::kFmkTypeCaffe ||
220 split_info->fmk_type == FmkType::kFmkTypeOnnx) { // NHWC -> NCHW
221 splited_axis += 1;
222 }
223
224 size_t node_size = conv_nodes.size();
225 size_t index_node = 0;
226 std::vector<std::vector<ShapeVector>> node_in_out_shapes;
227 while (index_node < node_size) {
228 // [conv3, conv2, conv1] conv1->conv2->conv3
229 auto out_node_name = conv_nodes[index_node]->fullname_with_scope();
230 auto output_shapes = Spliter::GetInstance()->graph_node_output_shapes()[out_node_name];
231 auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
232 // 0-> in-shape 1->out-shape
233 // only one in and one output
234 MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
235 node_in_out_shapes.push_back({input_shapes.front(), output_shapes.front()});
236 index_node++;
237 }
238 if (node_in_out_shapes.empty() || node_in_out_shapes.size() < (node_size - 1) || node_in_out_shapes[0].size() <= 1 ||
239 node_in_out_shapes[0][1].size() <= static_cast<size_t>(splited_axis) ||
240 node_in_out_shapes[node_size - 1].empty() ||
241 node_in_out_shapes[node_size - 1][0].size() <= static_cast<size_t>(splited_axis)) {
242 MS_LOG(ERROR) << "out of node_in_out_shapes range";
243 return false;
244 }
245 int64_t splited_axis_value = node_in_out_shapes[0][1][splited_axis];
246 int64_t final_split_axis_value = node_in_out_shapes[node_size - 1][0][splited_axis];
247 split_info->ori_split_axis_value = final_split_axis_value;
248 size_t split_num = split_info->size_splits.size();
249 std::vector<int64_t> split_axis_out_shape;
250 std::vector<int64_t> split_axis_reduce_out_shape;
251 if (!CalSplitOutputShape(splited_axis_value, split_info, &split_axis_out_shape, &split_axis_reduce_out_shape)) {
252 return false;
253 }
254 // infer in-shape after splited
255 std::vector<std::vector<int64_t>> split_axis_inputs_shape{split_axis_out_shape};
256 std::vector<std::vector<int64_t>> split_axis_reduce_inputs_shape{split_axis_reduce_out_shape};
257 index_node = 0;
258 // iter node
259 while (index_node < node_size) {
260 auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
261 MS_ASSERT(conv_cnode != nullptr);
262 auto ori_conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
263 MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
264 if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
265 &split_axis_reduce_inputs_shape)) {
266 MS_LOG(ERROR) << "CalSplitInShape failed";
267 return false;
268 }
269 index_node++;
270 }
271
272 // update ratio
273 split_info->size_splits.clear();
274 split_info->extend_top.clear();
275 split_info->extend_bottom.clear();
276
277 int32_t top = 0;
278 int32_t bottom = 0;
279 split_info->size_splits.push_back(split_axis_inputs_shape[node_size][0]);
280 split_info->extend_top.push_back(top);
281 split_info->extend_bottom.push_back(bottom);
282
283 for (size_t i = 1; i < split_num; i++) {
284 auto begin = split_axis_reduce_inputs_shape[node_size][i] - split_axis_inputs_shape[node_size][i] + 1;
285 top = split_axis_reduce_inputs_shape[node_size][i - 1] - begin + 1;
286 auto value = split_axis_inputs_shape[node_size][i] - top;
287 split_info->size_splits.push_back(value);
288 split_info->extend_top.push_back(top);
289 split_info->extend_bottom.push_back(bottom);
290 }
291 return true;
292 }
293
GetMultipleOutputsOfAnfNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,size_t output_num,std::vector<AnfNodePtr> * outputs)294 bool GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
295 std::vector<AnfNodePtr> *outputs) {
296 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
297 MS_CHECK_TRUE_MSG(node != nullptr, false, "input AnfNodePtr is nullptr");
298 MS_CHECK_TRUE_MSG(outputs != nullptr, false, "input std::vector<AnfNodePtr> is nullptr");
299 auto cnode = node->cast<CNodePtr>();
300 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "create CNode return nullptr");
301 for (size_t i = 0; i < output_num; i++) {
302 auto index = NewValueNode(SizeToInt(i));
303 MS_CHECK_TRUE_MSG(index != nullptr, false, "create ValueNode return nullptr");
304 size_t temp = SizeToInt(i);
305 auto imm = std::make_shared<Int32Imm>(temp);
306 MS_CHECK_TRUE_MSG(imm != nullptr, false, "create Int32Imm return nullptr");
307 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
308 MS_CHECK_TRUE_MSG(abstract_scalar != nullptr, false, "create AbstractScalar return nullptr");
309 index->set_abstract(abstract_scalar);
310 auto tuple_getitem_primitive = NewValueNode(prim::kPrimTupleGetItem);
311 MS_CHECK_TRUE_MSG(tuple_getitem_primitive != nullptr, false, "create PrimTupleGetItem return nullptr");
312 auto tuple_getitem = func_graph->NewCNode({tuple_getitem_primitive, node, index});
313 MS_CHECK_TRUE_MSG(tuple_getitem != nullptr, false, "create CNode return nullptr");
314 tuple_getitem->set_fullname_with_scope(cnode->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i + 1));
315 outputs->push_back(tuple_getitem);
316 }
317 return true;
318 }
319
CreateOutputsOfConcat(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_cnode,const std::vector<AnfNodePtr> & conv_outputs,SplitInfo * split_info,const std::string & node_name)320 AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_cnode,
321 const std::vector<AnfNodePtr> &conv_outputs, SplitInfo *split_info,
322 const std::string &node_name) {
323 MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "input FuncGraphPtr is nullptr");
324 MS_CHECK_TRUE_MSG(conv_cnode != nullptr, nullptr, "input AnfNodePtr is nullptr");
325 MS_CHECK_TRUE_MSG(split_info != nullptr, nullptr, "input SplitInfo is nullptr");
326
327 int32_t nodes_num = conv_outputs.size();
328 if (nodes_num != static_cast<int32_t>(split_info->out_num)) {
329 MS_LOG(ERROR) << "Conv outputs has wrong input size";
330 return nullptr;
331 }
332
333 auto concat_prim = std::make_shared<ops::Concat>();
334 MS_CHECK_TRUE_MSG(concat_prim != nullptr, nullptr, "create ops::Concat return nullptr");
335 concat_prim->set_axis(split_info->axis);
336
337 // the inputs of concate are from the outputs of conv
338 auto concate_primitive = NewValueNode(concat_prim);
339 MS_CHECK_TRUE_MSG(concate_primitive != nullptr, nullptr, "create concate_primitive return nullptr");
340 std::vector<AnfNodePtr> concate_inputs = {concate_primitive};
341 for (int32_t i = 0; i < nodes_num; i++) {
342 concate_inputs.push_back(conv_outputs[i]);
343 }
344
345 auto concate_cnode = func_graph->NewCNode(concate_inputs);
346 MS_CHECK_TRUE_MSG(concate_cnode != nullptr, nullptr, "create concate_cnode return nullptr");
347
348 concate_cnode->set_fullname_with_scope(node_name + "_Concat");
349 concate_cnode->set_scope(conv_cnode->scope());
350 std::vector<AnfNodePtr> outputs;
351 if (!GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs)) {
352 MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
353 return nullptr;
354 }
355 return concate_cnode;
356 }
357
CreateOutputsOfSplitWithOverlap(const FuncGraphPtr & func_graph,const AnfNodePtr & conv_node,std::vector<AnfNodePtr> * split_outputs,SplitInfo * split_info,const std::string & node_name)358 bool CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNodePtr &conv_node,
359 std::vector<AnfNodePtr> *split_outputs, SplitInfo *split_info,
360 const std::string &node_name) {
361 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "input FuncGraphPtr is nullptr");
362 MS_CHECK_TRUE_MSG(conv_node != nullptr, false, "input conv_node is nullptr");
363 MS_CHECK_TRUE_MSG(split_outputs != nullptr, false, "input split_outputs is nullptr");
364 MS_CHECK_TRUE_MSG(split_info != nullptr, false, "input split_info is nullptr");
365 // attr of split
366 auto split_prim = std::make_shared<ops::SplitWithOverlap>();
367 MS_CHECK_TRUE_MSG(split_prim != nullptr, false, "create ops::SplitWithOverlap return nullptr");
368 split_prim->set_split_dim(split_info->axis);
369 split_prim->set_number_split(split_info->out_num);
370 split_prim->set_ratio(split_info->size_splits);
371 split_prim->set_extend_top(split_info->extend_top);
372 split_prim->set_extend_bottom(split_info->extend_bottom);
373 auto conv_cnode = conv_node->cast<CNodePtr>();
374
375 // the inputs of split is from the inputs of conv
376 auto split_primitive = NewValueNode(split_prim);
377 MS_CHECK_TRUE_MSG(split_primitive != nullptr, false, "create split_primitive return nullptr");
378 std::vector<AnfNodePtr> split_inputs = {split_primitive};
379
380 // this conv only has one input, which has been ensured before
381 split_inputs.push_back(conv_cnode->input(1));
382
383 auto split_cnode = func_graph->NewCNode(split_inputs);
384 MS_CHECK_TRUE_MSG(split_cnode != nullptr, false, "create split_cnode return nullptr");
385
386 split_cnode->set_fullname_with_scope(node_name + "_Split");
387 // create outputs op split
388 if (!GetMultipleOutputsOfAnfNode(func_graph, split_cnode, split_info->out_num, split_outputs)) {
389 MS_LOG(ERROR) << "GetMultipleOutputsOfAnfNode failed";
390 return false;
391 }
392
393 AbstractBasePtrList ptr_list;
394 for (int64_t i = 0; i < split_info->out_num; i++) {
395 auto node = (*split_outputs)[i];
396 // set date_type same with weight
397 auto type_id = static_cast<TypeId>(kNumberTypeFloat32);
398 auto type_ptr = TypeIdToType(type_id);
399 std::vector<int64_t> shape_vector;
400 auto value_node = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
401 MS_CHECK_TRUE_MSG(value_node != nullptr, false, "create abstract::AbstractTensor return nullptr");
402 ptr_list.push_back(value_node);
403 }
404 split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
405 return true;
406 }
407
UpdateRatioWithPadStride(int64_t * ratio,size_t ratio_len,size_t split_size,int split_dim_size,int pad,int stride)408 bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_size, int split_dim_size, int pad,
409 int stride) {
410 MS_CHECK_TRUE_MSG(ratio != nullptr, false, "input ratio is nullptr");
411 int total_block_count = 0;
412 for (size_t i = 0; i < split_size; i++) {
413 total_block_count += ratio[i];
414 }
415 if (ratio_len < split_size) {
416 MS_LOG(ERROR) << "out of ratio range";
417 return false;
418 }
419
420 std::vector<int64_t> new_ratio(split_size);
421 int visited_block = 0;
422 for (size_t i = 0; i < split_size - 1; i++) {
423 visited_block += ratio[i];
424 if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
425 MS_LOG(ERROR) << "int mul overflow";
426 return false;
427 }
428 int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
429 new_ratio[i + 1] = cur_border;
430 }
431
432 for (size_t i = 0; i < split_size; i++) {
433 ratio[i] = new_ratio[i];
434 }
435 return true;
436 }
437 } // namespace opt
438 } // namespace mindspore
439