• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_GRAD_OPS_COMMON_UTILS_H_
17 #define MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_GRAD_OPS_COMMON_UTILS_H_
18 
19 #include <cmath>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "mindspore/core/ops/dynamic_broadcast_gradient_args.h"
26 #include "frontend/expander/bprop/bprop_irbuilder.h"
27 #include "include/common/expander/core/node.h"
28 
29 namespace mindspore::expander::bprop {
30 constexpr size_t i0 = 0;
31 constexpr size_t i1 = 1;
32 constexpr size_t i2 = 2;
33 constexpr size_t i3 = 3;
34 constexpr size_t i4 = 4;
35 constexpr size_t i5 = 5;
36 constexpr size_t i6 = 6;
37 constexpr size_t i7 = 7;
38 constexpr size_t i8 = 8;
39 constexpr size_t i9 = 9;
40 constexpr size_t i10 = 10;
41 inline const auto pi = std::acos(-1.0);
42 inline const auto log_2 = std::log(2.0);
43 inline const auto log_pi = std::log(pi);
44 
45 using mindspore::ops::BroadcastGradientArgsInferValue;
46 
47 NodePtrList ReturnZeros(BpropBuilder *ib);
48 // normalize the axis to [0, rank)
49 int64_t NormalizeAxis(int64_t axis, size_t rank);
50 
51 std::vector<int64_t> GetTransposeAxis(const std::vector<int64_t> &x_shape, int64_t axis);
52 
53 std::vector<int64_t> TupleDiv(const std::vector<int64_t> &x, const std::vector<int64_t> &y);
54 
55 std::vector<int64_t> ReduceShape(const std::vector<int64_t> &x, const std::vector<int64_t> &axis,
56                                  bool skip_mode = false);
57 
58 int64_t CheckRange(int64_t idx, int64_t dim_size);
59 
60 NodePtrList BinopGradCommon(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, const NodePtr &dy,
61                             size_t shift = 0UL);
62 NodePtrList MatMulExtBroadCastGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
63                                    const NodePtr &dy, size_t ignore_offset = 0UL);
64 
65 std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step = 1);
66 std::vector<int64_t> Range(int64_t stop);
67 
68 template <typename T>
69 std::vector<T> operator+(std::vector<T> const &m, std::vector<T> const &n) {
70   std::vector<T> v;                             // initialized vector v
71   v.reserve(m.size() + n.size());               // reverse function used in v
72   (void)v.insert(v.end(), m.begin(), m.end());  // insert func used in vec m.
73   (void)v.insert(v.end(), n.begin(), n.end());  // insert func used in vec n.
74   return v;                                     // return the vector v
75 }
76 
77 int64_t GetIntValue(const NodePtr &node);
78 std::vector<int64_t> GetIntList(const ValuePtr &value);
79 std::vector<int64_t> GetIntList(const NodePtr &node);
80 
81 NodePtr GetEps(BpropBuilder *ib, const TypePtr &type);
82 std::vector<int64_t> GenerateInverseIndex(const std::vector<int64_t> &x_shp, int64_t axis_v, int64_t batch_dims = 0);
83 std::vector<int64_t> GenerateShapeIndex(const std::vector<int64_t> &out_shp, const std::vector<int64_t> &ind_shp,
84                                         int64_t axis_v, int64_t batch_dims = 0);
85 std::vector<int64_t> RegenerateOutputShape(const std::vector<int64_t> &x_shp, const std::vector<int64_t> &ind_shp,
86                                            int64_t axis_v, int64_t batch_dims = 0);
87 std::vector<int64_t> InvertPermutation(const std::vector<int64_t> &perm);
88 std::vector<int64_t> GetTransposition(int64_t axis, int64_t rank);
89 
90 NodePtr SumGrad(Emitter *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &dout, bool keep_dims = false,
91                 bool skip_mode = false);
92 NodePtr MinOrMaxGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &keep_dims,
93                      const NodePtr &out, const NodePtr &dout);
94 std::pair<ShapeVector, ShapeVector> SplitShapeIndex(const ShapeVector &input_shape, const ShapeVector &axis);
95 NodePtr ArgminOrArgmaxGrad(BpropBuilder *ib, const NodePtr &x, const NodePtr &axis, const NodePtr &keep_dims,
96                            const NodePtr &out, const NodePtr &dout, const bool is_max);
97 TypeId PromoteBinaryDtype(TypeId t1, TypeId t2);
98 NodePtr LGamma(BpropBuilder *ib, const NodePtr &x);
99 bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types);
100 ShapeVector PoolToNHWC(const ShapeVector &v);
101 ShapeVector ConvToNHWC(const ShapeVector &v);
102 ShapeVector GetShapeByRange(const ShapeVector &v, int64_t begin = 0, int64_t end = -1);
103 NodePtr MatrixTranspose(BpropBuilder *ib, const NodePtr &x);
104 NodePtr MatrixTransposeExt(BpropBuilder *ib, const NodePtr &x);
105 NodePtr Adjoint(BpropBuilder *ib, const NodePtr &x);
106 }  // namespace mindspore::expander::bprop
107 #endif  // MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_GRAD_OPS_COMMON_UTILS_H_
108