• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <map>
22 #include <algorithm>
23 #include <vector>
24 #include "backend/session/anf_runtime_algorithm.h"
25 
26 namespace mindspore {
27 namespace kernel {
28 namespace tbe {
29 namespace {
30 constexpr int64_t k16 = 16;
31 constexpr int64_t k4 = 4;
32 constexpr int kDims2 = 2;
33 enum k2Axis : int { kN = 0, kC, kH, kW, kNchwDims };
34 enum k3Axis : int { N_ncdhw = 0, C_ncdhw, D_ncdhw, H_ncdhw, W_ncdhw, kNcdhwDims };
PaddingRangeTo5D(const RangePair & ori_range)35 RangePair PaddingRangeTo5D(const RangePair &ori_range) {
36   RangePair dst_range(kNcdhwDims, std::pair<int64_t, int64_t>(1, 1));
37   switch (ori_range.size()) {
38     case N_ncdhw:
39       return ori_range;
40     case C_ncdhw:
41       dst_range[C_ncdhw] = ori_range[N_ncdhw];
42       break;
43     case D_ncdhw:
44       dst_range[C_ncdhw] = ori_range[N_ncdhw];
45       dst_range[D_ncdhw] = ori_range[C_ncdhw];
46       break;
47     case H_ncdhw:
48       dst_range[C_ncdhw] = ori_range[N_ncdhw];
49       dst_range[D_ncdhw] = ori_range[C_ncdhw];
50       dst_range[H_ncdhw] = ori_range[D_ncdhw];
51       break;
52     case W_ncdhw:
53       dst_range[C_ncdhw] = ori_range[N_ncdhw];
54       dst_range[D_ncdhw] = ori_range[C_ncdhw];
55       dst_range[H_ncdhw] = ori_range[D_ncdhw];
56       dst_range[W_ncdhw] = ori_range[H_ncdhw];
57       break;
58     default:
59       MS_LOG(EXCEPTION) << "Unexpected shape size = " << ori_range.size();
60   }
61   return dst_range;
62 }
63 
PaddingRangeTo4D(const RangePair & ori_range)64 RangePair PaddingRangeTo4D(const RangePair &ori_range) {
65   RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
66   switch (ori_range.size()) {
67     case kN:
68       return dst_range;
69     case kC:
70       dst_range[kC] = ori_range[kN];
71       break;
72     case kH:
73       dst_range[kC] = ori_range[kN];
74       dst_range[kH] = ori_range[kC];
75       break;
76     case kW:
77       dst_range[kC] = ori_range[kN];
78       dst_range[kH] = ori_range[kC];
79       dst_range[kW] = ori_range[kH];
80       break;
81     case kNchwDims:
82       (void)std::copy(ori_range.begin(), ori_range.end(), dst_range.begin());
83       break;
84     default:
85       MS_LOG(EXCEPTION) << "Unexpected range size: " << ori_range.size();
86   }
87   return dst_range;
88 }
89 
NchwRange(const RangePair & range)90 RangePair NchwRange(const RangePair &range) { return range; }
91 
NhwcRange(const RangePair & range)92 RangePair NhwcRange(const RangePair &range) {
93   RangePair dst_range;
94   dst_range.push_back(range[kN]);
95   dst_range.push_back(range[kH]);
96   dst_range.push_back(range[kW]);
97   dst_range.push_back(range[kC]);
98   return dst_range;
99 }
100 
HwchRange(const RangePair & range)101 RangePair HwchRange(const RangePair &range) {
102   RangePair dst_range;
103   dst_range.push_back(range[kH]);
104   dst_range.push_back(range[kW]);
105   dst_range.push_back(range[kC]);
106   dst_range.push_back(range[kN]);
107   return dst_range;
108 }
109 
Nc1hwc0Range(const RangePair & range)110 RangePair Nc1hwc0Range(const RangePair &range) {
111   RangePair dst_range;
112   const std::pair<int64_t, int64_t> c0 = {k16, k16};
113   const std::pair<int64_t, int64_t> c1 = {(range[kC].first + k16 - 1) / k16, (range[kC].second + k16 - 1) / k16};
114   dst_range.push_back(range[kN]);
115   dst_range.push_back(c1);
116   dst_range.push_back(range[kH]);
117   dst_range.push_back(range[kW]);
118   dst_range.push_back(c0);
119   return dst_range;
120 }
121 
Nc1hwc04Range(const RangePair & range)122 RangePair Nc1hwc04Range(const RangePair &range) {
123   RangePair dst_range;
124   const std::pair<int64_t, int64_t> c0 = {k4, k4};
125   const std::pair<int64_t, int64_t> c1 = {1, 1};
126   dst_range.push_back(range[kN]);
127   dst_range.push_back(c1);
128   dst_range.push_back(range[kH]);
129   dst_range.push_back(range[kW]);
130   dst_range.push_back(c0);
131   return dst_range;
132 }
133 
FracNZRange(const RangePair & range)134 RangePair FracNZRange(const RangePair &range) {
135   RangePair dst_range;
136   if (range.size() < kDims2) {
137     MS_LOG(EXCEPTION) << "Format FracNZ can not support range size: " << range.size();
138   } else {
139     (void)std::copy(range.begin(), range.end() - kDims2, std::back_inserter(dst_range));
140   }
141   const std::pair<int64_t, int64_t> c0 = {k16, k16};
142   const std::pair<int64_t, int64_t> w1 = {(range[range.size() - 1].first - 1) / k16 + 1,
143                                           (range[range.size() - 1].second - 1) / k16 + 1};
144   const std::pair<int64_t, int64_t> h1 = {(range[range.size() - kDims2].first - 1) / k16 + 1,
145                                           (range[range.size() - kDims2].second - 1) / k16 + 1};
146   dst_range.push_back(w1);
147   dst_range.push_back(h1);
148   dst_range.push_back(c0);
149   dst_range.push_back(c0);
150   return dst_range;
151 }
152 
FracZRange(const RangePair & range)153 RangePair FracZRange(const RangePair &range) {
154   RangePair dst_range;
155   const std::pair<int64_t, int64_t> c0 = {k16, k16};
156   const std::pair<int64_t, int64_t> cout16 = {((range[kN].first + k16 - 1) / k16) * k16,
157                                               ((range[kN].second + k16 - 1) / k16) * k16};
158   const std::pair<int64_t, int64_t> cin16 = {((range[kC].first + k16 - 1) / k16) * k16,
159                                              ((range[kC].second + k16 - 1) / k16) * k16};
160   const std::pair<int64_t, int64_t> r0 = {range[kH].first * range[kW].first * cin16.first / k16,
161                                           range[kH].second * range[kW].second * cin16.second / k16};
162   const std::pair<int64_t, int64_t> r1 = {cout16.first / k16, cout16.second / k16};
163   dst_range.push_back(r0);
164   dst_range.push_back(r1);
165   dst_range.push_back(c0);
166   dst_range.push_back(c0);
167   return dst_range;
168 }
169 
FracZC04Range(const RangePair & range)170 RangePair FracZC04Range(const RangePair &range) {
171   RangePair dst_range;
172   const std::pair<int64_t, int64_t> c0 = {k4, k4};
173   const std::pair<int64_t, int64_t> c16 = {k16, k16};
174   const std::pair<int64_t, int64_t> first_dim = {(c0.first * range[kH].first * range[kW].first + k16 - 1) / k16,
175                                                  (c0.second * range[kH].second * range[kW].second + k16 - 1) / k16};
176   const std::pair<int64_t, int64_t> no = {(range[kN].first + k16 - 1) / k16, (range[kN].second + k16 - 1) / k16};
177   dst_range.push_back(first_dim);
178   dst_range.push_back(no);
179   dst_range.push_back(c16);
180   dst_range.push_back(c16);
181   return dst_range;
182 }
183 
FracZNLSTMCRange(const RangePair & range)184 RangePair FracZNLSTMCRange(const RangePair &range) {
185   RangePair dst_range;
186   const std::pair<int64_t, int64_t> c0 = {k4, k4};
187   const std::pair<int64_t, int64_t> c16 = {k4, k4};
188   const std::pair<int64_t, int64_t> h = {range[kN].first / c0.first, range[kN].second / c0.second};
189   const std::pair<int64_t, int64_t> i = {range[kC].first - h.first, range[kC].second - h.second};
190   const std::pair<int64_t, int64_t> first_dim = {(i.first + k16 - 1) / k16 + (h.first + k16 - 1) / k16,
191                                                  (i.second + k16 - 1) / k16 + (h.second + k16 - 1) / k16};
192   const std::pair<int64_t, int64_t> second = {c0.first * ((h.first + k16 - 1) / k16),
193                                               c0.second * ((h.second + k16 - 1) / k16)};
194   dst_range.push_back(first_dim);
195   dst_range.push_back(second);
196   dst_range.push_back(c16);
197   dst_range.push_back(c16);
198   return dst_range;
199 }
200 
C1hwncoc0Range(const RangePair & range)201 RangePair C1hwncoc0Range(const RangePair &range) {
202   RangePair dst_range;
203   const std::pair<int64_t, int64_t> c0 = {k16, k16};
204   const std::pair<int64_t, int64_t> r1 = {(range[kC].first - 1) / k16 + 1, (range[kC].second - 1) / k16 + 1};
205   dst_range.push_back(r1);
206   dst_range.push_back(range[kH]);
207   dst_range.push_back(range[kW]);
208   dst_range.push_back(range[kN]);
209   dst_range.push_back(c0);
210   dst_range.push_back(c0);
211   return dst_range;
212 }
213 
NcdhwRange(const RangePair & range)214 RangePair NcdhwRange(const RangePair &range) { return range; }
215 
NdhwcRange(const RangePair & range)216 RangePair NdhwcRange(const RangePair &range) {
217   RangePair dst_range;
218   dst_range.push_back(range[N_ncdhw]);
219   dst_range.push_back(range[D_ncdhw]);
220   dst_range.push_back(range[H_ncdhw]);
221   dst_range.push_back(range[W_ncdhw]);
222   dst_range.push_back(range[C_ncdhw]);
223   return range;
224 }
225 
Ndc1hwc0Range(const RangePair & range)226 RangePair Ndc1hwc0Range(const RangePair &range) {
227   RangePair dst_range;
228   const std::pair<int64_t, int64_t> c0 = {k16, k16};
229   const std::pair<int64_t, int64_t> c1 = {(range[C_ncdhw].first + k16 - 1) / k16,
230                                           (range[C_ncdhw].second + k16 - 1) / k16};
231   dst_range.push_back(range[N_ncdhw]);
232   dst_range.push_back(range[D_ncdhw]);
233   dst_range.push_back(c1);
234   dst_range.push_back(range[H_ncdhw]);
235   dst_range.push_back(range[W_ncdhw]);
236   dst_range.push_back(c0);
237   return dst_range;
238 }
239 
FracZ3DRange(const RangePair & range)240 RangePair FracZ3DRange(const RangePair &range) {
241   RangePair dst_range;
242   const std::pair<int64_t, int64_t> c0 = {k16, k16};
243   const std::pair<int64_t, int64_t> c1 = {(range[C_ncdhw].first + k16 - 1) / k16,
244                                           (range[C_ncdhw].second + k16 - 1) / k16};
245   const std::pair<int64_t, int64_t> n1 = {(range[N_ncdhw].first + k16 - 1) / k16,
246                                           (range[N_ncdhw].second + k16 - 1) / k16};
247   const int64_t r1_0 = range[D_ncdhw].first * c1.first * range[H_ncdhw].first * range[W_ncdhw].first;
248   const int64_t r1_1 = range[D_ncdhw].second * c1.second * range[H_ncdhw].second * range[W_ncdhw].second;
249   const std::pair<int64_t, int64_t> r1 = {r1_0, r1_1};
250   dst_range.push_back(r1);
251   dst_range.push_back(n1);
252   dst_range.push_back(c1);
253   dst_range.push_back(c0);
254   return dst_range;
255 }
256 
DynamicShapeRangeTrans(const RangePair & ori_range,const std::string & format)257 RangePair DynamicShapeRangeTrans(const RangePair &ori_range, const std::string &format) {
258   using RangeTransfer = std::function<RangePair(const RangePair &)>;
259   const std::map<std::string, RangeTransfer> format_range_map{
260     {kOpFormat_NCHW, NchwRange},
261     {kOpFormat_NHWC, NhwcRange},
262     {kOpFormat_HWCN, HwchRange},
263     {kOpFormat_NC1HWC0, Nc1hwc0Range},
264     {kOpFormat_NC1HWC0_C04, Nc1hwc04Range},
265     {kOpFormat_FRAC_Z, FracZRange},
266     {kOpFormat_FRACTAL_Z_C04, FracZC04Range},
267     {kOpFormat_C1HWNCoC0, C1hwncoc0Range},
268     {kOpFormat_NCDHW, NcdhwRange},
269     {kOpFormat_NDHWC, NdhwcRange},
270     {kOpFormat_NDC1HWC0, Ndc1hwc0Range},
271     {kOpFormat_FRACTAL_Z_3D, FracZ3DRange},
272   };
273 
274   if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
275     return ori_range;
276   }
277   if (format == kOpFormat_FRACTAL_ZN_LSTM) {
278     return FracZNLSTMCRange(ori_range);
279   }
280   if (format == kOpFormat_FRAC_NZ) {
281     return FracNZRange(ori_range);
282   }
283   auto temp_range = ori_range;
284   if (ori_range.size() < kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
285     MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
286     temp_range = PaddingRangeTo4D(ori_range);
287   }
288   if (ori_range.size() < kNcdhwDims && k3DFormatSet.find(format) != k3DFormatSet.end()) {
289     MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
290     temp_range = PaddingRangeTo5D(ori_range);
291   }
292   auto iter = format_range_map.find(format);
293   if (iter == format_range_map.end()) {
294     MS_LOG(WARNING) << "Can not find a supported format: " << format << ", using default range";
295     return ori_range;
296   }
297   return iter->second(temp_range);
298 }
299 }  // namespace
300 
IsDynamicShapeNode(const CNodePtr & cnode)301 bool TbeDynamicShapeUtil::IsDynamicShapeNode(const CNodePtr &cnode) {
302   MS_EXCEPTION_IF_NULL(cnode);
303   auto input_num = AnfAlgo ::GetInputTensorNum(cnode);
304   for (size_t i = 0; i < input_num; ++i) {
305     auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i);
306     if (std::any_of(input_shape.begin(), input_shape.end(), [](const size_t &dim) { return dim < 0; })) {
307       MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
308       return true;
309     }
310   }
311   auto output_num = AnfAlgo ::GetOutputTensorNum(cnode);
312   for (size_t i = 0; i < output_num; ++i) {
313     auto output_shape = AnfAlgo::GetOutputInferShape(cnode, i);
314     if (std::any_of(output_shape.begin(), output_shape.end(), [](const size_t &dim) { return dim < 0; })) {
315       MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
316       return true;
317     }
318   }
319   return false;
320 }
321 
IsDynamicShapeNode(const AnfNodePtr & anf_node)322 bool TbeDynamicShapeUtil::IsDynamicShapeNode(const AnfNodePtr &anf_node) {
323   MS_EXCEPTION_IF_NULL(anf_node);
324   if (anf_node->isa<CNode>()) {
325     auto cnode = anf_node->cast<CNodePtr>();
326     MS_EXCEPTION_IF_NULL(cnode);
327     return IsDynamicShapeNode(cnode);
328   }
329   return false;
330 }
331 
SetDynamicShapeAttr(const CNodePtr & cnode)332 void TbeDynamicShapeUtil::SetDynamicShapeAttr(const CNodePtr &cnode) {
333   MS_EXCEPTION_IF_NULL(cnode);
334   auto is_dyanmic_shape = IsDynamicShapeNode(cnode);
335   AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(is_dyanmic_shape), cnode);
336 }
337 
GetDynamicShapeAttr(const AnfNodePtr & anf_node)338 bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const AnfNodePtr &anf_node) {
339   MS_EXCEPTION_IF_NULL(anf_node);
340   if (anf_node->isa<CNode>()) {
341     auto cnode = anf_node->cast<CNodePtr>();
342     MS_EXCEPTION_IF_NULL(cnode);
343     return GetDynamicShapeAttr(cnode);
344   }
345   return false;
346 }
347 
GetDynamicShapeAttr(const CNodePtr & cnode)348 bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const CNodePtr &cnode) {
349   MS_EXCEPTION_IF_NULL(cnode);
350   auto is_dynamic_shape = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode);
351   if (!is_dynamic_shape) {
352     return false;
353   }
354   is_dynamic_shape = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrIsDynamicShape);
355   return is_dynamic_shape;
356 }
357 
FindOp(const std::string & op_name,const AnfNodePtr & anf_node)358 std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const AnfNodePtr &anf_node) {
359   MS_EXCEPTION_IF_NULL(anf_node);
360   if (anf_node->isa<CNode>()) {
361     auto cnode = anf_node->cast<CNodePtr>();
362     MS_EXCEPTION_IF_NULL(cnode);
363     return FindOp(op_name, cnode);
364   }
365   return nullptr;
366 }
367 
FindOp(const std::string & op_name,const CNodePtr & cnode)368 std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const CNodePtr &cnode) {
369   MS_EXCEPTION_IF_NULL(cnode);
370   auto is_dynamic_shape = GetDynamicShapeAttr(cnode);
371   return mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE, is_dynamic_shape);
372 }
373 
GetInputDynamicRange(const AnfNodePtr & anf_node,size_t index,const std::string & def_format)374 RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index,
375                                                     const std::string &def_format) {
376   MS_EXCEPTION_IF_NULL(anf_node);
377   auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
378   MS_EXCEPTION_IF_NULL(kernel_info);
379   auto format =
380     kernel_info->select_kernel_build_info() == nullptr ? def_format : AnfAlgo::GetInputFormat(anf_node, index);
381   auto input_range_min = AnfAlgo::GetInputMinShape(anf_node, index);
382   auto input_range_max = AnfAlgo::GetInputMaxShape(anf_node, index);
383   if (input_range_min.size() != input_range_max.size()) {
384     MS_EXCEPTION(ArgumentError) << "Input range size is not equal, min size: " << input_range_min.size()
385                                 << "max size: " << input_range_max.size();
386   }
387   if (input_range_min.empty() && input_range_max.empty()) {
388     RangePair ret = {{1, 1}};
389     return DynamicShapeRangeTrans(ret, format);
390   }
391   RangePair ret;
392   for (size_t i = 0; i < input_range_min.size(); ++i) {
393     ret.emplace_back(input_range_min[i], input_range_max[i]);
394   }
395   return DynamicShapeRangeTrans(ret, format);
396 }
397 
GetOutputDynamicRange(const AnfNodePtr & anf_node,size_t index,const std::string & def_format)398 RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index,
399                                                      const std::string &def_format) {
400   MS_EXCEPTION_IF_NULL(anf_node);
401   auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
402   MS_EXCEPTION_IF_NULL(kernel_info);
403   auto format =
404     kernel_info->select_kernel_build_info() == nullptr ? def_format : AnfAlgo::GetOutputFormat(anf_node, index);
405   auto output_range_min = AnfAlgo::GetOutputMinShape(anf_node, index);
406   auto output_range_max = AnfAlgo::GetOutputMaxShape(anf_node, index);
407   if (output_range_min.size() != output_range_max.size()) {
408     MS_EXCEPTION(ArgumentError) << "Onput range size is not equal, min size: " << output_range_min.size()
409                                 << "max size: " << output_range_max.size();
410   }
411   if (output_range_max.empty() && output_range_min.empty()) {
412     RangePair ret = {{1, 1}};
413     return DynamicShapeRangeTrans(ret, format);
414   }
415   RangePair ret;
416   for (size_t i = 0; i < output_range_min.size(); ++i) {
417     ret.emplace_back(output_range_min[i], output_range_max[i]);
418   }
419   return DynamicShapeRangeTrans(ret, format);
420 }
421 }  // namespace tbe
422 }  // namespace kernel
423 }  // namespace mindspore
424