• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/strided_slice_info.h"
18 
19 #include <bitset>
20 #include <algorithm>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 #include <functional>
25 
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/dynamic_creator.h"
28 #include "frontend/parallel/strategy.h"
29 #include "frontend/parallel/graph_util/node_info.h"
30 #include "frontend/parallel/graph_util/graph_utils.h"
31 #include "frontend/parallel/step_parallel.h"
32 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
33 #include "pipeline/jit/ps/resource.h"
34 #include "mindspore/core/symbolic_shape/symbol.h"
35 
36 namespace mindspore {
37 namespace parallel {
38 // 1, The mask is a int number, it needs to be converted to binary and reversed.
39 //    (e.g. the input's dimension is 4, and mask is 2, binary is [0, 0, 1, 0], after reversing: [0, 1, 0, 0])
40 // 2, If the ith bit of `begin_mask` is set, `begin[i]` is ignored.
41 // 3, If the ith bit of `end_mask` is set, `end[i]` is ignored.
42 // 4, If the ith bit of `ellipsis_mask` is set, begin[i]/end[i]/strides[i] replace to `...`, it is not supported now.
43 // 5, If the ith bit of `new_axis_mask` is set:
44 //    (e.g. input shape: (A, B, C, D), begin: (0, 0), end: (m, n), strides: (1, 1), new_axis_mask: 2)
45 //    1) The corresponding position is expanded by one dimension; (input shape:(A, 1, B, C, D))
46 //    2) Ignore the corresponding position of begin/end/strides; (begin: (0, ig), end: (m, ig), strides: (1, ig))
47 //    3) The output shape is (m, 1, B, C, D)
48 // 6, If the ith bit of `shrink_axis_mask` is set, delete that dimension.
49 //    (e.g. input shape: (A, B, C, D), begin: (0, 0), end: (m, n), strides: (1, 1), shrink_axis_mask: 2,
50 //     the output shape: (m, C, D)
51 //    notice: if input is [[1, 2], [3, 4]] and all fetch, but shrink_axis_mask is 1, then the output is [1, 2],
52 //            so if the ith bit of 'shrink_axis_mask' is set, the dimension can not be split
53 // 7, If the ith bit of `new_axis_mask` and `shrink_axis_mask` are both set, ignore the ith bit of `shrink_axis_mask`.
54 // 8, The size of begin/mask/strides must be equal, but it can smaller than input's dimension.
55 // 9, The mask part exceeding the begin/end/strides length is not effective.
GetMask(const std::string & mask_name,int64_t * mask_value)56 Status StridedSliceInfo::GetMask(const std::string &mask_name, int64_t *mask_value) {
57   if (mask_value == nullptr) {
58     return FAILED;
59   }
60 
61   auto mask_opt = GetScalarValueFromInputs<int64_t>(input_value_, name_, mask_name);
62   if (!mask_opt.has_value()) {
63     MS_LOG(ERROR) << name_ << " failed to get value for " << mask_name << ".";
64   }
65   *mask_value = mask_opt.value();
66   MS_LOG(INFO) << name_ << ": The attr name: " << mask_name << ", the value is " << *mask_value;
67   return SUCCESS;
68 }
69 
70 constexpr auto kStridedSliceMaxDims = 8;
Dec2Bin(int64_t mask)71 static std::vector<bool> Dec2Bin(int64_t mask) {
72   auto mask_str = std::bitset<kStridedSliceMaxDims>(mask).to_string();
73   std::vector<bool> result;
74   (void)std::transform(mask_str.rbegin(), mask_str.rend(), std::back_inserter(result),
75                        [](const char &c) { return c == '1'; });
76   return result;
77 }
78 
79 // If the ith bit of `begin_mask` is set, `begin[i]` is ignored.
80 // The mask part exceeding the begin length is not effective.
ComputeBeginMask()81 void StridedSliceInfo::ComputeBeginMask() {
82   for (size_t i = 0; i < begin_mask_bitmap_.size() && i < begin_.size(); ++i) {
83     if (begin_mask_bitmap_[i]) {
84       begin_[i] = strides_[i] < 0 ? inputs_shape_[0][i] - 1 : 0;
85     }
86   }
87 
88   if (begin_mask_ != 0) {
89     MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_;
90   }
91 }
92 
93 // If the ith bit of `end_mask` is set, `end[i]` is ignored.
94 // The mask part exceeding the end length is not effective.
ComputeEndMask()95 void StridedSliceInfo::ComputeEndMask() {
96   for (size_t j = 0; j < end_mask_bitmap_.size() && j < end_.size(); ++j) {
97     if (end_mask_bitmap_[j]) {
98       end_[j] = strides_[j] < 0 ? -1 : inputs_shape_[0][j];
99     }
100   }
101 
102   if (end_mask_ != 0) {
103     MS_LOG(INFO) << name_ << ": The end is modified to " << end_;
104   }
105 }
106 
107 // If the ith bit of `ellipsis_mask` is set, begin[i]/end[i]/strides[i] replace to `...`, it is not supported now.
ComputeEllipsisMask()108 void StridedSliceInfo::ComputeEllipsisMask() {
109   for (size_t k = 0; k < ellipsis_mask_bitmap_.size() && k < begin_.size(); ++k) {
110     if (ellipsis_mask_bitmap_[k]) {
111       begin_[k] = 0;
112       end_[k] = inputs_shape_[0][k];
113       strides_[k] = 1;
114     }
115   }
116 }
117 
118 // If the ith bit of `new_axis_mask` is set:
119 //    (e.g. input shape: (A, B, C, D), begin: (0, 0, 0, 0), end: (m, n, o, p), strides: (1, 1, 1, 1), new_axis_mask: 2)
120 //    Here, the size of begin/end/strides is equal to input's dimension through ComplementBeginEndStrides()
121 //    1) The corresponding position is expanded by one dimension; (input shape:(A, 1, B, C, D))
122 //    2) Ignore the corresponding position of begin/end/strides;
123 //       (begin: (0, ig, 0, 0), end: (m, ig, o, p), strides: (1, ig, 1, 1))
124 //    3) The output shape is (m, 1, o, p, D)
125 // So, use input_shape_in_process_ to generate a tmp input shape
ComputeNewAxisMask()126 void StridedSliceInfo::ComputeNewAxisMask() {
127   input_shape_in_process_ = Shape(inputs_shape_[0].size(), 0);
128   for (size_t l = 0; l < new_axis_mask_bitmap_.size() && l < begin_.size() && l < input_shape_in_process_.size(); ++l) {
129     if (new_axis_mask_bitmap_[l]) {
130       input_shape_in_process_[l] = 1;
131       begin_[l] = 0;
132       end_[l] = 1;
133       strides_[l] = 1;
134     }
135   }
136 
137   size_t count = 0;
138   for (auto &ele : input_shape_in_process_) {
139     if (ele != 0) {
140       continue;
141     }
142     ele = inputs_shape_[0][count];
143     count++;
144   }
145 
146   (void)input_shape_in_process_.insert(input_shape_in_process_.end(), inputs_shape_[0].begin() + count,
147                                        inputs_shape_[0].end());
148 
149   if (new_axis_mask_ != 0) {
150     MS_LOG(INFO) << name_ << ": The begin is modified to " << begin_ << ", the end is modified to " << end_
151                  << ", the strides is modified to " << strides_ << ", the input shape in process is "
152                  << input_shape_in_process_;
153   }
154 }
155 
156 // If the ith bit of `new_axis_mask` and `shrink_axis_mask` are both set, ignore the ith bit of `shrink_axis_mask`.
AdjustShrinkAxisMask()157 void StridedSliceInfo::AdjustShrinkAxisMask() {
158   bool flag = false;
159   for (size_t i = 0; i < new_axis_mask_bitmap_.size(); ++i) {
160     if (new_axis_mask_bitmap_[i]) {
161       shrink_axis_mask_bitmap_[i] = false;
162       flag = true;
163     }
164   }
165   if (flag) {
166     MS_LOG(INFO) << name_ << ": The shrink axis mask is modified to " << shrink_axis_mask_bitmap_;
167   }
168 }
169 
ComputeFullyFetchFlag()170 void StridedSliceInfo::ComputeFullyFetchFlag() {
171   ListSymbolPtr in_symbol = nullptr;
172   ListSymbolPtr out_symbol = nullptr;
173 
174   if (dynamic_shape_flag_) {
175     MS_EXCEPTION_IF_NULL(cnode_);
176     MS_EXCEPTION_IF_NULL(cnode_->input(1));
177     MS_EXCEPTION_IF_NULL(cnode_->input(1)->abstract());
178     MS_EXCEPTION_IF_NULL(cnode_->abstract());
179     in_symbol = cnode_->input(1)->abstract()->GetSymbolicShape();  // the input of stridedslice
180     out_symbol = cnode_->abstract()->GetSymbolicShape();           // the output of stridedslice
181     MS_EXCEPTION_IF_NULL(in_symbol);
182     MS_EXCEPTION_IF_NULL(out_symbol);
183   }
184   fully_fetch_flag_.clear();
185 
186   for (size_t k = 0; k < begin_.size(); ++k) {
187     bool fully_fetch = false;
188     if (dynamic_shape_flag_) {
189       MS_EXCEPTION_IF_NULL(in_symbol->item(k));
190       if (in_symbol->item(k)->EqualsTo(out_symbol->item(k))) {
191         fully_fetch = true;
192       }
193     } else {
194       fully_fetch = ((begin_[k] == 0) && (end_[k] >= input_shape_in_process_[k]));
195     }
196     fully_fetch_flag_.push_back(fully_fetch);
197   }
198 
199   MS_LOG(INFO) << name_ << ": the fully fetch flag is " << fully_fetch_flag_;
200 }
201 
GetAttrs()202 Status StridedSliceInfo::GetAttrs() {
203   if ((GetMask(BEGIN_MASK, &begin_mask_) != SUCCESS) || (GetMask(END_MASK, &end_mask_) != SUCCESS) ||
204       (GetMask(ELLIPSIS_MASK, &ellipsis_mask_) != SUCCESS) || (GetMask(NEW_AXIS_MASK, &new_axis_mask_) != SUCCESS) ||
205       (GetMask(SHRINK_AXIS_MASK, &shrink_axis_mask_) != SUCCESS)) {
206     return FAILED;
207   }
208 
209   has_mask_ =
210     (begin_mask_ != 0 || end_mask_ != 0 || ellipsis_mask_ != 0 || new_axis_mask_ != 0 || shrink_axis_mask_ != 0);
211 
212   if (ellipsis_mask_ != 0) {
213     MS_LOG(ERROR) << name_ << ": It can not support ellipsis_mask now";
214     return FAILED;
215   }
216 
217   // convert mask to bit map
218   begin_mask_bitmap_ = Dec2Bin(begin_mask_);
219   end_mask_bitmap_ = Dec2Bin(end_mask_);
220   ellipsis_mask_bitmap_ = Dec2Bin(ellipsis_mask_);
221   new_axis_mask_bitmap_ = Dec2Bin(new_axis_mask_);
222   shrink_axis_mask_bitmap_ = Dec2Bin(shrink_axis_mask_);
223   MS_LOG(INFO) << name_ << ": The begin mask bitmap is " << begin_mask_bitmap_;
224   MS_LOG(INFO) << name_ << ": The end mask bitmap is " << end_mask_bitmap_;
225   MS_LOG(INFO) << name_ << ": The ellipsis mask bitmap is " << ellipsis_mask_bitmap_;
226   MS_LOG(INFO) << name_ << ": The new axis mask bitmap is " << new_axis_mask_bitmap_;
227   MS_LOG(INFO) << name_ << ": The shrink axis mask bitmap is " << shrink_axis_mask_bitmap_;
228 
229   // if the ith bit of `new_axis_mask` and `shrink_axis_mask` are both set, ignore the ith bit of `shrink_axis_mask`
230   AdjustShrinkAxisMask();
231 
232   // get begin/end/strides, the size of begin/mask/strides must be equal, but it can smaller than input's dimension
233   if (input_value_.size() != STRIDED_SLICE_INPUTS_SIZE) {
234     MS_LOG(ERROR) << name_ << ": The size of input value must be " << STRIDED_SLICE_INPUTS_SIZE << ", but got "
235                   << input_value_.size();
236     return FAILED;
237   }
238 
239   std::vector<int64_t> unknow_value(inputs_shape_[0].size(), -1);
240   if (input_value_[STRIDED_SLICE_BEGIN_INDEX] != nullptr) {
241     if (TransValueSequeueToVector(input_value_[STRIDED_SLICE_BEGIN_INDEX], &begin_) != SUCCESS) {
242       MS_LOG(ERROR) << name_ << ": get begin value failed";
243       return FAILED;
244     }
245   } else {
246     begin_ = unknow_value;
247   }
248 
249   if (input_value_[STRIDED_SLICE_END_INDEX] != nullptr) {
250     if (TransValueSequeueToVector(input_value_[STRIDED_SLICE_END_INDEX], &end_) != SUCCESS) {
251       MS_LOG(ERROR) << name_ << ": get end value failed";
252       return FAILED;
253     }
254   } else {
255     end_ = unknow_value;
256   }
257 
258   if (input_value_[STRIDED_SLICE_STRIDES_INDEX] != nullptr) {
259     if (TransValueSequeueToVector(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS) {
260       MS_LOG(ERROR) << name_ << ": get strides value failed";
261       return FAILED;
262     }
263   } else {
264     strides_ = unknow_value;
265   }
266 
267   MS_LOG(INFO) << name_ << ": The begin is " << begin_ << ", the end is " << end_ << ", the stride is " << strides_;
268 
269   // handle the masks, it will modify the begin/end/strides, the new begin/end/strides are only used for CheckStrategy()
270   ComputeBeginMask();
271   ComputeEndMask();
272   ComputeEllipsisMask();
273   ComputeNewAxisMask();
274   // no need to handle shrink axis mask
275   auto prim = GetCNodePrimitive(cnode_);
276   if (prim->HasAttr(parallel::SKIP_REDISTRIBUTION)) {
277     skip_redistribution_ = GetValue<bool>(prim->GetAttr(parallel::SKIP_REDISTRIBUTION));
278   }
279 
280   ComputeFullyFetchFlag();
281   return SUCCESS;
282 }
283 
CheckInputStrategy(const Shape & strategy_value)284 Status StridedSliceInfo::CheckInputStrategy(const Shape &strategy_value) {
285   // change the strategy if the new mask axis is set
286   Shape strategy_in_process = Shape(strategy_value.size(), 0);
287   for (size_t i = 0; i < new_axis_mask_bitmap_.size() && i < begin_.size() && i < strategy_value.size(); ++i) {
288     if (new_axis_mask_bitmap_[i]) {
289       strategy_in_process[i] = 1;
290     }
291   }
292 
293   size_t count = 0;
294   for (auto &ele : strategy_in_process) {
295     if (ele != 0) {
296       continue;
297     }
298     ele = strategy_value[count];
299     count++;
300   }
301 
302   (void)strategy_in_process.insert(strategy_in_process.end(), strategy_value.begin() + count, strategy_value.end());
303   MS_LOG(INFO) << name_ << ": The strategy in process is " << strategy_in_process;
304 
305   for (size_t j = 0; j < strides_.size(); ++j) {
306     if ((strides_[j] != 1) && (strategy_in_process[j] > 1)) {
307       MS_LOG(ERROR)
308         << name_
309         << ": When a certain dimension is split, now does not support that the stride is not 1, the strides is "
310         << strides_ << ", the strategy is " << strategy_in_process << ", the index is " << j;
311       return FAILED;
312     }
313   }
314 
315   for (size_t k = 0; k < begin_.size(); ++k) {
316     if (!fully_fetch_flag_[k] && (strategy_in_process[k] != 1) && !skip_redistribution_) {
317       MS_LOG(ERROR) << name_
318                     << ": When a dimension is not fully fetched, the dimension can not be split now, the begin is "
319                     << begin_ << ", the end is " << end_ << ", the index is " << k << ", the input shape in process is "
320                     << input_shape_in_process_ << ", the strategy in process is " << strategy_in_process;
321       return FAILED;
322     }
323   }
324 
325   // if the ith bit of 'shrink_axis_mask' is set, the dimension can not be split
326   for (size_t l = 0; l < strategy_in_process.size() && l < shrink_axis_mask_bitmap_.size(); ++l) {
327     if (shrink_axis_mask_bitmap_[l] && strategy_in_process[l] != 1) {
328       MS_LOG(ERROR) << name_
329                     << ": When a dimension is shrunk, the dimension can not be split now, the strategy in process is "
330                     << strategy_in_process << ", the shrink axis mask bitmap is " << shrink_axis_mask_bitmap_;
331       return FAILED;
332     }
333   }
334 
335   return SUCCESS;
336 }
337 
CheckStrategy(const StrategyPtr & strategy)338 Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) {
339   MS_EXCEPTION_IF_NULL(strategy);
340   Shapes valid_inputs_shape = {inputs_shape_[0]};
341   if (CheckStrategyValue(strategy, valid_inputs_shape) != SUCCESS) {
342     MS_LOG(ERROR) << name_ << ": Invalid strategy";
343     return FAILED;
344   }
345 
346   std::vector<Dimensions> stra = strategy->GetInputDim();
347   if (stra.empty()) {
348     MS_LOG(ERROR) << name_ << ": The strategy is empty";
349     return FAILED;
350   }
351 
352   Dimensions strategy_value = stra[0];
353   if (strategy_value.size() < strides_.size()) {
354     MS_LOG(ERROR) << name_ << ": The size of strategy must be larger or equal to the size of strides";
355     return FAILED;
356   }
357 
358   if (dynamic_shape_flag_) {
359     auto shard_num = std::accumulate(strategy_value.begin(), strategy_value.end(), 1, std::multiplies<int64_t>());
360     if (shard_num == 1) {
361       return SUCCESS;
362     }
363 
364     if (has_mask_) {
365       MS_LOG(ERROR) << name_ << ": it does not support dynamic shape when it has mask, the strategy is "
366                     << ShapeToString(strategy_value);
367       return FAILED;
368     }
369 
370     if (strides_ == Shape(inputs_shape_[0].size(), -1)) {
371       MS_LOG(ERROR) << name_ << ": it does not support dynamic shape when the strides attr is not constant";
372       return FAILED;
373     }
374   }
375 
376   return CheckInputStrategy(strategy_value);
377 }
378 
InferDevMatrixShape()379 Status StridedSliceInfo::InferDevMatrixShape() {
380   MS_EXCEPTION_IF_NULL(strategy_);
381   std::vector<Dimensions> stra = strategy_->GetInputDim();
382   if (stra.empty()) {
383     MS_LOG(ERROR) << name_ << "The strategy is empty";
384     return FAILED;
385   }
386 
387   dev_matrix_shape_ = stra[0];
388   return SUCCESS;
389 }
390 
InferTensorMap()391 Status StridedSliceInfo::InferTensorMap() {
392   TensorMap tensor_map;
393   if (inputs_shape_.empty()) {
394     MS_LOG(ERROR) << name_ << "The inputs shape is empty";
395     return FAILED;
396   }
397 
398   // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
399   int64_t size = SizeToLong(inputs_shape_[0].size());
400   for (int64_t i = 0; i < size; ++i) {
401     tensor_map.push_back(size - i - 1);
402   }
403 
404   inputs_tensor_map_.push_back(tensor_map);
405 
406   // If the ith bit of `new_axis_mask` is set, the corresponding position is expanded by one dimension, and this
407   // dimension need to insert MAP_NONE for output tensor map.
408   for (size_t j = 0; j < new_axis_mask_bitmap_.size() && j < begin_.size(); ++j) {
409     if (new_axis_mask_bitmap_[j]) {
410       (void)tensor_map.insert(tensor_map.cbegin() + j, MAP_NONE);
411     }
412   }
413 
414   // If the ith bit of `shrink_axis_mask` is set, delete that dimension.
415   Shape out_tensor_map;
416   for (size_t k = 0; k < shrink_axis_mask_bitmap_.size() && k < tensor_map.size(); ++k) {
417     if (k < begin_.size() && shrink_axis_mask_bitmap_[k]) {
418       continue;
419     }
420     out_tensor_map.push_back(tensor_map[k]);
421   }
422 
423   MS_LOG(INFO) << name_ << ": The output tensor map is " << out_tensor_map;
424   outputs_tensor_map_.push_back(out_tensor_map);
425   return SUCCESS;
426 }
427 
ChangeCNodeBegin()428 void StridedSliceInfo::ChangeCNodeBegin() {
429   if (!skip_redistribution_) {
430     return;
431   }
432   auto shard_size = strategy_->GetInputDim()[0];
433   auto begin_new = begin_;
434   for (size_t i = 0; i < shard_size.size(); ++i) {
435     MS_EXCEPTION_IF_ZERO("shard_size", shard_size[i]);
436     begin_new[i] = begin_new[i] / shard_size[i];
437   }
438   auto begin_new_value = MakeValue(begin_new);
439   auto new_begin_value_node = std::make_shared<ValueNode>(begin_new_value);
440   auto manager = cnode_->func_graph()->manager();
441   manager->SetEdge(cnode_, STRIDE_SLICE_CNODE_BEGIN_INDEX, new_begin_value_node);
442 }
443 
ChangeCNodeEnd()444 void StridedSliceInfo::ChangeCNodeEnd() {
445   if (!skip_redistribution_) {
446     return;
447   }
448   auto shard_size = strategy_->GetInputDim()[0];
449   auto end_new = end_;
450   for (size_t i = 0; i < shard_size.size(); ++i) {
451     MS_EXCEPTION_IF_ZERO("shard_size", shard_size[i]);
452     end_new[i] = end_new[i] / shard_size[i];
453   }
454   auto end_new_value = MakeValue(end_new);
455   auto new_end_value_node = std::make_shared<ValueNode>(end_new_value);
456   auto manager = cnode_->func_graph()->manager();
457   manager->SetEdge(cnode_, STRIDE_SLICE_CNODE_END_INDEX, new_end_value_node);
458 }
459 
InferMirrorOps()460 Status StridedSliceInfo::InferMirrorOps() {
461   mirror_ops_.clear();
462   if (inputs_tensor_map_.empty()) {
463     MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
464     return FAILED;
465   }
466   Shape input_tensor_map = inputs_tensor_map_[0];
467   std::vector<Group> group;
468   if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
469     ReportError(name_ + ": Create group failed.");
470     return FAILED;
471   }
472 
473   if (group.empty()) {
474     MS_LOG(INFO) << name_ << ": The mirror group is empty.";
475     return SUCCESS;
476   }
477 
478   OperatorVector input_op, begin_op, end_op, strides_op;
479   input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
480   mirror_ops_.push_back(input_op);
481   mirror_ops_.push_back(begin_op);
482   mirror_ops_.push_back(end_op);
483   mirror_ops_.push_back(strides_op);
484   OperatorVector op_helper;
485   auto prim_name = GetPrimNameFromInfoName(name_);
486   auto res_size = ops::GetOpInputsNum(prim_name) - mirror_ops_.size();
487   for (size_t i = 0; i < res_size; ++i) {
488     mirror_ops_.push_back(op_helper);
489   }
490   return SUCCESS;
491 }
492 
replace_graph(const CNodePtr & cnode)493 ReplaceGraphPtr StridedSliceInfo::replace_graph(const CNodePtr &cnode) {
494   if (!skip_redistribution_) {
495     return nullptr;
496   }
497 
498   bool begin_is_constant = GetValueNode(cnode->input(STRIDE_SLICE_CNODE_BEGIN_INDEX)) != nullptr;
499   bool end_is_constant = GetValueNode(cnode->input(STRIDE_SLICE_CNODE_END_INDEX)) != nullptr;
500   if (begin_is_constant && end_is_constant) {
501     ChangeCNodeBegin();
502     ChangeCNodeEnd();
503     return nullptr;
504   }
505 
506   if (!begin_is_constant && !IsPrimitiveCNode(cnode->input(STRIDE_SLICE_CNODE_BEGIN_INDEX), prim::kPrimMakeTuple)) {
507     MS_LOG(EXCEPTION) << name_ << ": the begin is not constant value, and it is not make tuple";
508   }
509 
510   if (!end_is_constant && !IsPrimitiveCNode(cnode->input(STRIDE_SLICE_CNODE_END_INDEX), prim::kPrimMakeTuple)) {
511     MS_LOG(EXCEPTION) << name_ << ": the end is not constant value, and it is not make tuple";
512   }
513 
514   // need to handle the constant part of begin/end
515   if (!begin_is_constant) {
516     // constant element of begin div by shard size
517     ChangeMakeTupleConstant(cnode, STRIDE_SLICE_CNODE_BEGIN_INDEX);
518   } else {
519     ChangeCNodeBegin();
520   }
521 
522   if (!end_is_constant) {
523     // constant element of end div by shard size
524     ChangeMakeTupleConstant(cnode, STRIDE_SLICE_CNODE_END_INDEX);
525   } else {
526     ChangeCNodeEnd();
527   }
528 
529   return nullptr;
530 }
531 
532 // Note: if the batch dimension is not fully fetched, the batch strategy may not work.
GenerateBatchStrategies()533 std::shared_ptr<Strategies> StridedSliceInfo::GenerateBatchStrategies() {
534   if (GetAttrs() != SUCCESS) {
535     MS_LOG(EXCEPTION) << name_ << "generate batch parallel strategies failed.";
536   }
537   split_flag_list_ = {true};
538 
539   if (!fully_fetch_flag_[0]) {
540     split_flag_list_ = {false};
541   }
542   return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
543 }
544 
SetCostUnderStrategy(const StrategyPtr & strategy)545 Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
546   return SetCostUnderStrategyBase(strategy);
547 }
548 
GenerateOpStrategies(int64_t stage_id)549 std::vector<StrategyPtr> StridedSliceInfo::GenerateOpStrategies(int64_t stage_id) {
550   Shape input_split(inputs_shape_[0].size(), 1);
551   for (size_t i = 0; i < begin_.size(); ++i) {
552     if (!fully_fetch_flag_[i] || (strides_[i] != 1)) {
553       input_split[i] = 0;
554     }
555   }
556   Shapes splittable_inputs = {input_split};
557 
558   std::vector<StrategyPtr> sp_vector;
559   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
560     MS_LOG(EXCEPTION) << name_ << ": generate strategies failed";
561   }
562 
563   return sp_vector;
564 }
565 
566 REGISTER(StridedSliceInfo);
567 }  // namespace parallel
568 }  // namespace mindspore
569