• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/view/view_strides_calculator.h"
17 
18 namespace mindspore::ops {
GetInstance()19 ViewStridesCalcFactory &ViewStridesCalcFactory::GetInstance() {
20   static ViewStridesCalcFactory instance;
21   return instance;
22 }
23 
IsDynamic(const std::vector<int64_t> & shape)24 bool IsDynamic(const std::vector<int64_t> &shape) {
25   return std::any_of(shape.begin(), shape.end(), [](int64_t value) { return value < 0; });
26 }
27 
HasZero(const std::vector<int64_t> & value)28 bool HasZero(const std::vector<int64_t> &value) {
29   for (size_t i = 0; i < value.size(); ++i) {
30     if (value[i] == 0) {
31       return true;
32     }
33   }
34   return false;
35 }
36 
CheckInputsNull(const std::vector<ValuePtr> & inputs,const size_t & input_num)37 bool CheckInputsNull(const std::vector<ValuePtr> &inputs, const size_t &input_num) {
38   if (inputs.size() != input_num) {
39     MS_LOG(DEBUG) << "inputs.size() is not equal to input_num, inputs.size():" << inputs.size()
40                   << " input_num:" << input_num;
41     return true;
42   }
43 
44   return std::any_of(inputs.cbegin(), inputs.cend(), [](const ValuePtr &v) { return v == nullptr; });
45 }
46 
GetOriStrides(const std::vector<int64_t> & shape)47 std::vector<int64_t> GetOriStrides(const std::vector<int64_t> &shape) {
48   if (shape.empty()) {
49     return {};
50   }
51 
52   std::vector<int64_t> ret(shape.size(), 1);
53   int64_t strides = 1;
54   for (size_t i = shape.size() - 1; i > 0; --i) {
55     strides *= shape[i];
56     ret[i - 1] = strides;
57   }
58   return ret;
59 }
60 
IsContiguous(const ShapeVector & shape,const std::vector<int64_t> & strides)61 bool IsContiguous(const ShapeVector &shape, const std::vector<int64_t> &strides) {
62   if (shape.size() == 0) {
63     return true;
64   }
65   if (shape.size() != strides.size()) {
66     MS_LOG(EXCEPTION) << "shape.size() != strides.size()";
67   }
68 
69   int64_t z = 1;
70   for (int64_t i = SizeToLong(shape.size() - 1); i >= 0; --i) {
71     const auto &shape_i = shape[i];
72     if (shape_i != 1) {
73       if (strides[i] == z) {
74         z *= shape_i;
75       } else {
76         return false;
77       }
78     }
79   }
80 
81   return true;
82 }
83 
DynamicDimWrap(int64_t dim,int64_t dim_post_expr)84 int64_t DynamicDimWrap(int64_t dim, int64_t dim_post_expr) {
85   if (dim_post_expr * -1 <= dim && dim < dim_post_expr) {
86     if (dim < 0) {
87       return dim + dim_post_expr;
88     }
89     return dim;
90   }
91   MS_EXCEPTION(ValueError) << "dim value error. dim:" << dim << ", dim value should be in [" << -dim_post_expr << ", "
92                            << dim_post_expr << ").";
93 }
94 
GetOldTensorInfo(const tensor::BaseTensorPtr & tensor)95 OldTensorInfoPtr GetOldTensorInfo(const tensor::BaseTensorPtr &tensor) {
96   if (tensor->storage_info() == nullptr) {
97     auto old_strides = GetOriStrides(tensor->shape());
98     return std::make_shared<OldTensorInfo>(tensor->shape(), old_strides, tensor->shape(), old_strides, 0);
99   } else {
100     auto storage_info = tensor->storage_info();
101     return std::make_shared<OldTensorInfo>(storage_info->shape, storage_info->strides, storage_info->ori_shape,
102                                            storage_info->ori_strides, storage_info->storage_offset);
103   }
104 }
105 }  // namespace mindspore::ops
106