1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/grad/lstm_grad_data.h"
18
19 #include <memory>
20 #include <set>
21
22 #include "abstract/abstract_value.h"
23 #include "abstract/dshape.h"
24 #include "abstract/ops/op_infer.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "abstract/utils.h"
27 #include "base/base.h"
28 #include "ir/anf.h"
29 #include "ir/dtype/container.h"
30 #include "ir/dtype/number.h"
31 #include "ir/primitive.h"
32 #include "mindapi/base/shared_ptr.h"
33 #include "mindapi/ir/value.h"
34 #include "mindapi/src/helper.h"
35 #include "mindspore/core/ops/nn_ops.h"
36 #include "ops/op_name.h"
37 #include "ops/primitive_c.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/convert_utils_base.h"
40 #include "utils/log_adapter.h"
41 #include "utils/shape_utils.h"
42
43 namespace mindspore {
44 namespace ops {
set_input_size(const int64_t input_size)45 void LSTMGradData::set_input_size(const int64_t input_size) {
46 (void)CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
47 (void)AddAttr(kInput_size, api::MakeValue(input_size));
48 }
get_input_size() const49 int64_t LSTMGradData::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
set_hidden_size(const int64_t hidden_size)50 void LSTMGradData::set_hidden_size(const int64_t hidden_size) {
51 (void)CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
52 (void)AddAttr(kHidden_size, api::MakeValue(hidden_size));
53 }
get_hidden_size() const54 int64_t LSTMGradData::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
set_num_layers(const int64_t num_layers)55 void LSTMGradData::set_num_layers(const int64_t num_layers) {
56 (void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
57 (void)AddAttr(kNumLayers, api::MakeValue(num_layers));
58 }
get_num_layers() const59 int64_t LSTMGradData::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
set_has_bias(const bool has_bias)60 void LSTMGradData::set_has_bias(const bool has_bias) { (void)AddAttr(kHasBias, api::MakeValue(has_bias)); }
get_has_bias() const61 bool LSTMGradData::get_has_bias() const {
62 auto value_ptr = this->GetAttr(kHasBias);
63 return GetValue<bool>(value_ptr);
64 }
set_dropout(const float dropout)65 void LSTMGradData::set_dropout(const float dropout) {
66 CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
67 (void)AddAttr(kDropout, api::MakeValue(dropout));
68 }
get_dropout() const69 float LSTMGradData::get_dropout() const {
70 auto value_ptr = this->GetAttr(kDropout);
71 return GetValue<float>(value_ptr);
72 }
set_bidirectional(const bool bidirectional)73 void LSTMGradData::set_bidirectional(const bool bidirectional) {
74 (void)AddAttr(kBidirectional, api::MakeValue(bidirectional));
75 }
get_bidirectional() const76 bool LSTMGradData::get_bidirectional() const {
77 auto value_ptr = this->GetAttr(kBidirectional);
78 return GetValue<bool>(value_ptr);
79 }
set_num_directions(const int64_t num_directions)80 void LSTMGradData::set_num_directions(const int64_t num_directions) {
81 (void)AddAttr(kNumDirections, api::MakeValue(num_directions));
82 }
get_num_directions() const83 int64_t LSTMGradData::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
set_zoneout_cell(float zoneout_cell)84 void LSTMGradData::set_zoneout_cell(float zoneout_cell) { (void)AddAttr(kZoneoutCell, api::MakeValue(zoneout_cell)); }
85
get_zoneout_cell() const86 float LSTMGradData::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }
87
set_zoneout_hidden(float zoneout_hidden)88 void LSTMGradData::set_zoneout_hidden(float zoneout_hidden) {
89 (void)AddAttr(kZoneoutHidden, api::MakeValue(zoneout_hidden));
90 }
91
get_zoneout_hidden() const92 float LSTMGradData::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); }
93
set_proj_size(const int64_t proj_size)94 void LSTMGradData::set_proj_size(const int64_t proj_size) {
95 (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name());
96 (void)AddAttr(kProjection_size, api::MakeValue(proj_size));
97 }
98
get_proj_size() const99 int64_t LSTMGradData::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); }
100
Init(const int64_t input_size,const int64_t hidden_size,const int64_t num_layers,const bool has_bias,const float dropout,const bool bidirectional,const float zoneout_cell,const float zoneout_hidden,const int64_t proj_size)101 void LSTMGradData::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers,
102 const bool has_bias, const float dropout, const bool bidirectional, const float zoneout_cell,
103 const float zoneout_hidden, const int64_t proj_size) {
104 this->set_input_size(input_size);
105 this->set_hidden_size(hidden_size);
106 this->set_num_layers(num_layers);
107 this->set_has_bias(has_bias);
108 this->set_dropout(dropout);
109 this->set_bidirectional(bidirectional);
110 this->set_proj_size(proj_size);
111 if (bidirectional) {
112 constexpr int k2Directions = 2;
113 this->set_num_directions(k2Directions);
114 } else {
115 this->set_num_directions(1);
116 }
117 this->set_zoneout_cell(zoneout_cell);
118 this->set_zoneout_hidden(zoneout_hidden);
119 }
120
121 namespace {
122 const size_t kLstmOutputNum = 3;
LstmGradDataInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)123 abstract::TupleShapePtr LstmGradDataInferShape(const PrimitivePtr &primitive,
124 const std::vector<AbstractBasePtr> &input_args) {
125 MS_EXCEPTION_IF_NULL(primitive);
126 auto prim_name = primitive->name();
127 const size_t input_num = 9;
128 auto shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
129 auto unknown_shapes =
130 std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>(SizeToLong(kLstmOutputNum), shape_ptr));
131 for (size_t i = 0; i < input_num; i++) {
132 auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->GetShape())[kShape];
133 if (IsDynamicRank(shape)) {
134 return unknown_shapes;
135 }
136 }
137
138 auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
139 auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
140 auto dhy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
141 auto dcy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
142 bool dy_is_dynamic = IsDynamic(dy_shape);
143 bool dhy_is_dynamic = IsDynamic(dhy_shape);
144 bool dcy_is_dynamic = IsDynamic(dcy_shape);
145
146 int64_t input_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
147 int64_t hidden_size = GetValue<int64_t>(primitive->GetAttr(kHidden_size));
148 int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
149 bool bidirectional = GetValue<bool>(primitive->GetAttr(kBidirectional));
150 int64_t bidirection_num = 2;
151 int64_t num_directions = bidirectional ? bidirection_num : 1;
152
153 const size_t shape_size = 3;
154 if (!dhy_is_dynamic) {
155 (void)CheckAndConvertUtils::CheckInteger("dhy_shape.size()", SizeToLong(dhy_shape.size()), kEqual, shape_size,
156 prim_name);
157 (void)CheckAndConvertUtils::CheckInteger("h_shape[0]", dhy_shape[0], kEqual, num_layers * num_directions,
158 prim_name);
159 (void)CheckAndConvertUtils::CheckInteger("h_shape[2]", dhy_shape[kDim2], kEqual, hidden_size, prim_name);
160 if (!dcy_is_dynamic) {
161 (void)CheckAndConvertUtils::Check("dhy_shape", dhy_shape, kEqual, dcy_shape, prim_name);
162 }
163 }
164 if (!dy_is_dynamic) {
165 (void)CheckAndConvertUtils::CheckInteger("dy_shape.size()", SizeToLong(dy_shape.size()), kEqual, shape_size,
166 prim_name);
167 if (!dhy_is_dynamic) {
168 (void)CheckAndConvertUtils::CheckInteger("dy[1]", dy_shape[kDim1], kEqual, dhy_shape[kDim1], prim_name);
169 }
170 (void)CheckAndConvertUtils::CheckInteger("dy[2]", dy_shape[kDim2], kEqual, hidden_size * num_directions, prim_name);
171 }
172
173 std::vector<int64_t> dx_shape = {y_shape[0], y_shape[kDim1], input_size};
174 std::vector<abstract::BaseShapePtr> output_shapes;
175 output_shapes.push_back(std::make_shared<abstract::Shape>(dx_shape));
176 output_shapes.push_back(std::make_shared<abstract::Shape>(dhy_shape));
177 output_shapes.push_back(std::make_shared<abstract::Shape>(dcy_shape));
178 return std::make_shared<abstract::TupleShape>(output_shapes);
179 }
180
LstmGradDataInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)181 TuplePtr LstmGradDataInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
182 MS_EXCEPTION_IF_NULL(prim);
183 const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
184 auto x_dtype = input_args[0]->GetType();
185 (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
186
187 std::vector<TypePtr> type_tuple(kLstmOutputNum, x_dtype);
188 return std::make_shared<Tuple>(type_tuple);
189 }
190 } // namespace
191
LstmGradDataInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)192 AbstractBasePtr LstmGradDataInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
193 const std::vector<AbstractBasePtr> &input_args) {
194 MS_EXCEPTION_IF_NULL(primitive);
195 const int64_t kInputsNum = 9;
196 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
197 auto infer_type = LstmGradDataInferType(primitive, input_args);
198 auto infer_shape = LstmGradDataInferShape(primitive, input_args);
199 return abstract::MakeAbstract(infer_shape, infer_type);
200 }
201
202 MIND_API_OPERATOR_IMPL(LSTMGradData, BaseOperator);
203
204 // AG means auto generated
205 class MIND_API AGLstmGradDataInfer : public abstract::OpInferBase {
206 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const207 BaseShapePtr InferShape(const PrimitivePtr &primitive,
208 const std::vector<AbstractBasePtr> &input_args) const override {
209 return LstmGradDataInferShape(primitive, input_args);
210 }
211
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const212 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
213 return LstmGradDataInferType(primitive, input_args);
214 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const215 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
216 const std::vector<AbstractBasePtr> &input_args) const override {
217 return LstmGradDataInfer(engine, primitive, input_args);
218 }
219 };
220
221 REGISTER_PRIMITIVE_OP_INFER_IMPL(LSTMGradData, prim::kPrimLstmGradData, AGLstmGradDataInfer, false);
222 } // namespace ops
223 } // namespace mindspore
224