• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_OPS_VIEW_VIEWSTRIDESCALCULATOR_H_
17 #define MINDSPORE_CORE_OPS_VIEW_VIEWSTRIDESCALCULATOR_H_
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 #include <optional>
23 #include <utility>
24 #include "ir/tensor.h"
25 #include "utils/hash_map.h"
26 #include "ir/value.h"
27 #include "ops/op_name.h"
28 #include "ir/primitive.h"
29 
30 namespace mindspore {
31 namespace ops {
32 using TensorStorageInfoPtrList = std::vector<TensorStorageInfoPtr>;
33 // unsupported will return {}
34 using StridesCalcFunc = std::function<TensorStorageInfoPtrList(const PrimitivePtr &, const std::vector<ValuePtr> &)>;
35 using StridesVecotr = std::vector<int64_t>;
36 std::vector<int64_t> GetOriStrides(const std::vector<int64_t> &shape);
37 bool IsContiguous(const ShapeVector &shape, const std::vector<int64_t> &strides);
38 int64_t DynamicDimWrap(int64_t dim, int64_t dim_post_expr);
39 bool IsDynamic(const std::vector<int64_t> &shape);
40 bool HasZero(const std::vector<int64_t> &value);
41 bool CheckInputsNull(const std::vector<ValuePtr> &inputs, const size_t &input_num);
42 
43 struct OldTensorInfo {
OldTensorInfoOldTensorInfo44   OldTensorInfo(std::vector<int64_t> old_shape, std::vector<int64_t> old_strides, std::vector<int64_t> ori_shape,
45                 std::vector<int64_t> ori_strides, size_t old_offset)
46       : old_shape(std::move(old_shape)),
47         old_strides(std::move(old_strides)),
48         ori_shape(std::move(ori_shape)),
49         ori_strides(std::move(ori_strides)),
50         old_offset(old_offset) {}
51   std::vector<int64_t> old_shape;
52   std::vector<int64_t> old_strides;
53   std::vector<int64_t> ori_shape;
54   std::vector<int64_t> ori_strides;
55   size_t old_offset;
56 };
57 using OldTensorInfoPtr = std::shared_ptr<OldTensorInfo>;
58 
59 OldTensorInfoPtr GetOldTensorInfo(const tensor::BaseTensorPtr &tensor);
60 
61 class MIND_API ViewStridesCalcFactory {
62  public:
63   static ViewStridesCalcFactory &GetInstance();
64   ViewStridesCalcFactory() = default;
65   ~ViewStridesCalcFactory() = default;
AddStridesCalcFunc(const std::string & op_name,const StridesCalcFunc & func)66   void AddStridesCalcFunc(const std::string &op_name, const StridesCalcFunc &func) {
67     strides_calc_map_[op_name] = func;
68   }
69 
AddTupleOutStridesCalcFunc(const std::string & op_name,const StridesCalcFunc & func)70   void AddTupleOutStridesCalcFunc(const std::string &op_name, const StridesCalcFunc &func) {
71     tuple_out_strides_calc_map_[op_name] = func;
72   }
73 
GetStridesCalcFunc(const std::string & op_name)74   std::pair<std::optional<StridesCalcFunc>, bool> GetStridesCalcFunc(const std::string &op_name) {
75     const auto &iter = strides_calc_map_.find(op_name);
76     if (iter != strides_calc_map_.end()) {
77       return std::make_pair(iter->second, false);
78     }
79 
80     const auto &tuple_iter = tuple_out_strides_calc_map_.find(op_name);
81     if (tuple_iter != tuple_out_strides_calc_map_.end()) {
82       return std::make_pair(tuple_iter->second, true);
83     }
84 
85     return std::make_pair(std::nullopt, false);
86   }
87 
88  private:
89   mindspore::HashMap<std::string, StridesCalcFunc> strides_calc_map_;
90   mindspore::HashMap<std::string, StridesCalcFunc> tuple_out_strides_calc_map_;
91 };
92 
93 class ViewStridesCalcRegistrar {
94  public:
95   ViewStridesCalcRegistrar(const std::string &op_name, const StridesCalcFunc &func, bool is_tuple = false) {
96     if (is_tuple) {
97       ViewStridesCalcFactory::GetInstance().AddTupleOutStridesCalcFunc(op_name, func);
98     } else {
99       ViewStridesCalcFactory::GetInstance().AddStridesCalcFunc(op_name, func);
100     }
101   }
102 
103   ~ViewStridesCalcRegistrar() = default;
104 };
105 
106 #define REG_VIEW_STRIDES_CALC_FUN(op_name, func) \
107   static ViewStridesCalcRegistrar g_##op_name##StridesCalcReg(#op_name, func);
108 
109 #define REG_TUPLE_OUT_VIEW_STRIDES_CALC_FUN(op_name, func) \
110   static ViewStridesCalcRegistrar g_##op_name##StridesCalcReg(#op_name, func, true);
111 }  // namespace ops
112 }  // namespace mindspore
113 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_VIEWSTRIDESCALCULATOR_H_
114