• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
18 
19 #include <type_traits>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_types.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace tensorflow {
28 
29 // Maximum number of non-collapsible blocked dimensions supported by the
30 // {SpaceToBatch,BatchToSpace}ND operation.  To change the limit, modify this
31 // constant and the TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS macro definition
32 // below.
33 constexpr int kMaxSpaceToBatchBlockDims = 4;
34 
35 // Expands to:
36 //   MACRO(1, ## __VA_ARGS__)
37 //   ...
38 //   MACRO(kMaxSpaceToBatchBlockDims, ## __VA_ARGS__)
39 //
40 // Note: The space between the number and the comma is necessary for proper GCC
41 // comma handling: https://gcc.gnu.org/onlinedocs/cpp/Variadic-Macros.html
42 #define TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(MACRO, ...) \
43   MACRO(1 /**/, ##__VA_ARGS__)                              \
44   MACRO(2 /**/, ##__VA_ARGS__)                              \
45   MACRO(3 /**/, ##__VA_ARGS__)                              \
46   MACRO(4 /**/, ##__VA_ARGS__)                              \
47   /**/
48 
49 namespace internal {
50 namespace spacetobatch {
51 
52 template <typename InputType, typename OutputType>
SubtleMustCopyFlatHelper(const Tensor & t,OutputType * output)53 void SubtleMustCopyFlatHelper(const Tensor& t, OutputType* output) {
54   const int64 num_elements = t.shape().num_elements();
55   output->resize(num_elements);
56   auto eigen_vec = t.flat<InputType>();
57   for (int64 i = 0; i < num_elements; ++i) {
58     (*output)[i] = SubtleMustCopy(eigen_vec(i));
59   }
60 }
61 
62 // Copies flat contents of `t` to std::vector-like `*output`, which is resized
63 // as needed.  `OutputType` may be either `std::vector<int64>` or
64 // `gtl::InlinedVector<int64>`.
65 //
66 // Precondition: t.dtype() must be either DT_INT32 or DT_INT64.
67 template <typename OutputType>
SubtleMustCopyFlat(const Tensor & t,OutputType * output)68 void SubtleMustCopyFlat(const Tensor& t, OutputType* output) {
69   if (t.dtype() == DT_INT32) {
70     SubtleMustCopyFlatHelper<int32, OutputType>(t, output);
71   } else {
72     SubtleMustCopyFlatHelper<int64, OutputType>(t, output);
73   }
74 }
75 
76 }  // namespace spacetobatch
77 }  // namespace internal
78 
79 namespace functor {
80 
81 // Functor used by {SpaceToBatch,BatchToSpace}{ND,}Op to do the conversion.
82 //
83 // If B2S is false, then this performs the space-to-batch conversion.  If B2S is
84 // true, then this performs the inverse batch-to-space conversion.
85 template <typename Device, typename T, int NUM_BLOCK_DIMS, bool B2S = false>
86 struct SpaceToBatchFunctor {
87   using InputT = typename std::conditional<B2S, T, const T>::type;
88   using OutputT = typename std::conditional<B2S, const T, T>::type;
89   // Implements the space to batch conversion.
90   //
91   // space_tensor: input tensor of space-to-batch operation.  If B2S = false,
92   //     then this is the input to the conversion.  If B2S = true, then this
93   //     is the output of the conversion.
94   // block_size: array of shape [NUM_BLOCK_DIMS] specifying the block sizes for
95   //     dimensions 1 through NUM_BLOCK_DIMS.
96   // paddings: row-major array of shape [NUM_BLOCK_DIMS, 2] specifying the
97   //     start and end padding for dimensions 1 through NUM_BLOCK_DIMS.
98   // batch_tensor: output tensor of the space-to-batch operation.  If
99   //     B2S = false, then this is the output of the conversion.  If B2S = true,
100   //     then this is the input to the conversion.
101   //
102   // The caller must ensure that the dimensions of the tensors are correct.
103   Status operator()(
104       const Device& d,
105       typename TTypes<InputT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
106       const int64 block_shape[NUM_BLOCK_DIMS],
107       const int64 paddings[NUM_BLOCK_DIMS * 2],
108       typename TTypes<OutputT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor);
109 };
110 
111 }  // namespace functor
112 }  // namespace tensorflow
113 
114 #endif  // TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
115