• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23 
24 #include "abstract/abstract_value.h"
25 #include "abstract/dshape.h"
26 #include "abstract/ops/op_infer.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "base/base.h"
30 #include "ir/dtype/container.h"
31 #include "ir/dtype/number.h"
32 #include "ir/dtype/type.h"
33 #include "ir/primitive.h"
34 #include "mindapi/base/shape_vector.h"
35 #include "mindapi/base/type_id.h"
36 #include "mindapi/src/helper.h"
37 #include "mindspore/core/ops/array_ops.h"
38 #include "ops/dynamic_gru_v2_grad.h"
39 #include "ops/op_name.h"
40 #include "ops/primitive_c.h"
41 #include "utils/check_convert_utils.h"
42 #include "utils/log_adapter.h"
43 #include "utils/shape_utils.h"
44 
45 namespace mindspore {
46 namespace ops {
47 namespace {
DynamicGRUV2GradCheckShapeValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args,const int64_t & num_proj)48 void DynamicGRUV2GradCheckShapeValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
49                                      const int64_t &num_proj) {
50   auto prim_name = primitive->name();
51   auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
52   auto winput_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
53   auto whidden_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
54   auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
55   auto init_h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShape())[kShape];
56   auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
57   auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
58   auto dh_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->GetShape())[kShape];
59   auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->GetShape())[kShape];
60   auto reset_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShape())[kShape];
61   auto new_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex10]->GetShape())[kShape];
62   auto hnew_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex11]->GetShape())[kShape];
63 
64   std::vector<ShapeVector> all_shapes = {x_shape,  winput_shape, whidden_shape, y_shape,     init_h_shape, h_shape,
65                                          dy_shape, dh_shape,     update_shape,  reset_shape, new_shape,    hnew_shape};
66   auto is_dynamic = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamic);
67   if (is_dynamic) {
68     return;
69   }
70 
71   int64_t num_step = x_shape[0];
72   int64_t batch_size = x_shape[1];
73   int64_t input_size = x_shape[2];
74   int64_t hidden_size = whidden_shape[0];
75 
76   auto winput_shape_ptr = input_args[kInputIndex1]->GetShape();
77   auto whidden_shape_ptr = input_args[kInputIndex2]->GetShape();
78   auto y_shape_ptr = input_args[kInputIndex3]->GetShape();
79   auto init_h_shape_ptr = input_args[kInputIndex4]->GetShape();
80   auto h_shape_ptr = input_args[kInputIndex5]->GetShape();
81   auto dy_shape_ptr = input_args[kInputIndex6]->GetShape();
82   auto dh_shape_ptr = input_args[kInputIndex7]->GetShape();
83   auto update_shape_ptr = input_args[kInputIndex8]->GetShape();
84   auto reset_shape_ptr = input_args[kInputIndex9]->GetShape();
85   auto new_shape_ptr = input_args[kInputIndex10]->GetShape();
86   auto hnew_shape_ptr = input_args[kInputIndex11]->GetShape();
87 
88   (void)CheckAndConvertUtils::CheckTensorShapeSame({{"weight input shape", winput_shape_ptr}},
89                                                    std::vector<int64_t>{input_size, 3 * hidden_size}, prim_name);
90   (void)CheckAndConvertUtils::CheckTensorShapeSame({{"weight hidden shape", whidden_shape_ptr}},
91                                                    std::vector<int64_t>{hidden_size, 3 * hidden_size}, prim_name);
92   (void)CheckAndConvertUtils::CheckTensorShapeSame({{"init h shape", init_h_shape_ptr}},
93                                                    std::vector<int64_t>{batch_size, hidden_size}, prim_name);
94   (void)CheckAndConvertUtils::CheckTensorShapeSame({{"dh shape", dh_shape_ptr}},
95                                                    std::vector<int64_t>{batch_size, hidden_size}, prim_name);
96 
97   std::vector<int64_t> valid_y_shape;
98   (void)valid_y_shape.emplace_back(num_step);
99   (void)valid_y_shape.emplace_back(batch_size);
100   const int64_t kNumZero = 0;
101   if (num_proj > kNumZero) {
102     (void)valid_y_shape.emplace_back(std::min(hidden_size, num_proj));
103   } else {
104     (void)valid_y_shape.emplace_back(hidden_size);
105   }
106   (void)CheckAndConvertUtils::CheckTensorShapeSame({{"y shape", y_shape_ptr}}, valid_y_shape, prim_name);
107 
108   std::map<std::string, BaseShapePtr> check_shapes = {
109     {"h shape", h_shape_ptr},         {"dy shape", dy_shape_ptr},   {"update shape", update_shape_ptr},
110     {"reset shape", reset_shape_ptr}, {"new shape", new_shape_ptr}, {"hnew shape", hnew_shape_ptr}};
111   std::vector<int64_t> valid_shape = {num_step, batch_size, hidden_size};
112   (void)CheckAndConvertUtils::CheckTensorShapeSame(check_shapes, valid_shape, prim_name);
113 
114   if (input_args.size() >= kInputIndex13 && input_args[kInputIndex12]->GetType()->type_id() != kMetaTypeNone) {
115     auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex12]->GetShape())[kShape];
116     auto seq_shape_ptr = input_args[kInputIndex12]->GetShape();
117     if (!IsDynamic(seq_shape)) {
118       (void)CheckAndConvertUtils::CheckTensorShapeSame({{"seq shape", seq_shape_ptr}}, std::vector<int64_t>{batch_size},
119                                                        prim_name);
120     }
121   }
122 }
123 
DynamicGRUV2GradInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)124 abstract::TupleShapePtr DynamicGRUV2GradInferShape(const PrimitivePtr &primitive,
125                                                    const std::vector<AbstractBasePtr> &input_args) {
126   MS_EXCEPTION_IF_NULL(primitive);
127   auto prim_name = primitive->name();
128   auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
129   auto winput_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
130   auto whidden_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
131   auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
132 
133   int64_t num_proj = 0;
134   if (primitive->HasAttr(kNumProj)) {
135     num_proj = GetValue<int64_t>(primitive->GetAttr(kNumProj));
136   }
137 
138   std::vector<ShapeVector> check_shapes = {x_shape, winput_shape, whidden_shape, y_shape};
139   auto is_dynamic_rank = std::any_of(check_shapes.begin(), check_shapes.end(), IsDynamicRank);
140 
141   const int64_t kNumTwo = 2;
142   const int64_t kNumThree = 3;
143   if (!is_dynamic_rank) {
144     (void)CheckAndConvertUtils::CheckInteger("x shape rank", SizeToLong(x_shape.size()), kEqual, kNumThree, prim_name);
145     (void)CheckAndConvertUtils::CheckInteger("weight input shape rank", SizeToLong(winput_shape.size()), kEqual,
146                                              kNumTwo, prim_name);
147     (void)CheckAndConvertUtils::CheckInteger("weight hidden shape rank", SizeToLong(whidden_shape.size()), kEqual,
148                                              kNumTwo, prim_name);
149     (void)CheckAndConvertUtils::CheckInteger("y shape rank", SizeToLong(y_shape.size()), kEqual, kNumThree, prim_name);
150   }
151   DynamicGRUV2GradCheckShapeValue(primitive, input_args, num_proj);
152 
153   int64_t num_step = -1;
154   int64_t batch_size = -1;
155   int64_t input_size = -1;
156   int64_t hidden_size = -1;
157   int64_t hidden_size_three = -1;
158   if (!(IsDynamic(x_shape) || IsDynamic(whidden_shape))) {
159     num_step = x_shape[kInputIndex0];
160     batch_size = x_shape[kInputIndex1];
161     input_size = x_shape[kInputIndex2];
162     hidden_size = whidden_shape[kInputIndex0];
163     hidden_size_three = whidden_shape[kInputIndex1];
164   }
165 
166   ShapeVector dx_shape = {num_step, batch_size, input_size};
167   ShapeVector dh_shape = {batch_size, hidden_size};
168   ShapeVector dwinput_shape = {input_size, hidden_size_three};
169   ShapeVector dwhidden_shape = {hidden_size, hidden_size_three};
170   ShapeVector db_shape = {hidden_size_three};
171 
172   auto db_shape_ptr = std::make_shared<abstract::Shape>(db_shape);
173   auto dh_shape_ptr = std::make_shared<abstract::Shape>(dh_shape);
174   auto dx_shape_ptr = std::make_shared<abstract::Shape>(dx_shape);
175   auto dwinput_shape_ptr = std::make_shared<abstract::Shape>(dwinput_shape);
176   auto dwhidden_shape_ptr = std::make_shared<abstract::Shape>(dwhidden_shape);
177 
178   return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
179     dwinput_shape_ptr, dwhidden_shape_ptr, db_shape_ptr, db_shape_ptr, dx_shape_ptr, dh_shape_ptr});
180 }
181 
DynamicGRUV2GradInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)182 TuplePtr DynamicGRUV2GradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
183   MS_EXCEPTION_IF_NULL(primitive);
184   auto prim_name = primitive->name();
185   const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
186   auto x_dtype = input_args[kInputIndex0]->GetType();
187   auto winput_dtype = input_args[kInputIndex1]->GetType();
188   auto whidden_dtype = input_args[kInputIndex2]->GetType();
189   auto y_dtype = input_args[kInputIndex3]->GetType();
190   auto init_h_dtype = input_args[kInputIndex4]->GetType();
191   auto h_dtype = input_args[kInputIndex5]->GetType();
192   auto dy_dtype = input_args[kInputIndex6]->GetType();
193   auto dh_dtype = input_args[kInputIndex7]->GetType();
194   auto update_dtype = input_args[kInputIndex8]->GetType();
195   auto reset_dtype = input_args[kInputIndex9]->GetType();
196   auto new_dtype = input_args[kInputIndex10]->GetType();
197   auto hnew_dtype = input_args[kInputIndex11]->GetType();
198 
199   std::map<std::string, TypePtr> check_types = {
200     {"y_dtype", y_dtype},           {"h_dtype", h_dtype},         {"dy_dtype", dy_dtype},   {"dh_dtype", dh_dtype},
201     {"update_dtype", update_dtype}, {"reset_dtype", reset_dtype}, {"new_dtype", new_dtype}, {"hnew_dtype", hnew_dtype}};
202   (void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, valid_types, prim_name);
203   (void)CheckAndConvertUtils::CheckTensorTypeValid("winput_dtype", winput_dtype, valid_types, prim_name);
204   (void)CheckAndConvertUtils::CheckTensorTypeValid("whidden_dtype", whidden_dtype, valid_types, prim_name);
205   (void)CheckAndConvertUtils::CheckTensorTypeValid("init_h_dtype", init_h_dtype, valid_types, prim_name);
206   (void)CheckAndConvertUtils::CheckTensorTypeSame(check_types, valid_types, prim_name);
207   if (input_args.size() >= kInputIndex13 && input_args[kInputIndex12]->GetType()->type_id() != kMetaTypeNone) {
208     auto seq_dtype = input_args[kInputIndex12]->GetType();
209     (void)CheckAndConvertUtils::CheckTensorTypeValid("seq_dtype", seq_dtype, valid_types, prim_name);
210   }
211   if (input_args.size() >= kInputIndex14 && input_args[kInputIndex13]->GetType()->type_id() != kMetaTypeNone) {
212     auto mask_dtype = input_args[kInputIndex13]->GetType();
213     (void)CheckAndConvertUtils::CheckTensorTypeValid("mask_dtype", mask_dtype, valid_types, prim_name);
214   }
215 
216   return std::make_shared<Tuple>(
217     std::vector<TypePtr>{winput_dtype, whidden_dtype, init_h_dtype, init_h_dtype, x_dtype, init_h_dtype});
218 }
219 }  // namespace
220 
DynamicGRUV2GradInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)221 AbstractBasePtr DynamicGRUV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
222                                       const std::vector<AbstractBasePtr> &input_args) {
223   MS_EXCEPTION_IF_NULL(primitive);
224   auto prim_name = primitive->name();
225   const int64_t MinInputNum = 12;
226   CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, MinInputNum, prim_name);
227   auto types = DynamicGRUV2GradInferType(primitive, input_args);
228   auto shapes = DynamicGRUV2GradInferShape(primitive, input_args);
229   return abstract::MakeAbstract(shapes, types);
230 }
231 
232 MIND_API_OPERATOR_IMPL(DynamicGRUV2Grad, BaseOperator);
233 
234 // AG means auto generated
235 class MIND_API AGDynamicGRUV2GradInfer : public abstract::OpInferBase {
236  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const237   BaseShapePtr InferShape(const PrimitivePtr &primitive,
238                           const std::vector<AbstractBasePtr> &input_args) const override {
239     return DynamicGRUV2GradInferShape(primitive, input_args);
240   }
241 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const242   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
243     return DynamicGRUV2GradInferType(primitive, input_args);
244   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const245   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
246                                     const std::vector<AbstractBasePtr> &input_args) const override {
247     return DynamicGRUV2GradInfer(engine, primitive, input_args);
248   }
249 };
250 
251 REGISTER_PRIMITIVE_OP_INFER_IMPL(DynamicGRUV2Grad, prim::kPrimDynamicGRUV2Grad, AGDynamicGRUV2GradInfer, false);
252 }  // namespace ops
253 }  // namespace mindspore
254