1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif // GOOGLE_CUDA
21
22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
23
24 #include <memory>
25 #include <vector>
26
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36
37 namespace tensorflow {
38
39 typedef Eigen::ThreadPoolDevice CPUDevice;
40 typedef Eigen::GpuDevice GPUDevice;
41
42 namespace functor {
43
44 template <typename T>
LSTMBlockCellFpropWithEigen(const LSTMBlockCell & cell,OpKernelContext * ctx,const CPUDevice & d,const float forget_bias,const float cell_clip,bool use_peephole,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::Matrix xh,typename TTypes<T>::Matrix i,typename TTypes<T>::Matrix cs,typename TTypes<T>::Matrix f,typename TTypes<T>::Matrix o,typename TTypes<T>::Matrix ci,typename TTypes<T>::Matrix co,typename TTypes<T>::Matrix icfo,typename TTypes<T>::Matrix h)45 void LSTMBlockCellFpropWithEigen(
46 const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
47 const float forget_bias, const float cell_clip, bool use_peephole,
48 typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev,
49 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
50 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
51 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
52 typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
53 typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
54 typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
55 typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
56 typename TTypes<T>::Matrix h) {
57 // Concat xh = [x, h].
58 xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x;
59 xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev;
60
61 // states1 = xh * w + b
62 typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
63 TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute(
64 ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh,
65 w, typename gemm_compute_type<T>::type(0.f), icfo);
66 Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
67 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1});
68 icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
69
70 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
71 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
72
73 // Input gate.
74 if (use_peephole) {
75 auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
76 i.device(d) =
77 (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep)
78 .sigmoid();
79 } else {
80 i.device(d) =
81 icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid();
82 }
83
84 // Cell input.
85 ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh();
86
87 // Forget gate (w/ bias).
88 if (use_peephole) {
89 auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
90 f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
91 f.constant(T(forget_bias)) + f_peep)
92 .sigmoid();
93 } else {
94 f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) +
95 f.constant(T(forget_bias)))
96 .sigmoid();
97 }
98
99 // cs = ci .* i + f .* cs_prev
100 cs.device(d) = i * ci + f * cs_prev;
101
102 if (cell_clip > 0.0f) {
103 cs.device(d) =
104 cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op<T>());
105 }
106
107 // co = tanh(cs)
108 co.device(d) = cs.tanh();
109
110 // Output gate.
111 if (use_peephole) {
112 auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
113 o.device(d) =
114 (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep)
115 .sigmoid();
116 } else {
117 o.device(d) =
118 icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid();
119 }
120
121 // h = o .* co
122 h.device(d) = o * co;
123 }
124
125 template <typename Device, typename T, bool USE_CUBLAS>
LSTMBlockCellBpropWithEigen(const LSTMBlockCell & cell,OpKernelContext * ctx,const Device & d,bool use_peephole,typename TTypes<T>::ConstMatrix x,typename TTypes<T>::ConstMatrix cs_prev,typename TTypes<T>::ConstMatrix h_prev,typename TTypes<T>::ConstMatrix w,typename TTypes<T>::ConstVec wci,typename TTypes<T>::ConstVec wcf,typename TTypes<T>::ConstVec wco,typename TTypes<T>::ConstVec b,typename TTypes<T>::ConstMatrix i,typename TTypes<T>::ConstMatrix cs,typename TTypes<T>::ConstMatrix f,typename TTypes<T>::ConstMatrix o,typename TTypes<T>::ConstMatrix ci,typename TTypes<T>::ConstMatrix co,typename TTypes<T>::ConstMatrix cs_grad,typename TTypes<T>::ConstMatrix h_grad,typename TTypes<T>::Matrix do_,typename TTypes<T>::Matrix dcs,typename TTypes<T>::Matrix dci,typename TTypes<T>::Matrix df,typename TTypes<T>::Matrix di,typename TTypes<T>::Matrix dicfo,typename TTypes<T>::Matrix cs_prev_grad,typename TTypes<T>::Vec wci_grad,typename TTypes<T>::Vec wcf_grad,typename TTypes<T>::Vec wco_grad)126 void LSTMBlockCellBpropWithEigen(
127 const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
128 bool use_peephole, typename TTypes<T>::ConstMatrix x,
129 typename TTypes<T>::ConstMatrix cs_prev,
130 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
131 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
132 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
133 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
134 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
135 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
136 typename TTypes<T>::ConstMatrix cs_grad,
137 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
138 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
139 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
140 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
141 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
142 typename TTypes<T>::Vec wco_grad) {
143 // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
144 do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
145
146 // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
147 dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
148
149 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()});
150 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1});
151 if (use_peephole) {
152 dcs.device(d) =
153 dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
154 }
155
156 // dci[t] = tanh'(ci[t]) dcs[t] i[t]
157 dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
158
159 // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
160 df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
161
162 // di[t] = sigm'(i[t]) dcs[t] ci[t]
163 di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
164
165 dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di;
166 dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci;
167 dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df;
168 dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_;
169
170 cs_prev_grad.device(d) = dcs * f;
171 if (use_peephole) {
172 cs_prev_grad.device(d) =
173 cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
174 df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
175 wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
176 wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
177 wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
178 }
179 }
180
181 #define DEFINE_CPU_SPECS(T) \
182 template <> \
183 void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
184 OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
185 const float cell_clip, bool use_peephole, \
186 typename TTypes<T>::ConstMatrix x, \
187 typename TTypes<T>::ConstMatrix cs_prev, \
188 typename TTypes<T>::ConstMatrix h_prev, \
189 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
190 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
191 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
192 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
193 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
194 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
195 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) { \
196 LSTMBlockCellFpropWithEigen<T>( \
197 *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
198 h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \
199 } \
200 template <> \
201 void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
202 OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
203 typename TTypes<T>::ConstMatrix x, \
204 typename TTypes<T>::ConstMatrix cs_prev, \
205 typename TTypes<T>::ConstMatrix h_prev, \
206 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
207 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
208 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
209 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
210 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
211 typename TTypes<T>::ConstMatrix co, \
212 typename TTypes<T>::ConstMatrix cs_grad, \
213 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
214 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
215 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
216 typename TTypes<T>::Matrix dicfo, \
217 typename TTypes<T>::Matrix cs_prev_grad, \
218 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
219 typename TTypes<T>::Vec wco_grad) { \
220 LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>( \
221 *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
222 i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, \
223 cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
224 } \
225 template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>; \
226 template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
227
228 DEFINE_CPU_SPECS(float);
229 DEFINE_CPU_SPECS(Eigen::half);
230 #undef DEFINE_CPU_SPECS
231
232 } // namespace functor
233
234 template <typename Device, typename T, bool USE_CUBLAS>
235 class LSTMBlockCellOp : public OpKernel {
236 public:
LSTMBlockCellOp(OpKernelConstruction * ctx)237 explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
238 OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
239 OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
240 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
241 }
242
Compute(OpKernelContext * ctx)243 void Compute(OpKernelContext* ctx) override {
244 const Tensor* x_tensor = nullptr;
245 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
246
247 const Tensor* cs_prev_tensor = nullptr;
248 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
249
250 const Tensor* h_prev_tensor = nullptr;
251 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
252
253 const Tensor* w_tensor = nullptr;
254 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
255
256 const Tensor* wci_tensor = nullptr;
257 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
258
259 const Tensor* wcf_tensor = nullptr;
260 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
261
262 const Tensor* wco_tensor = nullptr;
263 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
264
265 const Tensor* b_tensor = nullptr;
266 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
267
268 const int64 batch_size = x_tensor->dim_size(0);
269 const int64 input_size = x_tensor->dim_size(1);
270 const int64 cell_size = cs_prev_tensor->dim_size(1);
271
272 // Sanity checks for our input shapes.
273 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
274 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
275 cs_prev_tensor->dim_size(0), " vs. ",
276 batch_size));
277 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
278 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
279 cs_prev_tensor->dim_size(1), " vs. ",
280 cell_size));
281
282 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
283 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
284 h_prev_tensor->dim_size(0), " vs. ",
285 batch_size));
286 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
287 errors::InvalidArgument(
288 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
289 " vs. ", cell_size));
290
291 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
292 errors::InvalidArgument(
293 "w.dim_size(0) != input_size + cell_size: ",
294 w_tensor->dim_size(0), " vs. ", input_size + cell_size));
295 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
296 errors::InvalidArgument(
297 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
298 " vs. ", cell_size * 4));
299
300 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
301 errors::InvalidArgument(
302 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
303 " vs. ", cell_size * 4));
304
305 // Allocate our output tensors.
306 Tensor* i_tensor = nullptr;
307 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
308 {"h_prev"}, "i",
309 TensorShape({batch_size, cell_size}), &i_tensor));
310
311 Tensor* cs_tensor = nullptr;
312 OP_REQUIRES_OK(
313 ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}),
314 &cs_tensor));
315
316 Tensor* f_tensor = nullptr;
317 OP_REQUIRES_OK(
318 ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}),
319 &f_tensor));
320
321 Tensor* o_tensor = nullptr;
322 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
323 {"cs_prev"}, "o",
324 TensorShape({batch_size, cell_size}), &o_tensor));
325
326 Tensor* ci_tensor = nullptr;
327 OP_REQUIRES_OK(
328 ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}),
329 &ci_tensor));
330
331 Tensor* co_tensor = nullptr;
332 OP_REQUIRES_OK(
333 ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}),
334 &co_tensor));
335
336 Tensor* h_tensor = nullptr;
337 OP_REQUIRES_OK(
338 ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}),
339 &h_tensor));
340
341 // Allocate our temp tensors.
342 Tensor xh_tensor;
343 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
344 DataTypeToEnum<T>::v(),
345 TensorShape({batch_size, input_size + cell_size}),
346 &xh_tensor));
347
348 Tensor icfo_tensor;
349 OP_REQUIRES_OK(ctx,
350 ctx->allocate_temp(DataTypeToEnum<T>::v(),
351 TensorShape({batch_size, cell_size * 4}),
352 &icfo_tensor));
353
354 const Device& device = ctx->eigen_device<Device>();
355
356 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
357 cell_size)(
358 ctx, device, forget_bias_, cell_clip_, use_peephole_,
359 x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
360 h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
361 wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
362 xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
363 f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
364 co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
365 }
366
367 private:
368 float forget_bias_;
369 float cell_clip_;
370 bool use_peephole_;
371 };
372
373 #define REGISTER_KERNEL(T) \
374 REGISTER_KERNEL_BUILDER( \
375 Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
376 LSTMBlockCellOp<CPUDevice, T, false>);
377 REGISTER_KERNEL(float);
378 REGISTER_KERNEL(Eigen::half);
379 #undef REGISTER_KERNEL
380
381 #if GOOGLE_CUDA
382 namespace functor {
383 #define DECLARE_GPU_SPEC(T) \
384 template <> \
385 void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
386 OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
387 const float cell_clip, bool use_peephole, \
388 typename TTypes<T>::ConstMatrix x, \
389 typename TTypes<T>::ConstMatrix cs_prev, \
390 typename TTypes<T>::ConstMatrix h_prev, \
391 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
392 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
393 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
394 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
395 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
396 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
397 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
398 \
399 extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
400
401 DECLARE_GPU_SPEC(float);
402 DECLARE_GPU_SPEC(Eigen::half);
403 #undef DECLARE_GPU_SPEC
404 } // end namespace functor
405
406 #define REGISTER_GPU_KERNEL(T) \
407 REGISTER_KERNEL_BUILDER( \
408 Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
409 LSTMBlockCellOp<GPUDevice, T, true>);
410
411 REGISTER_GPU_KERNEL(float);
412 REGISTER_GPU_KERNEL(Eigen::half);
413 // REGISTER_GPU_KERNEL(double);
414 #undef REGISTER_GPU_KERNEL
415 #endif // GOOGLE_CUDA
416
417 template <typename Device, typename T, bool USE_CUBLAS>
418 class LSTMBlockCellGradOp : public OpKernel {
419 public:
LSTMBlockCellGradOp(OpKernelConstruction * ctx)420 explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
421 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
422 }
423
Compute(OpKernelContext * ctx)424 void Compute(OpKernelContext* ctx) override {
425 const Tensor* x_tensor = nullptr;
426 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
427
428 const Tensor* cs_prev_tensor = nullptr;
429 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
430
431 const Tensor* h_prev_tensor = nullptr;
432 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
433
434 const Tensor* w_tensor = nullptr;
435 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
436
437 const Tensor* wci_tensor = nullptr;
438 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
439
440 const Tensor* wcf_tensor = nullptr;
441 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
442
443 const Tensor* wco_tensor = nullptr;
444 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
445
446 const Tensor* b_tensor = nullptr;
447 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
448
449 const Tensor* i_tensor = nullptr;
450 OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor));
451
452 const Tensor* cs_tensor = nullptr;
453 OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor));
454
455 const Tensor* f_tensor = nullptr;
456 OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor));
457
458 const Tensor* o_tensor = nullptr;
459 OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor));
460
461 const Tensor* ci_tensor = nullptr;
462 OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor));
463
464 const Tensor* co_tensor = nullptr;
465 OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor));
466
467 const Tensor* cs_grad_tensor = nullptr;
468 OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor));
469
470 const Tensor* h_grad_tensor = nullptr;
471 OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor));
472
473 const int64 batch_size = x_tensor->dim_size(0);
474 const int64 input_size = x_tensor->dim_size(1);
475 const int64 cell_size = cs_prev_tensor->dim_size(1);
476
477 // Sanity checks for our input shapes.
478 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
479 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
480 cs_prev_tensor->dim_size(0), " vs. ",
481 batch_size));
482 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
483 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
484 cs_prev_tensor->dim_size(1), " vs. ",
485 cell_size));
486
487 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
488 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
489 h_prev_tensor->dim_size(0), " vs. ",
490 batch_size));
491 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
492 errors::InvalidArgument(
493 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
494 " vs. ", cell_size));
495
496 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
497 errors::InvalidArgument(
498 "w.dim_size(0) != input_size + cell_size: ",
499 w_tensor->dim_size(0), " vs. ", input_size + cell_size));
500 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
501 errors::InvalidArgument(
502 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
503 " vs. ", cell_size * 4));
504
505 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
506 errors::InvalidArgument(
507 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
508 " vs. ", cell_size * 4));
509
510 OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size,
511 errors::InvalidArgument(
512 "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0),
513 " vs. ", batch_size));
514 OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size,
515 errors::InvalidArgument(
516 "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1),
517 " vs. ", cell_size));
518
519 OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size,
520 errors::InvalidArgument(
521 "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0),
522 " vs. ", batch_size));
523 OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size,
524 errors::InvalidArgument(
525 "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1),
526 " vs. ", cell_size));
527
528 OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size,
529 errors::InvalidArgument(
530 "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0),
531 " vs. ", batch_size));
532 OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size,
533 errors::InvalidArgument(
534 "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1),
535 " vs. ", cell_size));
536
537 OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size,
538 errors::InvalidArgument(
539 "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0),
540 " vs. ", batch_size));
541 OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size,
542 errors::InvalidArgument(
543 "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1),
544 " vs. ", cell_size));
545
546 OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size,
547 errors::InvalidArgument(
548 "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0),
549 " vs. ", batch_size));
550 OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size,
551 errors::InvalidArgument(
552 "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1),
553 " vs. ", cell_size));
554
555 OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size,
556 errors::InvalidArgument(
557 "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0),
558 " vs. ", batch_size));
559 OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size,
560 errors::InvalidArgument(
561 "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1),
562 " vs. ", cell_size));
563
564 OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size,
565 errors::InvalidArgument(
566 "cs_grad_tensor.dims(0) != batch_size: ",
567 cs_grad_tensor->dim_size(0), " vs. ", batch_size));
568 OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size,
569 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ",
570 cs_grad_tensor->dim_size(1), " vs. ",
571 cell_size));
572
573 OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size,
574 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ",
575 h_grad_tensor->dim_size(0), " vs. ",
576 batch_size));
577 OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size,
578 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ",
579 h_grad_tensor->dim_size(1), " vs. ",
580 cell_size));
581
582 // Allocate our output tensors.
583 Tensor* cs_prev_grad_tensor = nullptr;
584 OP_REQUIRES_OK(
585 ctx, ctx->forward_input_or_allocate_output(
586 {"cs_grad"}, "cs_prev_grad",
587 TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor));
588
589 Tensor* dicfo_tensor = nullptr;
590 OP_REQUIRES_OK(ctx, ctx->allocate_output(
591 "dicfo", TensorShape({batch_size, cell_size * 4}),
592 &dicfo_tensor));
593
594 Tensor* wci_grad_tensor = nullptr;
595 OP_REQUIRES_OK(
596 ctx, ctx->forward_input_or_allocate_output(
597 {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor));
598
599 Tensor* wcf_grad_tensor = nullptr;
600 OP_REQUIRES_OK(
601 ctx, ctx->forward_input_or_allocate_output(
602 {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor));
603
604 Tensor* wco_grad_tensor = nullptr;
605 OP_REQUIRES_OK(
606 ctx, ctx->forward_input_or_allocate_output(
607 {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor));
608
609 // Allocate our temp tensors.
610 Tensor do_tensor;
611 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
612 TensorShape({batch_size, cell_size}),
613 &do_tensor));
614
615 Tensor dcs_tensor;
616 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
617 TensorShape({batch_size, cell_size}),
618 &dcs_tensor));
619
620 Tensor dci_tensor;
621 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
622 TensorShape({batch_size, cell_size}),
623 &dci_tensor));
624
625 Tensor df_tensor;
626 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
627 TensorShape({batch_size, cell_size}),
628 &df_tensor));
629
630 Tensor di_tensor;
631 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
632 TensorShape({batch_size, cell_size}),
633 &di_tensor));
634
635 const Device& device = ctx->eigen_device<Device>();
636
637 functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
638 functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
639 functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
640
641 functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
642 cell_size)(
643 ctx, device, use_peephole_, x_tensor->matrix<T>(),
644 cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
645 w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
646 wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
647 cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
648 ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
649 cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
650 do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
651 df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
652 cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
653 wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
654 }
655
656 protected:
657 bool use_peephole_;
658 };
659
660 #define REGISTER_KERNEL(T) \
661 REGISTER_KERNEL_BUILDER( \
662 Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
663 LSTMBlockCellGradOp<CPUDevice, T, false>);
664 REGISTER_KERNEL(float);
665 REGISTER_KERNEL(Eigen::half);
666 #undef REGISTER_KERNEL
667
668 #if GOOGLE_CUDA
669 namespace functor {
670 #define DECLARE_GPU_SPEC(T) \
671 template <> \
672 void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
673 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
674 typename TTypes<T>::ConstMatrix x, \
675 typename TTypes<T>::ConstMatrix cs_prev, \
676 typename TTypes<T>::ConstMatrix h_prev, \
677 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
678 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
679 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
680 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
681 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
682 typename TTypes<T>::ConstMatrix co, \
683 typename TTypes<T>::ConstMatrix cs_grad, \
684 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
685 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
686 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
687 typename TTypes<T>::Matrix dicfo, \
688 typename TTypes<T>::Matrix cs_prev_grad, \
689 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
690 typename TTypes<T>::Vec wco_grad); \
691 \
692 extern template struct LSTMBlockCellBprop<GPUDevice, T, \
693 true /* USE_CUBLAS */>;
694
695 DECLARE_GPU_SPEC(float);
696 DECLARE_GPU_SPEC(Eigen::half);
697 // DECLARE_GPU_SPEC(double);
698 #undef DECLARE_GPU_SPEC
699 } // namespace functor
700
701 #define REGISTER_GPU_KERNEL(T) \
702 REGISTER_KERNEL_BUILDER( \
703 Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
704 LSTMBlockCellGradOp<GPUDevice, T, true>);
705
706 REGISTER_GPU_KERNEL(float);
707 REGISTER_GPU_KERNEL(Eigen::half);
708 // REGISTER_GPU_KERNEL(double);
709 #undef REGISTER_GPU_KERNEL
710 #endif // GOOGLE_CUDA
711
712 namespace {
713
714 // This helper class can be used to access timeslices of a 3D tensor. If a slice
715 // happens to be unaligned (usually because both batch size and number of cells
716 // are odd - this isn't common) this involves overhead, since data needs to be
717 // copied. However, if all slices are aligned, the bits aren't copied. In the
718 // cases where copying is needed, the outputs have to be recopied back.
719 // At the end of each time step you should call FinishTimeStep which does this,
720 // and also allows for reuse of temporary tensors.
721 template <typename Device, typename T>
722 class SliceHelper {
723 public:
SliceHelper(OpKernelContext * ctx)724 explicit SliceHelper(OpKernelContext* ctx)
725 : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {}
726
~SliceHelper()727 ~SliceHelper() {
728 CHECK(copy_out_.empty());
729 for (const auto& entry : pool_) {
730 CHECK(!entry.second.second); // nothing is in use
731 }
732 }
733
734 // Slice through an input tensor. This may copy unaligned slices, but no
735 // copying back will be done at the end.
InputSlice(const Tensor & t,int pos,const string & name)736 const Tensor InputSlice(const Tensor& t, int pos, const string& name) {
737 Tensor res = UnalignedSlice(t, pos);
738 if (res.IsAligned()) {
739 return res;
740 } else {
741 return AlignTensor(res, name);
742 }
743 }
744
745 // Slice through an output tensor. This may copy unaligned slices, and
746 // schedule copying back on destruction.
OutputSlice(Tensor * t,int pos,const string & name)747 Tensor OutputSlice(Tensor* t, int pos, const string& name) {
748 Tensor res = UnalignedSlice(*t, pos);
749 if (res.IsAligned()) {
750 return res;
751 } else {
752 Tensor aligned = AlignTensor(res, name);
753 copy_out_.emplace_back(res, aligned);
754 return aligned;
755 }
756 }
757
FinishTimeStep()758 void FinishTimeStep() {
759 for (const auto& p : copy_out_) {
760 const Tensor& aligned = p.second;
761 Tensor original = p.first;
762 // Copy from aligned back to original.
763 functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(),
764 original.unaligned_flat<T>());
765 }
766 copy_out_.clear();
767 // Mark all entries as not in use.
768 for (auto& entry : pool_) {
769 entry.second.second = false;
770 }
771 }
772
773 private:
774 // Return a slice at position 'pos'. Result may be unaligned. The resulting
775 // tensor always shares data with the source tensor.
UnalignedSlice(const Tensor & t,int pos) const776 Tensor UnalignedSlice(const Tensor& t, int pos) const {
777 Tensor res;
778 // CHECK should never fail here, since the number of elements must match
779 CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)}));
780 return res;
781 }
782
783 // Assumes input is not aligned, creates a temporary aligned tensor of the
784 // same shape and copies the original tensor's content into it.
AlignTensor(const Tensor & t,const string & name)785 Tensor AlignTensor(const Tensor& t, const string& name) {
786 VLOG(1) << "AlignTensor called for " << name << ", shape "
787 << t.shape().DebugString()
788 << ". This is unnecessary copying. Consider using shapes with even "
789 << "sizes";
790 Tensor aligned;
791 auto found = pool_.find(name);
792 if (found != pool_.end()) { // found in pool
793 CHECK(!found->second.second) << "Tensor " << name << " is in use";
794 found->second.second = true; // mark in use
795 aligned = found->second.first;
796 CHECK(aligned.shape().IsSameSize(t.shape()));
797 CHECK_EQ(aligned.dtype(), t.dtype());
798 } else { // allocate a new temporary tensor
799 TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned));
800 pool_.emplace(name, std::make_pair(aligned, true));
801 }
802 functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(),
803 aligned.flat<T>());
804 return aligned;
805 }
806
807 // Tensors to be copied.
808 std::vector<std::pair<Tensor, const Tensor>> copy_out_;
809 // A pool of pre-allocated temporary tensors, with an indicator for whether
810 // it's in use.
811 std::map<string, std::pair<Tensor, bool>> pool_;
812 // Op context
813 OpKernelContext* ctx_ = nullptr;
814 // Device
815 const Device& device_;
816 };
817
818 } // namespace
819
820 template <typename Device, typename T, bool USE_CUBLAS>
821 class BlockLSTMOp : public OpKernel {
822 public:
BlockLSTMOp(OpKernelConstruction * ctx)823 explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
824 OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
825 OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
826 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
827 }
828
Compute(OpKernelContext * ctx)829 void Compute(OpKernelContext* ctx) override {
830 const Tensor* seq_len_max_tensor = nullptr;
831 OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
832
833 const Tensor* x;
834 OP_REQUIRES_OK(ctx, ctx->input("x", &x));
835 OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
836 const int64 timelen = x->dim_size(0);
837 const int64 batch_size = x->dim_size(1);
838 const int64 input_size = x->dim_size(2);
839
840 const Tensor* cs_prev_tensor = nullptr;
841 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
842 OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
843 errors::InvalidArgument("cs_prev must be 2D"));
844 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
845 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
846 cs_prev_tensor->dim_size(0), " vs. ",
847 batch_size));
848 const int64 cell_size = cs_prev_tensor->dim_size(1);
849
850 if (batch_size * input_size % 2 == 1) {
851 LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
852 << "input_size are odd. You are using: batch_size="
853 << batch_size << ", input_size=" << input_size;
854 }
855 if (batch_size * cell_size % 2 == 1) {
856 LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
857 << "cell_size are odd. You are using: batch_size="
858 << batch_size << ", cell_size=" << cell_size;
859 }
860
861 const Tensor* h_prev_tensor = nullptr;
862 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
863 OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
864 errors::InvalidArgument("h_prev must be 2D"));
865 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
866 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
867 h_prev_tensor->dim_size(0), " vs. ",
868 batch_size));
869 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
870 errors::InvalidArgument(
871 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
872 " vs. ", cell_size));
873
874 const Tensor* w_tensor = nullptr;
875 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
876 OP_REQUIRES(ctx, w_tensor->dims() == 2,
877 errors::InvalidArgument("w must be 2D"));
878 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
879 errors::InvalidArgument(
880 "w.dim_size(0) != input_size + cell_size: ",
881 w_tensor->dim_size(0), " vs. ", input_size + cell_size));
882 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4,
883 errors::InvalidArgument(
884 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1),
885 " vs. ", cell_size * 4));
886
887 const Tensor* wci_tensor = nullptr;
888 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
889 OP_REQUIRES(ctx, wci_tensor->dims() == 1,
890 errors::InvalidArgument("wci must be 1D"));
891 OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size,
892 errors::InvalidArgument(
893 "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0),
894 " vs. ", cell_size));
895
896 const Tensor* wcf_tensor = nullptr;
897 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
898 OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
899 errors::InvalidArgument("wcf must be 1D"));
900 OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size,
901 errors::InvalidArgument(
902 "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0),
903 " vs. ", cell_size));
904
905 const Tensor* wco_tensor = nullptr;
906 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
907 OP_REQUIRES(ctx, wco_tensor->dims() == 1,
908 errors::InvalidArgument("wco must be 1D"));
909 OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size,
910 errors::InvalidArgument(
911 "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0),
912 " vs. ", cell_size));
913
914 const Tensor* b_tensor = nullptr;
915 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
916 OP_REQUIRES(ctx, b_tensor->dims() == 1,
917 errors::InvalidArgument("b must be 1D"));
918 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4,
919 errors::InvalidArgument(
920 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0),
921 " vs. ", cell_size * 4));
922
923 TensorShape batch_cell_shape({timelen, batch_size, cell_size});
924 Tensor* i_out;
925 OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out));
926
927 Tensor* cs_out;
928 OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out));
929
930 Tensor* f_out;
931 OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out));
932
933 Tensor* o_out;
934 OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out));
935
936 Tensor* ci_out;
937 OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out));
938
939 Tensor* co_out;
940 OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out));
941
942 Tensor* h_out;
943 OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out));
944
945 Tensor xh_tensor;
946 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
947 DataTypeToEnum<T>::v(),
948 TensorShape({batch_size, input_size + cell_size}),
949 &xh_tensor));
950
951 Tensor icfo_tensor;
952 OP_REQUIRES_OK(ctx,
953 ctx->allocate_temp(DataTypeToEnum<T>::v(),
954 TensorShape({batch_size, cell_size * 4}),
955 &icfo_tensor));
956
957 const Device& device = ctx->eigen_device<Device>();
958
959 const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
960 SliceHelper<Device, T> slicer(ctx);
961 for (int64 t = 0; t < seq_len_max; ++t) {
962 const Tensor x_tensor = slicer.InputSlice(*x, t, "x");
963 const Tensor& cs_prev_tensor2 =
964 t == 0 ? *cs_prev_tensor
965 : slicer.OutputSlice(cs_out, t - 1, "cs_prev");
966 const Tensor& h_prev_tensor2 =
967 t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev");
968
969 Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out");
970 Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out");
971 Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out");
972 Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out");
973 Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out");
974 Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
975 Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
976
977 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
978 cell_size)(
979 ctx, device, forget_bias_, cell_clip_, use_peephole_,
980 x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
981 h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
982 wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
983 b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
984 cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
985 ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
986 h_tensor.matrix<T>());
987 slicer.FinishTimeStep();
988 }
989
990 if (seq_len_max < timelen) {
991 Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen);
992 Tensor h_tensor = h_out->Slice(seq_len_max, timelen);
993
994 functor::TensorUnalignedZero<Device, T>()(device,
995 cs_tensor.unaligned_flat<T>());
996 functor::TensorUnalignedZero<Device, T>()(device,
997 h_tensor.unaligned_flat<T>());
998 }
999 }
1000
1001 private:
1002 float forget_bias_;
1003 float cell_clip_;
1004 bool use_peephole_;
1005 };
1006
1007 #define REGISTER_KERNEL(T) \
1008 REGISTER_KERNEL_BUILDER( \
1009 Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1010 BlockLSTMOp<CPUDevice, T, false>);
1011 REGISTER_KERNEL(float);
1012 REGISTER_KERNEL(Eigen::half);
1013 #undef REGISTER_KERNEL
1014
1015 #if GOOGLE_CUDA
1016 namespace functor {
1017 #define DECLARE_GPU_SPEC(T) \
1018 template <> \
1019 void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d, \
1020 typename TTypes<T>::Flat t); \
1021 \
1022 extern template struct TensorZero<GPUDevice, T>; \
1023 \
1024 template <> \
1025 void TensorUnalignedZero<GPUDevice, T>::operator()( \
1026 const GPUDevice& d, typename TTypes<T>::UnalignedFlat t); \
1027 \
1028 extern template struct TensorUnalignedZero<GPUDevice, T>;
1029
1030 DECLARE_GPU_SPEC(float);
1031 DECLARE_GPU_SPEC(Eigen::half);
1032 // DECLARE_GPU_SPEC(double);
1033 #undef DECLARE_GPU_SPEC
1034 } // end namespace functor
1035
1036 #define REGISTER_GPU_KERNEL(T) \
1037 REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \
1038 .Device(DEVICE_GPU) \
1039 .HostMemory("seq_len_max") \
1040 .TypeConstraint<T>("T"), \
1041 BlockLSTMOp<GPUDevice, T, true>);
1042
1043 REGISTER_GPU_KERNEL(float);
1044 REGISTER_GPU_KERNEL(Eigen::half);
1045 // REGISTER_GPU_KERNEL(double);
1046 #undef REGISTER_GPU_KERNEL
1047 #endif // GOOGLE_CUDA
1048
1049 template <typename Device, typename T, bool USE_CUBLAS>
1050 class BlockLSTMGradOp : public OpKernel {
1051 public:
BlockLSTMGradOp(OpKernelConstruction * ctx)1052 explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1053 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
1054 }
1055
Compute(OpKernelContext * ctx)1056 void Compute(OpKernelContext* ctx) override {
1057 const Tensor* seq_len_max_tensor = nullptr;
1058 OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
1059
1060 const Tensor* x;
1061 OP_REQUIRES_OK(ctx, ctx->input("x", &x));
1062 OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
1063 const int64 timelen = x->dim_size(0);
1064 const int64 batch_size = x->dim_size(1);
1065 const int64 input_size = x->dim_size(2);
1066
1067 const Tensor* cs_prev_tensor = nullptr;
1068 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
1069
1070 const Tensor* h_prev_tensor = nullptr;
1071 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
1072
1073 const Tensor* w_tensor = nullptr;
1074 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
1075 const int64 cell_size = w_tensor->dim_size(1) / 4;
1076 OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
1077 errors::InvalidArgument(
1078 "w matrix rows don't match: ", input_size + cell_size,
1079 " vs. ", w_tensor->dim_size(0)));
1080
1081 const Tensor* wci_tensor = nullptr;
1082 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
1083
1084 const Tensor* wcf_tensor = nullptr;
1085 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
1086
1087 const Tensor* wco_tensor = nullptr;
1088 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
1089
1090 const Tensor* b_tensor = nullptr;
1091 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
1092 OP_REQUIRES(
1093 ctx, cell_size == b_tensor->dim_size(0) / 4,
1094 errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
1095 " vs. ", b_tensor->dim_size(0)));
1096
1097 const Tensor* i_out = nullptr;
1098 OP_REQUIRES_OK(ctx, ctx->input("i", &i_out));
1099
1100 const Tensor* cs_out = nullptr;
1101 OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out));
1102
1103 const Tensor* f_out = nullptr;
1104 OP_REQUIRES_OK(ctx, ctx->input("f", &f_out));
1105
1106 const Tensor* o_out = nullptr;
1107 OP_REQUIRES_OK(ctx, ctx->input("o", &o_out));
1108
1109 const Tensor* ci_out = nullptr;
1110 OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out));
1111
1112 const Tensor* co_out = nullptr;
1113 OP_REQUIRES_OK(ctx, ctx->input("co", &co_out));
1114
1115 const Tensor* h_out = nullptr;
1116 OP_REQUIRES_OK(ctx, ctx->input("h", &h_out));
1117
1118 const Tensor* cs_grad = nullptr;
1119 OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad));
1120
1121 const Tensor* h_grad = nullptr;
1122 OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad));
1123
1124 TensorShape batch_input_shape({timelen, batch_size, input_size});
1125 Tensor* x_grad;
1126 OP_REQUIRES_OK(ctx,
1127 ctx->allocate_output("x_grad", batch_input_shape, &x_grad));
1128
1129 Tensor* cs_prev_grad_tensor = nullptr;
1130 OP_REQUIRES_OK(ctx,
1131 ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(),
1132 &cs_prev_grad_tensor));
1133
1134 Tensor* h_prev_grad_tensor = nullptr;
1135 OP_REQUIRES_OK(ctx,
1136 ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(),
1137 &h_prev_grad_tensor));
1138
1139 Tensor* w_grad_tensor = nullptr;
1140 OP_REQUIRES_OK(
1141 ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor));
1142
1143 Tensor* wci_grad_tensor = nullptr;
1144 OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
1145 &wci_grad_tensor));
1146
1147 Tensor* wcf_grad_tensor = nullptr;
1148 OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
1149 &wcf_grad_tensor));
1150
1151 Tensor* wco_grad_tensor = nullptr;
1152 OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
1153 &wco_grad_tensor));
1154
1155 Tensor* b_grad_tensor = nullptr;
1156 OP_REQUIRES_OK(
1157 ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
1158
1159 TensorShape batch_cell_shape({batch_size, cell_size});
1160
1161 Tensor xh_tensor;
1162 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
1163 DataTypeToEnum<T>::v(),
1164 TensorShape({batch_size, input_size + cell_size}),
1165 &xh_tensor));
1166
1167 Tensor xh_grad_tensor;
1168 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1169 xh_tensor.shape(), &xh_grad_tensor));
1170
1171 Tensor do_tensor;
1172 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1173 batch_cell_shape, &do_tensor));
1174
1175 Tensor dcs_tensor;
1176 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1177 batch_cell_shape, &dcs_tensor));
1178
1179 Tensor dci_tensor;
1180 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1181 batch_cell_shape, &dci_tensor));
1182
1183 Tensor df_tensor;
1184 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1185 batch_cell_shape, &df_tensor));
1186
1187 Tensor di_tensor;
1188 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1189 batch_cell_shape, &di_tensor));
1190
1191 Tensor dicfo_tensor;
1192 OP_REQUIRES_OK(ctx,
1193 ctx->allocate_temp(DataTypeToEnum<T>::v(),
1194 TensorShape({batch_size, cell_size * 4}),
1195 &dicfo_tensor));
1196
1197 Tensor cs_grad_tensor;
1198 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1199 batch_cell_shape, &cs_grad_tensor));
1200
1201 Tensor h_grad_tensor;
1202 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
1203 batch_cell_shape, &h_grad_tensor));
1204
1205 const Device& device = ctx->eigen_device<Device>();
1206
1207 functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<T>());
1208 functor::TensorZero<Device, T>()(device, cs_prev_grad_tensor->flat<T>());
1209 functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<T>());
1210 functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<T>());
1211 functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<T>());
1212 functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>());
1213 functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
1214 functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
1215 functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<T>());
1216
1217 const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
1218 SliceHelper<Device, T> slicer(ctx);
1219 for (int64 t = seq_len_max - 1; t >= 0; --t) {
1220 const Tensor& x_tensor = slicer.InputSlice(*x, t, "x");
1221 const Tensor& cs_prev_tensor2 =
1222 t == 0 ? *cs_prev_tensor
1223 : slicer.InputSlice(*cs_out, t - 1, "cs_prev");
1224 const Tensor& h_prev_tensor2 =
1225 t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev");
1226 const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out");
1227 const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out");
1228 const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out");
1229 const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out");
1230 const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out");
1231 const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out");
1232
1233 // Grab previous CS grad.
1234 const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
1235 const Tensor const_cs_grad_slice =
1236 slicer.InputSlice(*cs_grad, t, "cs_grad");
1237 functor::TensorAdd<Device, T>()(
1238 device, const_cs_prev_grad_tensor.flat<T>(),
1239 const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>());
1240
1241 // Combine previous h grad and h grad coming on top.
1242 const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
1243 const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad");
1244 functor::TensorAdd<Device, T>()(
1245 device, const_h_prev_grad_tensor.flat<T>(),
1246 const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>());
1247
1248 const Tensor& const_cs_grad_tensor = cs_grad_tensor;
1249 const Tensor& const_h_grad_tensor = h_grad_tensor;
1250
1251 Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
1252 functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
1253 cell_size)(
1254 ctx, device, use_peephole_, x_tensor.matrix<T>(),
1255 cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
1256 w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
1257 wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
1258 i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(),
1259 o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
1260 const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
1261 do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
1262 df_tensor.matrix<T>(), di_tensor.matrix<T>(),
1263 dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
1264 h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
1265 x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
1266 wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
1267 wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
1268 slicer.FinishTimeStep();
1269 }
1270
1271 if (seq_len_max < timelen) {
1272 Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen);
1273 functor::TensorUnalignedZero<Device, T>()(
1274 device, x_grad_tensor.unaligned_flat<T>());
1275 }
1276 }
1277
1278 private:
1279 bool use_peephole_;
1280 };
1281
1282 #define REGISTER_KERNEL(T) \
1283 REGISTER_KERNEL_BUILDER( \
1284 Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1285 BlockLSTMGradOp<CPUDevice, T, false>);
1286 REGISTER_KERNEL(float);
1287 REGISTER_KERNEL(Eigen::half);
1288 #undef REGISTER_KERNEL
1289
1290 #if GOOGLE_CUDA
1291 namespace functor {
1292 #define DECLARE_GPU_SPEC(T) \
1293 template <> \
1294 void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d, \
1295 typename TTypes<T>::ConstFlat src, \
1296 typename TTypes<T>::Flat dst); \
1297 \
1298 template <> \
1299 void TensorCopyUnaligned<GPUDevice, T>::operator()( \
1300 const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src, \
1301 typename TTypes<T>::Flat dst); \
1302 \
1303 template <> \
1304 void TensorCopyToUnaligned<GPUDevice, T>::operator()( \
1305 const GPUDevice& d, typename TTypes<T>::ConstFlat src, \
1306 typename TTypes<T>::UnalignedFlat dst); \
1307 \
1308 template <> \
1309 void TensorAdd<GPUDevice, T>::operator()( \
1310 const GPUDevice& d, typename TTypes<T>::ConstFlat a, \
1311 typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c); \
1312 \
1313 template <> \
1314 void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
1315 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
1316 typename TTypes<T>::ConstMatrix x, \
1317 typename TTypes<T>::ConstMatrix cs_prev, \
1318 typename TTypes<T>::ConstMatrix h_prev, \
1319 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
1320 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
1321 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
1322 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, \
1323 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, \
1324 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, \
1325 typename TTypes<T>::ConstMatrix cs_grad, \
1326 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
1327 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
1328 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
1329 typename TTypes<T>::Matrix dicfo, \
1330 typename TTypes<T>::Matrix cs_prev_grad, \
1331 typename TTypes<T>::Matrix h_prev_grad, \
1332 typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, \
1333 typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \
1334 typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \
1335 typename TTypes<T>::Vec b_grad); \
1336 \
1337 extern template struct TensorCopy<GPUDevice, T>; \
1338 extern template struct TensorAdd<GPUDevice, T>; \
1339 extern template struct BlockLSTMBprop<GPUDevice, T, true>;
1340
1341 DECLARE_GPU_SPEC(float);
1342 DECLARE_GPU_SPEC(Eigen::half);
1343 // DECLARE_GPU_SPEC(double);
1344 #undef DECLARE_GPU_SPEC
1345 } // end namespace functor
1346
1347 #define REGISTER_GPU_KERNEL(T) \
1348 REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \
1349 .Device(DEVICE_GPU) \
1350 .HostMemory("seq_len_max") \
1351 .TypeConstraint<T>("T"), \
1352 BlockLSTMGradOp<GPUDevice, T, true>);
1353
1354 REGISTER_GPU_KERNEL(float);
1355 REGISTER_GPU_KERNEL(Eigen::half);
1356 // REGISTER_GPU_KERNEL(double);
1357 #undef REGISTER_GPU_KERNEL
1358 #endif // GOOGLE_CUDA
1359
1360 } // end namespace tensorflow
1361