1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/image_ops.cc
17 #define EIGEN_USE_THREADS
18
19 #include <math.h>
20
21 #include <algorithm>
22 #include <array>
23
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/util/image_resizer_state.h"
33
34 namespace tensorflow {
35 namespace {
36
37 static const int64_t kTableSize = (1 << 10);
38
InitCoeffsTable(const double a)39 const float* InitCoeffsTable(const double a) {
40 // Allocate and initialize coefficients table using Bicubic
41 // convolution algorithm.
42 // https://en.wikipedia.org/wiki/Bicubic_interpolation
43 float* coeffs_table = new float[(kTableSize + 1) * 2];
44 for (int i = 0; i <= kTableSize; ++i) {
45 float x = i * 1.0 / kTableSize;
46 coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1;
47 x += 1.0;
48 coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
49 }
50
51 return coeffs_table;
52 }
53
GetCoeffsTable(const bool use_keys_cubic)54 const float* GetCoeffsTable(const bool use_keys_cubic) {
55 // Static so that we initialize it on first use
56 if (use_keys_cubic) {
57 // http://ieeexplore.ieee.org/document/1163711/
58 // R. G. Keys. Cubic convolution interpolation for digital image
59 // processing. IEEE Transactions on Acoustics, Speech, and Signal
60 // Processing, 29(6):1153–1160, 1981.
61 static const float* coeffs_table = InitCoeffsTable(-0.5f);
62 return coeffs_table;
63 } else {
64 static const float* coeffs_table = InitCoeffsTable(-0.75f);
65 return coeffs_table;
66 }
67 }
68
Bound(int64_t val,int64_t limit)69 inline int64 Bound(int64_t val, int64_t limit) {
70 return std::min(limit - 1, std::max(int64{0}, val));
71 }
72
73 struct WeightsAndIndices {
74 float weight_0;
75 float weight_1;
76 float weight_2;
77 float weight_3;
78 int64 index_0;
79 int64 index_1;
80 int64 index_2;
81 int64 index_3;
82
83 int advance; // advance value.
84 };
85
86 template <typename Scaler, bool use_keys_cubic>
GetWeightsAndIndices(const float scale,const int64_t out_loc,const int64_t limit,WeightsAndIndices * out)87 inline void GetWeightsAndIndices(const float scale, const int64_t out_loc,
88 const int64_t limit, WeightsAndIndices* out) {
89 const Scaler scaler;
90 const float in_loc_f = scaler(out_loc, scale);
91 const int64_t in_loc = std::floor(in_loc_f);
92 const float delta = in_loc_f - in_loc;
93 const int64_t offset = lrintf(delta * kTableSize);
94 const float* coeffs_table = GetCoeffsTable(use_keys_cubic);
95 if (use_keys_cubic) {
96 // The legacy code placed more weight on the edge pixels, since bounding
97 // the set of inputs to sample could cause an edge pixel to be repeated.
98 // Here we change the behavior at borders to match that used by the
99 // scale_and_translate_op, where sampling locations outside the image have
100 // their weight set to 0, and the weights are renormalized so that their sum
101 // is 1.0.
102 out->index_0 = Bound(in_loc - 1, limit);
103 out->weight_0 =
104 (out->index_0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f);
105 out->index_1 = Bound(in_loc, limit);
106 out->weight_1 = (out->index_1 == in_loc ? coeffs_table[offset * 2] : 0.0f);
107 out->index_2 = Bound(in_loc + 1, limit);
108 out->weight_2 =
109 (out->index_2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2]
110 : 0.0f);
111 out->index_3 = Bound(in_loc + 2, limit);
112 out->weight_3 = (out->index_3 == in_loc + 2
113 ? coeffs_table[(kTableSize - offset) * 2 + 1]
114 : 0.0f);
115
116 const float weight_sum =
117 out->weight_0 + out->weight_1 + out->weight_2 + out->weight_3;
118 if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits<float>::min()) {
119 const float one_over_weight_sum = 1.0f / weight_sum;
120 out->weight_0 *= one_over_weight_sum;
121 out->weight_1 *= one_over_weight_sum;
122 out->weight_2 *= one_over_weight_sum;
123 out->weight_3 *= one_over_weight_sum;
124 }
125 } else {
126 out->weight_0 = coeffs_table[offset * 2 + 1];
127 out->weight_1 = coeffs_table[offset * 2];
128 out->weight_2 = coeffs_table[(kTableSize - offset) * 2];
129 out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
130 out->index_0 = Bound(in_loc - 1, limit);
131 out->index_1 = Bound(in_loc, limit);
132 out->index_2 = Bound(in_loc + 1, limit);
133 out->index_3 = Bound(in_loc + 2, limit);
134 }
135 }
136
137 template <typename T>
Interpolate1D(const float weight_0,const float weight_1,const float weight_2,const float weight_3,const T value_0,const T value_1,const T value_2,const T value_3)138 inline float Interpolate1D(const float weight_0, const float weight_1,
139 const float weight_2, const float weight_3,
140 const T value_0, const T value_1, const T value_2,
141 const T value_3) {
142 return static_cast<float>(value_0) * weight_0 +
143 static_cast<float>(value_1) * weight_1 +
144 static_cast<float>(value_2) * weight_2 +
145 static_cast<float>(value_3) * weight_3;
146 }
147
148 // Compute the 1D interpolation for a given X index using the y_weights
Compute(float values_[4],const float xw_0,const float xw_1,const float xw_2,const float xw_3)149 static float Compute(float values_[4], const float xw_0, const float xw_1,
150 const float xw_2, const float xw_3) {
151 return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1],
152 values_[2], values_[3]);
153 }
154
155 // In order to compute a single output value, we look at a 4x4 patch in the
156 // source image. As we iterate increasing X across the image, the new 4x4 patch
157 // often overlaps with the previous 4x4 patch we just looked at.
158 //
159 // This class helps compute the number of values to copy from the previous
160 // point's values.
161 class CachedInterpolationCalculator {
162 public:
CachedInterpolationCalculator()163 CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {}
164
165 // Advances iteration. Returns the number of values that should be copied from
166 // the current point to the next point. The copying should always be done by
167 // copying the last <retval> values from the old point to the first <retval>
168 // values of the new point.
Advance(const int64_t x_0,const int64_t x_1,const int64_t x_2,const int64_t x_3)169 inline int Advance(const int64_t x_0, const int64_t x_1, const int64_t x_2,
170 const int64_t x_3) {
171 // We use 2 hands and walk through, copying from one to another where
172 // we already have values.
173 // Invariant, new_indices_hand <= cached_values_hand
174 const std::array<int64, 4> new_x_indices{{x_0, x_1, x_2, x_3}};
175 int cached_values_hand = 0;
176 int new_indices_hand = 0;
177 while (cached_values_hand < 4) {
178 if (indexes_[cached_values_hand] == new_x_indices[new_indices_hand]) {
179 if (new_indices_hand < cached_values_hand) {
180 indexes_[new_indices_hand] = indexes_[cached_values_hand];
181 }
182 cached_values_hand++;
183 new_indices_hand++;
184 } else {
185 cached_values_hand++;
186 }
187 }
188 switch (new_indices_hand) {
189 case 0:
190 indexes_[0] = x_0;
191 TF_FALLTHROUGH_INTENDED;
192 case 1:
193 indexes_[1] = x_1;
194 TF_FALLTHROUGH_INTENDED;
195 case 2:
196 indexes_[2] = x_2;
197 TF_FALLTHROUGH_INTENDED;
198 case 3:
199 indexes_[3] = x_3;
200 break;
201 }
202 return new_indices_hand;
203 }
204
205 private:
206 int64 indexes_[4];
207 };
208
ComputeXWeightsAndIndices(const ImageResizerState & resizer_state,const bool half_pixel_centers,std::vector<WeightsAndIndices> * x_wais)209 static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
210 const bool half_pixel_centers,
211 std::vector<WeightsAndIndices>* x_wais) {
212 CachedInterpolationCalculator calc;
213 if (half_pixel_centers) {
214 for (int64_t x = 0; x < resizer_state.out_width; ++x) {
215 GetWeightsAndIndices<HalfPixelScaler, true>(
216 resizer_state.width_scale, x, resizer_state.in_width, &(*x_wais)[x]);
217 auto& x_wai = (*x_wais)[x];
218 x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
219 x_wai.index_3);
220 }
221 } else {
222 for (int64_t x = 0; x < resizer_state.out_width; ++x) {
223 GetWeightsAndIndices<LegacyScaler, false>(
224 resizer_state.width_scale, x, resizer_state.in_width, &(*x_wais)[x]);
225 auto& x_wai = (*x_wais)[x];
226 x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
227 x_wai.index_3);
228 }
229 }
230 // Scale the values so they can be used as offsets into buffers.
231 for (int x = 0; x < resizer_state.out_width; ++x) {
232 (*x_wais)[x].index_0 *= resizer_state.channels;
233 (*x_wais)[x].index_1 *= resizer_state.channels;
234 (*x_wais)[x].index_2 *= resizer_state.channels;
235 (*x_wais)[x].index_3 *= resizer_state.channels;
236 }
237 }
238
ComputeGradientXWeightsAndIndices(const ImageResizerGradientState & resizer_state,const bool half_pixel_centers,std::vector<WeightsAndIndices> * x_wais)239 static void ComputeGradientXWeightsAndIndices(
240 const ImageResizerGradientState& resizer_state,
241 const bool half_pixel_centers, std::vector<WeightsAndIndices>* x_wais) {
242 CachedInterpolationCalculator calc;
243 if (half_pixel_centers) {
244 for (int64_t x = 0; x < resizer_state.resized_width; ++x) {
245 GetWeightsAndIndices<HalfPixelScaler, true>(resizer_state.width_scale, x,
246 resizer_state.original_width,
247 &(*x_wais)[x]);
248 auto& x_wai = (*x_wais)[x];
249 x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
250 x_wai.index_3);
251 }
252
253 } else {
254 for (int64_t x = 0; x < resizer_state.resized_width; ++x) {
255 GetWeightsAndIndices<LegacyScaler, false>(resizer_state.width_scale, x,
256 resizer_state.original_width,
257 &(*x_wais)[x]);
258 auto& x_wai = (*x_wais)[x];
259 x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
260 x_wai.index_3);
261 }
262 }
263 // Do not scale, as we will be using these directly as tensor indices on the
264 // gradient pass.
265 }
266
267 template <typename T>
ComputeYInterpolation(int which,int channel_num,const WeightsAndIndices & y_wai,const T * y_ptr_0,const T * y_ptr_1,const T * y_ptr_2,const T * y_ptr_3,const WeightsAndIndices & x_wai)268 static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
269 int which, int channel_num, const WeightsAndIndices& y_wai,
270 const T* y_ptr_0, const T* y_ptr_1, const T* y_ptr_2, const T* y_ptr_3,
271 const WeightsAndIndices& x_wai) {
272 int x_index;
273 switch (which) {
274 case 0:
275 x_index = x_wai.index_0;
276 break;
277 case 1:
278 x_index = x_wai.index_1;
279 break;
280 case 2:
281 x_index = x_wai.index_2;
282 break;
283 default:
284 x_index = x_wai.index_3;
285 break;
286 }
287 const int64_t pt_index = x_index + channel_num;
288 return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2,
289 y_wai.weight_3, y_ptr_0[pt_index], y_ptr_1[pt_index],
290 y_ptr_2[pt_index], y_ptr_3[pt_index]);
291 }
292
293 template <typename T>
interpolate_with_caching(const typename TTypes<T,4>::ConstTensor & input_data,const ImageResizerState & resizer_state,const bool half_pixel_centers,typename TTypes<float,4>::Tensor output_data)294 inline void interpolate_with_caching(
295 const typename TTypes<T, 4>::ConstTensor& input_data,
296 const ImageResizerState& resizer_state, const bool half_pixel_centers,
297 typename TTypes<float, 4>::Tensor output_data) {
298 std::vector<WeightsAndIndices> x_wais(resizer_state.out_width);
299 ComputeXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
300
301 const auto num_channels = resizer_state.channels;
302 const int64_t in_row_width = resizer_state.in_width * num_channels;
303 const int64_t in_batch_width = resizer_state.in_height * in_row_width;
304
305 const T* input_b_ptr = input_data.data();
306 float* output_y_ptr = output_data.data();
307 std::vector<float> cached_value(num_channels == 3 ? 0 : 4 * num_channels, 0);
308
309 for (int64_t b = 0; b < resizer_state.batch_size;
310 ++b, input_b_ptr += in_batch_width) {
311 for (int64_t y = 0; y < resizer_state.out_height;
312 ++y, output_y_ptr += resizer_state.out_width * num_channels) {
313 WeightsAndIndices y_wai;
314 if (half_pixel_centers) {
315 GetWeightsAndIndices<HalfPixelScaler, true>(
316 resizer_state.height_scale, y, resizer_state.in_height, &y_wai);
317 } else {
318 GetWeightsAndIndices<LegacyScaler, false>(
319 resizer_state.height_scale, y, resizer_state.in_height, &y_wai);
320 }
321 // Make pointers represent offsets of data in input_b_ptr.
322 const T* y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width;
323 const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width;
324 const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width;
325 const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width;
326
327 if (num_channels == 3) {
328 // Manually unroll case of 3 channels.
329 float cached_value_0[4] = {0};
330 float cached_value_1[4] = {0};
331 float cached_value_2[4] = {0};
332 for (int64_t x = 0; x < resizer_state.out_width; ++x) {
333 const WeightsAndIndices& x_wai = x_wais[x];
334 // Shift values in cached_value_* to fill first 'advance' values.
335 switch (x_wai.advance) {
336 case 3:
337 cached_value_0[0] = cached_value_0[1];
338 cached_value_0[1] = cached_value_0[2];
339 cached_value_0[2] = cached_value_0[3];
340 cached_value_1[0] = cached_value_1[1];
341 cached_value_1[1] = cached_value_1[2];
342 cached_value_1[2] = cached_value_1[3];
343 cached_value_2[0] = cached_value_2[1];
344 cached_value_2[1] = cached_value_2[2];
345 cached_value_2[2] = cached_value_2[3];
346 break;
347 case 2:
348 cached_value_0[0] = cached_value_0[2];
349 cached_value_0[1] = cached_value_0[3];
350 cached_value_1[0] = cached_value_1[2];
351 cached_value_1[1] = cached_value_1[3];
352 cached_value_2[0] = cached_value_2[2];
353 cached_value_2[1] = cached_value_2[3];
354 break;
355 case 1: {
356 cached_value_0[0] = cached_value_0[3];
357 cached_value_1[0] = cached_value_1[3];
358 cached_value_2[0] = cached_value_2[3];
359 break;
360 }
361 }
362
363 // Set the remaining '4-advance' values by computing.
364 switch (x_wai.advance) {
365 case 0:
366 cached_value_0[0] = ComputeYInterpolation(
367 0, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
368 cached_value_1[0] = ComputeYInterpolation(
369 0, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
370 cached_value_2[0] = ComputeYInterpolation(
371 0, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
372 TF_FALLTHROUGH_INTENDED;
373 case 1:
374 cached_value_0[1] = ComputeYInterpolation(
375 1, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
376 cached_value_1[1] = ComputeYInterpolation(
377 1, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
378 cached_value_2[1] = ComputeYInterpolation(
379 1, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
380 TF_FALLTHROUGH_INTENDED;
381 case 2:
382 cached_value_0[2] = ComputeYInterpolation(
383 2, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
384 cached_value_1[2] = ComputeYInterpolation(
385 2, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
386 cached_value_2[2] = ComputeYInterpolation(
387 2, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
388 TF_FALLTHROUGH_INTENDED;
389 case 3:
390 cached_value_0[3] = ComputeYInterpolation(
391 3, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
392 cached_value_1[3] = ComputeYInterpolation(
393 3, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
394 cached_value_2[3] = ComputeYInterpolation(
395 3, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
396 break;
397 }
398 output_y_ptr[x * num_channels + 0] =
399 Compute(cached_value_0, x_wai.weight_0, x_wai.weight_1,
400 x_wai.weight_2, x_wai.weight_3);
401 output_y_ptr[x * num_channels + 1] =
402 Compute(cached_value_1, x_wai.weight_0, x_wai.weight_1,
403 x_wai.weight_2, x_wai.weight_3);
404 output_y_ptr[x * num_channels + 2] =
405 Compute(cached_value_2, x_wai.weight_0, x_wai.weight_1,
406 x_wai.weight_2, x_wai.weight_3);
407 }
408 } else {
409 for (int64_t x = 0; x < resizer_state.out_width; ++x) {
410 const WeightsAndIndices& x_wai = x_wais[x];
411 // Shift values in cached_value to fill first 'advance' values.
412 switch (x_wai.advance) {
413 case 3:
414 for (int64_t c = 0; c < num_channels; ++c) {
415 cached_value[4 * c + 0] = cached_value[4 * c + 1];
416 cached_value[4 * c + 1] = cached_value[4 * c + 2];
417 cached_value[4 * c + 2] = cached_value[4 * c + 3];
418 }
419 break;
420 case 2:
421 for (int64_t c = 0; c < num_channels; ++c) {
422 cached_value[4 * c + 0] = cached_value[4 * c + 2];
423 cached_value[4 * c + 1] = cached_value[4 * c + 3];
424 }
425 break;
426 case 1: {
427 for (int64_t c = 0; c < num_channels; ++c) {
428 cached_value[4 * c + 0] = cached_value[4 * c + 3];
429 }
430 break;
431 }
432 }
433
434 // Set the remaining '4-advance' values by computing.
435 switch (x_wai.advance) {
436 case 0:
437 for (int64_t c = 0; c < num_channels; ++c) {
438 cached_value[4 * c + 0] = ComputeYInterpolation(
439 0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
440 }
441 TF_FALLTHROUGH_INTENDED;
442 case 1:
443 for (int64_t c = 0; c < num_channels; ++c) {
444 cached_value[4 * c + 1] = ComputeYInterpolation(
445 1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
446 }
447 TF_FALLTHROUGH_INTENDED;
448 case 2:
449 for (int64_t c = 0; c < num_channels; ++c) {
450 cached_value[4 * c + 2] = ComputeYInterpolation(
451 2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
452 }
453 TF_FALLTHROUGH_INTENDED;
454 case 3:
455 for (int64_t c = 0; c < num_channels; ++c) {
456 cached_value[4 * c + 3] = ComputeYInterpolation(
457 3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
458 }
459 break;
460 }
461 for (int64_t c = 0; c < num_channels; ++c) {
462 output_y_ptr[x * num_channels + c] =
463 Compute(&cached_value[4 * c], x_wai.weight_0, x_wai.weight_1,
464 x_wai.weight_2, x_wai.weight_3);
465 }
466 }
467 }
468 }
469 }
470 }
471
472 template <typename T>
ResizeBicubicGrad(typename TTypes<float,4>::ConstTensor input_grad,const ImageResizerGradientState & resizer_state,const bool half_pixel_centers,typename TTypes<T,4>::Tensor output_grad)473 inline void ResizeBicubicGrad(typename TTypes<float, 4>::ConstTensor input_grad,
474 const ImageResizerGradientState& resizer_state,
475 const bool half_pixel_centers,
476 typename TTypes<T, 4>::Tensor output_grad) {
477 // This function computes gradients for the ResizeBicubic op by iterating over
478 // the input_grad Tensor and using WeightsAndIndices to appropriately update
479 // the output gradient.
480 const float height_scale = resizer_state.height_scale;
481 const int64_t original_height = resizer_state.original_height;
482 const int channels = resizer_state.channels;
483 const int64_t resized_width = resizer_state.resized_width;
484 const int64_t resized_height = resizer_state.resized_height;
485
486 output_grad.setZero();
487
488 std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
489 ComputeGradientXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
490 for (int64_t b = 0; b < resizer_state.batch_size; ++b) {
491 for (int64_t y = 0; y < resized_height; ++y) {
492 WeightsAndIndices y_wai;
493 if (half_pixel_centers) {
494 GetWeightsAndIndices<HalfPixelScaler, true>(height_scale, y,
495 original_height, &y_wai);
496 } else {
497 GetWeightsAndIndices<LegacyScaler, false>(height_scale, y,
498 original_height, &y_wai);
499 }
500 for (int64_t x = 0; x < resized_width; ++x) {
501 const WeightsAndIndices& x_wai = x_wais[x];
502 for (int64_t c = 0; c < channels; ++c) {
503 T curr_input_grad = input_grad(b, y, x, c);
504 // row 0 of 0, 1, 2, 3
505 output_grad(b, y_wai.index_0, x_wai.index_0, c) +=
506 T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
507 output_grad(b, y_wai.index_0, x_wai.index_1, c) +=
508 T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
509 output_grad(b, y_wai.index_0, x_wai.index_2, c) +=
510 T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
511 output_grad(b, y_wai.index_0, x_wai.index_3, c) +=
512 T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
513 // row 1 of 0, 1, 2, 3
514 output_grad(b, y_wai.index_1, x_wai.index_0, c) +=
515 T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
516 output_grad(b, y_wai.index_1, x_wai.index_1, c) +=
517 T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
518 output_grad(b, y_wai.index_1, x_wai.index_2, c) +=
519 T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
520 output_grad(b, y_wai.index_1, x_wai.index_3, c) +=
521 T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
522 // row 2 of 0, 1, 2, 3
523 output_grad(b, y_wai.index_2, x_wai.index_0, c) +=
524 T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
525 output_grad(b, y_wai.index_2, x_wai.index_1, c) +=
526 T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
527 output_grad(b, y_wai.index_2, x_wai.index_2, c) +=
528 T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
529 output_grad(b, y_wai.index_2, x_wai.index_3, c) +=
530 T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
531 // row 3 of 0, 1, 2, 3
532 output_grad(b, y_wai.index_3, x_wai.index_0, c) +=
533 T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
534 output_grad(b, y_wai.index_3, x_wai.index_1, c) +=
535 T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
536 output_grad(b, y_wai.index_3, x_wai.index_2, c) +=
537 T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
538 output_grad(b, y_wai.index_3, x_wai.index_3, c) +=
539 T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
540 }
541 }
542 }
543 }
544 }
545
546 } // namespace
547
548 typedef Eigen::ThreadPoolDevice CPUDevice;
549
550 template <typename Device, typename T>
551 class ResizeBicubicOp : public OpKernel {
552 public:
ResizeBicubicOp(OpKernelConstruction * context)553 explicit ResizeBicubicOp(OpKernelConstruction* context) : OpKernel(context) {
554 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
555 OP_REQUIRES_OK(
556 context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
557 }
558
Compute(OpKernelContext * context)559 void Compute(OpKernelContext* context) override {
560 ImageResizerState st(align_corners_, half_pixel_centers_);
561 st.ValidateAndCreateOutput(context);
562
563 if (!context->status().ok()) return;
564
565 typename TTypes<T, 4>::ConstTensor input_data(
566 context->input(0).tensor<T, 4>());
567 TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
568
569 interpolate_with_caching<T>(input_data, st, half_pixel_centers_,
570 output_data);
571 }
572
573 private:
574 bool align_corners_;
575 bool half_pixel_centers_;
576 };
577
578 template <typename Device, typename T>
579 class ResizeBicubicOpGrad : public OpKernel {
580 public:
ResizeBicubicOpGrad(OpKernelConstruction * context)581 explicit ResizeBicubicOpGrad(OpKernelConstruction* context)
582 : OpKernel(context) {
583 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
584 OP_REQUIRES_OK(
585 context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
586 }
587
Compute(OpKernelContext * context)588 void Compute(OpKernelContext* context) override {
589 // Validate input.
590 ImageResizerGradientState st(align_corners_, half_pixel_centers_);
591 st.ValidateAndCreateOutput(context);
592
593 if (!context->status().ok()) return;
594
595 // First argument is gradient with respect to resized image.
596 TTypes<float, 4>::ConstTensor input_grad =
597 context->input(0).tensor<float, 4>();
598
599 typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
600
601 ResizeBicubicGrad<T>(input_grad, st, half_pixel_centers_, output_grad);
602 }
603
604 private:
605 bool align_corners_;
606 bool half_pixel_centers_;
607 };
608
609 #define REGISTER_KERNEL(T) \
610 REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \
611 .Device(DEVICE_CPU) \
612 .TypeConstraint<T>("T") \
613 .HostMemory("size"), \
614 ResizeBicubicOp<CPUDevice, T>);
615
616 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
617
618 #undef REGISTER_KERNEL
619
620 #define REGISTER_GRAD_KERNEL(T) \
621 REGISTER_KERNEL_BUILDER( \
622 Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
623 ResizeBicubicOpGrad<CPUDevice, T>);
624
625 TF_CALL_float(REGISTER_GRAD_KERNEL);
626 TF_CALL_double(REGISTER_GRAD_KERNEL);
627
628 #undef REGISTER_GRAD_KERNEL
629
630 } // namespace tensorflow
631