1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <map>
22 #include <algorithm>
23 #include <vector>
24 #include "backend/session/anf_runtime_algorithm.h"
25
26 namespace mindspore {
27 namespace kernel {
28 namespace tbe {
29 namespace {
30 constexpr int64_t k16 = 16;
31 constexpr int64_t k4 = 4;
32 constexpr int kDims2 = 2;
33 enum k2Axis : int { kN = 0, kC, kH, kW, kNchwDims };
34 enum k3Axis : int { N_ncdhw = 0, C_ncdhw, D_ncdhw, H_ncdhw, W_ncdhw, kNcdhwDims };
PaddingRangeTo5D(const RangePair & ori_range)35 RangePair PaddingRangeTo5D(const RangePair &ori_range) {
36 RangePair dst_range(kNcdhwDims, std::pair<int64_t, int64_t>(1, 1));
37 switch (ori_range.size()) {
38 case N_ncdhw:
39 return ori_range;
40 case C_ncdhw:
41 dst_range[C_ncdhw] = ori_range[N_ncdhw];
42 break;
43 case D_ncdhw:
44 dst_range[C_ncdhw] = ori_range[N_ncdhw];
45 dst_range[D_ncdhw] = ori_range[C_ncdhw];
46 break;
47 case H_ncdhw:
48 dst_range[C_ncdhw] = ori_range[N_ncdhw];
49 dst_range[D_ncdhw] = ori_range[C_ncdhw];
50 dst_range[H_ncdhw] = ori_range[D_ncdhw];
51 break;
52 case W_ncdhw:
53 dst_range[C_ncdhw] = ori_range[N_ncdhw];
54 dst_range[D_ncdhw] = ori_range[C_ncdhw];
55 dst_range[H_ncdhw] = ori_range[D_ncdhw];
56 dst_range[W_ncdhw] = ori_range[H_ncdhw];
57 break;
58 default:
59 MS_LOG(EXCEPTION) << "Unexpected shape size = " << ori_range.size();
60 }
61 return dst_range;
62 }
63
PaddingRangeTo4D(const RangePair & ori_range)64 RangePair PaddingRangeTo4D(const RangePair &ori_range) {
65 RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
66 switch (ori_range.size()) {
67 case kN:
68 return dst_range;
69 case kC:
70 dst_range[kC] = ori_range[kN];
71 break;
72 case kH:
73 dst_range[kC] = ori_range[kN];
74 dst_range[kH] = ori_range[kC];
75 break;
76 case kW:
77 dst_range[kC] = ori_range[kN];
78 dst_range[kH] = ori_range[kC];
79 dst_range[kW] = ori_range[kH];
80 break;
81 case kNchwDims:
82 (void)std::copy(ori_range.begin(), ori_range.end(), dst_range.begin());
83 break;
84 default:
85 MS_LOG(EXCEPTION) << "Unexpected range size: " << ori_range.size();
86 }
87 return dst_range;
88 }
89
NchwRange(const RangePair & range)90 RangePair NchwRange(const RangePair &range) { return range; }
91
NhwcRange(const RangePair & range)92 RangePair NhwcRange(const RangePair &range) {
93 RangePair dst_range;
94 dst_range.push_back(range[kN]);
95 dst_range.push_back(range[kH]);
96 dst_range.push_back(range[kW]);
97 dst_range.push_back(range[kC]);
98 return dst_range;
99 }
100
HwchRange(const RangePair & range)101 RangePair HwchRange(const RangePair &range) {
102 RangePair dst_range;
103 dst_range.push_back(range[kH]);
104 dst_range.push_back(range[kW]);
105 dst_range.push_back(range[kC]);
106 dst_range.push_back(range[kN]);
107 return dst_range;
108 }
109
Nc1hwc0Range(const RangePair & range)110 RangePair Nc1hwc0Range(const RangePair &range) {
111 RangePair dst_range;
112 const std::pair<int64_t, int64_t> c0 = {k16, k16};
113 const std::pair<int64_t, int64_t> c1 = {(range[kC].first + k16 - 1) / k16, (range[kC].second + k16 - 1) / k16};
114 dst_range.push_back(range[kN]);
115 dst_range.push_back(c1);
116 dst_range.push_back(range[kH]);
117 dst_range.push_back(range[kW]);
118 dst_range.push_back(c0);
119 return dst_range;
120 }
121
Nc1hwc04Range(const RangePair & range)122 RangePair Nc1hwc04Range(const RangePair &range) {
123 RangePair dst_range;
124 const std::pair<int64_t, int64_t> c0 = {k4, k4};
125 const std::pair<int64_t, int64_t> c1 = {1, 1};
126 dst_range.push_back(range[kN]);
127 dst_range.push_back(c1);
128 dst_range.push_back(range[kH]);
129 dst_range.push_back(range[kW]);
130 dst_range.push_back(c0);
131 return dst_range;
132 }
133
FracNZRange(const RangePair & range)134 RangePair FracNZRange(const RangePair &range) {
135 RangePair dst_range;
136 if (range.size() < kDims2) {
137 MS_LOG(EXCEPTION) << "Format FracNZ can not support range size: " << range.size();
138 } else {
139 (void)std::copy(range.begin(), range.end() - kDims2, std::back_inserter(dst_range));
140 }
141 const std::pair<int64_t, int64_t> c0 = {k16, k16};
142 const std::pair<int64_t, int64_t> w1 = {(range[range.size() - 1].first - 1) / k16 + 1,
143 (range[range.size() - 1].second - 1) / k16 + 1};
144 const std::pair<int64_t, int64_t> h1 = {(range[range.size() - kDims2].first - 1) / k16 + 1,
145 (range[range.size() - kDims2].second - 1) / k16 + 1};
146 dst_range.push_back(w1);
147 dst_range.push_back(h1);
148 dst_range.push_back(c0);
149 dst_range.push_back(c0);
150 return dst_range;
151 }
152
FracZRange(const RangePair & range)153 RangePair FracZRange(const RangePair &range) {
154 RangePair dst_range;
155 const std::pair<int64_t, int64_t> c0 = {k16, k16};
156 const std::pair<int64_t, int64_t> cout16 = {((range[kN].first + k16 - 1) / k16) * k16,
157 ((range[kN].second + k16 - 1) / k16) * k16};
158 const std::pair<int64_t, int64_t> cin16 = {((range[kC].first + k16 - 1) / k16) * k16,
159 ((range[kC].second + k16 - 1) / k16) * k16};
160 const std::pair<int64_t, int64_t> r0 = {range[kH].first * range[kW].first * cin16.first / k16,
161 range[kH].second * range[kW].second * cin16.second / k16};
162 const std::pair<int64_t, int64_t> r1 = {cout16.first / k16, cout16.second / k16};
163 dst_range.push_back(r0);
164 dst_range.push_back(r1);
165 dst_range.push_back(c0);
166 dst_range.push_back(c0);
167 return dst_range;
168 }
169
FracZC04Range(const RangePair & range)170 RangePair FracZC04Range(const RangePair &range) {
171 RangePair dst_range;
172 const std::pair<int64_t, int64_t> c0 = {k4, k4};
173 const std::pair<int64_t, int64_t> c16 = {k16, k16};
174 const std::pair<int64_t, int64_t> first_dim = {(c0.first * range[kH].first * range[kW].first + k16 - 1) / k16,
175 (c0.second * range[kH].second * range[kW].second + k16 - 1) / k16};
176 const std::pair<int64_t, int64_t> no = {(range[kN].first + k16 - 1) / k16, (range[kN].second + k16 - 1) / k16};
177 dst_range.push_back(first_dim);
178 dst_range.push_back(no);
179 dst_range.push_back(c16);
180 dst_range.push_back(c16);
181 return dst_range;
182 }
183
FracZNLSTMCRange(const RangePair & range)184 RangePair FracZNLSTMCRange(const RangePair &range) {
185 RangePair dst_range;
186 const std::pair<int64_t, int64_t> c0 = {k4, k4};
187 const std::pair<int64_t, int64_t> c16 = {k4, k4};
188 const std::pair<int64_t, int64_t> h = {range[kN].first / c0.first, range[kN].second / c0.second};
189 const std::pair<int64_t, int64_t> i = {range[kC].first - h.first, range[kC].second - h.second};
190 const std::pair<int64_t, int64_t> first_dim = {(i.first + k16 - 1) / k16 + (h.first + k16 - 1) / k16,
191 (i.second + k16 - 1) / k16 + (h.second + k16 - 1) / k16};
192 const std::pair<int64_t, int64_t> second = {c0.first * ((h.first + k16 - 1) / k16),
193 c0.second * ((h.second + k16 - 1) / k16)};
194 dst_range.push_back(first_dim);
195 dst_range.push_back(second);
196 dst_range.push_back(c16);
197 dst_range.push_back(c16);
198 return dst_range;
199 }
200
C1hwncoc0Range(const RangePair & range)201 RangePair C1hwncoc0Range(const RangePair &range) {
202 RangePair dst_range;
203 const std::pair<int64_t, int64_t> c0 = {k16, k16};
204 const std::pair<int64_t, int64_t> r1 = {(range[kC].first - 1) / k16 + 1, (range[kC].second - 1) / k16 + 1};
205 dst_range.push_back(r1);
206 dst_range.push_back(range[kH]);
207 dst_range.push_back(range[kW]);
208 dst_range.push_back(range[kN]);
209 dst_range.push_back(c0);
210 dst_range.push_back(c0);
211 return dst_range;
212 }
213
NcdhwRange(const RangePair & range)214 RangePair NcdhwRange(const RangePair &range) { return range; }
215
NdhwcRange(const RangePair & range)216 RangePair NdhwcRange(const RangePair &range) {
217 RangePair dst_range;
218 dst_range.push_back(range[N_ncdhw]);
219 dst_range.push_back(range[D_ncdhw]);
220 dst_range.push_back(range[H_ncdhw]);
221 dst_range.push_back(range[W_ncdhw]);
222 dst_range.push_back(range[C_ncdhw]);
223 return range;
224 }
225
Ndc1hwc0Range(const RangePair & range)226 RangePair Ndc1hwc0Range(const RangePair &range) {
227 RangePair dst_range;
228 const std::pair<int64_t, int64_t> c0 = {k16, k16};
229 const std::pair<int64_t, int64_t> c1 = {(range[C_ncdhw].first + k16 - 1) / k16,
230 (range[C_ncdhw].second + k16 - 1) / k16};
231 dst_range.push_back(range[N_ncdhw]);
232 dst_range.push_back(range[D_ncdhw]);
233 dst_range.push_back(c1);
234 dst_range.push_back(range[H_ncdhw]);
235 dst_range.push_back(range[W_ncdhw]);
236 dst_range.push_back(c0);
237 return dst_range;
238 }
239
FracZ3DRange(const RangePair & range)240 RangePair FracZ3DRange(const RangePair &range) {
241 RangePair dst_range;
242 const std::pair<int64_t, int64_t> c0 = {k16, k16};
243 const std::pair<int64_t, int64_t> c1 = {(range[C_ncdhw].first + k16 - 1) / k16,
244 (range[C_ncdhw].second + k16 - 1) / k16};
245 const std::pair<int64_t, int64_t> n1 = {(range[N_ncdhw].first + k16 - 1) / k16,
246 (range[N_ncdhw].second + k16 - 1) / k16};
247 const int64_t r1_0 = range[D_ncdhw].first * c1.first * range[H_ncdhw].first * range[W_ncdhw].first;
248 const int64_t r1_1 = range[D_ncdhw].second * c1.second * range[H_ncdhw].second * range[W_ncdhw].second;
249 const std::pair<int64_t, int64_t> r1 = {r1_0, r1_1};
250 dst_range.push_back(r1);
251 dst_range.push_back(n1);
252 dst_range.push_back(c1);
253 dst_range.push_back(c0);
254 return dst_range;
255 }
256
DynamicShapeRangeTrans(const RangePair & ori_range,const std::string & format)257 RangePair DynamicShapeRangeTrans(const RangePair &ori_range, const std::string &format) {
258 using RangeTransfer = std::function<RangePair(const RangePair &)>;
259 const std::map<std::string, RangeTransfer> format_range_map{
260 {kOpFormat_NCHW, NchwRange},
261 {kOpFormat_NHWC, NhwcRange},
262 {kOpFormat_HWCN, HwchRange},
263 {kOpFormat_NC1HWC0, Nc1hwc0Range},
264 {kOpFormat_NC1HWC0_C04, Nc1hwc04Range},
265 {kOpFormat_FRAC_Z, FracZRange},
266 {kOpFormat_FRACTAL_Z_C04, FracZC04Range},
267 {kOpFormat_C1HWNCoC0, C1hwncoc0Range},
268 {kOpFormat_NCDHW, NcdhwRange},
269 {kOpFormat_NDHWC, NdhwcRange},
270 {kOpFormat_NDC1HWC0, Ndc1hwc0Range},
271 {kOpFormat_FRACTAL_Z_3D, FracZ3DRange},
272 };
273
274 if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
275 return ori_range;
276 }
277 if (format == kOpFormat_FRACTAL_ZN_LSTM) {
278 return FracZNLSTMCRange(ori_range);
279 }
280 if (format == kOpFormat_FRAC_NZ) {
281 return FracNZRange(ori_range);
282 }
283 auto temp_range = ori_range;
284 if (ori_range.size() < kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
285 MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
286 temp_range = PaddingRangeTo4D(ori_range);
287 }
288 if (ori_range.size() < kNcdhwDims && k3DFormatSet.find(format) != k3DFormatSet.end()) {
289 MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
290 temp_range = PaddingRangeTo5D(ori_range);
291 }
292 auto iter = format_range_map.find(format);
293 if (iter == format_range_map.end()) {
294 MS_LOG(WARNING) << "Can not find a supported format: " << format << ", using default range";
295 return ori_range;
296 }
297 return iter->second(temp_range);
298 }
299 } // namespace
300
IsDynamicShapeNode(const CNodePtr & cnode)301 bool TbeDynamicShapeUtil::IsDynamicShapeNode(const CNodePtr &cnode) {
302 MS_EXCEPTION_IF_NULL(cnode);
303 auto input_num = AnfAlgo ::GetInputTensorNum(cnode);
304 for (size_t i = 0; i < input_num; ++i) {
305 auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i);
306 if (std::any_of(input_shape.begin(), input_shape.end(), [](const size_t &dim) { return dim < 0; })) {
307 MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
308 return true;
309 }
310 }
311 auto output_num = AnfAlgo ::GetOutputTensorNum(cnode);
312 for (size_t i = 0; i < output_num; ++i) {
313 auto output_shape = AnfAlgo::GetOutputInferShape(cnode, i);
314 if (std::any_of(output_shape.begin(), output_shape.end(), [](const size_t &dim) { return dim < 0; })) {
315 MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
316 return true;
317 }
318 }
319 return false;
320 }
321
IsDynamicShapeNode(const AnfNodePtr & anf_node)322 bool TbeDynamicShapeUtil::IsDynamicShapeNode(const AnfNodePtr &anf_node) {
323 MS_EXCEPTION_IF_NULL(anf_node);
324 if (anf_node->isa<CNode>()) {
325 auto cnode = anf_node->cast<CNodePtr>();
326 MS_EXCEPTION_IF_NULL(cnode);
327 return IsDynamicShapeNode(cnode);
328 }
329 return false;
330 }
331
SetDynamicShapeAttr(const CNodePtr & cnode)332 void TbeDynamicShapeUtil::SetDynamicShapeAttr(const CNodePtr &cnode) {
333 MS_EXCEPTION_IF_NULL(cnode);
334 auto is_dyanmic_shape = IsDynamicShapeNode(cnode);
335 AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(is_dyanmic_shape), cnode);
336 }
337
GetDynamicShapeAttr(const AnfNodePtr & anf_node)338 bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const AnfNodePtr &anf_node) {
339 MS_EXCEPTION_IF_NULL(anf_node);
340 if (anf_node->isa<CNode>()) {
341 auto cnode = anf_node->cast<CNodePtr>();
342 MS_EXCEPTION_IF_NULL(cnode);
343 return GetDynamicShapeAttr(cnode);
344 }
345 return false;
346 }
347
GetDynamicShapeAttr(const CNodePtr & cnode)348 bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const CNodePtr &cnode) {
349 MS_EXCEPTION_IF_NULL(cnode);
350 auto is_dynamic_shape = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode);
351 if (!is_dynamic_shape) {
352 return false;
353 }
354 is_dynamic_shape = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrIsDynamicShape);
355 return is_dynamic_shape;
356 }
357
FindOp(const std::string & op_name,const AnfNodePtr & anf_node)358 std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const AnfNodePtr &anf_node) {
359 MS_EXCEPTION_IF_NULL(anf_node);
360 if (anf_node->isa<CNode>()) {
361 auto cnode = anf_node->cast<CNodePtr>();
362 MS_EXCEPTION_IF_NULL(cnode);
363 return FindOp(op_name, cnode);
364 }
365 return nullptr;
366 }
367
FindOp(const std::string & op_name,const CNodePtr & cnode)368 std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const CNodePtr &cnode) {
369 MS_EXCEPTION_IF_NULL(cnode);
370 auto is_dynamic_shape = GetDynamicShapeAttr(cnode);
371 return mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE, is_dynamic_shape);
372 }
373
GetInputDynamicRange(const AnfNodePtr & anf_node,size_t index,const std::string & def_format)374 RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index,
375 const std::string &def_format) {
376 MS_EXCEPTION_IF_NULL(anf_node);
377 auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
378 MS_EXCEPTION_IF_NULL(kernel_info);
379 auto format =
380 kernel_info->select_kernel_build_info() == nullptr ? def_format : AnfAlgo::GetInputFormat(anf_node, index);
381 auto input_range_min = AnfAlgo::GetInputMinShape(anf_node, index);
382 auto input_range_max = AnfAlgo::GetInputMaxShape(anf_node, index);
383 if (input_range_min.size() != input_range_max.size()) {
384 MS_EXCEPTION(ArgumentError) << "Input range size is not equal, min size: " << input_range_min.size()
385 << "max size: " << input_range_max.size();
386 }
387 if (input_range_min.empty() && input_range_max.empty()) {
388 RangePair ret = {{1, 1}};
389 return DynamicShapeRangeTrans(ret, format);
390 }
391 RangePair ret;
392 for (size_t i = 0; i < input_range_min.size(); ++i) {
393 ret.emplace_back(input_range_min[i], input_range_max[i]);
394 }
395 return DynamicShapeRangeTrans(ret, format);
396 }
397
GetOutputDynamicRange(const AnfNodePtr & anf_node,size_t index,const std::string & def_format)398 RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index,
399 const std::string &def_format) {
400 MS_EXCEPTION_IF_NULL(anf_node);
401 auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
402 MS_EXCEPTION_IF_NULL(kernel_info);
403 auto format =
404 kernel_info->select_kernel_build_info() == nullptr ? def_format : AnfAlgo::GetOutputFormat(anf_node, index);
405 auto output_range_min = AnfAlgo::GetOutputMinShape(anf_node, index);
406 auto output_range_max = AnfAlgo::GetOutputMaxShape(anf_node, index);
407 if (output_range_min.size() != output_range_max.size()) {
408 MS_EXCEPTION(ArgumentError) << "Onput range size is not equal, min size: " << output_range_min.size()
409 << "max size: " << output_range_max.size();
410 }
411 if (output_range_max.empty() && output_range_min.empty()) {
412 RangePair ret = {{1, 1}};
413 return DynamicShapeRangeTrans(ret, format);
414 }
415 RangePair ret;
416 for (size_t i = 0; i < output_range_min.size(); ++i) {
417 ret.emplace_back(output_range_min[i], output_range_max[i]);
418 }
419 return DynamicShapeRangeTrans(ret, format);
420 }
421 } // namespace tbe
422 } // namespace kernel
423 } // namespace mindspore
424