• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/grad/lstm_grad_data.h"
18 
19 #include <memory>
20 #include <set>
21 
22 #include "abstract/abstract_value.h"
23 #include "abstract/dshape.h"
24 #include "abstract/ops/op_infer.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "abstract/utils.h"
27 #include "base/base.h"
28 #include "ir/anf.h"
29 #include "ir/dtype/container.h"
30 #include "ir/dtype/number.h"
31 #include "ir/primitive.h"
32 #include "mindapi/base/shared_ptr.h"
33 #include "mindapi/ir/value.h"
34 #include "mindapi/src/helper.h"
35 #include "mindspore/core/ops/nn_ops.h"
36 #include "ops/op_name.h"
37 #include "ops/primitive_c.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/convert_utils_base.h"
40 #include "utils/log_adapter.h"
41 #include "utils/shape_utils.h"
42 
43 namespace mindspore {
44 namespace ops {
set_input_size(const int64_t input_size)45 void LSTMGradData::set_input_size(const int64_t input_size) {
46   (void)CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
47   (void)AddAttr(kInput_size, api::MakeValue(input_size));
48 }
get_input_size() const49 int64_t LSTMGradData::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
set_hidden_size(const int64_t hidden_size)50 void LSTMGradData::set_hidden_size(const int64_t hidden_size) {
51   (void)CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
52   (void)AddAttr(kHidden_size, api::MakeValue(hidden_size));
53 }
get_hidden_size() const54 int64_t LSTMGradData::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
set_num_layers(const int64_t num_layers)55 void LSTMGradData::set_num_layers(const int64_t num_layers) {
56   (void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
57   (void)AddAttr(kNumLayers, api::MakeValue(num_layers));
58 }
get_num_layers() const59 int64_t LSTMGradData::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
set_has_bias(const bool has_bias)60 void LSTMGradData::set_has_bias(const bool has_bias) { (void)AddAttr(kHasBias, api::MakeValue(has_bias)); }
get_has_bias() const61 bool LSTMGradData::get_has_bias() const {
62   auto value_ptr = this->GetAttr(kHasBias);
63   return GetValue<bool>(value_ptr);
64 }
set_dropout(const float dropout)65 void LSTMGradData::set_dropout(const float dropout) {
66   CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
67   (void)AddAttr(kDropout, api::MakeValue(dropout));
68 }
get_dropout() const69 float LSTMGradData::get_dropout() const {
70   auto value_ptr = this->GetAttr(kDropout);
71   return GetValue<float>(value_ptr);
72 }
set_bidirectional(const bool bidirectional)73 void LSTMGradData::set_bidirectional(const bool bidirectional) {
74   (void)AddAttr(kBidirectional, api::MakeValue(bidirectional));
75 }
get_bidirectional() const76 bool LSTMGradData::get_bidirectional() const {
77   auto value_ptr = this->GetAttr(kBidirectional);
78   return GetValue<bool>(value_ptr);
79 }
set_num_directions(const int64_t num_directions)80 void LSTMGradData::set_num_directions(const int64_t num_directions) {
81   (void)AddAttr(kNumDirections, api::MakeValue(num_directions));
82 }
get_num_directions() const83 int64_t LSTMGradData::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
set_zoneout_cell(float zoneout_cell)84 void LSTMGradData::set_zoneout_cell(float zoneout_cell) { (void)AddAttr(kZoneoutCell, api::MakeValue(zoneout_cell)); }
85 
get_zoneout_cell() const86 float LSTMGradData::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }
87 
set_zoneout_hidden(float zoneout_hidden)88 void LSTMGradData::set_zoneout_hidden(float zoneout_hidden) {
89   (void)AddAttr(kZoneoutHidden, api::MakeValue(zoneout_hidden));
90 }
91 
get_zoneout_hidden() const92 float LSTMGradData::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); }
93 
set_proj_size(const int64_t proj_size)94 void LSTMGradData::set_proj_size(const int64_t proj_size) {
95   (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name());
96   (void)AddAttr(kProjection_size, api::MakeValue(proj_size));
97 }
98 
get_proj_size() const99 int64_t LSTMGradData::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); }
100 
Init(const int64_t input_size,const int64_t hidden_size,const int64_t num_layers,const bool has_bias,const float dropout,const bool bidirectional,const float zoneout_cell,const float zoneout_hidden,const int64_t proj_size)101 void LSTMGradData::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers,
102                         const bool has_bias, const float dropout, const bool bidirectional, const float zoneout_cell,
103                         const float zoneout_hidden, const int64_t proj_size) {
104   this->set_input_size(input_size);
105   this->set_hidden_size(hidden_size);
106   this->set_num_layers(num_layers);
107   this->set_has_bias(has_bias);
108   this->set_dropout(dropout);
109   this->set_bidirectional(bidirectional);
110   this->set_proj_size(proj_size);
111   if (bidirectional) {
112     constexpr int k2Directions = 2;
113     this->set_num_directions(k2Directions);
114   } else {
115     this->set_num_directions(1);
116   }
117   this->set_zoneout_cell(zoneout_cell);
118   this->set_zoneout_hidden(zoneout_hidden);
119 }
120 
121 namespace {
122 const size_t kLstmOutputNum = 3;
LstmGradDataInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)123 abstract::TupleShapePtr LstmGradDataInferShape(const PrimitivePtr &primitive,
124                                                const std::vector<AbstractBasePtr> &input_args) {
125   MS_EXCEPTION_IF_NULL(primitive);
126   auto prim_name = primitive->name();
127   const size_t input_num = 9;
128   auto shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
129   auto unknown_shapes =
130     std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>(SizeToLong(kLstmOutputNum), shape_ptr));
131   for (size_t i = 0; i < input_num; i++) {
132     auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->GetShape())[kShape];
133     if (IsDynamicRank(shape)) {
134       return unknown_shapes;
135     }
136   }
137 
138   auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShape())[kShape];
139   auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShape())[kShape];
140   auto dhy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShape())[kShape];
141   auto dcy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShape())[kShape];
142   bool dy_is_dynamic = IsDynamic(dy_shape);
143   bool dhy_is_dynamic = IsDynamic(dhy_shape);
144   bool dcy_is_dynamic = IsDynamic(dcy_shape);
145 
146   int64_t input_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
147   int64_t hidden_size = GetValue<int64_t>(primitive->GetAttr(kHidden_size));
148   int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
149   bool bidirectional = GetValue<bool>(primitive->GetAttr(kBidirectional));
150   int64_t bidirection_num = 2;
151   int64_t num_directions = bidirectional ? bidirection_num : 1;
152 
153   const size_t shape_size = 3;
154   if (!dhy_is_dynamic) {
155     (void)CheckAndConvertUtils::CheckInteger("dhy_shape.size()", SizeToLong(dhy_shape.size()), kEqual, shape_size,
156                                              prim_name);
157     (void)CheckAndConvertUtils::CheckInteger("h_shape[0]", dhy_shape[0], kEqual, num_layers * num_directions,
158                                              prim_name);
159     (void)CheckAndConvertUtils::CheckInteger("h_shape[2]", dhy_shape[kDim2], kEqual, hidden_size, prim_name);
160     if (!dcy_is_dynamic) {
161       (void)CheckAndConvertUtils::Check("dhy_shape", dhy_shape, kEqual, dcy_shape, prim_name);
162     }
163   }
164   if (!dy_is_dynamic) {
165     (void)CheckAndConvertUtils::CheckInteger("dy_shape.size()", SizeToLong(dy_shape.size()), kEqual, shape_size,
166                                              prim_name);
167     if (!dhy_is_dynamic) {
168       (void)CheckAndConvertUtils::CheckInteger("dy[1]", dy_shape[kDim1], kEqual, dhy_shape[kDim1], prim_name);
169     }
170     (void)CheckAndConvertUtils::CheckInteger("dy[2]", dy_shape[kDim2], kEqual, hidden_size * num_directions, prim_name);
171   }
172 
173   std::vector<int64_t> dx_shape = {y_shape[0], y_shape[kDim1], input_size};
174   std::vector<abstract::BaseShapePtr> output_shapes;
175   output_shapes.push_back(std::make_shared<abstract::Shape>(dx_shape));
176   output_shapes.push_back(std::make_shared<abstract::Shape>(dhy_shape));
177   output_shapes.push_back(std::make_shared<abstract::Shape>(dcy_shape));
178   return std::make_shared<abstract::TupleShape>(output_shapes);
179 }
180 
LstmGradDataInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)181 TuplePtr LstmGradDataInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
182   MS_EXCEPTION_IF_NULL(prim);
183   const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
184   auto x_dtype = input_args[0]->GetType();
185   (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
186 
187   std::vector<TypePtr> type_tuple(kLstmOutputNum, x_dtype);
188   return std::make_shared<Tuple>(type_tuple);
189 }
190 }  // namespace
191 
LstmGradDataInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)192 AbstractBasePtr LstmGradDataInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
193                                   const std::vector<AbstractBasePtr> &input_args) {
194   MS_EXCEPTION_IF_NULL(primitive);
195   const int64_t kInputsNum = 9;
196   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
197   auto infer_type = LstmGradDataInferType(primitive, input_args);
198   auto infer_shape = LstmGradDataInferShape(primitive, input_args);
199   return abstract::MakeAbstract(infer_shape, infer_type);
200 }
201 
202 MIND_API_OPERATOR_IMPL(LSTMGradData, BaseOperator);
203 
204 // AG means auto generated
205 class MIND_API AGLstmGradDataInfer : public abstract::OpInferBase {
206  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const207   BaseShapePtr InferShape(const PrimitivePtr &primitive,
208                           const std::vector<AbstractBasePtr> &input_args) const override {
209     return LstmGradDataInferShape(primitive, input_args);
210   }
211 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const212   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
213     return LstmGradDataInferType(primitive, input_args);
214   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const215   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
216                                     const std::vector<AbstractBasePtr> &input_args) const override {
217     return LstmGradDataInfer(engine, primitive, input_args);
218   }
219 };
220 
221 REGISTER_PRIMITIVE_OP_INFER_IMPL(LSTMGradData, prim::kPrimLstmGradData, AGLstmGradDataInfer, false);
222 }  // namespace ops
223 }  // namespace mindspore
224