• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
17 #define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/eigen_activations.h"
23 #include "tensorflow/core/platform/types.h"
24 
25 namespace tensorflow {
26 class OpKernelContext;
27 
28 namespace functor {
29 
30 template <typename Device, typename T>
31 struct TensorZero {
operatorTensorZero32   void operator()(const Device& d, typename TTypes<T>::Flat t) {
33     t.device(d) = t.constant(T(0));
34   }
35 };
36 
37 template <typename Device, typename T>
38 struct TensorUnalignedZero {
operatorTensorUnalignedZero39   void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) {
40     t.device(d) = t.constant(T(0));
41   }
42 };
43 
44 template <typename Device, typename T>
45 struct TensorCopy {
operatorTensorCopy46   void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
47                   typename TTypes<T>::Flat dst) {
48     dst.device(d) = src;
49   }
50 };
51 
52 template <typename Device, typename T>
53 struct TensorCopyUnaligned {
operatorTensorCopyUnaligned54   void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src,
55                   typename TTypes<T>::Flat dst) {
56     dst.device(d) = src;
57   }
58 };
59 
60 template <typename Device, typename T>
61 struct TensorCopyToUnaligned {
operatorTensorCopyToUnaligned62   void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
63                   typename TTypes<T>::UnalignedFlat dst) {
64     dst.device(d) = src;
65   }
66 };
67 
68 template <typename Device, typename T>
69 struct TensorAdd {
operatorTensorAdd70   void operator()(const Device& d, typename TTypes<T>::ConstFlat a,
71                   typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) {
72     c.device(d) = a + b;
73   }
74 };
75 
76 template <typename Device, typename T>
77 struct TensorZeroPadding {
operatorTensorZeroPadding78   void operator()(const Device& d, const int64 time_idx,
79                   typename TTypes<int64>::ConstVec seq_len,
80                   typename TTypes<T>::Vec mask, typename TTypes<T>::Matrix m) {
81     // mask is shape [batch_size].
82     mask.device(d) = seq_len.constant(time_idx) < seq_len;
83 
84     // m_shape is [batch_size, 1].
85     Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1});
86     // broadcast_shape is [1, units].
87     Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]});
88 
89     // m is shape [batch_size, units].
90     m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape);
91   }
92 };
93 
94 struct LSTMBlockCell {
LSTMBlockCellLSTMBlockCell95   LSTMBlockCell(const int batch_size, const int input_size, const int cell_size)
96       : batch_size_(batch_size),
97         input_size_(input_size),
98         cell_size_(cell_size) {}
99 
batch_sizeLSTMBlockCell100   int batch_size() const { return batch_size_; }
101 
input_sizeLSTMBlockCell102   int input_size() const { return input_size_; }
103 
cell_sizeLSTMBlockCell104   int cell_size() const { return cell_size_; }
105 
icfo_i_offsetsLSTMBlockCell106   inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const {
107     return {0, 0};
108   }
109 
icfo_c_offsetsLSTMBlockCell110   inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const {
111     return {0, cell_size_};
112   }
113 
icfo_f_offsetsLSTMBlockCell114   inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const {
115     return {0, cell_size_ * 2};
116   }
117 
icfo_o_offsetsLSTMBlockCell118   inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const {
119     return {0, cell_size_ * 3};
120   }
121 
cell_extentsLSTMBlockCell122   inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
123     return {batch_size_, cell_size_};
124   }
125 
xh_x_offsetsLSTMBlockCell126   inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const {
127     return {0, 0};
128   }
129 
xh_x_extentsLSTMBlockCell130   inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const {
131     return {batch_size_, input_size_};
132   }
133 
xh_h_offsetsLSTMBlockCell134   inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const {
135     return {0, input_size_};
136   }
137 
xh_h_extentsLSTMBlockCell138   inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const {
139     return {batch_size_, cell_size_};
140   }
141 
142  protected:
143   const int batch_size_;
144   const int input_size_;
145   const int cell_size_;
146 };
147 
148 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
149 // GPUDevice implementation.
150 template <typename Device, typename T, bool USE_CUBLAS>
151 struct LSTMBlockCellFprop : public LSTMBlockCell {
LSTMBlockCellFpropLSTMBlockCellFprop152   LSTMBlockCellFprop(const int batch_size, const int input_size,
153                      const int cell_size)
154       : LSTMBlockCell(batch_size, input_size, cell_size) {}
155 
156   void operator()(OpKernelContext* ctx, const Device& d,
157                   const float forget_bias, const float cell_clip,
158                   bool use_peephole, typename TTypes<T>::ConstMatrix x,
159                   typename TTypes<T>::ConstMatrix cs_prev,
160                   typename TTypes<T>::ConstMatrix h_prev,
161                   typename TTypes<T>::ConstMatrix w,
162                   typename TTypes<T>::ConstVec wci,
163                   typename TTypes<T>::ConstVec wcf,
164                   typename TTypes<T>::ConstVec wco,
165                   typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
166                   typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
167                   typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
168                   typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
169                   typename TTypes<T>::Matrix icfo,
170                   typename TTypes<T>::Matrix h);
171 };
172 
173 // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
174 // GPUDevice implementation.
175 template <typename Device, typename T, bool USE_CUBLAS>
176 struct LSTMBlockCellBprop : public LSTMBlockCell {
LSTMBlockCellBpropLSTMBlockCellBprop177   LSTMBlockCellBprop(const int batch_size, const int input_size,
178                      const int cell_size)
179       : LSTMBlockCell(batch_size, input_size, cell_size) {}
180 
181   void operator()(
182       OpKernelContext* ctx, const Device& d, bool use_peephole,
183       typename TTypes<T>::ConstMatrix x,
184       typename TTypes<T>::ConstMatrix cs_prev,
185       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
186       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
187       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
188       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
189       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
190       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
191       typename TTypes<T>::ConstMatrix cs_grad,
192       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
193       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
194       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
195       typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
196       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
197       typename TTypes<T>::Vec wco_grad);
198 };
199 
200 template <typename Device, typename T, bool USE_CUBLAS>
201 struct BlockLSTMBprop : public LSTMBlockCell {
BlockLSTMBpropBlockLSTMBprop202   BlockLSTMBprop(const int batch_size, const int input_size,
203                  const int cell_size)
204       : LSTMBlockCell(batch_size, input_size, cell_size) {}
205 
operatorBlockLSTMBprop206   void operator()(
207       OpKernelContext* ctx, const Device& d, bool use_peephole,
208       typename TTypes<T>::ConstMatrix x,
209       typename TTypes<T>::ConstMatrix cs_prev,
210       typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
211       typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
212       typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
213       typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i,
214       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,
215       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,
216       typename TTypes<T>::ConstMatrix co,
217       typename TTypes<T>::ConstMatrix cs_grad,
218       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
219       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
220       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
221       typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
222       typename TTypes<T>::Matrix h_prev_grad,
223       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,
224       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,
225       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,
226       typename TTypes<T>::Vec b_grad) {
227     // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
228     do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
229 
230     // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
231     dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
232 
233     Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
234     Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
235     if (use_peephole) {
236       dcs.device(d) =
237           dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
238     }
239 
240     // dci[t] = tanh'(ci[t]) dcs[t] i[t]
241     dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
242 
243     // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
244     df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
245 
246     // di[t] = sigm'(i[t]) dcs[t] ci[t]
247     di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
248 
249     dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
250     dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
251     dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
252     dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
253 
254     cs_prev_grad.device(d) = dcs * f;
255     if (use_peephole) {
256       cs_prev_grad.device(d) =
257           cs_prev_grad +
258           di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
259           df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
260     }
261 
262     // xh_grad.
263     typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
264                                                 dicfo.dimensions());
265     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
266         ctx, d, false, true, 1.f, const_dicfo, w, 0.f, xh_grad);
267 
268     // xh.
269     xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
270     xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
271     typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
272 
273     // x_grad.
274     x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents());
275     h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents());
276 
277     // w_grad.
278     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
279         ctx, d, true, false, 1.f, const_xh, const_dicfo, 1.f, w_grad);
280 
281     // b_grad.
282     b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
283 
284     if (use_peephole) {
285       wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0}));
286       wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0}));
287       wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0}));
288     }
289   }
290 };
291 
292 }  // namespace functor
293 }  // namespace tensorflow
294 
295 #endif  // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
296