1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "tensorflow/core/kernels/spacetobatch_functor.h"
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace tensorflow {
38
39 typedef Eigen::ThreadPoolDevice CPUDevice;
40 typedef Eigen::GpuDevice GPUDevice;
41
42 template <typename Device, typename T>
BatchToSpaceOpCompute(OpKernelContext * context,const Tensor & orig_input_tensor,const Tensor & orig_block_shape,const Tensor & orig_crops)43 static void BatchToSpaceOpCompute(OpKernelContext* context,
44 const Tensor& orig_input_tensor,
45 const Tensor& orig_block_shape,
46 const Tensor& orig_crops) {
47 const int input_dims = orig_input_tensor.dims();
48 OP_REQUIRES(
49 context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
50 errors::InvalidArgument("block_shape rank should be 1 instead of ",
51 orig_block_shape.dims()));
52
53 const int block_dims = orig_block_shape.dim_size(0);
54 OP_REQUIRES(
55 context, orig_input_tensor.dims() >= 1 + block_dims,
56 errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
57 " instead of ", orig_input_tensor.dims()));
58
59 OP_REQUIRES(context,
60 TensorShapeUtils::IsMatrix(orig_crops.shape()) &&
61 block_dims == orig_crops.dim_size(0) &&
62 2 == orig_crops.dim_size(1),
63 errors::InvalidArgument("crops should have shape [", block_dims,
64 ", 2] instead of ",
65 orig_crops.shape().DebugString()));
66 // To avoid out-of-bounds access in the case that the block_shape and/or
67 // crops tensors are concurrently modified, we must copy the values.
68 gtl::InlinedVector<int64, 4> block_shape;
69 gtl::InlinedVector<int64, 8> crops;
70 internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
71 internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops);
72
73 // Determine the length of the prefix of block dims that can be combined
74 // into the batch dimension due to having no padding and block_shape=1.
75 int removed_prefix_block_dims = 0;
76 for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
77 const int dim = removed_prefix_block_dims;
78 if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
79 block_shape[dim] != 1) {
80 break;
81 }
82 }
83
84 // Determine the length of the suffix of block dims that can be combined
85 // into the depth dimension due to having no padding and block_shape=1.
86 int removed_suffix_block_dims = 0;
87 for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
88 ++removed_suffix_block_dims) {
89 const int dim = block_dims - 1 - removed_suffix_block_dims;
90 if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
91 block_shape[dim] != 1) {
92 break;
93 }
94 }
95
96 // Compute the product of the block_shape values.
97 int64 block_shape_product = 1;
98 for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
99 block_shape_product *= block_shape[block_dim];
100 }
101 OP_REQUIRES(
102 context, block_shape_product > 0,
103 errors::InvalidArgument("Product of block sizes must be positive, got ",
104 block_shape_product));
105
106 const int64 orig_input_batch_size = orig_input_tensor.dim_size(0);
107 OP_REQUIRES(
108 context, orig_input_batch_size % block_shape_product == 0,
109 errors::InvalidArgument("Input batch dimension (", orig_input_batch_size,
110 ") is not divisible by product of block sizes (",
111 block_shape_product, ")"));
112
113 const int internal_block_dims =
114 block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
115 OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
116 errors::InvalidArgument(
117 "Maximum number of non-combined block dimensions is ",
118 internal_block_dims, " but must not exceed ",
119 kMaxSpaceToBatchBlockDims));
120
121 if (internal_block_dims == 0) {
122 context->set_output(0, orig_input_tensor);
123 return;
124 }
125
126 // For the purpose of computing the result, the input will be treated as
127 // having this shape, of rank 2 + internal_block_dims.
128 TensorShape internal_input_shape;
129
130 // For the purpose of computing the result, the output will be treated as
131 // having this shape, of rank 2 + internal_block_dims.
132 TensorShape internal_output_shape;
133
134 // The actual output shape exposed to callers.
135 TensorShape external_output_shape;
136
137 external_output_shape.AddDim(orig_input_batch_size / block_shape_product);
138
139 int64 input_batch_size = orig_input_batch_size;
140 for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
141 const int64 size = orig_input_tensor.dim_size(block_dim + 1);
142 input_batch_size *= size;
143 external_output_shape.AddDim(size);
144 }
145 internal_input_shape.AddDim(input_batch_size);
146 internal_output_shape.AddDim(input_batch_size / block_shape_product);
147
148 for (int block_dim = removed_prefix_block_dims;
149 block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
150 const int64 crop_start = crops[2 * block_dim],
151 crop_end = crops[2 * block_dim + 1];
152 OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0,
153 errors::InvalidArgument("Crops must be non-negative"));
154 const int64 input_size = orig_input_tensor.dim_size(block_dim + 1);
155 const int64 block_shape_value = block_shape[block_dim];
156 const int64 cropped_size =
157 input_size * block_shape_value - crop_start - crop_end;
158 OP_REQUIRES(context, cropped_size >= 0,
159 errors::InvalidArgument("cropped_shape[", block_dim, "]=",
160 cropped_size, " must be non-negative"));
161 internal_input_shape.AddDim(input_size);
162 internal_output_shape.AddDim(cropped_size);
163 external_output_shape.AddDim(cropped_size);
164 }
165
166 int64 depth = 1;
167 for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
168 ++dim) {
169 const int64 size = orig_input_tensor.dim_size(dim);
170 external_output_shape.AddDim(size);
171 depth *= size;
172 }
173 internal_input_shape.AddDim(depth);
174 internal_output_shape.AddDim(depth);
175
176 // Allocate output tensor.
177 Tensor* output_tensor = nullptr;
178 OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
179 &output_tensor));
180
181 const int64* internal_crops = &crops[2 * removed_prefix_block_dims];
182 const int64* internal_block_shape = &block_shape[removed_prefix_block_dims];
183
184 switch (internal_block_dims) {
185 #define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
186 case NUM_BLOCK_DIMS: { \
187 OP_REQUIRES_OK( \
188 context, \
189 (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \
190 context->eigen_device<Device>(), \
191 output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
192 internal_output_shape.dim_sizes()), \
193 internal_block_shape, internal_crops, \
194 orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
195 internal_input_shape.dim_sizes())))); \
196 } break; \
197 /**/
198 TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE)
199 #undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE
200 }
201 }
202
203 template <typename Device, typename T>
204 class BatchToSpaceNDOp : public OpKernel {
205 public:
BatchToSpaceNDOp(OpKernelConstruction * context)206 explicit BatchToSpaceNDOp(OpKernelConstruction* context)
207 : OpKernel(context) {}
208
Compute(OpKernelContext * context)209 void Compute(OpKernelContext* context) override {
210 const Tensor& orig_input_tensor = context->input(0);
211 const Tensor& orig_block_shape = context->input(1);
212 const Tensor& orig_crops = context->input(2);
213 BatchToSpaceOpCompute<Device, T>(context, orig_input_tensor,
214 orig_block_shape, orig_crops);
215 }
216 };
217
218 template <typename Device, typename T>
219 class BatchToSpaceOp : public OpKernel {
220 public:
BatchToSpaceOp(OpKernelConstruction * context)221 explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) {
222 OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
223 OP_REQUIRES(
224 context, block_size_ > 1,
225 errors::InvalidArgument("Block size should be > 1: ", block_size_));
226 // We don't use context->allocate_persistent because the allocation must
227 // happen on the CPU regardless of Device.
228 block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
229 auto block_shape_vec = block_shape_.vec<int64>();
230 block_shape_vec(0) = block_size_;
231 block_shape_vec(1) = block_size_;
232 }
233
Compute(OpKernelContext * context)234 void Compute(OpKernelContext* context) override {
235 const Tensor& in0 = context->input(0);
236 const Tensor& in1 = context->input(1);
237 const int dims = in0.dims();
238
239 // Check on the input dimensions first.
240 // The input is presumed to be [batch, height, width, depth]
241 static const int kRequiredDims = 4;
242 OP_REQUIRES(context, kRequiredDims == dims,
243 errors::InvalidArgument("Input rank should be: ", kRequiredDims,
244 "instead of: ", dims));
245 BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1);
246 }
247
248 private:
249 int block_size_;
250 Tensor block_shape_;
251 };
252
253 #define REGISTER(T) \
254 REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
255 .Device(DEVICE_CPU) \
256 .TypeConstraint<T>("T") \
257 .HostMemory("block_shape") \
258 .HostMemory("crops"), \
259 BatchToSpaceNDOp<CPUDevice, T>); \
260 REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
261 .Device(DEVICE_CPU) \
262 .TypeConstraint<T>("T") \
263 .HostMemory("crops"), \
264 BatchToSpaceOp<CPUDevice, T>);
265
266 TF_CALL_REAL_NUMBER_TYPES(REGISTER);
267 #undef REGISTER
268
269 #if GOOGLE_CUDA
270 #define REGISTER(T) \
271 REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
272 .Device(DEVICE_GPU) \
273 .TypeConstraint<T>("T") \
274 .HostMemory("block_shape") \
275 .HostMemory("crops"), \
276 BatchToSpaceNDOp<GPUDevice, T>); \
277 REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
278 .Device(DEVICE_GPU) \
279 .TypeConstraint<T>("T") \
280 .HostMemory("crops"), \
281 BatchToSpaceOp<GPUDevice, T>);
282
283 TF_CALL_GPU_NUMBER_TYPES(REGISTER);
284 #undef REGISTER
285 #endif // GOOGLE_CUDA
286
287 } // end namespace tensorflow
288