1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23
24 #include "abstract/abstract_value.h"
25 #include "abstract/dshape.h"
26 #include "abstract/ops/op_infer.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "base/base.h"
30 #include "ir/dtype/container.h"
31 #include "ir/dtype/number.h"
32 #include "ir/dtype/type.h"
33 #include "ir/primitive.h"
34 #include "mindapi/base/shape_vector.h"
35 #include "mindapi/base/type_id.h"
36 #include "mindapi/src/helper.h"
37 #include "mindspore/core/ops/array_ops.h"
38 #include "ops/dynamic_gru_v2_grad.h"
39 #include "ops/op_name.h"
40 #include "ops/primitive_c.h"
41 #include "utils/check_convert_utils.h"
42 #include "utils/log_adapter.h"
43 #include "utils/shape_utils.h"
44
45 namespace mindspore {
46 namespace ops {
47 namespace {
DynamicGRUV2GradCheckShapeValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args,const int64_t & num_proj)48 void DynamicGRUV2GradCheckShapeValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
49 const int64_t &num_proj) {
50 auto prim_name = primitive->name();
51 auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
52 auto winput_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
53 auto whidden_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
54 auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
55 auto init_h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShape())[kShape];
56 auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShape())[kShape];
57 auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->GetShape())[kShape];
58 auto dh_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->GetShape())[kShape];
59 auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->GetShape())[kShape];
60 auto reset_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShape())[kShape];
61 auto new_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex10]->GetShape())[kShape];
62 auto hnew_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex11]->GetShape())[kShape];
63
64 std::vector<ShapeVector> all_shapes = {x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape,
65 dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape};
66 auto is_dynamic = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamic);
67 if (is_dynamic) {
68 return;
69 }
70
71 int64_t num_step = x_shape[0];
72 int64_t batch_size = x_shape[1];
73 int64_t input_size = x_shape[2];
74 int64_t hidden_size = whidden_shape[0];
75
76 auto winput_shape_ptr = input_args[kInputIndex1]->GetShape();
77 auto whidden_shape_ptr = input_args[kInputIndex2]->GetShape();
78 auto y_shape_ptr = input_args[kInputIndex3]->GetShape();
79 auto init_h_shape_ptr = input_args[kInputIndex4]->GetShape();
80 auto h_shape_ptr = input_args[kInputIndex5]->GetShape();
81 auto dy_shape_ptr = input_args[kInputIndex6]->GetShape();
82 auto dh_shape_ptr = input_args[kInputIndex7]->GetShape();
83 auto update_shape_ptr = input_args[kInputIndex8]->GetShape();
84 auto reset_shape_ptr = input_args[kInputIndex9]->GetShape();
85 auto new_shape_ptr = input_args[kInputIndex10]->GetShape();
86 auto hnew_shape_ptr = input_args[kInputIndex11]->GetShape();
87
88 (void)CheckAndConvertUtils::CheckTensorShapeSame({{"weight input shape", winput_shape_ptr}},
89 std::vector<int64_t>{input_size, 3 * hidden_size}, prim_name);
90 (void)CheckAndConvertUtils::CheckTensorShapeSame({{"weight hidden shape", whidden_shape_ptr}},
91 std::vector<int64_t>{hidden_size, 3 * hidden_size}, prim_name);
92 (void)CheckAndConvertUtils::CheckTensorShapeSame({{"init h shape", init_h_shape_ptr}},
93 std::vector<int64_t>{batch_size, hidden_size}, prim_name);
94 (void)CheckAndConvertUtils::CheckTensorShapeSame({{"dh shape", dh_shape_ptr}},
95 std::vector<int64_t>{batch_size, hidden_size}, prim_name);
96
97 std::vector<int64_t> valid_y_shape;
98 (void)valid_y_shape.emplace_back(num_step);
99 (void)valid_y_shape.emplace_back(batch_size);
100 const int64_t kNumZero = 0;
101 if (num_proj > kNumZero) {
102 (void)valid_y_shape.emplace_back(std::min(hidden_size, num_proj));
103 } else {
104 (void)valid_y_shape.emplace_back(hidden_size);
105 }
106 (void)CheckAndConvertUtils::CheckTensorShapeSame({{"y shape", y_shape_ptr}}, valid_y_shape, prim_name);
107
108 std::map<std::string, BaseShapePtr> check_shapes = {
109 {"h shape", h_shape_ptr}, {"dy shape", dy_shape_ptr}, {"update shape", update_shape_ptr},
110 {"reset shape", reset_shape_ptr}, {"new shape", new_shape_ptr}, {"hnew shape", hnew_shape_ptr}};
111 std::vector<int64_t> valid_shape = {num_step, batch_size, hidden_size};
112 (void)CheckAndConvertUtils::CheckTensorShapeSame(check_shapes, valid_shape, prim_name);
113
114 if (input_args.size() >= kInputIndex13 && input_args[kInputIndex12]->GetType()->type_id() != kMetaTypeNone) {
115 auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex12]->GetShape())[kShape];
116 auto seq_shape_ptr = input_args[kInputIndex12]->GetShape();
117 if (!IsDynamic(seq_shape)) {
118 (void)CheckAndConvertUtils::CheckTensorShapeSame({{"seq shape", seq_shape_ptr}}, std::vector<int64_t>{batch_size},
119 prim_name);
120 }
121 }
122 }
123
DynamicGRUV2GradInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)124 abstract::TupleShapePtr DynamicGRUV2GradInferShape(const PrimitivePtr &primitive,
125 const std::vector<AbstractBasePtr> &input_args) {
126 MS_EXCEPTION_IF_NULL(primitive);
127 auto prim_name = primitive->name();
128 auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
129 auto winput_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
130 auto whidden_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
131 auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
132
133 int64_t num_proj = 0;
134 if (primitive->HasAttr(kNumProj)) {
135 num_proj = GetValue<int64_t>(primitive->GetAttr(kNumProj));
136 }
137
138 std::vector<ShapeVector> check_shapes = {x_shape, winput_shape, whidden_shape, y_shape};
139 auto is_dynamic_rank = std::any_of(check_shapes.begin(), check_shapes.end(), IsDynamicRank);
140
141 const int64_t kNumTwo = 2;
142 const int64_t kNumThree = 3;
143 if (!is_dynamic_rank) {
144 (void)CheckAndConvertUtils::CheckInteger("x shape rank", SizeToLong(x_shape.size()), kEqual, kNumThree, prim_name);
145 (void)CheckAndConvertUtils::CheckInteger("weight input shape rank", SizeToLong(winput_shape.size()), kEqual,
146 kNumTwo, prim_name);
147 (void)CheckAndConvertUtils::CheckInteger("weight hidden shape rank", SizeToLong(whidden_shape.size()), kEqual,
148 kNumTwo, prim_name);
149 (void)CheckAndConvertUtils::CheckInteger("y shape rank", SizeToLong(y_shape.size()), kEqual, kNumThree, prim_name);
150 }
151 DynamicGRUV2GradCheckShapeValue(primitive, input_args, num_proj);
152
153 int64_t num_step = -1;
154 int64_t batch_size = -1;
155 int64_t input_size = -1;
156 int64_t hidden_size = -1;
157 int64_t hidden_size_three = -1;
158 if (!(IsDynamic(x_shape) || IsDynamic(whidden_shape))) {
159 num_step = x_shape[kInputIndex0];
160 batch_size = x_shape[kInputIndex1];
161 input_size = x_shape[kInputIndex2];
162 hidden_size = whidden_shape[kInputIndex0];
163 hidden_size_three = whidden_shape[kInputIndex1];
164 }
165
166 ShapeVector dx_shape = {num_step, batch_size, input_size};
167 ShapeVector dh_shape = {batch_size, hidden_size};
168 ShapeVector dwinput_shape = {input_size, hidden_size_three};
169 ShapeVector dwhidden_shape = {hidden_size, hidden_size_three};
170 ShapeVector db_shape = {hidden_size_three};
171
172 auto db_shape_ptr = std::make_shared<abstract::Shape>(db_shape);
173 auto dh_shape_ptr = std::make_shared<abstract::Shape>(dh_shape);
174 auto dx_shape_ptr = std::make_shared<abstract::Shape>(dx_shape);
175 auto dwinput_shape_ptr = std::make_shared<abstract::Shape>(dwinput_shape);
176 auto dwhidden_shape_ptr = std::make_shared<abstract::Shape>(dwhidden_shape);
177
178 return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
179 dwinput_shape_ptr, dwhidden_shape_ptr, db_shape_ptr, db_shape_ptr, dx_shape_ptr, dh_shape_ptr});
180 }
181
DynamicGRUV2GradInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)182 TuplePtr DynamicGRUV2GradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
183 MS_EXCEPTION_IF_NULL(primitive);
184 auto prim_name = primitive->name();
185 const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
186 auto x_dtype = input_args[kInputIndex0]->GetType();
187 auto winput_dtype = input_args[kInputIndex1]->GetType();
188 auto whidden_dtype = input_args[kInputIndex2]->GetType();
189 auto y_dtype = input_args[kInputIndex3]->GetType();
190 auto init_h_dtype = input_args[kInputIndex4]->GetType();
191 auto h_dtype = input_args[kInputIndex5]->GetType();
192 auto dy_dtype = input_args[kInputIndex6]->GetType();
193 auto dh_dtype = input_args[kInputIndex7]->GetType();
194 auto update_dtype = input_args[kInputIndex8]->GetType();
195 auto reset_dtype = input_args[kInputIndex9]->GetType();
196 auto new_dtype = input_args[kInputIndex10]->GetType();
197 auto hnew_dtype = input_args[kInputIndex11]->GetType();
198
199 std::map<std::string, TypePtr> check_types = {
200 {"y_dtype", y_dtype}, {"h_dtype", h_dtype}, {"dy_dtype", dy_dtype}, {"dh_dtype", dh_dtype},
201 {"update_dtype", update_dtype}, {"reset_dtype", reset_dtype}, {"new_dtype", new_dtype}, {"hnew_dtype", hnew_dtype}};
202 (void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, valid_types, prim_name);
203 (void)CheckAndConvertUtils::CheckTensorTypeValid("winput_dtype", winput_dtype, valid_types, prim_name);
204 (void)CheckAndConvertUtils::CheckTensorTypeValid("whidden_dtype", whidden_dtype, valid_types, prim_name);
205 (void)CheckAndConvertUtils::CheckTensorTypeValid("init_h_dtype", init_h_dtype, valid_types, prim_name);
206 (void)CheckAndConvertUtils::CheckTensorTypeSame(check_types, valid_types, prim_name);
207 if (input_args.size() >= kInputIndex13 && input_args[kInputIndex12]->GetType()->type_id() != kMetaTypeNone) {
208 auto seq_dtype = input_args[kInputIndex12]->GetType();
209 (void)CheckAndConvertUtils::CheckTensorTypeValid("seq_dtype", seq_dtype, valid_types, prim_name);
210 }
211 if (input_args.size() >= kInputIndex14 && input_args[kInputIndex13]->GetType()->type_id() != kMetaTypeNone) {
212 auto mask_dtype = input_args[kInputIndex13]->GetType();
213 (void)CheckAndConvertUtils::CheckTensorTypeValid("mask_dtype", mask_dtype, valid_types, prim_name);
214 }
215
216 return std::make_shared<Tuple>(
217 std::vector<TypePtr>{winput_dtype, whidden_dtype, init_h_dtype, init_h_dtype, x_dtype, init_h_dtype});
218 }
219 } // namespace
220
DynamicGRUV2GradInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)221 AbstractBasePtr DynamicGRUV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
222 const std::vector<AbstractBasePtr> &input_args) {
223 MS_EXCEPTION_IF_NULL(primitive);
224 auto prim_name = primitive->name();
225 const int64_t MinInputNum = 12;
226 CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, MinInputNum, prim_name);
227 auto types = DynamicGRUV2GradInferType(primitive, input_args);
228 auto shapes = DynamicGRUV2GradInferShape(primitive, input_args);
229 return abstract::MakeAbstract(shapes, types);
230 }
231
232 MIND_API_OPERATOR_IMPL(DynamicGRUV2Grad, BaseOperator);
233
234 // AG means auto generated
235 class MIND_API AGDynamicGRUV2GradInfer : public abstract::OpInferBase {
236 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const237 BaseShapePtr InferShape(const PrimitivePtr &primitive,
238 const std::vector<AbstractBasePtr> &input_args) const override {
239 return DynamicGRUV2GradInferShape(primitive, input_args);
240 }
241
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const242 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
243 return DynamicGRUV2GradInferType(primitive, input_args);
244 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const245 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
246 const std::vector<AbstractBasePtr> &input_args) const override {
247 return DynamicGRUV2GradInfer(engine, primitive, input_args);
248 }
249 };
250
251 REGISTER_PRIMITIVE_OP_INFER_IMPL(DynamicGRUV2Grad, prim::kPrimDynamicGRUV2Grad, AGDynamicGRUV2GradInfer, false);
252 } // namespace ops
253 } // namespace mindspore
254