• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/strided_slice_checker.h"
19 #include <functional>
20 #include <vector>
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "ops/op_name.h"
23 
24 namespace mindspore {
25 namespace opt {
CheckCommonInfo(const CNodePtr & strided_slice)26 bool StridedSliceChecker::CheckCommonInfo(const CNodePtr &strided_slice) {
27   if (strided_slice == nullptr || strided_slice->size() > kInputSizeFive) {
28     return false;
29   }
30   if (IsMarkedTrainOp(strided_slice)) {
31     return false;
32   }
33   auto prim = GetCNodePrimitive(strided_slice);
34   MS_CHECK_TRUE_RET(prim != nullptr, false);
35   if (IsQuantParameterNode(prim)) {
36     return false;
37   }
38   auto ellipsis_mask =
39     prim->GetAttr(ops::kEllipsisMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kEllipsisMask)) : 0;
40   auto new_axis_mask =
41     prim->GetAttr(ops::kNewAxisMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kNewAxisMask)) : 0;
42   auto shrink_axis_mask =
43     prim->GetAttr(ops::kShrinkAxisMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kShrinkAxisMask)) : 0;
44   if (ellipsis_mask != 0 || new_axis_mask != 0 || shrink_axis_mask != 0) {
45     return false;
46   }
47 
48   if (!CheckStepIsOne(strided_slice)) {
49     return false;
50   }
51   return true;
52 }
53 
GetBegin(const CNodePtr & strided_slice,std::vector<int> * begin)54 int StridedSliceChecker::GetBegin(const CNodePtr &strided_slice, std::vector<int> *begin) {
55   if (strided_slice == nullptr || begin == nullptr) {
56     MS_LOG(ERROR) << "exist in-parameter is a nullptr.";
57     return lite::RET_NULL_PTR;
58   }
59   auto prim = GetCNodePrimitive(strided_slice);
60   MS_CHECK_TRUE_MSG(prim != nullptr, false, "Strided_slice's prim is a nullptr.");
61   auto begin_mask = prim->GetAttr(ops::kBeginMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kBeginMask)) : 0;
62   lite::DataInfo data;
63   auto ret = GetConstTensor(strided_slice, ops::kInputIndex2, &data);
64   if (ret == lite::RET_NOT_SUPPORT) {
65     return ret;
66   }
67   if (ret != lite::RET_OK) {
68     MS_LOG(ERROR) << "Get Strided_slice's begin failed, node name is " << strided_slice->fullname_with_scope();
69     return lite::RET_ERROR;
70   }
71   auto num = std::accumulate(data.shape_.begin(), data.shape_.end(), 1, std::multiplies<>());
72   for (int i = 0; i < num; ++i) {
73     bool begin_ineffective = (begin_mask & (1 << i));
74     int cur_begin = begin_ineffective ? 0 : static_cast<int *>(data.data_ptr_)[i];
75     begin->push_back(cur_begin);
76   }
77   return lite::RET_OK;
78 }
79 
GetEnd(const CNodePtr & strided_slice,std::vector<int> * end)80 int StridedSliceChecker::GetEnd(const CNodePtr &strided_slice, std::vector<int> *end) {
81   if (strided_slice == nullptr || end == nullptr) {
82     MS_LOG(ERROR) << "exist in-parameter is a nullptr.";
83     return lite::RET_NULL_PTR;
84   }
85   auto prim = GetCNodePrimitive(strided_slice);
86   MS_CHECK_TRUE_MSG(prim != nullptr, false, "Strided_slice's prim is a nullptr.");
87   auto end_mask = prim->GetAttr(ops::kEndMask) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kEndMask)) : 0;
88   lite::DataInfo data;
89   auto ret = GetConstTensor(strided_slice, ops::kInputIndex3, &data);
90   if (ret == lite::RET_NOT_SUPPORT) {
91     return ret;
92   }
93   if (ret != lite::RET_OK) {
94     MS_LOG(ERROR) << "Get Strided_slice's end failed, node name is " << strided_slice->fullname_with_scope();
95     return lite::RET_ERROR;
96   }
97   auto num = std::accumulate(data.shape_.begin(), data.shape_.end(), 1, std::multiplies<>());
98   for (int i = 0; i < num; ++i) {
99     bool end_ineffective = (end_mask & (1 << i));
100     int cur_end = end_ineffective ? INT_MAX : static_cast<int *>(data.data_ptr_)[i];
101     end->push_back(cur_end);
102   }
103   return lite::RET_OK;
104 }
105 
CheckStepIsOne(const CNodePtr & strided_slice)106 bool StridedSliceChecker::CheckStepIsOne(const CNodePtr &strided_slice) {
107   if (strided_slice == nullptr) {
108     return false;
109   }
110   if (strided_slice->size() < kInputSizeFive) {
111     return true;
112   }
113   lite::DataInfo data;
114   auto status = GetConstTensor(strided_slice, ops::kInputIndex4, &data);
115   if (status != lite::RET_OK) {
116     return false;
117   }
118   auto num = std::accumulate(data.shape_.begin(), data.shape_.end(), 1, std::multiplies<>());
119   std::vector<int> temp(num, 1);
120   return memcmp(data.data_ptr_, temp.data(), temp.size() * sizeof(int)) == 0;
121 }
122 
GetConstTensor(const CNodePtr & strided_slice,size_t index,lite::DataInfo * data_info)123 int StridedSliceChecker::GetConstTensor(const CNodePtr &strided_slice, size_t index, lite::DataInfo *data_info) {
124   if (strided_slice == nullptr || data_info == nullptr) {
125     MS_LOG(ERROR) << "exist in-parameter is a nullptr.";
126     return lite::RET_NULL_PTR;
127   }
128   if (index >= strided_slice->size() || strided_slice->input(index) == nullptr) {
129     MS_LOG(ERROR) << "Strided_slice input is invalid, node is " << strided_slice->fullname_with_scope();
130     return lite::RET_ERROR;
131   }
132   if (utils::isa<CNode>(strided_slice->input(index))) {
133     MS_LOG(DEBUG) << "Strided_slice " << index << " input is not a constant, node is "
134                   << strided_slice->fullname_with_scope();
135     return lite::RET_NOT_SUPPORT;
136   }
137   if (lite::FetchConstData(strided_slice, index, converter::kFmkTypeMs, data_info, true) != lite::RET_OK) {
138     MS_LOG(ERROR) << "Get Strided_slice " << index << "-input failed, node is " << strided_slice->fullname_with_scope();
139     return lite::RET_ERROR;
140   }
141   data_info->data_ptr_ = data_info->data_.data();
142   if (data_info->data_ptr_ == nullptr ||
143       (data_info->data_type_ != kNumberTypeInt && data_info->data_type_ != kNumberTypeInt32)) {
144     MS_LOG(ERROR) << "Get Strided_slice's constant failed, node name is " << strided_slice->fullname_with_scope();
145     return lite::RET_ERROR;
146   }
147   return lite::RET_OK;
148 }
149 }  // namespace opt
150 }  // namespace mindspore
151