• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 
21 namespace tensorflow {
22 namespace {
23 
SpaceToBatch(XlaOpKernelContext * ctx,const xla::XlaOp & input,DataType input_dtype,const TensorShape & input_tensor_shape,absl::Span<const int64> block_shape,const xla::Literal & paddings)24 void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
25                   DataType input_dtype, const TensorShape& input_tensor_shape,
26                   absl::Span<const int64> block_shape,
27                   const xla::Literal& paddings) {
28   const int input_rank = input_tensor_shape.dims();
29   const absl::InlinedVector<int64, 4> input_shape =
30       input_tensor_shape.dim_sizes();
31   const int block_rank = block_shape.size();
32 
33   OP_REQUIRES(
34       ctx, input_rank >= 1 + block_rank,
35       errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
36                               " instead of ", input_rank));
37   absl::Span<const int64> remainder_shape(input_shape);
38   remainder_shape.remove_prefix(1 + block_rank);
39 
40   OP_REQUIRES(
41       ctx,
42       paddings.shape().rank() == 2 &&
43           block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
44           2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
45       errors::InvalidArgument("paddings should have shape [", block_rank,
46                               ", 2] instead of ",
47                               xla::ShapeUtil::HumanString(paddings.shape())));
48 
49   xla::XlaBuilder* b = ctx->builder();
50 
51   // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
52   //  input according to `paddings` to produce `padded` of shape `padded_shape`.
53   xla::PaddingConfig padding_config;
54   std::vector<int64> padded_shape(input_shape.begin(), input_shape.end());
55   int64 block_num_elems = 1LL;
56   padding_config.add_dimensions();  // Don't pad the batch dimension.
57   for (int i = 0; i < block_rank; ++i) {
58     auto* dim = padding_config.add_dimensions();
59     int64 pad_start = paddings.Get<int64>({i, 0});
60     int64 pad_end = paddings.Get<int64>({i, 1});
61     OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
62                 errors::InvalidArgument("Paddings must be non-negative"));
63     dim->set_edge_padding_low(pad_start);
64     dim->set_edge_padding_high(pad_end);
65     padded_shape[1 + i] += pad_start + pad_end;
66     block_num_elems *= block_shape[i];
67   }
68   // Don't pad the remainder dimensions.
69   for (int i = 0; i < remainder_shape.size(); ++i) {
70     padding_config.add_dimensions();
71   }
72   OP_REQUIRES(ctx, block_num_elems > 0,
73               errors::InvalidArgument(
74                   "The product of the block dimensions must be positive"));
75 
76   xla::XlaOp padded =
77       xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
78 
79   // 2. Reshape `padded` to `reshaped_padded` of shape:
80   //
81   //      [batch] +
82   //      [padded_shape[1] / block_shape[0],
83   //        block_shape[0],
84   //       ...,
85   //       padded_shape[M] / block_shape[M-1],
86   //       block_shape[M-1]] +
87   //      remaining_shape
88   const int64 batch_size = input_shape[0];
89   std::vector<int64> reshaped_padded_shape(input_rank + block_rank);
90   reshaped_padded_shape[0] = batch_size;
91   for (int i = 0; i < block_rank; ++i) {
92     OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
93                 errors::InvalidArgument("padded_shape[", 1 + i,
94                                         "]=", padded_shape[1 + i],
95                                         " is not divisible by block_shape[", i,
96                                         "]=", block_shape[i]));
97 
98     reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
99     reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
100   }
101   std::copy(remainder_shape.begin(), remainder_shape.end(),
102             reshaped_padded_shape.begin() + 1 + 2 * block_rank);
103 
104   xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape);
105 
106   // 3. Permute dimensions of `reshaped_padded` to produce
107   //    `permuted_reshaped_padded` of shape:
108   //
109   //      block_shape +
110   //      [batch] +
111   //      [padded_shape[1] / block_shape[0],
112   //       ...,
113   //       padded_shape[M] / block_shape[M-1]] +
114   //      remaining_shape
115   std::vector<int64> permutation(reshaped_padded_shape.size());
116   for (int i = 0; i < block_rank; ++i) {
117     permutation[i] = 1 + 2 * i + 1;
118     permutation[block_rank + 1 + i] = 1 + 2 * i;
119   }
120   permutation[block_rank] = 0;
121   std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
122             1 + block_rank * 2);
123   xla::XlaOp permuted_reshaped_padded =
124       xla::Transpose(reshaped_padded, permutation);
125 
126   // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
127   //    batch dimension, producing an output tensor of shape:
128   //
129   //      [batch * prod(block_shape)] +
130   //      [padded_shape[1] / block_shape[0],
131   //       ...,
132   //       padded_shape[M] / block_shape[M-1]] +
133   //      remaining_shape
134   // Determine the length of the prefix of block dims that can be combined
135   // into the batch dimension due to having no padding and block_shape=1.
136   std::vector<int64> output_shape(input_rank);
137   output_shape[0] = batch_size * block_num_elems;
138   for (int i = 0; i < block_rank; ++i) {
139     output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
140   }
141   std::copy(remainder_shape.begin(), remainder_shape.end(),
142             output_shape.begin() + 1 + block_rank);
143 
144   xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape);
145   ctx->SetOutput(0, output);
146 }
147 
148 class SpaceToBatchNDOp : public XlaOpKernel {
149  public:
SpaceToBatchNDOp(OpKernelConstruction * ctx)150   explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
151 
Compile(XlaOpKernelContext * ctx)152   void Compile(XlaOpKernelContext* ctx) override {
153     std::vector<int64> block_shape;
154     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
155 
156     xla::Literal paddings;
157     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
158 
159     SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
160                  block_shape, paddings);
161   }
162 };
163 REGISTER_XLA_OP(Name("SpaceToBatchND")
164                     .CompileTimeConstantInput("paddings")
165                     .CompileTimeConstantInput("block_shape"),
166                 SpaceToBatchNDOp);
167 
168 class SpaceToBatchOp : public XlaOpKernel {
169  public:
SpaceToBatchOp(OpKernelConstruction * ctx)170   explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
171     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
172     OP_REQUIRES(
173         ctx, block_size_ > 1,
174         errors::InvalidArgument("Block size should be > 1: ", block_size_));
175   }
176 
Compile(XlaOpKernelContext * ctx)177   void Compile(XlaOpKernelContext* ctx) override {
178     xla::Literal paddings;
179     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
180 
181     SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
182                  {block_size_, block_size_}, paddings);
183   }
184 
185  private:
186   int block_size_;
187 };
188 REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstantInput("paddings"),
189                 SpaceToBatchOp);
190 
191 }  // namespace
192 }  // namespace tensorflow
193