1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/strided_slice_checker.h"
19 #include <functional>
20 #include <vector>
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "ops/op_name.h"
23
24 namespace mindspore {
25 namespace opt {
CheckCommonInfo(const CNodePtr & strided_slice)26 bool StridedSliceChecker::CheckCommonInfo(const CNodePtr &strided_slice) {
27 if (strided_slice == nullptr || strided_slice->size() > kInputSizeFive) {
28 return false;
29 }
30 if (IsMarkedTrainOp(strided_slice)) {
31 return false;
32 }
33 auto prim = GetCNodePrimitive(strided_slice);
34 MS_CHECK_TRUE_RET(prim != nullptr, false);
35 if (IsQuantParameterNode(prim)) {
36 return false;
37 }
38 auto ellipsis_mask =
39 prim->GetAttr(ops::kEllipsisMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kEllipsisMask)) : 0;
40 auto new_axis_mask =
41 prim->GetAttr(ops::kNewAxisMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kNewAxisMask)) : 0;
42 auto shrink_axis_mask =
43 prim->GetAttr(ops::kShrinkAxisMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kShrinkAxisMask)) : 0;
44 if (ellipsis_mask != 0 || new_axis_mask != 0 || shrink_axis_mask != 0) {
45 return false;
46 }
47
48 if (!CheckStepIsOne(strided_slice)) {
49 return false;
50 }
51 return true;
52 }
53
GetBegin(const CNodePtr & strided_slice,std::vector<int> * begin)54 int StridedSliceChecker::GetBegin(const CNodePtr &strided_slice, std::vector<int> *begin) {
55 if (strided_slice == nullptr || begin == nullptr) {
56 MS_LOG(ERROR) << "exist in-parameter is a nullptr.";
57 return lite::RET_NULL_PTR;
58 }
59 auto prim = GetCNodePrimitive(strided_slice);
60 MS_CHECK_TRUE_MSG(prim != nullptr, false, "Strided_slice's prim is a nullptr.");
61 auto begin_mask = prim->GetAttr(ops::kBeginMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kBeginMask)) : 0;
62 lite::DataInfo data;
63 auto ret = GetConstTensor(strided_slice, ops::kInputIndex2, &data);
64 if (ret == lite::RET_NOT_SUPPORT) {
65 return ret;
66 }
67 if (ret != lite::RET_OK) {
68 MS_LOG(ERROR) << "Get Strided_slice's begin failed, node name is " << strided_slice->fullname_with_scope();
69 return lite::RET_ERROR;
70 }
71 auto num = std::accumulate(data.shape_.begin(), data.shape_.end(), 1, std::multiplies<>());
72 for (int i = 0; i < num; ++i) {
73 bool begin_ineffective = (begin_mask & (1 << i));
74 int cur_begin = begin_ineffective ? 0 : static_cast<int *>(data.data_ptr_)[i];
75 begin->push_back(cur_begin);
76 }
77 return lite::RET_OK;
78 }
79
GetEnd(const CNodePtr & strided_slice,std::vector<int> * end)80 int StridedSliceChecker::GetEnd(const CNodePtr &strided_slice, std::vector<int> *end) {
81 if (strided_slice == nullptr || end == nullptr) {
82 MS_LOG(ERROR) << "exist in-parameter is a nullptr.";
83 return lite::RET_NULL_PTR;
84 }
85 auto prim = GetCNodePrimitive(strided_slice);
86 MS_CHECK_TRUE_MSG(prim != nullptr, false, "Strided_slice's prim is a nullptr.");
87 auto end_mask = prim->GetAttr(ops::kEndMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kEndMask)) : 0;
88 lite::DataInfo data;
89 auto ret = GetConstTensor(strided_slice, ops::kInputIndex3, &data);
90 if (ret == lite::RET_NOT_SUPPORT) {
91 return ret;
92 }
93 if (ret != lite::RET_OK) {
94 MS_LOG(ERROR) << "Get Strided_slice's end failed, node name is " << strided_slice->fullname_with_scope();
95 return lite::RET_ERROR;
96 }
97 auto num = std::accumulate(data.shape_.begin(), data.shape_.end(), 1, std::multiplies<>());
98 for (int i = 0; i < num; ++i) {
99 bool end_ineffective = (end_mask & (1 << i));
100 int cur_end = end_ineffective ? INT_MAX : static_cast<int *>(data.data_ptr_)[i];
101 end->push_back(cur_end);
102 }
103 return lite::RET_OK;
104 }
105
CheckStepIsOne(const CNodePtr & strided_slice)106 bool StridedSliceChecker::CheckStepIsOne(const CNodePtr &strided_slice) {
107 if (strided_slice == nullptr) {
108 return false;
109 }
110 if (strided_slice->size() < kInputSizeFive) {
111 return true;
112 }
113 lite::DataInfo data;
114 auto status = GetConstTensor(strided_slice, ops::kInputIndex4, &data);
115 if (status != lite::RET_OK) {
116 return false;
117 }
118 auto num = std::accumulate(data.shape_.begin(), data.shape_.end(), 1, std::multiplies<>());
119 std::vector<int> temp(num, 1);
120 return memcmp(data.data_ptr_, temp.data(), temp.size() * sizeof(int)) == 0;
121 }
122
GetConstTensor(const CNodePtr & strided_slice,size_t index,lite::DataInfo * data_info)123 int StridedSliceChecker::GetConstTensor(const CNodePtr &strided_slice, size_t index, lite::DataInfo *data_info) {
124 if (strided_slice == nullptr || data_info == nullptr) {
125 MS_LOG(ERROR) << "exist in-parameter is a nullptr.";
126 return lite::RET_NULL_PTR;
127 }
128 if (index >= strided_slice->size() || strided_slice->input(index) == nullptr) {
129 MS_LOG(ERROR) << "Strided_slice input is invalid, node is " << strided_slice->fullname_with_scope();
130 return lite::RET_ERROR;
131 }
132 if (utils::isa<CNode>(strided_slice->input(index))) {
133 MS_LOG(DEBUG) << "Strided_slice " << index << " input is not a constant, node is "
134 << strided_slice->fullname_with_scope();
135 return lite::RET_NOT_SUPPORT;
136 }
137 if (lite::FetchConstData(strided_slice, index, converter::kFmkTypeMs, data_info, true) != lite::RET_OK) {
138 MS_LOG(ERROR) << "Get Strided_slice " << index << "-input failed, node is " << strided_slice->fullname_with_scope();
139 return lite::RET_ERROR;
140 }
141 data_info->data_ptr_ = data_info->data_.data();
142 if (data_info->data_ptr_ == nullptr ||
143 (data_info->data_type_ != kNumberTypeInt && data_info->data_type_ != kNumberTypeInt32)) {
144 MS_LOG(ERROR) << "Get Strided_slice's constant failed, node name is " << strided_slice->fullname_with_scope();
145 return lite::RET_ERROR;
146 }
147 return lite::RET_OK;
148 }
149 } // namespace opt
150 } // namespace mindspore
151