1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20
21 namespace tensorflow {
22 namespace {
23
BatchToSpace(XlaOpKernelContext * ctx,const xla::XlaOp & input,DataType input_dtype,const TensorShape & input_tensor_shape,absl::Span<const int64> block_shape,const xla::Literal & crops)24 void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
25 DataType input_dtype, const TensorShape& input_tensor_shape,
26 absl::Span<const int64> block_shape,
27 const xla::Literal& crops) {
28 const int input_rank = input_tensor_shape.dims();
29 const absl::InlinedVector<int64, 4> input_shape =
30 input_tensor_shape.dim_sizes();
31 const int block_rank = block_shape.size();
32
33 OP_REQUIRES(
34 ctx, input_rank >= 1 + block_rank,
35 errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
36 " instead of ", input_rank));
37 absl::Span<const int64> remainder_shape(input_shape);
38 remainder_shape.remove_prefix(1 + block_rank);
39
40 OP_REQUIRES(
41 ctx,
42 crops.shape().rank() == 2 &&
43 block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
44 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
45 errors::InvalidArgument("crops should have shape [", block_rank,
46 ", 2] instead of ",
47 xla::ShapeUtil::HumanString(crops.shape())));
48
49 const int64 batch_size = input_shape[0];
50
51 // Compute the product of the block_shape values.
52 int64 block_num_elems = 1;
53 for (int i = 0; i < block_rank; ++i) {
54 block_num_elems *= block_shape[i];
55 }
56 OP_REQUIRES(ctx, block_num_elems > 0,
57 errors::InvalidArgument(
58 "The product of the block dimensions must be positive"));
59
60 // 1. Reshape `input` to `reshaped` of shape:
61 // [block_shape[0], ..., block_shape[M-1],
62 // batch / prod(block_shape),
63 // input_shape[1], ..., input_shape[N-1]]
64
65 OP_REQUIRES(
66 ctx, batch_size % block_num_elems == 0,
67 errors::InvalidArgument("Input batch dimension (", batch_size,
68 ") is not divisible by product of block sizes (",
69 block_num_elems, ")"));
70 std::vector<int64> reshaped_shape(input_rank + block_rank);
71 std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin());
72 reshaped_shape[block_rank] = batch_size / block_num_elems;
73 std::copy(input_shape.begin() + 1, input_shape.end(),
74 reshaped_shape.begin() + block_rank + 1);
75 xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
76
77 // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
78 // [batch / prod(block_shape),
79 //
80 // input_shape[1], block_shape[0],
81 // ...,
82 // input_shape[M], block_shape[M-1],
83 //
84 // input_shape[M+1], ..., input_shape[N-1]]
85 std::vector<int64> permutation(reshaped_shape.size());
86 permutation[0] = block_rank;
87 for (int i = 0; i < block_rank; ++i) {
88 permutation[1 + 2 * i] = block_rank + 1 + i;
89 permutation[1 + 2 * i + 1] = i;
90 }
91 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
92 1 + block_rank * 2);
93 xla::XlaOp permuted = xla::Transpose(reshaped, permutation);
94
95 // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
96 // [batch / prod(block_shape),
97 //
98 // input_shape[1] * block_shape[0],
99 // ...,
100 // input_shape[M] * block_shape[M-1],
101 //
102 // input_shape[M+1],
103 // ...,
104 // input_shape[N-1]]
105 std::vector<int64> reshaped_permuted_shape(input_rank);
106 reshaped_permuted_shape[0] = batch_size / block_num_elems;
107 for (int i = 0; i < block_rank; ++i) {
108 reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i];
109 }
110 std::copy(remainder_shape.begin(), remainder_shape.end(),
111 reshaped_permuted_shape.begin() + 1 + block_rank);
112
113 xla::XlaOp reshaped_permuted =
114 xla::Reshape(permuted, reshaped_permuted_shape);
115
116 // 4. Crop the start and end of dimensions `[1, ..., M]` of
117 // `reshaped_permuted` according to `crops` to produce the output of shape:
118 // [batch / prod(block_shape),
119 //
120 // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
121 // ...,
122 // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
123 //
124 // input_shape[M+1], ..., input_shape[N-1]]
125 std::vector<int64> start_indices(input_rank, 0);
126 std::vector<int64> end_indices = reshaped_permuted_shape;
127 std::vector<int64> strides(input_rank, 1);
128 for (int i = 0; i < block_rank; ++i) {
129 int64 crop_start = crops.Get<int64>({i, 0});
130 int64 crop_end = crops.Get<int64>({i, 1});
131 OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
132 errors::InvalidArgument("Crops must be non-negative"));
133 start_indices[1 + i] = crop_start;
134 end_indices[1 + i] -= crop_end;
135 OP_REQUIRES(
136 ctx, start_indices[1 + i] <= end_indices[1 + i],
137 errors::InvalidArgument(
138 "Cropped size must be non-negative: start: ", crop_start,
139 " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
140 }
141 xla::XlaOp output =
142 xla::Slice(reshaped_permuted, start_indices, end_indices, strides);
143 ctx->SetOutput(0, output);
144 }
145
146 class BatchToSpaceNDOp : public XlaOpKernel {
147 public:
BatchToSpaceNDOp(OpKernelConstruction * ctx)148 explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
149
Compile(XlaOpKernelContext * ctx)150 void Compile(XlaOpKernelContext* ctx) override {
151 std::vector<int64> block_shape;
152 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
153
154 xla::Literal crops;
155 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops));
156
157 BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
158 block_shape, crops);
159 }
160 };
161 REGISTER_XLA_OP(Name("BatchToSpaceND")
162 .CompileTimeConstantInput("block_shape")
163 .CompileTimeConstantInput("crops"),
164 BatchToSpaceNDOp);
165
166 class BatchToSpaceOp : public XlaOpKernel {
167 public:
BatchToSpaceOp(OpKernelConstruction * ctx)168 explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
169 OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
170 OP_REQUIRES(
171 ctx, block_size_ > 1,
172 errors::InvalidArgument("Block size should be > 1: ", block_size_));
173 }
174
Compile(XlaOpKernelContext * ctx)175 void Compile(XlaOpKernelContext* ctx) override {
176 xla::Literal crops;
177 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops));
178
179 BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
180 {block_size_, block_size_}, crops);
181 }
182
183 private:
184 int block_size_;
185 };
186 REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstantInput("crops"),
187 BatchToSpaceOp);
188
189 } // namespace
190 } // namespace tensorflow
191