1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_ 19 20 #include <string> 21 22 #include <memory> 23 #include <vector> 24 25 #include "utils/hash_map.h" 26 #include "ir/value.h" 27 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 #include "frontend/parallel/strategy.h" 30 31 namespace mindspore { 32 namespace parallel { 33 constexpr size_t STRIDE_SLICE_CNODE_BEGIN_INDEX = 2; 34 constexpr size_t STRIDE_SLICE_CNODE_END_INDEX = 3; 35 class StridedSliceInfo : public OperatorInfo { 36 public: StridedSliceInfo(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)37 StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, 38 const PrimitiveAttrs &attrs) 39 : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {} 40 ~StridedSliceInfo() override = default; 41 42 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 43 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; 44 std::shared_ptr<Strategies> GenerateBatchStrategies() override; 45 ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; begin()46 std::vector<int64_t> begin() const { return begin_; } strides()47 std::vector<int64_t> strides() const { return strides_; } new_axis_mask_bitmap()48 std::vector<bool> new_axis_mask_bitmap() const { return new_axis_mask_bitmap_; } fully_fetch_flag()49 std::vector<bool> fully_fetch_flag() const { return fully_fetch_flag_; } skip_redistribution()50 bool skip_redistribution() const { return skip_redistribution_; } 51 Status GetAttrs() override; 52 53 protected: 54 Status CheckStrategy(const StrategyPtr &strategy) override; 55 Status InferMirrorOps() override; InferForwardCommunication()56 Status InferForwardCommunication() override { return SUCCESS; } 57 Status InferDevMatrixShape() override; 58 Status InferTensorMap() override; 59 Status GetMask(const std::string &mask_name, int64_t *mask_value); 60 void ChangeCNodeBegin(); 61 void ChangeCNodeEnd(); 62 63 private: 64 std::vector<int64_t> begin_; 65 std::vector<int64_t> end_; 66 std::vector<int64_t> strides_; 67 int64_t begin_mask_ = 0; 68 int64_t end_mask_ = 0; 69 int64_t ellipsis_mask_ = 0; 70 int64_t new_axis_mask_ = 0; 71 int64_t shrink_axis_mask_ = 0; 72 bool has_mask_ = false; 73 std::vector<bool> fully_fetch_flag_; 74 bool skip_redistribution_ = false; 75 std::vector<bool> begin_mask_bitmap_; 76 std::vector<bool> end_mask_bitmap_; 77 std::vector<bool> ellipsis_mask_bitmap_; 78 std::vector<bool> new_axis_mask_bitmap_; 79 std::vector<bool> shrink_axis_mask_bitmap_; 80 Shape input_shape_in_process_; 81 void ComputeBeginMask(); 82 void ComputeEndMask(); 83 void ComputeEllipsisMask(); 84 void ComputeNewAxisMask(); 85 void ComputeFullyFetchFlag(); 86 void AdjustShrinkAxisMask(); 87 Status CheckInputStrategy(const Shape &strategy); 88 }; 89 90 using StridedSliceInfoPtr = std::shared_ptr<StridedSliceInfo>; 91 } // namespace parallel 92 } // namespace mindspore 93 94 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STRIDED_SLICE_INFO_H_ 95