• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
16 
17 #include "absl/strings/str_format.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/jit/xla_activity.pb.h"
20 #include "tensorflow/compiler/jit/xla_activity_listener.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/lib/math/math_util.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 // We implement bilinear interpolation by upsampling followed by convolution.
37 // The basic idea is as follows. To scale from NxN to RxR:
38 //
39 //    1. S := (N - 1) /  gcd(N-1, R-1)
40 //    2. k := (R - 1) /  gcd(N-1, R-1)
41 //    3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
42 //
43 // For example, to Scale from 7x7 -> 15x15:
44 //
45 //    1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
46 //    2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
47 //    3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
48 //
49 //
50 // The 7x7 -> 15x15 case is much too large to write out in full as an
51 // example. The smallest interesting example is 3x3 -> 4x4.
52 //
53 // S := 2
54 // k := 3
55 //
56 // 00 03 06    00 00 00 00 00 00 00 00 00 00 00      00 02 04 06
57 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00   -> 06 08 10 12
58 // 18 21 24    00 00 00 00 00 03 00 00 06 00 00      12 14 16 18
59 //             00 00 00 00 00 00 00 00 00 00 00      18 20 22 24
60 //             00 00 00 00 00 00 00 00 00 00 00
61 //             00 00 09 00 00 12 00 00 15 00 00
62 //             00 00 00 00 00 00 00 00 00 00 00
63 //             00 00 00 00 00 00 00 00 00 00 00
64 //             00 00 18 00 00 21 00 00 24 00 00
65 //             00 00 00 00 00 00 00 00 00 00 00
66 //             00 00 00 00 00 00 00 00 00 00 00
67 //
68 // with the following convolutional kernel, with stride [2, 2]:
69 //       1 2 3 2 1
70 //       2 4 6 4 2
71 // 1/9 * 3 6 9 6 3
72 //       2 4 6 4 2
73 //       1 2 3 2 1
74 // Note that the convolution kernel matrix is separable and thus we can instead
75 // use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
76 
77 // Computes the size of the convolutional kernel and stride to use when resizing
78 // from in_size to out_size.
79 struct ResizeConvolutionDims {
80   // Size of the kernel to use.
81   std::vector<int64> kernel_size;  // k
82 
83   // Stride of the convolution to use.
84   std::vector<int64> stride;  // S
85 };
ComputeResizeConvolutionParameters(absl::Span<const int64> in_size,absl::Span<const int64> out_size,bool align_corners)86 ResizeConvolutionDims ComputeResizeConvolutionParameters(
87     absl::Span<const int64> in_size, absl::Span<const int64> out_size,
88     bool align_corners) {
89   CHECK_EQ(in_size.size(), out_size.size());
90   int num_spatial_dims = in_size.size();
91   ResizeConvolutionDims dims;
92   dims.kernel_size.resize(num_spatial_dims);
93   dims.stride.resize(num_spatial_dims);
94   for (int i = 0; i < num_spatial_dims; ++i) {
95     if (in_size[i] == 1) {
96       // We must handle input size 1 specially because XLA convolution does
97       // not allow stride 0.
98       dims.stride[i] = dims.kernel_size[i] = 1;
99     } else if (out_size[i] == 1) {
100       // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
101       // entry before resizing.
102       dims.stride[i] = dims.kernel_size[i] = 1;
103     } else {
104       // The scaling factor changes depending on the alignment of corners.
105       const int64_t in_size_factor =
106           align_corners ? in_size[i] - 1 : in_size[i];
107       const int64_t out_size_factor =
108           align_corners ? out_size[i] - 1 : out_size[i];
109 
110       int64_t gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
111                                   static_cast<uint64>(out_size_factor));
112       dims.stride[i] = in_size_factor / gcd;
113       dims.kernel_size[i] = out_size_factor / gcd;
114     }
115   }
116   return dims;
117 }
118 
119 // The upper padding of the input needed by ConvGeneralDilated calls is
120 // determined by solving two related relationships (assuming rhs_dilation == 0):
121 // 1. dilated_input_dim = lower_padding + upper_padding
122 //                        + lhs_dilation * (in_size - 1) + 1
123 // 2. dilated_input_dim = (2 * dims.kernel-size - 1)
124 //                        + dims.stride * (out_size - 1)
CalculateUpperPadding(int64_t in_size,int64_t out_size,int64_t kernel_size,int64_t stride)125 int64 CalculateUpperPadding(int64_t in_size, int64_t out_size,
126                             int64_t kernel_size, int64_t stride) {
127   int64_t padding = (2 * kernel_size - 1) + (out_size - 1) * stride -
128                     (kernel_size - 1) - 1 - (kernel_size * (in_size - 1));
129 
130   return padding;
131 }
132 
133 // Form a 2D convolution kernel like:
134 //       1 2 3 2 1
135 //       2 4 6 4 2
136 // 1/9 * 3 6 9 6 3
137 //       2 4 6 4 2
138 //       1 2 3 2 1
139 // by multiplying two 1D kernels of the form:
140 // 1/3 * [1 2 3 2 1]
141 // If the 2D kernel would be very large, the 1D kernel can be applied once in
142 // each dimension due to the symmetry of the kernel along all axis to reduce the
143 // computational intensity.
MakeBilinear1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)144 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder,
145                                 xla::PrimitiveType type, int64_t n) {
146   std::vector<float> kernel(n * 2 - 1);
147   for (int64_t i = 0; i < n; ++i) {
148     float v = (i + 1.0f) / n;
149     kernel[i] = v;
150     kernel[n * 2 - 2 - i] = v;
151   }
152   return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
153 }
154 
155 // Unlike the bilinear kernel, which is triangular, the nearest neighbor
156 // kernel is a square. For example, a 1D kernel with n=3 would look like
157 // [0 1 1 1 0]
158 // and n=4 would look like
159 // [0 0 1 1 1 1 0].
160 // Note that in the second case, the kernel is not symmetric and we default
161 // to the right (because an existing non TPU kernel
162 // for nearest neighbor resize already chose to default to the right,
163 // so we want to be consistent).
MakeNearestNeighbor1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)164 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder,
165                                        xla::PrimitiveType type, int64_t n) {
166   std::vector<float> kernel(n * 2 - 1, 0.0f);
167   std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f);
168 
169   return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
170 }
171 
172 // Kernels with more than 16 spatial elements are considered intense and the
173 // kernel should be applied to each dimension independently.
174 const int64_t kMax2DKernelSize = 16;
175 
MakeGeneralResizeKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64_t channels,bool is_kernel_bilinear)176 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder,
177                                    xla::PrimitiveType type,
178                                    absl::Span<const int64> kernel_size,
179                                    int64_t channels, bool is_kernel_bilinear) {
180   auto make_kernel_func =
181       is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
182 
183   std::vector<int64> depthwise_kernel_sizes = {
184       (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1};
185   auto depthwise_kernel =
186       xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]),
187                           depthwise_kernel_sizes, /*broadcast_dimensions=*/{1});
188 
189   return xla::Mul(depthwise_kernel,
190                   make_kernel_func(builder, type, kernel_size[0]),
191                   /*broadcast_dimensions=*/{0});
192 }
193 
MakeGeneralResizeKernelInDim(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64> kernel_size,int64_t channels,int64_t dim,bool is_kernel_bilinear)194 xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder,
195                                         xla::PrimitiveType type,
196                                         absl::Span<const int64> kernel_size,
197                                         int64_t channels, int64_t dim,
198                                         bool is_kernel_bilinear) {
199   auto make_kernel_func =
200       is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
201 
202   std::vector<int64> depthwise_kernel_sizes = {
203       dim == 0 ? (2 * kernel_size[0] - 1) : 1,
204       dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1};
205   return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]),
206                              depthwise_kernel_sizes,
207                              /*broadcast_dimensions=*/{dim});
208 }
209 
BroadcastSpatialDimensions(xla::XlaBuilder * builder,const xla::XlaOp & input,int32_t spatial_dimensions_offset,absl::Span<const int64> in_size,absl::Span<const int64> out_size)210 xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder,
211                                       const xla::XlaOp& input,
212                                       int32_t spatial_dimensions_offset,
213                                       absl::Span<const int64> in_size,
214                                       absl::Span<const int64> out_size) {
215   // Add broadcasts to handle expanding from a size == 1 dimension to a
216   // size > 1 dimension.
217   auto broadcast_shape_or_status = builder->GetShape(input);
218   if (!broadcast_shape_or_status.ok()) {
219     return builder->ReportError(broadcast_shape_or_status.status());
220   }
221   xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie();
222   for (int32_t i = 0; i < in_size.size(); ++i) {
223     if (in_size[i] == 1 && out_size[i] > 1) {
224       broadcast_shape.set_dimensions(spatial_dimensions_offset + i,
225                                      out_size[i]);
226     }
227   }
228   return xla::BroadcastInDim(input, broadcast_shape.dimensions(),
229                              /*broadcast_dimensions=*/{0, 1, 2, 3});
230 }
231 
ResizeUsingDilationAndConvolution(xla::XlaBuilder * builder,const xla::XlaOp & input,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> out_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)232 xla::XlaOp ResizeUsingDilationAndConvolution(
233     xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type,
234     const int num_spatial_dims, absl::Span<const int64> in_size,
235     absl::Span<const int64> out_size, const int64_t channels,
236     const bool align_corners, bool is_kernel_bilinear) {
237   // Picture for a 1x3 to 1x4 bilinear resize:
238   // stride = 2, kernel size = 3
239   // Input:
240   // 3 6 9
241   // Input with dilation and padding:
242   // 0 0 3 0 0 6 0 0 9 0 0
243   // Convolution kernel:
244   // 1/3 * [1 2 3 2 1]
245   // Output:
246   // 3 5 7 9
247   xla::ConvolutionDimensionNumbers dimension_numbers;
248   dimension_numbers.set_input_batch_dimension(0);
249   dimension_numbers.set_output_batch_dimension(0);
250   dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
251   dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
252   for (int i = 0; i < num_spatial_dims; ++i) {
253     dimension_numbers.add_input_spatial_dimensions(1 + i);
254     dimension_numbers.add_output_spatial_dimensions(1 + i);
255     dimension_numbers.add_kernel_spatial_dimensions(i);
256   }
257   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
258   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
259 
260   ResizeConvolutionDims dims =
261       ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
262 
263   if (dims.kernel_size[0] * dims.kernel_size[1] >
264       kMax2DKernelSize * kMax2DKernelSize) {
265     BroadcastOptimizationRemark(
266         XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
267         absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
268         .IgnoreError();
269   }
270 
271   xla::XlaOp output;
272 
273   // Concatenation and padding below currently assumes num_spatial_dims is 2 to
274   // prevent needless code complexity.
275   CHECK_EQ(num_spatial_dims, 2)
276       << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
277   std::vector<int64> upper_padding(num_spatial_dims);
278   for (int i = 0; i < num_spatial_dims; ++i) {
279     upper_padding[i] = dims.kernel_size[i] - 1;
280   }
281   xla::XlaOp input_data = input;
282 
283   if (!align_corners) {
284     // When Tensorflow does not align_corners, the resize indexing can access
285     // beyond the upper bound and is instead clamped to prevent out of bounds
286     // reads. This is conceptually the same as extending the edges of the input.
287     // We emulate this by copying the last row/column of the input.
288     // Calculate what padding would be needed then determine how far to extend
289     // the border before lhs dilation.
290     std::vector<int64> num_extended(num_spatial_dims);
291     upper_padding[0] = CalculateUpperPadding(
292         in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
293     upper_padding[1] = CalculateUpperPadding(
294         in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
295     num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
296     num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
297 
298     const int64_t batch_dim_size =
299         builder->GetShape(input).ValueOrDie().dimensions(0);
300     if (num_extended[0] > 0) {
301       auto slice = xla::Slice(
302           input_data, {0, in_size[0] - 1, 0, 0},
303           {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
304       for (int i = 0; i < num_extended[0]; i++) {
305         input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
306       }
307     }
308 
309     if (num_extended[1] > 0) {
310       auto slice = xla::Slice(
311           input_data, {0, 0, in_size[1] - 1, 0},
312           {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
313           {1, 1, 1, 1});
314       for (int i = 0; i < num_extended[1]; i++) {
315         input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
316       }
317     }
318 
319     // Setting in_size to (in_size + num_extended) due to the above Slice and
320     // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
321     upper_padding[0] =
322         CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
323                               dims.kernel_size[0], dims.stride[0]);
324     upper_padding[1] =
325         CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
326                               dims.kernel_size[1], dims.stride[1]);
327   }
328 
329   // Split convolutions into independent dimensions if they would be a very
330   // large kernel or if one or more of the dimensions are already equal.
331   bool decompose_resize =
332       in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
333       dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
334   if (!decompose_resize) {
335     xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
336                                                 channels, is_kernel_bilinear);
337     output =
338         xla::ConvGeneralDilated(input_data, kernel, dims.stride,
339                                 /*padding=*/
340                                 {{dims.kernel_size[0] - 1, upper_padding[0]},
341                                  {dims.kernel_size[1] - 1, upper_padding[1]}},
342                                 /*lhs_dilation=*/dims.kernel_size,
343                                 /*rhs_dilation=*/{1, 1}, dimension_numbers,
344                                 /*feature_group_count=*/channels);
345   } else {
346     output = input_data;
347     if (in_size[0] != out_size[0]) {
348       xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
349           builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
350       output = xla::ConvGeneralDilated(
351           output, kernel0, {dims.stride[0], 1},
352           /*padding=*/
353           {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
354           /*lhs_dilation=*/{dims.kernel_size[0], 1},
355           /*rhs_dilation=*/{1, 1}, dimension_numbers,
356           /*feature_group_count=*/channels);
357     }
358 
359     if (in_size[1] != out_size[1]) {
360       xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
361           builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
362       output = xla::ConvGeneralDilated(
363           output, kernel1, {1, dims.stride[1]},
364           /*padding=*/
365           {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
366           /*lhs_dilation=*/{1, dims.kernel_size[1]},
367           /*rhs_dilation=*/{1, 1}, dimension_numbers,
368           /*feature_group_count=*/channels);
369     }
370   }
371 
372   // Add broadcasts to handle expanding from a size == 1 dimension to a
373   // size > 1 dimension.
374   return BroadcastSpatialDimensions(
375       builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size);
376 }
377 
ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder * builder,const xla::XlaOp & grad,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64> in_size,absl::Span<const int64> grad_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)378 xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(
379     xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type,
380     const int num_spatial_dims, absl::Span<const int64> in_size,
381     absl::Span<const int64> grad_size, const int64_t channels,
382     const bool align_corners, bool is_kernel_bilinear) {
383   ResizeConvolutionDims dims =
384       ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
385 
386   // To form the backward convolution, we keep the kernel unchanged (it is
387   // already symmetric) and swap the roles of strides and LHS dilation.
388   xla::ConvolutionDimensionNumbers dimension_numbers;
389   dimension_numbers.set_input_batch_dimension(0);
390   dimension_numbers.set_output_batch_dimension(0);
391   dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
392   dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
393   for (int i = 0; i < num_spatial_dims; ++i) {
394     dimension_numbers.add_input_spatial_dimensions(i + 1);
395     dimension_numbers.add_output_spatial_dimensions(i + 1);
396     dimension_numbers.add_kernel_spatial_dimensions(i);
397   }
398   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
399   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
400   xla::XlaOp output;
401   if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
402     xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
403                                                 channels, is_kernel_bilinear);
404 
405     // Broadcast the input kernel where the forward op expanded from a size == 1
406     // dimension to a size > 1 dimension. This has the effect of summing the
407     // gradient contributions in that dimension.
408     kernel = BroadcastSpatialDimensions(
409         builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size);
410 
411     output = xla::ConvGeneralDilated(
412         grad, kernel, /*window_strides=*/dims.kernel_size,
413         /*padding=*/
414         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
415          {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
416         /*lhs_dilation=*/dims.stride,
417         /*rhs_dilation=*/{1, 1}, dimension_numbers,
418         /*feature_group_count=*/channels);
419   } else {
420     xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
421         builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
422     xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
423         builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
424 
425     // Broadcast the input kernel where the forward op expanded from a
426     // size == 1 dimension to a size > 1 dimension. This has the effect of
427     // summing the gradient contributions in that dimension.
428     if (in_size[0] == 1 && grad_size[0] > 1) {
429       kernel0 = BroadcastSpatialDimensions(builder, kernel0,
430                                            /*spatial_dimensions_offset=*/0, {1},
431                                            {grad_size[0]});
432     }
433     if (in_size[1] == 1 && grad_size[1] > 1) {
434       kernel1 = BroadcastSpatialDimensions(builder, kernel0,
435                                            /*spatial_dimensions_offset=*/0,
436                                            in_size, grad_size);
437     }
438 
439     output = xla::ConvGeneralDilated(
440         grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
441         /*padding=*/
442         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
443         /*lhs_dilation=*/{dims.stride[0], 1},
444         /*rhs_dilation=*/{1, 1}, dimension_numbers,
445         /*feature_group_count=*/channels);
446 
447     output = xla::ConvGeneralDilated(
448         output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
449         /*padding=*/
450         {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
451         /*lhs_dilation=*/{1, dims.stride[1]},
452         /*rhs_dilation=*/{1, 1}, dimension_numbers,
453         /*feature_group_count=*/channels);
454   }
455 
456   // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
457   // Opposite of the slice performed by the forward op.
458   xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
459   bool pad_output = false;
460   for (int i = 0; i < num_spatial_dims; ++i) {
461     if (in_size[i] > 1 && grad_size[i] == 1) {
462       pad_output = true;
463       padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
464     }
465   }
466   if (pad_output) {
467     output = xla::Pad(output, xla::Zero(builder, type), padding);
468   }
469   return output;
470 }
471 
GeneralCompile(XlaOpKernelContext * ctx,bool align_corners_,bool is_kernel_bilinear)472 void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_,
473                     bool is_kernel_bilinear) {
474   xla::XlaBuilder* b = ctx->builder();
475 
476   TensorShape input_shape = ctx->InputShape(0);
477   OP_REQUIRES(ctx, input_shape.dims() == 4,
478               errors::InvalidArgument("input must be 4-dimensional",
479                                       input_shape.DebugString()));
480   // First dimension always assumed to be batch
481   const int64_t batch = input_shape.dim_size(0);
482   std::vector<int64> in_size = {input_shape.dim_size(1),
483                                 input_shape.dim_size(2)};
484   // Last/4th dimension always assumed to be num channels
485   const int64_t channels = input_shape.dim_size(3);
486   OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
487               errors::InvalidArgument("input size must be positive, got [",
488                                       in_size[0], ",", in_size[1], "]"));
489 
490   std::vector<int64> out_size;
491   OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
492   OP_REQUIRES(ctx, out_size.size() == 2,
493               errors::InvalidArgument("output size must be length 2, got ",
494                                       out_size.size()));
495   OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
496               errors::InvalidArgument("output size must be positive, got [",
497                                       out_size[0], ",", out_size[1], "]"));
498 
499   const int num_spatial_dims = 2;
500 
501   xla::XlaOp input = ctx->Input(0);
502   xla::PrimitiveType input_type = ctx->input_xla_type(0);
503 
504   // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
505   // dimension i.
506   bool slice_input = false;
507   for (int i = 0; i < num_spatial_dims; ++i) {
508     if (in_size[i] > 1 && out_size[i] == 1) {
509       // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
510       // entry before resizing.
511       slice_input = true;
512       in_size[i] = 1;
513     }
514   }
515   if (slice_input) {
516     input = xla::Slice(input, {0, 0, 0, 0},
517                        {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
518   }
519 
520   // Output is always type float if 'is_kernel_bilinear' is true.
521   // GPU with integer input also uses float, because XLA
522   // integer convolution on CuDNN is either not supported or not allowed
523   // directly.
524   xla::PrimitiveType original_input_type = input_type;
525   if (is_kernel_bilinear || (xla::primitive_util::IsIntegralType(input_type))) {
526     input = xla::ConvertElementType(input, xla::F32);
527     input_type = xla::F32;
528   }
529 
530   for (int dim = 0; dim < in_size.size(); ++dim) {
531     // If the pairwise_distance function more accurately estimated performance,
532     // this threshold could be reduced.
533     constexpr int64_t kSmallDimThreshold = 1 << 10;
534     if (in_size[dim] > out_size[dim] || out_size[dim] < kSmallDimThreshold) {
535       std::vector<int64> next_size = in_size;
536       next_size[dim] = out_size[dim];
537       input = ResizeUsingDilationAndConvolution(
538           b, input, input_type, num_spatial_dims, in_size, next_size, channels,
539           align_corners_, is_kernel_bilinear);
540       in_size[dim] = next_size[dim];
541     }
542   }
543 
544   // This function approximates the cost of a bilinear resize from a src_size to
545   // a dst_size. The accuracy is okay, but empirically, the algorithm makes some
546   // suboptimal choices. A better cost model would improve performance.
547   auto pairwise_distance = [align_corners_](int64_t src_size,
548                                             int64_t dst_size) {
549     auto params = ComputeResizeConvolutionParameters({src_size}, {dst_size},
550                                                      align_corners_);
551     return params.stride[0];
552   };
553 
554   for (int dim = 0; dim < in_size.size(); ++dim) {
555     std::vector<int64> distances(out_size[dim] + 1);
556     std::vector<int64> next_step(out_size[dim] + 1);
557     for (int64_t i = distances.size() - 2; i >= in_size[dim]; --i) {
558       distances[i] = INT64_MAX;
559       for (int64_t j = i + 1; j < distances.size(); ++j) {
560         int64_t distance = pairwise_distance(i, j) + distances[j];
561         if (distance < distances[i]) {
562           distances[i] = distance;
563           next_step[i] = j;
564         }
565       }
566     }
567 
568     while (in_size[dim] != out_size[dim]) {
569       auto next_size = in_size;
570       next_size[dim] = next_step[in_size[dim]];
571       input = ResizeUsingDilationAndConvolution(
572           b, input, input_type, num_spatial_dims, in_size, next_size, channels,
573           align_corners_, is_kernel_bilinear);
574       in_size[dim] = next_size[dim];
575     }
576   }
577 
578   // Bilinear always outputs float, but nearest neighbor keeps the original type
579   if (!is_kernel_bilinear && original_input_type != input_type) {
580     input = xla::ConvertElementType(input, original_input_type);
581   }
582   ctx->SetOutput(0, input);
583 }
584 }  // namespace
585 
ResizeNearestNeighborOp(OpKernelConstruction * ctx)586 ResizeNearestNeighborOp::ResizeNearestNeighborOp(OpKernelConstruction* ctx)
587     : XlaOpKernel(ctx) {
588   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
589   OP_REQUIRES(
590       ctx, align_corners_ == true,
591       errors::Unimplemented("ResizeNearestNeighbor with align_corners=False "
592                             "is not yet implemented"));
593   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
594   OP_REQUIRES(ctx, half_pixel_centers_ == false,
595               errors::Unimplemented(
596                   "ResizeNearestNeighbor with half_pixel_centers=True is "
597                   "not yet implemented"));
598 }
599 
Compile(XlaOpKernelContext * ctx)600 void ResizeNearestNeighborOp::Compile(XlaOpKernelContext* ctx) {
601   GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
602 }
603 
604 REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"),
605                 ResizeNearestNeighborOp);
606 
ResizeBilinearOp(OpKernelConstruction * ctx)607 ResizeBilinearOp::ResizeBilinearOp(OpKernelConstruction* ctx)
608     : XlaOpKernel(ctx) {
609   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
610   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
611   OP_REQUIRES(
612       ctx, half_pixel_centers_ == false,
613       errors::Unimplemented("ResizeBilinear with half_pixel_centers=True is "
614                             "not yet implemented"));
615 }
616 
Compile(XlaOpKernelContext * ctx)617 void ResizeBilinearOp::Compile(XlaOpKernelContext* ctx) {
618   GeneralCompile(ctx, align_corners_, is_kernel_bilinear_);
619 }
620 
621 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"),
622                 ResizeBilinearOp);
623 
ResizeBilinearGradOp(OpKernelConstruction * ctx)624 ResizeBilinearGradOp::ResizeBilinearGradOp(OpKernelConstruction* ctx)
625     : XlaOpKernel(ctx) {
626   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
627   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
628   OP_REQUIRES(
629       ctx, align_corners_ == true,
630       errors::Unimplemented("ResizeBilinearGrad with align_corners=False is "
631                             "not yet implemented"));
632   OP_REQUIRES(ctx, half_pixel_centers_ == false,
633               errors::Unimplemented(
634                   "ResizeBilinearGrad with half_pixel_centers=True is "
635                   "not yet implemented"));
636 
637   DataType output_dtype;
638   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
639   OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
640 }
641 
Compile(XlaOpKernelContext * ctx)642 void ResizeBilinearGradOp::Compile(XlaOpKernelContext* ctx) {
643   xla::XlaBuilder* b = ctx->builder();
644 
645   TensorShape input_shape = ctx->InputShape(1);
646   OP_REQUIRES(ctx, input_shape.dims() == 4,
647               errors::InvalidArgument("input must be 4-dimensional",
648                                       input_shape.DebugString()));
649   const int64_t batch = input_shape.dim_size(0);
650   std::vector<int64> in_size = {input_shape.dim_size(1),
651                                 input_shape.dim_size(2)};
652   const int64_t channels = input_shape.dim_size(3);
653   OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
654               errors::InvalidArgument("input size must be positive, got [",
655                                       in_size[0], ",", in_size[1], "]"));
656 
657   TensorShape grad_shape = ctx->InputShape(0);
658   OP_REQUIRES(ctx, grad_shape.dims() == 4,
659               errors::InvalidArgument("gradient must be 4-dimensional",
660                                       grad_shape.DebugString()));
661   const int64_t grad_batch = grad_shape.dim_size(0);
662   const std::vector<int64> grad_size = {grad_shape.dim_size(1),
663                                         grad_shape.dim_size(2)};
664   const int64_t grad_channels = grad_shape.dim_size(3);
665   OP_REQUIRES(ctx, batch == grad_batch,
666               errors::InvalidArgument(
667                   "activations and gradients must have the same batch size (",
668                   batch, " vs. ", grad_batch, ")"));
669   OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
670               errors::InvalidArgument("gradient size must be positive, got [",
671                                       grad_size[0], ",", grad_size[1], "]"));
672   OP_REQUIRES(
673       ctx, channels == grad_channels,
674       errors::InvalidArgument(
675           "activations and gradients must have the same number of channels (",
676           channels, " vs. ", grad_channels, ")"));
677 
678   const int num_spatial_dims = 2;
679 
680   xla::XlaOp grad = ctx->Input(0);
681 
682   xla::XlaOp output = grad;
683   while (in_size != grad_size) {
684     if (in_size[0] != 1 && in_size[1] != 1) {
685       std::vector<float> k = {
686           (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
687           (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
688       if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
689           k[0] > 1 && k[1] > 1) {
690         std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
691                                              (in_size[1] - 1) * 2 + 1};
692         output = ResizeUsingDilationAndConvolutionGradOp(
693             b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size,
694             channels, align_corners_, true);
695         grad = output;
696         in_size = next_grad_size;
697       } else {
698         output = ResizeUsingDilationAndConvolutionGradOp(
699             b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
700             align_corners_, true);
701         in_size = grad_size;
702       }
703     } else {
704       output = ResizeUsingDilationAndConvolutionGradOp(
705           b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
706           align_corners_, true);
707       in_size = grad_size;
708     }
709   }
710 
711   output = xla::ConvertElementType(output, output_type_);
712   ctx->SetOutput(0, output);
713 }
714 
715 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
716 
717 }  // namespace tensorflow
718