1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/strided_slice_info.h"
18
19 #include <bitset>
20 #include <algorithm>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 #include <functional>
25
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/dynamic_creator.h"
28 #include "frontend/parallel/strategy.h"
29 #include "frontend/parallel/graph_util/node_info.h"
30 #include "frontend/parallel/graph_util/graph_utils.h"
31 #include "frontend/parallel/step_parallel.h"
32 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
33 #include "pipeline/jit/ps/resource.h"
34 #include "mindspore/core/symbolic_shape/symbol.h"
35
36 namespace mindspore {
37 namespace parallel {
38 // 1, The mask is a int number, it needs to be converted to binary and reversed.
39 // (e.g. the input's dimension is 4, and mask is 2, binary is [0, 0, 1, 0], after reversing: [0, 1, 0, 0])
40 // 2, If the ith bit of `begin_mask` is set, `begin[i]` is ignored.
41 // 3, If the ith bit of `end_mask` is set, `end[i]` is ignored.
42 // 4, If the ith bit of `ellipsis_mask` is set, begin[i]/end[i]/strides[i] replace to `...`, it is not supported now.
43 // 5, If the ith bit of `new_axis_mask` is set:
44 // (e.g. input shape: (A, B, C, D), begin: (0, 0), end: (m, n), strides: (1, 1), new_axis_mask: 2)
45 // 1) The corresponding position is expanded by one dimension; (input shape:(A, 1, B, C, D))
46 // 2) Ignore the corresponding position of begin/end/strides; (begin: (0, ig), end: (m, ig), strides: (1, ig))
47 // 3) The output shape is (m, 1, B, C, D)
48 // 6, If the ith bit of `shrink_axis_mask` is set, delete that dimension.
49 // (e.g. input shape: (A, B, C, D), begin: (0, 0), end: (m, n), strides: (1, 1), shrink_axis_mask: 2,
50 // the output shape: (m, C, D)
51 // notice: if input is [[1, 2], [3, 4]] and all fetch, but shrink_axis_mask is 1, then the output is [1, 2],
52 // so if the ith bit of 'shrink_axis_mask' is set, the dimension can not be split
53 // 7, If the ith bit of `new_axis_mask` and `shrink_axis_mask` are both set, ignore the ith bit of `shrink_axis_mask`.
54 // 8, The size of begin/mask/strides must be equal, but it can smaller than input's dimension.
55 // 9, The mask part exceeding the begin/end/strides length is not effective.
GetMask(const std::string & mask_name,int64_t * mask_value)56 Status StridedSliceInfo::GetMask(const std::string &mask_name, int64_t *mask_value) {
57 if (mask_value == nullptr) {
58 return FAILED;
59 }
60
61 auto mask_opt = GetScalarValueFromInputs<int64_t>(input_value_, name_, mask_name);
62 if (!mask_opt.has_value()) {
63 MS_LOG(ERROR) << name_ << " failed to get value for " << mask_name << ".";
64 }
65 *mask_value = mask_opt.value();
66 MS_LOG(INFO) << name_ << ": The attr name: " << mask_name << ", the value is " << *mask_value;
67 return SUCCESS;
68 }
69
70 constexpr auto kStridedSliceMaxDims = 8;
Dec2Bin(int64_t mask)71 static std::vector<bool> Dec2Bin(int64_t mask) {
72 auto mask_str = std::bitset<kStridedSliceMaxDims>(mask).to_string();
73 std::vector<bool> result;
74 (void)std::transform(mask_str.rbegin(), mask_str.rend(), std::back_inserter(result),
75 [](const char &c) { return c == '1'; });
76 return result;
77 }
78
79 // If the ith bit of `begin_mask` is set, `begin[i]` is ignored.
80 // The mask part exceeding the begin length is not effective.
ComputeBeginMask()81 void StridedSliceInfo::ComputeBeginMask() {
82 for (size_t i = 0; i < begin_mask_bitmap_.size() && i < begin_.size(); ++i) {
83 if (begin_mask_bitmap_[i]) {
84 begin_[i] = strides_[i] < 0 ? inputs_shape_[0][i] - 1 : 0;
85 }
86 }
87
88 if (begin_mask_ != 0) {
89 MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_;
90 }
91 }
92
93 // If the ith bit of `end_mask` is set, `end[i]` is ignored.
94 // The mask part exceeding the end length is not effective.
ComputeEndMask()95 void StridedSliceInfo::ComputeEndMask() {
96 for (size_t j = 0; j < end_mask_bitmap_.size() && j < end_.size(); ++j) {
97 if (end_mask_bitmap_[j]) {
98 end_[j] = strides_[j] < 0 ? -1 : inputs_shape_[0][j];
99 }
100 }
101
102 if (end_mask_ != 0) {
103 MS_LOG(INFO) << name_ << ": The end is modified to " << end_;
104 }
105 }
106
107 // If the ith bit of `ellipsis_mask` is set, begin[i]/end[i]/strides[i] replace to `...`, it is not supported now.
ComputeEllipsisMask()108 void StridedSliceInfo::ComputeEllipsisMask() {
109 for (size_t k = 0; k < ellipsis_mask_bitmap_.size() && k < begin_.size(); ++k) {
110 if (ellipsis_mask_bitmap_[k]) {
111 begin_[k] = 0;
112 end_[k] = inputs_shape_[0][k];
113 strides_[k] = 1;
114 }
115 }
116 }
117
118 // If the ith bit of `new_axis_mask` is set:
119 // (e.g. input shape: (A, B, C, D), begin: (0, 0, 0, 0), end: (m, n, o, p), strides: (1, 1, 1, 1), new_axis_mask: 2)
120 // Here, the size of begin/end/strides is equal to input's dimension through ComplementBeginEndStrides()
121 // 1) The corresponding position is expanded by one dimension; (input shape:(A, 1, B, C, D))
122 // 2) Ignore the corresponding position of begin/end/strides;
123 // (begin: (0, ig, 0, 0), end: (m, ig, o, p), strides: (1, ig, 1, 1))
124 // 3) The output shape is (m, 1, o, p, D)
125 // So, use input_shape_in_process_ to generate a tmp input shape
ComputeNewAxisMask()126 void StridedSliceInfo::ComputeNewAxisMask() {
127 input_shape_in_process_ = Shape(inputs_shape_[0].size(), 0);
128 for (size_t l = 0; l < new_axis_mask_bitmap_.size() && l < begin_.size() && l < input_shape_in_process_.size(); ++l) {
129 if (new_axis_mask_bitmap_[l]) {
130 input_shape_in_process_[l] = 1;
131 begin_[l] = 0;
132 end_[l] = 1;
133 strides_[l] = 1;
134 }
135 }
136
137 size_t count = 0;
138 for (auto &ele : input_shape_in_process_) {
139 if (ele != 0) {
140 continue;
141 }
142 ele = inputs_shape_[0][count];
143 count++;
144 }
145
146 (void)input_shape_in_process_.insert(input_shape_in_process_.end(), inputs_shape_[0].begin() + count,
147 inputs_shape_[0].end());
148
149 if (new_axis_mask_ != 0) {
150 MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_ << ", the end is modified to " << end_
151 << ", the strides is modified to " << strides_ << ", the input shape in process is "
152 << input_shape_in_process_;
153 }
154 }
155
156 // If the ith bit of `new_axis_mask` and `shrink_axis_mask` are both set, ignore the ith bit of `shrink_axis_mask`.
AdjustShrinkAxisMask()157 void StridedSliceInfo::AdjustShrinkAxisMask() {
158 bool flag = false;
159 for (size_t i = 0; i < new_axis_mask_bitmap_.size(); ++i) {
160 if (new_axis_mask_bitmap_[i]) {
161 shrink_axis_mask_bitmap_[i] = false;
162 flag = true;
163 }
164 }
165 if (flag) {
166 MS_LOG(INFO) << name_ << ": The shrink axis mask is modified to " << shrink_axis_mask_bitmap_;
167 }
168 }
169
ComputeFullyFetchFlag()170 void StridedSliceInfo::ComputeFullyFetchFlag() {
171 ListSymbolPtr in_symbol = nullptr;
172 ListSymbolPtr out_symbol = nullptr;
173
174 if (dynamic_shape_flag_) {
175 MS_EXCEPTION_IF_NULL(cnode_);
176 MS_EXCEPTION_IF_NULL(cnode_->input(1));
177 MS_EXCEPTION_IF_NULL(cnode_->input(1)->abstract());
178 MS_EXCEPTION_IF_NULL(cnode_->abstract());
179 in_symbol = cnode_->input(1)->abstract()->GetSymbolicShape(); // the input of stridedslice
180 out_symbol = cnode_->abstract()->GetSymbolicShape(); // the output of stridedslice
181 MS_EXCEPTION_IF_NULL(in_symbol);
182 MS_EXCEPTION_IF_NULL(out_symbol);
183 }
184 fully_fetch_flag_.clear();
185
186 for (size_t k = 0; k < begin_.size(); ++k) {
187 bool fully_fetch = false;
188 if (dynamic_shape_flag_) {
189 MS_EXCEPTION_IF_NULL(in_symbol->item(k));
190 if (in_symbol->item(k)->EqualsTo(out_symbol->item(k))) {
191 fully_fetch = true;
192 }
193 } else {
194 fully_fetch = ((begin_[k] == 0) && (end_[k] >= input_shape_in_process_[k]));
195 }
196 fully_fetch_flag_.push_back(fully_fetch);
197 }
198
199 MS_LOG(INFO) << name_ << ": the fully fetch flag is " << fully_fetch_flag_;
200 }
201
GetAttrs()202 Status StridedSliceInfo::GetAttrs() {
203 if ((GetMask(BEGIN_MASK, &begin_mask_) != SUCCESS) || (GetMask(END_MASK, &end_mask_) != SUCCESS) ||
204 (GetMask(ELLIPSIS_MASK, &ellipsis_mask_) != SUCCESS) || (GetMask(NEW_AXIS_MASK, &new_axis_mask_) != SUCCESS) ||
205 (GetMask(SHRINK_AXIS_MASK, &shrink_axis_mask_) != SUCCESS)) {
206 return FAILED;
207 }
208
209 has_mask_ =
210 (begin_mask_ != 0 || end_mask_ != 0 || ellipsis_mask_ != 0 || new_axis_mask_ != 0 || shrink_axis_mask_ != 0);
211
212 if (ellipsis_mask_ != 0) {
213 MS_LOG(ERROR) << name_ << ": It can not support ellipsis_mask now";
214 return FAILED;
215 }
216
217 // convert mask to bit map
218 begin_mask_bitmap_ = Dec2Bin(begin_mask_);
219 end_mask_bitmap_ = Dec2Bin(end_mask_);
220 ellipsis_mask_bitmap_ = Dec2Bin(ellipsis_mask_);
221 new_axis_mask_bitmap_ = Dec2Bin(new_axis_mask_);
222 shrink_axis_mask_bitmap_ = Dec2Bin(shrink_axis_mask_);
223 MS_LOG(INFO) << name_ << ": The begin mask bitmap is " << begin_mask_bitmap_;
224 MS_LOG(INFO) << name_ << ": The end mask bitmap is " << end_mask_bitmap_;
225 MS_LOG(INFO) << name_ << ": The ellipsis mask bitmap is " << ellipsis_mask_bitmap_;
226 MS_LOG(INFO) << name_ << ": The new axis mask bitmap is " << new_axis_mask_bitmap_;
227 MS_LOG(INFO) << name_ << ": The shrink axis mask bitmap is " << shrink_axis_mask_bitmap_;
228
229 // if the ith bit of `new_axis_mask` and `shrink_axis_mask` are both set, ignore the ith bit of `shrink_axis_mask`
230 AdjustShrinkAxisMask();
231
232 // get begin/end/strides, the size of begin/mask/strides must be equal, but it can smaller than input's dimension
233 if (input_value_.size() != STRIDED_SLICE_INPUTS_SIZE) {
234 MS_LOG(ERROR) << name_ << ": The size of input value must be " << STRIDED_SLICE_INPUTS_SIZE << ", but got "
235 << input_value_.size();
236 return FAILED;
237 }
238
239 std::vector<int64_t> unknow_value(inputs_shape_[0].size(), -1);
240 if (input_value_[STRIDED_SLICE_BEGIN_INDEX] != nullptr) {
241 if (TransValueSequeueToVector(input_value_[STRIDED_SLICE_BEGIN_INDEX], &begin_) != SUCCESS) {
242 MS_LOG(ERROR) << name_ << ": get begin value failed";
243 return FAILED;
244 }
245 } else {
246 begin_ = unknow_value;
247 }
248
249 if (input_value_[STRIDED_SLICE_END_INDEX] != nullptr) {
250 if (TransValueSequeueToVector(input_value_[STRIDED_SLICE_END_INDEX], &end_) != SUCCESS) {
251 MS_LOG(ERROR) << name_ << ": get end value failed";
252 return FAILED;
253 }
254 } else {
255 end_ = unknow_value;
256 }
257
258 if (input_value_[STRIDED_SLICE_STRIDES_INDEX] != nullptr) {
259 if (TransValueSequeueToVector(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS) {
260 MS_LOG(ERROR) << name_ << ": get strides value failed";
261 return FAILED;
262 }
263 } else {
264 strides_ = unknow_value;
265 }
266
267 MS_LOG(INFO) << name_ << ": The begin is " << begin_ << ", the end is " << end_ << ", the stride is " << strides_;
268
269 // handle the masks, it will modify the begin/end/strides, the new begin/end/strides are only used for CheckStrategy()
270 ComputeBeginMask();
271 ComputeEndMask();
272 ComputeEllipsisMask();
273 ComputeNewAxisMask();
274 // no need to handle shrink axis mask
275 auto prim = GetCNodePrimitive(cnode_);
276 if (prim->HasAttr(parallel::SKIP_REDISTRIBUTION)) {
277 skip_redistribution_ = GetValue<bool>(prim->GetAttr(parallel::SKIP_REDISTRIBUTION));
278 }
279
280 ComputeFullyFetchFlag();
281 return SUCCESS;
282 }
283
CheckInputStrategy(const Shape & strategy_value)284 Status StridedSliceInfo::CheckInputStrategy(const Shape &strategy_value) {
285 // change the strategy if the new mask axis is set
286 Shape strategy_in_process = Shape(strategy_value.size(), 0);
287 for (size_t i = 0; i < new_axis_mask_bitmap_.size() && i < begin_.size() && i < strategy_value.size(); ++i) {
288 if (new_axis_mask_bitmap_[i]) {
289 strategy_in_process[i] = 1;
290 }
291 }
292
293 size_t count = 0;
294 for (auto &ele : strategy_in_process) {
295 if (ele != 0) {
296 continue;
297 }
298 ele = strategy_value[count];
299 count++;
300 }
301
302 (void)strategy_in_process.insert(strategy_in_process.end(), strategy_value.begin() + count, strategy_value.end());
303 MS_LOG(INFO) << name_ << ": The strategy in process is " << strategy_in_process;
304
305 for (size_t j = 0; j < strides_.size(); ++j) {
306 if ((strides_[j] != 1) && (strategy_in_process[j] > 1)) {
307 MS_LOG(ERROR)
308 << name_
309 << ": When a certain dimension is split, now does not support that the stride is not 1, the strides is "
310 << strides_ << ", the strategy is " << strategy_in_process << ", the index is " << j;
311 return FAILED;
312 }
313 }
314
315 for (size_t k = 0; k < begin_.size(); ++k) {
316 if (!fully_fetch_flag_[k] && (strategy_in_process[k] != 1) && !skip_redistribution_) {
317 MS_LOG(ERROR) << name_
318 << ": When a dimension is not fully fetched, the dimension can not be split now, the begin is "
319 << begin_ << ", the end is " << end_ << ", the index is " << k << ", the input shape in process is "
320 << input_shape_in_process_ << ", the strategy in process is " << strategy_in_process;
321 return FAILED;
322 }
323 }
324
325 // if the ith bit of 'shrink_axis_mask' is set, the dimension can not be split
326 for (size_t l = 0; l < strategy_in_process.size() && l < shrink_axis_mask_bitmap_.size(); ++l) {
327 if (shrink_axis_mask_bitmap_[l] && strategy_in_process[l] != 1) {
328 MS_LOG(ERROR) << name_
329 << ": When a dimension is shrunk, the dimension can not be split now, the strategy in process is "
330 << strategy_in_process << ", the shrink axis mask bitmap is " << shrink_axis_mask_bitmap_;
331 return FAILED;
332 }
333 }
334
335 return SUCCESS;
336 }
337
CheckStrategy(const StrategyPtr & strategy)338 Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) {
339 MS_EXCEPTION_IF_NULL(strategy);
340 Shapes valid_inputs_shape = {inputs_shape_[0]};
341 if (CheckStrategyValue(strategy, valid_inputs_shape) != SUCCESS) {
342 MS_LOG(ERROR) << name_ << ": Invalid strategy";
343 return FAILED;
344 }
345
346 std::vector<Dimensions> stra = strategy->GetInputDim();
347 if (stra.empty()) {
348 MS_LOG(ERROR) << name_ << ": The strategy is empty";
349 return FAILED;
350 }
351
352 Dimensions strategy_value = stra[0];
353 if (strategy_value.size() < strides_.size()) {
354 MS_LOG(ERROR) << name_ << ": The size of strategy must be larger or equal to the size of strides";
355 return FAILED;
356 }
357
358 if (dynamic_shape_flag_) {
359 auto shard_num = std::accumulate(strategy_value.begin(), strategy_value.end(), 1, std::multiplies<int64_t>());
360 if (shard_num == 1) {
361 return SUCCESS;
362 }
363
364 if (has_mask_) {
365 MS_LOG(ERROR) << name_ << ": it does not support dynamic shape when it has mask, the strategy is "
366 << ShapeToString(strategy_value);
367 return FAILED;
368 }
369
370 if (strides_ == Shape(inputs_shape_[0].size(), -1)) {
371 MS_LOG(ERROR) << name_ << ": it does not support dynamic shape when the strides attr is not constant";
372 return FAILED;
373 }
374 }
375
376 return CheckInputStrategy(strategy_value);
377 }
378
InferDevMatrixShape()379 Status StridedSliceInfo::InferDevMatrixShape() {
380 MS_EXCEPTION_IF_NULL(strategy_);
381 std::vector<Dimensions> stra = strategy_->GetInputDim();
382 if (stra.empty()) {
383 MS_LOG(ERROR) << name_ << "The strategy is empty";
384 return FAILED;
385 }
386
387 dev_matrix_shape_ = stra[0];
388 return SUCCESS;
389 }
390
InferTensorMap()391 Status StridedSliceInfo::InferTensorMap() {
392 TensorMap tensor_map;
393 if (inputs_shape_.empty()) {
394 MS_LOG(ERROR) << name_ << "The inputs shape is empty";
395 return FAILED;
396 }
397
398 // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
399 int64_t size = SizeToLong(inputs_shape_[0].size());
400 for (int64_t i = 0; i < size; ++i) {
401 tensor_map.push_back(size - i - 1);
402 }
403
404 inputs_tensor_map_.push_back(tensor_map);
405
406 // If the ith bit of `new_axis_mask` is set, the corresponding position is expanded by one dimension, and this
407 // dimension need to insert MAP_NONE for output tensor map.
408 for (size_t j = 0; j < new_axis_mask_bitmap_.size() && j < begin_.size(); ++j) {
409 if (new_axis_mask_bitmap_[j]) {
410 (void)tensor_map.insert(tensor_map.cbegin() + j, MAP_NONE);
411 }
412 }
413
414 // If the ith bit of `shrink_axis_mask` is set, delete that dimension.
415 Shape out_tensor_map;
416 for (size_t k = 0; k < shrink_axis_mask_bitmap_.size() && k < tensor_map.size(); ++k) {
417 if (k < begin_.size() && shrink_axis_mask_bitmap_[k]) {
418 continue;
419 }
420 out_tensor_map.push_back(tensor_map[k]);
421 }
422
423 MS_LOG(INFO) << name_ << ": The output tensor map is " << out_tensor_map;
424 outputs_tensor_map_.push_back(out_tensor_map);
425 return SUCCESS;
426 }
427
ChangeCNodeBegin()428 void StridedSliceInfo::ChangeCNodeBegin() {
429 if (!skip_redistribution_) {
430 return;
431 }
432 auto shard_size = strategy_->GetInputDim()[0];
433 auto begin_new = begin_;
434 for (size_t i = 0; i < shard_size.size(); ++i) {
435 MS_EXCEPTION_IF_ZERO("shard_size", shard_size[i]);
436 begin_new[i] = begin_new[i] / shard_size[i];
437 }
438 auto begin_new_value = MakeValue(begin_new);
439 auto new_begin_value_node = std::make_shared<ValueNode>(begin_new_value);
440 auto manager = cnode_->func_graph()->manager();
441 manager->SetEdge(cnode_, STRIDE_SLICE_CNODE_BEGIN_INDEX, new_begin_value_node);
442 }
443
ChangeCNodeEnd()444 void StridedSliceInfo::ChangeCNodeEnd() {
445 if (!skip_redistribution_) {
446 return;
447 }
448 auto shard_size = strategy_->GetInputDim()[0];
449 auto end_new = end_;
450 for (size_t i = 0; i < shard_size.size(); ++i) {
451 MS_EXCEPTION_IF_ZERO("shard_size", shard_size[i]);
452 end_new[i] = end_new[i] / shard_size[i];
453 }
454 auto end_new_value = MakeValue(end_new);
455 auto new_end_value_node = std::make_shared<ValueNode>(end_new_value);
456 auto manager = cnode_->func_graph()->manager();
457 manager->SetEdge(cnode_, STRIDE_SLICE_CNODE_END_INDEX, new_end_value_node);
458 }
459
InferMirrorOps()460 Status StridedSliceInfo::InferMirrorOps() {
461 mirror_ops_.clear();
462 if (inputs_tensor_map_.empty()) {
463 MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
464 return FAILED;
465 }
466 Shape input_tensor_map = inputs_tensor_map_[0];
467 std::vector<Group> group;
468 if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
469 ReportError(name_ + ": Create group failed.");
470 return FAILED;
471 }
472
473 if (group.empty()) {
474 MS_LOG(INFO) << name_ << ": The mirror group is empty.";
475 return SUCCESS;
476 }
477
478 OperatorVector input_op, begin_op, end_op, strides_op;
479 input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
480 mirror_ops_.push_back(input_op);
481 mirror_ops_.push_back(begin_op);
482 mirror_ops_.push_back(end_op);
483 mirror_ops_.push_back(strides_op);
484 OperatorVector op_helper;
485 auto prim_name = GetPrimNameFromInfoName(name_);
486 auto res_size = ops::GetOpInputsNum(prim_name) - mirror_ops_.size();
487 for (size_t i = 0; i < res_size; ++i) {
488 mirror_ops_.push_back(op_helper);
489 }
490 return SUCCESS;
491 }
492
replace_graph(const CNodePtr & cnode)493 ReplaceGraphPtr StridedSliceInfo::replace_graph(const CNodePtr &cnode) {
494 if (!skip_redistribution_) {
495 return nullptr;
496 }
497
498 bool begin_is_constant = GetValueNode(cnode->input(STRIDE_SLICE_CNODE_BEGIN_INDEX)) != nullptr;
499 bool end_is_constant = GetValueNode(cnode->input(STRIDE_SLICE_CNODE_END_INDEX)) != nullptr;
500 if (begin_is_constant && end_is_constant) {
501 ChangeCNodeBegin();
502 ChangeCNodeEnd();
503 return nullptr;
504 }
505
506 if (!begin_is_constant && !IsPrimitiveCNode(cnode->input(STRIDE_SLICE_CNODE_BEGIN_INDEX), prim::kPrimMakeTuple)) {
507 MS_LOG(EXCEPTION) << name_ << ": the begin is not constant value, and it is not make tuple";
508 }
509
510 if (!end_is_constant && !IsPrimitiveCNode(cnode->input(STRIDE_SLICE_CNODE_END_INDEX), prim::kPrimMakeTuple)) {
511 MS_LOG(EXCEPTION) << name_ << ": the end is not constant value, and it is not make tuple";
512 }
513
514 // need to handle the constant part of begin/end
515 if (!begin_is_constant) {
516 // constant element of begin div by shard size
517 ChangeMakeTupleConstant(cnode, STRIDE_SLICE_CNODE_BEGIN_INDEX);
518 } else {
519 ChangeCNodeBegin();
520 }
521
522 if (!end_is_constant) {
523 // constant element of end div by shard size
524 ChangeMakeTupleConstant(cnode, STRIDE_SLICE_CNODE_END_INDEX);
525 } else {
526 ChangeCNodeEnd();
527 }
528
529 return nullptr;
530 }
531
532 // Note: if the batch dimension is not fully fetched, the batch strategy may not work.
GenerateBatchStrategies()533 std::shared_ptr<Strategies> StridedSliceInfo::GenerateBatchStrategies() {
534 if (GetAttrs() != SUCCESS) {
535 MS_LOG(EXCEPTION) << name_ << "generate batch parallel strategies failed.";
536 }
537 split_flag_list_ = {true};
538
539 if (!fully_fetch_flag_[0]) {
540 split_flag_list_ = {false};
541 }
542 return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
543 }
544
SetCostUnderStrategy(const StrategyPtr & strategy)545 Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
546 return SetCostUnderStrategyBase(strategy);
547 }
548
GenerateOpStrategies(int64_t stage_id)549 std::vector<StrategyPtr> StridedSliceInfo::GenerateOpStrategies(int64_t stage_id) {
550 Shape input_split(inputs_shape_[0].size(), 1);
551 for (size_t i = 0; i < begin_.size(); ++i) {
552 if (!fully_fetch_flag_[i] || (strides_[i] != 1)) {
553 input_split[i] = 0;
554 }
555 }
556 Shapes splittable_inputs = {input_split};
557
558 std::vector<StrategyPtr> sp_vector;
559 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
560 MS_LOG(EXCEPTION) << name_ << ": generate strategies failed";
561 }
562
563 return sp_vector;
564 }
565
566 REGISTER(StridedSliceInfo);
567 } // namespace parallel
568 } // namespace mindspore
569