1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/optimizer/mindir/reduce_axis_update.h"
18 #include <vector>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include "mindspore/core/ops/math_ops.h"
23 #include "include/common/utils/anfalgo.h"
24
25 namespace mindspore {
26 namespace opt {
27 namespace {
28 constexpr size_t kReduceInputNum = 2;
29 constexpr size_t kXInputIndex = 1;
30 constexpr size_t kAxisInputIndex = 2;
31 constexpr auto r_reduce = "r_reduce";
32 constexpr auto m_reduce = "m_reduce";
33 constexpr auto kXs = "Xs";
34 constexpr auto kV = "V";
35 constexpr auto v_axis = "axis";
36
CreateTensor(const std::vector<int64_t> & values)37 tensor::TensorPtr CreateTensor(const std::vector<int64_t> &values) {
38 auto type_ptr = kInt64;
39 auto data_length = sizeof(int64_t);
40 std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
41 MS_EXCEPTION_IF_NULL(type_ptr);
42 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
43 MS_EXCEPTION_IF_NULL(tensor);
44 tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
45 tensor->set_device_info(device_info);
46 auto data_ptr = tensor->data_c();
47 MS_EXCEPTION_IF_NULL(data_ptr);
48 auto buffer_size = values.size() * data_length;
49 if (buffer_size != 0) {
50 auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), buffer_size);
51 if (ret_code != EOK) {
52 MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
53 }
54 }
55 return tensor;
56 }
57 } // namespace
58
IsReduce(const BaseRef & ref)59 bool ReduceAxisUpdate::IsReduce(const BaseRef &ref) {
60 if (utils::isa<AnfNodePtr>(ref)) {
61 AnfNodePtr node = utils::cast<AnfNodePtr>(ref);
62 MS_EXCEPTION_IF_NULL(node);
63 if (IsPrimitive(node, prim::kPrimReduceMin) || IsPrimitive(node, prim::kPrimReduceMax) ||
64 IsPrimitive(node, prim::kPrimReduceMean) || IsPrimitive(node, prim::kPrimReduceSum) ||
65 IsPrimitive(node, prim::kPrimReduceProd) || IsPrimitive(node, prim::kPrimReduceAll) ||
66 IsPrimitive(node, prim::kPrimReduceAny) || IsPrimitive(node, prim::kPrimMeanExt) ||
67 IsPrimitive(node, prim::kPrimSumExt) || IsPrimitive(node, prim::kPrimProdExt)) {
68 return true;
69 }
70 }
71
72 return false;
73 }
74
IsAxisEmpty(const ValueNodePtr & axis_node) const75 bool ReduceAxisUpdate::IsAxisEmpty(const ValueNodePtr &axis_node) const {
76 MS_EXCEPTION_IF_NULL(axis_node);
77 const ValuePtr &value = axis_node->value();
78 MS_EXCEPTION_IF_NULL(value);
79 if (value->isa<ValueTuple>()) {
80 auto tuple = value->cast<ValueTuplePtr>();
81 MS_EXCEPTION_IF_NULL(tuple);
82 return tuple->size() == 0;
83 } else if (value->isa<ValueList>()) {
84 auto list = value->cast<ValueListPtr>();
85 MS_EXCEPTION_IF_NULL(list);
86 return list->size() == 0;
87 } else if (value->isa<tensor::Tensor>()) {
88 auto tensor = value->cast<tensor::TensorPtr>();
89 MS_EXCEPTION_IF_NULL(tensor);
90 return tensor->DataSize() == 0;
91 }
92
93 return false;
94 }
95
IsAxisNone(const AnfNodePtr & cnode,const ValueNodePtr & axis_node) const96 bool ReduceAxisUpdate::IsAxisNone(const AnfNodePtr &cnode, const ValueNodePtr &axis_node) const {
97 static std::set<std::string> op_name_support_none = {prim::kPrimMeanExt->name(), prim::kPrimSumExt->name(),
98 prim::kPrimProdExt->name(), prim::kPrimReduceAll->name()};
99 auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
100 if (op_name_support_none.find(cnode_name) == op_name_support_none.end()) {
101 return false;
102 }
103 return axis_node->value()->isa<None>();
104 }
105
IsInputScalar(const AnfNodePtr & x_node) const106 bool ReduceAxisUpdate::IsInputScalar(const AnfNodePtr &x_node) const {
107 MS_EXCEPTION_IF_NULL(x_node);
108 auto x_shape_ptr = x_node->Shape();
109 MS_EXCEPTION_IF_NULL(x_shape_ptr);
110 ShapeVector x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
111 return x_shape.empty();
112 }
113
114 namespace {
115 constexpr size_t kAxisIndex = 2;
IsAxisEmptySequence(const AnfNodePtr & node)116 bool IsAxisEmptySequence(const AnfNodePtr &node) {
117 MS_EXCEPTION_IF_NULL(node);
118 if (!node->isa<CNode>()) {
119 return false;
120 }
121 const auto &cnode = node->cast<CNodePtr>();
122 MS_EXCEPTION_IF_NULL(cnode);
123 if (cnode->size() <= kAxisIndex) {
124 return false;
125 }
126 const auto &axis_node = cnode->input(kAxisIndex);
127
128 if (axis_node == nullptr || (!axis_node->isa<ValueNode>())) {
129 return false;
130 }
131 const auto &value_node = axis_node->cast<ValueNodePtr>();
132 MS_EXCEPTION_IF_NULL(value_node);
133 const auto &value = value_node->value();
134 if (value == nullptr) {
135 return false;
136 }
137 if (value->isa<ValueSequence>()) {
138 const auto &value_sequence = value->cast<ValueSequencePtr>();
139 MS_EXCEPTION_IF_NULL(value_sequence);
140 return value_sequence->size() == 0;
141 } else if (value->isa<tensor::Tensor>()) {
142 const auto &tensor = value->cast<tensor::TensorPtr>();
143 MS_EXCEPTION_IF_NULL(tensor);
144 const auto &shapes = tensor->shape();
145 return shapes.size() == 1 && shapes[0] == 0;
146 }
147 return false;
148 }
149 } // namespace
150
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr & graph,const AnfNodePtr & node) const151 bool ReduceAxisUpdate::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &graph, const AnfNodePtr &node) const {
152 MS_EXCEPTION_IF_NULL(node);
153 MS_LOG(INFO) << "Reduce node is " << node->DebugString() << ".";
154
155 // In control flow, empty tuples are sometimes set to dynamic len which are considered dynamic shapes, but they
156 // are not actually needed, so empty tuple scenarios are excluded here.
157 if (common::AnfAlgo::IsNodeOutputDynamicShape(node) && (!IsAxisEmptySequence(node))) {
158 MS_LOG(INFO) << "The dimension of " << node->DebugString() << " is unknown.";
159 return false;
160 }
161
162 // If input is dynamic rank, expand axis will get wrong result.
163 if (IsDynamicRank(common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0))) {
164 MS_LOG(INFO) << "The input rank of dimension of " << node->DebugString() << " is unknown.";
165 return false;
166 }
167
168 auto cnode = node->cast<CNodePtr>();
169 MS_EXCEPTION_IF_NULL(cnode);
170
171 auto input_num = common::AnfAlgo::GetInputNum(cnode);
172 if (input_num < kReduceInputNum) {
173 MS_LOG(EXCEPTION) << "The input tensor size[" << input_num << "] of node ["
174 << cnode->DebugString() + "] is not equal to " << kReduceInputNum
175 << trace::DumpSourceLines(cnode);
176 }
177
178 const auto &inputs = cnode->inputs();
179 const AnfNodePtr &input_x = inputs.at(kXInputIndex);
180 const AnfNodePtr &input_axis = inputs.at(kAxisInputIndex);
181 MS_EXCEPTION_IF_NULL(input_x);
182 MS_EXCEPTION_IF_NULL(input_axis);
183 MS_LOG(INFO) << "X input is " << input_x->DebugString() << ".";
184 MS_LOG(INFO) << "Axis input is " << input_axis->DebugString() << ".";
185
186 auto axis_value_node = input_axis->cast<ValueNodePtr>();
187 if (axis_value_node == nullptr ||
188 (!(IsAxisEmpty(axis_value_node) || IsAxisNone(cnode, axis_value_node)) && !IsInputScalar(input_x))) {
189 MS_LOG(INFO) << "Axis input of node " << node->fullname_with_scope()
190 << " is not value node or axis is not empty or none.";
191 return false;
192 } else {
193 MS_LOG(INFO) << "Axis of node " << node->fullname_with_scope() << " is empty.";
194 }
195 return true;
196 }
197
BuildAxis(const PatternMap & m)198 AnfNodePtr BuildAxis(const PatternMap &m) {
199 auto node = m.Get(m_reduce);
200 MS_EXCEPTION_IF_NULL(node);
201 ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
202 size_t x_dim_len = x_shape.size();
203 MS_LOG(INFO) << "Input x dim len: " << x_dim_len;
204 std::vector<int64_t> axis = {0};
205 for (size_t i = 1; i < x_dim_len; ++i) {
206 (void)axis.emplace_back(SizeToLong(i));
207 MS_LOG(INFO) << "x dim: " << x_shape[i];
208 }
209 ValuePtr new_value = MakeValue(CreateTensor(axis));
210 MS_EXCEPTION_IF_NULL(new_value);
211 auto new_axis_node = std::make_shared<ValueNode>(new_value);
212 MS_EXCEPTION_IF_NULL(new_axis_node);
213 new_axis_node->set_abstract(new_value->ToAbstract());
214
215 auto kernel_info = std::make_shared<device::KernelInfo>();
216 MS_EXCEPTION_IF_NULL(kernel_info);
217 new_axis_node->set_kernel_info(kernel_info);
218 std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
219 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
220 MS_EXCEPTION_IF_NULL(builder);
221 kernel_info->set_select_kernel_build_info(builder->Build());
222 MS_EXCEPTION_IF_NULL(kernel_info->select_kernel_build_info());
223 kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
224 kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsFormat({kOpFormat_DEFAULT});
225 kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsDeviceType({TypeId::kNumberTypeInt64});
226 auto func_graph = node->func_graph();
227 MS_EXCEPTION_IF_NULL(func_graph);
228 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
229 MS_EXCEPTION_IF_NULL(kernel_graph);
230 kernel_graph->AddValueNodeToGraph(new_axis_node);
231 return new_axis_node;
232 }
233
BuildReduce(const PatternMap & m,const AnfNodePtr &)234 AnfNodePtr BuildReduce(const PatternMap &m, const AnfNodePtr &) {
235 auto anf = m.Get(m_reduce);
236 MS_EXCEPTION_IF_NULL(anf);
237 auto graph = anf->func_graph();
238 MS_EXCEPTION_IF_NULL(graph);
239 auto manager = graph->manager();
240 MS_EXCEPTION_IF_NULL(manager);
241 auto cnode = anf->cast<CNodePtr>();
242 MS_EXCEPTION_IF_NULL(cnode);
243 manager->SetEdge(cnode, kAxisInputIndex, m.Get(v_axis));
244 return cnode;
245 }
246
DefineSrcPattern(SrcPattern * src_pattern)247 void ReduceAxisUpdate::DefineSrcPattern(SrcPattern *src_pattern) {
248 (void)(*src_pattern).AddVar(kV, IsReduce).AddSeqVar(kXs).AddCNode(m_reduce, {kV, kXs});
249 }
250
DefineDstPattern(DstPattern * dst_pattern)251 void ReduceAxisUpdate::DefineDstPattern(DstPattern *dst_pattern) {
252 auto reduce_input = Unpacking(kXs);
253 reduce_input.at(kAxisInputIndex - 1) = v_axis;
254 (void)(*dst_pattern).AddValueNode(v_axis, BuildAxis).AddCNode(r_reduce, {kV, reduce_input}, BuildReduce);
255 }
256 } // namespace opt
257 } // namespace mindspore
258