1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif // GOOGLE_CUDA
23
24 #include "tensorflow/core/kernels/reverse_sequence_op.h"
25
26 #include <memory>
27 #include <vector>
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37
38 namespace tensorflow {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42
43 template <typename Device, typename Tlen>
CheckErrors(OpKernelContext * context,int batch_dim,int seq_dim)44 void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
45 const Tensor& input = context->input(0);
46 const Tensor& seq_lens = context->input(1);
47
48 auto seq_lens_t = seq_lens.vec<Tlen>();
49
50 std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
51
52 // Copy seq_len info down for validity checks
53 context->eigen_device<Device>().memcpyDeviceToHost(
54 seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size());
55
56 OP_REQUIRES(context, batch_dim != seq_dim,
57 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
58 OP_REQUIRES(context, seq_dim < input.dims(),
59 errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
60 seq_dim, " vs. ", input.dims(), ")"));
61 OP_REQUIRES(context, batch_dim < input.dims(),
62 errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
63 batch_dim, " vs. ", input.dims(), ")"));
64 OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
65 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
66 "), ", "(", seq_lens.NumElements(),
67 " vs. ", input.dim_size(batch_dim), ")"));
68
69 for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
70 OP_REQUIRES(context, seq_lens_vec[d] >= 0,
71 errors::InvalidArgument("seq_lens(", d, ") < 0"));
72 OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
73 errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
74 seq_dim, ")"));
75 }
76 }
77
CheckErrorsGPU(OpKernelContext * context,int batch_dim,int seq_dim)78 void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
79 const Tensor& input = context->input(0);
80 const Tensor& seq_lens = context->input(1);
81
82 OP_REQUIRES(context, batch_dim != seq_dim,
83 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
84 OP_REQUIRES(context, seq_dim < input.dims(),
85 errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
86 seq_dim, " vs. ", input.dims(), ")"));
87 OP_REQUIRES(context, batch_dim < input.dims(),
88 errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
89 batch_dim, " vs. ", input.dims(), ")"));
90
91 OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
92 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
93 "), ", "(", seq_lens.NumElements(),
94 " vs. ", input.dim_size(batch_dim), ")"));
95 }
96
97 template <>
CheckErrors(OpKernelContext * context,int batch_dim,int seq_dim)98 void CheckErrors<GPUDevice, int32>(OpKernelContext* context, int batch_dim,
99 int seq_dim) {
100 CheckErrorsGPU(context, batch_dim, seq_dim);
101 }
102
103 template <>
CheckErrors(OpKernelContext * context,int batch_dim,int seq_dim)104 void CheckErrors<GPUDevice, int64>(OpKernelContext* context, int batch_dim,
105 int seq_dim) {
106 CheckErrorsGPU(context, batch_dim, seq_dim);
107 }
108
109 template <typename Device, typename T, typename Tlen>
110 class ReverseSequenceOp : public OpKernel {
111 public:
ReverseSequenceOp(OpKernelConstruction * context)112 explicit ReverseSequenceOp(OpKernelConstruction* context)
113 : OpKernel(context) {
114 OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
115 OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
116 }
117
Compute(OpKernelContext * context)118 void Compute(OpKernelContext* context) override {
119 const Tensor& input = context->input(0);
120 const Tensor& seq_lens = context->input(1);
121
122 // Preliminary validation of sizes.
123 OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
124 errors::InvalidArgument("seq_lens input must be 1-dim, not ",
125 seq_lens.dims()));
126
127 auto seq_lens_t = seq_lens.vec<Tlen>();
128
129 CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
130 if (!context->status().ok()) return;
131
132 const int input_dims = input.dims();
133
134 Tensor* output = nullptr;
135 OP_REQUIRES_OK(context,
136 context->allocate_output(0, input.shape(), &output));
137
138 #define HANDLE_DIM(NDIM) \
139 case NDIM: \
140 functor::ReverseSequence<Device, T, Tlen, NDIM>::Compute( \
141 context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
142 seq_dim_, seq_lens_t, output->tensor<T, NDIM>()); \
143 break;
144
145 switch (input_dims) {
146 HANDLE_DIM(2);
147 HANDLE_DIM(3);
148 HANDLE_DIM(4);
149 HANDLE_DIM(5);
150
151 default:
152 OP_REQUIRES(context, false,
153 errors::InvalidArgument(
154 "ReverseSequenceOp : Unhandled input dimensions: ",
155 input_dims));
156 }
157 }
158
159 private:
160 int32 batch_dim_;
161 int32 seq_dim_;
162
163 TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
164 };
165
166 #define REGISTER_REVERSE_SEQUENCE(type, len_type) \
167 REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \
168 .Device(DEVICE_CPU) \
169 .TypeConstraint<type>("T") \
170 .TypeConstraint<len_type>("Tlen"), \
171 ReverseSequenceOp<CPUDevice, type, len_type>);
172
173 #define REGISTER_REVERSE_SEQUENCE_LEN(type) \
174 REGISTER_REVERSE_SEQUENCE(type, int32); \
175 REGISTER_REVERSE_SEQUENCE(type, int64);
176
177 TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
178 TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN);
179
180 #if GOOGLE_CUDA
181
182 // Forward declarations of the functor specializations for GPU.
183 namespace functor {
184 #define DECLARE_GPU_SPEC(T, Tlen, Dims) \
185 template <> \
186 void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \
187 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
188 int32 batch_dim, int32 seq_dim, \
189 typename TTypes<Tlen>::ConstVec seq_lens, \
190 typename TTypes<T, Dims>::Tensor output); \
191 extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
192
193 #define DECLARE_GPU_SPEC_LEN(T, Dims) \
194 DECLARE_GPU_SPEC(T, int32, Dims); \
195 DECLARE_GPU_SPEC(T, int64, Dims);
196
197 #define DECLARE_GPU_SPECS(T) \
198 DECLARE_GPU_SPEC_LEN(T, 2); \
199 DECLARE_GPU_SPEC_LEN(T, 3); \
200 DECLARE_GPU_SPEC_LEN(T, 4); \
201 DECLARE_GPU_SPEC_LEN(T, 5);
202
203 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
204 TF_CALL_bool(DECLARE_GPU_SPECS);
205
206 } // namespace functor
207
208 // Registration of the GPU implementations.
209 #define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type) \
210 REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \
211 .Device(DEVICE_GPU) \
212 .TypeConstraint<type>("T") \
213 .TypeConstraint<len_type>("Tlen"), \
214 ReverseSequenceOp<GPUDevice, type, len_type>);
215
216 #define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \
217 REGISTER_REVERSE_SEQUENCE_GPU(type, int32); \
218 REGISTER_REVERSE_SEQUENCE_GPU(type, int64);
219
220 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
221 TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
222
223 #undef REGISTER_REVERSE_SEQUENCE_GPU
224
225 #endif // GOOGLE_CUDA
226
227 } // namespace tensorflow
228