• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include <algorithm>
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "backend/optimizer/common/helper.h"
24 #include "base/core_ops.h"
25 #include "utils/utils.h"
26 
27 namespace mindspore {
28 namespace opt {
29 namespace {
30 constexpr size_t kAvgPool3DInputNum = 1;
31 constexpr size_t k5DInferDims = 5;
32 constexpr int64_t kC0 = 16;
33 constexpr size_t kDHWDimNum = 3;
34 constexpr size_t kNCDHWDimNum = 5;
35 
GetInterSection(int64_t start_1,int64_t end_1,int64_t start_2,int64_t end_2)36 int64_t GetInterSection(int64_t start_1, int64_t end_1, int64_t start_2, int64_t end_2) {
37   if (end_1 <= start_2) {
38     return 0;
39   }
40   if (start_1 >= end_2) {
41     return 0;
42   }
43   if (start_1 < start_2) {
44     start_1 = start_2;
45   }
46   if (end_1 > end_2) {
47     end_1 = end_2;
48   }
49   return end_1 - start_1;
50 }
51 
GetKernelSize(const AnfNodePtr & node,int64_t * kd,int64_t * kh,int64_t * kw)52 bool GetKernelSize(const AnfNodePtr &node, int64_t *kd, int64_t *kh, int64_t *kw) {
53   MS_EXCEPTION_IF_NULL(node);
54   MS_EXCEPTION_IF_NULL(kd);
55   MS_EXCEPTION_IF_NULL(kh);
56   MS_EXCEPTION_IF_NULL(kw);
57   if (AnfAlgo::HasNodeAttr("kernel_size", node->cast<CNodePtr>())) {
58     auto kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "kernel_size");
59     if (kernel_size.size() == 1) {
60       *kd = kernel_size[kDim0];
61       *kh = kernel_size[kDim0];
62       *kw = kernel_size[kDim0];
63     } else if (kernel_size.size() == kDHWDimNum) {
64       *kd = kernel_size[kDim0];
65       *kh = kernel_size[kDim1];
66       *kw = kernel_size[kDim2];
67     } else if (kernel_size.size() == kNCDHWDimNum) {
68       // NCDHW
69       *kd = kernel_size[kDim2];
70       *kh = kernel_size[kDim3];
71       *kw = kernel_size[kDim4];
72     } else {
73       MS_LOG(EXCEPTION) << "Unknown kernel size " << kernel_size.size();
74     }
75     return true;
76   }
77   return false;
78 }
79 
GetStrideSize(const AnfNodePtr & node,int64_t * sd,int64_t * sh,int64_t * sw)80 bool GetStrideSize(const AnfNodePtr &node, int64_t *sd, int64_t *sh, int64_t *sw) {
81   MS_EXCEPTION_IF_NULL(node);
82   MS_EXCEPTION_IF_NULL(sd);
83   MS_EXCEPTION_IF_NULL(sh);
84   MS_EXCEPTION_IF_NULL(sw);
85   if (AnfAlgo::HasNodeAttr("strides", node->cast<CNodePtr>())) {
86     auto kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "strides");
87     if (kernel_size.size() == 1) {
88       *sd = kernel_size[kDim0];
89       *sh = kernel_size[kDim0];
90       *sw = kernel_size[kDim0];
91     } else if (kernel_size.size() == kDHWDimNum) {
92       *sd = kernel_size[kDim0];
93       *sh = kernel_size[kDim1];
94       *sw = kernel_size[kDim2];
95     } else if (kernel_size.size() == kNCDHWDimNum) {
96       // NCDHW
97       *sd = kernel_size[kDim2];
98       *sh = kernel_size[kDim3];
99       *sw = kernel_size[kDim4];
100     } else {
101       MS_LOG(EXCEPTION) << "Unknown strides size " << kernel_size.size();
102     }
103     return true;
104   }
105   return false;
106 }
107 
GetAttrs(const AnfNodePtr & node,std::vector<int64_t> * pad_list,bool * count_include_pad,bool * ceil_mode,int64_t * divisor_override)108 void GetAttrs(const AnfNodePtr &node, std::vector<int64_t> *pad_list, bool *count_include_pad, bool *ceil_mode,
109               int64_t *divisor_override) {
110   MS_EXCEPTION_IF_NULL(node);
111   if (!AnfAlgo::HasNodeAttr("pad_list", node->cast<CNodePtr>())) {
112     MS_LOG(EXCEPTION) << "AvgPool3D should has attr pad_list";
113   }
114   *pad_list = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "pad_list");
115   if (AnfAlgo::HasNodeAttr("count_include_pad", node->cast<CNodePtr>())) {
116     *count_include_pad = AnfAlgo::GetNodeAttr<bool>(node, "count_include_pad");
117   }
118   if (AnfAlgo::HasNodeAttr("ceil_mode", node->cast<CNodePtr>())) {
119     *ceil_mode = AnfAlgo::GetNodeAttr<bool>(node, "ceil_mode");
120   }
121   if (AnfAlgo::HasNodeAttr("divisor_override", node->cast<CNodePtr>())) {
122     *divisor_override = AnfAlgo::GetNodeAttr<int64_t>(node, "divisor_override");
123   }
124 }
125 
IsVectorImpl(int64_t fh,int64_t fw,int64_t kh,int64_t kw,const std::vector<int64_t> & pad_list)126 bool IsVectorImpl(int64_t fh, int64_t fw, int64_t kh, int64_t kw, const std::vector<int64_t> &pad_list) {
127   if (std::any_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item != 0; })) {
128     return false;
129   }
130   if (fh != kh || fw != kw) {
131     return false;
132   }
133   return true;
134 }
135 
IsZeroPads(const std::vector<int64_t> & pad_list)136 bool IsZeroPads(const std::vector<int64_t> &pad_list) {
137   return std::all_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item == 0; });
138 }
139 
ConstructFilter(const FuncGraphPtr & func_graph,const std::vector<int64_t> & pad_list,int64_t fc,int64_t kd,int64_t kh,int64_t kw,bool ceil_mode,int64_t divisor_override)140 AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int64_t> &pad_list, int64_t fc, int64_t kd,
141                            int64_t kh, int64_t kw, bool ceil_mode, int64_t divisor_override) {
142   MS_EXCEPTION_IF_NULL(func_graph);
143   // assist tensor 1
144   int64_t c1 = (fc + kC0 - 1) / kC0;
145   std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0};  // frac_z_3d
146   std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
147   float val = 1.0 / (kd * kh * kw);
148   if (divisor_override) {
149     val = 1.0 / divisor_override;
150   } else if (!IsZeroPads(pad_list) || ceil_mode) {
151     val = 1.0;
152   }
153   // create value node
154   int64_t cnt = c1 * kd * kh * kw;
155   return ConstructFilterValueNode(func_graph, val, assist_shape, infer_shape, cnt);
156 }
157 
ConstructMultiplier(const FuncGraphPtr & func_graph,int64_t fn,int64_t fc,int64_t fd,int64_t fh,int64_t fw,int64_t dd,int64_t dh,int64_t dw,int64_t kd,int64_t kh,int64_t kw,int64_t sd,int64_t sh,int64_t sw,const std::vector<int64_t> & pad_list,bool count_include_pad)158 AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64_t fc, int64_t fd, int64_t fh,
159                                int64_t fw, int64_t dd, int64_t dh, int64_t dw, int64_t kd, int64_t kh, int64_t kw,
160                                int64_t sd, int64_t sh, int64_t sw, const std::vector<int64_t> &pad_list,
161                                bool count_include_pad) {
162   MS_EXCEPTION_IF_NULL(func_graph);
163   //  assist tensor 2
164   std::vector<int64_t> assist_shape = {fn, fc, dd, dh, dw};  // NCDHW
165   auto infer_shape = {LongToSize(fn), LongToSize(fc), LongToSize(dd), LongToSize(dh), LongToSize(dw)};
166   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
167   MS_EXCEPTION_IF_NULL(tensor);
168   auto tensor_data = reinterpret_cast<float16 *>(tensor->data_c());
169   auto pad_d = pad_list[kDim0] + pad_list[kDim1];
170   auto pad_h = pad_list[kDim2] + pad_list[kDim3];
171   auto pad_w = pad_list[kDim4] + pad_list[kDim5];
172   auto len_d = fd + pad_d;
173   auto len_h = fh + pad_h;
174   auto len_w = fw + pad_w;
175   for (int64_t nn = 0; nn < fn; nn++) {
176     for (int64_t cc = 0; cc < fc; cc++) {
177       int64_t start_d = 0;
178       for (int64_t di = 0; di < dd; di++) {
179         auto v_kd = start_d + kd <= len_d ? kd : len_d - start_d;
180         int64_t start_h = 0;
181         for (int64_t hi = 0; hi < dh; hi++) {
182           auto v_kh = start_h + kh <= len_h ? kh : len_h - start_h;
183           int64_t start_w = 0;
184           for (int64_t wi = 0; wi < dw; wi++) {
185             auto v_kw = start_w + kw < len_w ? kw : len_w - start_w;
186             auto vaild_d = GetInterSection(start_d, start_d + kd, pad_list[kDim0], pad_list[kDim0] + fd);
187             auto vaild_h = GetInterSection(start_h, start_h + kh, pad_list[kDim2], pad_list[kDim2] + fh);
188             auto vaild_w = GetInterSection(start_w, start_w + kw, pad_list[kDim4], pad_list[kDim4] + fw);
189             auto vaild_data = vaild_d * vaild_h * vaild_w;
190             auto vaild_kernel = v_kd * v_kh * v_kw;
191             auto valid_dividend = count_include_pad ? vaild_kernel : vaild_data;
192             if (valid_dividend == 0) {
193               MS_LOG(EXCEPTION) << "Dividend 'valid_dividend' should not be 0.";
194             }
195             float val = 1.0 / valid_dividend;
196             *tensor_data = float16(val);
197             ++tensor_data;
198             start_w += sw;
199           }
200           start_h += sh;
201         }
202         start_d += sd;
203       }
204     }
205   }
206   auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
207   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
208   MS_EXCEPTION_IF_NULL(kernel_graph);
209   auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
210   kernel_graph->AddValueNodeToGraph(value_node);
211   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
212   return value_node;
213 }
214 }  // namespace
215 
ConstructFilterValueNode(const FuncGraphPtr & func_graph,float val,const std::vector<int64_t> & assist_shape,const std::vector<size_t> & infer_shape,int64_t cnt)216 AnfNodePtr ConstructFilterValueNode(const FuncGraphPtr &func_graph, float val, const std::vector<int64_t> &assist_shape,
217                                     const std::vector<size_t> &infer_shape, int64_t cnt) {
218   tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
219   MS_EXCEPTION_IF_NULL(assist_tensor);
220   TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
221   tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
222   assist_tensor->set_device_info(device_info);
223   auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
224   for (int64_t i = 0; i < cnt; ++i) {
225     for (int64_t j = 0; j < kC0; ++j) {
226       for (int64_t k = 0; k < kC0; ++k) {
227         float t = j == k ? val : 0;
228         *tensor_data = float16(t);
229         ++tensor_data;
230       }
231     }
232   }
233 
234   auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
235   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
236   MS_EXCEPTION_IF_NULL(kernel_graph);
237   auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
238   kernel_graph->AddValueNodeToGraph(value_node);
239   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
240   return value_node;
241 }
242 
DefinePattern() const243 const BaseRef AvgPool3DFusion::DefinePattern() const {
244   VarPtr Xs = std::make_shared<SeqVar>();
245   return VectorRef({prim::kPrimAvgPool3D, Xs});
246 }
247 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const248 const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
249                                           const EquivPtr &) const {
250   MS_EXCEPTION_IF_NULL(func_graph);
251   MS_EXCEPTION_IF_NULL(node);
252   auto avg_pool_3d_node = node->cast<CNodePtr>();
253   MS_EXCEPTION_IF_NULL(avg_pool_3d_node);
254   if (avg_pool_3d_node->size() != kAvgPool3DInputNum + 1) {
255     MS_LOG(INFO) << "The node " << avg_pool_3d_node->DebugString() << " is not equal to " << kAvgPool3DInputNum
256                  << " inputs. Can not do fusion.";
257     return nullptr;
258   }
259   auto dims_in = AnfAlgo::GetPrevNodeOutputInferShape(avg_pool_3d_node, 0);
260   auto dims_out = AnfAlgo::GetOutputInferShape(avg_pool_3d_node, 0);
261   if (dims_in.size() < k5DInferDims || dims_out.size() < k5DInferDims) {
262     MS_LOG(EXCEPTION) << "AvgPool3D's in_out infer shape dims can not be less " << k5DInferDims;
263   }
264   auto fn = SizeToLong(dims_in[kDim0]);
265   auto fc = SizeToLong(dims_in[kDim1]);
266   auto fd = SizeToLong(dims_in[kDim2]);
267   auto fh = SizeToLong(dims_in[kDim3]);
268   auto fw = SizeToLong(dims_in[kDim4]);
269   auto dout = SizeToLong(dims_out[kDim2]);
270   auto dh = SizeToLong(dims_out[kDim3]);
271   auto dw = SizeToLong(dims_out[kDim4]);
272   // kernel size
273   int64_t kd;
274   int64_t kh;
275   int64_t kw;
276   if (!GetKernelSize(avg_pool_3d_node, &kd, &kh, &kw)) {
277     MS_LOG(EXCEPTION) << "GetK kernel size failed";
278   }
279   // strides
280   int64_t sd;
281   int64_t sh;
282   int64_t sw;
283   if (!GetStrideSize(avg_pool_3d_node, &sd, &sh, &sw)) {
284     MS_LOG(EXCEPTION) << "GetK stride size failed";
285   }
286   std::vector<int64_t> pad_list;
287   bool count_include_pad = false;
288   bool ceil_mode = false;
289   int64_t divisor_override = 0;
290   GetAttrs(avg_pool_3d_node, &pad_list, &count_include_pad, &ceil_mode, &divisor_override);
291   if (IsVectorImpl(fh, fw, kh, kw, pad_list)) {
292     MS_LOG(INFO) << "No need fusion";
293     return nullptr;
294   }
295   std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAvgPool3D->name()))};
296   (void)new_inputs.insert(new_inputs.end(), avg_pool_3d_node->inputs().begin() + 1, avg_pool_3d_node->inputs().end());
297   // assist node 1
298   auto filter_node = ConstructFilter(func_graph, pad_list, fc, kd, kh, kw, ceil_mode, divisor_override);
299   new_inputs.push_back(filter_node);
300   MS_EXCEPTION_IF_NULL(filter_node);
301   // assist node 2
302   if ((!IsZeroPads(pad_list) || ceil_mode) && !divisor_override) {
303     auto multiplier = ConstructMultiplier(func_graph, fn, fc, fd, fh, fw, dout, dh, dw, kd, kh, kw, sd, sh, sw,
304                                           pad_list, count_include_pad);
305     new_inputs.push_back(multiplier);
306   }
307   auto new_3d = func_graph->NewCNode(new_inputs);
308   MS_EXCEPTION_IF_NULL(new_3d);
309   new_3d->set_scope(avg_pool_3d_node->scope());
310   new_3d->set_abstract(avg_pool_3d_node->abstract());
311   AnfAlgo::CopyNodeAttrs(avg_pool_3d_node, new_3d);
312   return new_3d;
313 }
314 }  // namespace opt
315 }  // namespace mindspore
316