1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/view/view_strides_calculator.h"
17
18 namespace mindspore::ops {
GetInstance()19 ViewStridesCalcFactory &ViewStridesCalcFactory::GetInstance() {
20 static ViewStridesCalcFactory instance;
21 return instance;
22 }
23
IsDynamic(const std::vector<int64_t> & shape)24 bool IsDynamic(const std::vector<int64_t> &shape) {
25 return std::any_of(shape.begin(), shape.end(), [](int64_t value) { return value < 0; });
26 }
27
HasZero(const std::vector<int64_t> & value)28 bool HasZero(const std::vector<int64_t> &value) {
29 for (size_t i = 0; i < value.size(); ++i) {
30 if (value[i] == 0) {
31 return true;
32 }
33 }
34 return false;
35 }
36
CheckInputsNull(const std::vector<ValuePtr> & inputs,const size_t & input_num)37 bool CheckInputsNull(const std::vector<ValuePtr> &inputs, const size_t &input_num) {
38 if (inputs.size() != input_num) {
39 MS_LOG(DEBUG) << "inputs.size() is not equal to input_num, inputs.size():" << inputs.size()
40 << " input_num:" << input_num;
41 return true;
42 }
43
44 return std::any_of(inputs.cbegin(), inputs.cend(), [](const ValuePtr &v) { return v == nullptr; });
45 }
46
GetOriStrides(const std::vector<int64_t> & shape)47 std::vector<int64_t> GetOriStrides(const std::vector<int64_t> &shape) {
48 if (shape.empty()) {
49 return {};
50 }
51
52 std::vector<int64_t> ret(shape.size(), 1);
53 int64_t strides = 1;
54 for (size_t i = shape.size() - 1; i > 0; --i) {
55 strides *= shape[i];
56 ret[i - 1] = strides;
57 }
58 return ret;
59 }
60
IsContiguous(const ShapeVector & shape,const std::vector<int64_t> & strides)61 bool IsContiguous(const ShapeVector &shape, const std::vector<int64_t> &strides) {
62 if (shape.size() == 0) {
63 return true;
64 }
65 if (shape.size() != strides.size()) {
66 MS_LOG(EXCEPTION) << "shape.size() != strides.size()";
67 }
68
69 int64_t z = 1;
70 for (int64_t i = SizeToLong(shape.size() - 1); i >= 0; --i) {
71 const auto &shape_i = shape[i];
72 if (shape_i != 1) {
73 if (strides[i] == z) {
74 z *= shape_i;
75 } else {
76 return false;
77 }
78 }
79 }
80
81 return true;
82 }
83
DynamicDimWrap(int64_t dim,int64_t dim_post_expr)84 int64_t DynamicDimWrap(int64_t dim, int64_t dim_post_expr) {
85 if (dim_post_expr * -1 <= dim && dim < dim_post_expr) {
86 if (dim < 0) {
87 return dim + dim_post_expr;
88 }
89 return dim;
90 }
91 MS_EXCEPTION(ValueError) << "dim value error. dim:" << dim << ", dim value should be in [" << -dim_post_expr << ", "
92 << dim_post_expr << ").";
93 }
94
GetOldTensorInfo(const tensor::BaseTensorPtr & tensor)95 OldTensorInfoPtr GetOldTensorInfo(const tensor::BaseTensorPtr &tensor) {
96 if (tensor->storage_info() == nullptr) {
97 auto old_strides = GetOriStrides(tensor->shape());
98 return std::make_shared<OldTensorInfo>(tensor->shape(), old_strides, tensor->shape(), old_strides, 0);
99 } else {
100 auto storage_info = tensor->storage_info();
101 return std::make_shared<OldTensorInfo>(storage_info->shape, storage_info->strides, storage_info->ori_shape,
102 storage_info->ori_strides, storage_info->storage_offset);
103 }
104 }
105 } // namespace mindspore::ops
106