1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
16
17 #include "absl/strings/str_format.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/jit/xla_activity.pb.h"
20 #include "tensorflow/compiler/jit/xla_activity_listener.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/lib/math/math_util.h"
32
33 namespace tensorflow {
34 namespace {
35
36 // We implement bilinear interpolation by upsampling followed by convolution.
37 // The basic idea is as follows. To scale from NxN to RxR:
38 //
39 // 1. S := (N - 1) / gcd(N-1, R-1)
40 // 2. k := (R - 1) / gcd(N-1, R-1)
41 // 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
42 //
43 // For example, to Scale from 7x7 -> 15x15:
44 //
45 // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
46 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
47 // 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
48 //
49 //
50 // The 7x7 -> 15x15 case is much too large to write out in full as an
51 // example. The smallest interesting example is 3x3 -> 4x4.
52 //
53 // S := 2
54 // k := 3
55 //
56 // 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06
57 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12
58 // 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18
59 // 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24
60 // 00 00 00 00 00 00 00 00 00 00 00
61 // 00 00 09 00 00 12 00 00 15 00 00
62 // 00 00 00 00 00 00 00 00 00 00 00
63 // 00 00 00 00 00 00 00 00 00 00 00
64 // 00 00 18 00 00 21 00 00 24 00 00
65 // 00 00 00 00 00 00 00 00 00 00 00
66 // 00 00 00 00 00 00 00 00 00 00 00
67 //
68 // with the following convolutional kernel, with stride [2, 2]:
69 // 1 2 3 2 1
70 // 2 4 6 4 2
71 // 1/9 * 3 6 9 6 3
72 // 2 4 6 4 2
73 // 1 2 3 2 1
74 // Note that the convolution kernel matrix is separable and thus we can instead
75 // use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
76
77 // Computes the size of the convolutional kernel and stride to use when resizing
78 // from in_size to out_size.
79 struct ResizeConvolutionDims {
80 // Size of the kernel to use.
81 std::vector<int64> kernel_size; // k
82
83 // Stride of the convolution to use.
84 std::vector<int64> stride; // S
85 };
ComputeResizeConvolutionParameters(absl::Span<const int64> in_size,absl::Span<const int64> out_size,bool align_corners)86 ResizeConvolutionDims ComputeResizeConvolutionParameters(
87 absl::Span<const int64> in_size, absl::Span<const int64> out_size,
88 bool align_corners) {
89 CHECK_EQ(in_size.size(), out_size.size());
90 int num_spatial_dims = in_size.size();
91 ResizeConvolutionDims dims;
92 dims.kernel_size.resize(num_spatial_dims);
93 dims.stride.resize(num_spatial_dims);
94 for (int i = 0; i < num_spatial_dims; ++i) {
95 if (in_size[i] == 1) {
96 // We must handle input size 1 specially because XLA convolution does
97 // not allow stride 0.
98 dims.stride[i] = dims.kernel_size[i] = 1;
99 } else if (out_size[i] == 1) {
100 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
101 // entry before resizing.
102 dims.stride[i] = dims.kernel_size[i] = 1;
103 } else {
104 // The scaling factor changes depending on the alignment of corners.
105 const int64_t in_size_factor =
106 align_corners ? in_size[i] - 1 : in_size[i];
107 const int64_t out_size_factor =
108 align_corners ? out_size[i] - 1 : out_size[i];
109
110 int64_t gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
111 static_cast<uint64>(out_size_factor));
112 dims.stride[i] = in_size_factor / gcd;
113 dims.kernel_size[i] = out_size_factor / gcd;
114 }
115 }
116 return dims;
117 }
118
119 // The upper padding of the input needed by ConvGeneralDilated calls is
120 // determined by solving two related relationships (assuming rhs_dilation == 0):
121 // 1. dilated_input_dim = lower_padding + upper_padding
122 // + lhs_dilation * (in_size - 1) + 1
123 // 2. dilated_input_dim = (2 * dims.kernel-size - 1)
124 // + dims.stride * (out_size - 1)
CalculateUpperPadding(int64_t in_size,int64_t out_size,int64_t kernel_size,int64_t stride)125 int64 CalculateUpperPadding(int64_t in_size, int64_t out_size,
126 int64_t kernel_size, int64_t stride) {
127 int64_t padding = (2 * kernel_size - 1) + (out_size - 1) * stride -
128 (kernel_size - 1) - 1 - (kernel_size * (in_size - 1));
129
130 return padding;
131 }
132
133 // Form a 2D convolution kernel like:
134 // 1 2 3 2 1
135 // 2 4 6 4 2
136 // 1/9 * 3 6 9 6 3
137 // 2 4 6 4 2
138 // 1 2 3 2 1
139 // by multiplying two 1D kernels of the form:
140 // 1/3 * [1 2 3 2 1]
141 // If the 2D kernel would be very large, the 1D kernel can be applied once in
142 // each dimension due to the symmetry of the kernel along all axis to reduce the
143 // computational intensity.
MakeBilinear1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)144 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder,
145 xla::PrimitiveType type, int64_t n) {
146 std::vector<float> kernel(n * 2 - 1);
147 for (int64_t i = 0; i < n; ++i) {
148 float v = (i + 1.0f) / n;
149 kernel[i] = v;
150 kernel[n * 2 - 2 - i] = v;
151 }
152 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
153 }
154
155 // Unlike the bilinear kernel, which is triangular, the nearest neighbor
156 // kernel is a square. For example, a 1D kernel with n=3 would look like
157 // [0 1 1 1 0]
158 // and n=4 would look like
159 // [0 0 1 1 1 1 0].
160 // Note that in the second case, the kernel is not symmetric and we default
161 // to the right (because an existing non TPU kernel
162 // for nearest neighbor resize already chose to default to the right,
163 // so we want to be consistent).
MakeNearestNeighbor1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)164 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder,
165 xla::PrimitiveType type, int64_t n) {
166 std::vector<float> kernel(n * 2 - 1, 0.0f);
167 std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f);
168
169 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
170 }
171
172 // Kernels with more than 16 spatial elements are considered intense and the
173 // kernel should be applied to each dimension independently.
174 const int64_t kMax2DKernelSize = 16;
175
MakeGeneralResizeKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64_t channels,bool is_kernel_bilinear)176 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder,
177 xla::PrimitiveType type,
178 absl::Span<const int64> kernel_size,
179 int64_t channels, bool is_kernel_bilinear) {
180 auto make_kernel_func =
181 is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
182
183 std::vector<int64> depthwise_kernel_sizes = {
184 (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1};
185 auto depthwise_kernel =
186 xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]),
187 depthwise_kernel_sizes, /*broadcast_dimensions=*/{1});
188
189 return xla::Mul(depthwise_kernel,
190 make_kernel_func(builder, type, kernel_size[0]),
191 /*broadcast_dimensions=*/{0});
192 }
193
MakeGeneralResizeKernelInDim(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64_t channels,int64_t dim,bool is_kernel_bilinear)194 xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder,
195 xla::PrimitiveType type,
196 absl::Span<const int64> kernel_size,
197 int64_t channels, int64_t dim,
198 bool is_kernel_bilinear) {
199 auto make_kernel_func =
200 is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
201
202 std::vector<int64> depthwise_kernel_sizes = {
203 dim == 0 ? (2 * kernel_size[0] - 1) : 1,
204 dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1};
205 return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]),
206 depthwise_kernel_sizes,
207 /*broadcast_dimensions=*/{dim});
208 }
209
BroadcastSpatialDimensions(xla::XlaBuilder * builder,const xla::XlaOp & input,int32_t spatial_dimensions_offset,absl::Span<const int64> in_size,absl::Span<const int64> out_size)210 xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder,
211 const xla::XlaOp& input,
212 int32_t spatial_dimensions_offset,
213 absl::Span<const int64> in_size,
214 absl::Span<const int64> out_size) {
215 // Add broadcasts to handle expanding from a size == 1 dimension to a
216 // size > 1 dimension.
217 auto broadcast_shape_or_status = builder->GetShape(input);
218 if (!broadcast_shape_or_status.ok()) {
219 return builder->ReportError(broadcast_shape_or_status.status());
220 }
221 xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie();
222 for (int32_t i = 0; i < in_size.size(); ++i) {
223 if (in_size[i] == 1 && out_size[i] > 1) {
224 broadcast_shape.set_dimensions(spatial_dimensions_offset + i,
225 out_size[i]);
226 }
227 }
228 return xla::BroadcastInDim(input, broadcast_shape.dimensions(),
229 /*broadcast_dimensions=*/{0, 1, 2, 3});
230 }
231
ResizeUsingDilationAndConvolution(xla::XlaBuilder * builder,const xla::XlaOp & input,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> out_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)232 xla::XlaOp ResizeUsingDilationAndConvolution(
233 xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type,
234 const int num_spatial_dims, absl::Span<const int64> in_size,
235 absl::Span<const int64> out_size, const int64_t channels,
236 const bool align_corners, bool is_kernel_bilinear) {
237 // Picture for a 1x3 to 1x4 bilinear resize:
238 // stride = 2, kernel size = 3
239 // Input:
240 // 3 6 9
241 // Input with dilation and padding:
242 // 0 0 3 0 0 6 0 0 9 0 0
243 // Convolution kernel:
244 // 1/3 * [1 2 3 2 1]
245 // Output:
246 // 3 5 7 9
247 xla::ConvolutionDimensionNumbers dimension_numbers;
248 dimension_numbers.set_input_batch_dimension(0);
249 dimension_numbers.set_output_batch_dimension(0);
250 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
251 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
252 for (int i = 0; i < num_spatial_dims; ++i) {
253 dimension_numbers.add_input_spatial_dimensions(1 + i);
254 dimension_numbers.add_output_spatial_dimensions(1 + i);
255 dimension_numbers.add_kernel_spatial_dimensions(i);
256 }
257 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
258 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
259
260 ResizeConvolutionDims dims =
261 ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
262
263 if (dims.kernel_size[0] * dims.kernel_size[1] >
264 kMax2DKernelSize * kMax2DKernelSize) {
265 BroadcastOptimizationRemark(
266 XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
267 absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
268 .IgnoreError();
269 }
270
271 xla::XlaOp output;
272
273 // Concatenation and padding below currently assumes num_spatial_dims is 2 to
274 // prevent needless code complexity.
275 CHECK_EQ(num_spatial_dims, 2)
276 << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
277 std::vector<int64> upper_padding(num_spatial_dims);
278 for (int i = 0; i < num_spatial_dims; ++i) {
279 upper_padding[i] = dims.kernel_size[i] - 1;
280 }
281 xla::XlaOp input_data = input;
282
283 if (!align_corners) {
284 // When Tensorflow does not align_corners, the resize indexing can access
285 // beyond the upper bound and is instead clamped to prevent out of bounds
286 // reads. This is conceptually the same as extending the edges of the input.
287 // We emulate this by copying the last row/column of the input.
288 // Calculate what padding would be needed then determine how far to extend
289 // the border before lhs dilation.
290 std::vector<int64> num_extended(num_spatial_dims);
291 upper_padding[0] = CalculateUpperPadding(
292 in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
293 upper_padding[1] = CalculateUpperPadding(
294 in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
295 num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
296 num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
297
298 const int64_t batch_dim_size =
299 builder->GetShape(input).ValueOrDie().dimensions(0);
300 if (num_extended[0] > 0) {
301 auto slice = xla::Slice(
302 input_data, {0, in_size[0] - 1, 0, 0},
303 {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
304 for (int i = 0; i < num_extended[0]; i++) {
305 input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
306 }
307 }
308
309 if (num_extended[1] > 0) {
310 auto slice = xla::Slice(
311 input_data, {0, 0, in_size[1] - 1, 0},
312 {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
313 {1, 1, 1, 1});
314 for (int i = 0; i < num_extended[1]; i++) {
315 input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
316 }
317 }
318
319 // Setting in_size to (in_size + num_extended) due to the above Slice and
320 // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
321 upper_padding[0] =
322 CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
323 dims.kernel_size[0], dims.stride[0]);
324 upper_padding[1] =
325 CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
326 dims.kernel_size[1], dims.stride[1]);
327 }
328
329 // Split convolutions into independent dimensions if they would be a very
330 // large kernel or if one or more of the dimensions are already equal.
331 bool decompose_resize =
332 in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
333 dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
334 if (!decompose_resize) {
335 xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
336 channels, is_kernel_bilinear);
337 output =
338 xla::ConvGeneralDilated(input_data, kernel, dims.stride,
339 /*padding=*/
340 {{dims.kernel_size[0] - 1, upper_padding[0]},
341 {dims.kernel_size[1] - 1, upper_padding[1]}},
342 /*lhs_dilation=*/dims.kernel_size,
343 /*rhs_dilation=*/{1, 1}, dimension_numbers,
344 /*feature_group_count=*/channels);
345 } else {
346 output = input_data;
347 if (in_size[0] != out_size[0]) {
348 xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
349 builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
350 output = xla::ConvGeneralDilated(
351 output, kernel0, {dims.stride[0], 1},
352 /*padding=*/
353 {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
354 /*lhs_dilation=*/{dims.kernel_size[0], 1},
355 /*rhs_dilation=*/{1, 1}, dimension_numbers,
356 /*feature_group_count=*/channels);
357 }
358
359 if (in_size[1] != out_size[1]) {
360 xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
361 builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
362 output = xla::ConvGeneralDilated(
363 output, kernel1, {1, dims.stride[1]},
364 /*padding=*/
365 {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
366 /*lhs_dilation=*/{1, dims.kernel_size[1]},
367 /*rhs_dilation=*/{1, 1}, dimension_numbers,
368 /*feature_group_count=*/channels);
369 }
370 }
371
372 // Add broadcasts to handle expanding from a size == 1 dimension to a
373 // size > 1 dimension.
374 return BroadcastSpatialDimensions(
375 builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size);
376 }
377
ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder * builder,const xla::XlaOp & grad,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> grad_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)378 xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(
379 xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type,
380 const int num_spatial_dims, absl::Span<const int64> in_size,
381 absl::Span<const int64> grad_size, const int64_t channels,
382 const bool align_corners, bool is_kernel_bilinear) {
383 ResizeConvolutionDims dims =
384 ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
385
386 // To form the backward convolution, we keep the kernel unchanged (it is
387 // already symmetric) and swap the roles of strides and LHS dilation.
388 xla::ConvolutionDimensionNumbers dimension_numbers;
389 dimension_numbers.set_input_batch_dimension(0);
390 dimension_numbers.set_output_batch_dimension(0);
391 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
392 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
393 for (int i = 0; i < num_spatial_dims; ++i) {
394 dimension_numbers.add_input_spatial_dimensions(i + 1);
395 dimension_numbers.add_output_spatial_dimensions(i + 1);
396 dimension_numbers.add_kernel_spatial_dimensions(i);
397 }
398 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
399 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
400 xla::XlaOp output;
401 if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
402 xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
403 channels, is_kernel_bilinear);
404
405 // Broadcast the input kernel where the forward op expanded from a size == 1
406 // dimension to a size > 1 dimension. This has the effect of summing the
407 // gradient contributions in that dimension.
408 kernel = BroadcastSpatialDimensions(
409 builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size);
410
411 output = xla::ConvGeneralDilated(
412 grad, kernel, /*window_strides=*/dims.kernel_size,
413 /*padding=*/
414 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
415 {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
416 /*lhs_dilation=*/dims.stride,
417 /*rhs_dilation=*/{1, 1}, dimension_numbers,
418 /*feature_group_count=*/channels);
419 } else {
420 xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
421 builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
422 xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
423 builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
424
425 // Broadcast the input kernel where the forward op expanded from a
426 // size == 1 dimension to a size > 1 dimension. This has the effect of
427 // summing the gradient contributions in that dimension.
428 if (in_size[0] == 1 && grad_size[0] > 1) {
429 kernel0 = BroadcastSpatialDimensions(builder, kernel0,
430 /*spatial_dimensions_offset=*/0, {1},
431 {grad_size[0]});
432 }
433 if (in_size[1] == 1 && grad_size[1] > 1) {
434 kernel1 = BroadcastSpatialDimensions(builder, kernel0,
435 /*spatial_dimensions_offset=*/0,
436 in_size, grad_size);
437 }
438
439 output = xla::ConvGeneralDilated(
440 grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
441 /*padding=*/
442 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
443 /*lhs_dilation=*/{dims.stride[0], 1},
444 /*rhs_dilation=*/{1, 1}, dimension_numbers,
445 /*feature_group_count=*/channels);
446
447 output = xla::ConvGeneralDilated(
448 output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
449 /*padding=*/
450 {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
451 /*lhs_dilation=*/{1, dims.stride[1]},
452 /*rhs_dilation=*/{1, 1}, dimension_numbers,
453 /*feature_group_count=*/channels);
454 }
455
456 // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
457 // Opposite of the slice performed by the forward op.
458 xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
459 bool pad_output = false;
460 for (int i = 0; i < num_spatial_dims; ++i) {
461 if (in_size[i] > 1 && grad_size[i] == 1) {
462 pad_output = true;
463 padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
464 }
465 }
466 if (pad_output) {
467 output = xla::Pad(output, xla::Zero(builder, type), padding);
468 }
469 return output;
470 }
471
GeneralCompile(XlaOpKernelContext * ctx,bool align_corners_,bool is_kernel_bilinear)472 void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_,
473 bool is_kernel_bilinear) {
474 xla::XlaBuilder* b = ctx->builder();
475
476 TensorShape input_shape = ctx->InputShape(0);
477 OP_REQUIRES(ctx, input_shape.dims() == 4,
478 errors::InvalidArgument("input must be 4-dimensional",
479 input_shape.DebugString()));
480 // First dimension always assumed to be batch
481 const int64_t batch = input_shape.dim_size(0);
482 std::vector<int64> in_size = {input_shape.dim_size(1),
483 input_shape.dim_size(2)};
484 // Last/4th dimension always assumed to be num channels
485 const int64_t channels = input_shape.dim_size(3);
486 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
487 errors::InvalidArgument("input size must be positive, got [",
488 in_size[0], ",", in_size[1], "]"));
489
490 std::vector<int64> out_size;
491 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
492 OP_REQUIRES(ctx, out_size.size() == 2,
493 errors::InvalidArgument("output size must be length 2, got ",
494 out_size.size()));
495 OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
496 errors::InvalidArgument("output size must be positive, got [",
497 out_size[0], ",", out_size[1], "]"));
498
499 const int num_spatial_dims = 2;
500
501 xla::XlaOp input = ctx->Input(0);
502 xla::PrimitiveType input_type = ctx->input_xla_type(0);
503
504 // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
505 // dimension i.
506 bool slice_input = false;
507 for (int i = 0; i < num_spatial_dims; ++i) {
508 if (in_size[i] > 1 && out_size[i] == 1) {
509 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
510 // entry before resizing.
511 slice_input = true;
512 in_size[i] = 1;
513 }
514 }
515 if (slice_input) {
516 input = xla::Slice(input, {0, 0, 0, 0},
517 {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
518 }
519
520 // Output is always type float if 'is_kernel_bilinear' is true.
521 // GPU with integer input also uses float, because XLA
522 // integer convolution on CuDNN is either not supported or not allowed
523 // directly.
524 xla::PrimitiveType original_input_type = input_type;
525 if (is_kernel_bilinear || (xla::primitive_util::IsIntegralType(input_type))) {
526 input = xla::ConvertElementType(input, xla::F32);
527 input_type = xla::F32;
528 }
529
530 for (int dim = 0; dim < in_size.size(); ++dim) {
531 // If the pairwise_distance function more accurately estimated performance,
532 // this threshold could be reduced.
533 constexpr int64_t kSmallDimThreshold = 1 << 10;
534 if (in_size[dim] > out_size[dim] || out_size[dim] < kSmallDimThreshold) {
535 std::vector<int64> next_size = in_size;
536 next_size[dim] = out_size[dim];
537 input = ResizeUsingDilationAndConvolution(
538 b, input, input_type, num_spatial_dims, in_size, next_size, channels,
539 align_corners_, is_kernel_bilinear);
540 in_size[dim] = next_size[dim];
541 }
542 }
543
544 // This function approximates the cost of a bilinear resize from a src_size to
545 // a dst_size. The accuracy is okay, but empirically, the algorithm makes some
546 // suboptimal choices. A better cost model would improve performance.
547 auto pairwise_distance = [align_corners_](int64_t src_size,
548 int64_t dst_size) {
549 auto params = ComputeResizeConvolutionParameters({src_size}, {dst_size},
550 align_corners_);
551 return params.stride[0];
552 };
553
554 for (int dim = 0; dim < in_size.size(); ++dim) {
555 std::vector<int64> distances(out_size[dim] + 1);
556 std::vector<int64> next_step(out_size[dim] + 1);
557 for (int64_t i = distances.size() - 2; i >= in_size[dim]; --i) {
558 distances[i] = INT64_MAX;
559 for (int64_t j = i + 1; j < distances.size(); ++j) {
560 int64_t distance = pairwise_distance(i, j) + distances[j];
561 if (distance < distances[i]) {
562 distances[i] = distance;
563 next_step[i] = j;
564 }
565 }
566 }
567
568 while (in_size[dim] != out_size[dim]) {
569 auto next_size = in_size;
570 next_size[dim] = next_step[in_size[dim]];
571 input = ResizeUsingDilationAndConvolution(
572 b, input, input_type, num_spatial_dims, in_size, next_size, channels,
573 align_corners_, is_kernel_bilinear);
574 in_size[dim] = next_size[dim];
575 }
576 }
577
578 // Bilinear always outputs float, but nearest neighbor keeps the original type
579 if (!is_kernel_bilinear && original_input_type != input_type) {
580 input = xla::ConvertElementType(input, original_input_type);
581 }
582 ctx->SetOutput(0, input);
583 }
584 } // namespace
585
ResizeNearestNeighborOp(OpKernelConstruction * ctx)586 ResizeNearestNeighborOp::ResizeNearestNeighborOp(OpKernelConstruction* ctx)
587 : XlaOpKernel(ctx) {
588 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
589 OP_REQUIRES(
590 ctx, align_corners_ == true,
591 errors::Unimplemented("ResizeNearestNeighbor with align_corners=False "
592 "is not yet implemented"));
593 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
594 OP_REQUIRES(ctx, half_pixel_centers_ == false,
595 errors::Unimplemented(
596 "ResizeNearestNeighbor with half_pixel_centers=True is "
597 "not yet implemented"));
598 }
599
Compile(XlaOpKernelContext * ctx)600 void ResizeNearestNeighborOp::Compile(XlaOpKernelContext* ctx) {
601 GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
602 }
603
604 REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"),
605 ResizeNearestNeighborOp);
606
ResizeBilinearOp(OpKernelConstruction * ctx)607 ResizeBilinearOp::ResizeBilinearOp(OpKernelConstruction* ctx)
608 : XlaOpKernel(ctx) {
609 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
610 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
611 OP_REQUIRES(
612 ctx, half_pixel_centers_ == false,
613 errors::Unimplemented("ResizeBilinear with half_pixel_centers=True is "
614 "not yet implemented"));
615 }
616
Compile(XlaOpKernelContext * ctx)617 void ResizeBilinearOp::Compile(XlaOpKernelContext* ctx) {
618 GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
619 }
620
621 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"),
622 ResizeBilinearOp);
623
ResizeBilinearGradOp(OpKernelConstruction * ctx)624 ResizeBilinearGradOp::ResizeBilinearGradOp(OpKernelConstruction* ctx)
625 : XlaOpKernel(ctx) {
626 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
627 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
628 OP_REQUIRES(
629 ctx, align_corners_ == true,
630 errors::Unimplemented("ResizeBilinearGrad with align_corners=False is "
631 "not yet implemented"));
632 OP_REQUIRES(ctx, half_pixel_centers_ == false,
633 errors::Unimplemented(
634 "ResizeBilinearGrad with half_pixel_centers=True is "
635 "not yet implemented"));
636
637 DataType output_dtype;
638 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
639 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
640 }
641
Compile(XlaOpKernelContext * ctx)642 void ResizeBilinearGradOp::Compile(XlaOpKernelContext* ctx) {
643 xla::XlaBuilder* b = ctx->builder();
644
645 TensorShape input_shape = ctx->InputShape(1);
646 OP_REQUIRES(ctx, input_shape.dims() == 4,
647 errors::InvalidArgument("input must be 4-dimensional",
648 input_shape.DebugString()));
649 const int64_t batch = input_shape.dim_size(0);
650 std::vector<int64> in_size = {input_shape.dim_size(1),
651 input_shape.dim_size(2)};
652 const int64_t channels = input_shape.dim_size(3);
653 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
654 errors::InvalidArgument("input size must be positive, got [",
655 in_size[0], ",", in_size[1], "]"));
656
657 TensorShape grad_shape = ctx->InputShape(0);
658 OP_REQUIRES(ctx, grad_shape.dims() == 4,
659 errors::InvalidArgument("gradient must be 4-dimensional",
660 grad_shape.DebugString()));
661 const int64_t grad_batch = grad_shape.dim_size(0);
662 const std::vector<int64> grad_size = {grad_shape.dim_size(1),
663 grad_shape.dim_size(2)};
664 const int64_t grad_channels = grad_shape.dim_size(3);
665 OP_REQUIRES(ctx, batch == grad_batch,
666 errors::InvalidArgument(
667 "activations and gradients must have the same batch size (",
668 batch, " vs. ", grad_batch, ")"));
669 OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
670 errors::InvalidArgument("gradient size must be positive, got [",
671 grad_size[0], ",", grad_size[1], "]"));
672 OP_REQUIRES(
673 ctx, channels == grad_channels,
674 errors::InvalidArgument(
675 "activations and gradients must have the same number of channels (",
676 channels, " vs. ", grad_channels, ")"));
677
678 const int num_spatial_dims = 2;
679
680 xla::XlaOp grad = ctx->Input(0);
681
682 xla::XlaOp output = grad;
683 while (in_size != grad_size) {
684 if (in_size[0] != 1 && in_size[1] != 1) {
685 std::vector<float> k = {
686 (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
687 (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
688 if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
689 k[0] > 1 && k[1] > 1) {
690 std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
691 (in_size[1] - 1) * 2 + 1};
692 output = ResizeUsingDilationAndConvolutionGradOp(
693 b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size,
694 channels, align_corners_, true);
695 grad = output;
696 in_size = next_grad_size;
697 } else {
698 output = ResizeUsingDilationAndConvolutionGradOp(
699 b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
700 align_corners_, true);
701 in_size = grad_size;
702 }
703 } else {
704 output = ResizeUsingDilationAndConvolutionGradOp(
705 b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
706 align_corners_, true);
707 in_size = grad_size;
708 }
709 }
710
711 output = xla::ConvertElementType(output, output_type_);
712 ctx->SetOutput(0, output);
713 }
714
715 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
716
717 } // namespace tensorflow
718