1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include <algorithm>
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "backend/optimizer/common/helper.h"
24 #include "base/core_ops.h"
25 #include "utils/utils.h"
26
27 namespace mindspore {
28 namespace opt {
29 namespace {
30 constexpr size_t kAvgPool3DInputNum = 1;
31 constexpr size_t k5DInferDims = 5;
32 constexpr int64_t kC0 = 16;
33 constexpr size_t kDHWDimNum = 3;
34 constexpr size_t kNCDHWDimNum = 5;
35
GetInterSection(int64_t start_1,int64_t end_1,int64_t start_2,int64_t end_2)36 int64_t GetInterSection(int64_t start_1, int64_t end_1, int64_t start_2, int64_t end_2) {
37 if (end_1 <= start_2) {
38 return 0;
39 }
40 if (start_1 >= end_2) {
41 return 0;
42 }
43 if (start_1 < start_2) {
44 start_1 = start_2;
45 }
46 if (end_1 > end_2) {
47 end_1 = end_2;
48 }
49 return end_1 - start_1;
50 }
51
GetKernelSize(const AnfNodePtr & node,int64_t * kd,int64_t * kh,int64_t * kw)52 bool GetKernelSize(const AnfNodePtr &node, int64_t *kd, int64_t *kh, int64_t *kw) {
53 MS_EXCEPTION_IF_NULL(node);
54 MS_EXCEPTION_IF_NULL(kd);
55 MS_EXCEPTION_IF_NULL(kh);
56 MS_EXCEPTION_IF_NULL(kw);
57 if (AnfAlgo::HasNodeAttr("kernel_size", node->cast<CNodePtr>())) {
58 auto kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "kernel_size");
59 if (kernel_size.size() == 1) {
60 *kd = kernel_size[kDim0];
61 *kh = kernel_size[kDim0];
62 *kw = kernel_size[kDim0];
63 } else if (kernel_size.size() == kDHWDimNum) {
64 *kd = kernel_size[kDim0];
65 *kh = kernel_size[kDim1];
66 *kw = kernel_size[kDim2];
67 } else if (kernel_size.size() == kNCDHWDimNum) {
68 // NCDHW
69 *kd = kernel_size[kDim2];
70 *kh = kernel_size[kDim3];
71 *kw = kernel_size[kDim4];
72 } else {
73 MS_LOG(EXCEPTION) << "Unknown kernel size " << kernel_size.size();
74 }
75 return true;
76 }
77 return false;
78 }
79
GetStrideSize(const AnfNodePtr & node,int64_t * sd,int64_t * sh,int64_t * sw)80 bool GetStrideSize(const AnfNodePtr &node, int64_t *sd, int64_t *sh, int64_t *sw) {
81 MS_EXCEPTION_IF_NULL(node);
82 MS_EXCEPTION_IF_NULL(sd);
83 MS_EXCEPTION_IF_NULL(sh);
84 MS_EXCEPTION_IF_NULL(sw);
85 if (AnfAlgo::HasNodeAttr("strides", node->cast<CNodePtr>())) {
86 auto kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "strides");
87 if (kernel_size.size() == 1) {
88 *sd = kernel_size[kDim0];
89 *sh = kernel_size[kDim0];
90 *sw = kernel_size[kDim0];
91 } else if (kernel_size.size() == kDHWDimNum) {
92 *sd = kernel_size[kDim0];
93 *sh = kernel_size[kDim1];
94 *sw = kernel_size[kDim2];
95 } else if (kernel_size.size() == kNCDHWDimNum) {
96 // NCDHW
97 *sd = kernel_size[kDim2];
98 *sh = kernel_size[kDim3];
99 *sw = kernel_size[kDim4];
100 } else {
101 MS_LOG(EXCEPTION) << "Unknown strides size " << kernel_size.size();
102 }
103 return true;
104 }
105 return false;
106 }
107
GetAttrs(const AnfNodePtr & node,std::vector<int64_t> * pad_list,bool * count_include_pad,bool * ceil_mode,int64_t * divisor_override)108 void GetAttrs(const AnfNodePtr &node, std::vector<int64_t> *pad_list, bool *count_include_pad, bool *ceil_mode,
109 int64_t *divisor_override) {
110 MS_EXCEPTION_IF_NULL(node);
111 if (!AnfAlgo::HasNodeAttr("pad_list", node->cast<CNodePtr>())) {
112 MS_LOG(EXCEPTION) << "AvgPool3D should has attr pad_list";
113 }
114 *pad_list = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "pad_list");
115 if (AnfAlgo::HasNodeAttr("count_include_pad", node->cast<CNodePtr>())) {
116 *count_include_pad = AnfAlgo::GetNodeAttr<bool>(node, "count_include_pad");
117 }
118 if (AnfAlgo::HasNodeAttr("ceil_mode", node->cast<CNodePtr>())) {
119 *ceil_mode = AnfAlgo::GetNodeAttr<bool>(node, "ceil_mode");
120 }
121 if (AnfAlgo::HasNodeAttr("divisor_override", node->cast<CNodePtr>())) {
122 *divisor_override = AnfAlgo::GetNodeAttr<int64_t>(node, "divisor_override");
123 }
124 }
125
IsVectorImpl(int64_t fh,int64_t fw,int64_t kh,int64_t kw,const std::vector<int64_t> & pad_list)126 bool IsVectorImpl(int64_t fh, int64_t fw, int64_t kh, int64_t kw, const std::vector<int64_t> &pad_list) {
127 if (std::any_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item != 0; })) {
128 return false;
129 }
130 if (fh != kh || fw != kw) {
131 return false;
132 }
133 return true;
134 }
135
IsZeroPads(const std::vector<int64_t> & pad_list)136 bool IsZeroPads(const std::vector<int64_t> &pad_list) {
137 return std::all_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item == 0; });
138 }
139
ConstructFilter(const FuncGraphPtr & func_graph,const std::vector<int64_t> & pad_list,int64_t fc,int64_t kd,int64_t kh,int64_t kw,bool ceil_mode,int64_t divisor_override)140 AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int64_t> &pad_list, int64_t fc, int64_t kd,
141 int64_t kh, int64_t kw, bool ceil_mode, int64_t divisor_override) {
142 MS_EXCEPTION_IF_NULL(func_graph);
143 // assist tensor 1
144 int64_t c1 = (fc + kC0 - 1) / kC0;
145 std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
146 std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
147 float val = 1.0 / (kd * kh * kw);
148 if (divisor_override) {
149 val = 1.0 / divisor_override;
150 } else if (!IsZeroPads(pad_list) || ceil_mode) {
151 val = 1.0;
152 }
153 // create value node
154 int64_t cnt = c1 * kd * kh * kw;
155 return ConstructFilterValueNode(func_graph, val, assist_shape, infer_shape, cnt);
156 }
157
ConstructMultiplier(const FuncGraphPtr & func_graph,int64_t fn,int64_t fc,int64_t fd,int64_t fh,int64_t fw,int64_t dd,int64_t dh,int64_t dw,int64_t kd,int64_t kh,int64_t kw,int64_t sd,int64_t sh,int64_t sw,const std::vector<int64_t> & pad_list,bool count_include_pad)158 AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64_t fc, int64_t fd, int64_t fh,
159 int64_t fw, int64_t dd, int64_t dh, int64_t dw, int64_t kd, int64_t kh, int64_t kw,
160 int64_t sd, int64_t sh, int64_t sw, const std::vector<int64_t> &pad_list,
161 bool count_include_pad) {
162 MS_EXCEPTION_IF_NULL(func_graph);
163 // assist tensor 2
164 std::vector<int64_t> assist_shape = {fn, fc, dd, dh, dw}; // NCDHW
165 auto infer_shape = {LongToSize(fn), LongToSize(fc), LongToSize(dd), LongToSize(dh), LongToSize(dw)};
166 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
167 MS_EXCEPTION_IF_NULL(tensor);
168 auto tensor_data = reinterpret_cast<float16 *>(tensor->data_c());
169 auto pad_d = pad_list[kDim0] + pad_list[kDim1];
170 auto pad_h = pad_list[kDim2] + pad_list[kDim3];
171 auto pad_w = pad_list[kDim4] + pad_list[kDim5];
172 auto len_d = fd + pad_d;
173 auto len_h = fh + pad_h;
174 auto len_w = fw + pad_w;
175 for (int64_t nn = 0; nn < fn; nn++) {
176 for (int64_t cc = 0; cc < fc; cc++) {
177 int64_t start_d = 0;
178 for (int64_t di = 0; di < dd; di++) {
179 auto v_kd = start_d + kd <= len_d ? kd : len_d - start_d;
180 int64_t start_h = 0;
181 for (int64_t hi = 0; hi < dh; hi++) {
182 auto v_kh = start_h + kh <= len_h ? kh : len_h - start_h;
183 int64_t start_w = 0;
184 for (int64_t wi = 0; wi < dw; wi++) {
185 auto v_kw = start_w + kw < len_w ? kw : len_w - start_w;
186 auto vaild_d = GetInterSection(start_d, start_d + kd, pad_list[kDim0], pad_list[kDim0] + fd);
187 auto vaild_h = GetInterSection(start_h, start_h + kh, pad_list[kDim2], pad_list[kDim2] + fh);
188 auto vaild_w = GetInterSection(start_w, start_w + kw, pad_list[kDim4], pad_list[kDim4] + fw);
189 auto vaild_data = vaild_d * vaild_h * vaild_w;
190 auto vaild_kernel = v_kd * v_kh * v_kw;
191 auto valid_dividend = count_include_pad ? vaild_kernel : vaild_data;
192 if (valid_dividend == 0) {
193 MS_LOG(EXCEPTION) << "Dividend 'valid_dividend' should not be 0.";
194 }
195 float val = 1.0 / valid_dividend;
196 *tensor_data = float16(val);
197 ++tensor_data;
198 start_w += sw;
199 }
200 start_h += sh;
201 }
202 start_d += sd;
203 }
204 }
205 }
206 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
207 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
208 MS_EXCEPTION_IF_NULL(kernel_graph);
209 auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
210 kernel_graph->AddValueNodeToGraph(value_node);
211 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
212 return value_node;
213 }
214 } // namespace
215
ConstructFilterValueNode(const FuncGraphPtr & func_graph,float val,const std::vector<int64_t> & assist_shape,const std::vector<size_t> & infer_shape,int64_t cnt)216 AnfNodePtr ConstructFilterValueNode(const FuncGraphPtr &func_graph, float val, const std::vector<int64_t> &assist_shape,
217 const std::vector<size_t> &infer_shape, int64_t cnt) {
218 tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
219 MS_EXCEPTION_IF_NULL(assist_tensor);
220 TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
221 tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
222 assist_tensor->set_device_info(device_info);
223 auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
224 for (int64_t i = 0; i < cnt; ++i) {
225 for (int64_t j = 0; j < kC0; ++j) {
226 for (int64_t k = 0; k < kC0; ++k) {
227 float t = j == k ? val : 0;
228 *tensor_data = float16(t);
229 ++tensor_data;
230 }
231 }
232 }
233
234 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
235 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
236 MS_EXCEPTION_IF_NULL(kernel_graph);
237 auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
238 kernel_graph->AddValueNodeToGraph(value_node);
239 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
240 return value_node;
241 }
242
DefinePattern() const243 const BaseRef AvgPool3DFusion::DefinePattern() const {
244 VarPtr Xs = std::make_shared<SeqVar>();
245 return VectorRef({prim::kPrimAvgPool3D, Xs});
246 }
247
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const248 const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
249 const EquivPtr &) const {
250 MS_EXCEPTION_IF_NULL(func_graph);
251 MS_EXCEPTION_IF_NULL(node);
252 auto avg_pool_3d_node = node->cast<CNodePtr>();
253 MS_EXCEPTION_IF_NULL(avg_pool_3d_node);
254 if (avg_pool_3d_node->size() != kAvgPool3DInputNum + 1) {
255 MS_LOG(INFO) << "The node " << avg_pool_3d_node->DebugString() << " is not equal to " << kAvgPool3DInputNum
256 << " inputs. Can not do fusion.";
257 return nullptr;
258 }
259 auto dims_in = AnfAlgo::GetPrevNodeOutputInferShape(avg_pool_3d_node, 0);
260 auto dims_out = AnfAlgo::GetOutputInferShape(avg_pool_3d_node, 0);
261 if (dims_in.size() < k5DInferDims || dims_out.size() < k5DInferDims) {
262 MS_LOG(EXCEPTION) << "AvgPool3D's in_out infer shape dims can not be less " << k5DInferDims;
263 }
264 auto fn = SizeToLong(dims_in[kDim0]);
265 auto fc = SizeToLong(dims_in[kDim1]);
266 auto fd = SizeToLong(dims_in[kDim2]);
267 auto fh = SizeToLong(dims_in[kDim3]);
268 auto fw = SizeToLong(dims_in[kDim4]);
269 auto dout = SizeToLong(dims_out[kDim2]);
270 auto dh = SizeToLong(dims_out[kDim3]);
271 auto dw = SizeToLong(dims_out[kDim4]);
272 // kernel size
273 int64_t kd;
274 int64_t kh;
275 int64_t kw;
276 if (!GetKernelSize(avg_pool_3d_node, &kd, &kh, &kw)) {
277 MS_LOG(EXCEPTION) << "GetK kernel size failed";
278 }
279 // strides
280 int64_t sd;
281 int64_t sh;
282 int64_t sw;
283 if (!GetStrideSize(avg_pool_3d_node, &sd, &sh, &sw)) {
284 MS_LOG(EXCEPTION) << "GetK stride size failed";
285 }
286 std::vector<int64_t> pad_list;
287 bool count_include_pad = false;
288 bool ceil_mode = false;
289 int64_t divisor_override = 0;
290 GetAttrs(avg_pool_3d_node, &pad_list, &count_include_pad, &ceil_mode, &divisor_override);
291 if (IsVectorImpl(fh, fw, kh, kw, pad_list)) {
292 MS_LOG(INFO) << "No need fusion";
293 return nullptr;
294 }
295 std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAvgPool3D->name()))};
296 (void)new_inputs.insert(new_inputs.end(), avg_pool_3d_node->inputs().begin() + 1, avg_pool_3d_node->inputs().end());
297 // assist node 1
298 auto filter_node = ConstructFilter(func_graph, pad_list, fc, kd, kh, kw, ceil_mode, divisor_override);
299 new_inputs.push_back(filter_node);
300 MS_EXCEPTION_IF_NULL(filter_node);
301 // assist node 2
302 if ((!IsZeroPads(pad_list) || ceil_mode) && !divisor_override) {
303 auto multiplier = ConstructMultiplier(func_graph, fn, fc, fd, fh, fw, dout, dh, dw, kd, kh, kw, sd, sh, sw,
304 pad_list, count_include_pad);
305 new_inputs.push_back(multiplier);
306 }
307 auto new_3d = func_graph->NewCNode(new_inputs);
308 MS_EXCEPTION_IF_NULL(new_3d);
309 new_3d->set_scope(avg_pool_3d_node->scope());
310 new_3d->set_abstract(avg_pool_3d_node->abstract());
311 AnfAlgo::CopyNodeAttrs(avg_pool_3d_node, new_3d);
312 return new_3d;
313 }
314 } // namespace opt
315 } // namespace mindspore
316