• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements convolution operations with image transformations (resize and
17 // mirror padding) baked into the processing, to optimize latency and memory
18 // usage.
19 
20 #define EIGEN_USE_THREADS
21 
22 #include <string>
23 #include <vector>
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/numeric_op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_slice.h"
33 #include "tensorflow/core/kernels/conv_2d.h"
34 #include "tensorflow/core/kernels/conv_ops.h"
35 #include "tensorflow/core/kernels/gemm_functors.h"
36 #include "tensorflow/core/kernels/image_resizer_state.h"
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/util/mirror_pad_mode.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 
43 namespace tensorflow {
44 namespace {
45 
46 // We don't want to allocate a buffer to hold all the patches if the size is
47 // going to be extremely large, so break it into chunks if it's bigger than
48 // a limit. Each chunk will be processed serially, so we can refill the
49 // buffer for the next chunk and reuse it, keeping maximum memory size down.
50 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
51 // other platforms using Eigen, and 1MB for iOS devices, from experimentation.
52 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
53 const size_t kMaxChunkSize = (1 * 1024 * 1024);
54 #else
55 const size_t kMaxChunkSize = (16 * 1024 * 1024);
56 #endif
57 const size_t kResizeCacheSize = (8 * 1024 * 1024);
58 
59 // Lookup method used when resizing.
60 enum SamplingMode {
61   BILINEAR = 0,
62   NEAREST = 1,
63 };
64 
65 // Simple utility function used by FusedConv to multithread basic workloads. To
66 // use it, pass begin and end values for the full workload and a std::function
67 // that receives a subset of that through the begin and end values for each
68 // worker's task. The division of the full workload into worker tasks is handled
69 // by the multithreading logic. Here's an example of how to use it:
70 // std::vector<float> my_vector(100);
71 // ...
72 // FusedConvParallelFor(context, 0, 100,
73 //   [&my_vector](int64 task_begin, int64 task_end) {
74 //     for (int64 current = task_begin; current != task_end; ++current) {
75 //       my_vector[current] *= 10.0f;
76 //     }
77 // });
FusedConvParallelFor(OpKernelContext * context,int64 begin,int64 end,const std::function<void (int64,int64)> & task_function)78 void FusedConvParallelFor(
79     OpKernelContext* context, int64 begin, int64 end,
80     const std::function<void(int64, int64)>& task_function) {
81 // On iOS, the thread management imposes a very big performance penalty, so
82 // just call the function directly with no multithreading.
83 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
84   task_function(begin, end);
85 #else
86   auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
87   thread::ThreadPool* thread_pool = worker_threads.workers;
88   const int64 total_elements = end - begin;
89   // This is a bit of an arbitrary number, but was found to work well for
90   // typical models we've been profiling on various devices.
91   const int64 element_cost = 10000000;
92   thread_pool->ParallelFor(
93       total_elements, element_cost,
94       [begin, task_function](int64 begin_offset, int64 end_offset) {
95         const int64 task_begin = begin + begin_offset;
96         const int64 task_end = begin + end_offset;
97         task_function(task_begin, task_end);
98       });
99 #endif
100 }
101 
102 // Holds the state needed for the resizing subtasks.
103 template <class T1>
104 struct ResizeTaskParameters {
ResizeTaskParameterstensorflow::__anond19478670111::ResizeTaskParameters105   ResizeTaskParameters() : st(false, false) {}
106 
107   int cache_height;
108   T1* resize_cache;
109   int cache_line_width;
110   int input_width;
111   int input_depth;
112   int top_padding;
113   int pad_offset;
114   int64 resized_height;
115   ImageResizerState st;
116   const T1* input_batch_start;
117   int64 cache_start_x;
118   int64 cache_end_x;
119   int left_padding;
120   int64 resized_width;
121   int64 padded_width;
122   int64 padded_height;
123 };
124 
125 template <class T1>
126 struct PerCacheLineParameters {
PerCacheLineParameterstensorflow::__anond19478670111::PerCacheLineParameters127   PerCacheLineParameters() {}
PerCacheLineParameterstensorflow::__anond19478670111::PerCacheLineParameters128   PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
129       : cache_line_start(other.cache_line_start),
130         input_top_row_start(other.input_top_row_start),
131         input_bottom_row_start(other.input_bottom_row_start),
132         y_lerp(other.y_lerp) {}
133 
134   T1* cache_line_start;
135   const T1* input_top_row_start;
136   const T1* input_bottom_row_start;
137   T1 y_lerp;
138 };
139 
140 // Helper class to simplify bilinear filtering
141 template <class T1>
142 struct SampleRect {
SampleRecttensorflow::__anond19478670111::SampleRect143   EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
144                                  const T1* in_bottom_left,
145                                  const T1* in_bottom_right)
146       : top_left(in_top_left),
147         top_right(in_top_right),
148         bottom_left(in_bottom_left),
149         bottom_right(in_bottom_right) {}
150 
BilinearSampletensorflow::__anond19478670111::SampleRect151   EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
152                                         T1 y_lerp) const {
153     const T1 top =
154         top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
155     const T1 bottom = bottom_left[channel] +
156                       (bottom_right[channel] - bottom_left[channel]) * x_lerp;
157     return top + (bottom - top) * y_lerp;
158   }
159 
160   const T1* top_left;
161   const T1* top_right;
162   const T1* bottom_left;
163   const T1* bottom_right;
164 };
165 
166 // Calculates parameters which remain constant through a resize cache row.
167 template <class T1>
CalculatePerCacheLineParameters(int64 cache_height,int64 cache_y,T1 * resize_cache,int64 cache_line_width,int64 input_width,int64 input_depth,int64 top_padding,int64 pad_offset,int64 resized_height,const ImageResizerState & st,const T1 * input_batch_start)168 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
169     int64 cache_height, int64 cache_y, T1* resize_cache, int64 cache_line_width,
170     int64 input_width, int64 input_depth, int64 top_padding, int64 pad_offset,
171     int64 resized_height, const ImageResizerState& st,
172     const T1* input_batch_start) {
173   PerCacheLineParameters<T1> result;
174   // The cache is organized so that the real y values of the resized image map
175   // onto the actual cache values through a modulo scheme. This means that as we
176   // progress downwards through the image, we keep reusing a small cache and so
177   // keep memory usage down.
178   int64 cache_index_y;
179   if (cache_y < 0) {
180     cache_index_y = cache_height + (cache_y % cache_height);
181   } else {
182     cache_index_y = cache_y % cache_height;
183   }
184   result.cache_line_start =
185       resize_cache + (cache_index_y * cache_line_width * input_depth);
186   // This part is implementing the mirror padding that happens before resizing.
187   float in_y = (cache_y - top_padding);
188   if (in_y < 0) {
189     in_y = -(in_y + 1.0f - pad_offset);
190   } else if (in_y >= resized_height) {
191     in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
192   }
193   // Here's where do do the actual resize.
194   in_y *= st.height_scale;
195   const int64 top_y_index = static_cast<int64>(std::floor(in_y));
196   const int64 bottom_y_index =
197       std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
198   // Lerp is used for bilinear filtering when that's needed.
199   result.y_lerp = static_cast<T1>(in_y - top_y_index);
200   // Which rows of the original input image to pull the values from.
201   result.input_top_row_start =
202       input_batch_start + (top_y_index * input_width * input_depth);
203   result.input_bottom_row_start =
204       input_batch_start + (bottom_y_index * input_width * input_depth);
205   return result;
206 }
207 
208 template <class T1>
209 struct PerCachePixelParameters {
PerCachePixelParameterstensorflow::__anond19478670111::PerCachePixelParameters210   PerCachePixelParameters() {}
PerCachePixelParameterstensorflow::__anond19478670111::PerCachePixelParameters211   PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
212       : cache_line_pixel(other.cache_line_pixel),
213         left_x_index(other.left_x_index),
214         right_x_index(other.right_x_index),
215         x_lerp(other.x_lerp) {}
216 
217   T1* cache_line_pixel;
218   int64 left_x_index;
219   int64 right_x_index;
220   T1 x_lerp;
221 };
222 
223 // Pulls out common parameters used for every resized pixel.
224 template <class T1>
225 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
CalculatePerCachePixelParameters(int64 cache_x,int64 cache_start_x,T1 * cache_line_start,int64 input_depth,int64 left_padding,int64 pad_offset,int64 resized_width,const ImageResizerState & st)226 CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
227                                  T1* cache_line_start, int64 input_depth,
228                                  int64 left_padding, int64 pad_offset,
229                                  int64 resized_width,
230                                  const ImageResizerState& st) {
231   PerCachePixelParameters<T1> result;
232   // Figure out where we're going to store the results of our transform.
233   const int cache_index_x = cache_x - cache_start_x;
234   result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
235   // Implement mirror padding by flipping in_x if it's off the edge.
236   float in_x = (cache_x - left_padding);
237   if (in_x < 0) {
238     in_x = -(in_x + 1.0f - pad_offset);
239   } else if (in_x >= resized_width) {
240     in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
241   }
242   // Resize the x parameters.
243   in_x *= st.width_scale;
244   // Get the x coordinates for the left and right pixels to pull from.
245   result.left_x_index = static_cast<int64>(std::floor(in_x));
246   result.right_x_index =
247       std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
248   // This x_lerp is used to blend pixels in bilinear filtering.
249   result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
250   return result;
251 }
252 
253 // Combines bilinear resizing and mirror padding into the im2col transformation
254 // stage of convolution.
255 template <class T1, class T2, class T3, class TGemmFunctor,
256           SamplingMode SampleMode>
257 class FusedResizeAndPadConvFunctor {
258  public:
operator ()(OpKernelContext * context,const Tensor & input,int input_batches,int resized_height,int resized_width,int padded_height,int padded_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width,const ImageResizerState & st,int top_padding,int bottom_padding,int left_padding,int right_padding,int pad_offset)259   void operator()(OpKernelContext* context, const Tensor& input,
260                   int input_batches, int resized_height, int resized_width,
261                   int padded_height, int padded_width, int input_depth,
262                   const T2* filter_data, int filter_height, int filter_width,
263                   int filter_count, int stride_rows, int stride_cols,
264                   Padding padding, T3* output_data, int output_height,
265                   int output_width, const ImageResizerState& st,
266                   int top_padding, int bottom_padding, int left_padding,
267                   int right_padding, int pad_offset) {
268     if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
269         (input_depth <= 0)) {
270       LOG(WARNING) << "Conv2D was called with bad input dimensions: "
271                    << input_batches << ", " << padded_height << ", "
272                    << padded_width << ", " << input_depth;
273       return;
274     }
275     if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
276       LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
277                    << filter_width << ", " << filter_height << ", "
278                    << filter_count;
279       return;
280     }
281     if ((output_width <= 0) || (output_height <= 0)) {
282       LOG(WARNING) << "Conv2D was called with bad output width or height: "
283                    << output_width << ", " << output_height;
284       return;
285     }
286     OP_REQUIRES(
287         context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
288         errors::InvalidArgument("Bad sample mode passed in", SampleMode));
289 
290     // These calculations define how the patches will be positioned within the
291     // input image. The actual definitions are quite complex, and rely on the
292     // previously-calculated output size.
293     int filter_left_offset;
294     int filter_top_offset;
295     if (padding == VALID) {
296       filter_left_offset =
297           ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
298           2;
299       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
300                            padded_height + 1) /
301                           2;
302     } else {
303       filter_left_offset =
304           ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
305       filter_top_offset =
306           ((output_height - 1) * stride_rows + filter_height - padded_height) /
307           2;
308     }
309 
310     ResizeTaskParameters<T1> task_params;
311     task_params.input_depth = input_depth;
312     task_params.top_padding = top_padding;
313     task_params.pad_offset = pad_offset;
314     task_params.resized_height = resized_height;
315     task_params.st = st;
316     task_params.left_padding = left_padding;
317     task_params.resized_width = resized_width;
318     task_params.padded_width = padded_width;
319     task_params.padded_height = padded_height;
320 
321     // The im2col buffer has # of patches rows, and # of filters cols.
322     // It's laid out like this, in row major order in memory:
323     //        < filter value count >
324     //   ^   +---------------------+
325     // patch |                     |
326     // count |                     |
327     //   v   +---------------------+
328     // Each patch row contains a filter_width x filter_height patch of the
329     // input, with the depth channel as the most contiguous in memory, followed
330     // by the width, then the height. This is the standard memory order in the
331     // image world if it helps to visualize it.
332     const int filter_value_count = filter_width * filter_height * input_depth;
333 
334     OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
335                 errors::InvalidArgument("Im2Col patch too large for buffer"));
336     const size_t patches_per_chunk =
337         kMaxChunkSize / (filter_value_count * sizeof(T1));
338     // Because memory allocation is very expensive on mobile platforms, try to
339     // allocate a persistent buffer that will be kept around between calls. We
340     // use TensorFlow's resource management to ensure that the memory will be
341     // released when the session is over.
342     Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
343     std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
344         [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
345           *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
346           return Status::OK();
347         };
348     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
349                                 "Conv2d", "im2col_buffer",
350                                 &im2col_buffer_resource, creator));
351 
352     // Create a resize cache memory buffer that will hold the rows of
353     // transformed and mirror padded input pixels, ready to be copied
354     // into filter patches by im2col.
355     // It's laid out like this, in row major order in memory:
356     //         < cache line width >
357     //   ^    +--------------------+
358     // cache  |                    |
359     // height |                    |
360     //   v    +--------------------+
361     // Each cache row contains a cache_line_width number of resized pixels,
362     // each with input_depth channels. The cache height is typically less than
363     // the full height the resized image would be, so it's filled up
364     // incrementally as we progress downwards through the input creating im2col
365     // patches.
366     task_params.cache_start_x = -filter_left_offset;
367     task_params.cache_end_x =
368         (((output_width - 1) * stride_cols) - filter_left_offset) +
369         filter_width;
370     task_params.cache_line_width =
371         task_params.cache_end_x - task_params.cache_start_x;
372     task_params.cache_height =
373         kResizeCacheSize / (task_params.cache_line_width * input_depth);
374     const int needed_resize_cache_count =
375         filter_height * task_params.cache_line_width * input_depth;
376     OP_REQUIRES(context,
377                 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
378                 errors::InvalidArgument("Input too large for resize cache"));
379     Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
380     std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
381         resize_creator =
382             [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
383               *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
384               return Status::OK();
385             };
386     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
387                                 "Conv2d", "resize_cache",
388                                 &resize_cache_resource, resize_creator));
389 
390     // This means that multiple ops can't be run simultaneously on different
391     // threads, because we have a single shared resource. The platforms this is
392     // aimed at have intra-op parallelism as their focus though, so it shouldn't
393     // be an issue.
394     mutex_lock lock_buffer(im2col_buffer_resource->mu);
395     core::ScopedUnref unref_buffer(im2col_buffer_resource);
396     T1* im2col_buffer = im2col_buffer_resource->data;
397 
398     // This buffer is used as a fairly heavy-weight cache for the resized and
399     // mirrored inputs to the im2col operation. The problem is that we want to
400     // keep the memory usage down by not rendering the fully resized and padded
401     // input tensor to the convolution into an entire buffer. The first approach
402     // to avoid this was to fold the bilinear filtering and padding spatial
403     // transformations into the im2col lookup itself. This successfully reduced
404     // memory usage, but because im2col can access an individual pixel for many
405     // different patches, the extra overhead of doing the same bilinear lookups
406     // repeatedly became too expensive.
407     // The resize cache is designed to avoid this problem by keeping a
408     // horizontal slice of the resized and padded input to the im2col
409     // precalculated, so that repeated accesses to the same pixel from different
410     // filter patches can just be copied from this cache. It's organized as a
411     // horizontal slice stretching across the whole virtual image, and as high
412     // as the filter window, so that as the patch processing moves across all
413     // the pixels are present, and before a new row of patches is started any
414     // previously calculated rows that are needed are maintained, with new rows
415     // calculated as required.
416     mutex_lock resize_lock_buffer(resize_cache_resource->mu);
417     core::ScopedUnref unref_resized_cache(resize_cache_resource);
418     task_params.resize_cache = resize_cache_resource->data;
419 
420     const T1* input_data = input.flat<T1>().data();
421     const int64 input_height = input.shape().dim_sizes()[1];
422     task_params.input_width = input.shape().dim_sizes()[2];
423 
424     int end_cached_lines = std::numeric_limits<int>::min();
425 
426     for (int batch = 0; batch < input_batches; ++batch) {
427       task_params.input_batch_start =
428           input_data +
429           (batch * input_height * task_params.input_width * input_depth);
430       const int in_y_end =
431           ((output_height * stride_rows) - filter_top_offset) + filter_height;
432       for (int out_y = 0; out_y < output_height; ++out_y) {
433         const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
434         const int cache_start_y = std::max(in_y_origin, end_cached_lines);
435         const int cache_end_y = std::min(
436             in_y_end, std::max((in_y_origin + task_params.cache_height),
437                                end_cached_lines));
438         if (end_cached_lines < (in_y_origin + filter_height)) {
439           // This call breaks up the work required for calculating the mirror
440           // padding and resizing across multiple threads.
441           FusedConvParallelFor(
442               context, cache_start_y, cache_end_y,
443               [task_params](int64 task_cache_start_y, int64 task_cache_end_y) {
444                 // This is a long and confusing function, but it's been laid out
445                 // this way to help with performance on some intensive models.
446                 // What it's doing is populating a cache of the original input
447                 // image, after it's been bilinear resized and had its edges
448                 // mirrored. This allows the following im2col code to access the
449                 // transformed pixels from this cache, without having to
450                 // repeatedly apply the expensive bilinear calculations as the
451                 // same pixels are accessed by different patches.
452                 // This is most effective when the stride is small and the
453                 // filter size is large, since that's when pixels are reused
454                 // most frequently as patches overlap.
455                 for (int cache_y = task_cache_start_y;
456                      cache_y < task_cache_end_y; ++cache_y) {
457                   // We organize the cache as a series of rows, each containing
458                   // all the transformed pixels for a given line in the image.
459                   // This cache is big enough to hold at least a filter's height
460                   // worth of rows, but typically more, limited by the size of
461                   // the cache buffer.
462                   // We don't allocate an entire image's worth of rows though,
463                   // because we're trying to keep memory usage down, so as we
464                   // progress downwards through the im2col we periodically
465                   // refresh the cache so that the next lines that are needed
466                   // for that operation are always present.
467                   // Work out the parameters that remain constant across the
468                   // row we're calculating.
469                   PerCacheLineParameters<T1> line_params(
470                       CalculatePerCacheLineParameters<T1>(
471                           task_params.cache_height, cache_y,
472                           task_params.resize_cache,
473                           task_params.cache_line_width, task_params.input_width,
474                           task_params.input_depth, task_params.top_padding,
475                           task_params.pad_offset, task_params.resized_height,
476                           task_params.st, task_params.input_batch_start));
477                   // Iterate through the resize cache row we're filling in.
478                   for (int cache_x = task_params.cache_start_x;
479                        cache_x < task_params.cache_end_x; ++cache_x) {
480                     // Figure out what we need for the cache pixel we're
481                     // populating.
482                     PerCachePixelParameters<T1> pixel_params(
483                         CalculatePerCachePixelParameters<T1>(
484                             cache_x, task_params.cache_start_x,
485                             line_params.cache_line_start,
486                             task_params.input_depth, task_params.left_padding,
487                             task_params.pad_offset, task_params.resized_width,
488                             task_params.st));
489                     // If the access is off the left, right, top, or bottom of
490                     // the resized image, the conv padding means we should set
491                     // it to zero.
492                     if ((cache_x < 0) ||
493                         (cache_x >= task_params.padded_width) ||
494                         (cache_y < 0) ||
495                         (cache_y >= task_params.padded_height)) {
496                       std::fill_n(pixel_params.cache_line_pixel,
497                                   task_params.input_depth, T1(0));
498                     } else {
499                       // There are two different sampling strategies for
500                       // resizing. When using nearest, we can just do a
501                       // straight copy of the pixel closest to our sample point,
502                       // but bilinear requires a more complex calculation.
503                       if (SampleMode == NEAREST) {
504                         const T1* input_top_left_pixel =
505                             line_params.input_top_row_start +
506                             (pixel_params.left_x_index *
507                              task_params.input_depth);
508 
509                         std::copy_n(input_top_left_pixel,
510                                     task_params.input_depth,
511                                     pixel_params.cache_line_pixel);
512                       } else {
513                         const SampleRect<T1> rect(
514                             line_params.input_top_row_start +
515                                 (pixel_params.left_x_index *
516                                  task_params.input_depth),
517                             line_params.input_top_row_start +
518                                 (pixel_params.right_x_index *
519                                  task_params.input_depth),
520                             line_params.input_bottom_row_start +
521                                 (pixel_params.left_x_index *
522                                  task_params.input_depth),
523                             line_params.input_bottom_row_start +
524                                 (pixel_params.right_x_index *
525                                  task_params.input_depth));
526                         for (int in_channel = 0;
527                              in_channel < task_params.input_depth;
528                              ++in_channel) {
529                           pixel_params.cache_line_pixel[in_channel] =
530                               rect.BilinearSample(in_channel,
531                                                   pixel_params.x_lerp,
532                                                   line_params.y_lerp);
533                         }
534                       }
535                     }
536                   }
537                 }
538               });
539           end_cached_lines = cache_end_y;
540         }
541         for (int out_x = 0; out_x < output_width; ++out_x) {
542           const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
543           const int patch_index = (batch * output_width * output_height) +
544                                   (out_y * output_width) + out_x;
545           const int patch_index_within_chunk = patch_index % patches_per_chunk;
546           T1* im2col_patch_start =
547               im2col_buffer + (patch_index_within_chunk * filter_value_count);
548           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
549             T1* im2col_row_start =
550                 im2col_patch_start +
551                 (filter_y * filter_width * task_params.input_depth);
552             const int conv_in_y = in_y_origin + filter_y;
553             int cache_index_y;
554             if (conv_in_y < 0) {
555               cache_index_y = task_params.cache_height +
556                               (conv_in_y % task_params.cache_height);
557             } else {
558               cache_index_y = conv_in_y % task_params.cache_height;
559             }
560             T1* cache_line_start =
561                 task_params.resize_cache +
562                 (cache_index_y * task_params.cache_line_width *
563                  task_params.input_depth);
564             T1* cache_filter_row_start =
565                 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
566                                     task_params.input_depth);
567             std::copy_n(cache_filter_row_start,
568                         (filter_width * task_params.input_depth),
569                         im2col_row_start);
570           }
571           const bool is_last_in_chunk =
572               (patch_index_within_chunk == (patches_per_chunk - 1));
573           const bool is_last_overall =
574               ((batch == (input_batches - 1)) &&
575                (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
576           if (is_last_in_chunk || is_last_overall) {
577             // Now we've assembled a set of image patches into a matrix, apply
578             // a GEMM matrix multiply of the patches as rows, times the filter
579             // weights in columns, to get partial results in the output
580             // matrix.
581             const int how_many_patches = patch_index_within_chunk + 1;
582             const int m = how_many_patches;
583             const int n = filter_count;
584             const int k = filter_value_count;
585             const int lda = filter_value_count;
586             const int ldb = filter_count;
587             const int ldc = filter_count;
588             const size_t start_patch_index =
589                 patch_index - (how_many_patches - 1);
590             T3* chunk_output_data =
591                 output_data + (start_patch_index * filter_count);
592             TGemmFunctor gemm_functor;
593             gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
594                          chunk_output_data, ldc);
595           }
596         }
597       }
598     }
599   }
600 };
601 
602 }  // namespace
603 
604 // Implements a version of convolution with bilinear resizing and mirror padding
605 // included.
606 template <class T, class TConvFunctor, bool DoResize>
607 class FusedResizeConv2DUsingGemmOp : public OpKernel {
608  public:
FusedResizeConv2DUsingGemmOp(OpKernelConstruction * context)609   explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
610       : OpKernel(context) {
611     if (DoResize) {
612       OP_REQUIRES_OK(context,
613                      context->GetAttr("resize_align_corners", &align_corners_));
614     }
615     MirrorPadMode mode;
616     OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
617 
618     switch (mode) {
619       case MirrorPadMode::SYMMETRIC: {
620         offset_ = 0;
621         break;
622       }
623       case MirrorPadMode::REFLECT: {
624         offset_ = 1;
625         break;
626       }
627       default:
628         OP_REQUIRES(context, false,
629                     errors::InvalidArgument(
630                         "mode must be either REFLECT or SYMMETRIC."));
631     }
632     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
633     OP_REQUIRES(context, strides_.size() == 4,
634                 errors::InvalidArgument("Sliding window strides field must "
635                                         "specify 4 dimensions"));
636     const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
637     const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
638     OP_REQUIRES(
639         context, stride_n == 1 && stride_c == 1,
640         errors::InvalidArgument("Current implementation does not yet support "
641                                 "strides in the batch and depth dimensions."));
642     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
643   }
644 
Compute(OpKernelContext * context)645   void Compute(OpKernelContext* context) override {
646     // Input tensor is of the following dimensions:
647     // [ batch, in_rows, in_cols, in_depth ]
648     const Tensor& input = context->input(0);
649     OP_REQUIRES(context, (input.shape().num_elements() > 0),
650                 errors::InvalidArgument("Input tensor can't be empty"));
651 
652     ImageResizerState st(false, false);
653     if (DoResize) {
654       st = ImageResizerState(align_corners_, false);
655       st.ValidateAndCalculateOutputSize(context, input);
656       if (!context->status().ok()) return;
657     } else {
658       // Set up the resize parameters to do no scaling at all.
659       st.batch_size = input.dim_size(0);
660       st.out_height = input.dim_size(1);
661       st.out_width = input.dim_size(2);
662       st.in_height = input.dim_size(1);
663       st.in_width = input.dim_size(2);
664       st.channels = input.dim_size(3);
665       st.height_scale = 1.0f;
666       st.width_scale = 1.0f;
667     }
668     TensorShape resized_shape(
669         {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
670     int paddings_index;
671     int filter_index;
672     if (DoResize) {
673       paddings_index = 2;
674       filter_index = 3;
675     } else {
676       paddings_index = 1;
677       filter_index = 2;
678     }
679     const Tensor& paddings = context->input(paddings_index);
680 
681     const int dims = resized_shape.dims();
682     OP_REQUIRES(
683         context,
684         TensorShapeUtils::IsMatrix(paddings.shape()) &&
685             paddings.dim_size(1) == 2,
686         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
687                                 paddings.shape().DebugString()));
688     const int fixed_dims =
689         (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
690             ? 1
691             : dims;
692     OP_REQUIRES(
693         context, fixed_dims == paddings.dim_size(0),
694         errors::InvalidArgument(
695             "The first dimension of paddings must be the rank of inputs: ",
696             fixed_dims, " ", paddings.shape().DebugString(), " ",
697             resized_shape.DebugString()));
698     OP_REQUIRES(
699         context, dims == paddings.dim_size(0),
700         errors::InvalidArgument(
701             "The first dimension of paddings must be the rank of inputs: ",
702             dims, " ", paddings.shape().DebugString(), " ",
703             resized_shape.DebugString()));
704 
705     OP_REQUIRES(
706         context, dims == 4,
707         errors::InvalidArgument(
708             "Fused mirror padding only supports four-dimensional inputs, but ",
709             dims, " requested"));
710 
711     // Compute the shape of the output tensor, and allocate it.
712     TensorShape padded_shape;
713     TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
714     for (int d = 0; d < dims; ++d) {
715       const int32 before =
716           paddings_matrix(d, 0);  // Pad before existing elements.
717       const int32 after =
718           paddings_matrix(d, 1);  // Pad after existing elements.
719       OP_REQUIRES(context, before >= 0 && after >= 0,
720                   errors::InvalidArgument(
721                       "paddings must be non-negative: ", before, " ", after));
722       if (offset_ == 0) {  // SYMMETRIC mode.
723         OP_REQUIRES(
724             context,
725             before <= resized_shape.dim_size(d) &&
726                 after <= resized_shape.dim_size(d),
727             errors::InvalidArgument("paddings must be no greater "
728                                     "than the dimension size: ",
729                                     before, ", ", after, " greater than ",
730                                     resized_shape.dim_size(d)));
731       } else if (offset_ == 1) {  // REFLECT mode.
732         OP_REQUIRES(
733             context,
734             before < resized_shape.dim_size(d) &&
735                 after < resized_shape.dim_size(d),
736             errors::InvalidArgument("paddings must be less than"
737                                     " the dimension size: ",
738                                     before, ", ", after, " not less than ",
739                                     resized_shape.dim_size(d)));
740       }
741       padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
742     }
743 
744     OP_REQUIRES(
745         context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
746         errors::InvalidArgument(
747             "Fused mirror padding only support spatial padding, not batches: ",
748             paddings.DebugString()));
749     OP_REQUIRES(
750         context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
751         errors::InvalidArgument(
752             "Fused mirror padding only support spatial padding, not channels: ",
753             paddings.DebugString()));
754     const int32 top_padding = paddings_matrix(1, 0);
755     const int32 bottom_padding = paddings_matrix(1, 1);
756     const int32 left_padding = paddings_matrix(2, 0);
757     const int32 right_padding = paddings_matrix(2, 1);
758 
759     // Input filter is of the following dimensions:
760     // [ filter_rows, filter_cols, in_depth, out_depth]
761     const Tensor& filter = context->input(filter_index);
762 
763     // For 2D convolution, there should be 4 dimensions.
764     OP_REQUIRES(context, padded_shape.dims() == 4,
765                 errors::InvalidArgument("input must be 4-dimensional",
766                                         padded_shape.DebugString()));
767     OP_REQUIRES(context, filter.dims() == 4,
768                 errors::InvalidArgument("filter must be 4-dimensional: ",
769                                         filter.shape().DebugString()));
770 
771     // We only check the first three dims, since the depth is accessed as an
772     // int64 below.
773     for (int i = 0; i < 3; i++) {
774       OP_REQUIRES(
775           context,
776           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
777           errors::InvalidArgument("filter too large"));
778     }
779 
780     // The last dimension for input is in_depth. It must be the same as the
781     // filter's in_depth.
782     const int64 in_depth = padded_shape.dim_size(3);
783     OP_REQUIRES(context, in_depth == filter.dim_size(2),
784                 errors::InvalidArgument(
785                     "input and filter must have the same depth: ", in_depth,
786                     " vs ", filter.dim_size(2)));
787 
788     // The last dimension for filter is out_depth.
789     const int out_depth = static_cast<int>(filter.dim_size(3));
790 
791     // The second dimension for input is rows/height.
792     // The first dimension for filter is rows/height.
793     const int64 padded_rows_raw = padded_shape.dim_size(1);
794     OP_REQUIRES(
795         context,
796         FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
797         errors::InvalidArgument("Input rows too large"));
798     const int padded_rows = static_cast<int>(padded_rows_raw);
799     const int filter_rows = static_cast<int>(filter.dim_size(0));
800     const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
801 
802     // The third dimension for input is columns/width.
803     // The second dimension for filter is columns/width.
804     const int64 padded_cols_raw = padded_shape.dim_size(2);
805     OP_REQUIRES(
806         context,
807         FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
808         errors::InvalidArgument("Input cols too large"));
809     const int padded_cols = static_cast<int>(padded_cols_raw);
810     const int filter_cols = static_cast<int>(filter.dim_size(1));
811     const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
812 
813     // The first dimension for input is batch.
814     const int64 batch_raw = padded_shape.dim_size(0);
815     OP_REQUIRES(context,
816                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
817                 errors::InvalidArgument("batch is too large"));
818     const int batch = static_cast<int>(batch_raw);
819 
820     // For now we take the stride from the second and third dimensions only (we
821     // do not support striding on the batch or depth dimension).
822     const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
823     const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
824 
825     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
826     OP_REQUIRES_OK(context,
827                    GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
828                                          padding_, &out_rows, &pad_rows));
829     OP_REQUIRES_OK(context,
830                    GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
831                                          padding_, &out_cols, &pad_cols));
832     TensorShape out_shape =
833         ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
834     OP_REQUIRES(context, (out_shape.num_elements() > 0),
835                 errors::InvalidArgument("Output tensor can't be empty"));
836 
837     // Output tensor is of the following dimensions:
838     // [ in_batch, out_rows, out_cols, out_depth ]
839     Tensor* output = nullptr;
840     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
841 
842     VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
843             << ", padded_cols = " << padded_cols
844             << ", resized_cols = " << resized_cols
845             << ", filter_cols = " << filter_cols
846             << ", padded_rows = " << padded_rows
847             << ", resized_rows = " << resized_rows
848             << ", filter_rows = " << filter_rows
849             << ", stride_rows = " << stride_rows
850             << ", stride_cols = " << stride_cols
851             << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
852 
853     // If there is nothing to compute, return.
854     if (out_shape.num_elements() == 0) {
855       return;
856     }
857     TConvFunctor conv_functor;
858     conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
859                  padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
860                  filter_cols, out_depth, stride_rows, stride_cols, padding_,
861                  output->flat<T>().data(), out_rows, out_cols, st, top_padding,
862                  bottom_padding, left_padding, right_padding, offset_);
863   }
864 
865  private:
866   std::vector<int32> strides_;
867   Padding padding_;
868   bool align_corners_;
869   int offset_;
870 
871   TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
872 };
873 
874 #define REGISTER_FUSED(T)                                                 \
875   REGISTER_KERNEL_BUILDER(                                                \
876       Name("FusedResizeAndPadConv2D")                                     \
877           .Device(DEVICE_CPU)                                             \
878           .TypeConstraint<T>("T"),                                        \
879       FusedResizeConv2DUsingGemmOp<                                       \
880           T,                                                              \
881           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
882                                        BILINEAR>,                         \
883           true>);
884 
885 TF_CALL_half(REGISTER_FUSED);
886 TF_CALL_float(REGISTER_FUSED);
887 TF_CALL_double(REGISTER_FUSED);
888 
889 #define REGISTER_PAD_ONLY_FUSED(T)                                        \
890   REGISTER_KERNEL_BUILDER(                                                \
891       Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
892       FusedResizeConv2DUsingGemmOp<                                       \
893           T,                                                              \
894           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
895                                        NEAREST>,                          \
896           false>);
897 
898 TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
899 TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
900 TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
901 
902 }  // namespace tensorflow
903