• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA
23 
24 #include "tensorflow/core/kernels/reverse_sequence_op.h"
25 
26 #include <memory>
27 #include <vector>
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename Device, typename Tlen>
CheckErrors(OpKernelContext * context,int batch_dim,int seq_dim)44 void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
45   const Tensor& input = context->input(0);
46   const Tensor& seq_lens = context->input(1);
47 
48   auto seq_lens_t = seq_lens.vec<Tlen>();
49 
50   std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
51 
52   // Copy seq_len info down for validity checks
53   context->eigen_device<Device>().memcpyDeviceToHost(
54       seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size());
55 
56   OP_REQUIRES(context, batch_dim != seq_dim,
57               errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
58   OP_REQUIRES(context, seq_dim < input.dims(),
59               errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
60                                       seq_dim, " vs. ", input.dims(), ")"));
61   OP_REQUIRES(context, batch_dim < input.dims(),
62               errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
63                                       batch_dim, " vs. ", input.dims(), ")"));
64   OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
65               errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
66                                       "), ", "(", seq_lens.NumElements(),
67                                       " vs. ", input.dim_size(batch_dim), ")"));
68 
69   for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
70     OP_REQUIRES(context, seq_lens_vec[d] >= 0,
71                 errors::InvalidArgument("seq_lens(", d, ") < 0"));
72     OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
73                 errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
74                                         seq_dim, ")"));
75   }
76 }
77 
CheckErrorsGPU(OpKernelContext * context,int batch_dim,int seq_dim)78 void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
79   const Tensor& input = context->input(0);
80   const Tensor& seq_lens = context->input(1);
81 
82   OP_REQUIRES(context, batch_dim != seq_dim,
83               errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
84   OP_REQUIRES(context, seq_dim < input.dims(),
85               errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
86                                       seq_dim, " vs. ", input.dims(), ")"));
87   OP_REQUIRES(context, batch_dim < input.dims(),
88               errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
89                                       batch_dim, " vs. ", input.dims(), ")"));
90 
91   OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
92               errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
93                                       "), ", "(", seq_lens.NumElements(),
94                                       " vs. ", input.dim_size(batch_dim), ")"));
95 }
96 
97 template <>
CheckErrors(OpKernelContext * context,int batch_dim,int seq_dim)98 void CheckErrors<GPUDevice, int32>(OpKernelContext* context, int batch_dim,
99                                    int seq_dim) {
100   CheckErrorsGPU(context, batch_dim, seq_dim);
101 }
102 
103 template <>
CheckErrors(OpKernelContext * context,int batch_dim,int seq_dim)104 void CheckErrors<GPUDevice, int64>(OpKernelContext* context, int batch_dim,
105                                    int seq_dim) {
106   CheckErrorsGPU(context, batch_dim, seq_dim);
107 }
108 
109 template <typename Device, typename T, typename Tlen>
110 class ReverseSequenceOp : public OpKernel {
111  public:
ReverseSequenceOp(OpKernelConstruction * context)112   explicit ReverseSequenceOp(OpKernelConstruction* context)
113       : OpKernel(context) {
114     OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
115     OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
116   }
117 
Compute(OpKernelContext * context)118   void Compute(OpKernelContext* context) override {
119     const Tensor& input = context->input(0);
120     const Tensor& seq_lens = context->input(1);
121 
122     // Preliminary validation of sizes.
123     OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
124                 errors::InvalidArgument("seq_lens input must be 1-dim, not ",
125                                         seq_lens.dims()));
126 
127     auto seq_lens_t = seq_lens.vec<Tlen>();
128 
129     CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
130     if (!context->status().ok()) return;
131 
132     const int input_dims = input.dims();
133 
134     Tensor* output = nullptr;
135     OP_REQUIRES_OK(context,
136                    context->allocate_output(0, input.shape(), &output));
137 
138 #define HANDLE_DIM(NDIM)                                                      \
139   case NDIM:                                                                  \
140     functor::ReverseSequence<Device, T, Tlen, NDIM>::Compute(                 \
141         context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
142         seq_dim_, seq_lens_t, output->tensor<T, NDIM>());                     \
143     break;
144 
145     switch (input_dims) {
146       HANDLE_DIM(2);
147       HANDLE_DIM(3);
148       HANDLE_DIM(4);
149       HANDLE_DIM(5);
150 
151       default:
152         OP_REQUIRES(context, false,
153                     errors::InvalidArgument(
154                         "ReverseSequenceOp : Unhandled input dimensions: ",
155                         input_dims));
156     }
157   }
158 
159  private:
160   int32 batch_dim_;
161   int32 seq_dim_;
162 
163   TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
164 };
165 
166 #define REGISTER_REVERSE_SEQUENCE(type, len_type)                \
167   REGISTER_KERNEL_BUILDER(Name("ReverseSequence")                \
168                               .Device(DEVICE_CPU)                \
169                               .TypeConstraint<type>("T")         \
170                               .TypeConstraint<len_type>("Tlen"), \
171                           ReverseSequenceOp<CPUDevice, type, len_type>);
172 
173 #define REGISTER_REVERSE_SEQUENCE_LEN(type) \
174   REGISTER_REVERSE_SEQUENCE(type, int32);   \
175   REGISTER_REVERSE_SEQUENCE(type, int64);
176 
177 TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
178 TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN);
179 
180 #if GOOGLE_CUDA
181 
182 // Forward declarations of the functor specializations for GPU.
183 namespace functor {
184 #define DECLARE_GPU_SPEC(T, Tlen, Dims)                                \
185   template <>                                                          \
186   void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute(             \
187       const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
188       int32 batch_dim, int32 seq_dim,                                  \
189       typename TTypes<Tlen>::ConstVec seq_lens,                        \
190       typename TTypes<T, Dims>::Tensor output);                        \
191   extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
192 
193 #define DECLARE_GPU_SPEC_LEN(T, Dims) \
194   DECLARE_GPU_SPEC(T, int32, Dims);   \
195   DECLARE_GPU_SPEC(T, int64, Dims);
196 
197 #define DECLARE_GPU_SPECS(T)  \
198   DECLARE_GPU_SPEC_LEN(T, 2); \
199   DECLARE_GPU_SPEC_LEN(T, 3); \
200   DECLARE_GPU_SPEC_LEN(T, 4); \
201   DECLARE_GPU_SPEC_LEN(T, 5);
202 
203 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
204 TF_CALL_bool(DECLARE_GPU_SPECS);
205 
206 }  // namespace functor
207 
208 // Registration of the GPU implementations.
209 #define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type)            \
210   REGISTER_KERNEL_BUILDER(Name("ReverseSequence")                \
211                               .Device(DEVICE_GPU)                \
212                               .TypeConstraint<type>("T")         \
213                               .TypeConstraint<len_type>("Tlen"), \
214                           ReverseSequenceOp<GPUDevice, type, len_type>);
215 
216 #define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \
217   REGISTER_REVERSE_SEQUENCE_GPU(type, int32);   \
218   REGISTER_REVERSE_SEQUENCE_GPU(type, int64);
219 
220 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
221 TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
222 
223 #undef REGISTER_REVERSE_SEQUENCE_GPU
224 
225 #endif  // GOOGLE_CUDA
226 
227 }  // namespace tensorflow
228