• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
16 
17 #include "absl/strings/str_format.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/jit/xla_activity.pb.h"
20 #include "tensorflow/compiler/jit/xla_activity_listener.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/lib/math/math_util.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 // We implement bilinear interpolation by upsampling followed by convolution.
37 // The basic idea is as follows. To scale from NxN to RxR:
38 //
39 //    1. S := (N - 1) /  gcd(N-1, R-1)
40 //    2. k := (R - 1) /  gcd(N-1, R-1)
41 //    3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
42 //
43 // For example, to Scale from 7x7 -> 15x15:
44 //
45 //    1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
46 //    2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
47 //    3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
48 //
49 //
50 // The 7x7 -> 15x15 case is much too large to write out in full as an
51 // example. The smallest interesting example is 3x3 -> 4x4.
52 //
53 // S := 2
54 // k := 3
55 //
56 // 00 03 06    00 00 00 00 00 00 00 00 00 00 00      00 02 04 06
57 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00   -> 06 08 10 12
58 // 18 21 24    00 00 00 00 00 03 00 00 06 00 00      12 14 16 18
59 //             00 00 00 00 00 00 00 00 00 00 00      18 20 22 24
60 //             00 00 00 00 00 00 00 00 00 00 00
61 //             00 00 09 00 00 12 00 00 15 00 00
62 //             00 00 00 00 00 00 00 00 00 00 00
63 //             00 00 00 00 00 00 00 00 00 00 00
64 //             00 00 18 00 00 21 00 00 24 00 00
65 //             00 00 00 00 00 00 00 00 00 00 00
66 //             00 00 00 00 00 00 00 00 00 00 00
67 //
68 // with the following convolutional kernel, with stride [2, 2]:
69 //       1 2 3 2 1
70 //       2 4 6 4 2
71 // 1/9 * 3 6 9 6 3
72 //       2 4 6 4 2
73 //       1 2 3 2 1
74 // Note that the convolution kernel matrix is separable and thus we can instead
75 // use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
76 
77 // Computes the size of the convolutional kernel and stride to use when resizing
78 // from in_size to out_size.
79 struct ResizeConvolutionDims {
80   // Size of the kernel to use.
81   std::vector<int64> kernel_size;  // k
82 
83   // Stride of the convolution to use.
84   std::vector<int64> stride;  // S
85 };
ComputeResizeConvolutionParameters(absl::Span<const int64> in_size,absl::Span<const int64> out_size,bool align_corners)86 ResizeConvolutionDims ComputeResizeConvolutionParameters(
87     absl::Span<const int64> in_size, absl::Span<const int64> out_size,
88     bool align_corners) {
89   CHECK_EQ(in_size.size(), out_size.size());
90   int num_spatial_dims = in_size.size();
91   ResizeConvolutionDims dims;
92   dims.kernel_size.resize(num_spatial_dims);
93   dims.stride.resize(num_spatial_dims);
94   for (int i = 0; i < num_spatial_dims; ++i) {
95     if (in_size[i] == 1) {
96       // We must handle input size 1 specially because XLA convolution does
97       // not allow stride 0.
98       dims.stride[i] = dims.kernel_size[i] = 1;
99     } else if (out_size[i] == 1) {
100       // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
101       // entry before resizing.
102       dims.stride[i] = dims.kernel_size[i] = 1;
103     } else {
104       // The scaling factor changes depending on the alignment of corners.
105       const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i];
106       const int64 out_size_factor =
107           align_corners ? out_size[i] - 1 : out_size[i];
108 
109       int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
110                                 static_cast<uint64>(out_size_factor));
111       dims.stride[i] = in_size_factor / gcd;
112       dims.kernel_size[i] = out_size_factor / gcd;
113     }
114   }
115   return dims;
116 }
117 
118 // The upper padding of the input needed by ConvGeneralDilated calls is
119 // determined by solving two related relationships (assuming rhs_dilation == 0):
120 // 1. dilated_input_dim = lower_padding + upper_padding
121 //                        + lhs_dilation * (in_size - 1) + 1
122 // 2. dilated_input_dim = (2 * dims.kernel-size - 1)
123 //                        + dims.stride * (out_size - 1)
CalculateUpperPadding(int64 in_size,int64 out_size,int64 kernel_size,int64 stride)124 int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
125                             int64 stride) {
126   int64 padding = (2 * kernel_size - 1) + (out_size - 1) * stride -
127                   (kernel_size - 1) - 1 - (kernel_size * (in_size - 1));
128 
129   return padding;
130 }
131 
132 // Form a 2D convolution kernel like:
133 //       1 2 3 2 1
134 //       2 4 6 4 2
135 // 1/9 * 3 6 9 6 3
136 //       2 4 6 4 2
137 //       1 2 3 2 1
138 // by multiplying two 1D kernels of the form:
139 // 1/3 * [1 2 3 2 1]
140 // If the 2D kernel would be very large, the 1D kernel can be applied once in
141 // each dimension due to the symmetry of the kernel along all axis to reduce the
142 // computational intensity.
MakeBilinear1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64 n)143 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder,
144                                 xla::PrimitiveType type, int64 n) {
145   std::vector<float> kernel(n * 2 - 1);
146   for (int64 i = 0; i < n; ++i) {
147     float v = (i + 1.0f) / n;
148     kernel[i] = v;
149     kernel[n * 2 - 2 - i] = v;
150   }
151   return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
152 }
153 
154 // Unlike the bilinear kernel, which is triangular, the nearest neighbor
155 // kernel is a square. For example, a 1D kernel with n=3 would look like
156 // [0 1 1 1 0]
157 // and n=4 would look like
158 // [0 0 1 1 1 1 0].
159 // Note that in the second case, the kernel is not symmetric and we default
160 // to the right (because an existing non TPU kernel
161 // for nearest neighbor resize already chose to default to the right,
162 // so we want to be consistent).
MakeNearestNeighbor1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64 n)163 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder,
164                                        xla::PrimitiveType type, int64 n) {
165   std::vector<float> kernel(n * 2 - 1, 0.0f);
166   std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f);
167 
168   return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
169 }
170 
171 // Kernels with more than 16 spatial elements are considered intense and the
172 // kernel should be applied to each dimension independently.
173 const int64 kMax2DKernelSize = 16;
174 
MakeGeneralResizeKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64 channels,bool is_kernel_bilinear)175 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder,
176                                    xla::PrimitiveType type,
177                                    absl::Span<const int64> kernel_size,
178                                    int64 channels, bool is_kernel_bilinear) {
179   auto make_kernel_func =
180       is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
181 
182   std::vector<int64> depthwise_kernel_sizes = {
183       (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1};
184   auto depthwise_kernel =
185       xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]),
186                           depthwise_kernel_sizes, /*broadcast_dimensions=*/{1});
187 
188   return xla::Mul(depthwise_kernel,
189                   make_kernel_func(builder, type, kernel_size[0]),
190                   /*broadcast_dimensions=*/{0});
191 }
192 
MakeGeneralResizeKernelInDim(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64 channels,int64 dim,bool is_kernel_bilinear)193 xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder,
194                                         xla::PrimitiveType type,
195                                         absl::Span<const int64> kernel_size,
196                                         int64 channels, int64 dim,
197                                         bool is_kernel_bilinear) {
198   auto make_kernel_func =
199       is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
200 
201   std::vector<int64> depthwise_kernel_sizes = {
202       dim == 0 ? (2 * kernel_size[0] - 1) : 1,
203       dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1};
204   return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]),
205                              depthwise_kernel_sizes,
206                              /*broadcast_dimensions=*/{dim});
207 }
208 
BroadcastSpatialDimensions(xla::XlaBuilder * builder,const xla::XlaOp & input,int32 spatial_dimensions_offset,absl::Span<const int64> in_size,absl::Span<const int64> out_size)209 xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder,
210                                       const xla::XlaOp& input,
211                                       int32 spatial_dimensions_offset,
212                                       absl::Span<const int64> in_size,
213                                       absl::Span<const int64> out_size) {
214   // Add broadcasts to handle expanding from a size == 1 dimension to a
215   // size > 1 dimension.
216   auto broadcast_shape_or_status = builder->GetShape(input);
217   if (!broadcast_shape_or_status.ok()) {
218     return builder->ReportError(broadcast_shape_or_status.status());
219   }
220   xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie();
221   for (int32 i = 0; i < in_size.size(); ++i) {
222     if (in_size[i] == 1 && out_size[i] > 1) {
223       broadcast_shape.set_dimensions(spatial_dimensions_offset + i,
224                                      out_size[i]);
225     }
226   }
227   return xla::BroadcastInDim(input, broadcast_shape.dimensions(),
228                              /*broadcast_dimensions=*/{0, 1, 2, 3});
229 }
230 
ResizeUsingDilationAndConvolution(xla::XlaBuilder * builder,const xla::XlaOp & input,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> out_size,const int64 channels,const bool align_corners,bool is_kernel_bilinear)231 xla::XlaOp ResizeUsingDilationAndConvolution(
232     xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type,
233     const int num_spatial_dims, absl::Span<const int64> in_size,
234     absl::Span<const int64> out_size, const int64 channels,
235     const bool align_corners, bool is_kernel_bilinear) {
236   // Picture for a 1x3 to 1x4 bilinear resize:
237   // stride = 2, kernel size = 3
238   // Input:
239   // 3 6 9
240   // Input with dilation and padding:
241   // 0 0 3 0 0 6 0 0 9 0 0
242   // Convolution kernel:
243   // 1/3 * [1 2 3 2 1]
244   // Output:
245   // 3 5 7 9
246   xla::ConvolutionDimensionNumbers dimension_numbers;
247   dimension_numbers.set_input_batch_dimension(0);
248   dimension_numbers.set_output_batch_dimension(0);
249   dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
250   dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
251   for (int i = 0; i < num_spatial_dims; ++i) {
252     dimension_numbers.add_input_spatial_dimensions(1 + i);
253     dimension_numbers.add_output_spatial_dimensions(1 + i);
254     dimension_numbers.add_kernel_spatial_dimensions(i);
255   }
256   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
257   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
258 
259   ResizeConvolutionDims dims =
260       ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
261 
262   if (dims.kernel_size[0] * dims.kernel_size[1] >
263       kMax2DKernelSize * kMax2DKernelSize) {
264     BroadcastOptimizationRemark(
265         XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
266         absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
267         .IgnoreError();
268   }
269 
270   xla::XlaOp output;
271 
272   // Concatenation and padding below currently assumes num_spatial_dims is 2 to
273   // prevent needless code complexity.
274   CHECK_EQ(num_spatial_dims, 2)
275       << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
276   std::vector<int64> upper_padding(num_spatial_dims);
277   for (int i = 0; i < num_spatial_dims; ++i) {
278     upper_padding[i] = dims.kernel_size[i] - 1;
279   }
280   xla::XlaOp input_data = input;
281 
282   if (!align_corners) {
283     // When Tensorflow does not align_corners, the resize indexing can access
284     // beyond the upper bound and is instead clamped to prevent out of bounds
285     // reads. This is conceptually the same as extending the edges of the input.
286     // We emulate this by copying the last row/column of the input.
287     // Calculate what padding would be needed then determine how far to extend
288     // the border before lhs dilation.
289     std::vector<int64> num_extended(num_spatial_dims);
290     upper_padding[0] = CalculateUpperPadding(
291         in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
292     upper_padding[1] = CalculateUpperPadding(
293         in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
294     num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
295     num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
296 
297     const int64 batch_dim_size =
298         builder->GetShape(input).ValueOrDie().dimensions(0);
299     if (num_extended[0] > 0) {
300       auto slice = xla::Slice(
301           input_data, {0, in_size[0] - 1, 0, 0},
302           {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
303       for (int i = 0; i < num_extended[0]; i++) {
304         input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
305       }
306     }
307 
308     if (num_extended[1] > 0) {
309       auto slice = xla::Slice(
310           input_data, {0, 0, in_size[1] - 1, 0},
311           {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
312           {1, 1, 1, 1});
313       for (int i = 0; i < num_extended[1]; i++) {
314         input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
315       }
316     }
317 
318     // Setting in_size to (in_size + num_extended) due to the above Slice and
319     // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
320     upper_padding[0] =
321         CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
322                               dims.kernel_size[0], dims.stride[0]);
323     upper_padding[1] =
324         CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
325                               dims.kernel_size[1], dims.stride[1]);
326   }
327 
328   // Split convolutions into independent dimensions if they would be a very
329   // large kernel or if one or more of the dimensions are already equal.
330   bool decompose_resize =
331       in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
332       dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
333   if (!decompose_resize) {
334     xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
335                                                 channels, is_kernel_bilinear);
336     output =
337         xla::ConvGeneralDilated(input_data, kernel, dims.stride,
338                                 /*padding=*/
339                                 {{dims.kernel_size[0] - 1, upper_padding[0]},
340                                  {dims.kernel_size[1] - 1, upper_padding[1]}},
341                                 /*lhs_dilation=*/dims.kernel_size,
342                                 /*rhs_dilation=*/{1, 1}, dimension_numbers,
343                                 /*feature_group_count=*/channels);
344   } else {
345     output = input_data;
346     if (in_size[0] != out_size[0]) {
347       xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
348           builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
349       output = xla::ConvGeneralDilated(
350           output, kernel0, {dims.stride[0], 1},
351           /*padding=*/
352           {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
353           /*lhs_dilation=*/{dims.kernel_size[0], 1},
354           /*rhs_dilation=*/{1, 1}, dimension_numbers,
355           /*feature_group_count=*/channels);
356     }
357 
358     if (in_size[1] != out_size[1]) {
359       xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
360           builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
361       output = xla::ConvGeneralDilated(
362           output, kernel1, {1, dims.stride[1]},
363           /*padding=*/
364           {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
365           /*lhs_dilation=*/{1, dims.kernel_size[1]},
366           /*rhs_dilation=*/{1, 1}, dimension_numbers,
367           /*feature_group_count=*/channels);
368     }
369   }
370 
371   // Add broadcasts to handle expanding from a size == 1 dimension to a
372   // size > 1 dimension.
373   return BroadcastSpatialDimensions(
374       builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size);
375 }
376 
ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder * builder,const xla::XlaOp & grad,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> grad_size,const int64 channels,const bool align_corners,bool is_kernel_bilinear)377 xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(
378     xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type,
379     const int num_spatial_dims, absl::Span<const int64> in_size,
380     absl::Span<const int64> grad_size, const int64 channels,
381     const bool align_corners, bool is_kernel_bilinear) {
382   ResizeConvolutionDims dims =
383       ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
384 
385   // To form the backward convolution, we keep the kernel unchanged (it is
386   // already symmetric) and swap the roles of strides and LHS dilation.
387   xla::ConvolutionDimensionNumbers dimension_numbers;
388   dimension_numbers.set_input_batch_dimension(0);
389   dimension_numbers.set_output_batch_dimension(0);
390   dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
391   dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
392   for (int i = 0; i < num_spatial_dims; ++i) {
393     dimension_numbers.add_input_spatial_dimensions(i + 1);
394     dimension_numbers.add_output_spatial_dimensions(i + 1);
395     dimension_numbers.add_kernel_spatial_dimensions(i);
396   }
397   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
398   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
399   xla::XlaOp output;
400   if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
401     xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
402                                                 channels, is_kernel_bilinear);
403 
404     // Broadcast the input kernel where the forward op expanded from a size == 1
405     // dimension to a size > 1 dimension. This has the effect of summing the
406     // gradient contributions in that dimension.
407     kernel = BroadcastSpatialDimensions(
408         builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size);
409 
410     output = xla::ConvGeneralDilated(
411         grad, kernel, /*window_strides=*/dims.kernel_size,
412         /*padding=*/
413         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
414          {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
415         /*lhs_dilation=*/dims.stride,
416         /*rhs_dilation=*/{1, 1}, dimension_numbers,
417         /*feature_group_count=*/channels);
418   } else {
419     xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
420         builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
421     xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
422         builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
423 
424     // Broadcast the input kernel where the forward op expanded from a
425     // size == 1 dimension to a size > 1 dimension. This has the effect of
426     // summing the gradient contributions in that dimension.
427     if (in_size[0] == 1 && grad_size[0] > 1) {
428       kernel0 = BroadcastSpatialDimensions(builder, kernel0,
429                                            /*spatial_dimensions_offset=*/0, {1},
430                                            {grad_size[0]});
431     }
432     if (in_size[1] == 1 && grad_size[1] > 1) {
433       kernel1 = BroadcastSpatialDimensions(builder, kernel0,
434                                            /*spatial_dimensions_offset=*/0,
435                                            in_size, grad_size);
436     }
437 
438     output = xla::ConvGeneralDilated(
439         grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
440         /*padding=*/
441         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
442         /*lhs_dilation=*/{dims.stride[0], 1},
443         /*rhs_dilation=*/{1, 1}, dimension_numbers,
444         /*feature_group_count=*/channels);
445 
446     output = xla::ConvGeneralDilated(
447         output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
448         /*padding=*/
449         {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
450         /*lhs_dilation=*/{1, dims.stride[1]},
451         /*rhs_dilation=*/{1, 1}, dimension_numbers,
452         /*feature_group_count=*/channels);
453   }
454 
455   // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
456   // Opposite of the slice performed by the forward op.
457   xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
458   bool pad_output = false;
459   for (int i = 0; i < num_spatial_dims; ++i) {
460     if (in_size[i] > 1 && grad_size[i] == 1) {
461       pad_output = true;
462       padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
463     }
464   }
465   if (pad_output) {
466     output = xla::Pad(output, xla::Zero(builder, type), padding);
467   }
468   return output;
469 }
470 
GeneralCompile(XlaOpKernelContext * ctx,bool align_corners_,bool is_kernel_bilinear)471 void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_,
472                     bool is_kernel_bilinear) {
473   xla::XlaBuilder* b = ctx->builder();
474 
475   TensorShape input_shape = ctx->InputShape(0);
476   OP_REQUIRES(ctx, input_shape.dims() == 4,
477               errors::InvalidArgument("input must be 4-dimensional",
478                                       input_shape.DebugString()));
479   // First dimension always assumed to be batch
480   const int64 batch = input_shape.dim_size(0);
481   std::vector<int64> in_size = {input_shape.dim_size(1),
482                                 input_shape.dim_size(2)};
483   // Last/4th dimension always assumed to be num channels
484   const int64 channels = input_shape.dim_size(3);
485   OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
486               errors::InvalidArgument("input size must be positive, got [",
487                                       in_size[0], ",", in_size[1], "]"));
488 
489   std::vector<int64> out_size;
490   OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
491   OP_REQUIRES(ctx, out_size.size() == 2,
492               errors::InvalidArgument("output size must be length 2, got ",
493                                       out_size.size()));
494   OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
495               errors::InvalidArgument("output size must be positive, got [",
496                                       out_size[0], ",", out_size[1], "]"));
497 
498   const int num_spatial_dims = 2;
499 
500   xla::XlaOp input = ctx->Input(0);
501   xla::PrimitiveType input_type = ctx->input_xla_type(0);
502 
503   // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
504   // dimension i.
505   bool slice_input = false;
506   for (int i = 0; i < num_spatial_dims; ++i) {
507     if (in_size[i] > 1 && out_size[i] == 1) {
508       // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
509       // entry before resizing.
510       slice_input = true;
511       in_size[i] = 1;
512     }
513   }
514   if (slice_input) {
515     input = xla::Slice(input, {0, 0, 0, 0},
516                        {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
517   }
518 
519   // Output is always type float if 'is_kernel_bilinear' is true.
520   // GPU with integer input also uses float, because XLA
521   // integer convolution on CuDNN is either not supported or not allowed
522   // directly.
523   xla::PrimitiveType original_input_type = input_type;
524   if (is_kernel_bilinear || (xla::primitive_util::IsIntegralType(input_type))) {
525     input = xla::ConvertElementType(input, xla::F32);
526     input_type = xla::F32;
527   }
528 
529   for (int dim = 0; dim < in_size.size(); ++dim) {
530     // If the pairwise_distance function more accurately estimated performance,
531     // this threshold could be reduced.
532     constexpr int64 kSmallDimThreshold = 1 << 10;
533     if (in_size[dim] > out_size[dim] || out_size[dim] < kSmallDimThreshold) {
534       std::vector<int64> next_size = in_size;
535       next_size[dim] = out_size[dim];
536       input = ResizeUsingDilationAndConvolution(
537           b, input, input_type, num_spatial_dims, in_size, next_size, channels,
538           align_corners_, is_kernel_bilinear);
539       in_size[dim] = next_size[dim];
540     }
541   }
542 
543   // This function approximates the cost of a bilinear resize from a src_size to
544   // a dst_size. The accuracy is okay, but empirically, the algorithm makes some
545   // suboptimal choices. A better cost model would improve performance.
546   auto pairwise_distance = [align_corners_](int64 src_size, int64 dst_size) {
547     auto params = ComputeResizeConvolutionParameters({src_size}, {dst_size},
548                                                      align_corners_);
549     return params.stride[0];
550   };
551 
552   for (int dim = 0; dim < in_size.size(); ++dim) {
553     std::vector<int64> distances(out_size[dim] + 1);
554     std::vector<int64> next_step(out_size[dim] + 1);
555     for (int64 i = distances.size() - 2; i >= in_size[dim]; --i) {
556       distances[i] = INT64_MAX;
557       for (int64 j = i + 1; j < distances.size(); ++j) {
558         int64 distance = pairwise_distance(i, j) + distances[j];
559         if (distance < distances[i]) {
560           distances[i] = distance;
561           next_step[i] = j;
562         }
563       }
564     }
565 
566     while (in_size[dim] != out_size[dim]) {
567       auto next_size = in_size;
568       next_size[dim] = next_step[in_size[dim]];
569       input = ResizeUsingDilationAndConvolution(
570           b, input, input_type, num_spatial_dims, in_size, next_size, channels,
571           align_corners_, is_kernel_bilinear);
572       in_size[dim] = next_size[dim];
573     }
574   }
575 
576   // Bilinear always outputs float, but nearest neighbor keeps the original type
577   if (!is_kernel_bilinear && original_input_type != input_type) {
578     input = xla::ConvertElementType(input, original_input_type);
579   }
580   ctx->SetOutput(0, input);
581 }
582 }  // namespace
583 
ResizeNearestNeighborOp(OpKernelConstruction * ctx)584 ResizeNearestNeighborOp::ResizeNearestNeighborOp(OpKernelConstruction* ctx)
585     : XlaOpKernel(ctx) {
586   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
587   OP_REQUIRES(
588       ctx, align_corners_ == true,
589       errors::Unimplemented("ResizeNearestNeighbor with align_corners=False "
590                             "is not yet implemented"));
591   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
592   OP_REQUIRES(ctx, half_pixel_centers_ == false,
593               errors::Unimplemented(
594                   "ResizeNearestNeighbor with half_pixel_centers=True is "
595                   "not yet implemented"));
596 }
597 
Compile(XlaOpKernelContext * ctx)598 void ResizeNearestNeighborOp::Compile(XlaOpKernelContext* ctx) {
599   GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
600 }
601 
602 REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"),
603                 ResizeNearestNeighborOp);
604 
ResizeBilinearOp(OpKernelConstruction * ctx)605 ResizeBilinearOp::ResizeBilinearOp(OpKernelConstruction* ctx)
606     : XlaOpKernel(ctx) {
607   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
608   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
609   OP_REQUIRES(
610       ctx, half_pixel_centers_ == false,
611       errors::Unimplemented("ResizeBilinear with half_pixel_centers=True is "
612                             "not yet implemented"));
613 }
614 
Compile(XlaOpKernelContext * ctx)615 void ResizeBilinearOp::Compile(XlaOpKernelContext* ctx) {
616   GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
617 }
618 
619 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"),
620                 ResizeBilinearOp);
621 
ResizeBilinearGradOp(OpKernelConstruction * ctx)622 ResizeBilinearGradOp::ResizeBilinearGradOp(OpKernelConstruction* ctx)
623     : XlaOpKernel(ctx) {
624   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
625   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
626   OP_REQUIRES(
627       ctx, align_corners_ == true,
628       errors::Unimplemented("ResizeBilinearGrad with align_corners=False is "
629                             "not yet implemented"));
630   OP_REQUIRES(ctx, half_pixel_centers_ == false,
631               errors::Unimplemented(
632                   "ResizeBilinearGrad with half_pixel_centers=True is "
633                   "not yet implemented"));
634 
635   DataType output_dtype;
636   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
637   OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
638 }
639 
Compile(XlaOpKernelContext * ctx)640 void ResizeBilinearGradOp::Compile(XlaOpKernelContext* ctx) {
641   xla::XlaBuilder* b = ctx->builder();
642 
643   TensorShape input_shape = ctx->InputShape(1);
644   OP_REQUIRES(ctx, input_shape.dims() == 4,
645               errors::InvalidArgument("input must be 4-dimensional",
646                                       input_shape.DebugString()));
647   const int64 batch = input_shape.dim_size(0);
648   std::vector<int64> in_size = {input_shape.dim_size(1),
649                                 input_shape.dim_size(2)};
650   const int64 channels = input_shape.dim_size(3);
651   OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
652               errors::InvalidArgument("input size must be positive, got [",
653                                       in_size[0], ",", in_size[1], "]"));
654 
655   TensorShape grad_shape = ctx->InputShape(0);
656   OP_REQUIRES(ctx, grad_shape.dims() == 4,
657               errors::InvalidArgument("gradient must be 4-dimensional",
658                                       grad_shape.DebugString()));
659   const int64 grad_batch = grad_shape.dim_size(0);
660   const std::vector<int64> grad_size = {grad_shape.dim_size(1),
661                                         grad_shape.dim_size(2)};
662   const int64 grad_channels = grad_shape.dim_size(3);
663   OP_REQUIRES(ctx, batch == grad_batch,
664               errors::InvalidArgument(
665                   "activations and gradients must have the same batch size (",
666                   batch, " vs. ", grad_batch, ")"));
667   OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
668               errors::InvalidArgument("gradient size must be positive, got [",
669                                       grad_size[0], ",", grad_size[1], "]"));
670   OP_REQUIRES(
671       ctx, channels == grad_channels,
672       errors::InvalidArgument(
673           "activations and gradients must have the same number of channels (",
674           channels, " vs. ", grad_channels, ")"));
675 
676   const int num_spatial_dims = 2;
677 
678   xla::XlaOp grad = ctx->Input(0);
679 
680   xla::XlaOp output = grad;
681   while (in_size != grad_size) {
682     if (in_size[0] != 1 && in_size[1] != 1) {
683       std::vector<float> k = {
684           (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
685           (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
686       if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
687           k[0] > 1 && k[1] > 1) {
688         std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
689                                              (in_size[1] - 1) * 2 + 1};
690         output = ResizeUsingDilationAndConvolutionGradOp(
691             b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size,
692             channels, align_corners_, true);
693         grad = output;
694         in_size = next_grad_size;
695       } else {
696         output = ResizeUsingDilationAndConvolutionGradOp(
697             b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
698             align_corners_, true);
699         in_size = grad_size;
700       }
701     } else {
702       output = ResizeUsingDilationAndConvolutionGradOp(
703           b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
704           align_corners_, true);
705       in_size = grad_size;
706     }
707   }
708 
709   output = xla::ConvertElementType(output, output_type_);
710   ctx->SetOutput(0, output);
711 }
712 
713 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
714 
715 }  // namespace tensorflow
716