1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
21
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/kernels/eigen_activations.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/cuda_kernel_helper.h"
26
27 namespace tensorflow {
28 namespace functor {
29
30 typedef Eigen::GpuDevice GPUDevice;
31
32 namespace {
33
34 struct FloatToHalf {
operator ()tensorflow::functor::__anonc337b0540111::FloatToHalf35 __host__ __device__ EIGEN_STRONG_INLINE Eigen::half operator()(
36 const float& x) const {
37 return Eigen::half_impl::float_to_half_rtne(x);
38 }
39 };
40
41 template <typename U, typename T>
42 __host__ __device__ EIGEN_STRONG_INLINE
43 typename std::enable_if<!std::is_same<T, U>::value, U>::type
44 strict_cast(T t);
45
46 template <typename U, typename T>
47 __host__ __device__ EIGEN_STRONG_INLINE
48 typename std::enable_if<std::is_same<T, U>::value, U>::type
strict_cast(T t)49 strict_cast(T t) {
50 return t;
51 }
52
53 template <>
54 __host__ __device__ EIGEN_STRONG_INLINE Eigen::half
strict_cast(float t)55 strict_cast<Eigen::half, float>(float t) {
56 return FloatToHalf()(t);
57 }
58
59 } // namespace
60
61 template <typename T>
62 struct TensorZero<GPUDevice, T> {
operator ()tensorflow::functor::TensorZero63 void operator()(const GPUDevice& d, typename TTypes<T>::Flat t) {
64 t.device(d) = t.constant(strict_cast<T>(0.f));
65 }
66 };
67
68 template <typename T>
69 struct TensorUnalignedZero<GPUDevice, T> {
operator ()tensorflow::functor::TensorUnalignedZero70 void operator()(const GPUDevice& d, typename TTypes<T>::UnalignedFlat t) {
71 t.device(d) = t.constant(strict_cast<T>(0.f));
72 }
73 };
74
75 namespace {
76
77 // Adds bias, applies non-linearities and gates.
78 //
79 // Launch with a 2D setup such that there is one thread per (example,
80 // activation) with 'x' governing example index and 'y' governing activation.
81 //
82 // Launch with blocks of (batch x 32)
83 //
84 // TODO(b/67600500): Try making 'use_peephole' a template parameter.
85 template <typename T, bool use_peephole>
lstm_gates(const T * icfo,const T * b,const T * cs_prev,const T * wci,const T * wcf,const T * wco,T * o,T * h,T * ci,T * cs,T * co,T * i,T * f,const float forget_bias,const float cell_clip,const int batch_size,const int cell_size)86 __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev,
87 const T* wci, const T* wcf, const T* wco, T* o, T* h,
88 T* ci, T* cs, T* co, T* i, T* f,
89 const float forget_bias, const float cell_clip,
90 const int batch_size, const int cell_size) {
91 const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;
92 const int act_id = blockIdx.y * blockDim.y + threadIdx.y;
93
94 T forget_bias_t = strict_cast<T>(forget_bias);
95 T cell_clip_t = strict_cast<T>(cell_clip);
96
97 if (batch_id >= batch_size || act_id >= cell_size) return;
98
99 // The following code assumes the input arrays are of the following
100 // shapes and interpretations.
101 //
102 // 1) 'icfo' is a matrix such that,
103 //
104 // cell_size cell_size cell_size cell_size
105 // +----------+----------+----------+----------+
106 // | | | | |
107 // | i | c | f | o | batch_size
108 // | | | | |
109 // +----------+----------+----------+----------+
110 //
111 // 'gid' is the index assigned to this thread for 'icfo' in the 'i' submatrix.
112 //
113 // 2) 'b' is a vector such that,
114 //
115 // cell_size cell_size cell_size cell_size
116 // +----------+----------+----------+----------+
117 // | i | c | f | o | 1
118 // +----------+----------+----------+----------+
119 //
120 // 'act_id' is the index assigned to this thread for 'b' in the 'i' subvector.
121 //
122 // 3) 'wc{i,f,o}' are vectors such that,
123 //
124 // cell_size
125 // +----------+
126 // | i | 1
127 // +----------+
128 //
129 // 'act_id' is the index to this thread.
130 //
131 // 4) All other matrices have the form,
132 //
133 // cell_size
134 // +----------+
135 // | |
136 // | i | batch_size
137 // | |
138 // +----------+
139 //
140 // 'cid' is the index assigned to this thread.
141 //
142 const int gid = batch_id * cell_size * 4 + act_id;
143 const int cid = batch_id * cell_size + act_id;
144 Eigen::internal::scalar_logistic_op<T> sigmoid_op;
145 Eigen::internal::scalar_tanh_op<T> tanh_op;
146 Eigen::scalar_clip_op<T> clip_op;
147
148 T i_local;
149 if (use_peephole) {
150 i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id] +
151 cs_prev[cid] * wci[act_id]);
152 } else {
153 i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id]);
154 }
155 i[cid] = i_local;
156
157 const T ci_local =
158 tanh_op(icfo[1 * cell_size + gid] + b[1 * cell_size + act_id]);
159 ci[cid] = ci_local;
160
161 T f_local;
162 if (use_peephole) {
163 f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] +
164 forget_bias_t + cs_prev[cid] * wcf[act_id]);
165 } else {
166 f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] +
167 forget_bias_t);
168 }
169 f[cid] = f_local;
170
171 T cs_local = i_local * ci_local + f_local * cs_prev[cid];
172 if (cell_clip > 0.0f) {
173 cs_local = clip_op(cs_local, cell_clip_t);
174 }
175 cs[cid] = cs_local;
176
177 const T co_local = tanh_op(cs_local);
178 co[cid] = co_local;
179
180 T o_local;
181 if (use_peephole) {
182 o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id] +
183 cs_local * wco[act_id]);
184 } else {
185 o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id]);
186 }
187 o[cid] = o_local;
188
189 h[cid] = o_local * co_local;
190 }
191
192 // Concatenate 'x' and 'h' and copy their contents into 'xh'.
193 template <typename T>
concat_xh(T * xh,const T * x,const T * h_prev,const int batch_size,const int cell_size,const int input_size)194 __global__ void concat_xh(T* xh, const T* x, const T* h_prev,
195 const int batch_size, const int cell_size,
196 const int input_size) {
197 // Assumes 'x', 'h', and 'xh' are of the following shape,
198 //
199 // input_size cell_size
200 // +----------+----------+
201 // | | |
202 // | x | h | batch_size
203 // | | |
204 // +----------+----------+
205 //
206 const int gid = blockDim.x * blockIdx.x + threadIdx.x;
207 const int width = input_size + cell_size;
208
209 if (gid >= width * batch_size) return;
210
211 const int output_row = gid / width;
212 const int output_col = gid % width;
213
214 if (output_col < input_size) { // x
215 xh[gid] = x[output_row * input_size + output_col];
216 } else { // h
217 xh[gid] = h_prev[output_row * cell_size + output_col - input_size];
218 }
219 }
220
221 template <typename T>
LSTMBlockCellFpropWithCUDA(OpKernelContext * ctx,const GPUDevice & d,const float forget_bias,const float cell_clip,bool use_peephole,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::Matrix xh,typename TTypes<T>::Matrix i,typename TTypes<T>::Matrix cs,typename TTypes<T>::Matrix f,typename TTypes<T>::Matrix o,typename TTypes<T>::Matrix ci,typename TTypes<T>::Matrix co,typename TTypes<T>::Matrix icfo,typename TTypes<T>::Matrix h,int batch_size,int cell_size,int input_size)222 void LSTMBlockCellFpropWithCUDA(
223 OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
224 const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
225 typename TTypes<T>::ConstMatrix cs_prev,
226 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
227 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
228 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
229 typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
230 typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
231 typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
232 typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
233 typename TTypes<T>::Matrix h, int batch_size, int cell_size,
234 int input_size) {
235 const cudaStream_t& cu_stream = GetCudaStream(ctx);
236
237 // Concatenate xh = [x, h].
238 //
239 // Each block is assigned 128 threads. Good values are in [128, 1024] and are
240 // divisible by 32 (the size of a warp). The number of blocks is such that
241 // there are enough to process all the data.
242 const int block_dim = 128;
243 const int grid_dim =
244 Eigen::divup(batch_size * (cell_size + input_size), block_dim);
245 TF_CHECK_OK(CudaLaunchKernel(concat_xh<T>, grid_dim, block_dim, 0, cu_stream,
246 xh.data(), x.data(), h_prev.data(), batch_size,
247 cell_size, input_size));
248
249 // states1 = xh * w
250 typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
251 TensorBlasGemm<GPUDevice, T, true /* USE_CUBLAS */>::compute(
252 ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
253 w, typename gemm_compute_type<T>::type(0.f), icfo);
254
255 // Add bias, apply non-linearities and gating.
256 //
257 // Use 2D blocks. The number of threads per block is equal to x * y, where x =
258 // min(batch_size, 8) and y = 32. See above for guidance on number of
259 // threads.
260 dim3 block_dim_2d(std::min(batch_size, 8), 32);
261 dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),
262 Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
263
264 if (use_peephole) {
265 TF_CHECK_OK(CudaLaunchKernel(
266 lstm_gates<T, true>, grid_dim_2d, block_dim_2d, 0, cu_stream,
267 icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
268 wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
269 i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
270 } else {
271 TF_CHECK_OK(CudaLaunchKernel(
272 lstm_gates<T, false>, grid_dim_2d, block_dim_2d, 0, cu_stream,
273 icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
274 wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
275 i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
276 }
277 }
278
279 template <typename T>
lstm_gates_bprop(const T * cs_prev,const T * h_prev,const T * w,const T * wci,const T * wcf,const T * wco,const T * b,const T * i,const T * cs,const T * f,const T * o,const T * ci,const T * co,const T * cs_grad,const T * h_grad,T * do_,T * dcs,T * dci,T * df,T * di,T * dicfo,T * cs_prev_grad,const int batch_size,const int cell_size,const bool use_peephole)280 __global__ void lstm_gates_bprop(
281 const T* cs_prev, // [batch_size, cell_size]
282 const T* h_prev, // [batch_size, cell_size]
283 const T* w, // [input_size + cell_size, 4 * cell_size]
284 const T* wci, // [cell_size]
285 const T* wcf, // [cell_size]
286 const T* wco, // [cell_size]
287 const T* b, // [4 * cell_size]
288 const T* i, // [batch_size, cell_size]
289 const T* cs, // [batch_size, cell_size]
290 const T* f, // [batch_size, cell_size]
291 const T* o, // [batch_size, cell_size]
292 const T* ci, // [batch_size, cell_size]
293 const T* co, // [batch_size, cell_size]
294 const T* cs_grad, // [batch_size, cell_size]
295 const T* h_grad, // [batch_size, cell_size]
296 T* do_, // [batch_size, cell_size]
297 T* dcs, // [batch_size, cell_size]
298 T* dci, // [batch_size, cell_size]
299 T* df, // [batch_size, cell_size]
300 T* di, // [batch_size, cell_size]
301 T* dicfo, // [input_size + cell_size, 4 * cell_size]
302 T* cs_prev_grad, // [batch_size, cell_size]
303 const int batch_size, const int cell_size, const bool use_peephole) {
304 const int batch_id = blockIdx.x * blockDim.x + threadIdx.x;
305 const int act_id = blockIdx.y * blockDim.y + threadIdx.y;
306
307 if (batch_id >= batch_size || act_id >= cell_size) return;
308
309 const int gid = batch_id * cell_size * 4 + act_id;
310 const int cid = batch_id * cell_size + act_id;
311
312 const T one = static_cast<T>(1.0f);
313
314 // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
315 const T o_local = o[cid];
316 const T h_grad_local = h_grad[cid];
317 const T co_local = co[cid];
318 const T ci_local = ci[cid];
319 const T do_local = o_local * (one - o_local) * h_grad_local * co_local;
320 const T i_local = i[cid];
321 const T f_local = f[cid];
322
323 do_[cid] = do_local;
324
325 // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
326 T dcs_local =
327 (one - co_local * co_local) * h_grad_local * o_local + cs_grad[cid];
328 if (use_peephole) {
329 dcs_local += do_local * wco[act_id];
330 }
331 dcs[cid] = dcs_local;
332
333 // dci[t] = tanh'(ci[t]) dcs[t] i[t]
334 const T dci_local = (one - ci_local * ci_local) * dcs_local * i_local;
335 dci[cid] = dci_local;
336
337 // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
338 const T df_local = f_local * (one - f_local) * dcs_local * cs_prev[cid];
339 df[cid] = df_local;
340
341 // di[t] = sigm'(i[t]) dcs[t] ci[t]
342 const T di_local = i_local * (one - i_local) * dcs_local * ci_local;
343 di[cid] = di_local;
344
345 dicfo[gid + 0 * cell_size] = di_local;
346 dicfo[gid + 1 * cell_size] = dci_local;
347 dicfo[gid + 2 * cell_size] = df_local;
348 dicfo[gid + 3 * cell_size] = do_local;
349
350 cs_prev_grad[cid] = dcs_local * f_local;
351 if (use_peephole) {
352 cs_prev_grad[cid] += di_local * wci[act_id] + df_local * wcf[act_id];
353 }
354 }
355
356 template <typename T>
LSTMBlockCellBpropWithCUDA(OpKernelContext * ctx,const GPUDevice & d,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::ConstMatrix i,typename TTypes<T>::ConstMatrix cs,typename TTypes<T>::ConstMatrix f,typename TTypes<T>::ConstMatrix o,typename TTypes<T>::ConstMatrix ci,typename TTypes<T>::ConstMatrix co,typename TTypes<T>::ConstMatrix cs_grad,typename TTypes<T>::ConstMatrix h_grad,typename TTypes<T>::Matrix do_,typename TTypes<T>::Matrix dcs,typename TTypes<T>::Matrix dci,typename TTypes<T>::Matrix df,typename TTypes<T>::Matrix di,typename TTypes<T>::Matrix dicfo,typename TTypes<T>::Matrix cs_prev_grad,typename TTypes<T>::Vec wci_grad,typename TTypes<T>::Vec wcf_grad,typename TTypes<T>::Vec wco_grad,const int batch_size,const int cell_size,const bool use_peephole)357 void LSTMBlockCellBpropWithCUDA(
358 OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
359 typename TTypes<T>::ConstMatrix cs_prev,
360 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
361 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
362 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
363 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
364 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
365 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
366 typename TTypes<T>::ConstMatrix cs_grad,
367 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
368 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
369 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
370 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
371 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
372 typename TTypes<T>::Vec wco_grad, const int batch_size, const int cell_size,
373 const bool use_peephole) {
374 const cudaStream_t& cu_stream = GetCudaStream(ctx);
375
376 dim3 block_dim_2d(std::min(batch_size, 8), 32);
377 dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),
378 Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
379
380 TF_CHECK_OK(CudaLaunchKernel(
381 lstm_gates_bprop<T>, grid_dim_2d, block_dim_2d, 0, cu_stream,
382 cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
383 wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
384 co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
385 dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(),
386 batch_size, cell_size, use_peephole));
387
388 if (use_peephole) {
389 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size});
390 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size, 1});
391 cs_prev_grad.device(d) =
392 cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
393 df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
394 wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
395 wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
396 wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
397 }
398 }
399
400 } // namespace
401
402 #define DEFINE_GPU_SPECS(T) \
403 template struct TensorZero<GPUDevice, T>; \
404 template struct TensorUnalignedZero<GPUDevice, T>; \
405 template struct TensorCopy<GPUDevice, T>; \
406 template struct TensorCopyUnaligned<GPUDevice, T>; \
407 template struct TensorCopyToUnaligned<GPUDevice, T>; \
408 template struct TensorAdd<GPUDevice, T>; \
409 template <> \
410 void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
411 OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
412 const float cell_clip, bool use_peephole, \
413 typename TTypes<T>::ConstMatrix x, \
414 typename TTypes<T>::ConstMatrix cs_prev, \
415 typename TTypes<T>::ConstMatrix h_prev, \
416 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
417 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
418 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
419 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
420 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
421 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
422 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) { \
423 LSTMBlockCellFpropWithCUDA<T>(ctx, d, forget_bias, cell_clip, \
424 use_peephole, x, cs_prev, h_prev, w, wci, \
425 wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, \
426 h, batch_size_, cell_size_, input_size_); \
427 } \
428 template <> \
429 void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
430 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
431 typename TTypes<T>::ConstMatrix x, \
432 typename TTypes<T>::ConstMatrix cs_prev, \
433 typename TTypes<T>::ConstMatrix h_prev, \
434 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
435 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
436 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
437 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
438 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
439 typename TTypes<T>::ConstMatrix co, \
440 typename TTypes<T>::ConstMatrix cs_grad, \
441 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
442 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
443 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
444 typename TTypes<T>::Matrix dicfo, \
445 typename TTypes<T>::Matrix cs_prev_grad, \
446 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
447 typename TTypes<T>::Vec wco_grad) { \
448 LSTMBlockCellBpropWithCUDA<T>( \
449 ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
450 cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, cs_prev_grad, wci_grad, \
451 wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
452 } \
453 template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>; \
454 template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>; \
455 template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */>;
456
457 DEFINE_GPU_SPECS(float);
458 DEFINE_GPU_SPECS(Eigen::half);
459 // DEFINE_GPU_SPECS(double);
460 #undef DEFINE_GPU_SPECS
461
462 } // end namespace functor
463 } // end namespace tensorflow
464 #endif // GOOGLE_CUDA
465