• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA
21 
22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
23 
24 #include <memory>
25 #include <vector>
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 
37 namespace tensorflow {
38 
39 typedef Eigen::ThreadPoolDevice CPUDevice;
40 typedef Eigen::GpuDevice GPUDevice;
41 
42 namespace functor {
43 
44 template <typename T>
LSTMBlockCellFpropWithEigen(const LSTMBlockCell & cell,OpKernelContext * ctx,const CPUDevice & d,const float forget_bias,const float cell_clip,bool use_peephole,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::Matrix xh,typename TTypes<T>::Matrix i,typename TTypes<T>::Matrix cs,typename TTypes<T>::Matrix f,typename TTypes<T>::Matrix o,typename TTypes<T>::Matrix ci,typename TTypes<T>::Matrix co,typename TTypes<T>::Matrix icfo,typename TTypes<T>::Matrix h)45 void LSTMBlockCellFpropWithEigen(
46     const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
47     const float forget_bias, const float cell_clip, bool use_peephole,
48     typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev,
49     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
50     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
51     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
52     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
53     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
54     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
55     typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
56     typename TTypes<T>::Matrix h) {
57   // Concat xh = [x, h].
58   xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x;
59   xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev;
60 
61   // states1 = xh * w + b
62   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
63   TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute(
64       ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
65       w, typename gemm_compute_type<T>::type(0.f), icfo);
66   Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
67   Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1});
68   icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
69 
70   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
71   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
72 
73   // Input gate.
74   if (use_peephole) {
75     auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
76     i.device(d) =
77         (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep)
78             .sigmoid();
79   } else {
80     i.device(d) =
81         icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid();
82   }
83 
84   // Cell input.
85   ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh();
86 
87   // Forget gate (w/ bias).
88   if (use_peephole) {
89     auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
90     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
91                    f.constant(T(forget_bias)) + f_peep)
92                       .sigmoid();
93   } else {
94     f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
95                    f.constant(T(forget_bias)))
96                       .sigmoid();
97   }
98 
99   // cs = ci .* i + f .* cs_prev
100   cs.device(d) = i * ci + f * cs_prev;
101 
102   if (cell_clip > 0.0f) {
103     cs.device(d) =
104         cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op<T>());
105   }
106 
107   // co = tanh(cs)
108   co.device(d) = cs.tanh();
109 
110   // Output gate.
111   if (use_peephole) {
112     auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
113     o.device(d) =
114         (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep)
115             .sigmoid();
116   } else {
117     o.device(d) =
118         icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid();
119   }
120 
121   // h = o .* co
122   h.device(d) = o * co;
123 }
124 
125 template <typename Device, typename T, bool USE_CUBLAS>
LSTMBlockCellBpropWithEigen(const LSTMBlockCell & cell,OpKernelContext * ctx,const Device & d,bool use_peephole,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::ConstMatrix i,typename TTypes<T>::ConstMatrix cs,typename TTypes<T>::ConstMatrix f,typename TTypes<T>::ConstMatrix o,typename TTypes<T>::ConstMatrix ci,typename TTypes<T>::ConstMatrix co,typename TTypes<T>::ConstMatrix cs_grad,typename TTypes<T>::ConstMatrix h_grad,typename TTypes<T>::Matrix do_,typename TTypes<T>::Matrix dcs,typename TTypes<T>::Matrix dci,typename TTypes<T>::Matrix df,typename TTypes<T>::Matrix di,typename TTypes<T>::Matrix dicfo,typename TTypes<T>::Matrix cs_prev_grad,typename TTypes<T>::Vec wci_grad,typename TTypes<T>::Vec wcf_grad,typename TTypes<T>::Vec wco_grad)126 void LSTMBlockCellBpropWithEigen(
127     const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
128     bool use_peephole, typename TTypes<T>::ConstMatrix x,
129     typename TTypes<T>::ConstMatrix cs_prev,
130     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
131     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
132     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
133     typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
134     typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
135     typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
136     typename TTypes<T>::ConstMatrix cs_grad,
137     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
138     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
139     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
140     typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
141     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
142     typename TTypes<T>::Vec wco_grad) {
143   // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
144   do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
145 
146   // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
147   dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
148 
149   Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
150   Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
151   if (use_peephole) {
152     dcs.device(d) =
153         dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
154   }
155 
156   // dci[t] = tanh'(ci[t]) dcs[t] i[t]
157   dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
158 
159   // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
160   df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
161 
162   // di[t] = sigm'(i[t]) dcs[t] ci[t]
163   di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
164 
165   dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di;
166   dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci;
167   dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df;
168   dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_;
169 
170   cs_prev_grad.device(d) = dcs * f;
171   if (use_peephole) {
172     cs_prev_grad.device(d) =
173         cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
174         df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
175     wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
176     wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
177     wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
178   }
179 }
180 
181 #define DEFINE_CPU_SPECS(T)                                                   \
182   template <>                                                                 \
183   void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
184       OpKernelContext* ctx, const CPUDevice& d, const float forget_bias,      \
185       const float cell_clip, bool use_peephole,                               \
186       typename TTypes<T>::ConstMatrix x,                                      \
187       typename TTypes<T>::ConstMatrix cs_prev,                                \
188       typename TTypes<T>::ConstMatrix h_prev,                                 \
189       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
190       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
191       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,          \
192       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,            \
193       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,             \
194       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,           \
195       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {        \
196     LSTMBlockCellFpropWithEigen<T>(                                           \
197         *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev,      \
198         h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h);       \
199   }                                                                           \
200   template <>                                                                 \
201   void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()(  \
202       OpKernelContext* ctx, const CPUDevice& d, bool use_peephole,            \
203       typename TTypes<T>::ConstMatrix x,                                      \
204       typename TTypes<T>::ConstMatrix cs_prev,                                \
205       typename TTypes<T>::ConstMatrix h_prev,                                 \
206       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
207       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
208       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
209       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
210       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
211       typename TTypes<T>::ConstMatrix co,                                     \
212       typename TTypes<T>::ConstMatrix cs_grad,                                \
213       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
214       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
215       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
216       typename TTypes<T>::Matrix dicfo,                                       \
217       typename TTypes<T>::Matrix cs_prev_grad,                                \
218       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
219       typename TTypes<T>::Vec wco_grad) {                                     \
220     LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>(        \
221         *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
222         i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo,   \
223         cs_prev_grad, wci_grad, wcf_grad, wco_grad);                          \
224   }                                                                           \
225   template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>;   \
226   template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
227 
228 DEFINE_CPU_SPECS(float);
229 DEFINE_CPU_SPECS(Eigen::half);
230 #undef DEFINE_CPU_SPECS
231 
232 }  // namespace functor
233 
234 template <typename Device, typename T, bool USE_CUBLAS>
235 class LSTMBlockCellOp : public OpKernel {
236  public:
LSTMBlockCellOp(OpKernelConstruction * ctx)237   explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
238     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
239     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
240     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
241   }
242 
Compute(OpKernelContext * ctx)243   void Compute(OpKernelContext* ctx) override {
244     const Tensor* x_tensor = nullptr;
245     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
246 
247     const Tensor* cs_prev_tensor = nullptr;
248     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
249 
250     const Tensor* h_prev_tensor = nullptr;
251     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
252 
253     const Tensor* w_tensor = nullptr;
254     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
255 
256     const Tensor* wci_tensor = nullptr;
257     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
258 
259     const Tensor* wcf_tensor = nullptr;
260     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
261 
262     const Tensor* wco_tensor = nullptr;
263     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
264 
265     const Tensor* b_tensor = nullptr;
266     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
267 
268     const int64 batch_size = x_tensor->dim_size(0);
269     const int64 input_size = x_tensor->dim_size(1);
270     const int64 cell_size = cs_prev_tensor->dim_size(1);
271 
272     // Sanity checks for our input shapes.
273     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
274                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
275                                         cs_prev_tensor->dim_size(0), " vs. ",
276                                         batch_size));
277     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
278                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
279                                         cs_prev_tensor->dim_size(1), " vs. ",
280                                         cell_size));
281 
282     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
283                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
284                                         h_prev_tensor->dim_size(0), " vs. ",
285                                         batch_size));
286     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
287                 errors::InvalidArgument(
288                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
289                     " vs. ", cell_size));
290 
291     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
292                 errors::InvalidArgument(
293                     "w.dim_size(0) != input_size + cell_size: ",
294                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
295     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
296                 errors::InvalidArgument(
297                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
298                     " vs. ", cell_size * 4));
299 
300     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
301                 errors::InvalidArgument(
302                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
303                     " vs. ", cell_size * 4));
304 
305     // Allocate our output tensors.
306     Tensor* i_tensor = nullptr;
307     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
308                             {"h_prev"}, "i",
309                             TensorShape({batch_size, cell_size}), &i_tensor));
310 
311     Tensor* cs_tensor = nullptr;
312     OP_REQUIRES_OK(
313         ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}),
314                                   &cs_tensor));
315 
316     Tensor* f_tensor = nullptr;
317     OP_REQUIRES_OK(
318         ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}),
319                                   &f_tensor));
320 
321     Tensor* o_tensor = nullptr;
322     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
323                             {"cs_prev"}, "o",
324                             TensorShape({batch_size, cell_size}), &o_tensor));
325 
326     Tensor* ci_tensor = nullptr;
327     OP_REQUIRES_OK(
328         ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}),
329                                   &ci_tensor));
330 
331     Tensor* co_tensor = nullptr;
332     OP_REQUIRES_OK(
333         ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}),
334                                   &co_tensor));
335 
336     Tensor* h_tensor = nullptr;
337     OP_REQUIRES_OK(
338         ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}),
339                                   &h_tensor));
340 
341     // Allocate our temp tensors.
342     Tensor xh_tensor;
343     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
344                             DataTypeToEnum<T>::v(),
345                             TensorShape({batch_size, input_size + cell_size}),
346                             &xh_tensor));
347 
348     Tensor icfo_tensor;
349     OP_REQUIRES_OK(ctx,
350                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
351                                       TensorShape({batch_size, cell_size * 4}),
352                                       &icfo_tensor));
353 
354     const Device& device = ctx->eigen_device<Device>();
355 
356     functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
357                                                        cell_size)(
358         ctx, device, forget_bias_, cell_clip_, use_peephole_,
359         x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
360         h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
361         wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
362         xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
363         f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
364         co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
365   }
366 
367  private:
368   float forget_bias_;
369   float cell_clip_;
370   bool use_peephole_;
371 };
372 
373 #define REGISTER_KERNEL(T)                                             \
374   REGISTER_KERNEL_BUILDER(                                             \
375       Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
376       LSTMBlockCellOp<CPUDevice, T, false>);
377 REGISTER_KERNEL(float);
378 REGISTER_KERNEL(Eigen::half);
379 #undef REGISTER_KERNEL
380 
381 #if GOOGLE_CUDA
382 namespace functor {
383 #define DECLARE_GPU_SPEC(T)                                                \
384   template <>                                                              \
385   void LSTMBlockCellFprop<GPUDevice, T, true>::operator()(                 \
386       OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,   \
387       const float cell_clip, bool use_peephole,                            \
388       typename TTypes<T>::ConstMatrix x,                                   \
389       typename TTypes<T>::ConstMatrix cs_prev,                             \
390       typename TTypes<T>::ConstMatrix h_prev,                              \
391       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
392       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,  \
393       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,       \
394       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,         \
395       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,          \
396       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,        \
397       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h);      \
398                                                                            \
399   extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
400 
401 DECLARE_GPU_SPEC(float);
402 DECLARE_GPU_SPEC(Eigen::half);
403 #undef DECLARE_GPU_SPEC
404 }  // end namespace functor
405 
406 #define REGISTER_GPU_KERNEL(T)                                         \
407   REGISTER_KERNEL_BUILDER(                                             \
408       Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
409       LSTMBlockCellOp<GPUDevice, T, true>);
410 
411 REGISTER_GPU_KERNEL(float);
412 REGISTER_GPU_KERNEL(Eigen::half);
413 // REGISTER_GPU_KERNEL(double);
414 #undef REGISTER_GPU_KERNEL
415 #endif  // GOOGLE_CUDA
416 
417 template <typename Device, typename T, bool USE_CUBLAS>
418 class LSTMBlockCellGradOp : public OpKernel {
419  public:
LSTMBlockCellGradOp(OpKernelConstruction * ctx)420   explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
421     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
422   }
423 
Compute(OpKernelContext * ctx)424   void Compute(OpKernelContext* ctx) override {
425     const Tensor* x_tensor = nullptr;
426     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
427 
428     const Tensor* cs_prev_tensor = nullptr;
429     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
430 
431     const Tensor* h_prev_tensor = nullptr;
432     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
433 
434     const Tensor* w_tensor = nullptr;
435     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
436 
437     const Tensor* wci_tensor = nullptr;
438     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
439 
440     const Tensor* wcf_tensor = nullptr;
441     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
442 
443     const Tensor* wco_tensor = nullptr;
444     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
445 
446     const Tensor* b_tensor = nullptr;
447     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
448 
449     const Tensor* i_tensor = nullptr;
450     OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor));
451 
452     const Tensor* cs_tensor = nullptr;
453     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor));
454 
455     const Tensor* f_tensor = nullptr;
456     OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor));
457 
458     const Tensor* o_tensor = nullptr;
459     OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor));
460 
461     const Tensor* ci_tensor = nullptr;
462     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor));
463 
464     const Tensor* co_tensor = nullptr;
465     OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor));
466 
467     const Tensor* cs_grad_tensor = nullptr;
468     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor));
469 
470     const Tensor* h_grad_tensor = nullptr;
471     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor));
472 
473     const int64 batch_size = x_tensor->dim_size(0);
474     const int64 input_size = x_tensor->dim_size(1);
475     const int64 cell_size = cs_prev_tensor->dim_size(1);
476 
477     // Sanity checks for our input shapes.
478     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
479                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
480                                         cs_prev_tensor->dim_size(0), " vs. ",
481                                         batch_size));
482     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
483                 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
484                                         cs_prev_tensor->dim_size(1), " vs. ",
485                                         cell_size));
486 
487     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
488                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
489                                         h_prev_tensor->dim_size(0), " vs. ",
490                                         batch_size));
491     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
492                 errors::InvalidArgument(
493                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
494                     " vs. ", cell_size));
495 
496     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
497                 errors::InvalidArgument(
498                     "w.dim_size(0) != input_size + cell_size: ",
499                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
500     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
501                 errors::InvalidArgument(
502                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
503                     " vs. ", cell_size * 4));
504 
505     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
506                 errors::InvalidArgument(
507                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
508                     " vs. ", cell_size * 4));
509 
510     OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size,
511                 errors::InvalidArgument(
512                     "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0),
513                     " vs. ", batch_size));
514     OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size,
515                 errors::InvalidArgument(
516                     "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1),
517                     " vs. ", cell_size));
518 
519     OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size,
520                 errors::InvalidArgument(
521                     "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0),
522                     " vs. ", batch_size));
523     OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size,
524                 errors::InvalidArgument(
525                     "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1),
526                     " vs. ", cell_size));
527 
528     OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size,
529                 errors::InvalidArgument(
530                     "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0),
531                     " vs. ", batch_size));
532     OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size,
533                 errors::InvalidArgument(
534                     "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1),
535                     " vs. ", cell_size));
536 
537     OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size,
538                 errors::InvalidArgument(
539                     "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0),
540                     " vs. ", batch_size));
541     OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size,
542                 errors::InvalidArgument(
543                     "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1),
544                     " vs. ", cell_size));
545 
546     OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size,
547                 errors::InvalidArgument(
548                     "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0),
549                     " vs. ", batch_size));
550     OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size,
551                 errors::InvalidArgument(
552                     "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1),
553                     " vs. ", cell_size));
554 
555     OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size,
556                 errors::InvalidArgument(
557                     "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0),
558                     " vs. ", batch_size));
559     OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size,
560                 errors::InvalidArgument(
561                     "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1),
562                     " vs. ", cell_size));
563 
564     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size,
565                 errors::InvalidArgument(
566                     "cs_grad_tensor.dims(0) != batch_size: ",
567                     cs_grad_tensor->dim_size(0), " vs. ", batch_size));
568     OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size,
569                 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ",
570                                         cs_grad_tensor->dim_size(1), " vs. ",
571                                         cell_size));
572 
573     OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size,
574                 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ",
575                                         h_grad_tensor->dim_size(0), " vs. ",
576                                         batch_size));
577     OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size,
578                 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ",
579                                         h_grad_tensor->dim_size(1), " vs. ",
580                                         cell_size));
581 
582     // Allocate our output tensors.
583     Tensor* cs_prev_grad_tensor = nullptr;
584     OP_REQUIRES_OK(
585         ctx, ctx->forward_input_or_allocate_output(
586                  {"cs_grad"}, "cs_prev_grad",
587                  TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor));
588 
589     Tensor* dicfo_tensor = nullptr;
590     OP_REQUIRES_OK(ctx, ctx->allocate_output(
591                             "dicfo", TensorShape({batch_size, cell_size * 4}),
592                             &dicfo_tensor));
593 
594     Tensor* wci_grad_tensor = nullptr;
595     OP_REQUIRES_OK(
596         ctx, ctx->forward_input_or_allocate_output(
597                  {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor));
598 
599     Tensor* wcf_grad_tensor = nullptr;
600     OP_REQUIRES_OK(
601         ctx, ctx->forward_input_or_allocate_output(
602                  {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor));
603 
604     Tensor* wco_grad_tensor = nullptr;
605     OP_REQUIRES_OK(
606         ctx, ctx->forward_input_or_allocate_output(
607                  {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor));
608 
609     // Allocate our temp tensors.
610     Tensor do_tensor;
611     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
612                                            TensorShape({batch_size, cell_size}),
613                                            &do_tensor));
614 
615     Tensor dcs_tensor;
616     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
617                                            TensorShape({batch_size, cell_size}),
618                                            &dcs_tensor));
619 
620     Tensor dci_tensor;
621     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
622                                            TensorShape({batch_size, cell_size}),
623                                            &dci_tensor));
624 
625     Tensor df_tensor;
626     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
627                                            TensorShape({batch_size, cell_size}),
628                                            &df_tensor));
629 
630     Tensor di_tensor;
631     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
632                                            TensorShape({batch_size, cell_size}),
633                                            &di_tensor));
634 
635     const Device& device = ctx->eigen_device<Device>();
636 
637     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
638     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
639     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
640 
641     functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
642                                                        cell_size)(
643         ctx, device, use_peephole_, x_tensor->matrix<T>(),
644         cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
645         w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
646         wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
647         cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
648         ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
649         cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
650         do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
651         df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
652         cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
653         wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
654   }
655 
656  protected:
657   bool use_peephole_;
658 };
659 
660 #define REGISTER_KERNEL(T)                                                 \
661   REGISTER_KERNEL_BUILDER(                                                 \
662       Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
663       LSTMBlockCellGradOp<CPUDevice, T, false>);
664 REGISTER_KERNEL(float);
665 REGISTER_KERNEL(Eigen::half);
666 #undef REGISTER_KERNEL
667 
668 #if GOOGLE_CUDA
669 namespace functor {
670 #define DECLARE_GPU_SPEC(T)                                                   \
671   template <>                                                                 \
672   void LSTMBlockCellBprop<GPUDevice, T, true>::operator()(                    \
673       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,            \
674       typename TTypes<T>::ConstMatrix x,                                      \
675       typename TTypes<T>::ConstMatrix cs_prev,                                \
676       typename TTypes<T>::ConstMatrix h_prev,                                 \
677       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
678       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
679       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
680       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
681       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
682       typename TTypes<T>::ConstMatrix co,                                     \
683       typename TTypes<T>::ConstMatrix cs_grad,                                \
684       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
685       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
686       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
687       typename TTypes<T>::Matrix dicfo,                                       \
688       typename TTypes<T>::Matrix cs_prev_grad,                                \
689       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
690       typename TTypes<T>::Vec wco_grad);                                      \
691                                                                               \
692   extern template struct LSTMBlockCellBprop<GPUDevice, T,                     \
693                                             true /* USE_CUBLAS */>;
694 
695 DECLARE_GPU_SPEC(float);
696 DECLARE_GPU_SPEC(Eigen::half);
697 // DECLARE_GPU_SPEC(double);
698 #undef DECLARE_GPU_SPEC
699 }  // namespace functor
700 
701 #define REGISTER_GPU_KERNEL(T)                                             \
702   REGISTER_KERNEL_BUILDER(                                                 \
703       Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
704       LSTMBlockCellGradOp<GPUDevice, T, true>);
705 
706 REGISTER_GPU_KERNEL(float);
707 REGISTER_GPU_KERNEL(Eigen::half);
708 // REGISTER_GPU_KERNEL(double);
709 #undef REGISTER_GPU_KERNEL
710 #endif  // GOOGLE_CUDA
711 
712 namespace {
713 
714 // This helper class can be used to access timeslices of a 3D tensor. If a slice
715 // happens to be unaligned (usually because both batch size and number of cells
716 // are odd - this isn't common) this involves overhead, since data needs to be
717 // copied. However, if all slices are aligned, the bits aren't copied. In the
718 // cases where copying is needed, the outputs have to be recopied back.
719 // At the end of each time step you should call FinishTimeStep which does this,
720 // and also allows for reuse of temporary tensors.
721 template <typename Device, typename T>
722 class SliceHelper {
723  public:
SliceHelper(OpKernelContext * ctx)724   explicit SliceHelper(OpKernelContext* ctx)
725       : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {}
726 
~SliceHelper()727   ~SliceHelper() {
728     CHECK(copy_out_.empty());
729     for (const auto& entry : pool_) {
730       CHECK(!entry.second.second);  // nothing is in use
731     }
732   }
733 
734   // Slice through an input tensor. This may copy unaligned slices, but no
735   // copying back will be done at the end.
InputSlice(const Tensor & t,int pos,const string & name)736   const Tensor InputSlice(const Tensor& t, int pos, const string& name) {
737     Tensor res = UnalignedSlice(t, pos);
738     if (res.IsAligned()) {
739       return res;
740     } else {
741       return AlignTensor(res, name);
742     }
743   }
744 
745   // Slice through an output tensor. This may copy unaligned slices, and
746   // schedule copying back on destruction.
OutputSlice(Tensor * t,int pos,const string & name)747   Tensor OutputSlice(Tensor* t, int pos, const string& name) {
748     Tensor res = UnalignedSlice(*t, pos);
749     if (res.IsAligned()) {
750       return res;
751     } else {
752       Tensor aligned = AlignTensor(res, name);
753       copy_out_.emplace_back(res, aligned);
754       return aligned;
755     }
756   }
757 
FinishTimeStep()758   void FinishTimeStep() {
759     for (const auto& p : copy_out_) {
760       const Tensor& aligned = p.second;
761       Tensor original = p.first;
762       // Copy from aligned back to original.
763       functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(),
764                                                   original.unaligned_flat<T>());
765     }
766     copy_out_.clear();
767     // Mark all entries as not in use.
768     for (auto& entry : pool_) {
769       entry.second.second = false;
770     }
771   }
772 
773  private:
774   // Return a slice at position 'pos'. Result may be unaligned. The resulting
775   // tensor always shares data with the source tensor.
UnalignedSlice(const Tensor & t,int pos) const776   Tensor UnalignedSlice(const Tensor& t, int pos) const {
777     Tensor res;
778     // CHECK should never fail here, since the number of elements must match
779     CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)}));
780     return res;
781   }
782 
783   // Assumes input is not aligned, creates a temporary aligned tensor of the
784   // same shape and copies the original tensor's content into it.
AlignTensor(const Tensor & t,const string & name)785   Tensor AlignTensor(const Tensor& t, const string& name) {
786     VLOG(1) << "AlignTensor called for " << name << ", shape "
787             << t.shape().DebugString()
788             << ". This is unnecessary copying. Consider using shapes with even "
789             << "sizes";
790     Tensor aligned;
791     auto found = pool_.find(name);
792     if (found != pool_.end()) {  // found in pool
793       CHECK(!found->second.second) << "Tensor " << name << " is in use";
794       found->second.second = true;  // mark in use
795       aligned = found->second.first;
796       CHECK(aligned.shape().IsSameSize(t.shape()));
797       CHECK_EQ(aligned.dtype(), t.dtype());
798     } else {  // allocate a new temporary tensor
799       TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned));
800       pool_.emplace(name, std::make_pair(aligned, true));
801     }
802     functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(),
803                                               aligned.flat<T>());
804     return aligned;
805   }
806 
807   // Tensors to be copied.
808   std::vector<std::pair<Tensor, const Tensor>> copy_out_;
809   // A pool of pre-allocated temporary tensors, with an indicator for whether
810   // it's in use.
811   std::map<string, std::pair<Tensor, bool>> pool_;
812   // Op context
813   OpKernelContext* ctx_ = nullptr;
814   // Device
815   const Device& device_;
816 };
817 
818 }  // namespace
819 
820 template <typename Device, typename T, bool USE_CUBLAS>
821 class BlockLSTMOp : public OpKernel {
822  public:
BlockLSTMOp(OpKernelConstruction * ctx)823   explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
824     OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
825     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
826     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
827   }
828 
Compute(OpKernelContext * ctx)829   void Compute(OpKernelContext* ctx) override {
830     const Tensor* seq_len_max_tensor = nullptr;
831     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
832 
833     const Tensor* x;
834     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
835     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
836     const int64 timelen = x->dim_size(0);
837     const int64 batch_size = x->dim_size(1);
838     const int64 input_size = x->dim_size(2);
839 
840     const Tensor* cs_prev_tensor = nullptr;
841     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
842     OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
843                 errors::InvalidArgument("cs_prev must be 2D"));
844     OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
845                 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
846                                         cs_prev_tensor->dim_size(0), " vs. ",
847                                         batch_size));
848     const int64 cell_size = cs_prev_tensor->dim_size(1);
849 
850     if (batch_size * input_size % 2 == 1) {
851       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
852                    << "input_size are odd. You are using: batch_size="
853                    << batch_size << ", input_size=" << input_size;
854     }
855     if (batch_size * cell_size % 2 == 1) {
856       LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
857                    << "cell_size are odd. You are using: batch_size="
858                    << batch_size << ", cell_size=" << cell_size;
859     }
860 
861     const Tensor* h_prev_tensor = nullptr;
862     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
863     OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
864                 errors::InvalidArgument("h_prev must be 2D"));
865     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
866                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
867                                         h_prev_tensor->dim_size(0), " vs. ",
868                                         batch_size));
869     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
870                 errors::InvalidArgument(
871                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
872                     " vs. ", cell_size));
873 
874     const Tensor* w_tensor = nullptr;
875     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
876     OP_REQUIRES(ctx, w_tensor->dims() == 2,
877                 errors::InvalidArgument("w must be 2D"));
878     OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
879                 errors::InvalidArgument(
880                     "w.dim_size(0) != input_size + cell_size: ",
881                     w_tensor->dim_size(0), " vs. ", input_size + cell_size));
882     OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
883                 errors::InvalidArgument(
884                     "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
885                     " vs. ", cell_size * 4));
886 
887     const Tensor* wci_tensor = nullptr;
888     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
889     OP_REQUIRES(ctx, wci_tensor->dims() == 1,
890                 errors::InvalidArgument("wci must be 1D"));
891     OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size,
892                 errors::InvalidArgument(
893                     "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0),
894                     " vs. ", cell_size));
895 
896     const Tensor* wcf_tensor = nullptr;
897     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
898     OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
899                 errors::InvalidArgument("wcf must be 1D"));
900     OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size,
901                 errors::InvalidArgument(
902                     "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0),
903                     " vs. ", cell_size));
904 
905     const Tensor* wco_tensor = nullptr;
906     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
907     OP_REQUIRES(ctx, wco_tensor->dims() == 1,
908                 errors::InvalidArgument("wco must be 1D"));
909     OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size,
910                 errors::InvalidArgument(
911                     "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0),
912                     " vs. ", cell_size));
913 
914     const Tensor* b_tensor = nullptr;
915     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
916     OP_REQUIRES(ctx, b_tensor->dims() == 1,
917                 errors::InvalidArgument("b must be 1D"));
918     OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
919                 errors::InvalidArgument(
920                     "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
921                     " vs. ", cell_size * 4));
922 
923     TensorShape batch_cell_shape({timelen, batch_size, cell_size});
924     Tensor* i_out;
925     OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out));
926 
927     Tensor* cs_out;
928     OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out));
929 
930     Tensor* f_out;
931     OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out));
932 
933     Tensor* o_out;
934     OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out));
935 
936     Tensor* ci_out;
937     OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out));
938 
939     Tensor* co_out;
940     OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out));
941 
942     Tensor* h_out;
943     OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out));
944 
945     Tensor xh_tensor;
946     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
947                             DataTypeToEnum<T>::v(),
948                             TensorShape({batch_size, input_size + cell_size}),
949                             &xh_tensor));
950 
951     Tensor icfo_tensor;
952     OP_REQUIRES_OK(ctx,
953                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
954                                       TensorShape({batch_size, cell_size * 4}),
955                                       &icfo_tensor));
956 
957     const Device& device = ctx->eigen_device<Device>();
958 
959     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
960     SliceHelper<Device, T> slicer(ctx);
961     for (int64 t = 0; t < seq_len_max; ++t) {
962       const Tensor x_tensor = slicer.InputSlice(*x, t, "x");
963       const Tensor& cs_prev_tensor2 =
964           t == 0 ? *cs_prev_tensor
965                  : slicer.OutputSlice(cs_out, t - 1, "cs_prev");
966       const Tensor& h_prev_tensor2 =
967           t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev");
968 
969       Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out");
970       Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out");
971       Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out");
972       Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out");
973       Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out");
974       Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
975       Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
976 
977       functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
978                                                          cell_size)(
979           ctx, device, forget_bias_, cell_clip_, use_peephole_,
980           x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
981           h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
982           wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
983           b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
984           cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
985           ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
986           h_tensor.matrix<T>());
987       slicer.FinishTimeStep();
988     }
989 
990     if (seq_len_max < timelen) {
991       Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen);
992       Tensor h_tensor = h_out->Slice(seq_len_max, timelen);
993 
994       functor::TensorUnalignedZero<Device, T>()(device,
995                                                 cs_tensor.unaligned_flat<T>());
996       functor::TensorUnalignedZero<Device, T>()(device,
997                                                 h_tensor.unaligned_flat<T>());
998     }
999   }
1000 
1001  private:
1002   float forget_bias_;
1003   float cell_clip_;
1004   bool use_peephole_;
1005 };
1006 
1007 #define REGISTER_KERNEL(T)                                         \
1008   REGISTER_KERNEL_BUILDER(                                         \
1009       Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1010       BlockLSTMOp<CPUDevice, T, false>);
1011 REGISTER_KERNEL(float);
1012 REGISTER_KERNEL(Eigen::half);
1013 #undef REGISTER_KERNEL
1014 
1015 #if GOOGLE_CUDA
1016 namespace functor {
1017 #define DECLARE_GPU_SPEC(T)                                              \
1018   template <>                                                            \
1019   void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d,          \
1020                                             typename TTypes<T>::Flat t); \
1021                                                                          \
1022   extern template struct TensorZero<GPUDevice, T>;                       \
1023                                                                          \
1024   template <>                                                            \
1025   void TensorUnalignedZero<GPUDevice, T>::operator()(                    \
1026       const GPUDevice& d, typename TTypes<T>::UnalignedFlat t);          \
1027                                                                          \
1028   extern template struct TensorUnalignedZero<GPUDevice, T>;
1029 
1030 DECLARE_GPU_SPEC(float);
1031 DECLARE_GPU_SPEC(Eigen::half);
1032 // DECLARE_GPU_SPEC(double);
1033 #undef DECLARE_GPU_SPEC
1034 }  // end namespace functor
1035 
1036 #define REGISTER_GPU_KERNEL(T)                           \
1037   REGISTER_KERNEL_BUILDER(Name("BlockLSTM")              \
1038                               .Device(DEVICE_GPU)        \
1039                               .HostMemory("seq_len_max") \
1040                               .TypeConstraint<T>("T"),   \
1041                           BlockLSTMOp<GPUDevice, T, true>);
1042 
1043 REGISTER_GPU_KERNEL(float);
1044 REGISTER_GPU_KERNEL(Eigen::half);
1045 // REGISTER_GPU_KERNEL(double);
1046 #undef REGISTER_GPU_KERNEL
1047 #endif  // GOOGLE_CUDA
1048 
1049 template <typename Device, typename T, bool USE_CUBLAS>
1050 class BlockLSTMGradOp : public OpKernel {
1051  public:
BlockLSTMGradOp(OpKernelConstruction * ctx)1052   explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1053     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
1054   }
1055 
Compute(OpKernelContext * ctx)1056   void Compute(OpKernelContext* ctx) override {
1057     const Tensor* seq_len_max_tensor = nullptr;
1058     OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
1059 
1060     const Tensor* x;
1061     OP_REQUIRES_OK(ctx, ctx->input("x", &x));
1062     OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
1063     const int64 timelen = x->dim_size(0);
1064     const int64 batch_size = x->dim_size(1);
1065     const int64 input_size = x->dim_size(2);
1066 
1067     const Tensor* cs_prev_tensor = nullptr;
1068     OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
1069 
1070     const Tensor* h_prev_tensor = nullptr;
1071     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
1072 
1073     const Tensor* w_tensor = nullptr;
1074     OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
1075     const int64 cell_size = w_tensor->dim_size(1) / 4;
1076     OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
1077                 errors::InvalidArgument(
1078                     "w matrix rows don't match: ", input_size + cell_size,
1079                     " vs. ", w_tensor->dim_size(0)));
1080 
1081     const Tensor* wci_tensor = nullptr;
1082     OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
1083 
1084     const Tensor* wcf_tensor = nullptr;
1085     OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
1086 
1087     const Tensor* wco_tensor = nullptr;
1088     OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
1089 
1090     const Tensor* b_tensor = nullptr;
1091     OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
1092     OP_REQUIRES(
1093         ctx, cell_size == b_tensor->dim_size(0) / 4,
1094         errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
1095                                 " vs. ", b_tensor->dim_size(0)));
1096 
1097     const Tensor* i_out = nullptr;
1098     OP_REQUIRES_OK(ctx, ctx->input("i", &i_out));
1099 
1100     const Tensor* cs_out = nullptr;
1101     OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out));
1102 
1103     const Tensor* f_out = nullptr;
1104     OP_REQUIRES_OK(ctx, ctx->input("f", &f_out));
1105 
1106     const Tensor* o_out = nullptr;
1107     OP_REQUIRES_OK(ctx, ctx->input("o", &o_out));
1108 
1109     const Tensor* ci_out = nullptr;
1110     OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out));
1111 
1112     const Tensor* co_out = nullptr;
1113     OP_REQUIRES_OK(ctx, ctx->input("co", &co_out));
1114 
1115     const Tensor* h_out = nullptr;
1116     OP_REQUIRES_OK(ctx, ctx->input("h", &h_out));
1117 
1118     const Tensor* cs_grad = nullptr;
1119     OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad));
1120 
1121     const Tensor* h_grad = nullptr;
1122     OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad));
1123 
1124     TensorShape batch_input_shape({timelen, batch_size, input_size});
1125     Tensor* x_grad;
1126     OP_REQUIRES_OK(ctx,
1127                    ctx->allocate_output("x_grad", batch_input_shape, &x_grad));
1128 
1129     Tensor* cs_prev_grad_tensor = nullptr;
1130     OP_REQUIRES_OK(ctx,
1131                    ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(),
1132                                         &cs_prev_grad_tensor));
1133 
1134     Tensor* h_prev_grad_tensor = nullptr;
1135     OP_REQUIRES_OK(ctx,
1136                    ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(),
1137                                         &h_prev_grad_tensor));
1138 
1139     Tensor* w_grad_tensor = nullptr;
1140     OP_REQUIRES_OK(
1141         ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor));
1142 
1143     Tensor* wci_grad_tensor = nullptr;
1144     OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
1145                                              &wci_grad_tensor));
1146 
1147     Tensor* wcf_grad_tensor = nullptr;
1148     OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
1149                                              &wcf_grad_tensor));
1150 
1151     Tensor* wco_grad_tensor = nullptr;
1152     OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
1153                                              &wco_grad_tensor));
1154 
1155     Tensor* b_grad_tensor = nullptr;
1156     OP_REQUIRES_OK(
1157         ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
1158 
1159     TensorShape batch_cell_shape({batch_size, cell_size});
1160 
1161     Tensor xh_tensor;
1162     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
1163                             DataTypeToEnum<T>::v(),
1164                             TensorShape({batch_size, input_size + cell_size}),
1165                             &xh_tensor));
1166 
1167     Tensor xh_grad_tensor;
1168     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1169                                            xh_tensor.shape(), &xh_grad_tensor));
1170 
1171     Tensor do_tensor;
1172     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1173                                            batch_cell_shape, &do_tensor));
1174 
1175     Tensor dcs_tensor;
1176     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1177                                            batch_cell_shape, &dcs_tensor));
1178 
1179     Tensor dci_tensor;
1180     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1181                                            batch_cell_shape, &dci_tensor));
1182 
1183     Tensor df_tensor;
1184     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1185                                            batch_cell_shape, &df_tensor));
1186 
1187     Tensor di_tensor;
1188     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1189                                            batch_cell_shape, &di_tensor));
1190 
1191     Tensor dicfo_tensor;
1192     OP_REQUIRES_OK(ctx,
1193                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
1194                                       TensorShape({batch_size, cell_size * 4}),
1195                                       &dicfo_tensor));
1196 
1197     Tensor cs_grad_tensor;
1198     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1199                                            batch_cell_shape, &cs_grad_tensor));
1200 
1201     Tensor h_grad_tensor;
1202     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1203                                            batch_cell_shape, &h_grad_tensor));
1204 
1205     const Device& device = ctx->eigen_device<Device>();
1206 
1207     functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<T>());
1208     functor::TensorZero<Device, T>()(device, cs_prev_grad_tensor->flat<T>());
1209     functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<T>());
1210     functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<T>());
1211     functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<T>());
1212     functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
1213     functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
1214     functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
1215     functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<T>());
1216 
1217     const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
1218     SliceHelper<Device, T> slicer(ctx);
1219     for (int64 t = seq_len_max - 1; t >= 0; --t) {
1220       const Tensor& x_tensor = slicer.InputSlice(*x, t, "x");
1221       const Tensor& cs_prev_tensor2 =
1222           t == 0 ? *cs_prev_tensor
1223                  : slicer.InputSlice(*cs_out, t - 1, "cs_prev");
1224       const Tensor& h_prev_tensor2 =
1225           t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev");
1226       const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out");
1227       const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out");
1228       const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out");
1229       const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out");
1230       const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out");
1231       const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out");
1232 
1233       // Grab previous CS grad.
1234       const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
1235       const Tensor const_cs_grad_slice =
1236           slicer.InputSlice(*cs_grad, t, "cs_grad");
1237       functor::TensorAdd<Device, T>()(
1238           device, const_cs_prev_grad_tensor.flat<T>(),
1239           const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>());
1240 
1241       // Combine previous h grad and h grad coming on top.
1242       const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
1243       const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad");
1244       functor::TensorAdd<Device, T>()(
1245           device, const_h_prev_grad_tensor.flat<T>(),
1246           const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>());
1247 
1248       const Tensor& const_cs_grad_tensor = cs_grad_tensor;
1249       const Tensor& const_h_grad_tensor = h_grad_tensor;
1250 
1251       Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
1252       functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
1253                                                      cell_size)(
1254           ctx, device, use_peephole_, x_tensor.matrix<T>(),
1255           cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
1256           w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
1257           wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
1258           i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(),
1259           o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
1260           const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
1261           do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
1262           df_tensor.matrix<T>(), di_tensor.matrix<T>(),
1263           dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
1264           h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
1265           x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
1266           wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
1267           wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
1268       slicer.FinishTimeStep();
1269     }
1270 
1271     if (seq_len_max < timelen) {
1272       Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen);
1273       functor::TensorUnalignedZero<Device, T>()(
1274           device, x_grad_tensor.unaligned_flat<T>());
1275     }
1276   }
1277 
1278  private:
1279   bool use_peephole_;
1280 };
1281 
1282 #define REGISTER_KERNEL(T)                                             \
1283   REGISTER_KERNEL_BUILDER(                                             \
1284       Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1285       BlockLSTMGradOp<CPUDevice, T, false>);
1286 REGISTER_KERNEL(float);
1287 REGISTER_KERNEL(Eigen::half);
1288 #undef REGISTER_KERNEL
1289 
1290 #if GOOGLE_CUDA
1291 namespace functor {
1292 #define DECLARE_GPU_SPEC(T)                                                    \
1293   template <>                                                                  \
1294   void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d,                \
1295                                             typename TTypes<T>::ConstFlat src, \
1296                                             typename TTypes<T>::Flat dst);     \
1297                                                                                \
1298   template <>                                                                  \
1299   void TensorCopyUnaligned<GPUDevice, T>::operator()(                          \
1300       const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src,          \
1301       typename TTypes<T>::Flat dst);                                           \
1302                                                                                \
1303   template <>                                                                  \
1304   void TensorCopyToUnaligned<GPUDevice, T>::operator()(                        \
1305       const GPUDevice& d, typename TTypes<T>::ConstFlat src,                   \
1306       typename TTypes<T>::UnalignedFlat dst);                                  \
1307                                                                                \
1308   template <>                                                                  \
1309   void TensorAdd<GPUDevice, T>::operator()(                                    \
1310       const GPUDevice& d, typename TTypes<T>::ConstFlat a,                     \
1311       typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c);            \
1312                                                                                \
1313   template <>                                                                  \
1314   void BlockLSTMBprop<GPUDevice, T, true>::operator()(                         \
1315       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
1316       typename TTypes<T>::ConstMatrix x,                                       \
1317       typename TTypes<T>::ConstMatrix cs_prev,                                 \
1318       typename TTypes<T>::ConstMatrix h_prev,                                  \
1319       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
1320       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
1321       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
1322       typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,   \
1323       typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,    \
1324       typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,  \
1325       typename TTypes<T>::ConstMatrix cs_grad,                                 \
1326       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
1327       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
1328       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
1329       typename TTypes<T>::Matrix dicfo,                                        \
1330       typename TTypes<T>::Matrix cs_prev_grad,                                 \
1331       typename TTypes<T>::Matrix h_prev_grad,                                  \
1332       typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,   \
1333       typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,     \
1334       typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,      \
1335       typename TTypes<T>::Vec b_grad);                                         \
1336                                                                                \
1337   extern template struct TensorCopy<GPUDevice, T>;                             \
1338   extern template struct TensorAdd<GPUDevice, T>;                              \
1339   extern template struct BlockLSTMBprop<GPUDevice, T, true>;
1340 
1341 DECLARE_GPU_SPEC(float);
1342 DECLARE_GPU_SPEC(Eigen::half);
1343 // DECLARE_GPU_SPEC(double);
1344 #undef DECLARE_GPU_SPEC
1345 }  // end namespace functor
1346 
1347 #define REGISTER_GPU_KERNEL(T)                           \
1348   REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad")          \
1349                               .Device(DEVICE_GPU)        \
1350                               .HostMemory("seq_len_max") \
1351                               .TypeConstraint<T>("T"),   \
1352                           BlockLSTMGradOp<GPUDevice, T, true>);
1353 
1354 REGISTER_GPU_KERNEL(float);
1355 REGISTER_GPU_KERNEL(Eigen::half);
1356 // REGISTER_GPU_KERNEL(double);
1357 #undef REGISTER_GPU_KERNEL
1358 #endif  // GOOGLE_CUDA
1359 
1360 }  // end namespace tensorflow
1361