• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/kernels/eigen_activations.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/cuda_kernel_helper.h"
26 
27 namespace tensorflow {
28 namespace functor {
29 
30 typedef Eigen::GpuDevice GPUDevice;
31 
32 namespace {
33 
34 struct FloatToHalf {
operator ()tensorflow::functor::__anonc337b0540111::FloatToHalf35   __host__ __device__ EIGEN_STRONG_INLINE Eigen::half operator()(
36       const float& x) const {
37     return Eigen::half_impl::float_to_half_rtne(x);
38   }
39 };
40 
41 template <typename U, typename T>
42 __host__ __device__ EIGEN_STRONG_INLINE
43     typename std::enable_if<!std::is_same<T, U>::value, U>::type
44     strict_cast(T t);
45 
46 template <typename U, typename T>
47 __host__ __device__ EIGEN_STRONG_INLINE
48     typename std::enable_if<std::is_same<T, U>::value, U>::type
strict_cast(T t)49     strict_cast(T t) {
50   return t;
51 }
52 
53 template <>
54 __host__ __device__ EIGEN_STRONG_INLINE Eigen::half
strict_cast(float t)55 strict_cast<Eigen::half, float>(float t) {
56   return FloatToHalf()(t);
57 }
58 
59 }  // namespace
60 
61 template <typename T>
62 struct TensorZero<GPUDevice, T> {
operator ()tensorflow::functor::TensorZero63   void operator()(const GPUDevice& d, typename TTypes<T>::Flat t) {
64     t.device(d) = t.constant(strict_cast<T>(0.f));
65   }
66 };
67 
68 template <typename T>
69 struct TensorUnalignedZero<GPUDevice, T> {
operator ()tensorflow::functor::TensorUnalignedZero70   void operator()(const GPUDevice& d, typename TTypes<T>::UnalignedFlat t) {
71     t.device(d) = t.constant(strict_cast<T>(0.f));
72   }
73 };
74 
75 namespace {
76 
77 // Adds bias, applies non-linearities and gates.
78 //
79 // Launch with a 2D setup such that there is one thread per (example,
80 // activation) with 'x' governing example index and 'y' governing activation.
81 //
82 // Launch with blocks of (batch x 32)
83 //
84 // TODO(b/67600500): Try making 'use_peephole' a template parameter.
85 template <typename T, bool use_peephole>
lstm_gates(const T * icfo,const T * b,const T * cs_prev,const T * wci,const T * wcf,const T * wco,T * o,T * h,T * ci,T * cs,T * co,T * i,T * f,const float forget_bias,const float cell_clip,const int batch_size,const int cell_size)86 __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,
87                            const T* wci, const T* wcf, const T* wco, T* o, T* h,
88                            T* ci, T* cs, T* co, T* i, T* f,
89                            const float forget_bias, const float cell_clip,
90                            const int batch_size, const int cell_size) {
91   const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;
92   const int act_id = blockIdx.y * blockDim.y + threadIdx.y;
93 
94   T forget_bias_t = strict_cast<T>(forget_bias);
95   T cell_clip_t = strict_cast<T>(cell_clip);
96 
97   if (batch_id >= batch_size || act_id >= cell_size) return;
98 
99   // The following code assumes the input arrays are of the following
100   // shapes and interpretations.
101   //
102   // 1) 'icfo' is a matrix such that,
103   //
104   //   cell_size  cell_size  cell_size  cell_size
105   //  +----------+----------+----------+----------+
106   //  |          |          |          |          |
107   //  |    i     |    c     |    f     |    o     |  batch_size
108   //  |          |          |          |          |
109   //  +----------+----------+----------+----------+
110   //
111   // 'gid' is the index assigned to this thread for 'icfo' in the 'i' submatrix.
112   //
113   // 2) 'b' is a vector such that,
114   //
115   //   cell_size  cell_size  cell_size  cell_size
116   //  +----------+----------+----------+----------+
117   //  |    i     |    c     |    f     |    o     |  1
118   //  +----------+----------+----------+----------+
119   //
120   // 'act_id' is the index assigned to this thread for 'b' in the 'i' subvector.
121   //
122   // 3) 'wc{i,f,o}' are vectors such that,
123   //
124   //   cell_size
125   //  +----------+
126   //  |    i     |  1
127   //  +----------+
128   //
129   //  'act_id' is the index to this thread.
130   //
131   // 4) All other matrices have the form,
132   //
133   //   cell_size
134   //  +----------+
135   //  |          |
136   //  |    i     |  batch_size
137   //  |          |
138   //  +----------+
139   //
140   // 'cid' is the index assigned to this thread.
141   //
142   const int gid = batch_id * cell_size * 4 + act_id;
143   const int cid = batch_id * cell_size + act_id;
144   Eigen::internal::scalar_logistic_op<T> sigmoid_op;
145   Eigen::internal::scalar_tanh_op<T> tanh_op;
146   Eigen::scalar_clip_op<T> clip_op;
147 
148   T i_local;
149   if (use_peephole) {
150     i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id] +
151                          cs_prev[cid] * wci[act_id]);
152   } else {
153     i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id]);
154   }
155   i[cid] = i_local;
156 
157   const T ci_local =
158       tanh_op(icfo[1 * cell_size + gid] + b[1 * cell_size + act_id]);
159   ci[cid] = ci_local;
160 
161   T f_local;
162   if (use_peephole) {
163     f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] +
164                          forget_bias_t + cs_prev[cid] * wcf[act_id]);
165   } else {
166     f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] +
167                          forget_bias_t);
168   }
169   f[cid] = f_local;
170 
171   T cs_local = i_local * ci_local + f_local * cs_prev[cid];
172   if (cell_clip > 0.0f) {
173     cs_local = clip_op(cs_local, cell_clip_t);
174   }
175   cs[cid] = cs_local;
176 
177   const T co_local = tanh_op(cs_local);
178   co[cid] = co_local;
179 
180   T o_local;
181   if (use_peephole) {
182     o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id] +
183                          cs_local * wco[act_id]);
184   } else {
185     o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id]);
186   }
187   o[cid] = o_local;
188 
189   h[cid] = o_local * co_local;
190 }
191 
192 // Concatenate 'x' and 'h' and copy their contents into 'xh'.
193 template <typename T>
concat_xh(T * xh,const T * x,const T * h_prev,const int batch_size,const int cell_size,const int input_size)194 __global__ void concat_xh(T* xh, const T* x, const T* h_prev,
195                           const int batch_size, const int cell_size,
196                           const int input_size) {
197   // Assumes 'x', 'h', and 'xh' are of the following shape,
198   //
199   //   input_size  cell_size
200   //  +----------+----------+
201   //  |          |          |
202   //  |    x     |    h     |  batch_size
203   //  |          |          |
204   //  +----------+----------+
205   //
206   const int gid = blockDim.x * blockIdx.x + threadIdx.x;
207   const int width = input_size + cell_size;
208 
209   if (gid >= width * batch_size) return;
210 
211   const int output_row = gid / width;
212   const int output_col = gid % width;
213 
214   if (output_col < input_size) {  // x
215     xh[gid] = x[output_row * input_size + output_col];
216   } else {  // h
217     xh[gid] = h_prev[output_row * cell_size + output_col - input_size];
218   }
219 }
220 
221 template <typename T>
LSTMBlockCellFpropWithCUDA(OpKernelContext * ctx,const GPUDevice & d,const float forget_bias,const float cell_clip,bool use_peephole,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::Matrix xh,typename TTypes<T>::Matrix i,typename TTypes<T>::Matrix cs,typename TTypes<T>::Matrix f,typename TTypes<T>::Matrix o,typename TTypes<T>::Matrix ci,typename TTypes<T>::Matrix co,typename TTypes<T>::Matrix icfo,typename TTypes<T>::Matrix h,int batch_size,int cell_size,int input_size)222 void LSTMBlockCellFpropWithCUDA(
223     OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
224     const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
225     typename TTypes<T>::ConstMatrix cs_prev,
226     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
227     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
228     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
229     typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
230     typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
231     typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
232     typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
233     typename TTypes<T>::Matrix h, int batch_size, int cell_size,
234     int input_size) {
235   const cudaStream_t& cu_stream = GetCudaStream(ctx);
236 
237   // Concatenate xh = [x, h].
238   //
239   // Each block is assigned 128 threads. Good values are in [128, 1024] and are
240   // divisible by 32 (the size of a warp). The number of blocks is such that
241   // there are enough to process all the data.
242   const int block_dim = 128;
243   const int grid_dim =
244       Eigen::divup(batch_size * (cell_size + input_size), block_dim);
245   TF_CHECK_OK(CudaLaunchKernel(concat_xh<T>, grid_dim, block_dim, 0, cu_stream,
246                                xh.data(), x.data(), h_prev.data(), batch_size,
247                                cell_size, input_size));
248 
249   // states1 = xh * w
250   typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
251   TensorBlasGemm<GPUDevice, T, true /* USE_CUBLAS */>::compute(
252       ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
253       w, typename gemm_compute_type<T>::type(0.f), icfo);
254 
255   // Add bias, apply non-linearities and gating.
256   //
257   // Use 2D blocks. The number of threads per block is equal to x * y, where x =
258   // min(batch_size, 8) and y = 32. See above for guidance on number of
259   // threads.
260   dim3 block_dim_2d(std::min(batch_size, 8), 32);
261   dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),
262                    Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
263 
264   if (use_peephole) {
265     TF_CHECK_OK(CudaLaunchKernel(
266         lstm_gates<T, true>, grid_dim_2d, block_dim_2d, 0, cu_stream,
267         icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
268         wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
269         i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
270   } else {
271     TF_CHECK_OK(CudaLaunchKernel(
272         lstm_gates<T, false>, grid_dim_2d, block_dim_2d, 0, cu_stream,
273         icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
274         wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
275         i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
276   }
277 }
278 
279 template <typename T>
lstm_gates_bprop(const T * cs_prev,const T * h_prev,const T * w,const T * wci,const T * wcf,const T * wco,const T * b,const T * i,const T * cs,const T * f,const T * o,const T * ci,const T * co,const T * cs_grad,const T * h_grad,T * do_,T * dcs,T * dci,T * df,T * di,T * dicfo,T * cs_prev_grad,const int batch_size,const int cell_size,const bool use_peephole)280 __global__ void lstm_gates_bprop(
281     const T* cs_prev,  // [batch_size, cell_size]
282     const T* h_prev,   // [batch_size, cell_size]
283     const T* w,        // [input_size + cell_size, 4 * cell_size]
284     const T* wci,      // [cell_size]
285     const T* wcf,      // [cell_size]
286     const T* wco,      // [cell_size]
287     const T* b,        // [4 * cell_size]
288     const T* i,        // [batch_size, cell_size]
289     const T* cs,       // [batch_size, cell_size]
290     const T* f,        // [batch_size, cell_size]
291     const T* o,        // [batch_size, cell_size]
292     const T* ci,       // [batch_size, cell_size]
293     const T* co,       // [batch_size, cell_size]
294     const T* cs_grad,  // [batch_size, cell_size]
295     const T* h_grad,   // [batch_size, cell_size]
296     T* do_,            // [batch_size, cell_size]
297     T* dcs,            // [batch_size, cell_size]
298     T* dci,            // [batch_size, cell_size]
299     T* df,             // [batch_size, cell_size]
300     T* di,             // [batch_size, cell_size]
301     T* dicfo,          // [input_size + cell_size, 4 * cell_size]
302     T* cs_prev_grad,   // [batch_size, cell_size]
303     const int batch_size, const int cell_size, const bool use_peephole) {
304   const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;
305   const int act_id = blockIdx.y * blockDim.y + threadIdx.y;
306 
307   if (batch_id >= batch_size || act_id >= cell_size) return;
308 
309   const int gid = batch_id * cell_size * 4 + act_id;
310   const int cid = batch_id * cell_size + act_id;
311 
312   const T one = static_cast<T>(1.0f);
313 
314   // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
315   const T o_local = o[cid];
316   const T h_grad_local = h_grad[cid];
317   const T co_local = co[cid];
318   const T ci_local = ci[cid];
319   const T do_local = o_local * (one - o_local) * h_grad_local * co_local;
320   const T i_local = i[cid];
321   const T f_local = f[cid];
322 
323   do_[cid] = do_local;
324 
325   // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
326   T dcs_local =
327       (one - co_local * co_local) * h_grad_local * o_local + cs_grad[cid];
328   if (use_peephole) {
329     dcs_local += do_local * wco[act_id];
330   }
331   dcs[cid] = dcs_local;
332 
333   // dci[t] = tanh'(ci[t]) dcs[t] i[t]
334   const T dci_local = (one - ci_local * ci_local) * dcs_local * i_local;
335   dci[cid] = dci_local;
336 
337   // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
338   const T df_local = f_local * (one - f_local) * dcs_local * cs_prev[cid];
339   df[cid] = df_local;
340 
341   // di[t] = sigm'(i[t]) dcs[t] ci[t]
342   const T di_local = i_local * (one - i_local) * dcs_local * ci_local;
343   di[cid] = di_local;
344 
345   dicfo[gid + 0 * cell_size] = di_local;
346   dicfo[gid + 1 * cell_size] = dci_local;
347   dicfo[gid + 2 * cell_size] = df_local;
348   dicfo[gid + 3 * cell_size] = do_local;
349 
350   cs_prev_grad[cid] = dcs_local * f_local;
351   if (use_peephole) {
352     cs_prev_grad[cid] += di_local * wci[act_id] + df_local * wcf[act_id];
353   }
354 }
355 
356 template <typename T>
LSTMBlockCellBpropWithCUDA(OpKernelContext * ctx,const GPUDevice & d,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::ConstMatrix i,typename TTypes<T>::ConstMatrix cs,typename TTypes<T>::ConstMatrix f,typename TTypes<T>::ConstMatrix o,typename TTypes<T>::ConstMatrix ci,typename TTypes<T>::ConstMatrix co,typename TTypes<T>::ConstMatrix cs_grad,typename TTypes<T>::ConstMatrix h_grad,typename TTypes<T>::Matrix do_,typename TTypes<T>::Matrix dcs,typename TTypes<T>::Matrix dci,typename TTypes<T>::Matrix df,typename TTypes<T>::Matrix di,typename TTypes<T>::Matrix dicfo,typename TTypes<T>::Matrix cs_prev_grad,typename TTypes<T>::Vec wci_grad,typename TTypes<T>::Vec wcf_grad,typename TTypes<T>::Vec wco_grad,const int batch_size,const int cell_size,const bool use_peephole)357 void LSTMBlockCellBpropWithCUDA(
358     OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
359     typename TTypes<T>::ConstMatrix cs_prev,
360     typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
361     typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
362     typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
363     typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
364     typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
365     typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
366     typename TTypes<T>::ConstMatrix cs_grad,
367     typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
368     typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
369     typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
370     typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
371     typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
372     typename TTypes<T>::Vec wco_grad, const int batch_size, const int cell_size,
373     const bool use_peephole) {
374   const cudaStream_t& cu_stream = GetCudaStream(ctx);
375 
376   dim3 block_dim_2d(std::min(batch_size, 8), 32);
377   dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),
378                    Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
379 
380   TF_CHECK_OK(CudaLaunchKernel(
381       lstm_gates_bprop<T>, grid_dim_2d, block_dim_2d, 0, cu_stream,
382       cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
383       wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
384       co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
385       dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(),
386       batch_size, cell_size, use_peephole));
387 
388   if (use_peephole) {
389     Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size});
390     Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size, 1});
391     cs_prev_grad.device(d) =
392         cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
393         df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
394     wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
395     wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
396     wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
397   }
398 }
399 
400 }  // namespace
401 
402 #define DEFINE_GPU_SPECS(T)                                                    \
403   template struct TensorZero<GPUDevice, T>;                                    \
404   template struct TensorUnalignedZero<GPUDevice, T>;                           \
405   template struct TensorCopy<GPUDevice, T>;                                    \
406   template struct TensorCopyUnaligned<GPUDevice, T>;                           \
407   template struct TensorCopyToUnaligned<GPUDevice, T>;                         \
408   template struct TensorAdd<GPUDevice, T>;                                     \
409   template <>                                                                  \
410   void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()(    \
411       OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,       \
412       const float cell_clip, bool use_peephole,                                \
413       typename TTypes<T>::ConstMatrix x,                                       \
414       typename TTypes<T>::ConstMatrix cs_prev,                                 \
415       typename TTypes<T>::ConstMatrix h_prev,                                  \
416       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
417       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
418       typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
419       typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,             \
420       typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,              \
421       typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,            \
422       typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) {         \
423     LSTMBlockCellFpropWithCUDA<T>(ctx, d, forget_bias, cell_clip,              \
424                                   use_peephole, x, cs_prev, h_prev, w, wci,    \
425                                   wcf, wco, b, xh, i, cs, f, o, ci, co, icfo,  \
426                                   h, batch_size_, cell_size_, input_size_);    \
427   }                                                                            \
428   template <>                                                                  \
429   void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()(    \
430       OpKernelContext* ctx, const GPUDevice& d, bool use_peephole,             \
431       typename TTypes<T>::ConstMatrix x,                                       \
432       typename TTypes<T>::ConstMatrix cs_prev,                                 \
433       typename TTypes<T>::ConstMatrix h_prev,                                  \
434       typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
435       typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
436       typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,       \
437       typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,   \
438       typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,   \
439       typename TTypes<T>::ConstMatrix co,                                      \
440       typename TTypes<T>::ConstMatrix cs_grad,                                 \
441       typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
442       typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
443       typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
444       typename TTypes<T>::Matrix dicfo,                                        \
445       typename TTypes<T>::Matrix cs_prev_grad,                                 \
446       typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,      \
447       typename TTypes<T>::Vec wco_grad) {                                      \
448     LSTMBlockCellBpropWithCUDA<T>(                                             \
449         ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co,  \
450         cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, cs_prev_grad, wci_grad, \
451         wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole);            \
452   }                                                                            \
453   template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>;     \
454   template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>;     \
455   template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */>;
456 
457 DEFINE_GPU_SPECS(float);
458 DEFINE_GPU_SPECS(Eigen::half);
459 // DEFINE_GPU_SPECS(double);
460 #undef DEFINE_GPU_SPECS
461 
462 }  // end namespace functor
463 }  // end namespace tensorflow
464 #endif  // GOOGLE_CUDA
465