• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 #define EIGEN_USE_THREADS
18 
19 #include <math.h>
20 
21 #include <algorithm>
22 #include <array>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/util/image_resizer_state.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 static const int64_t kTableSize = (1 << 10);
38 
InitCoeffsTable(const double a)39 const float* InitCoeffsTable(const double a) {
40   // Allocate and initialize coefficients table using Bicubic
41   // convolution algorithm.
42   // https://en.wikipedia.org/wiki/Bicubic_interpolation
43   float* coeffs_table = new float[(kTableSize + 1) * 2];
44   for (int i = 0; i <= kTableSize; ++i) {
45     float x = i * 1.0 / kTableSize;
46     coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1;
47     x += 1.0;
48     coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
49   }
50 
51   return coeffs_table;
52 }
53 
GetCoeffsTable(const bool use_keys_cubic)54 const float* GetCoeffsTable(const bool use_keys_cubic) {
55   // Static so that we initialize it on first use
56   if (use_keys_cubic) {
57     // http://ieeexplore.ieee.org/document/1163711/
58     // R. G. Keys. Cubic convolution interpolation for digital image
59     // processing. IEEE Transactions on Acoustics, Speech, and Signal
60     // Processing, 29(6):1153–1160, 1981.
61     static const float* coeffs_table = InitCoeffsTable(-0.5f);
62     return coeffs_table;
63   } else {
64     static const float* coeffs_table = InitCoeffsTable(-0.75f);
65     return coeffs_table;
66   }
67 }
68 
Bound(int64_t val,int64_t limit)69 inline int64 Bound(int64_t val, int64_t limit) {
70   return std::min(limit - 1, std::max(int64{0}, val));
71 }
72 
73 struct WeightsAndIndices {
74   float weight_0;
75   float weight_1;
76   float weight_2;
77   float weight_3;
78   int64 index_0;
79   int64 index_1;
80   int64 index_2;
81   int64 index_3;
82 
83   int advance;  // advance value.
84 };
85 
86 template <typename Scaler, bool use_keys_cubic>
GetWeightsAndIndices(const float scale,const int64_t out_loc,const int64_t limit,WeightsAndIndices * out)87 inline void GetWeightsAndIndices(const float scale, const int64_t out_loc,
88                                  const int64_t limit, WeightsAndIndices* out) {
89   const Scaler scaler;
90   const float in_loc_f = scaler(out_loc, scale);
91   const int64_t in_loc = std::floor(in_loc_f);
92   const float delta = in_loc_f - in_loc;
93   const int64_t offset = lrintf(delta * kTableSize);
94   const float* coeffs_table = GetCoeffsTable(use_keys_cubic);
95   if (use_keys_cubic) {
96     // The legacy code placed more weight on the edge pixels, since bounding
97     // the set of inputs to sample could cause an edge pixel to be repeated.
98     // Here we change the behavior at borders to match that used by the
99     // scale_and_translate_op, where sampling locations outside the image have
100     // their weight set to 0, and the weights are renormalized so that their sum
101     // is 1.0.
102     out->index_0 = Bound(in_loc - 1, limit);
103     out->weight_0 =
104         (out->index_0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f);
105     out->index_1 = Bound(in_loc, limit);
106     out->weight_1 = (out->index_1 == in_loc ? coeffs_table[offset * 2] : 0.0f);
107     out->index_2 = Bound(in_loc + 1, limit);
108     out->weight_2 =
109         (out->index_2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2]
110                                     : 0.0f);
111     out->index_3 = Bound(in_loc + 2, limit);
112     out->weight_3 = (out->index_3 == in_loc + 2
113                          ? coeffs_table[(kTableSize - offset) * 2 + 1]
114                          : 0.0f);
115 
116     const float weight_sum =
117         out->weight_0 + out->weight_1 + out->weight_2 + out->weight_3;
118     if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits<float>::min()) {
119       const float one_over_weight_sum = 1.0f / weight_sum;
120       out->weight_0 *= one_over_weight_sum;
121       out->weight_1 *= one_over_weight_sum;
122       out->weight_2 *= one_over_weight_sum;
123       out->weight_3 *= one_over_weight_sum;
124     }
125   } else {
126     out->weight_0 = coeffs_table[offset * 2 + 1];
127     out->weight_1 = coeffs_table[offset * 2];
128     out->weight_2 = coeffs_table[(kTableSize - offset) * 2];
129     out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
130     out->index_0 = Bound(in_loc - 1, limit);
131     out->index_1 = Bound(in_loc, limit);
132     out->index_2 = Bound(in_loc + 1, limit);
133     out->index_3 = Bound(in_loc + 2, limit);
134   }
135 }
136 
137 template <typename T>
Interpolate1D(const float weight_0,const float weight_1,const float weight_2,const float weight_3,const T value_0,const T value_1,const T value_2,const T value_3)138 inline float Interpolate1D(const float weight_0, const float weight_1,
139                            const float weight_2, const float weight_3,
140                            const T value_0, const T value_1, const T value_2,
141                            const T value_3) {
142   return static_cast<float>(value_0) * weight_0 +
143          static_cast<float>(value_1) * weight_1 +
144          static_cast<float>(value_2) * weight_2 +
145          static_cast<float>(value_3) * weight_3;
146 }
147 
148 // Compute the 1D interpolation for a given X index using the y_weights
Compute(float values_[4],const float xw_0,const float xw_1,const float xw_2,const float xw_3)149 static float Compute(float values_[4], const float xw_0, const float xw_1,
150                      const float xw_2, const float xw_3) {
151   return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1],
152                        values_[2], values_[3]);
153 }
154 
155 // In order to compute a single output value, we look at a 4x4 patch in the
156 // source image. As we iterate increasing X across the image, the new 4x4 patch
157 // often overlaps with the previous 4x4 patch we just looked at.
158 //
159 // This class helps compute the number of values to copy from the previous
160 // point's values.
161 class CachedInterpolationCalculator {
162  public:
CachedInterpolationCalculator()163   CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {}
164 
165   // Advances iteration. Returns the number of values that should be copied from
166   // the current point to the next point. The copying should always be done by
167   // copying the last <retval> values from the old point to the first <retval>
168   // values of the new point.
Advance(const int64_t x_0,const int64_t x_1,const int64_t x_2,const int64_t x_3)169   inline int Advance(const int64_t x_0, const int64_t x_1, const int64_t x_2,
170                      const int64_t x_3) {
171     // We use 2 hands and walk through, copying from one to another where
172     // we already have values.
173     // Invariant, new_indices_hand <= cached_values_hand
174     const std::array<int64, 4> new_x_indices{{x_0, x_1, x_2, x_3}};
175     int cached_values_hand = 0;
176     int new_indices_hand = 0;
177     while (cached_values_hand < 4) {
178       if (indexes_[cached_values_hand] == new_x_indices[new_indices_hand]) {
179         if (new_indices_hand < cached_values_hand) {
180           indexes_[new_indices_hand] = indexes_[cached_values_hand];
181         }
182         cached_values_hand++;
183         new_indices_hand++;
184       } else {
185         cached_values_hand++;
186       }
187     }
188     switch (new_indices_hand) {
189       case 0:
190         indexes_[0] = x_0;
191         TF_FALLTHROUGH_INTENDED;
192       case 1:
193         indexes_[1] = x_1;
194         TF_FALLTHROUGH_INTENDED;
195       case 2:
196         indexes_[2] = x_2;
197         TF_FALLTHROUGH_INTENDED;
198       case 3:
199         indexes_[3] = x_3;
200         break;
201     }
202     return new_indices_hand;
203   }
204 
205  private:
206   int64 indexes_[4];
207 };
208 
ComputeXWeightsAndIndices(const ImageResizerState & resizer_state,const bool half_pixel_centers,std::vector<WeightsAndIndices> * x_wais)209 static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
210                                       const bool half_pixel_centers,
211                                       std::vector<WeightsAndIndices>* x_wais) {
212   CachedInterpolationCalculator calc;
213   if (half_pixel_centers) {
214     for (int64_t x = 0; x < resizer_state.out_width; ++x) {
215       GetWeightsAndIndices<HalfPixelScaler, true>(
216           resizer_state.width_scale, x, resizer_state.in_width, &(*x_wais)[x]);
217       auto& x_wai = (*x_wais)[x];
218       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
219                                    x_wai.index_3);
220     }
221   } else {
222     for (int64_t x = 0; x < resizer_state.out_width; ++x) {
223       GetWeightsAndIndices<LegacyScaler, false>(
224           resizer_state.width_scale, x, resizer_state.in_width, &(*x_wais)[x]);
225       auto& x_wai = (*x_wais)[x];
226       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
227                                    x_wai.index_3);
228     }
229   }
230   // Scale the values so they can be used as offsets into buffers.
231   for (int x = 0; x < resizer_state.out_width; ++x) {
232     (*x_wais)[x].index_0 *= resizer_state.channels;
233     (*x_wais)[x].index_1 *= resizer_state.channels;
234     (*x_wais)[x].index_2 *= resizer_state.channels;
235     (*x_wais)[x].index_3 *= resizer_state.channels;
236   }
237 }
238 
ComputeGradientXWeightsAndIndices(const ImageResizerGradientState & resizer_state,const bool half_pixel_centers,std::vector<WeightsAndIndices> * x_wais)239 static void ComputeGradientXWeightsAndIndices(
240     const ImageResizerGradientState& resizer_state,
241     const bool half_pixel_centers, std::vector<WeightsAndIndices>* x_wais) {
242   CachedInterpolationCalculator calc;
243   if (half_pixel_centers) {
244     for (int64_t x = 0; x < resizer_state.resized_width; ++x) {
245       GetWeightsAndIndices<HalfPixelScaler, true>(resizer_state.width_scale, x,
246                                                   resizer_state.original_width,
247                                                   &(*x_wais)[x]);
248       auto& x_wai = (*x_wais)[x];
249       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
250                                    x_wai.index_3);
251     }
252 
253   } else {
254     for (int64_t x = 0; x < resizer_state.resized_width; ++x) {
255       GetWeightsAndIndices<LegacyScaler, false>(resizer_state.width_scale, x,
256                                                 resizer_state.original_width,
257                                                 &(*x_wais)[x]);
258       auto& x_wai = (*x_wais)[x];
259       x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
260                                    x_wai.index_3);
261     }
262   }
263   // Do not scale, as we will be using these directly as tensor indices on the
264   // gradient pass.
265 }
266 
267 template <typename T>
ComputeYInterpolation(int which,int channel_num,const WeightsAndIndices & y_wai,const T * y_ptr_0,const T * y_ptr_1,const T * y_ptr_2,const T * y_ptr_3,const WeightsAndIndices & x_wai)268 static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
269     int which, int channel_num, const WeightsAndIndices& y_wai,
270     const T* y_ptr_0, const T* y_ptr_1, const T* y_ptr_2, const T* y_ptr_3,
271     const WeightsAndIndices& x_wai) {
272   int x_index;
273   switch (which) {
274     case 0:
275       x_index = x_wai.index_0;
276       break;
277     case 1:
278       x_index = x_wai.index_1;
279       break;
280     case 2:
281       x_index = x_wai.index_2;
282       break;
283     default:
284       x_index = x_wai.index_3;
285       break;
286   }
287   const int64_t pt_index = x_index + channel_num;
288   return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2,
289                           y_wai.weight_3, y_ptr_0[pt_index], y_ptr_1[pt_index],
290                           y_ptr_2[pt_index], y_ptr_3[pt_index]);
291 }
292 
293 template <typename T>
interpolate_with_caching(const typename TTypes<T,4>::ConstTensor & input_data,const ImageResizerState & resizer_state,const bool half_pixel_centers,typename TTypes<float,4>::Tensor output_data)294 inline void interpolate_with_caching(
295     const typename TTypes<T, 4>::ConstTensor& input_data,
296     const ImageResizerState& resizer_state, const bool half_pixel_centers,
297     typename TTypes<float, 4>::Tensor output_data) {
298   std::vector<WeightsAndIndices> x_wais(resizer_state.out_width);
299   ComputeXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
300 
301   const auto num_channels = resizer_state.channels;
302   const int64_t in_row_width = resizer_state.in_width * num_channels;
303   const int64_t in_batch_width = resizer_state.in_height * in_row_width;
304 
305   const T* input_b_ptr = input_data.data();
306   float* output_y_ptr = output_data.data();
307   std::vector<float> cached_value(num_channels == 3 ? 0 : 4 * num_channels, 0);
308 
309   for (int64_t b = 0; b < resizer_state.batch_size;
310        ++b, input_b_ptr += in_batch_width) {
311     for (int64_t y = 0; y < resizer_state.out_height;
312          ++y, output_y_ptr += resizer_state.out_width * num_channels) {
313       WeightsAndIndices y_wai;
314       if (half_pixel_centers) {
315         GetWeightsAndIndices<HalfPixelScaler, true>(
316             resizer_state.height_scale, y, resizer_state.in_height, &y_wai);
317       } else {
318         GetWeightsAndIndices<LegacyScaler, false>(
319             resizer_state.height_scale, y, resizer_state.in_height, &y_wai);
320       }
321       // Make pointers represent offsets of data in input_b_ptr.
322       const T* y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width;
323       const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width;
324       const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width;
325       const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width;
326 
327       if (num_channels == 3) {
328         // Manually unroll case of 3 channels.
329         float cached_value_0[4] = {0};
330         float cached_value_1[4] = {0};
331         float cached_value_2[4] = {0};
332         for (int64_t x = 0; x < resizer_state.out_width; ++x) {
333           const WeightsAndIndices& x_wai = x_wais[x];
334           // Shift values in cached_value_* to fill first 'advance' values.
335           switch (x_wai.advance) {
336             case 3:
337               cached_value_0[0] = cached_value_0[1];
338               cached_value_0[1] = cached_value_0[2];
339               cached_value_0[2] = cached_value_0[3];
340               cached_value_1[0] = cached_value_1[1];
341               cached_value_1[1] = cached_value_1[2];
342               cached_value_1[2] = cached_value_1[3];
343               cached_value_2[0] = cached_value_2[1];
344               cached_value_2[1] = cached_value_2[2];
345               cached_value_2[2] = cached_value_2[3];
346               break;
347             case 2:
348               cached_value_0[0] = cached_value_0[2];
349               cached_value_0[1] = cached_value_0[3];
350               cached_value_1[0] = cached_value_1[2];
351               cached_value_1[1] = cached_value_1[3];
352               cached_value_2[0] = cached_value_2[2];
353               cached_value_2[1] = cached_value_2[3];
354               break;
355             case 1: {
356               cached_value_0[0] = cached_value_0[3];
357               cached_value_1[0] = cached_value_1[3];
358               cached_value_2[0] = cached_value_2[3];
359               break;
360             }
361           }
362 
363           // Set the remaining '4-advance' values by computing.
364           switch (x_wai.advance) {
365             case 0:
366               cached_value_0[0] = ComputeYInterpolation(
367                   0, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
368               cached_value_1[0] = ComputeYInterpolation(
369                   0, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
370               cached_value_2[0] = ComputeYInterpolation(
371                   0, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
372               TF_FALLTHROUGH_INTENDED;
373             case 1:
374               cached_value_0[1] = ComputeYInterpolation(
375                   1, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
376               cached_value_1[1] = ComputeYInterpolation(
377                   1, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
378               cached_value_2[1] = ComputeYInterpolation(
379                   1, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
380               TF_FALLTHROUGH_INTENDED;
381             case 2:
382               cached_value_0[2] = ComputeYInterpolation(
383                   2, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
384               cached_value_1[2] = ComputeYInterpolation(
385                   2, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
386               cached_value_2[2] = ComputeYInterpolation(
387                   2, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
388               TF_FALLTHROUGH_INTENDED;
389             case 3:
390               cached_value_0[3] = ComputeYInterpolation(
391                   3, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
392               cached_value_1[3] = ComputeYInterpolation(
393                   3, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
394               cached_value_2[3] = ComputeYInterpolation(
395                   3, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
396               break;
397           }
398           output_y_ptr[x * num_channels + 0] =
399               Compute(cached_value_0, x_wai.weight_0, x_wai.weight_1,
400                       x_wai.weight_2, x_wai.weight_3);
401           output_y_ptr[x * num_channels + 1] =
402               Compute(cached_value_1, x_wai.weight_0, x_wai.weight_1,
403                       x_wai.weight_2, x_wai.weight_3);
404           output_y_ptr[x * num_channels + 2] =
405               Compute(cached_value_2, x_wai.weight_0, x_wai.weight_1,
406                       x_wai.weight_2, x_wai.weight_3);
407         }
408       } else {
409         for (int64_t x = 0; x < resizer_state.out_width; ++x) {
410           const WeightsAndIndices& x_wai = x_wais[x];
411           // Shift values in cached_value to fill first 'advance' values.
412           switch (x_wai.advance) {
413             case 3:
414               for (int64_t c = 0; c < num_channels; ++c) {
415                 cached_value[4 * c + 0] = cached_value[4 * c + 1];
416                 cached_value[4 * c + 1] = cached_value[4 * c + 2];
417                 cached_value[4 * c + 2] = cached_value[4 * c + 3];
418               }
419               break;
420             case 2:
421               for (int64_t c = 0; c < num_channels; ++c) {
422                 cached_value[4 * c + 0] = cached_value[4 * c + 2];
423                 cached_value[4 * c + 1] = cached_value[4 * c + 3];
424               }
425               break;
426             case 1: {
427               for (int64_t c = 0; c < num_channels; ++c) {
428                 cached_value[4 * c + 0] = cached_value[4 * c + 3];
429               }
430               break;
431             }
432           }
433 
434           // Set the remaining '4-advance' values by computing.
435           switch (x_wai.advance) {
436             case 0:
437               for (int64_t c = 0; c < num_channels; ++c) {
438                 cached_value[4 * c + 0] = ComputeYInterpolation(
439                     0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
440               }
441               TF_FALLTHROUGH_INTENDED;
442             case 1:
443               for (int64_t c = 0; c < num_channels; ++c) {
444                 cached_value[4 * c + 1] = ComputeYInterpolation(
445                     1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
446               }
447               TF_FALLTHROUGH_INTENDED;
448             case 2:
449               for (int64_t c = 0; c < num_channels; ++c) {
450                 cached_value[4 * c + 2] = ComputeYInterpolation(
451                     2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
452               }
453               TF_FALLTHROUGH_INTENDED;
454             case 3:
455               for (int64_t c = 0; c < num_channels; ++c) {
456                 cached_value[4 * c + 3] = ComputeYInterpolation(
457                     3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
458               }
459               break;
460           }
461           for (int64_t c = 0; c < num_channels; ++c) {
462             output_y_ptr[x * num_channels + c] =
463                 Compute(&cached_value[4 * c], x_wai.weight_0, x_wai.weight_1,
464                         x_wai.weight_2, x_wai.weight_3);
465           }
466         }
467       }
468     }
469   }
470 }
471 
472 template <typename T>
ResizeBicubicGrad(typename TTypes<float,4>::ConstTensor input_grad,const ImageResizerGradientState & resizer_state,const bool half_pixel_centers,typename TTypes<T,4>::Tensor output_grad)473 inline void ResizeBicubicGrad(typename TTypes<float, 4>::ConstTensor input_grad,
474                               const ImageResizerGradientState& resizer_state,
475                               const bool half_pixel_centers,
476                               typename TTypes<T, 4>::Tensor output_grad) {
477   // This function computes gradients for the ResizeBicubic op by iterating over
478   // the input_grad Tensor and using WeightsAndIndices to appropriately update
479   // the output gradient.
480   const float height_scale = resizer_state.height_scale;
481   const int64_t original_height = resizer_state.original_height;
482   const int channels = resizer_state.channels;
483   const int64_t resized_width = resizer_state.resized_width;
484   const int64_t resized_height = resizer_state.resized_height;
485 
486   output_grad.setZero();
487 
488   std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
489   ComputeGradientXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
490   for (int64_t b = 0; b < resizer_state.batch_size; ++b) {
491     for (int64_t y = 0; y < resized_height; ++y) {
492       WeightsAndIndices y_wai;
493       if (half_pixel_centers) {
494         GetWeightsAndIndices<HalfPixelScaler, true>(height_scale, y,
495                                                     original_height, &y_wai);
496       } else {
497         GetWeightsAndIndices<LegacyScaler, false>(height_scale, y,
498                                                   original_height, &y_wai);
499       }
500       for (int64_t x = 0; x < resized_width; ++x) {
501         const WeightsAndIndices& x_wai = x_wais[x];
502         for (int64_t c = 0; c < channels; ++c) {
503           T curr_input_grad = input_grad(b, y, x, c);
504           // row 0 of 0, 1, 2, 3
505           output_grad(b, y_wai.index_0, x_wai.index_0, c) +=
506               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
507           output_grad(b, y_wai.index_0, x_wai.index_1, c) +=
508               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
509           output_grad(b, y_wai.index_0, x_wai.index_2, c) +=
510               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
511           output_grad(b, y_wai.index_0, x_wai.index_3, c) +=
512               T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
513           // row 1 of 0, 1, 2, 3
514           output_grad(b, y_wai.index_1, x_wai.index_0, c) +=
515               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
516           output_grad(b, y_wai.index_1, x_wai.index_1, c) +=
517               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
518           output_grad(b, y_wai.index_1, x_wai.index_2, c) +=
519               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
520           output_grad(b, y_wai.index_1, x_wai.index_3, c) +=
521               T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
522           // row 2 of 0, 1, 2, 3
523           output_grad(b, y_wai.index_2, x_wai.index_0, c) +=
524               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
525           output_grad(b, y_wai.index_2, x_wai.index_1, c) +=
526               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
527           output_grad(b, y_wai.index_2, x_wai.index_2, c) +=
528               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
529           output_grad(b, y_wai.index_2, x_wai.index_3, c) +=
530               T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
531           // row 3 of 0, 1, 2, 3
532           output_grad(b, y_wai.index_3, x_wai.index_0, c) +=
533               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
534           output_grad(b, y_wai.index_3, x_wai.index_1, c) +=
535               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
536           output_grad(b, y_wai.index_3, x_wai.index_2, c) +=
537               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
538           output_grad(b, y_wai.index_3, x_wai.index_3, c) +=
539               T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
540         }
541       }
542     }
543   }
544 }
545 
546 }  // namespace
547 
548 typedef Eigen::ThreadPoolDevice CPUDevice;
549 
550 template <typename Device, typename T>
551 class ResizeBicubicOp : public OpKernel {
552  public:
ResizeBicubicOp(OpKernelConstruction * context)553   explicit ResizeBicubicOp(OpKernelConstruction* context) : OpKernel(context) {
554     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
555     OP_REQUIRES_OK(
556         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
557   }
558 
Compute(OpKernelContext * context)559   void Compute(OpKernelContext* context) override {
560     ImageResizerState st(align_corners_, half_pixel_centers_);
561     st.ValidateAndCreateOutput(context);
562 
563     if (!context->status().ok()) return;
564 
565     typename TTypes<T, 4>::ConstTensor input_data(
566         context->input(0).tensor<T, 4>());
567     TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
568 
569     interpolate_with_caching<T>(input_data, st, half_pixel_centers_,
570                                 output_data);
571   }
572 
573  private:
574   bool align_corners_;
575   bool half_pixel_centers_;
576 };
577 
578 template <typename Device, typename T>
579 class ResizeBicubicOpGrad : public OpKernel {
580  public:
ResizeBicubicOpGrad(OpKernelConstruction * context)581   explicit ResizeBicubicOpGrad(OpKernelConstruction* context)
582       : OpKernel(context) {
583     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
584     OP_REQUIRES_OK(
585         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
586   }
587 
Compute(OpKernelContext * context)588   void Compute(OpKernelContext* context) override {
589     // Validate input.
590     ImageResizerGradientState st(align_corners_, half_pixel_centers_);
591     st.ValidateAndCreateOutput(context);
592 
593     if (!context->status().ok()) return;
594 
595     // First argument is gradient with respect to resized image.
596     TTypes<float, 4>::ConstTensor input_grad =
597         context->input(0).tensor<float, 4>();
598 
599     typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
600 
601     ResizeBicubicGrad<T>(input_grad, st, half_pixel_centers_, output_grad);
602   }
603 
604  private:
605   bool align_corners_;
606   bool half_pixel_centers_;
607 };
608 
609 #define REGISTER_KERNEL(T)                            \
610   REGISTER_KERNEL_BUILDER(Name("ResizeBicubic")       \
611                               .Device(DEVICE_CPU)     \
612                               .TypeConstraint<T>("T") \
613                               .HostMemory("size"),    \
614                           ResizeBicubicOp<CPUDevice, T>);
615 
616 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
617 
618 #undef REGISTER_KERNEL
619 
620 #define REGISTER_GRAD_KERNEL(T)                                            \
621   REGISTER_KERNEL_BUILDER(                                                 \
622       Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
623       ResizeBicubicOpGrad<CPUDevice, T>);
624 
625 TF_CALL_float(REGISTER_GRAD_KERNEL);
626 TF_CALL_double(REGISTER_GRAD_KERNEL);
627 
628 #undef REGISTER_GRAD_KERNEL
629 
630 }  // namespace tensorflow
631