• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/expander/bprop/grad_ops/common_utils.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include "utils/anf_utils.h"
27 #include "utils/check_convert_utils.h"
28 #include "utils/ms_context.h"
29 #include "ops/op_utils.h"
30 
31 namespace mindspore::expander::bprop {
ReturnZeros(BpropBuilder * ib)32 NodePtrList ReturnZeros(BpropBuilder *ib) {
33   const auto &inputs = ib->GetInputs();
34   if (inputs.size() <= kDim2) {
35     MS_LOG(EXCEPTION) << "Bprop's inputs size should be greater than 2 (includes out and dout), but got "
36                       << inputs.size();
37   }
38   auto output_num = inputs.size() - kDim2;
39   NodePtrList outputs(output_num);
40   for (size_t i = 0; i < output_num; ++i) {
41     outputs[i] = ib->OutZeros(inputs[i]);
42   }
43   return outputs;
44 }
45 
46 namespace {
DynBroadcastGradientArgs(const std::vector<int64_t> & x_shape,const std::vector<int64_t> & y_shape)47 std::pair<std::vector<bool>, std::vector<std::vector<int64_t>>> DynBroadcastGradientArgs(
48   const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape) {
49   auto x_size = x_shape.size();
50   auto y_size = y_shape.size();
51   ShapeVector shape[kDim2] = {x_shape, y_shape};
52   auto n = std::max(x_size, y_size);
53   std::vector<bool> need_shapecalc = {false, false};
54   std::vector<std::vector<int64_t>> reduce_axis(kDim2);
55   if (IsDynamicRank(shape[0]) || IsDynamicRank(shape[1])) {
56     return {{true, true}, reduce_axis};
57   }
58   for (size_t i = n; i >= 1; i--) {
59     int64_t dim_value[2] = {x_size < i ? 1 : shape[0][x_size - i], y_size < i ? 1 : shape[1][y_size - i]};
60     const int64_t reduce_idx = SizeToLong(n - i);
61     if (dim_value[1] == dim_value[0]) {
62       if (dim_value[0] == -1) {
63         need_shapecalc[0] = need_shapecalc[1] = true;
64         break;
65       }
66     } else if (dim_value[1] > 0 && dim_value[0] > 0) {
67       for (size_t j = 0; j < kDim2; j++) {
68         if (dim_value[j] == 1) {
69           (void)reduce_axis[j].emplace_back(reduce_idx);
70         }
71       }
72     } else {
73       for (size_t j = 0; j < kDim2; j++) {
74         if (dim_value[j] == -1) {
75           if (dim_value[j ^ 1] == 1) {
76             (void)reduce_axis[j ^ 1].emplace_back(reduce_idx);
77           } else {
78             need_shapecalc[j] = true;
79             if (need_shapecalc[j ^ 1] == need_shapecalc[j]) {
80               break;
81             }
82             (void)reduce_axis[j].emplace_back(reduce_idx);
83           }
84         }
85       }
86     }
87   }
88   return {need_shapecalc, reduce_axis};
89 }
90 
DynBinopGradCommon(BpropBuilder * ib,const NodePtr & x,const NodePtr & y,const NodePtr & dx,const NodePtr & dy,size_t shift=0UL)91 NodePtrList DynBinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
92                                const NodePtr &dy, size_t shift = 0UL) {
93   NodePtr inputs[] = {x, y};
94   NodePtrList reduce = {dx, dy};
95   ShapeVector shape[] = {ib->GetShape(inputs[0]), ib->GetShape(inputs[1])};
96   auto [need_shapecalc, reduce_axis] = DynBroadcastGradientArgs(shape[0], shape[1]);
97   NodePtrList broadcast_axes;
98   if (need_shapecalc[0] || need_shapecalc[1]) {
99     broadcast_axes = ib->BroadcastGradientArgs(inputs[0], inputs[1], shift);
100   }
101   for (size_t i = 0; i < kDim2; i++) {
102     auto dout_shape = ib->GetShape(reduce[i]);
103     if (!need_shapecalc[i] && IsDynamicRank(dout_shape)) {
104       MS_LOG(WARNING) << "The dynamic shape inference of" << reduce[i]->ToString() << " is overly generalized.";
105     }
106     if (!need_shapecalc[i] && !IsDynamicRank(dout_shape)) {
107       if (!reduce_axis[i].empty()) {
108         reduce[i] = ib->SumExt(reduce[i], ib->Value<ShapeVector>(reduce_axis[i]),
109                                ib->Value<bool>(dout_shape.size() == shape[i].size()));
110       }
111       if (ib->GetRank(reduce[i]) != shape[i].size()) {
112         reduce[i] = ib->Reshape(reduce[i], ib->Shape(inputs[i]));
113       }
114     } else {
115       bool keep_dims = (!IsDynamicRank(shape[0]) && !IsDynamicRank(shape[1]) && shape[i].size() >= shape[i ^ 1].size());
116       reduce[i] = ib->ReduceSum(reduce[i], broadcast_axes[i], keep_dims, true);
117       reduce[i] = ib->Reshape(reduce[i], ib->Shape(inputs[i]));
118     }
119   }
120   return reduce;
121 }
122 
GetOutputDtype(TypeId t1,TypeId t2,bool use_complex=false)123 TypeId GetOutputDtype(TypeId t1, TypeId t2, bool use_complex = false) {
124   static std::unordered_map<TypeId, int> complex_priority_map{
125     {kNumberTypeFloat32, 0}, {kNumberTypeFloat32, 1}, {kNumberTypeComplex64, 2}, {kNumberTypeComplex128, 4}};
126   static std::unordered_map<TypeId, int> type_priority_map{
127     {kNumberTypeBool, 0},     {kNumberTypeUInt8, 1},   {kNumberTypeInt8, 2},     {kNumberTypeUInt16, 3},
128     {kNumberTypeInt16, 4},    {kNumberTypeUInt32, 5},  {kNumberTypeInt32, 6},    {kNumberTypeUInt64, 7},
129     {kNumberTypeInt64, 8},    {kNumberTypeFloat16, 9}, {kNumberTypeFloat32, 10}, {kNumberTypeFloat64, 11},
130     {kNumberTypeBFloat16, 12}};
131   int priority_1 = 0;
132   int priority_2 = 0;
133   if (use_complex) {
134     if (complex_priority_map.find(t1) == complex_priority_map.end() ||
135         complex_priority_map.find(t2) == complex_priority_map.end()) {
136       MS_EXCEPTION(ValueError) << "Complex binary op type promotion not supported for " << TypeIdToString(t1) << " and "
137                                << TypeIdToString(t2);
138     }
139     priority_1 = complex_priority_map[t1];
140     priority_2 = complex_priority_map[t2];
141   } else {
142     if (type_priority_map.find(t1) == type_priority_map.end() ||
143         type_priority_map.find(t2) == type_priority_map.end()) {
144       MS_EXCEPTION(ValueError) << "Binary op type promotion not supported for " << TypeIdToString(t1) << " and "
145                                << TypeIdToString(t2);
146     }
147     priority_1 = type_priority_map[t1];
148     priority_2 = type_priority_map[t2];
149   }
150   return (priority_1 > priority_2 ? t1 : t2);
151 }
152 }  // namespace
153 
NormalizeAxis(int64_t axis,size_t rank)154 int64_t NormalizeAxis(int64_t axis, size_t rank) {
155   auto rank_i = SizeToLong(rank);
156   if (axis < -rank_i || axis >= rank_i) {
157     MS_EXCEPTION(ValueError) << "For rank " << rank << ", the axis must be in range [" << -rank_i << ", " << rank_i
158                              << "), but got " << axis;
159   }
160   return (axis < 0) ? (axis + rank_i) : axis;
161 }
162 
SplitShapeIndex(const ShapeVector & input_shape,const ShapeVector & axis)163 std::pair<ShapeVector, ShapeVector> SplitShapeIndex(const ShapeVector &input_shape, const ShapeVector &axis) {
164   auto rank = SizeToLong(input_shape.size());
165   if (rank == 0) {
166     return {};
167   }
168   std::vector<bool> reduction_indices_map(input_shape.size());
169   ShapeVector perm;
170   int64_t reduced_num = 1;
171   int64_t other_num = 1;
172   for (auto i : axis) {
173     if (i < 0) {
174       i += rank;
175     }
176     reduction_indices_map[i] = True;
177     reduced_num *= input_shape[LongToSize(i)];
178     (void)perm.emplace_back(i);
179   }
180   for (int64_t i = 0; i < rank; i++) {
181     if (!reduction_indices_map[i]) {
182       other_num *= input_shape[LongToSize(i)];
183       (void)perm.emplace_back(i);
184     }
185   }
186   ShapeVector pack_shape{reduced_num, other_num};
187   return std::make_pair(pack_shape, perm);
188 }
189 
TupleDiv(const std::vector<int64_t> & x,const std::vector<int64_t> & y)190 std::vector<int64_t> TupleDiv(const std::vector<int64_t> &x, const std::vector<int64_t> &y) {
191   std::vector<int64_t> out;
192   if (x.size() != y.size()) {
193     MS_LOG(EXCEPTION) << "The size of inputs of TupleDiv must be the same, but the size of divisor tuple is"
194                       << " " << y.size() << ", the size of dividend tuple is " << x.size() << ".";
195   }
196   for (size_t i = 0; i < y.size(); i++) {
197     if (y[i] == 0) {
198       MS_LOG(EXCEPTION) << "The divisor value should not be 0!";
199     }
200     if ((x[i] % y[i]) != 0) {
201       MS_LOG(EXCEPTION) << "The inputs of TupleDiv should be divisible, but they are not divisible now, "
202                         << "the dividend is " << x[i] << ", the divisor is " << y[i] << ".";
203     }
204     out.push_back(x[i] / y[i]);
205   }
206   return out;
207 }
208 
ReduceShape(const std::vector<int64_t> & x,const std::vector<int64_t> & axis,bool skip_mode)209 std::vector<int64_t> ReduceShape(const std::vector<int64_t> &x, const std::vector<int64_t> &axis, bool skip_mode) {
210   if (x.empty()) {
211     return {};
212   }
213   if (axis.empty()) {
214     if (skip_mode) {
215       return x;
216     }
217     return std::vector<int64_t>(x.size(), 1LL);
218   }
219   int64_t x_rank = SizeToLong(x.size());
220   std::vector<int64_t> out(x);
221   for (auto i : axis) {
222     if (i >= x_rank || i < (-x_rank)) {
223       MS_LOG(EXCEPTION) << "axis should be in range [" << (-x_rank) << ", " << x_rank << ").";
224     }
225     if (i < 0) {
226       i += x_rank;
227     }
228     out[i] = 1LL;
229   }
230   return out;
231 }
232 
GetIntValue(const NodePtr & node)233 int64_t GetIntValue(const NodePtr &node) {
234   MS_EXCEPTION_IF_NULL(node);
235   auto value = node->BuildValue();
236   if (value->isa<tensor::BaseTensor>()) {
237     auto t_vec = CheckAndConvertUtils::CheckTensorIntValue("tensor", value, "bprop");
238     MS_EXCEPTION_IF_CHECK_FAIL(t_vec.size() >= kIndex1, "Get single tensor value failed");
239     return t_vec[kIndex0];
240   }
241   return AnfUtils::GetIntValue(value);
242 }
243 
GetIntList(const ValuePtr & value)244 std::vector<int64_t> GetIntList(const ValuePtr &value) {
245   MS_EXCEPTION_IF_NULL(value);
246   if (value->isa<tensor::BaseTensor>()) {
247     auto tensor = value->cast<tensor::BaseTensorPtr>();
248     MS_EXCEPTION_IF_NULL(tensor);
249     tensor->data_sync();
250     return CheckAndConvertUtils::CheckTensorIntValue("tensor", value, "bprop");
251   } else {
252     return CheckAndConvertUtils::CheckIntOrTupleInt("value", value, "bprop");
253   }
254 }
255 
GetIntList(const NodePtr & node)256 std::vector<int64_t> GetIntList(const NodePtr &node) {
257   auto value = node->BuildValue();
258   MS_EXCEPTION_IF_NULL(value);
259   return GetIntList(value);
260 }
261 
StaticBinopGradCommon(BpropBuilder * ib,const NodePtr & dx,const ShapeArray & shape,const ShapeArray & broadcast_shape,size_t shift,size_t index,bool * is_dynamic_shape)262 NodePtr StaticBinopGradCommon(BpropBuilder *ib, const NodePtr &dx, const ShapeArray &shape,
263                               const ShapeArray &broadcast_shape, size_t shift, size_t index, bool *is_dynamic_shape) {
264   NodePtr reduce_dx = dx;
265   auto shape_dynamic_dims = std::count_if(shape[index].begin(), shape[index].end(), [](int64_t x) { return x <= -1; });
266   if (broadcast_shape[kIndex0].empty() || broadcast_shape[kIndex1].empty()) {
267     if (broadcast_shape[index].empty()) {
268       if (shift) {
269         std::vector<int64_t> axis(broadcast_shape[index ^ 1].size());
270         std::iota(axis.begin(), axis.end(), 0LL);
271         reduce_dx = ib->SumExt(reduce_dx, ib->Value<ShapeVector>(axis), ib->Value(false));
272       } else {
273         reduce_dx = ib->SumExt(reduce_dx, ib->EmitValue(kNone), ib->Value(false));
274       }
275     }
276   } else if (!IsDynamic(broadcast_shape[0]) && !IsDynamic(broadcast_shape[1]) && shape_dynamic_dims <= 1) {
277     std::vector<std::vector<int64_t>> bc_axis = BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1]);
278     if (!bc_axis[index].empty()) {
279       reduce_dx = ib->SumExt(reduce_dx, ib->Value<ShapeVector>(bc_axis[index]),
280                              ib->Value<bool>(ib->GetRank(reduce_dx) == shape[index].size()));
281     }
282     reduce_dx = ib->Reshape(reduce_dx, shape[index]);
283   } else {
284     *is_dynamic_shape = true;
285   }
286   return reduce_dx;
287 }
288 
MatMulExtBroadCastGradPart(BpropBuilder * ib,const NodePtr & dx,const ShapeArray & shape,const ShapeArray & broadcast_shape,size_t ignore_offset,size_t index)289 NodePtr MatMulExtBroadCastGradPart(BpropBuilder *ib, const NodePtr &dx, const ShapeArray &shape,
290                                    const ShapeArray &broadcast_shape, size_t ignore_offset, size_t index) {
291   NodePtr reduce_dx = dx;
292   std::vector<std::vector<int64_t>> bc_axis =
293     BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1], ignore_offset);
294   if (!bc_axis[index].empty()) {
295     reduce_dx = ib->ReduceSum(reduce_dx, bc_axis[index], ib->GetRank(reduce_dx) == shape[index].size());
296   }
297   if (ib->GetRank(reduce_dx) != shape[index].size()) {
298     reduce_dx = ib->Reshape(reduce_dx, shape[index]);
299   }
300   return reduce_dx;
301 }
302 
BinopGradCommon(BpropBuilder * ib,const NodePtr & x,const NodePtr & y,const NodePtr & dx,const NodePtr & dy,size_t shift)303 NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, const NodePtr &dy,
304                             size_t shift) {
305   // Common grad definition for binary operations with shift.
306   // The function is usually used in backprop op to reduce additional dimensions
307   // created by broadcasting.
308   NodePtrList inputs{x, y};
309   ShapeArray shape{ib->GetShape(inputs[kIndex0]), ib->GetShape(inputs[kIndex1])};
310   NodePtrList reduce = {dx, dy};
311   if (IsDynamicRank(shape[kIndex0]) || IsDynamicRank(shape[kIndex1])) {
312     return DynBinopGradCommon(ib, x, y, dx, dy, shift);
313   }
314   if (shape[kIndex0].size() <= shift && shape[kIndex0].size() == shape[kIndex1].size()) {
315     return reduce;
316   }
317   ShapeArray broadcast_shape(kDim2);
318   for (size_t i = 0; i < kDim2; i++) {
319     broadcast_shape[i] = ShapeVector(shape[i].begin(), shape[i].end() - shift);
320   }
321   bool is_x_shape_dynamic = false;
322   bool is_y_shape_dynamic = false;
323   if (dx != nullptr) {
324     reduce[kIndex0] =
325       StaticBinopGradCommon(ib, reduce[kIndex0], shape, broadcast_shape, shift, kIndex0, &is_x_shape_dynamic);
326   }
327   if (dy != nullptr) {
328     reduce[kIndex1] =
329       StaticBinopGradCommon(ib, reduce[kIndex1], shape, broadcast_shape, shift, kIndex1, &is_y_shape_dynamic);
330   }
331   if (is_x_shape_dynamic || is_y_shape_dynamic) {
332     return DynBinopGradCommon(ib, x, y, dx, dy, shift);
333   }
334   return reduce;
335 }
336 
MatMulExtBroadCastGrad(BpropBuilder * ib,const NodePtr & x,const NodePtr & y,const NodePtr & dx,const NodePtr & dy,size_t ignore_offset)337 NodePtrList MatMulExtBroadCastGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
338                                    const NodePtr &dy, size_t ignore_offset) {
339   NodePtrList inputs{x, y};
340   ShapeArray shape{ib->GetShape(inputs[kIndex0]), ib->GetShape(inputs[kIndex1])};
341   NodePtrList reduce = {dx, dy};
342   ShapeArray broadcast_shape(kDim2);
343   broadcast_shape[0] = shape[0];
344   broadcast_shape[1] = shape[1];
345 
346   if (dx != nullptr) {
347     reduce[kIndex0] = MatMulExtBroadCastGradPart(ib, reduce[kIndex0], shape, broadcast_shape, ignore_offset, kIndex0);
348   }
349   if (dy != nullptr) {
350     reduce[kIndex1] = MatMulExtBroadCastGradPart(ib, reduce[kIndex1], shape, broadcast_shape, ignore_offset, kIndex1);
351   }
352   return reduce;
353 }
354 
Range(int64_t start,int64_t stop,int64_t step)355 std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step) {
356   if (step == 0) {
357     MS_EXCEPTION(ValueError) << "For Range, step should not be 0";
358   }
359   auto size = stop - start;
360   if (size * step <= 0) {
361     return {};
362   }
363   if (size % step == 0) {
364     size = size / step;
365   } else {
366     size = size / step + 1;
367   }
368   std::vector<int64_t> range(LongToSize(size));
369   for (size_t i = 0; i < range.size(); i++, start += step) {
370     range[i] = start;
371   }
372   return range;
373 }
374 
Range(int64_t stop)375 std::vector<int64_t> Range(int64_t stop) { return Range(0, stop); }
376 
GetTransposeAxis(const std::vector<int64_t> & x_shape,int64_t axis)377 std::vector<int64_t> GetTransposeAxis(const std::vector<int64_t> &x_shape, int64_t axis) {
378   std::vector<int64_t> reverse_axis;
379   if (x_shape.empty()) {
380     return reverse_axis;
381   }
382   auto rk = static_cast<int64_t>(x_shape.size());
383   if (axis < 0) {
384     axis += rk;
385   }
386   reverse_axis.reserve(x_shape.size());
387   for (int64_t i = 0; i < rk; ++i) {
388     (void)reverse_axis.emplace_back(i);
389   }
390   reverse_axis[LongToSize(axis)] = rk - 1;
391   reverse_axis[LongToSize(rk - 1)] = axis;
392   return reverse_axis;
393 }
394 
CheckRange(int64_t idx,int64_t dim_size)395 int64_t CheckRange(int64_t idx, int64_t dim_size) {
396   if (idx < -dim_size || idx >= dim_size) {
397     MS_EXCEPTION(IndexError) << "index {" << idx << "} is out of bounds for dimension with size {" << dim_size << "}";
398   }
399   return idx < 0 ? (idx + dim_size) : idx;
400 }
401 
GetEps(BpropBuilder * ib,const TypePtr & type)402 NodePtr GetEps(BpropBuilder *ib, const TypePtr &type) {
403   constexpr auto epsilon = 0.000977;
404   switch (type->type_id()) {
405     case kNumberTypeFloat16:
406       return ib->Tensor(epsilon, type);
407     case kNumberTypeFloat32:
408       return ib->Tensor(std::numeric_limits<float>::epsilon(), type);
409     case kNumberTypeFloat64:
410       return ib->Tensor(std::numeric_limits<double>::epsilon(), type);
411     default:
412       return ib->Tensor(0, type);
413   }
414 }
415 
GenerateInverseIndex(const std::vector<int64_t> & x_shp,int64_t axis_v,int64_t batch_dims)416 std::vector<int64_t> GenerateInverseIndex(const std::vector<int64_t> &x_shp, int64_t axis_v, int64_t batch_dims) {
417   int64_t x_rank = static_cast<int64_t>(x_shp.size());
418   auto index = Range(x_rank);
419   if (axis_v < 0) {
420     axis_v += x_rank;
421   }
422   std::vector<int64_t> perm;
423   auto start1 = x_rank <= 1 ? index.end() : index.begin() + batch_dims + 1;
424   auto end1 = axis_v + 1 >= x_rank ? index.end() : index.begin() + axis_v + 1;
425   auto start2 = axis_v + 1 >= x_rank ? index.end() : index.begin() + axis_v + 1;
426   (void)std::copy(index.begin(), index.begin() + batch_dims, std::back_inserter(perm));
427   (void)std::copy(start1, end1, std::back_inserter(perm));
428   perm.push_back(batch_dims);
429   (void)std::copy(start2, index.end(), std::back_inserter(perm));
430   return perm;
431 }
432 
GenerateShapeIndex(const std::vector<int64_t> & out_shp,const std::vector<int64_t> & ind_shp,int64_t axis_v,int64_t batch_dims)433 std::vector<int64_t> GenerateShapeIndex(const std::vector<int64_t> &out_shp, const std::vector<int64_t> &ind_shp,
434                                         int64_t axis_v, int64_t batch_dims) {
435   int64_t out_rank = static_cast<int64_t>(out_shp.size());
436   int64_t ind_rank = static_cast<int64_t>(ind_shp.size());
437   if (axis_v < 0) {
438     axis_v += out_rank - ind_rank + 1;
439   }
440   auto perm_part1 = Range(axis_v, axis_v + ind_rank - batch_dims);
441   auto index = Range(out_rank);
442   std::vector<int64_t> perm;
443   auto end = axis_v >= out_rank ? out_rank - 1 : axis_v;
444   auto start =
445     (axis_v + ind_rank - batch_dims) >= out_rank ? index.end() : (index.begin() + axis_v + ind_rank - batch_dims);
446   (void)std::copy(index.begin(), index.begin() + batch_dims, std::back_inserter(perm));
447   (void)std::copy(perm_part1.begin(), perm_part1.end(), std::back_inserter(perm));
448   (void)std::copy(index.begin() + batch_dims, index.begin() + end, std::back_inserter(perm));
449   (void)std::copy(start, index.end(), std::back_inserter(perm));
450   return perm;
451 }
452 
RegenerateOutputShape(const std::vector<int64_t> & x_shp,const std::vector<int64_t> & ind_shp,int64_t axis_v,int64_t batch_dims)453 std::vector<int64_t> RegenerateOutputShape(const std::vector<int64_t> &x_shp, const std::vector<int64_t> &ind_shp,
454                                            int64_t axis_v, int64_t batch_dims) {
455   int64_t rank = static_cast<int64_t>(x_shp.size());
456   if (axis_v < 0) {
457     axis_v += rank;
458   }
459   std::vector<int64_t> out_shp;
460   auto end = axis_v >= rank ? rank - 1 : axis_v;
461   auto start = axis_v + 1 >= rank ? x_shp.end() : x_shp.begin() + axis_v + 1;
462   (void)std::copy(x_shp.begin(), x_shp.begin() + end, std::back_inserter(out_shp));
463   (void)std::copy(ind_shp.begin() + batch_dims, ind_shp.end(), std::back_inserter(out_shp));
464   (void)std::copy(start, x_shp.end(), std::back_inserter(out_shp));
465   return out_shp;
466 }
467 
InvertPermutation(const std::vector<int64_t> & perm)468 std::vector<int64_t> InvertPermutation(const std::vector<int64_t> &perm) {
469   std::vector<int64_t> check_perm(perm);
470   std::vector<int64_t> res(perm);
471   if (res.empty()) {
472     return res;
473   }
474   std::sort(check_perm.begin(), check_perm.end());
475   int64_t perm_size = static_cast<int64_t>(check_perm.size());
476   for (int64_t i = 0; i < perm_size; i++) {
477     auto idx = LongToSize(i);
478     if (check_perm[idx] != i) {
479       MS_LOG(EXCEPTION) << "For InvertPermutation, the input_x should be '[0-" << (perm_size - 1) << "]', but got "
480                         << check_perm;
481     }
482     res[LongToSize(perm[idx])] = i;
483   }
484   return res;
485 }
486 
GetTransposition(int64_t axis,int64_t rank)487 std::vector<int64_t> GetTransposition(int64_t axis, int64_t rank) {
488   if (axis < 0) {
489     axis += rank;
490   }
491   auto trans = Range(axis);
492   auto after_axis = Range(axis + 1, rank - 1);
493   trans.push_back(rank - 1);
494   (void)trans.insert(trans.end(), after_axis.begin(), after_axis.end());
495   trans.push_back(axis);
496   return trans;
497 }
498 
499 class ReduceShapeShapeCalc : public ShapeCalcFunctor {
500  public:
501   // cppcheck-suppress unknownMacro
502   DECLARE_SHAPE_CALC("ShapeCalc_ReduceShape", ReduceShapeShapeCalc)
ReduceShapeShapeCalc(bool skip_mode)503   explicit ReduceShapeShapeCalc(bool skip_mode) : ShapeCalcFunctor("ShapeCalc_ReduceShape"), skip_mode_(skip_mode) {}
ToValue() const504   ValuePtr ToValue() const override { return MakeValue(skip_mode_); }
FromValue(const ValuePtr & value)505   void FromValue(const ValuePtr &value) override { skip_mode_ = GetValue<int64_t>(value); }
Calc(const ShapeArray & inputs) const506   ShapeArray Calc(const ShapeArray &inputs) const override {
507     auto x_shape = inputs.at(0);
508     auto axis_value = inputs.at(1);
509     auto r_shape = ReduceShape(x_shape, axis_value, skip_mode_);
510     auto scaling = TupleDiv(x_shape, r_shape);
511     return {r_shape, scaling};
512   }
Infer(const ShapeArray & inputs,const HashSet<size_t> &) const513   std::vector<int64_t> Infer(const ShapeArray &inputs, const HashSet<size_t> &) const override {
514     int64_t x_rank = IsDynamicRank(inputs.at(0)) ? -1 : static_cast<int64_t>(inputs.at(0).size());
515     return {x_rank, x_rank};
516   }
517 
518  protected:
519   bool skip_mode_ = false;
520 };
521 REG_FUNCTOR("ShapeCalc_ReduceShape", ReduceShapeShapeCalc);
522 
SumGrad(Emitter * ib,const NodePtr & x,const NodePtr & axis,const NodePtr & dout,bool keep_dims,bool skip_mode)523 NodePtr SumGrad(Emitter *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &dout, bool keep_dims,
524                 bool skip_mode) {
525   auto grad = dout;
526   auto calc_res = ib->ShapeCalc(std::make_shared<ReduceShapeShapeCalc>(skip_mode), {x, axis}, {1});
527   if (!keep_dims) {
528     grad = ib->Reshape(grad, calc_res[0]);
529   }
530   auto tile_scaling = calc_res[1];
531   if (tile_scaling->input_type() == InputType::kConstant || IsDynamic(x->shape())) {
532     return ib->Tile(grad, tile_scaling);
533   }
534   return ib->BroadcastTo(grad, x);
535 }
536 
MinOrMaxGrad(BpropBuilder * ib,const NodePtr & x,const NodePtr & axis,const NodePtr & keep_dims,const NodePtr & out,const NodePtr & dout)537 NodePtr MinOrMaxGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &keep_dims,
538                      const NodePtr &out, const NodePtr &dout) {
539   auto y = out;
540   auto grad = dout;
541   auto keepdims = GetValue<bool>(keep_dims->BuildValue());
542   if (!keepdims) {
543     auto output_shape_kept_dims = ib->ShapeCalc(std::make_shared<ReduceShapeShapeCalc>(), {x, axis}, {1})[0];
544     y = ib->Reshape(out, output_shape_kept_dims);
545     grad = ib->Reshape(dout, output_shape_kept_dims);
546   }
547   auto indicators = ib->Cast(ib->Equal(y, x), ib->GetDtype(grad));
548   auto num_selected = ib->ReduceSum(indicators, axis, true, false);
549   return indicators / num_selected * grad;
550 }
551 
TensorScatterElementsZeroDim(Emitter * ib,const NodePtr & input,const ValuePtr & dim,const NodePtr & index,const NodePtr & src,const std::string & reduce_string)552 inline NodePtr TensorScatterElementsZeroDim(Emitter *ib, const NodePtr &input, const ValuePtr &dim,
553                                             const NodePtr &index, const NodePtr &src,
554                                             const std::string &reduce_string) {
555   // TensorScatterElements op: ZeroDim need to expand to OneDim
556   auto input_expand = ib->ExpandDims(input, -1);
557   auto index_expand = ib->ExpandDims(index, -1);
558   auto src_expand = ib->ExpandDims(src, -1);
559   auto out = ib->Emit("TensorScatterElements", {input_expand, index_expand, src_expand},
560                       {{"reduction", MakeValue<string>(reduce_string)}, {"axis", dim}});
561   // recover OneDim To ZeroDim
562   return ib->Squeeze(out, MakeValue(ShapeVector{0}));
563 }
564 
TensorScatterElements(Emitter * ib,const NodePtr & input,const ValuePtr & dim,const NodePtr & index,const NodePtr & src,const std::string & reduce_string)565 inline NodePtr TensorScatterElements(Emitter *ib, const NodePtr &input, const ValuePtr &dim, const NodePtr &index,
566                                      const NodePtr &src, const std::string &reduce_string) {
567   return ib->Emit("TensorScatterElements", {input, index, src},
568                   {{"reduction", MakeValue<string>(reduce_string)}, {"axis", dim}});
569 }
570 
Scatter_(BpropBuilder * ib,const NodePtr & input,const NodePtr & dim,const NodePtr & index,const NodePtr & src,const std::string & reduce_string)571 NodePtr Scatter_(BpropBuilder *ib, const NodePtr &input, const NodePtr &dim, const NodePtr &index, const NodePtr &src,
572                  const std::string &reduce_string) {
573   auto dim_val = dim->BuildValue();
574   if (!ops::IsValueKnown(dim_val)) {
575     MS_EXCEPTION(ValueError) << "For `TensorScatterElements` op, the `axis` must currently be a constant!";
576   }
577   auto input_shape = ib->GetShape(input);
578   if (input_shape.size() == 0) {
579     return TensorScatterElementsZeroDim(ib, input, dim_val, index, src, reduce_string);
580   } else if (IsDynamicRank(input_shape)) {
581     auto rank = ib->Emit("Rank", {input});
582     auto is_zero_dim_cond = ib->Emit("scalar_eq", {rank, ib->Value<int64_t>(0)});
583     auto scatter_zero_dim_impl = [&input, &dim_val, &index, &src, &reduce_string](Emitter *e) -> NodePtrList {
584       return {TensorScatterElementsZeroDim(e, input, dim_val, index, src, reduce_string)};
585     };
586     auto scatter_impl = [&input, &dim_val, &index, &src, &reduce_string](Emitter *e) -> NodePtrList {
587       return {TensorScatterElements(e, input, dim_val, index, src, reduce_string)};
588     };
589     return ib->Conditional(is_zero_dim_cond, scatter_zero_dim_impl, scatter_impl);
590   }
591   return TensorScatterElements(ib, input, dim_val, index, src, reduce_string);
592 }
593 
ArgminOrArgmaxGrad(BpropBuilder * ib,const NodePtr & x,const NodePtr & axis,const NodePtr & keep_dims,const NodePtr & out,const NodePtr & dout,const bool is_max)594 NodePtr ArgminOrArgmaxGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &keep_dims,
595                            const NodePtr &out, const NodePtr &dout, const bool is_max) {
596   auto keep_dims_value = keep_dims->BuildValue();
597   NodePtr dout_value = ib->TupleGetItem(dout, 1);
598   NodePtr indices = ib->TupleGetItem(out, 0);
599   auto input_shape = ib->GetShape(x);
600   if (ops::IsValueKnown(keep_dims_value) && !IsDynamicRank(input_shape)) {
601     auto is_zero_dim = input_shape.size() == 0;
602     auto keep_dims_bool = GetValue<bool>(keep_dims_value);
603     indices = (keep_dims_bool || is_zero_dim) ? indices : ib->Emit("ExpandDims", {indices, axis});
604     dout_value = (keep_dims_bool || is_zero_dim) ? dout_value : ib->Emit("ExpandDims", {dout_value, axis});
605   } else {
606     auto rank = ib->Emit("Rank", {x});
607     auto rank_is_zero = ib->Emit("scalar_eq", {rank, ib->Value<int64_t>(0)});
608     auto cond = ib->LogicalOr(ib->ScalarToTensor(keep_dims, kBool), ib->ScalarToTensor(rank_is_zero, kBool));
609     auto indices_expand = [&indices, &axis](Emitter *e) -> NodePtrList {
610       return {e->Emit("ExpandDims", {indices, axis})};
611     };
612     auto indices_ori = [&indices](Emitter *e) -> NodePtrList { return {indices}; };
613     indices = ib->Conditional(cond, indices_ori, indices_expand);
614     auto dout_expand = [&dout_value, &axis](Emitter *e) -> NodePtrList {
615       return {e->Emit("ExpandDims", {dout_value, axis})};
616     };
617     auto dout_ori = [&dout_value](Emitter *e) -> NodePtrList { return {dout_value}; };
618     dout_value = ib->Conditional(cond, dout_ori, dout_expand);
619   }
620   NodePtr dx_zeros = ib->Zeros(x);
621   auto dx = Scatter_(ib, dx_zeros, axis, indices, dout_value, "none");
622   return dx;
623 }
624 
PromoteBinaryDtype(TypeId t1,TypeId t2)625 TypeId PromoteBinaryDtype(TypeId t1, TypeId t2) {
626   if (t1 == t2) {
627     return t1;
628   }
629   static std::unordered_set<TypeId> complex_types{kNumberTypeComplex64, kNumberTypeComplex128};
630   return GetOutputDtype(
631     t1, t2, (complex_types.find(t1) != complex_types.end() || complex_types.find(t2) != complex_types.end()));
632 }
633 
LGamma(BpropBuilder * ib,const NodePtr & x)634 NodePtr LGamma(BpropBuilder *ib, const NodePtr &x) {
635   auto k_lanczos_gamma = 7;
636   auto k_base_lanczos_coeff = 0.9999999999998099;
637   double k_lanczos_coefficients[8] = {676.520368121885098567009190444019, -1259.13921672240287047156078755283,
638                                       771.3234287776530788486528258894,   -176.61502916214059906584551354,
639                                       12.507343278686904814458936853,     -0.13857109526572011689554707,
640                                       9.984369578019570859563e-6,         1.50563273514931155834e-7};
641   auto input_dtype = ib->GetDtype(x);
642   auto one_half = ib->Tensor(0.5, input_dtype);
643   auto one = ib->Tensor(1, input_dtype);
644   auto zero = ib->Tensor(0, input_dtype);
645   auto log_sqrt_two_pi = ib->Tensor((log_2 + log_pi) / 2, input_dtype);
646   auto lanczos_gamma_plus_one_half = k_lanczos_gamma + 0.5;
647   auto log_lanczos_gamma_plus_one_half = log(lanczos_gamma_plus_one_half);
648   auto inf = std::numeric_limits<double>::infinity();
649   auto infinity = ib->Fill(inf, ib->Shape(x), input_dtype->type_id());
650   auto need_to_reflect = ib->Less(x, one_half);
651   auto neg_input = ib->Neg(x);
652   auto z = ib->Select(need_to_reflect, neg_input, ib->Sub(x, one));
653   auto CalculateReflectedX = [&ib, &z, &k_base_lanczos_coeff, &k_lanczos_coefficients]() -> NodePtr {
654     auto z_dtype = ib->GetDtype(z);
655     NodePtr reflex_x = ib->Tensor(k_base_lanczos_coeff, z_dtype);
656     for (int i = 0; i < 8; ++i) {
657       auto btmp = ib->Add(z, ib->Tensor(i, z_dtype));
658       btmp = ib->Add(btmp, (ib->Tensor(1, z_dtype)));
659       auto product = ib->RealDiv((ib->Tensor(k_lanczos_coefficients[i], z_dtype)), btmp);
660       reflex_x = ib->Add(product, reflex_x);
661     }
662     return reflex_x;
663   };
664   auto reflex_x = CalculateReflectedX();
665   auto lanczos_tensor = ib->Tensor(lanczos_gamma_plus_one_half, input_dtype);
666   auto log_lanczos_tensor = ib->Tensor(log_lanczos_gamma_plus_one_half, input_dtype);
667   auto t = ib->Add(z, lanczos_tensor);
668   auto log_t = ib->Add((ib->Emit("Log1p", {ib->RealDiv(z, lanczos_tensor)})), log_lanczos_tensor);
669   auto log_y = ib->Add(
670     (ib->Add((ib->Log(reflex_x)), (ib->Mul((ib->Sub((ib->Add(z, one_half)), (ib->RealDiv(t, log_t)))), log_t)))),
671     log_sqrt_two_pi);
672   auto abs_input = ib->Emit("Abs", {x});
673   auto abs_frac_input = ib->Sub(abs_input, (ib->Emit("Floor", {abs_input})));
674   auto new_x = ib->Select(ib->LessEqual(x, zero), ib->Select(ib->Equal(abs_frac_input, zero), infinity, x), x);
675   auto reduced_frac_input =
676     ib->Select(ib->Greater(abs_frac_input, one_half), ib->Sub(one, abs_frac_input), abs_frac_input);
677   auto reflection_denom =
678     ib->Log(ib->Emit("Sin", {ib->Mul(ib->Tensor(pi, ib->GetDtype(reduced_frac_input)), reduced_frac_input)}));
679   auto reflection =
680     ib->Select(ib->Emit("IsFinite", {reflection_denom}),
681                ib->Add((ib->Sub((ib->Neg(reflection_denom)), log_y)), ib->Tensor(log_pi, ib->GetDtype(log_y))),
682                ib->Neg(reflection_denom));
683   auto result = ib->Select(need_to_reflect, reflection, log_y);
684   return ib->Select(ib->Emit("IsFinite", {new_x}), result, infinity);
685 }
686 
CheckType(const TypePtr & check_type,const std::set<TypePtr> & template_types)687 bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types) {
688   return std::any_of(template_types.begin(), template_types.end(), [&check_type](const TypePtr &accept) -> bool {
689     return IsIdentidityOrSubclass(check_type, accept);
690   });
691 }
692 
PoolToNHWC(const ShapeVector & v)693 ShapeVector PoolToNHWC(const ShapeVector &v) {
694   ShapeVector new_v(v);
695   new_v[kIndex1] = v[kIndex2];
696   new_v[kIndex2] = v[kIndex3];
697   new_v[kIndex3] = v[kIndex1];
698   return new_v;
699 }
700 
ConvToNHWC(const ShapeVector & v)701 ShapeVector ConvToNHWC(const ShapeVector &v) {
702   ShapeVector new_v(v);
703   new_v[kIndex0] = v[kIndex1];
704   new_v[kIndex1] = v[kIndex2];
705   new_v[kIndex2] = v[kIndex3];
706   new_v[kIndex3] = 1;
707   return new_v;
708 }
709 
GetShapeByRange(const ShapeVector & v,int64_t begin,int64_t end)710 ShapeVector GetShapeByRange(const ShapeVector &v, int64_t begin, int64_t end) {
711   // Get range [begin, end) in v.
712   auto rank = SizeToLong(v.size());
713   auto real_begin = std::min((begin < 0) ? (rank + begin) : begin, rank);
714   auto real_end = std::min((end < 0) ? (rank + end) : end, rank);
715   ShapeVector res(v.begin() + real_begin, v.begin() + real_end);
716   return res;
717 }
718 
MatrixTranspose(BpropBuilder * ib,const NodePtr & x)719 NodePtr MatrixTranspose(BpropBuilder *ib, const NodePtr &x) {
720   auto shape = ib->GetShape(x);
721   if (IsDynamicRank(shape)) {
722     auto dim = ib->Emit("Rank", {x});
723     auto perm = ib->Range(dim);
724     auto stridedslice_helper = [&perm, &ib](int64_t begin, int64_t end, int64_t step, int64_t end_mask = 0) {
725       return ib->Emit("StridedSlice",
726                       {perm, ib->Value<ShapeVector>(ShapeVector{begin}), ib->Value<ShapeVector>(ShapeVector{end}),
727                        ib->Value<ShapeVector>(ShapeVector{step}), ib->Value<int64_t>(0LL), ib->Value<int64_t>(end_mask),
728                        ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL)});
729     };
730     auto part_1 = stridedslice_helper(0, -2, 1);
731     auto part_2 = stridedslice_helper(-1, 0, 1, 1);
732     auto part_3 = stridedslice_helper(-2, -1, 1);
733     perm = ib->Concat({part_1, part_2, part_3}, -1);
734     return ib->Transpose(x, ib->TensorToTuple(perm));
735   }
736   auto dim = shape.size();
737   if (dim < kDim2) {
738     MS_LOG_EXCEPTION << "For MatrixTranspose, input's ndim " << dim << " is less or equal to 2, which is invalid";
739   }
740   std::vector<int64_t> perm(dim);
741   for (size_t i = 0; i < dim; i++) {
742     perm[i] = static_cast<int64_t>(i);
743   }
744   std::swap(perm[dim - kIndex2], perm[dim - kIndex1]);
745   return ib->Transpose(x, perm);
746 }
747 
MatrixTransposeExt(BpropBuilder * ib,const NodePtr & x)748 NodePtr MatrixTransposeExt(BpropBuilder *ib, const NodePtr &x) {
749   auto shape = ib->GetShape(x);
750   if (IsDynamicRank(shape)) {
751     auto dim = ib->Emit("Rank", {x});
752     auto perm = ib->Range(dim);
753     auto stridedslice_helper = [&perm, &ib](int64_t begin, int64_t end, int64_t step, int64_t end_mask = 0) {
754       return ib->Emit("StridedSlice",
755                       {perm, ib->Value<ShapeVector>(ShapeVector{begin}), ib->Value<ShapeVector>(ShapeVector{end}),
756                        ib->Value<ShapeVector>(ShapeVector{step}), ib->Value<int64_t>(0LL), ib->Value<int64_t>(end_mask),
757                        ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL)});
758     };
759     auto part_1 = stridedslice_helper(0, -2, 1);
760     auto part_2 = stridedslice_helper(-1, 0, 1, 1);
761     auto part_3 = stridedslice_helper(-2, -1, 1);
762     perm = ib->Concat({part_1, part_2, part_3}, -1);
763     return ib->Transpose(x, ib->TensorToTuple(perm));
764   }
765   auto dim = shape.size();
766   if (dim < kDim2) {
767     return x;
768   }
769   std::vector<int64_t> perm(dim);
770   for (size_t i = 0; i < dim; i++) {
771     perm[i] = static_cast<int64_t>(i);
772   }
773   std::swap(perm[dim - kIndex2], perm[dim - kIndex1]);
774   return ib->Transpose(x, perm);
775 }
776 
Adjoint(BpropBuilder * ib,const NodePtr & x)777 NodePtr Adjoint(BpropBuilder *ib, const NodePtr &x) { return MatrixTranspose(ib, ib->Conj(x)); }
778 }  // namespace mindspore::expander::bprop
779