• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/ascend/optimizer/mindir/reduce_axis_update.h"
18 #include <vector>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include "mindspore/core/ops/math_ops.h"
23 #include "include/common/utils/anfalgo.h"
24 
25 namespace mindspore {
26 namespace opt {
27 namespace {
28 constexpr size_t kReduceInputNum = 2;
29 constexpr size_t kXInputIndex = 1;
30 constexpr size_t kAxisInputIndex = 2;
31 constexpr auto r_reduce = "r_reduce";
32 constexpr auto m_reduce = "m_reduce";
33 constexpr auto kXs = "Xs";
34 constexpr auto kV = "V";
35 constexpr auto v_axis = "axis";
36 
CreateTensor(const std::vector<int64_t> & values)37 tensor::TensorPtr CreateTensor(const std::vector<int64_t> &values) {
38   auto type_ptr = kInt64;
39   auto data_length = sizeof(int64_t);
40   std::vector<int64_t> tensor_shape = {SizeToLong(values.size())};
41   MS_EXCEPTION_IF_NULL(type_ptr);
42   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
43   MS_EXCEPTION_IF_NULL(tensor);
44   tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
45   tensor->set_device_info(device_info);
46   auto data_ptr = tensor->data_c();
47   MS_EXCEPTION_IF_NULL(data_ptr);
48   auto buffer_size = values.size() * data_length;
49   if (buffer_size != 0) {
50     auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), buffer_size);
51     if (ret_code != EOK) {
52       MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
53     }
54   }
55   return tensor;
56 }
57 }  // namespace
58 
IsReduce(const BaseRef & ref)59 bool ReduceAxisUpdate::IsReduce(const BaseRef &ref) {
60   if (utils::isa<AnfNodePtr>(ref)) {
61     AnfNodePtr node = utils::cast<AnfNodePtr>(ref);
62     MS_EXCEPTION_IF_NULL(node);
63     if (IsPrimitive(node, prim::kPrimReduceMin) || IsPrimitive(node, prim::kPrimReduceMax) ||
64         IsPrimitive(node, prim::kPrimReduceMean) || IsPrimitive(node, prim::kPrimReduceSum) ||
65         IsPrimitive(node, prim::kPrimReduceProd) || IsPrimitive(node, prim::kPrimReduceAll) ||
66         IsPrimitive(node, prim::kPrimReduceAny) || IsPrimitive(node, prim::kPrimMeanExt) ||
67         IsPrimitive(node, prim::kPrimSumExt) || IsPrimitive(node, prim::kPrimProdExt)) {
68       return true;
69     }
70   }
71 
72   return false;
73 }
74 
IsAxisEmpty(const ValueNodePtr & axis_node) const75 bool ReduceAxisUpdate::IsAxisEmpty(const ValueNodePtr &axis_node) const {
76   MS_EXCEPTION_IF_NULL(axis_node);
77   const ValuePtr &value = axis_node->value();
78   MS_EXCEPTION_IF_NULL(value);
79   if (value->isa<ValueTuple>()) {
80     auto tuple = value->cast<ValueTuplePtr>();
81     MS_EXCEPTION_IF_NULL(tuple);
82     return tuple->size() == 0;
83   } else if (value->isa<ValueList>()) {
84     auto list = value->cast<ValueListPtr>();
85     MS_EXCEPTION_IF_NULL(list);
86     return list->size() == 0;
87   } else if (value->isa<tensor::Tensor>()) {
88     auto tensor = value->cast<tensor::TensorPtr>();
89     MS_EXCEPTION_IF_NULL(tensor);
90     return tensor->DataSize() == 0;
91   }
92 
93   return false;
94 }
95 
IsAxisNone(const AnfNodePtr & cnode,const ValueNodePtr & axis_node) const96 bool ReduceAxisUpdate::IsAxisNone(const AnfNodePtr &cnode, const ValueNodePtr &axis_node) const {
97   static std::set<std::string> op_name_support_none = {prim::kPrimMeanExt->name(), prim::kPrimSumExt->name(),
98                                                        prim::kPrimProdExt->name(), prim::kPrimReduceAll->name()};
99   auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
100   if (op_name_support_none.find(cnode_name) == op_name_support_none.end()) {
101     return false;
102   }
103   return axis_node->value()->isa<None>();
104 }
105 
IsInputScalar(const AnfNodePtr & x_node) const106 bool ReduceAxisUpdate::IsInputScalar(const AnfNodePtr &x_node) const {
107   MS_EXCEPTION_IF_NULL(x_node);
108   auto x_shape_ptr = x_node->Shape();
109   MS_EXCEPTION_IF_NULL(x_shape_ptr);
110   ShapeVector x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
111   return x_shape.empty();
112 }
113 
114 namespace {
115 constexpr size_t kAxisIndex = 2;
IsAxisEmptySequence(const AnfNodePtr & node)116 bool IsAxisEmptySequence(const AnfNodePtr &node) {
117   MS_EXCEPTION_IF_NULL(node);
118   if (!node->isa<CNode>()) {
119     return false;
120   }
121   const auto &cnode = node->cast<CNodePtr>();
122   MS_EXCEPTION_IF_NULL(cnode);
123   if (cnode->size() <= kAxisIndex) {
124     return false;
125   }
126   const auto &axis_node = cnode->input(kAxisIndex);
127 
128   if (axis_node == nullptr || (!axis_node->isa<ValueNode>())) {
129     return false;
130   }
131   const auto &value_node = axis_node->cast<ValueNodePtr>();
132   MS_EXCEPTION_IF_NULL(value_node);
133   const auto &value = value_node->value();
134   if (value == nullptr) {
135     return false;
136   }
137   if (value->isa<ValueSequence>()) {
138     const auto &value_sequence = value->cast<ValueSequencePtr>();
139     MS_EXCEPTION_IF_NULL(value_sequence);
140     return value_sequence->size() == 0;
141   } else if (value->isa<tensor::Tensor>()) {
142     const auto &tensor = value->cast<tensor::TensorPtr>();
143     MS_EXCEPTION_IF_NULL(tensor);
144     const auto &shapes = tensor->shape();
145     return shapes.size() == 1 && shapes[0] == 0;
146   }
147   return false;
148 }
149 }  // namespace
150 
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr & graph,const AnfNodePtr & node) const151 bool ReduceAxisUpdate::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &graph, const AnfNodePtr &node) const {
152   MS_EXCEPTION_IF_NULL(node);
153   MS_LOG(INFO) << "Reduce node is " << node->DebugString() << ".";
154 
155   // In control flow, empty tuples are sometimes set to dynamic len which are considered dynamic shapes, but they
156   // are not actually needed, so empty tuple scenarios are excluded here.
157   if (common::AnfAlgo::IsNodeOutputDynamicShape(node) && (!IsAxisEmptySequence(node))) {
158     MS_LOG(INFO) << "The dimension of " << node->DebugString() << " is unknown.";
159     return false;
160   }
161 
162   // If input is dynamic rank, expand axis will get wrong result.
163   if (IsDynamicRank(common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0))) {
164     MS_LOG(INFO) << "The input rank of dimension of " << node->DebugString() << " is unknown.";
165     return false;
166   }
167 
168   auto cnode = node->cast<CNodePtr>();
169   MS_EXCEPTION_IF_NULL(cnode);
170 
171   auto input_num = common::AnfAlgo::GetInputNum(cnode);
172   if (input_num < kReduceInputNum) {
173     MS_LOG(EXCEPTION) << "The input tensor size[" << input_num << "] of node ["
174                       << cnode->DebugString() + "] is not equal to " << kReduceInputNum
175                       << trace::DumpSourceLines(cnode);
176   }
177 
178   const auto &inputs = cnode->inputs();
179   const AnfNodePtr &input_x = inputs.at(kXInputIndex);
180   const AnfNodePtr &input_axis = inputs.at(kAxisInputIndex);
181   MS_EXCEPTION_IF_NULL(input_x);
182   MS_EXCEPTION_IF_NULL(input_axis);
183   MS_LOG(INFO) << "X input is " << input_x->DebugString() << ".";
184   MS_LOG(INFO) << "Axis input is " << input_axis->DebugString() << ".";
185 
186   auto axis_value_node = input_axis->cast<ValueNodePtr>();
187   if (axis_value_node == nullptr ||
188       (!(IsAxisEmpty(axis_value_node) || IsAxisNone(cnode, axis_value_node)) && !IsInputScalar(input_x))) {
189     MS_LOG(INFO) << "Axis input of node " << node->fullname_with_scope()
190                  << " is not value node or axis is not empty or none.";
191     return false;
192   } else {
193     MS_LOG(INFO) << "Axis of node " << node->fullname_with_scope() << " is empty.";
194   }
195   return true;
196 }
197 
BuildAxis(const PatternMap & m)198 AnfNodePtr BuildAxis(const PatternMap &m) {
199   auto node = m.Get(m_reduce);
200   MS_EXCEPTION_IF_NULL(node);
201   ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
202   size_t x_dim_len = x_shape.size();
203   MS_LOG(INFO) << "Input x dim len: " << x_dim_len;
204   std::vector<int64_t> axis = {0};
205   for (size_t i = 1; i < x_dim_len; ++i) {
206     (void)axis.emplace_back(SizeToLong(i));
207     MS_LOG(INFO) << "x dim: " << x_shape[i];
208   }
209   ValuePtr new_value = MakeValue(CreateTensor(axis));
210   MS_EXCEPTION_IF_NULL(new_value);
211   auto new_axis_node = std::make_shared<ValueNode>(new_value);
212   MS_EXCEPTION_IF_NULL(new_axis_node);
213   new_axis_node->set_abstract(new_value->ToAbstract());
214 
215   auto kernel_info = std::make_shared<device::KernelInfo>();
216   MS_EXCEPTION_IF_NULL(kernel_info);
217   new_axis_node->set_kernel_info(kernel_info);
218   std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
219     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
220   MS_EXCEPTION_IF_NULL(builder);
221   kernel_info->set_select_kernel_build_info(builder->Build());
222   MS_EXCEPTION_IF_NULL(kernel_info->select_kernel_build_info());
223   kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
224   kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsFormat({kOpFormat_DEFAULT});
225   kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsDeviceType({TypeId::kNumberTypeInt64});
226   auto func_graph = node->func_graph();
227   MS_EXCEPTION_IF_NULL(func_graph);
228   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
229   MS_EXCEPTION_IF_NULL(kernel_graph);
230   kernel_graph->AddValueNodeToGraph(new_axis_node);
231   return new_axis_node;
232 }
233 
BuildReduce(const PatternMap & m,const AnfNodePtr &)234 AnfNodePtr BuildReduce(const PatternMap &m, const AnfNodePtr &) {
235   auto anf = m.Get(m_reduce);
236   MS_EXCEPTION_IF_NULL(anf);
237   auto graph = anf->func_graph();
238   MS_EXCEPTION_IF_NULL(graph);
239   auto manager = graph->manager();
240   MS_EXCEPTION_IF_NULL(manager);
241   auto cnode = anf->cast<CNodePtr>();
242   MS_EXCEPTION_IF_NULL(cnode);
243   manager->SetEdge(cnode, kAxisInputIndex, m.Get(v_axis));
244   return cnode;
245 }
246 
DefineSrcPattern(SrcPattern * src_pattern)247 void ReduceAxisUpdate::DefineSrcPattern(SrcPattern *src_pattern) {
248   (void)(*src_pattern).AddVar(kV, IsReduce).AddSeqVar(kXs).AddCNode(m_reduce, {kV, kXs});
249 }
250 
DefineDstPattern(DstPattern * dst_pattern)251 void ReduceAxisUpdate::DefineDstPattern(DstPattern *dst_pattern) {
252   auto reduce_input = Unpacking(kXs);
253   reduce_input.at(kAxisInputIndex - 1) = v_axis;
254   (void)(*dst_pattern).AddValueNode(v_axis, BuildAxis).AddCNode(r_reduce, {kV, reduce_input}, BuildReduce);
255 }
256 }  // namespace opt
257 }  // namespace mindspore
258