1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20
21 namespace tensorflow {
22 namespace {
23
SpaceToBatch(XlaOpKernelContext * ctx,const xla::XlaOp & input,DataType input_dtype,const TensorShape & input_tensor_shape,absl::Span<const int64> block_shape,const xla::Literal & paddings)24 void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
25 DataType input_dtype, const TensorShape& input_tensor_shape,
26 absl::Span<const int64> block_shape,
27 const xla::Literal& paddings) {
28 const int input_rank = input_tensor_shape.dims();
29 const absl::InlinedVector<int64, 4> input_shape =
30 input_tensor_shape.dim_sizes();
31 const int block_rank = block_shape.size();
32
33 OP_REQUIRES(
34 ctx, input_rank >= 1 + block_rank,
35 errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
36 " instead of ", input_rank));
37 absl::Span<const int64> remainder_shape(input_shape);
38 remainder_shape.remove_prefix(1 + block_rank);
39
40 OP_REQUIRES(
41 ctx,
42 paddings.shape().rank() == 2 &&
43 block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
44 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
45 errors::InvalidArgument("paddings should have shape [", block_rank,
46 ", 2] instead of ",
47 xla::ShapeUtil::HumanString(paddings.shape())));
48
49 xla::XlaBuilder* b = ctx->builder();
50
51 // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
52 // input according to `paddings` to produce `padded` of shape `padded_shape`.
53 xla::PaddingConfig padding_config;
54 std::vector<int64> padded_shape(input_shape.begin(), input_shape.end());
55 int64 block_num_elems = 1LL;
56 padding_config.add_dimensions(); // Don't pad the batch dimension.
57 for (int i = 0; i < block_rank; ++i) {
58 auto* dim = padding_config.add_dimensions();
59 int64 pad_start = paddings.Get<int64>({i, 0});
60 int64 pad_end = paddings.Get<int64>({i, 1});
61 OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
62 errors::InvalidArgument("Paddings must be non-negative"));
63 dim->set_edge_padding_low(pad_start);
64 dim->set_edge_padding_high(pad_end);
65 padded_shape[1 + i] += pad_start + pad_end;
66 block_num_elems *= block_shape[i];
67 }
68 // Don't pad the remainder dimensions.
69 for (int i = 0; i < remainder_shape.size(); ++i) {
70 padding_config.add_dimensions();
71 }
72 OP_REQUIRES(ctx, block_num_elems > 0,
73 errors::InvalidArgument(
74 "The product of the block dimensions must be positive"));
75
76 xla::XlaOp padded =
77 xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
78
79 // 2. Reshape `padded` to `reshaped_padded` of shape:
80 //
81 // [batch] +
82 // [padded_shape[1] / block_shape[0],
83 // block_shape[0],
84 // ...,
85 // padded_shape[M] / block_shape[M-1],
86 // block_shape[M-1]] +
87 // remaining_shape
88 const int64 batch_size = input_shape[0];
89 std::vector<int64> reshaped_padded_shape(input_rank + block_rank);
90 reshaped_padded_shape[0] = batch_size;
91 for (int i = 0; i < block_rank; ++i) {
92 OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
93 errors::InvalidArgument("padded_shape[", 1 + i,
94 "]=", padded_shape[1 + i],
95 " is not divisible by block_shape[", i,
96 "]=", block_shape[i]));
97
98 reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
99 reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
100 }
101 std::copy(remainder_shape.begin(), remainder_shape.end(),
102 reshaped_padded_shape.begin() + 1 + 2 * block_rank);
103
104 xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape);
105
106 // 3. Permute dimensions of `reshaped_padded` to produce
107 // `permuted_reshaped_padded` of shape:
108 //
109 // block_shape +
110 // [batch] +
111 // [padded_shape[1] / block_shape[0],
112 // ...,
113 // padded_shape[M] / block_shape[M-1]] +
114 // remaining_shape
115 std::vector<int64> permutation(reshaped_padded_shape.size());
116 for (int i = 0; i < block_rank; ++i) {
117 permutation[i] = 1 + 2 * i + 1;
118 permutation[block_rank + 1 + i] = 1 + 2 * i;
119 }
120 permutation[block_rank] = 0;
121 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
122 1 + block_rank * 2);
123 xla::XlaOp permuted_reshaped_padded =
124 xla::Transpose(reshaped_padded, permutation);
125
126 // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
127 // batch dimension, producing an output tensor of shape:
128 //
129 // [batch * prod(block_shape)] +
130 // [padded_shape[1] / block_shape[0],
131 // ...,
132 // padded_shape[M] / block_shape[M-1]] +
133 // remaining_shape
134 // Determine the length of the prefix of block dims that can be combined
135 // into the batch dimension due to having no padding and block_shape=1.
136 std::vector<int64> output_shape(input_rank);
137 output_shape[0] = batch_size * block_num_elems;
138 for (int i = 0; i < block_rank; ++i) {
139 output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
140 }
141 std::copy(remainder_shape.begin(), remainder_shape.end(),
142 output_shape.begin() + 1 + block_rank);
143
144 xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape);
145 ctx->SetOutput(0, output);
146 }
147
148 class SpaceToBatchNDOp : public XlaOpKernel {
149 public:
SpaceToBatchNDOp(OpKernelConstruction * ctx)150 explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
151
Compile(XlaOpKernelContext * ctx)152 void Compile(XlaOpKernelContext* ctx) override {
153 std::vector<int64> block_shape;
154 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
155
156 xla::Literal paddings;
157 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
158
159 SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
160 block_shape, paddings);
161 }
162 };
163 REGISTER_XLA_OP(Name("SpaceToBatchND")
164 .CompileTimeConstantInput("paddings")
165 .CompileTimeConstantInput("block_shape"),
166 SpaceToBatchNDOp);
167
168 class SpaceToBatchOp : public XlaOpKernel {
169 public:
SpaceToBatchOp(OpKernelConstruction * ctx)170 explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
171 OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
172 OP_REQUIRES(
173 ctx, block_size_ > 1,
174 errors::InvalidArgument("Block size should be > 1: ", block_size_));
175 }
176
Compile(XlaOpKernelContext * ctx)177 void Compile(XlaOpKernelContext* ctx) override {
178 xla::Literal paddings;
179 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
180
181 SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
182 {block_size_, block_size_}, paddings);
183 }
184
185 private:
186 int block_size_;
187 };
188 REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstantInput("paddings"),
189 SpaceToBatchOp);
190
191 } // namespace
192 } // namespace tensorflow
193