• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/contrib/rnn/kernels/gru_ops.h"
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/op_kernel.h"
21 
22 namespace tensorflow {
23 
24 typedef Eigen::ThreadPoolDevice CPUDevice;
25 typedef Eigen::GpuDevice GPUDevice;
26 
27 template <typename Device, typename T, bool USE_CUBLAS>
28 class GRUCellBlockOp : public OpKernel {
29  public:
GRUCellBlockOp(OpKernelConstruction * ctx)30   explicit GRUCellBlockOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
31   // TODO(gitegaurav) Replace the input checks with some smarter function.
Compute(OpKernelContext * ctx)32   void Compute(OpKernelContext* ctx) override {
33     // Grab the input tensors.
34     const Tensor* x_tensor = nullptr;
35     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
36 
37     const Tensor* h_prev_tensor = nullptr;
38     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
39 
40     const Tensor* w_ru_tensor = nullptr;
41     OP_REQUIRES_OK(ctx, ctx->input("w_ru", &w_ru_tensor));
42 
43     const Tensor* w_c_tensor = nullptr;
44     OP_REQUIRES_OK(ctx, ctx->input("w_c", &w_c_tensor));
45 
46     const Tensor* b_ru_tensor = nullptr;
47     OP_REQUIRES_OK(ctx, ctx->input("b_ru", &b_ru_tensor));
48 
49     const Tensor* b_c_tensor = nullptr;
50     OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor));
51 
52     const int64 batch_size = x_tensor->dim_size(0);
53     const int64 input_size = x_tensor->dim_size(1);
54     const int64 cell_size = h_prev_tensor->dim_size(1);
55 
56     // Sanity checks for input shapes.
57 
58     // Shape of 'h' must be [batch_size, cell_size]
59     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
60                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
61                                         h_prev_tensor->dim_size(0), " vs. ",
62                                         batch_size));
63     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
64                 errors::InvalidArgument(
65                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
66                     " vs. ", cell_size));
67 
68     // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size]
69     OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size,
70                 errors::InvalidArgument(
71                     "w_ru.dim_size(0) != input_size + cell_size: ",
72                     w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size));
73 
74     OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2,
75                 errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ",
76                                         w_ru_tensor->dim_size(1), " vs. ",
77                                         cell_size * 2));
78 
79     // Shape of 'w_c' must be [input_size+cell_size, cell_size]
80     OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size,
81                 errors::InvalidArgument(
82                     "w_c.dim_size(0) != input_size + cell_size: ",
83                     w_c_tensor->dim_size(0), " vs. ", input_size + cell_size));
84 
85     OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size,
86                 errors::InvalidArgument(
87                     "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1),
88                     " vs. ", cell_size));
89 
90     // Shape of 'b_ru' must be [2*cell_size]
91     OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2,
92                 errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ",
93                                         b_ru_tensor->dim_size(0), " vs. ",
94                                         cell_size * 2));
95 
96     OP_REQUIRES(ctx, b_ru_tensor->dims() == 1,
97                 errors::InvalidArgument("Rank of b_ru must be 1",
98                                         b_ru_tensor->dims(), " vs. 1", 1));
99     // Shape of 'b_c' must be [cell_size]
100     OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size,
101                 errors::InvalidArgument(
102                     "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0),
103                     " vs. ", cell_size));
104     OP_REQUIRES(ctx, b_c_tensor->dims() == 1,
105                 errors::InvalidArgument("Rank of b_c must be 1",
106                                         b_c_tensor->dims(), " vs. 1"));
107 
108     // Create output tensors.
109     Tensor* r_tensor = nullptr;
110     OP_REQUIRES_OK(
111         ctx, ctx->allocate_output("r", TensorShape({batch_size, cell_size}),
112                                   &r_tensor));
113 
114     Tensor* u_tensor = nullptr;
115     OP_REQUIRES_OK(
116         ctx, ctx->allocate_output("u", TensorShape({batch_size, cell_size}),
117                                   &u_tensor));
118 
119     Tensor* c_tensor = nullptr;
120     OP_REQUIRES_OK(
121         ctx, ctx->allocate_output("c", TensorShape({batch_size, cell_size}),
122                                   &c_tensor));
123 
124     Tensor* h_tensor = nullptr;
125     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
126                             {"h_prev"}, "h",
127                             TensorShape({batch_size, cell_size}), &h_tensor));
128 
129     // Allocate temp tensors.
130     Tensor x_h_prev_tensor;
131     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
132                             DataTypeToEnum<T>::v(),
133                             TensorShape({batch_size, input_size + cell_size}),
134                             &x_h_prev_tensor));
135 
136     Tensor x_h_prevr_tensor;
137     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
138                             DataTypeToEnum<T>::v(),
139                             TensorShape({batch_size, input_size + cell_size}),
140                             &x_h_prevr_tensor));
141 
142     Tensor r_u_bar_tensor;
143     OP_REQUIRES_OK(ctx,
144                    ctx->allocate_temp(DataTypeToEnum<T>::v(),
145                                       TensorShape({batch_size, 2 * cell_size}),
146                                       &r_u_bar_tensor));
147 
148     const Device& device = ctx->eigen_device<Device>();
149 
150     functor::GRUBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
151                                                       cell_size)(
152         ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
153         w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
154         b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_u_bar_tensor.matrix<T>(),
155         r_tensor->matrix<T>(), u_tensor->matrix<T>(), c_tensor->matrix<T>(),
156         h_tensor->matrix<T>(), x_h_prev_tensor.matrix<T>(),
157         x_h_prevr_tensor.matrix<T>());
158   }
159 };
160 
161 // Register the Block GRU cell kernel for CPU.
162 #define REGISTER_KERNEL(T)                                            \
163   REGISTER_KERNEL_BUILDER(                                            \
164       Name("GRUBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
165       GRUCellBlockOp<CPUDevice, T, false>);
166 
167 REGISTER_KERNEL(float);
168 #undef REGISTER_KERNEL
169 
170 template <typename Device, typename T, bool USE_CUBLAS>
171 class GRUBlockCellGradOp : public OpKernel {
172  public:
GRUBlockCellGradOp(OpKernelConstruction * ctx)173   explicit GRUBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
174 
Compute(OpKernelContext * ctx)175   void Compute(OpKernelContext* ctx) override {
176     // Grab the input tensors.
177     const Tensor* x_tensor = nullptr;
178     OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
179 
180     const Tensor* h_prev_tensor = nullptr;
181     OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
182 
183     const Tensor* w_ru_tensor = nullptr;
184     OP_REQUIRES_OK(ctx, ctx->input("w_ru", &w_ru_tensor));
185 
186     const Tensor* w_c_tensor = nullptr;
187     OP_REQUIRES_OK(ctx, ctx->input("w_c", &w_c_tensor));
188 
189     const Tensor* b_ru_tensor = nullptr;
190     OP_REQUIRES_OK(ctx, ctx->input("b_ru", &b_ru_tensor));
191 
192     const Tensor* b_c_tensor = nullptr;
193     OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor));
194 
195     const Tensor* r_tensor = nullptr;
196     OP_REQUIRES_OK(ctx, ctx->input("r", &r_tensor));
197 
198     const Tensor* u_tensor = nullptr;
199     OP_REQUIRES_OK(ctx, ctx->input("u", &u_tensor));
200 
201     const Tensor* c_tensor = nullptr;
202     OP_REQUIRES_OK(ctx, ctx->input("c", &c_tensor));
203 
204     const Tensor* d_h_tensor = nullptr;
205     OP_REQUIRES_OK(ctx, ctx->input("d_h", &d_h_tensor));
206 
207     const int64 batch_size = x_tensor->dim_size(0);
208     const int64 input_size = x_tensor->dim_size(1);
209     const int64 cell_size = h_prev_tensor->dim_size(1);
210 
211     // Sanity checks for input shapes.
212 
213     // Shape of 'h_prev' must be [batch_size, cell_size]
214     OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
215                 errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
216                                         h_prev_tensor->dim_size(0), " vs. ",
217                                         batch_size));
218     OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
219                 errors::InvalidArgument(
220                     "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1),
221                     " vs. ", cell_size));
222 
223     // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size]
224     OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size,
225                 errors::InvalidArgument(
226                     "w_ru.dim_size(0) != input_size + cell_size: ",
227                     w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size));
228 
229     OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2,
230                 errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ",
231                                         w_ru_tensor->dim_size(1), " vs. ",
232                                         cell_size * 2));
233 
234     // Shape of 'w_c' must be [input_size+cell_size, cell_size]
235     OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size,
236                 errors::InvalidArgument(
237                     "w_c.dim_size(0) != input_size + cell_size: ",
238                     w_c_tensor->dim_size(0), " vs. ", input_size + cell_size));
239 
240     OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size,
241                 errors::InvalidArgument(
242                     "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1),
243                     " vs. ", cell_size));
244 
245     // Shape of 'b_ru' must be [2*cell_size]
246     OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2,
247                 errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ",
248                                         b_ru_tensor->dim_size(0), " vs. ",
249                                         cell_size * 2));
250 
251     OP_REQUIRES(ctx, b_ru_tensor->dims() == 1,
252                 errors::InvalidArgument("Rank of b_ru must be 1",
253                                         b_ru_tensor->dims(), " vs. 1"));
254 
255     // Shape of 'b_c' must be [cell_size]
256     OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size,
257                 errors::InvalidArgument(
258                     "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0),
259                     " vs. ", cell_size));
260 
261     OP_REQUIRES(ctx, b_c_tensor->dims() == 1,
262                 errors::InvalidArgument("Rank of b_c must be 1 ",
263                                         b_c_tensor->dims(), " vs. 1"));
264 
265     // Shape of 'r' must be [batch_size, cell_size]
266     OP_REQUIRES(ctx, r_tensor->dim_size(0) == batch_size,
267                 errors::InvalidArgument(
268                     "r.dims(0) != batch_size: ", r_tensor->dim_size(0), " vs. ",
269                     batch_size));
270     OP_REQUIRES(ctx, r_tensor->dim_size(1) == cell_size,
271                 errors::InvalidArgument(
272                     "r.dims(1) != cell_size: ", r_tensor->dim_size(1), " vs. ",
273                     cell_size));
274 
275     // Shape of 'u' must be [batch_size, cell_size]
276     OP_REQUIRES(ctx, u_tensor->dim_size(0) == batch_size,
277                 errors::InvalidArgument(
278                     "u.dims(0) != batch_size: ", u_tensor->dim_size(0), " vs. ",
279                     batch_size));
280     OP_REQUIRES(ctx, u_tensor->dim_size(1) == cell_size,
281                 errors::InvalidArgument(
282                     "u.dims(1) != cell_size: ", u_tensor->dim_size(1), " vs. ",
283                     cell_size));
284 
285     // Shape of 'c' must be [batch_size, cell_size]
286     OP_REQUIRES(ctx, c_tensor->dim_size(0) == batch_size,
287                 errors::InvalidArgument(
288                     "c.dims(0) != batch_size: ", c_tensor->dim_size(0), " vs. ",
289                     batch_size));
290     OP_REQUIRES(ctx, c_tensor->dim_size(1) == cell_size,
291                 errors::InvalidArgument(
292                     "c.dims(1) != cell_size: ", c_tensor->dim_size(1), " vs. ",
293                     cell_size));
294 
295     // Shape of 'd_h' must be [batch_size, cell_size]
296     OP_REQUIRES(ctx, d_h_tensor->dim_size(0) == batch_size,
297                 errors::InvalidArgument(
298                     "d_h.dims(0) != batch_size: ", d_h_tensor->dim_size(0),
299                     " vs. ", batch_size));
300     OP_REQUIRES(ctx, d_h_tensor->dim_size(1) == cell_size,
301                 errors::InvalidArgument(
302                     "d_h.dims(1) != cell_size: ", d_h_tensor->dim_size(1),
303                     " vs. ", cell_size));
304 
305     // Create output tensors.
306     Tensor* d_x_tensor = nullptr;
307     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
308                             {"x"}, "d_x", TensorShape({batch_size, input_size}),
309                             &d_x_tensor));
310 
311     Tensor* d_h_prev_tensor = nullptr;
312     OP_REQUIRES_OK(
313         ctx, ctx->forward_input_or_allocate_output(
314                  {"h_prev"}, "d_h_prev", TensorShape({batch_size, cell_size}),
315                  &d_h_prev_tensor));
316 
317     Tensor* d_c_bar_tensor;
318     OP_REQUIRES_OK(ctx, ctx->allocate_output(
319                             "d_c_bar", TensorShape({batch_size, cell_size}),
320                             &d_c_bar_tensor));
321 
322     Tensor* d_r_bar_u_bar_tensor;
323     OP_REQUIRES_OK(
324         ctx, ctx->allocate_output("d_r_bar_u_bar",
325                                   TensorShape({batch_size, 2 * cell_size}),
326                                   &d_r_bar_u_bar_tensor));
327 
328     // Allocate temp tensors.
329     Tensor d_r_bar_tensor;
330     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
331                                            TensorShape({batch_size, cell_size}),
332                                            &d_r_bar_tensor));
333 
334     Tensor d_u_bar_tensor;
335     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
336                                            TensorShape({batch_size, cell_size}),
337                                            &d_u_bar_tensor));
338 
339     Tensor d_h_prevr_tensor;
340     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
341                                            TensorShape({batch_size, cell_size}),
342                                            &d_h_prevr_tensor));
343 
344     Tensor d_x_component_1_h_prev_compenent_1;
345     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
346                             DataTypeToEnum<T>::v(),
347                             TensorShape({batch_size, input_size + cell_size}),
348                             &d_x_component_1_h_prev_compenent_1));
349 
350     Tensor d_x_component_2_h_prevr;
351     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
352                             DataTypeToEnum<T>::v(),
353                             TensorShape({batch_size, input_size + cell_size}),
354                             &d_x_component_2_h_prevr));
355 
356     const Device& device = ctx->eigen_device<Device>();
357 
358     functor::GRUBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
359                                                       cell_size)(
360         ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
361         w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
362         b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_tensor->matrix<T>(),
363         u_tensor->matrix<T>(), c_tensor->matrix<T>(), d_h_tensor->matrix<T>(),
364         d_x_tensor->matrix<T>(), d_h_prev_tensor->matrix<T>(),
365         d_c_bar_tensor->matrix<T>(), d_r_bar_u_bar_tensor->matrix<T>(),
366         d_r_bar_tensor.matrix<T>(), d_u_bar_tensor.matrix<T>(),
367         d_h_prevr_tensor.matrix<T>(),
368         d_x_component_1_h_prev_compenent_1.matrix<T>(),
369         d_x_component_2_h_prevr.matrix<T>());
370   }
371 };
372 
373 // Register the gradient kernel for CPU.
374 #define REGISTER_KERNEL(T)                                                \
375   REGISTER_KERNEL_BUILDER(                                                \
376       Name("GRUBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
377       GRUBlockCellGradOp<CPUDevice, T, false>);
378 
379 REGISTER_KERNEL(float);
380 #undef REGISTER_KERNEL
381 
382 // GPU support.
383 #if GOOGLE_CUDA
384 #define EIGEN_USE_GPU
385 
386 // Forward declare the GPU Fprop functor.
387 namespace functor {
388 #define DECLARE_GPU_SPEC(T)                                                   \
389   template <>                                                                 \
390   void GRUBlockCellFprop<GPUDevice, T, true>::operator()(                     \
391       OpKernelContext* ctx, const GPUDevice& d,                               \
392       typename TTypes<T>::ConstMatrix x,                                      \
393       typename TTypes<T>::ConstMatrix h_prev,                                 \
394       typename TTypes<T>::ConstMatrix w_ru,                                   \
395       typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
396       typename TTypes<T>::ConstVec b_c, typename TTypes<T>::Matrix r_u_bar,   \
397       typename TTypes<T>::Matrix r, typename TTypes<T>::Matrix u,             \
398       typename TTypes<T>::Matrix c, typename TTypes<T>::Matrix h,             \
399       typename TTypes<T>::Matrix x_h_prev,                                    \
400       typename TTypes<T>::Matrix x_h_prevr);                                  \
401   extern template struct GRUBlockCellFprop<GPUDevice, T, true>;
402 
403 DECLARE_GPU_SPEC(float);
404 #undef DECLARE_GPU_SPEC
405 }  // end namespace functor
406 
407 // Register the Block GRU cell kernel for GPU.
408 #define REGISTER_GPU_KERNEL(T)                                        \
409   REGISTER_KERNEL_BUILDER(                                            \
410       Name("GRUBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
411       GRUCellBlockOp<GPUDevice, T, true>);
412 
413 REGISTER_GPU_KERNEL(float);
414 #undef REGISTER_GPU_KERNEL
415 
416 // Forward declare the GPU Bprop functor.
417 namespace functor {
418 #define DECLARE_GPU_SPEC(T)                                                    \
419   template <>                                                                  \
420   void GRUBlockCellBprop<GPUDevice, T, true>::operator()(                      \
421       OpKernelContext* ctx, const GPUDevice& d,                                \
422       typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix h,    \
423       typename TTypes<T>::ConstMatrix w_ru,                                    \
424       typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru,  \
425       typename TTypes<T>::ConstVec b_c, typename TTypes<T>::ConstMatrix r,     \
426       typename TTypes<T>::ConstMatrix u, typename TTypes<T>::ConstMatrix c,    \
427       typename TTypes<T>::ConstMatrix d_h, typename TTypes<T>::Matrix d_x,     \
428       typename TTypes<T>::Matrix d_h_prev, typename TTypes<T>::Matrix d_c_bar, \
429       typename TTypes<T>::Matrix d_r_bar_u_bar,                                \
430       typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar,  \
431       typename TTypes<T>::Matrix d_h_prevr,                                    \
432       typename TTypes<T>::Matrix d_x_comp1_h_prev_comp1,                       \
433       typename TTypes<T>::Matrix d_x_comp2_and_h_prevr);                       \
434   extern template struct GRUBlockCellBprop<GPUDevice, T, true>;
435 
436 DECLARE_GPU_SPEC(float);
437 #undef DECLARE_GPU_SPEC
438 }  // end namespace functor
439 
440 // Register the gradient kernel for GPU.
441 #define REGISTER_GPU_KERNEL(T)                                            \
442   REGISTER_KERNEL_BUILDER(                                                \
443       Name("GRUBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
444       GRUBlockCellGradOp<GPUDevice, T, true>);
445 
446 REGISTER_GPU_KERNEL(float);
447 #undef REGISTER_GPU_KERNEL
448 #endif  // GOOGLE_CUDA
449 
450 }  // end namespace tensorflow
451