1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_OPS_VIEW_VIEWSTRIDESCALCULATOR_H_ 17 #define MINDSPORE_CORE_OPS_VIEW_VIEWSTRIDESCALCULATOR_H_ 18 19 #include <string> 20 #include <memory> 21 #include <vector> 22 #include <optional> 23 #include <utility> 24 #include "ir/tensor.h" 25 #include "utils/hash_map.h" 26 #include "ir/value.h" 27 #include "ops/op_name.h" 28 #include "ir/primitive.h" 29 30 namespace mindspore { 31 namespace ops { 32 using TensorStorageInfoPtrList = std::vector<TensorStorageInfoPtr>; 33 // unsupported will return {} 34 using StridesCalcFunc = std::function<TensorStorageInfoPtrList(const PrimitivePtr &, const std::vector<ValuePtr> &)>; 35 using StridesVecotr = std::vector<int64_t>; 36 std::vector<int64_t> GetOriStrides(const std::vector<int64_t> &shape); 37 bool IsContiguous(const ShapeVector &shape, const std::vector<int64_t> &strides); 38 int64_t DynamicDimWrap(int64_t dim, int64_t dim_post_expr); 39 bool IsDynamic(const std::vector<int64_t> &shape); 40 bool HasZero(const std::vector<int64_t> &value); 41 bool CheckInputsNull(const std::vector<ValuePtr> &inputs, const size_t &input_num); 42 43 struct OldTensorInfo { OldTensorInfoOldTensorInfo44 OldTensorInfo(std::vector<int64_t> old_shape, std::vector<int64_t> old_strides, std::vector<int64_t> ori_shape, 45 std::vector<int64_t> ori_strides, size_t old_offset) 46 : old_shape(std::move(old_shape)), 47 old_strides(std::move(old_strides)), 48 ori_shape(std::move(ori_shape)), 49 ori_strides(std::move(ori_strides)), 50 old_offset(old_offset) {} 51 std::vector<int64_t> old_shape; 52 std::vector<int64_t> old_strides; 53 std::vector<int64_t> ori_shape; 54 std::vector<int64_t> ori_strides; 55 size_t old_offset; 56 }; 57 using OldTensorInfoPtr = std::shared_ptr<OldTensorInfo>; 58 59 OldTensorInfoPtr GetOldTensorInfo(const tensor::BaseTensorPtr &tensor); 60 61 class MIND_API ViewStridesCalcFactory { 62 public: 63 static ViewStridesCalcFactory &GetInstance(); 64 ViewStridesCalcFactory() = default; 65 ~ViewStridesCalcFactory() = default; AddStridesCalcFunc(const std::string & op_name,const StridesCalcFunc & func)66 void AddStridesCalcFunc(const std::string &op_name, const StridesCalcFunc &func) { 67 strides_calc_map_[op_name] = func; 68 } 69 AddTupleOutStridesCalcFunc(const std::string & op_name,const StridesCalcFunc & func)70 void AddTupleOutStridesCalcFunc(const std::string &op_name, const StridesCalcFunc &func) { 71 tuple_out_strides_calc_map_[op_name] = func; 72 } 73 GetStridesCalcFunc(const std::string & op_name)74 std::pair<std::optional<StridesCalcFunc>, bool> GetStridesCalcFunc(const std::string &op_name) { 75 const auto &iter = strides_calc_map_.find(op_name); 76 if (iter != strides_calc_map_.end()) { 77 return std::make_pair(iter->second, false); 78 } 79 80 const auto &tuple_iter = tuple_out_strides_calc_map_.find(op_name); 81 if (tuple_iter != tuple_out_strides_calc_map_.end()) { 82 return std::make_pair(tuple_iter->second, true); 83 } 84 85 return std::make_pair(std::nullopt, false); 86 } 87 88 private: 89 mindspore::HashMap<std::string, StridesCalcFunc> strides_calc_map_; 90 mindspore::HashMap<std::string, StridesCalcFunc> tuple_out_strides_calc_map_; 91 }; 92 93 class ViewStridesCalcRegistrar { 94 public: 95 ViewStridesCalcRegistrar(const std::string &op_name, const StridesCalcFunc &func, bool is_tuple = false) { 96 if (is_tuple) { 97 ViewStridesCalcFactory::GetInstance().AddTupleOutStridesCalcFunc(op_name, func); 98 } else { 99 ViewStridesCalcFactory::GetInstance().AddStridesCalcFunc(op_name, func); 100 } 101 } 102 103 ~ViewStridesCalcRegistrar() = default; 104 }; 105 106 #define REG_VIEW_STRIDES_CALC_FUN(op_name, func) \ 107 static ViewStridesCalcRegistrar g_##op_name##StridesCalcReg(#op_name, func); 108 109 #define REG_TUPLE_OUT_VIEW_STRIDES_CALC_FUN(op_name, func) \ 110 static ViewStridesCalcRegistrar g_##op_name##StridesCalcReg(#op_name, func, true); 111 } // namespace ops 112 } // namespace mindspore 113 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_VIEWSTRIDESCALCULATOR_H_ 114