• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_LSTM_H_
18 #define MINDSPORE_CORE_OPS_LSTM_H_
19 
20 #include <map>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <algorithm>
25 #include "ops/op_utils.h"
26 #include "ops/primitive_c.h"
27 #include "abstract/primitive_infer_map.h"
28 #include "abstract/abstract_value.h"
29 #include "utils/check_convert_utils.h"
30 
31 namespace mindspore {
32 namespace ops {
33 constexpr auto kNameLSTM = "LSTM";
34 /// \brief Performs the Long Short-Term Memory (LSTM) on the input.
35 /// Refer to Python API @ref mindspore.ops.LSTM for more details.
36 class MS_CORE_API LSTM : public PrimitiveC {
37  public:
38   /// \brief Constructor.
LSTM()39   LSTM() : PrimitiveC(kNameLSTM) {}
40   /// \brief Destructor.
41   ~LSTM() = default;
42   MS_DECLARE_PARENT(LSTM, PrimitiveC);
43   /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.LSTM for the inputs.
44   void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
45             const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f,
46             const float zoneout_hidden = 0.0f);
47   /// \brief Set input_size.
48   void set_input_size(const int64_t input_size);
49   /// \brief Get input_size.
50   ///
51   /// \return input_size.
52   int64_t get_input_size() const;
53   /// \brief Set hidden_size.
54   void set_hidden_size(const int64_t hidden_size);
55   /// \brief Get hidden_size.
56   ///
57   /// \return hidden_size.
58   int64_t get_hidden_size() const;
59   /// \brief Set num_layers.
60   void set_num_layers(const int64_t num_layers);
61   /// \brief Get num_layers.
62   ///
63   /// \return num_layers.
64   int64_t get_num_layers() const;
65   /// \brief Set has_bias.
66   void set_has_bias(const bool has_bias);
67   /// \brief Get has_bias.
68   ///
69   /// \return has_bias.
70   bool get_has_bias() const;
71   /// \brief Set dropout.
72   void set_dropout(const float dropout);
73   /// \brief Get dropout.
74   ///
75   /// \return dropout.
76   float get_dropout() const;
77   /// \brief Set bidirectional.
78   void set_bidirectional(const bool bidirectional);
79   /// \brief Get bidirectional.
80   ///
81   /// \return bidirectional.
82   bool get_bidirectional() const;
83   /// \brief Set num_directions.
84   void set_num_directions(const int64_t num_directions);
85   /// \brief Get num_directions.
86   ///
87   /// \return num_directions.
88   int64_t get_num_directions() const;
89   /// \brief Set zoneout_cell.
90   void set_zoneout_cell(float zoneout_cell);
91   /// \brief Get zoneout_cell.
92   ///
93   /// \return zoneout_cell.
94   float get_zoneout_cell() const;
95   /// \brief Set zoneout_hidden.
96   void set_zoneout_hidden(float zoneout_hidden);
97   /// \brief Get zoneout_hidden.
98   ///
99   /// \return zoneout_hidden.
100   float get_zoneout_hidden() const;
101   /// \brief Get good_ld.
102   ///
103   /// \return good_ld.
104   int64_t get_good_ld(const int64_t dim, const int64_t type_size);
105 };
106 AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
107                           const std::vector<AbstractBasePtr> &input_args);
108 using PrimLstmPtr = std::shared_ptr<LSTM>;
109 }  // namespace ops
110 }  // namespace mindspore
111 
112 #endif  // MINDSPORE_CORE_OPS_LSTM_H_
113