• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "tensorflow/core/kernels/spacetobatch_functor.h"
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace tensorflow {
38 
39 typedef Eigen::ThreadPoolDevice CPUDevice;
40 typedef Eigen::GpuDevice GPUDevice;
41 
42 template <typename Device, typename T>
BatchToSpaceOpCompute(OpKernelContext * context,const Tensor & orig_input_tensor,const Tensor & orig_block_shape,const Tensor & orig_crops)43 static void BatchToSpaceOpCompute(OpKernelContext* context,
44                                   const Tensor& orig_input_tensor,
45                                   const Tensor& orig_block_shape,
46                                   const Tensor& orig_crops) {
47   const int input_dims = orig_input_tensor.dims();
48   OP_REQUIRES(
49       context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
50       errors::InvalidArgument("block_shape rank should be 1 instead of ",
51                               orig_block_shape.dims()));
52 
53   const int block_dims = orig_block_shape.dim_size(0);
54   OP_REQUIRES(
55       context, orig_input_tensor.dims() >= 1 + block_dims,
56       errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
57                               " instead of ", orig_input_tensor.dims()));
58 
59   OP_REQUIRES(context,
60               TensorShapeUtils::IsMatrix(orig_crops.shape()) &&
61                   block_dims == orig_crops.dim_size(0) &&
62                   2 == orig_crops.dim_size(1),
63               errors::InvalidArgument("crops should have shape [", block_dims,
64                                       ", 2] instead of ",
65                                       orig_crops.shape().DebugString()));
66   // To avoid out-of-bounds access in the case that the block_shape and/or
67   // crops tensors are concurrently modified, we must copy the values.
68   gtl::InlinedVector<int64, 4> block_shape;
69   gtl::InlinedVector<int64, 8> crops;
70   internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
71   internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops);
72 
73   // Determine the length of the prefix of block dims that can be combined
74   // into the batch dimension due to having no padding and block_shape=1.
75   int removed_prefix_block_dims = 0;
76   for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
77     const int dim = removed_prefix_block_dims;
78     if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
79         block_shape[dim] != 1) {
80       break;
81     }
82   }
83 
84   // Determine the length of the suffix of block dims that can be combined
85   // into the depth dimension due to having no padding and block_shape=1.
86   int removed_suffix_block_dims = 0;
87   for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
88        ++removed_suffix_block_dims) {
89     const int dim = block_dims - 1 - removed_suffix_block_dims;
90     if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
91         block_shape[dim] != 1) {
92       break;
93     }
94   }
95 
96   // Compute the product of the block_shape values.
97   int64 block_shape_product = 1;
98   for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
99     block_shape_product *= block_shape[block_dim];
100   }
101   OP_REQUIRES(
102       context, block_shape_product > 0,
103       errors::InvalidArgument("Product of block sizes must be positive, got ",
104                               block_shape_product));
105 
106   const int64 orig_input_batch_size = orig_input_tensor.dim_size(0);
107   OP_REQUIRES(
108       context, orig_input_batch_size % block_shape_product == 0,
109       errors::InvalidArgument("Input batch dimension (", orig_input_batch_size,
110                               ") is not divisible by product of block sizes (",
111                               block_shape_product, ")"));
112 
113   const int internal_block_dims =
114       block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
115   OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
116               errors::InvalidArgument(
117                   "Maximum number of non-combined block dimensions is ",
118                   internal_block_dims, " but must not exceed ",
119                   kMaxSpaceToBatchBlockDims));
120 
121   if (internal_block_dims == 0) {
122     context->set_output(0, orig_input_tensor);
123     return;
124   }
125 
126   // For the purpose of computing the result, the input will be treated as
127   // having this shape, of rank 2 + internal_block_dims.
128   TensorShape internal_input_shape;
129 
130   // For the purpose of computing the result, the output will be treated as
131   // having this shape, of rank 2 + internal_block_dims.
132   TensorShape internal_output_shape;
133 
134   // The actual output shape exposed to callers.
135   TensorShape external_output_shape;
136 
137   external_output_shape.AddDim(orig_input_batch_size / block_shape_product);
138 
139   int64 input_batch_size = orig_input_batch_size;
140   for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
141     const int64 size = orig_input_tensor.dim_size(block_dim + 1);
142     input_batch_size *= size;
143     external_output_shape.AddDim(size);
144   }
145   internal_input_shape.AddDim(input_batch_size);
146   internal_output_shape.AddDim(input_batch_size / block_shape_product);
147 
148   for (int block_dim = removed_prefix_block_dims;
149        block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
150     const int64 crop_start = crops[2 * block_dim],
151                 crop_end = crops[2 * block_dim + 1];
152     OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0,
153                 errors::InvalidArgument("Crops must be non-negative"));
154     const int64 input_size = orig_input_tensor.dim_size(block_dim + 1);
155     const int64 block_shape_value = block_shape[block_dim];
156     const int64 cropped_size =
157         input_size * block_shape_value - crop_start - crop_end;
158     OP_REQUIRES(context, cropped_size >= 0,
159                 errors::InvalidArgument("cropped_shape[", block_dim, "]=",
160                                         cropped_size, " must be non-negative"));
161     internal_input_shape.AddDim(input_size);
162     internal_output_shape.AddDim(cropped_size);
163     external_output_shape.AddDim(cropped_size);
164   }
165 
166   int64 depth = 1;
167   for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
168        ++dim) {
169     const int64 size = orig_input_tensor.dim_size(dim);
170     external_output_shape.AddDim(size);
171     depth *= size;
172   }
173   internal_input_shape.AddDim(depth);
174   internal_output_shape.AddDim(depth);
175 
176   // Allocate output tensor.
177   Tensor* output_tensor = nullptr;
178   OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
179                                                    &output_tensor));
180 
181   const int64* internal_crops = &crops[2 * removed_prefix_block_dims];
182   const int64* internal_block_shape = &block_shape[removed_prefix_block_dims];
183 
184   switch (internal_block_dims) {
185 #define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS)                   \
186   case NUM_BLOCK_DIMS: {                                                  \
187     OP_REQUIRES_OK(                                                       \
188         context,                                                          \
189         (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \
190             context->eigen_device<Device>(),                              \
191             output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>(                 \
192                 internal_output_shape.dim_sizes()),                       \
193             internal_block_shape, internal_crops,                         \
194             orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>(              \
195                 internal_input_shape.dim_sizes()))));                     \
196   } break;                                                                \
197     /**/
198     TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE)
199 #undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE
200   }
201 }
202 
203 template <typename Device, typename T>
204 class BatchToSpaceNDOp : public OpKernel {
205  public:
BatchToSpaceNDOp(OpKernelConstruction * context)206   explicit BatchToSpaceNDOp(OpKernelConstruction* context)
207       : OpKernel(context) {}
208 
Compute(OpKernelContext * context)209   void Compute(OpKernelContext* context) override {
210     const Tensor& orig_input_tensor = context->input(0);
211     const Tensor& orig_block_shape = context->input(1);
212     const Tensor& orig_crops = context->input(2);
213     BatchToSpaceOpCompute<Device, T>(context, orig_input_tensor,
214                                      orig_block_shape, orig_crops);
215   }
216 };
217 
218 template <typename Device, typename T>
219 class BatchToSpaceOp : public OpKernel {
220  public:
BatchToSpaceOp(OpKernelConstruction * context)221   explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) {
222     OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
223     OP_REQUIRES(
224         context, block_size_ > 1,
225         errors::InvalidArgument("Block size should be > 1: ", block_size_));
226     // We don't use context->allocate_persistent because the allocation must
227     // happen on the CPU regardless of Device.
228     block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
229     auto block_shape_vec = block_shape_.vec<int64>();
230     block_shape_vec(0) = block_size_;
231     block_shape_vec(1) = block_size_;
232   }
233 
Compute(OpKernelContext * context)234   void Compute(OpKernelContext* context) override {
235     const Tensor& in0 = context->input(0);
236     const Tensor& in1 = context->input(1);
237     const int dims = in0.dims();
238 
239     // Check on the input dimensions first.
240     // The input is presumed to be [batch, height, width, depth]
241     static const int kRequiredDims = 4;
242     OP_REQUIRES(context, kRequiredDims == dims,
243                 errors::InvalidArgument("Input rank should be: ", kRequiredDims,
244                                         "instead of: ", dims));
245     BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1);
246   }
247 
248  private:
249   int block_size_;
250   Tensor block_shape_;
251 };
252 
253 #define REGISTER(T)                                        \
254   REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND")           \
255                               .Device(DEVICE_CPU)          \
256                               .TypeConstraint<T>("T")      \
257                               .HostMemory("block_shape")   \
258                               .HostMemory("crops"),        \
259                           BatchToSpaceNDOp<CPUDevice, T>); \
260   REGISTER_KERNEL_BUILDER(Name("BatchToSpace")             \
261                               .Device(DEVICE_CPU)          \
262                               .TypeConstraint<T>("T")      \
263                               .HostMemory("crops"),        \
264                           BatchToSpaceOp<CPUDevice, T>);
265 
266 TF_CALL_REAL_NUMBER_TYPES(REGISTER);
267 #undef REGISTER
268 
269 #if GOOGLE_CUDA
270 #define REGISTER(T)                                        \
271   REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND")           \
272                               .Device(DEVICE_GPU)          \
273                               .TypeConstraint<T>("T")      \
274                               .HostMemory("block_shape")   \
275                               .HostMemory("crops"),        \
276                           BatchToSpaceNDOp<GPUDevice, T>); \
277   REGISTER_KERNEL_BUILDER(Name("BatchToSpace")             \
278                               .Device(DEVICE_GPU)          \
279                               .TypeConstraint<T>("T")      \
280                               .HostMemory("crops"),        \
281                           BatchToSpaceOp<GPUDevice, T>);
282 
283 TF_CALL_GPU_NUMBER_TYPES(REGISTER);
284 #undef REGISTER
285 #endif  // GOOGLE_CUDA
286 
287 }  // end namespace tensorflow
288