Lines Matching defs:CellParams
175 struct CellParams : public CellParamsBase { struct
176 CellParams( in CellParams() function
184 const Tensor& w_ih;
185 const Tensor& w_hh;
186 const Tensor& b_ih_; /* optional */
187 const Tensor& b_hh_; /* optional */
188 const Tensor& w_hr; /* only defined for LSTMs with projections */
190 Tensor matmul_ih(const Tensor& input) const override { in matmul_ih()
193 Tensor matmul_hh(const Tensor& h) const override { in matmul_hh()
196 Tensor matmul_hr(const Tensor& h) const override { in matmul_hr()
202 Tensor linear_ih(const Tensor& input) const override { in linear_ih()
205 Tensor linear_hh(const Tensor& h) const override { in linear_hh()
208 const Tensor& b_ih() const override { in b_ih()
211 const Tensor& b_hh() const override { in b_hh()
214 CellParamsSerializationType __getstate__() const override { in __getstate__()
217 static c10::intrusive_ptr<CellParamsBase> __setstate__( in __setstate__()
1376 ONE_HIDDEN_RNN(gru, GRUCell<CellParams>) in ONE_HIDDEN_RNN()