1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
16
17 #include "absl/strings/str_format.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/jit/xla_activity.pb.h"
20 #include "tensorflow/compiler/jit/xla_activity_listener.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/lib/math/math_util.h"
32
33 namespace tensorflow {
34 namespace {
35
36 // We implement bilinear interpolation by upsampling followed by convolution.
37 // The basic idea is as follows. To scale from NxN to RxR:
38 //
39 // 1. S := (N - 1) / gcd(N-1, R-1)
40 // 2. k := (R - 1) / gcd(N-1, R-1)
41 // 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
42 //
43 // For example, to Scale from 7x7 -> 15x15:
44 //
45 // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
46 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
47 // 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
48 //
49 //
50 // The 7x7 -> 15x15 case is much too large to write out in full as an
51 // example. The smallest interesting example is 3x3 -> 4x4.
52 //
53 // S := 2
54 // k := 3
55 //
56 // 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06
57 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12
58 // 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18
59 // 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24
60 // 00 00 00 00 00 00 00 00 00 00 00
61 // 00 00 09 00 00 12 00 00 15 00 00
62 // 00 00 00 00 00 00 00 00 00 00 00
63 // 00 00 00 00 00 00 00 00 00 00 00
64 // 00 00 18 00 00 21 00 00 24 00 00
65 // 00 00 00 00 00 00 00 00 00 00 00
66 // 00 00 00 00 00 00 00 00 00 00 00
67 //
68 // with the following convolutional kernel, with stride [2, 2]:
69 // 1 2 3 2 1
70 // 2 4 6 4 2
71 // 1/9 * 3 6 9 6 3
72 // 2 4 6 4 2
73 // 1 2 3 2 1
74 // Note that the convolution kernel matrix is separable and thus we can instead
75 // use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
76
77 // Computes the size of the convolutional kernel and stride to use when resizing
78 // from in_size to out_size.
79 struct ResizeConvolutionDims {
80 // Size of the kernel to use.
81 std::vector<int64> kernel_size; // k
82
83 // Stride of the convolution to use.
84 std::vector<int64> stride; // S
85 };
ComputeResizeConvolutionParameters(absl::Span<const int64> in_size,absl::Span<const int64> out_size,bool align_corners)86 ResizeConvolutionDims ComputeResizeConvolutionParameters(
87 absl::Span<const int64> in_size, absl::Span<const int64> out_size,
88 bool align_corners) {
89 CHECK_EQ(in_size.size(), out_size.size());
90 int num_spatial_dims = in_size.size();
91 ResizeConvolutionDims dims;
92 dims.kernel_size.resize(num_spatial_dims);
93 dims.stride.resize(num_spatial_dims);
94 for (int i = 0; i < num_spatial_dims; ++i) {
95 if (in_size[i] == 1) {
96 // We must handle input size 1 specially because XLA convolution does
97 // not allow stride 0.
98 dims.stride[i] = dims.kernel_size[i] = 1;
99 } else if (out_size[i] == 1) {
100 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
101 // entry before resizing.
102 dims.stride[i] = dims.kernel_size[i] = 1;
103 } else {
104 // The scaling factor changes depending on the alignment of corners.
105 const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i];
106 const int64 out_size_factor =
107 align_corners ? out_size[i] - 1 : out_size[i];
108
109 int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
110 static_cast<uint64>(out_size_factor));
111 dims.stride[i] = in_size_factor / gcd;
112 dims.kernel_size[i] = out_size_factor / gcd;
113 }
114 }
115 return dims;
116 }
117
118 // The upper padding of the input needed by ConvGeneralDilated calls is
119 // determined by solving two related relationships (assuming rhs_dilation == 0):
120 // 1. dilated_input_dim = lower_padding + upper_padding
121 // + lhs_dilation * (in_size - 1) + 1
122 // 2. dilated_input_dim = (2 * dims.kernel-size - 1)
123 // + dims.stride * (out_size - 1)
CalculateUpperPadding(int64 in_size,int64 out_size,int64 kernel_size,int64 stride)124 int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
125 int64 stride) {
126 int64 padding = (2 * kernel_size - 1) + (out_size - 1) * stride -
127 (kernel_size - 1) - 1 - (kernel_size * (in_size - 1));
128
129 return padding;
130 }
131
132 // Form a 2D convolution kernel like:
133 // 1 2 3 2 1
134 // 2 4 6 4 2
135 // 1/9 * 3 6 9 6 3
136 // 2 4 6 4 2
137 // 1 2 3 2 1
138 // by multiplying two 1D kernels of the form:
139 // 1/3 * [1 2 3 2 1]
140 // If the 2D kernel would be very large, the 1D kernel can be applied once in
141 // each dimension due to the symmetry of the kernel along all axis to reduce the
142 // computational intensity.
MakeBilinear1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64 n)143 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder,
144 xla::PrimitiveType type, int64 n) {
145 std::vector<float> kernel(n * 2 - 1);
146 for (int64 i = 0; i < n; ++i) {
147 float v = (i + 1.0f) / n;
148 kernel[i] = v;
149 kernel[n * 2 - 2 - i] = v;
150 }
151 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
152 }
153
154 // Unlike the bilinear kernel, which is triangular, the nearest neighbor
155 // kernel is a square. For example, a 1D kernel with n=3 would look like
156 // [0 1 1 1 0]
157 // and n=4 would look like
158 // [0 0 1 1 1 1 0].
159 // Note that in the second case, the kernel is not symmetric and we default
160 // to the right (because an existing non TPU kernel
161 // for nearest neighbor resize already chose to default to the right,
162 // so we want to be consistent).
MakeNearestNeighbor1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64 n)163 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder,
164 xla::PrimitiveType type, int64 n) {
165 std::vector<float> kernel(n * 2 - 1, 0.0f);
166 std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f);
167
168 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
169 }
170
171 // Kernels with more than 16 spatial elements are considered intense and the
172 // kernel should be applied to each dimension independently.
173 const int64 kMax2DKernelSize = 16;
174
MakeGeneralResizeKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64 channels,bool is_kernel_bilinear)175 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder,
176 xla::PrimitiveType type,
177 absl::Span<const int64> kernel_size,
178 int64 channels, bool is_kernel_bilinear) {
179 auto make_kernel_func =
180 is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
181
182 std::vector<int64> depthwise_kernel_sizes = {
183 (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1};
184 auto depthwise_kernel =
185 xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]),
186 depthwise_kernel_sizes, /*broadcast_dimensions=*/{1});
187
188 return xla::Mul(depthwise_kernel,
189 make_kernel_func(builder, type, kernel_size[0]),
190 /*broadcast_dimensions=*/{0});
191 }
192
MakeGeneralResizeKernelInDim(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64 channels,int64 dim,bool is_kernel_bilinear)193 xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder,
194 xla::PrimitiveType type,
195 absl::Span<const int64> kernel_size,
196 int64 channels, int64 dim,
197 bool is_kernel_bilinear) {
198 auto make_kernel_func =
199 is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
200
201 std::vector<int64> depthwise_kernel_sizes = {
202 dim == 0 ? (2 * kernel_size[0] - 1) : 1,
203 dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1};
204 return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]),
205 depthwise_kernel_sizes,
206 /*broadcast_dimensions=*/{dim});
207 }
208
BroadcastSpatialDimensions(xla::XlaBuilder * builder,const xla::XlaOp & input,int32 spatial_dimensions_offset,absl::Span<const int64> in_size,absl::Span<const int64> out_size)209 xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder,
210 const xla::XlaOp& input,
211 int32 spatial_dimensions_offset,
212 absl::Span<const int64> in_size,
213 absl::Span<const int64> out_size) {
214 // Add broadcasts to handle expanding from a size == 1 dimension to a
215 // size > 1 dimension.
216 auto broadcast_shape_or_status = builder->GetShape(input);
217 if (!broadcast_shape_or_status.ok()) {
218 return builder->ReportError(broadcast_shape_or_status.status());
219 }
220 xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie();
221 for (int32 i = 0; i < in_size.size(); ++i) {
222 if (in_size[i] == 1 && out_size[i] > 1) {
223 broadcast_shape.set_dimensions(spatial_dimensions_offset + i,
224 out_size[i]);
225 }
226 }
227 return xla::BroadcastInDim(input, broadcast_shape.dimensions(),
228 /*broadcast_dimensions=*/{0, 1, 2, 3});
229 }
230
ResizeUsingDilationAndConvolution(xla::XlaBuilder * builder,const xla::XlaOp & input,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> out_size,const int64 channels,const bool align_corners,bool is_kernel_bilinear)231 xla::XlaOp ResizeUsingDilationAndConvolution(
232 xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type,
233 const int num_spatial_dims, absl::Span<const int64> in_size,
234 absl::Span<const int64> out_size, const int64 channels,
235 const bool align_corners, bool is_kernel_bilinear) {
236 // Picture for a 1x3 to 1x4 bilinear resize:
237 // stride = 2, kernel size = 3
238 // Input:
239 // 3 6 9
240 // Input with dilation and padding:
241 // 0 0 3 0 0 6 0 0 9 0 0
242 // Convolution kernel:
243 // 1/3 * [1 2 3 2 1]
244 // Output:
245 // 3 5 7 9
246 xla::ConvolutionDimensionNumbers dimension_numbers;
247 dimension_numbers.set_input_batch_dimension(0);
248 dimension_numbers.set_output_batch_dimension(0);
249 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
250 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
251 for (int i = 0; i < num_spatial_dims; ++i) {
252 dimension_numbers.add_input_spatial_dimensions(1 + i);
253 dimension_numbers.add_output_spatial_dimensions(1 + i);
254 dimension_numbers.add_kernel_spatial_dimensions(i);
255 }
256 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
257 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
258
259 ResizeConvolutionDims dims =
260 ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
261
262 if (dims.kernel_size[0] * dims.kernel_size[1] >
263 kMax2DKernelSize * kMax2DKernelSize) {
264 BroadcastOptimizationRemark(
265 XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
266 absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
267 .IgnoreError();
268 }
269
270 xla::XlaOp output;
271
272 // Concatenation and padding below currently assumes num_spatial_dims is 2 to
273 // prevent needless code complexity.
274 CHECK_EQ(num_spatial_dims, 2)
275 << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
276 std::vector<int64> upper_padding(num_spatial_dims);
277 for (int i = 0; i < num_spatial_dims; ++i) {
278 upper_padding[i] = dims.kernel_size[i] - 1;
279 }
280 xla::XlaOp input_data = input;
281
282 if (!align_corners) {
283 // When Tensorflow does not align_corners, the resize indexing can access
284 // beyond the upper bound and is instead clamped to prevent out of bounds
285 // reads. This is conceptually the same as extending the edges of the input.
286 // We emulate this by copying the last row/column of the input.
287 // Calculate what padding would be needed then determine how far to extend
288 // the border before lhs dilation.
289 std::vector<int64> num_extended(num_spatial_dims);
290 upper_padding[0] = CalculateUpperPadding(
291 in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
292 upper_padding[1] = CalculateUpperPadding(
293 in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
294 num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
295 num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
296
297 const int64 batch_dim_size =
298 builder->GetShape(input).ValueOrDie().dimensions(0);
299 if (num_extended[0] > 0) {
300 auto slice = xla::Slice(
301 input_data, {0, in_size[0] - 1, 0, 0},
302 {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
303 for (int i = 0; i < num_extended[0]; i++) {
304 input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
305 }
306 }
307
308 if (num_extended[1] > 0) {
309 auto slice = xla::Slice(
310 input_data, {0, 0, in_size[1] - 1, 0},
311 {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
312 {1, 1, 1, 1});
313 for (int i = 0; i < num_extended[1]; i++) {
314 input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
315 }
316 }
317
318 // Setting in_size to (in_size + num_extended) due to the above Slice and
319 // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
320 upper_padding[0] =
321 CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
322 dims.kernel_size[0], dims.stride[0]);
323 upper_padding[1] =
324 CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
325 dims.kernel_size[1], dims.stride[1]);
326 }
327
328 // Split convolutions into independent dimensions if they would be a very
329 // large kernel or if one or more of the dimensions are already equal.
330 bool decompose_resize =
331 in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
332 dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
333 if (!decompose_resize) {
334 xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
335 channels, is_kernel_bilinear);
336 output =
337 xla::ConvGeneralDilated(input_data, kernel, dims.stride,
338 /*padding=*/
339 {{dims.kernel_size[0] - 1, upper_padding[0]},
340 {dims.kernel_size[1] - 1, upper_padding[1]}},
341 /*lhs_dilation=*/dims.kernel_size,
342 /*rhs_dilation=*/{1, 1}, dimension_numbers,
343 /*feature_group_count=*/channels);
344 } else {
345 output = input_data;
346 if (in_size[0] != out_size[0]) {
347 xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
348 builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
349 output = xla::ConvGeneralDilated(
350 output, kernel0, {dims.stride[0], 1},
351 /*padding=*/
352 {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
353 /*lhs_dilation=*/{dims.kernel_size[0], 1},
354 /*rhs_dilation=*/{1, 1}, dimension_numbers,
355 /*feature_group_count=*/channels);
356 }
357
358 if (in_size[1] != out_size[1]) {
359 xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
360 builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
361 output = xla::ConvGeneralDilated(
362 output, kernel1, {1, dims.stride[1]},
363 /*padding=*/
364 {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
365 /*lhs_dilation=*/{1, dims.kernel_size[1]},
366 /*rhs_dilation=*/{1, 1}, dimension_numbers,
367 /*feature_group_count=*/channels);
368 }
369 }
370
371 // Add broadcasts to handle expanding from a size == 1 dimension to a
372 // size > 1 dimension.
373 return BroadcastSpatialDimensions(
374 builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size);
375 }
376
ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder * builder,const xla::XlaOp & grad,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> grad_size,const int64 channels,const bool align_corners,bool is_kernel_bilinear)377 xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(
378 xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type,
379 const int num_spatial_dims, absl::Span<const int64> in_size,
380 absl::Span<const int64> grad_size, const int64 channels,
381 const bool align_corners, bool is_kernel_bilinear) {
382 ResizeConvolutionDims dims =
383 ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
384
385 // To form the backward convolution, we keep the kernel unchanged (it is
386 // already symmetric) and swap the roles of strides and LHS dilation.
387 xla::ConvolutionDimensionNumbers dimension_numbers;
388 dimension_numbers.set_input_batch_dimension(0);
389 dimension_numbers.set_output_batch_dimension(0);
390 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
391 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
392 for (int i = 0; i < num_spatial_dims; ++i) {
393 dimension_numbers.add_input_spatial_dimensions(i + 1);
394 dimension_numbers.add_output_spatial_dimensions(i + 1);
395 dimension_numbers.add_kernel_spatial_dimensions(i);
396 }
397 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
398 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
399 xla::XlaOp output;
400 if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
401 xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
402 channels, is_kernel_bilinear);
403
404 // Broadcast the input kernel where the forward op expanded from a size == 1
405 // dimension to a size > 1 dimension. This has the effect of summing the
406 // gradient contributions in that dimension.
407 kernel = BroadcastSpatialDimensions(
408 builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size);
409
410 output = xla::ConvGeneralDilated(
411 grad, kernel, /*window_strides=*/dims.kernel_size,
412 /*padding=*/
413 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
414 {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
415 /*lhs_dilation=*/dims.stride,
416 /*rhs_dilation=*/{1, 1}, dimension_numbers,
417 /*feature_group_count=*/channels);
418 } else {
419 xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
420 builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
421 xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
422 builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
423
424 // Broadcast the input kernel where the forward op expanded from a
425 // size == 1 dimension to a size > 1 dimension. This has the effect of
426 // summing the gradient contributions in that dimension.
427 if (in_size[0] == 1 && grad_size[0] > 1) {
428 kernel0 = BroadcastSpatialDimensions(builder, kernel0,
429 /*spatial_dimensions_offset=*/0, {1},
430 {grad_size[0]});
431 }
432 if (in_size[1] == 1 && grad_size[1] > 1) {
433 kernel1 = BroadcastSpatialDimensions(builder, kernel0,
434 /*spatial_dimensions_offset=*/0,
435 in_size, grad_size);
436 }
437
438 output = xla::ConvGeneralDilated(
439 grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
440 /*padding=*/
441 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
442 /*lhs_dilation=*/{dims.stride[0], 1},
443 /*rhs_dilation=*/{1, 1}, dimension_numbers,
444 /*feature_group_count=*/channels);
445
446 output = xla::ConvGeneralDilated(
447 output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
448 /*padding=*/
449 {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
450 /*lhs_dilation=*/{1, dims.stride[1]},
451 /*rhs_dilation=*/{1, 1}, dimension_numbers,
452 /*feature_group_count=*/channels);
453 }
454
455 // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
456 // Opposite of the slice performed by the forward op.
457 xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
458 bool pad_output = false;
459 for (int i = 0; i < num_spatial_dims; ++i) {
460 if (in_size[i] > 1 && grad_size[i] == 1) {
461 pad_output = true;
462 padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
463 }
464 }
465 if (pad_output) {
466 output = xla::Pad(output, xla::Zero(builder, type), padding);
467 }
468 return output;
469 }
470
GeneralCompile(XlaOpKernelContext * ctx,bool align_corners_,bool is_kernel_bilinear)471 void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_,
472 bool is_kernel_bilinear) {
473 xla::XlaBuilder* b = ctx->builder();
474
475 TensorShape input_shape = ctx->InputShape(0);
476 OP_REQUIRES(ctx, input_shape.dims() == 4,
477 errors::InvalidArgument("input must be 4-dimensional",
478 input_shape.DebugString()));
479 // First dimension always assumed to be batch
480 const int64 batch = input_shape.dim_size(0);
481 std::vector<int64> in_size = {input_shape.dim_size(1),
482 input_shape.dim_size(2)};
483 // Last/4th dimension always assumed to be num channels
484 const int64 channels = input_shape.dim_size(3);
485 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
486 errors::InvalidArgument("input size must be positive, got [",
487 in_size[0], ",", in_size[1], "]"));
488
489 std::vector<int64> out_size;
490 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
491 OP_REQUIRES(ctx, out_size.size() == 2,
492 errors::InvalidArgument("output size must be length 2, got ",
493 out_size.size()));
494 OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
495 errors::InvalidArgument("output size must be positive, got [",
496 out_size[0], ",", out_size[1], "]"));
497
498 const int num_spatial_dims = 2;
499
500 xla::XlaOp input = ctx->Input(0);
501 xla::PrimitiveType input_type = ctx->input_xla_type(0);
502
503 // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
504 // dimension i.
505 bool slice_input = false;
506 for (int i = 0; i < num_spatial_dims; ++i) {
507 if (in_size[i] > 1 && out_size[i] == 1) {
508 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
509 // entry before resizing.
510 slice_input = true;
511 in_size[i] = 1;
512 }
513 }
514 if (slice_input) {
515 input = xla::Slice(input, {0, 0, 0, 0},
516 {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
517 }
518
519 // Output is always type float if 'is_kernel_bilinear' is true.
520 // GPU with integer input also uses float, because XLA
521 // integer convolution on CuDNN is either not supported or not allowed
522 // directly.
523 xla::PrimitiveType original_input_type = input_type;
524 if (is_kernel_bilinear || (xla::primitive_util::IsIntegralType(input_type))) {
525 input = xla::ConvertElementType(input, xla::F32);
526 input_type = xla::F32;
527 }
528
529 for (int dim = 0; dim < in_size.size(); ++dim) {
530 // If the pairwise_distance function more accurately estimated performance,
531 // this threshold could be reduced.
532 constexpr int64 kSmallDimThreshold = 1 << 10;
533 if (in_size[dim] > out_size[dim] || out_size[dim] < kSmallDimThreshold) {
534 std::vector<int64> next_size = in_size;
535 next_size[dim] = out_size[dim];
536 input = ResizeUsingDilationAndConvolution(
537 b, input, input_type, num_spatial_dims, in_size, next_size, channels,
538 align_corners_, is_kernel_bilinear);
539 in_size[dim] = next_size[dim];
540 }
541 }
542
543 // This function approximates the cost of a bilinear resize from a src_size to
544 // a dst_size. The accuracy is okay, but empirically, the algorithm makes some
545 // suboptimal choices. A better cost model would improve performance.
546 auto pairwise_distance = [align_corners_](int64 src_size, int64 dst_size) {
547 auto params = ComputeResizeConvolutionParameters({src_size}, {dst_size},
548 align_corners_);
549 return params.stride[0];
550 };
551
552 for (int dim = 0; dim < in_size.size(); ++dim) {
553 std::vector<int64> distances(out_size[dim] + 1);
554 std::vector<int64> next_step(out_size[dim] + 1);
555 for (int64 i = distances.size() - 2; i >= in_size[dim]; --i) {
556 distances[i] = INT64_MAX;
557 for (int64 j = i + 1; j < distances.size(); ++j) {
558 int64 distance = pairwise_distance(i, j) + distances[j];
559 if (distance < distances[i]) {
560 distances[i] = distance;
561 next_step[i] = j;
562 }
563 }
564 }
565
566 while (in_size[dim] != out_size[dim]) {
567 auto next_size = in_size;
568 next_size[dim] = next_step[in_size[dim]];
569 input = ResizeUsingDilationAndConvolution(
570 b, input, input_type, num_spatial_dims, in_size, next_size, channels,
571 align_corners_, is_kernel_bilinear);
572 in_size[dim] = next_size[dim];
573 }
574 }
575
576 // Bilinear always outputs float, but nearest neighbor keeps the original type
577 if (!is_kernel_bilinear && original_input_type != input_type) {
578 input = xla::ConvertElementType(input, original_input_type);
579 }
580 ctx->SetOutput(0, input);
581 }
582 } // namespace
583
ResizeNearestNeighborOp(OpKernelConstruction * ctx)584 ResizeNearestNeighborOp::ResizeNearestNeighborOp(OpKernelConstruction* ctx)
585 : XlaOpKernel(ctx) {
586 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
587 OP_REQUIRES(
588 ctx, align_corners_ == true,
589 errors::Unimplemented("ResizeNearestNeighbor with align_corners=False "
590 "is not yet implemented"));
591 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
592 OP_REQUIRES(ctx, half_pixel_centers_ == false,
593 errors::Unimplemented(
594 "ResizeNearestNeighbor with half_pixel_centers=True is "
595 "not yet implemented"));
596 }
597
Compile(XlaOpKernelContext * ctx)598 void ResizeNearestNeighborOp::Compile(XlaOpKernelContext* ctx) {
599 GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
600 }
601
602 REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"),
603 ResizeNearestNeighborOp);
604
ResizeBilinearOp(OpKernelConstruction * ctx)605 ResizeBilinearOp::ResizeBilinearOp(OpKernelConstruction* ctx)
606 : XlaOpKernel(ctx) {
607 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
608 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
609 OP_REQUIRES(
610 ctx, half_pixel_centers_ == false,
611 errors::Unimplemented("ResizeBilinear with half_pixel_centers=True is "
612 "not yet implemented"));
613 }
614
Compile(XlaOpKernelContext * ctx)615 void ResizeBilinearOp::Compile(XlaOpKernelContext* ctx) {
616 GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
617 }
618
619 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"),
620 ResizeBilinearOp);
621
ResizeBilinearGradOp(OpKernelConstruction * ctx)622 ResizeBilinearGradOp::ResizeBilinearGradOp(OpKernelConstruction* ctx)
623 : XlaOpKernel(ctx) {
624 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
625 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
626 OP_REQUIRES(
627 ctx, align_corners_ == true,
628 errors::Unimplemented("ResizeBilinearGrad with align_corners=False is "
629 "not yet implemented"));
630 OP_REQUIRES(ctx, half_pixel_centers_ == false,
631 errors::Unimplemented(
632 "ResizeBilinearGrad with half_pixel_centers=True is "
633 "not yet implemented"));
634
635 DataType output_dtype;
636 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
637 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
638 }
639
Compile(XlaOpKernelContext * ctx)640 void ResizeBilinearGradOp::Compile(XlaOpKernelContext* ctx) {
641 xla::XlaBuilder* b = ctx->builder();
642
643 TensorShape input_shape = ctx->InputShape(1);
644 OP_REQUIRES(ctx, input_shape.dims() == 4,
645 errors::InvalidArgument("input must be 4-dimensional",
646 input_shape.DebugString()));
647 const int64 batch = input_shape.dim_size(0);
648 std::vector<int64> in_size = {input_shape.dim_size(1),
649 input_shape.dim_size(2)};
650 const int64 channels = input_shape.dim_size(3);
651 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
652 errors::InvalidArgument("input size must be positive, got [",
653 in_size[0], ",", in_size[1], "]"));
654
655 TensorShape grad_shape = ctx->InputShape(0);
656 OP_REQUIRES(ctx, grad_shape.dims() == 4,
657 errors::InvalidArgument("gradient must be 4-dimensional",
658 grad_shape.DebugString()));
659 const int64 grad_batch = grad_shape.dim_size(0);
660 const std::vector<int64> grad_size = {grad_shape.dim_size(1),
661 grad_shape.dim_size(2)};
662 const int64 grad_channels = grad_shape.dim_size(3);
663 OP_REQUIRES(ctx, batch == grad_batch,
664 errors::InvalidArgument(
665 "activations and gradients must have the same batch size (",
666 batch, " vs. ", grad_batch, ")"));
667 OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
668 errors::InvalidArgument("gradient size must be positive, got [",
669 grad_size[0], ",", grad_size[1], "]"));
670 OP_REQUIRES(
671 ctx, channels == grad_channels,
672 errors::InvalidArgument(
673 "activations and gradients must have the same number of channels (",
674 channels, " vs. ", grad_channels, ")"));
675
676 const int num_spatial_dims = 2;
677
678 xla::XlaOp grad = ctx->Input(0);
679
680 xla::XlaOp output = grad;
681 while (in_size != grad_size) {
682 if (in_size[0] != 1 && in_size[1] != 1) {
683 std::vector<float> k = {
684 (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
685 (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
686 if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
687 k[0] > 1 && k[1] > 1) {
688 std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
689 (in_size[1] - 1) * 2 + 1};
690 output = ResizeUsingDilationAndConvolutionGradOp(
691 b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size,
692 channels, align_corners_, true);
693 grad = output;
694 in_size = next_grad_size;
695 } else {
696 output = ResizeUsingDilationAndConvolutionGradOp(
697 b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
698 align_corners_, true);
699 in_size = grad_size;
700 }
701 } else {
702 output = ResizeUsingDilationAndConvolutionGradOp(
703 b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
704 align_corners_, true);
705 in_size = grad_size;
706 }
707 }
708
709 output = xla::ConvertElementType(output, output_type_);
710 ctx->SetOutput(0, output);
711 }
712
713 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
714
715 } // namespace tensorflow
716