1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/expander/bprop/grad_ops/common_utils.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include "utils/anf_utils.h"
27 #include "utils/check_convert_utils.h"
28 #include "utils/ms_context.h"
29 #include "ops/op_utils.h"
30
31 namespace mindspore::expander::bprop {
ReturnZeros(BpropBuilder * ib)32 NodePtrList ReturnZeros(BpropBuilder *ib) {
33 const auto &inputs = ib->GetInputs();
34 if (inputs.size() <= kDim2) {
35 MS_LOG(EXCEPTION) << "Bprop's inputs size should be greater than 2 (includes out and dout), but got "
36 << inputs.size();
37 }
38 auto output_num = inputs.size() - kDim2;
39 NodePtrList outputs(output_num);
40 for (size_t i = 0; i < output_num; ++i) {
41 outputs[i] = ib->OutZeros(inputs[i]);
42 }
43 return outputs;
44 }
45
46 namespace {
DynBroadcastGradientArgs(const std::vector<int64_t> & x_shape,const std::vector<int64_t> & y_shape)47 std::pair<std::vector<bool>, std::vector<std::vector<int64_t>>> DynBroadcastGradientArgs(
48 const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape) {
49 auto x_size = x_shape.size();
50 auto y_size = y_shape.size();
51 ShapeVector shape[kDim2] = {x_shape, y_shape};
52 auto n = std::max(x_size, y_size);
53 std::vector<bool> need_shapecalc = {false, false};
54 std::vector<std::vector<int64_t>> reduce_axis(kDim2);
55 if (IsDynamicRank(shape[0]) || IsDynamicRank(shape[1])) {
56 return {{true, true}, reduce_axis};
57 }
58 for (size_t i = n; i >= 1; i--) {
59 int64_t dim_value[2] = {x_size < i ? 1 : shape[0][x_size - i], y_size < i ? 1 : shape[1][y_size - i]};
60 const int64_t reduce_idx = SizeToLong(n - i);
61 if (dim_value[1] == dim_value[0]) {
62 if (dim_value[0] == -1) {
63 need_shapecalc[0] = need_shapecalc[1] = true;
64 break;
65 }
66 } else if (dim_value[1] > 0 && dim_value[0] > 0) {
67 for (size_t j = 0; j < kDim2; j++) {
68 if (dim_value[j] == 1) {
69 (void)reduce_axis[j].emplace_back(reduce_idx);
70 }
71 }
72 } else {
73 for (size_t j = 0; j < kDim2; j++) {
74 if (dim_value[j] == -1) {
75 if (dim_value[j ^ 1] == 1) {
76 (void)reduce_axis[j ^ 1].emplace_back(reduce_idx);
77 } else {
78 need_shapecalc[j] = true;
79 if (need_shapecalc[j ^ 1] == need_shapecalc[j]) {
80 break;
81 }
82 (void)reduce_axis[j].emplace_back(reduce_idx);
83 }
84 }
85 }
86 }
87 }
88 return {need_shapecalc, reduce_axis};
89 }
90
DynBinopGradCommon(BpropBuilder * ib,const NodePtr & x,const NodePtr & y,const NodePtr & dx,const NodePtr & dy,size_t shift=0UL)91 NodePtrList DynBinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
92 const NodePtr &dy, size_t shift = 0UL) {
93 NodePtr inputs[] = {x, y};
94 NodePtrList reduce = {dx, dy};
95 ShapeVector shape[] = {ib->GetShape(inputs[0]), ib->GetShape(inputs[1])};
96 auto [need_shapecalc, reduce_axis] = DynBroadcastGradientArgs(shape[0], shape[1]);
97 NodePtrList broadcast_axes;
98 if (need_shapecalc[0] || need_shapecalc[1]) {
99 broadcast_axes = ib->BroadcastGradientArgs(inputs[0], inputs[1], shift);
100 }
101 for (size_t i = 0; i < kDim2; i++) {
102 auto dout_shape = ib->GetShape(reduce[i]);
103 if (!need_shapecalc[i] && IsDynamicRank(dout_shape)) {
104 MS_LOG(WARNING) << "The dynamic shape inference of" << reduce[i]->ToString() << " is overly generalized.";
105 }
106 if (!need_shapecalc[i] && !IsDynamicRank(dout_shape)) {
107 if (!reduce_axis[i].empty()) {
108 reduce[i] = ib->SumExt(reduce[i], ib->Value<ShapeVector>(reduce_axis[i]),
109 ib->Value<bool>(dout_shape.size() == shape[i].size()));
110 }
111 if (ib->GetRank(reduce[i]) != shape[i].size()) {
112 reduce[i] = ib->Reshape(reduce[i], ib->Shape(inputs[i]));
113 }
114 } else {
115 bool keep_dims = (!IsDynamicRank(shape[0]) && !IsDynamicRank(shape[1]) && shape[i].size() >= shape[i ^ 1].size());
116 reduce[i] = ib->ReduceSum(reduce[i], broadcast_axes[i], keep_dims, true);
117 reduce[i] = ib->Reshape(reduce[i], ib->Shape(inputs[i]));
118 }
119 }
120 return reduce;
121 }
122
GetOutputDtype(TypeId t1,TypeId t2,bool use_complex=false)123 TypeId GetOutputDtype(TypeId t1, TypeId t2, bool use_complex = false) {
124 static std::unordered_map<TypeId, int> complex_priority_map{
125 {kNumberTypeFloat32, 0}, {kNumberTypeFloat32, 1}, {kNumberTypeComplex64, 2}, {kNumberTypeComplex128, 4}};
126 static std::unordered_map<TypeId, int> type_priority_map{
127 {kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt16, 3},
128 {kNumberTypeInt16, 4}, {kNumberTypeUInt32, 5}, {kNumberTypeInt32, 6}, {kNumberTypeUInt64, 7},
129 {kNumberTypeInt64, 8}, {kNumberTypeFloat16, 9}, {kNumberTypeFloat32, 10}, {kNumberTypeFloat64, 11},
130 {kNumberTypeBFloat16, 12}};
131 int priority_1 = 0;
132 int priority_2 = 0;
133 if (use_complex) {
134 if (complex_priority_map.find(t1) == complex_priority_map.end() ||
135 complex_priority_map.find(t2) == complex_priority_map.end()) {
136 MS_EXCEPTION(ValueError) << "Complex binary op type promotion not supported for " << TypeIdToString(t1) << " and "
137 << TypeIdToString(t2);
138 }
139 priority_1 = complex_priority_map[t1];
140 priority_2 = complex_priority_map[t2];
141 } else {
142 if (type_priority_map.find(t1) == type_priority_map.end() ||
143 type_priority_map.find(t2) == type_priority_map.end()) {
144 MS_EXCEPTION(ValueError) << "Binary op type promotion not supported for " << TypeIdToString(t1) << " and "
145 << TypeIdToString(t2);
146 }
147 priority_1 = type_priority_map[t1];
148 priority_2 = type_priority_map[t2];
149 }
150 return (priority_1 > priority_2 ? t1 : t2);
151 }
152 } // namespace
153
NormalizeAxis(int64_t axis,size_t rank)154 int64_t NormalizeAxis(int64_t axis, size_t rank) {
155 auto rank_i = SizeToLong(rank);
156 if (axis < -rank_i || axis >= rank_i) {
157 MS_EXCEPTION(ValueError) << "For rank " << rank << ", the axis must be in range [" << -rank_i << ", " << rank_i
158 << "), but got " << axis;
159 }
160 return (axis < 0) ? (axis + rank_i) : axis;
161 }
162
SplitShapeIndex(const ShapeVector & input_shape,const ShapeVector & axis)163 std::pair<ShapeVector, ShapeVector> SplitShapeIndex(const ShapeVector &input_shape, const ShapeVector &axis) {
164 auto rank = SizeToLong(input_shape.size());
165 if (rank == 0) {
166 return {};
167 }
168 std::vector<bool> reduction_indices_map(input_shape.size());
169 ShapeVector perm;
170 int64_t reduced_num = 1;
171 int64_t other_num = 1;
172 for (auto i : axis) {
173 if (i < 0) {
174 i += rank;
175 }
176 reduction_indices_map[i] = True;
177 reduced_num *= input_shape[LongToSize(i)];
178 (void)perm.emplace_back(i);
179 }
180 for (int64_t i = 0; i < rank; i++) {
181 if (!reduction_indices_map[i]) {
182 other_num *= input_shape[LongToSize(i)];
183 (void)perm.emplace_back(i);
184 }
185 }
186 ShapeVector pack_shape{reduced_num, other_num};
187 return std::make_pair(pack_shape, perm);
188 }
189
TupleDiv(const std::vector<int64_t> & x,const std::vector<int64_t> & y)190 std::vector<int64_t> TupleDiv(const std::vector<int64_t> &x, const std::vector<int64_t> &y) {
191 std::vector<int64_t> out;
192 if (x.size() != y.size()) {
193 MS_LOG(EXCEPTION) << "The size of inputs of TupleDiv must be the same, but the size of divisor tuple is"
194 << " " << y.size() << ", the size of dividend tuple is " << x.size() << ".";
195 }
196 for (size_t i = 0; i < y.size(); i++) {
197 if (y[i] == 0) {
198 MS_LOG(EXCEPTION) << "The divisor value should not be 0!";
199 }
200 if ((x[i] % y[i]) != 0) {
201 MS_LOG(EXCEPTION) << "The inputs of TupleDiv should be divisible, but they are not divisible now, "
202 << "the dividend is " << x[i] << ", the divisor is " << y[i] << ".";
203 }
204 out.push_back(x[i] / y[i]);
205 }
206 return out;
207 }
208
ReduceShape(const std::vector<int64_t> & x,const std::vector<int64_t> & axis,bool skip_mode)209 std::vector<int64_t> ReduceShape(const std::vector<int64_t> &x, const std::vector<int64_t> &axis, bool skip_mode) {
210 if (x.empty()) {
211 return {};
212 }
213 if (axis.empty()) {
214 if (skip_mode) {
215 return x;
216 }
217 return std::vector<int64_t>(x.size(), 1LL);
218 }
219 int64_t x_rank = SizeToLong(x.size());
220 std::vector<int64_t> out(x);
221 for (auto i : axis) {
222 if (i >= x_rank || i < (-x_rank)) {
223 MS_LOG(EXCEPTION) << "axis should be in range [" << (-x_rank) << ", " << x_rank << ").";
224 }
225 if (i < 0) {
226 i += x_rank;
227 }
228 out[i] = 1LL;
229 }
230 return out;
231 }
232
GetIntValue(const NodePtr & node)233 int64_t GetIntValue(const NodePtr &node) {
234 MS_EXCEPTION_IF_NULL(node);
235 auto value = node->BuildValue();
236 if (value->isa<tensor::BaseTensor>()) {
237 auto t_vec = CheckAndConvertUtils::CheckTensorIntValue("tensor", value, "bprop");
238 MS_EXCEPTION_IF_CHECK_FAIL(t_vec.size() >= kIndex1, "Get single tensor value failed");
239 return t_vec[kIndex0];
240 }
241 return AnfUtils::GetIntValue(value);
242 }
243
GetIntList(const ValuePtr & value)244 std::vector<int64_t> GetIntList(const ValuePtr &value) {
245 MS_EXCEPTION_IF_NULL(value);
246 if (value->isa<tensor::BaseTensor>()) {
247 auto tensor = value->cast<tensor::BaseTensorPtr>();
248 MS_EXCEPTION_IF_NULL(tensor);
249 tensor->data_sync();
250 return CheckAndConvertUtils::CheckTensorIntValue("tensor", value, "bprop");
251 } else {
252 return CheckAndConvertUtils::CheckIntOrTupleInt("value", value, "bprop");
253 }
254 }
255
GetIntList(const NodePtr & node)256 std::vector<int64_t> GetIntList(const NodePtr &node) {
257 auto value = node->BuildValue();
258 MS_EXCEPTION_IF_NULL(value);
259 return GetIntList(value);
260 }
261
StaticBinopGradCommon(BpropBuilder * ib,const NodePtr & dx,const ShapeArray & shape,const ShapeArray & broadcast_shape,size_t shift,size_t index,bool * is_dynamic_shape)262 NodePtr StaticBinopGradCommon(BpropBuilder *ib, const NodePtr &dx, const ShapeArray &shape,
263 const ShapeArray &broadcast_shape, size_t shift, size_t index, bool *is_dynamic_shape) {
264 NodePtr reduce_dx = dx;
265 auto shape_dynamic_dims = std::count_if(shape[index].begin(), shape[index].end(), [](int64_t x) { return x <= -1; });
266 if (broadcast_shape[kIndex0].empty() || broadcast_shape[kIndex1].empty()) {
267 if (broadcast_shape[index].empty()) {
268 if (shift) {
269 std::vector<int64_t> axis(broadcast_shape[index ^ 1].size());
270 std::iota(axis.begin(), axis.end(), 0LL);
271 reduce_dx = ib->SumExt(reduce_dx, ib->Value<ShapeVector>(axis), ib->Value(false));
272 } else {
273 reduce_dx = ib->SumExt(reduce_dx, ib->EmitValue(kNone), ib->Value(false));
274 }
275 }
276 } else if (!IsDynamic(broadcast_shape[0]) && !IsDynamic(broadcast_shape[1]) && shape_dynamic_dims <= 1) {
277 std::vector<std::vector<int64_t>> bc_axis = BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1]);
278 if (!bc_axis[index].empty()) {
279 reduce_dx = ib->SumExt(reduce_dx, ib->Value<ShapeVector>(bc_axis[index]),
280 ib->Value<bool>(ib->GetRank(reduce_dx) == shape[index].size()));
281 }
282 reduce_dx = ib->Reshape(reduce_dx, shape[index]);
283 } else {
284 *is_dynamic_shape = true;
285 }
286 return reduce_dx;
287 }
288
MatMulExtBroadCastGradPart(BpropBuilder * ib,const NodePtr & dx,const ShapeArray & shape,const ShapeArray & broadcast_shape,size_t ignore_offset,size_t index)289 NodePtr MatMulExtBroadCastGradPart(BpropBuilder *ib, const NodePtr &dx, const ShapeArray &shape,
290 const ShapeArray &broadcast_shape, size_t ignore_offset, size_t index) {
291 NodePtr reduce_dx = dx;
292 std::vector<std::vector<int64_t>> bc_axis =
293 BroadcastGradientArgsInferValue(broadcast_shape[0], broadcast_shape[1], ignore_offset);
294 if (!bc_axis[index].empty()) {
295 reduce_dx = ib->ReduceSum(reduce_dx, bc_axis[index], ib->GetRank(reduce_dx) == shape[index].size());
296 }
297 if (ib->GetRank(reduce_dx) != shape[index].size()) {
298 reduce_dx = ib->Reshape(reduce_dx, shape[index]);
299 }
300 return reduce_dx;
301 }
302
BinopGradCommon(BpropBuilder * ib,const NodePtr & x,const NodePtr & y,const NodePtr & dx,const NodePtr & dy,size_t shift)303 NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, const NodePtr &dy,
304 size_t shift) {
305 // Common grad definition for binary operations with shift.
306 // The function is usually used in backprop op to reduce additional dimensions
307 // created by broadcasting.
308 NodePtrList inputs{x, y};
309 ShapeArray shape{ib->GetShape(inputs[kIndex0]), ib->GetShape(inputs[kIndex1])};
310 NodePtrList reduce = {dx, dy};
311 if (IsDynamicRank(shape[kIndex0]) || IsDynamicRank(shape[kIndex1])) {
312 return DynBinopGradCommon(ib, x, y, dx, dy, shift);
313 }
314 if (shape[kIndex0].size() <= shift && shape[kIndex0].size() == shape[kIndex1].size()) {
315 return reduce;
316 }
317 ShapeArray broadcast_shape(kDim2);
318 for (size_t i = 0; i < kDim2; i++) {
319 broadcast_shape[i] = ShapeVector(shape[i].begin(), shape[i].end() - shift);
320 }
321 bool is_x_shape_dynamic = false;
322 bool is_y_shape_dynamic = false;
323 if (dx != nullptr) {
324 reduce[kIndex0] =
325 StaticBinopGradCommon(ib, reduce[kIndex0], shape, broadcast_shape, shift, kIndex0, &is_x_shape_dynamic);
326 }
327 if (dy != nullptr) {
328 reduce[kIndex1] =
329 StaticBinopGradCommon(ib, reduce[kIndex1], shape, broadcast_shape, shift, kIndex1, &is_y_shape_dynamic);
330 }
331 if (is_x_shape_dynamic || is_y_shape_dynamic) {
332 return DynBinopGradCommon(ib, x, y, dx, dy, shift);
333 }
334 return reduce;
335 }
336
MatMulExtBroadCastGrad(BpropBuilder * ib,const NodePtr & x,const NodePtr & y,const NodePtr & dx,const NodePtr & dy,size_t ignore_offset)337 NodePtrList MatMulExtBroadCastGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
338 const NodePtr &dy, size_t ignore_offset) {
339 NodePtrList inputs{x, y};
340 ShapeArray shape{ib->GetShape(inputs[kIndex0]), ib->GetShape(inputs[kIndex1])};
341 NodePtrList reduce = {dx, dy};
342 ShapeArray broadcast_shape(kDim2);
343 broadcast_shape[0] = shape[0];
344 broadcast_shape[1] = shape[1];
345
346 if (dx != nullptr) {
347 reduce[kIndex0] = MatMulExtBroadCastGradPart(ib, reduce[kIndex0], shape, broadcast_shape, ignore_offset, kIndex0);
348 }
349 if (dy != nullptr) {
350 reduce[kIndex1] = MatMulExtBroadCastGradPart(ib, reduce[kIndex1], shape, broadcast_shape, ignore_offset, kIndex1);
351 }
352 return reduce;
353 }
354
Range(int64_t start,int64_t stop,int64_t step)355 std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step) {
356 if (step == 0) {
357 MS_EXCEPTION(ValueError) << "For Range, step should not be 0";
358 }
359 auto size = stop - start;
360 if (size * step <= 0) {
361 return {};
362 }
363 if (size % step == 0) {
364 size = size / step;
365 } else {
366 size = size / step + 1;
367 }
368 std::vector<int64_t> range(LongToSize(size));
369 for (size_t i = 0; i < range.size(); i++, start += step) {
370 range[i] = start;
371 }
372 return range;
373 }
374
Range(int64_t stop)375 std::vector<int64_t> Range(int64_t stop) { return Range(0, stop); }
376
GetTransposeAxis(const std::vector<int64_t> & x_shape,int64_t axis)377 std::vector<int64_t> GetTransposeAxis(const std::vector<int64_t> &x_shape, int64_t axis) {
378 std::vector<int64_t> reverse_axis;
379 if (x_shape.empty()) {
380 return reverse_axis;
381 }
382 auto rk = static_cast<int64_t>(x_shape.size());
383 if (axis < 0) {
384 axis += rk;
385 }
386 reverse_axis.reserve(x_shape.size());
387 for (int64_t i = 0; i < rk; ++i) {
388 (void)reverse_axis.emplace_back(i);
389 }
390 reverse_axis[LongToSize(axis)] = rk - 1;
391 reverse_axis[LongToSize(rk - 1)] = axis;
392 return reverse_axis;
393 }
394
CheckRange(int64_t idx,int64_t dim_size)395 int64_t CheckRange(int64_t idx, int64_t dim_size) {
396 if (idx < -dim_size || idx >= dim_size) {
397 MS_EXCEPTION(IndexError) << "index {" << idx << "} is out of bounds for dimension with size {" << dim_size << "}";
398 }
399 return idx < 0 ? (idx + dim_size) : idx;
400 }
401
GetEps(BpropBuilder * ib,const TypePtr & type)402 NodePtr GetEps(BpropBuilder *ib, const TypePtr &type) {
403 constexpr auto epsilon = 0.000977;
404 switch (type->type_id()) {
405 case kNumberTypeFloat16:
406 return ib->Tensor(epsilon, type);
407 case kNumberTypeFloat32:
408 return ib->Tensor(std::numeric_limits<float>::epsilon(), type);
409 case kNumberTypeFloat64:
410 return ib->Tensor(std::numeric_limits<double>::epsilon(), type);
411 default:
412 return ib->Tensor(0, type);
413 }
414 }
415
GenerateInverseIndex(const std::vector<int64_t> & x_shp,int64_t axis_v,int64_t batch_dims)416 std::vector<int64_t> GenerateInverseIndex(const std::vector<int64_t> &x_shp, int64_t axis_v, int64_t batch_dims) {
417 int64_t x_rank = static_cast<int64_t>(x_shp.size());
418 auto index = Range(x_rank);
419 if (axis_v < 0) {
420 axis_v += x_rank;
421 }
422 std::vector<int64_t> perm;
423 auto start1 = x_rank <= 1 ? index.end() : index.begin() + batch_dims + 1;
424 auto end1 = axis_v + 1 >= x_rank ? index.end() : index.begin() + axis_v + 1;
425 auto start2 = axis_v + 1 >= x_rank ? index.end() : index.begin() + axis_v + 1;
426 (void)std::copy(index.begin(), index.begin() + batch_dims, std::back_inserter(perm));
427 (void)std::copy(start1, end1, std::back_inserter(perm));
428 perm.push_back(batch_dims);
429 (void)std::copy(start2, index.end(), std::back_inserter(perm));
430 return perm;
431 }
432
GenerateShapeIndex(const std::vector<int64_t> & out_shp,const std::vector<int64_t> & ind_shp,int64_t axis_v,int64_t batch_dims)433 std::vector<int64_t> GenerateShapeIndex(const std::vector<int64_t> &out_shp, const std::vector<int64_t> &ind_shp,
434 int64_t axis_v, int64_t batch_dims) {
435 int64_t out_rank = static_cast<int64_t>(out_shp.size());
436 int64_t ind_rank = static_cast<int64_t>(ind_shp.size());
437 if (axis_v < 0) {
438 axis_v += out_rank - ind_rank + 1;
439 }
440 auto perm_part1 = Range(axis_v, axis_v + ind_rank - batch_dims);
441 auto index = Range(out_rank);
442 std::vector<int64_t> perm;
443 auto end = axis_v >= out_rank ? out_rank - 1 : axis_v;
444 auto start =
445 (axis_v + ind_rank - batch_dims) >= out_rank ? index.end() : (index.begin() + axis_v + ind_rank - batch_dims);
446 (void)std::copy(index.begin(), index.begin() + batch_dims, std::back_inserter(perm));
447 (void)std::copy(perm_part1.begin(), perm_part1.end(), std::back_inserter(perm));
448 (void)std::copy(index.begin() + batch_dims, index.begin() + end, std::back_inserter(perm));
449 (void)std::copy(start, index.end(), std::back_inserter(perm));
450 return perm;
451 }
452
RegenerateOutputShape(const std::vector<int64_t> & x_shp,const std::vector<int64_t> & ind_shp,int64_t axis_v,int64_t batch_dims)453 std::vector<int64_t> RegenerateOutputShape(const std::vector<int64_t> &x_shp, const std::vector<int64_t> &ind_shp,
454 int64_t axis_v, int64_t batch_dims) {
455 int64_t rank = static_cast<int64_t>(x_shp.size());
456 if (axis_v < 0) {
457 axis_v += rank;
458 }
459 std::vector<int64_t> out_shp;
460 auto end = axis_v >= rank ? rank - 1 : axis_v;
461 auto start = axis_v + 1 >= rank ? x_shp.end() : x_shp.begin() + axis_v + 1;
462 (void)std::copy(x_shp.begin(), x_shp.begin() + end, std::back_inserter(out_shp));
463 (void)std::copy(ind_shp.begin() + batch_dims, ind_shp.end(), std::back_inserter(out_shp));
464 (void)std::copy(start, x_shp.end(), std::back_inserter(out_shp));
465 return out_shp;
466 }
467
InvertPermutation(const std::vector<int64_t> & perm)468 std::vector<int64_t> InvertPermutation(const std::vector<int64_t> &perm) {
469 std::vector<int64_t> check_perm(perm);
470 std::vector<int64_t> res(perm);
471 if (res.empty()) {
472 return res;
473 }
474 std::sort(check_perm.begin(), check_perm.end());
475 int64_t perm_size = static_cast<int64_t>(check_perm.size());
476 for (int64_t i = 0; i < perm_size; i++) {
477 auto idx = LongToSize(i);
478 if (check_perm[idx] != i) {
479 MS_LOG(EXCEPTION) << "For InvertPermutation, the input_x should be '[0-" << (perm_size - 1) << "]', but got "
480 << check_perm;
481 }
482 res[LongToSize(perm[idx])] = i;
483 }
484 return res;
485 }
486
GetTransposition(int64_t axis,int64_t rank)487 std::vector<int64_t> GetTransposition(int64_t axis, int64_t rank) {
488 if (axis < 0) {
489 axis += rank;
490 }
491 auto trans = Range(axis);
492 auto after_axis = Range(axis + 1, rank - 1);
493 trans.push_back(rank - 1);
494 (void)trans.insert(trans.end(), after_axis.begin(), after_axis.end());
495 trans.push_back(axis);
496 return trans;
497 }
498
499 class ReduceShapeShapeCalc : public ShapeCalcFunctor {
500 public:
501 // cppcheck-suppress unknownMacro
502 DECLARE_SHAPE_CALC("ShapeCalc_ReduceShape", ReduceShapeShapeCalc)
ReduceShapeShapeCalc(bool skip_mode)503 explicit ReduceShapeShapeCalc(bool skip_mode) : ShapeCalcFunctor("ShapeCalc_ReduceShape"), skip_mode_(skip_mode) {}
ToValue() const504 ValuePtr ToValue() const override { return MakeValue(skip_mode_); }
FromValue(const ValuePtr & value)505 void FromValue(const ValuePtr &value) override { skip_mode_ = GetValue<int64_t>(value); }
Calc(const ShapeArray & inputs) const506 ShapeArray Calc(const ShapeArray &inputs) const override {
507 auto x_shape = inputs.at(0);
508 auto axis_value = inputs.at(1);
509 auto r_shape = ReduceShape(x_shape, axis_value, skip_mode_);
510 auto scaling = TupleDiv(x_shape, r_shape);
511 return {r_shape, scaling};
512 }
Infer(const ShapeArray & inputs,const HashSet<size_t> &) const513 std::vector<int64_t> Infer(const ShapeArray &inputs, const HashSet<size_t> &) const override {
514 int64_t x_rank = IsDynamicRank(inputs.at(0)) ? -1 : static_cast<int64_t>(inputs.at(0).size());
515 return {x_rank, x_rank};
516 }
517
518 protected:
519 bool skip_mode_ = false;
520 };
521 REG_FUNCTOR("ShapeCalc_ReduceShape", ReduceShapeShapeCalc);
522
SumGrad(Emitter * ib,const NodePtr & x,const NodePtr & axis,const NodePtr & dout,bool keep_dims,bool skip_mode)523 NodePtr SumGrad(Emitter *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &dout, bool keep_dims,
524 bool skip_mode) {
525 auto grad = dout;
526 auto calc_res = ib->ShapeCalc(std::make_shared<ReduceShapeShapeCalc>(skip_mode), {x, axis}, {1});
527 if (!keep_dims) {
528 grad = ib->Reshape(grad, calc_res[0]);
529 }
530 auto tile_scaling = calc_res[1];
531 if (tile_scaling->input_type() == InputType::kConstant || IsDynamic(x->shape())) {
532 return ib->Tile(grad, tile_scaling);
533 }
534 return ib->BroadcastTo(grad, x);
535 }
536
MinOrMaxGrad(BpropBuilder * ib,const NodePtr & x,const NodePtr & axis,const NodePtr & keep_dims,const NodePtr & out,const NodePtr & dout)537 NodePtr MinOrMaxGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &keep_dims,
538 const NodePtr &out, const NodePtr &dout) {
539 auto y = out;
540 auto grad = dout;
541 auto keepdims = GetValue<bool>(keep_dims->BuildValue());
542 if (!keepdims) {
543 auto output_shape_kept_dims = ib->ShapeCalc(std::make_shared<ReduceShapeShapeCalc>(), {x, axis}, {1})[0];
544 y = ib->Reshape(out, output_shape_kept_dims);
545 grad = ib->Reshape(dout, output_shape_kept_dims);
546 }
547 auto indicators = ib->Cast(ib->Equal(y, x), ib->GetDtype(grad));
548 auto num_selected = ib->ReduceSum(indicators, axis, true, false);
549 return indicators / num_selected * grad;
550 }
551
TensorScatterElementsZeroDim(Emitter * ib,const NodePtr & input,const ValuePtr & dim,const NodePtr & index,const NodePtr & src,const std::string & reduce_string)552 inline NodePtr TensorScatterElementsZeroDim(Emitter *ib, const NodePtr &input, const ValuePtr &dim,
553 const NodePtr &index, const NodePtr &src,
554 const std::string &reduce_string) {
555 // TensorScatterElements op: ZeroDim need to expand to OneDim
556 auto input_expand = ib->ExpandDims(input, -1);
557 auto index_expand = ib->ExpandDims(index, -1);
558 auto src_expand = ib->ExpandDims(src, -1);
559 auto out = ib->Emit("TensorScatterElements", {input_expand, index_expand, src_expand},
560 {{"reduction", MakeValue<string>(reduce_string)}, {"axis", dim}});
561 // recover OneDim To ZeroDim
562 return ib->Squeeze(out, MakeValue(ShapeVector{0}));
563 }
564
TensorScatterElements(Emitter * ib,const NodePtr & input,const ValuePtr & dim,const NodePtr & index,const NodePtr & src,const std::string & reduce_string)565 inline NodePtr TensorScatterElements(Emitter *ib, const NodePtr &input, const ValuePtr &dim, const NodePtr &index,
566 const NodePtr &src, const std::string &reduce_string) {
567 return ib->Emit("TensorScatterElements", {input, index, src},
568 {{"reduction", MakeValue<string>(reduce_string)}, {"axis", dim}});
569 }
570
Scatter_(BpropBuilder * ib,const NodePtr & input,const NodePtr & dim,const NodePtr & index,const NodePtr & src,const std::string & reduce_string)571 NodePtr Scatter_(BpropBuilder *ib, const NodePtr &input, const NodePtr &dim, const NodePtr &index, const NodePtr &src,
572 const std::string &reduce_string) {
573 auto dim_val = dim->BuildValue();
574 if (!ops::IsValueKnown(dim_val)) {
575 MS_EXCEPTION(ValueError) << "For `TensorScatterElements` op, the `axis` must currently be a constant!";
576 }
577 auto input_shape = ib->GetShape(input);
578 if (input_shape.size() == 0) {
579 return TensorScatterElementsZeroDim(ib, input, dim_val, index, src, reduce_string);
580 } else if (IsDynamicRank(input_shape)) {
581 auto rank = ib->Emit("Rank", {input});
582 auto is_zero_dim_cond = ib->Emit("scalar_eq", {rank, ib->Value<int64_t>(0)});
583 auto scatter_zero_dim_impl = [&input, &dim_val, &index, &src, &reduce_string](Emitter *e) -> NodePtrList {
584 return {TensorScatterElementsZeroDim(e, input, dim_val, index, src, reduce_string)};
585 };
586 auto scatter_impl = [&input, &dim_val, &index, &src, &reduce_string](Emitter *e) -> NodePtrList {
587 return {TensorScatterElements(e, input, dim_val, index, src, reduce_string)};
588 };
589 return ib->Conditional(is_zero_dim_cond, scatter_zero_dim_impl, scatter_impl);
590 }
591 return TensorScatterElements(ib, input, dim_val, index, src, reduce_string);
592 }
593
ArgminOrArgmaxGrad(BpropBuilder * ib,const NodePtr & x,const NodePtr & axis,const NodePtr & keep_dims,const NodePtr & out,const NodePtr & dout,const bool is_max)594 NodePtr ArgminOrArgmaxGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &keep_dims,
595 const NodePtr &out, const NodePtr &dout, const bool is_max) {
596 auto keep_dims_value = keep_dims->BuildValue();
597 NodePtr dout_value = ib->TupleGetItem(dout, 1);
598 NodePtr indices = ib->TupleGetItem(out, 0);
599 auto input_shape = ib->GetShape(x);
600 if (ops::IsValueKnown(keep_dims_value) && !IsDynamicRank(input_shape)) {
601 auto is_zero_dim = input_shape.size() == 0;
602 auto keep_dims_bool = GetValue<bool>(keep_dims_value);
603 indices = (keep_dims_bool || is_zero_dim) ? indices : ib->Emit("ExpandDims", {indices, axis});
604 dout_value = (keep_dims_bool || is_zero_dim) ? dout_value : ib->Emit("ExpandDims", {dout_value, axis});
605 } else {
606 auto rank = ib->Emit("Rank", {x});
607 auto rank_is_zero = ib->Emit("scalar_eq", {rank, ib->Value<int64_t>(0)});
608 auto cond = ib->LogicalOr(ib->ScalarToTensor(keep_dims, kBool), ib->ScalarToTensor(rank_is_zero, kBool));
609 auto indices_expand = [&indices, &axis](Emitter *e) -> NodePtrList {
610 return {e->Emit("ExpandDims", {indices, axis})};
611 };
612 auto indices_ori = [&indices](Emitter *e) -> NodePtrList { return {indices}; };
613 indices = ib->Conditional(cond, indices_ori, indices_expand);
614 auto dout_expand = [&dout_value, &axis](Emitter *e) -> NodePtrList {
615 return {e->Emit("ExpandDims", {dout_value, axis})};
616 };
617 auto dout_ori = [&dout_value](Emitter *e) -> NodePtrList { return {dout_value}; };
618 dout_value = ib->Conditional(cond, dout_ori, dout_expand);
619 }
620 NodePtr dx_zeros = ib->Zeros(x);
621 auto dx = Scatter_(ib, dx_zeros, axis, indices, dout_value, "none");
622 return dx;
623 }
624
PromoteBinaryDtype(TypeId t1,TypeId t2)625 TypeId PromoteBinaryDtype(TypeId t1, TypeId t2) {
626 if (t1 == t2) {
627 return t1;
628 }
629 static std::unordered_set<TypeId> complex_types{kNumberTypeComplex64, kNumberTypeComplex128};
630 return GetOutputDtype(
631 t1, t2, (complex_types.find(t1) != complex_types.end() || complex_types.find(t2) != complex_types.end()));
632 }
633
LGamma(BpropBuilder * ib,const NodePtr & x)634 NodePtr LGamma(BpropBuilder *ib, const NodePtr &x) {
635 auto k_lanczos_gamma = 7;
636 auto k_base_lanczos_coeff = 0.9999999999998099;
637 double k_lanczos_coefficients[8] = {676.520368121885098567009190444019, -1259.13921672240287047156078755283,
638 771.3234287776530788486528258894, -176.61502916214059906584551354,
639 12.507343278686904814458936853, -0.13857109526572011689554707,
640 9.984369578019570859563e-6, 1.50563273514931155834e-7};
641 auto input_dtype = ib->GetDtype(x);
642 auto one_half = ib->Tensor(0.5, input_dtype);
643 auto one = ib->Tensor(1, input_dtype);
644 auto zero = ib->Tensor(0, input_dtype);
645 auto log_sqrt_two_pi = ib->Tensor((log_2 + log_pi) / 2, input_dtype);
646 auto lanczos_gamma_plus_one_half = k_lanczos_gamma + 0.5;
647 auto log_lanczos_gamma_plus_one_half = log(lanczos_gamma_plus_one_half);
648 auto inf = std::numeric_limits<double>::infinity();
649 auto infinity = ib->Fill(inf, ib->Shape(x), input_dtype->type_id());
650 auto need_to_reflect = ib->Less(x, one_half);
651 auto neg_input = ib->Neg(x);
652 auto z = ib->Select(need_to_reflect, neg_input, ib->Sub(x, one));
653 auto CalculateReflectedX = [&ib, &z, &k_base_lanczos_coeff, &k_lanczos_coefficients]() -> NodePtr {
654 auto z_dtype = ib->GetDtype(z);
655 NodePtr reflex_x = ib->Tensor(k_base_lanczos_coeff, z_dtype);
656 for (int i = 0; i < 8; ++i) {
657 auto btmp = ib->Add(z, ib->Tensor(i, z_dtype));
658 btmp = ib->Add(btmp, (ib->Tensor(1, z_dtype)));
659 auto product = ib->RealDiv((ib->Tensor(k_lanczos_coefficients[i], z_dtype)), btmp);
660 reflex_x = ib->Add(product, reflex_x);
661 }
662 return reflex_x;
663 };
664 auto reflex_x = CalculateReflectedX();
665 auto lanczos_tensor = ib->Tensor(lanczos_gamma_plus_one_half, input_dtype);
666 auto log_lanczos_tensor = ib->Tensor(log_lanczos_gamma_plus_one_half, input_dtype);
667 auto t = ib->Add(z, lanczos_tensor);
668 auto log_t = ib->Add((ib->Emit("Log1p", {ib->RealDiv(z, lanczos_tensor)})), log_lanczos_tensor);
669 auto log_y = ib->Add(
670 (ib->Add((ib->Log(reflex_x)), (ib->Mul((ib->Sub((ib->Add(z, one_half)), (ib->RealDiv(t, log_t)))), log_t)))),
671 log_sqrt_two_pi);
672 auto abs_input = ib->Emit("Abs", {x});
673 auto abs_frac_input = ib->Sub(abs_input, (ib->Emit("Floor", {abs_input})));
674 auto new_x = ib->Select(ib->LessEqual(x, zero), ib->Select(ib->Equal(abs_frac_input, zero), infinity, x), x);
675 auto reduced_frac_input =
676 ib->Select(ib->Greater(abs_frac_input, one_half), ib->Sub(one, abs_frac_input), abs_frac_input);
677 auto reflection_denom =
678 ib->Log(ib->Emit("Sin", {ib->Mul(ib->Tensor(pi, ib->GetDtype(reduced_frac_input)), reduced_frac_input)}));
679 auto reflection =
680 ib->Select(ib->Emit("IsFinite", {reflection_denom}),
681 ib->Add((ib->Sub((ib->Neg(reflection_denom)), log_y)), ib->Tensor(log_pi, ib->GetDtype(log_y))),
682 ib->Neg(reflection_denom));
683 auto result = ib->Select(need_to_reflect, reflection, log_y);
684 return ib->Select(ib->Emit("IsFinite", {new_x}), result, infinity);
685 }
686
CheckType(const TypePtr & check_type,const std::set<TypePtr> & template_types)687 bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types) {
688 return std::any_of(template_types.begin(), template_types.end(), [&check_type](const TypePtr &accept) -> bool {
689 return IsIdentidityOrSubclass(check_type, accept);
690 });
691 }
692
PoolToNHWC(const ShapeVector & v)693 ShapeVector PoolToNHWC(const ShapeVector &v) {
694 ShapeVector new_v(v);
695 new_v[kIndex1] = v[kIndex2];
696 new_v[kIndex2] = v[kIndex3];
697 new_v[kIndex3] = v[kIndex1];
698 return new_v;
699 }
700
ConvToNHWC(const ShapeVector & v)701 ShapeVector ConvToNHWC(const ShapeVector &v) {
702 ShapeVector new_v(v);
703 new_v[kIndex0] = v[kIndex1];
704 new_v[kIndex1] = v[kIndex2];
705 new_v[kIndex2] = v[kIndex3];
706 new_v[kIndex3] = 1;
707 return new_v;
708 }
709
GetShapeByRange(const ShapeVector & v,int64_t begin,int64_t end)710 ShapeVector GetShapeByRange(const ShapeVector &v, int64_t begin, int64_t end) {
711 // Get range [begin, end) in v.
712 auto rank = SizeToLong(v.size());
713 auto real_begin = std::min((begin < 0) ? (rank + begin) : begin, rank);
714 auto real_end = std::min((end < 0) ? (rank + end) : end, rank);
715 ShapeVector res(v.begin() + real_begin, v.begin() + real_end);
716 return res;
717 }
718
MatrixTranspose(BpropBuilder * ib,const NodePtr & x)719 NodePtr MatrixTranspose(BpropBuilder *ib, const NodePtr &x) {
720 auto shape = ib->GetShape(x);
721 if (IsDynamicRank(shape)) {
722 auto dim = ib->Emit("Rank", {x});
723 auto perm = ib->Range(dim);
724 auto stridedslice_helper = [&perm, &ib](int64_t begin, int64_t end, int64_t step, int64_t end_mask = 0) {
725 return ib->Emit("StridedSlice",
726 {perm, ib->Value<ShapeVector>(ShapeVector{begin}), ib->Value<ShapeVector>(ShapeVector{end}),
727 ib->Value<ShapeVector>(ShapeVector{step}), ib->Value<int64_t>(0LL), ib->Value<int64_t>(end_mask),
728 ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL)});
729 };
730 auto part_1 = stridedslice_helper(0, -2, 1);
731 auto part_2 = stridedslice_helper(-1, 0, 1, 1);
732 auto part_3 = stridedslice_helper(-2, -1, 1);
733 perm = ib->Concat({part_1, part_2, part_3}, -1);
734 return ib->Transpose(x, ib->TensorToTuple(perm));
735 }
736 auto dim = shape.size();
737 if (dim < kDim2) {
738 MS_LOG_EXCEPTION << "For MatrixTranspose, input's ndim " << dim << " is less or equal to 2, which is invalid";
739 }
740 std::vector<int64_t> perm(dim);
741 for (size_t i = 0; i < dim; i++) {
742 perm[i] = static_cast<int64_t>(i);
743 }
744 std::swap(perm[dim - kIndex2], perm[dim - kIndex1]);
745 return ib->Transpose(x, perm);
746 }
747
MatrixTransposeExt(BpropBuilder * ib,const NodePtr & x)748 NodePtr MatrixTransposeExt(BpropBuilder *ib, const NodePtr &x) {
749 auto shape = ib->GetShape(x);
750 if (IsDynamicRank(shape)) {
751 auto dim = ib->Emit("Rank", {x});
752 auto perm = ib->Range(dim);
753 auto stridedslice_helper = [&perm, &ib](int64_t begin, int64_t end, int64_t step, int64_t end_mask = 0) {
754 return ib->Emit("StridedSlice",
755 {perm, ib->Value<ShapeVector>(ShapeVector{begin}), ib->Value<ShapeVector>(ShapeVector{end}),
756 ib->Value<ShapeVector>(ShapeVector{step}), ib->Value<int64_t>(0LL), ib->Value<int64_t>(end_mask),
757 ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL), ib->Value<int64_t>(0LL)});
758 };
759 auto part_1 = stridedslice_helper(0, -2, 1);
760 auto part_2 = stridedslice_helper(-1, 0, 1, 1);
761 auto part_3 = stridedslice_helper(-2, -1, 1);
762 perm = ib->Concat({part_1, part_2, part_3}, -1);
763 return ib->Transpose(x, ib->TensorToTuple(perm));
764 }
765 auto dim = shape.size();
766 if (dim < kDim2) {
767 return x;
768 }
769 std::vector<int64_t> perm(dim);
770 for (size_t i = 0; i < dim; i++) {
771 perm[i] = static_cast<int64_t>(i);
772 }
773 std::swap(perm[dim - kIndex2], perm[dim - kIndex1]);
774 return ib->Transpose(x, perm);
775 }
776
Adjoint(BpropBuilder * ib,const NodePtr & x)777 NodePtr Adjoint(BpropBuilder *ib, const NodePtr &x) { return MatrixTranspose(ib, ib->Conj(x)); }
778 } // namespace mindspore::expander::bprop
779