• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements convolution operations with image transformations (resize and
17 // mirror padding) baked into the processing, to optimize latency and memory
18 // usage.
19 
20 #define EIGEN_USE_THREADS
21 
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_shape_util.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_slice.h"
34 #include "tensorflow/core/kernels/conv_2d.h"
35 #include "tensorflow/core/kernels/conv_ops.h"
36 #include "tensorflow/core/kernels/gemm_functors.h"
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/util/image_resizer_state.h"
40 #include "tensorflow/core/util/mirror_pad_mode.h"
41 #include "tensorflow/core/util/padding.h"
42 #include "tensorflow/core/util/tensor_format.h"
43 
44 namespace tensorflow {
45 namespace {
46 
47 // We don't want to allocate a buffer to hold all the patches if the size is
48 // going to be extremely large, so break it into chunks if it's bigger than
49 // a limit. Each chunk will be processed serially, so we can refill the
50 // buffer for the next chunk and reuse it, keeping maximum memory size down.
51 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
52 // other platforms using Eigen, and 1MB for iOS devices, from experimentation.
53 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
54 const size_t kMaxChunkSize = (1 * 1024 * 1024);
55 #else
56 const size_t kMaxChunkSize = (16 * 1024 * 1024);
57 #endif
58 const size_t kResizeCacheSize = (8 * 1024 * 1024);
59 
60 // Lookup method used when resizing.
61 enum SamplingMode {
62   BILINEAR = 0,
63   NEAREST = 1,
64 };
65 
66 // Simple utility function used by FusedConv to multithread basic workloads. To
67 // use it, pass begin and end values for the full workload and a std::function
68 // that receives a subset of that through the begin and end values for each
69 // worker's task. The division of the full workload into worker tasks is handled
70 // by the multithreading logic. Here's an example of how to use it:
71 // std::vector<float> my_vector(100);
72 // ...
73 // FusedConvParallelFor(context, 0, 100,
74 //   [&my_vector](int64 task_begin, int64 task_end) {
75 //     for (int64 current = task_begin; current != task_end; ++current) {
76 //       my_vector[current] *= 10.0f;
77 //     }
78 // });
FusedConvParallelFor(OpKernelContext * context,int64_t begin,int64_t end,const std::function<void (int64_t,int64_t)> & task_function)79 void FusedConvParallelFor(
80     OpKernelContext* context, int64_t begin, int64_t end,
81     const std::function<void(int64_t, int64_t)>& task_function) {
82 // On iOS, the thread management imposes a very big performance penalty, so
83 // just call the function directly with no multithreading.
84 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
85   task_function(begin, end);
86 #else
87   auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
88   thread::ThreadPool* thread_pool = worker_threads.workers;
89   const int64_t total_elements = end - begin;
90   // This is a bit of an arbitrary number, but was found to work well for
91   // typical models we've been profiling on various devices.
92   const int64_t element_cost = 10000000;
93   thread_pool->ParallelFor(
94       total_elements, element_cost,
95       [begin, task_function](int64_t begin_offset, int64_t end_offset) {
96         const int64_t task_begin = begin + begin_offset;
97         const int64_t task_end = begin + end_offset;
98         task_function(task_begin, task_end);
99       });
100 #endif
101 }
102 
103 // Holds the state needed for the resizing subtasks.
104 template <class T1>
105 struct ResizeTaskParameters {
ResizeTaskParameterstensorflow::__anonfacc87f40111::ResizeTaskParameters106   ResizeTaskParameters() : st(false, false) {}
107 
108   int cache_height;
109   T1* resize_cache;
110   int cache_line_width;
111   int input_width;
112   int input_depth;
113   int top_padding;
114   int pad_offset;
115   int64_t resized_height;
116   ImageResizerState st;
117   const T1* input_batch_start;
118   int64_t cache_start_x;
119   int64_t cache_end_x;
120   int left_padding;
121   int64_t resized_width;
122   int64_t padded_width;
123   int64_t padded_height;
124 };
125 
126 template <class T1>
127 struct PerCacheLineParameters {
PerCacheLineParameterstensorflow::__anonfacc87f40111::PerCacheLineParameters128   PerCacheLineParameters() {}
PerCacheLineParameterstensorflow::__anonfacc87f40111::PerCacheLineParameters129   PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
130       : cache_line_start(other.cache_line_start),
131         input_top_row_start(other.input_top_row_start),
132         input_bottom_row_start(other.input_bottom_row_start),
133         y_lerp(other.y_lerp) {}
134 
135   T1* cache_line_start;
136   const T1* input_top_row_start;
137   const T1* input_bottom_row_start;
138   T1 y_lerp;
139 };
140 
141 // Helper class to simplify bilinear filtering
142 template <class T1>
143 struct SampleRect {
SampleRecttensorflow::__anonfacc87f40111::SampleRect144   EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
145                                  const T1* in_bottom_left,
146                                  const T1* in_bottom_right)
147       : top_left(in_top_left),
148         top_right(in_top_right),
149         bottom_left(in_bottom_left),
150         bottom_right(in_bottom_right) {}
151 
BilinearSampletensorflow::__anonfacc87f40111::SampleRect152   EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
153                                         T1 y_lerp) const {
154     const T1 top =
155         top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
156     const T1 bottom = bottom_left[channel] +
157                       (bottom_right[channel] - bottom_left[channel]) * x_lerp;
158     return top + (bottom - top) * y_lerp;
159   }
160 
161   const T1* top_left;
162   const T1* top_right;
163   const T1* bottom_left;
164   const T1* bottom_right;
165 };
166 
167 // Calculates parameters which remain constant through a resize cache row.
168 template <class T1>
CalculatePerCacheLineParameters(int64_t cache_height,int64_t cache_y,T1 * resize_cache,int64_t cache_line_width,int64_t input_width,int64_t input_depth,int64_t top_padding,int64_t pad_offset,int64_t resized_height,const ImageResizerState & st,const T1 * input_batch_start)169 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
170     int64_t cache_height, int64_t cache_y, T1* resize_cache,
171     int64_t cache_line_width, int64_t input_width, int64_t input_depth,
172     int64_t top_padding, int64_t pad_offset, int64_t resized_height,
173     const ImageResizerState& st, const T1* input_batch_start) {
174   PerCacheLineParameters<T1> result;
175   // The cache is organized so that the real y values of the resized image map
176   // onto the actual cache values through a modulo scheme. This means that as we
177   // progress downwards through the image, we keep reusing a small cache and so
178   // keep memory usage down.
179   int64_t cache_index_y;
180   if (cache_y < 0) {
181     cache_index_y = cache_height + (cache_y % cache_height);
182   } else {
183     cache_index_y = cache_y % cache_height;
184   }
185   result.cache_line_start =
186       resize_cache + (cache_index_y * cache_line_width * input_depth);
187   // This part is implementing the mirror padding that happens before resizing.
188   float in_y = (cache_y - top_padding);
189   if (in_y < 0) {
190     in_y = -(in_y + 1.0f - pad_offset);
191   } else if (in_y >= resized_height) {
192     in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
193   }
194   // Here's where to do the actual resize.
195   in_y *= st.height_scale;
196   const int64_t top_y_index = static_cast<int64_t>(std::floor(in_y));
197   const int64_t bottom_y_index =
198       std::min(static_cast<int64_t>(std::ceil(in_y)), (st.in_height - 1));
199   // Lerp is used for bilinear filtering when that's needed.
200   result.y_lerp = static_cast<T1>(in_y - top_y_index);
201   // Which rows of the original input image to pull the values from.
202   result.input_top_row_start =
203       input_batch_start + (top_y_index * input_width * input_depth);
204   result.input_bottom_row_start =
205       input_batch_start + (bottom_y_index * input_width * input_depth);
206   return result;
207 }
208 
209 template <class T1>
210 struct PerCachePixelParameters {
PerCachePixelParameterstensorflow::__anonfacc87f40111::PerCachePixelParameters211   PerCachePixelParameters() {}
PerCachePixelParameterstensorflow::__anonfacc87f40111::PerCachePixelParameters212   PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
213       : cache_line_pixel(other.cache_line_pixel),
214         left_x_index(other.left_x_index),
215         right_x_index(other.right_x_index),
216         x_lerp(other.x_lerp) {}
217 
218   T1* cache_line_pixel;
219   int64_t left_x_index;
220   int64_t right_x_index;
221   T1 x_lerp;
222 };
223 
224 // Pulls out common parameters used for every resized pixel.
225 template <class T1>
226 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
CalculatePerCachePixelParameters(int64_t cache_x,int64_t cache_start_x,T1 * cache_line_start,int64_t input_depth,int64_t left_padding,int64_t pad_offset,int64_t resized_width,const ImageResizerState & st)227 CalculatePerCachePixelParameters(int64_t cache_x, int64_t cache_start_x,
228                                  T1* cache_line_start, int64_t input_depth,
229                                  int64_t left_padding, int64_t pad_offset,
230                                  int64_t resized_width,
231                                  const ImageResizerState& st) {
232   PerCachePixelParameters<T1> result;
233   // Figure out where we're going to store the results of our transform.
234   const int cache_index_x = cache_x - cache_start_x;
235   result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
236   // Implement mirror padding by flipping in_x if it's off the edge.
237   float in_x = (cache_x - left_padding);
238   if (in_x < 0) {
239     in_x = -(in_x + 1.0f - pad_offset);
240   } else if (in_x >= resized_width) {
241     in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
242   }
243   // Resize the x parameters.
244   in_x *= st.width_scale;
245   // Get the x coordinates for the left and right pixels to pull from.
246   result.left_x_index = static_cast<int64_t>(std::floor(in_x));
247   result.right_x_index =
248       std::min(static_cast<int64_t>(std::ceil(in_x)), (st.in_width - 1));
249   // This x_lerp is used to blend pixels in bilinear filtering.
250   result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
251   return result;
252 }
253 
254 // Combines bilinear resizing and mirror padding into the im2col transformation
255 // stage of convolution.
256 template <class T1, class T2, class T3, class TGemmFunctor,
257           SamplingMode SampleMode>
258 class FusedResizeAndPadConvFunctor {
259  public:
operator ()(OpKernelContext * context,const Tensor & input,int input_batches,int resized_height,int resized_width,int padded_height,int padded_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width,const ImageResizerState & st,int top_padding,int bottom_padding,int left_padding,int right_padding,int pad_offset)260   void operator()(OpKernelContext* context, const Tensor& input,
261                   int input_batches, int resized_height, int resized_width,
262                   int padded_height, int padded_width, int input_depth,
263                   const T2* filter_data, int filter_height, int filter_width,
264                   int filter_count, int stride_rows, int stride_cols,
265                   Padding padding, T3* output_data, int output_height,
266                   int output_width, const ImageResizerState& st,
267                   int top_padding, int bottom_padding, int left_padding,
268                   int right_padding, int pad_offset) {
269     if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
270         (input_depth <= 0)) {
271       LOG(WARNING) << "Conv2D was called with bad input dimensions: "
272                    << input_batches << ", " << padded_height << ", "
273                    << padded_width << ", " << input_depth;
274       return;
275     }
276     if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
277       LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
278                    << filter_width << ", " << filter_height << ", "
279                    << filter_count;
280       return;
281     }
282     if ((output_width <= 0) || (output_height <= 0)) {
283       LOG(WARNING) << "Conv2D was called with bad output width or height: "
284                    << output_width << ", " << output_height;
285       return;
286     }
287     OP_REQUIRES(
288         context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
289         errors::InvalidArgument("Bad sample mode passed in", SampleMode));
290 
291     // These calculations define how the patches will be positioned within the
292     // input image. The actual definitions are quite complex, and rely on the
293     // previously-calculated output size.
294     int filter_left_offset;
295     int filter_top_offset;
296     if (padding == VALID) {
297       filter_left_offset =
298           ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
299           2;
300       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
301                            padded_height + 1) /
302                           2;
303     } else {
304       filter_left_offset =
305           ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
306       filter_top_offset =
307           ((output_height - 1) * stride_rows + filter_height - padded_height) /
308           2;
309     }
310 
311     ResizeTaskParameters<T1> task_params;
312     task_params.input_depth = input_depth;
313     task_params.top_padding = top_padding;
314     task_params.pad_offset = pad_offset;
315     task_params.resized_height = resized_height;
316     task_params.st = st;
317     task_params.left_padding = left_padding;
318     task_params.resized_width = resized_width;
319     task_params.padded_width = padded_width;
320     task_params.padded_height = padded_height;
321 
322     // The im2col buffer has # of patches rows, and # of filters cols.
323     // It's laid out like this, in row major order in memory:
324     //        < filter value count >
325     //   ^   +---------------------+
326     // patch |                     |
327     // count |                     |
328     //   v   +---------------------+
329     // Each patch row contains a filter_width x filter_height patch of the
330     // input, with the depth channel as the most contiguous in memory, followed
331     // by the width, then the height. This is the standard memory order in the
332     // image world if it helps to visualize it.
333     const int filter_value_count = filter_width * filter_height * input_depth;
334 
335     OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
336                 errors::InvalidArgument("Im2Col patch too large for buffer"));
337     const size_t patches_per_chunk =
338         kMaxChunkSize / (filter_value_count * sizeof(T1));
339     // Because memory allocation is very expensive on mobile platforms, try to
340     // allocate a persistent buffer that will be kept around between calls. We
341     // use TensorFlow's resource management to ensure that the memory will be
342     // released when the session is over.
343     Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
344     std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
345         [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
346           *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
347           return OkStatus();
348         };
349     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
350                                 "Conv2d", "im2col_buffer",
351                                 &im2col_buffer_resource, creator));
352 
353     // Create a resize cache memory buffer that will hold the rows of
354     // transformed and mirror padded input pixels, ready to be copied
355     // into filter patches by im2col.
356     // It's laid out like this, in row major order in memory:
357     //         < cache line width >
358     //   ^    +--------------------+
359     // cache  |                    |
360     // height |                    |
361     //   v    +--------------------+
362     // Each cache row contains a cache_line_width number of resized pixels,
363     // each with input_depth channels. The cache height is typically less than
364     // the full height the resized image would be, so it's filled up
365     // incrementally as we progress downwards through the input creating im2col
366     // patches.
367     task_params.cache_start_x = -filter_left_offset;
368     task_params.cache_end_x =
369         (((output_width - 1) * stride_cols) - filter_left_offset) +
370         filter_width;
371     task_params.cache_line_width =
372         task_params.cache_end_x - task_params.cache_start_x;
373     task_params.cache_height =
374         kResizeCacheSize / (task_params.cache_line_width * input_depth);
375     const int needed_resize_cache_count =
376         filter_height * task_params.cache_line_width * input_depth;
377     OP_REQUIRES(context,
378                 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
379                 errors::InvalidArgument("Input too large for resize cache"));
380     Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
381     std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
382         resize_creator =
383             [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
384               *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
385               return OkStatus();
386             };
387     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
388                                 "Conv2d", "resize_cache",
389                                 &resize_cache_resource, resize_creator));
390 
391     // This means that multiple ops can't be run simultaneously on different
392     // threads, because we have a single shared resource. The platforms this is
393     // aimed at have intra-op parallelism as their focus though, so it shouldn't
394     // be an issue.
395     mutex_lock lock_buffer(im2col_buffer_resource->mu);
396     core::ScopedUnref unref_buffer(im2col_buffer_resource);
397     T1* im2col_buffer = im2col_buffer_resource->data;
398 
399     // This buffer is used as a fairly heavy-weight cache for the resized and
400     // mirrored inputs to the im2col operation. The problem is that we want to
401     // keep the memory usage down by not rendering the fully resized and padded
402     // input tensor to the convolution into an entire buffer. The first approach
403     // to avoid this was to fold the bilinear filtering and padding spatial
404     // transformations into the im2col lookup itself. This successfully reduced
405     // memory usage, but because im2col can access an individual pixel for many
406     // different patches, the extra overhead of doing the same bilinear lookups
407     // repeatedly became too expensive.
408     // The resize cache is designed to avoid this problem by keeping a
409     // horizontal slice of the resized and padded input to the im2col
410     // precalculated, so that repeated accesses to the same pixel from different
411     // filter patches can just be copied from this cache. It's organized as a
412     // horizontal slice stretching across the whole virtual image, and as high
413     // as the filter window, so that as the patch processing moves across all
414     // the pixels are present, and before a new row of patches is started any
415     // previously calculated rows that are needed are maintained, with new rows
416     // calculated as required.
417     mutex_lock resize_lock_buffer(resize_cache_resource->mu);
418     core::ScopedUnref unref_resized_cache(resize_cache_resource);
419     task_params.resize_cache = resize_cache_resource->data;
420 
421     const T1* input_data = input.flat<T1>().data();
422     const int64_t input_height = input.shape().dim_sizes()[1];
423     task_params.input_width = input.shape().dim_sizes()[2];
424 
425     int end_cached_lines = std::numeric_limits<int>::min();
426 
427     for (int batch = 0; batch < input_batches; ++batch) {
428       task_params.input_batch_start =
429           input_data +
430           (batch * input_height * task_params.input_width * input_depth);
431       const int in_y_end =
432           ((output_height * stride_rows) - filter_top_offset) + filter_height;
433       for (int out_y = 0; out_y < output_height; ++out_y) {
434         const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
435         const int cache_start_y = std::max(in_y_origin, end_cached_lines);
436         const int cache_end_y = std::min(
437             in_y_end, std::max((in_y_origin + task_params.cache_height),
438                                end_cached_lines));
439         if (end_cached_lines < (in_y_origin + filter_height)) {
440           // This call breaks up the work required for calculating the mirror
441           // padding and resizing across multiple threads.
442           FusedConvParallelFor(
443               context, cache_start_y, cache_end_y,
444               [task_params](int64_t task_cache_start_y,
445                             int64_t task_cache_end_y) {
446                 // This is a long and confusing function, but it's been laid out
447                 // this way to help with performance on some intensive models.
448                 // What it's doing is populating a cache of the original input
449                 // image, after it's been bilinear resized and had its edges
450                 // mirrored. This allows the following im2col code to access the
451                 // transformed pixels from this cache, without having to
452                 // repeatedly apply the expensive bilinear calculations as the
453                 // same pixels are accessed by different patches.
454                 // This is most effective when the stride is small and the
455                 // filter size is large, since that's when pixels are reused
456                 // most frequently as patches overlap.
457                 for (int cache_y = task_cache_start_y;
458                      cache_y < task_cache_end_y; ++cache_y) {
459                   // We organize the cache as a series of rows, each containing
460                   // all the transformed pixels for a given line in the image.
461                   // This cache is big enough to hold at least a filter's height
462                   // worth of rows, but typically more, limited by the size of
463                   // the cache buffer.
464                   // We don't allocate an entire image's worth of rows though,
465                   // because we're trying to keep memory usage down, so as we
466                   // progress downwards through the im2col we periodically
467                   // refresh the cache so that the next lines that are needed
468                   // for that operation are always present.
469                   // Work out the parameters that remain constant across the
470                   // row we're calculating.
471                   PerCacheLineParameters<T1> line_params(
472                       CalculatePerCacheLineParameters<T1>(
473                           task_params.cache_height, cache_y,
474                           task_params.resize_cache,
475                           task_params.cache_line_width, task_params.input_width,
476                           task_params.input_depth, task_params.top_padding,
477                           task_params.pad_offset, task_params.resized_height,
478                           task_params.st, task_params.input_batch_start));
479                   // Iterate through the resize cache row we're filling in.
480                   for (int cache_x = task_params.cache_start_x;
481                        cache_x < task_params.cache_end_x; ++cache_x) {
482                     // Figure out what we need for the cache pixel we're
483                     // populating.
484                     PerCachePixelParameters<T1> pixel_params(
485                         CalculatePerCachePixelParameters<T1>(
486                             cache_x, task_params.cache_start_x,
487                             line_params.cache_line_start,
488                             task_params.input_depth, task_params.left_padding,
489                             task_params.pad_offset, task_params.resized_width,
490                             task_params.st));
491                     // If the access is off the left, right, top, or bottom of
492                     // the resized image, the conv padding means we should set
493                     // it to zero.
494                     if ((cache_x < 0) ||
495                         (cache_x >= task_params.padded_width) ||
496                         (cache_y < 0) ||
497                         (cache_y >= task_params.padded_height)) {
498                       std::fill_n(pixel_params.cache_line_pixel,
499                                   task_params.input_depth, T1(0));
500                     } else {
501                       // There are two different sampling strategies for
502                       // resizing. When using nearest, we can just do a
503                       // straight copy of the pixel closest to our sample point,
504                       // but bilinear requires a more complex calculation.
505                       if (SampleMode == NEAREST) {
506                         const T1* input_top_left_pixel =
507                             line_params.input_top_row_start +
508                             (pixel_params.left_x_index *
509                              task_params.input_depth);
510 
511                         std::copy_n(input_top_left_pixel,
512                                     task_params.input_depth,
513                                     pixel_params.cache_line_pixel);
514                       } else {
515                         const SampleRect<T1> rect(
516                             line_params.input_top_row_start +
517                                 (pixel_params.left_x_index *
518                                  task_params.input_depth),
519                             line_params.input_top_row_start +
520                                 (pixel_params.right_x_index *
521                                  task_params.input_depth),
522                             line_params.input_bottom_row_start +
523                                 (pixel_params.left_x_index *
524                                  task_params.input_depth),
525                             line_params.input_bottom_row_start +
526                                 (pixel_params.right_x_index *
527                                  task_params.input_depth));
528                         for (int in_channel = 0;
529                              in_channel < task_params.input_depth;
530                              ++in_channel) {
531                           pixel_params.cache_line_pixel[in_channel] =
532                               rect.BilinearSample(in_channel,
533                                                   pixel_params.x_lerp,
534                                                   line_params.y_lerp);
535                         }
536                       }
537                     }
538                   }
539                 }
540               });
541           end_cached_lines = cache_end_y;
542         }
543         for (int out_x = 0; out_x < output_width; ++out_x) {
544           const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
545           const int patch_index = (batch * output_width * output_height) +
546                                   (out_y * output_width) + out_x;
547           const int patch_index_within_chunk = patch_index % patches_per_chunk;
548           T1* im2col_patch_start =
549               im2col_buffer + (patch_index_within_chunk * filter_value_count);
550           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
551             T1* im2col_row_start =
552                 im2col_patch_start +
553                 (filter_y * filter_width * task_params.input_depth);
554             const int conv_in_y = in_y_origin + filter_y;
555             int cache_index_y;
556             if (conv_in_y < 0) {
557               cache_index_y = task_params.cache_height +
558                               (conv_in_y % task_params.cache_height);
559             } else {
560               cache_index_y = conv_in_y % task_params.cache_height;
561             }
562             T1* cache_line_start =
563                 task_params.resize_cache +
564                 (cache_index_y * task_params.cache_line_width *
565                  task_params.input_depth);
566             T1* cache_filter_row_start =
567                 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
568                                     task_params.input_depth);
569             std::copy_n(cache_filter_row_start,
570                         (filter_width * task_params.input_depth),
571                         im2col_row_start);
572           }
573           const bool is_last_in_chunk =
574               (patch_index_within_chunk == (patches_per_chunk - 1));
575           const bool is_last_overall =
576               ((batch == (input_batches - 1)) &&
577                (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
578           if (is_last_in_chunk || is_last_overall) {
579             // Now we've assembled a set of image patches into a matrix, apply
580             // a GEMM matrix multiply of the patches as rows, times the filter
581             // weights in columns, to get partial results in the output
582             // matrix.
583             const int how_many_patches = patch_index_within_chunk + 1;
584             const int m = how_many_patches;
585             const int n = filter_count;
586             const int k = filter_value_count;
587             const int lda = filter_value_count;
588             const int ldb = filter_count;
589             const int ldc = filter_count;
590             const size_t start_patch_index =
591                 patch_index - (how_many_patches - 1);
592             T3* chunk_output_data =
593                 output_data + (start_patch_index * filter_count);
594             TGemmFunctor gemm_functor;
595             gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
596                          chunk_output_data, ldc);
597           }
598         }
599       }
600     }
601   }
602 };
603 
604 }  // namespace
605 
606 // Implements a version of convolution with bilinear resizing and mirror padding
607 // included.
608 template <class T, class TConvFunctor, bool DoResize>
609 class FusedResizeConv2DUsingGemmOp : public OpKernel {
610  public:
FusedResizeConv2DUsingGemmOp(OpKernelConstruction * context)611   explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
612       : OpKernel(context) {
613     if (DoResize) {
614       OP_REQUIRES_OK(context,
615                      context->GetAttr("resize_align_corners", &align_corners_));
616     }
617     MirrorPadMode mode;
618     OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
619 
620     switch (mode) {
621       case MirrorPadMode::SYMMETRIC: {
622         offset_ = 0;
623         break;
624       }
625       case MirrorPadMode::REFLECT: {
626         offset_ = 1;
627         break;
628       }
629       default:
630         OP_REQUIRES(context, false,
631                     errors::InvalidArgument(
632                         "mode must be either REFLECT or SYMMETRIC."));
633     }
634     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
635     OP_REQUIRES(context, strides_.size() == 4,
636                 errors::InvalidArgument("Sliding window strides field must "
637                                         "specify 4 dimensions"));
638     const int64_t stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
639     const int64_t stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
640     OP_REQUIRES(
641         context, stride_n == 1 && stride_c == 1,
642         errors::InvalidArgument("Current implementation does not yet support "
643                                 "strides in the batch and depth dimensions."));
644     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
645   }
646 
Compute(OpKernelContext * context)647   void Compute(OpKernelContext* context) override {
648     // Input tensor is of the following dimensions:
649     // [ batch, in_rows, in_cols, in_depth ]
650     const Tensor& input = context->input(0);
651     OP_REQUIRES(context, (input.shape().num_elements() > 0),
652                 errors::InvalidArgument("Input tensor can't be empty"));
653 
654     ImageResizerState st(false, false);
655     if (DoResize) {
656       st = ImageResizerState(align_corners_, false);
657       st.ValidateAndCalculateOutputSize(context);
658       if (!context->status().ok()) return;
659     } else {
660       // Set up the resize parameters to do no scaling at all.
661       st.batch_size = input.dim_size(0);
662       st.out_height = input.dim_size(1);
663       st.out_width = input.dim_size(2);
664       st.in_height = input.dim_size(1);
665       st.in_width = input.dim_size(2);
666       st.channels = input.dim_size(3);
667       st.height_scale = 1.0f;
668       st.width_scale = 1.0f;
669     }
670     TensorShape resized_shape(
671         {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
672     int paddings_index;
673     int filter_index;
674     if (DoResize) {
675       paddings_index = 2;
676       filter_index = 3;
677     } else {
678       paddings_index = 1;
679       filter_index = 2;
680     }
681     const Tensor& paddings = context->input(paddings_index);
682 
683     const int dims = resized_shape.dims();
684     OP_REQUIRES(
685         context,
686         TensorShapeUtils::IsMatrix(paddings.shape()) &&
687             paddings.dim_size(1) == 2,
688         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
689                                 paddings.shape().DebugString()));
690     OP_REQUIRES(
691         context, dims == paddings.dim_size(0),
692         errors::InvalidArgument(
693             "The first dimension of paddings must be the rank of inputs: ",
694             dims, " ", paddings.shape().DebugString(), " ",
695             resized_shape.DebugString()));
696     OP_REQUIRES(
697         context, dims == paddings.dim_size(0),
698         errors::InvalidArgument(
699             "The first dimension of paddings must be the rank of inputs: ",
700             dims, " ", paddings.shape().DebugString(), " ",
701             resized_shape.DebugString()));
702 
703     OP_REQUIRES(
704         context, dims == 4,
705         errors::InvalidArgument(
706             "Fused mirror padding only supports four-dimensional inputs, but ",
707             dims, " requested"));
708 
709     // Compute the shape of the output tensor, and allocate it.
710     TensorShape padded_shape;
711     TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
712     for (int d = 0; d < dims; ++d) {
713       const int32_t before =
714           paddings_matrix(d, 0);  // Pad before existing elements.
715       const int32_t after =
716           paddings_matrix(d, 1);  // Pad after existing elements.
717       OP_REQUIRES(context, before >= 0 && after >= 0,
718                   errors::InvalidArgument(
719                       "paddings must be non-negative: ", before, " ", after));
720       if (offset_ == 0) {  // SYMMETRIC mode.
721         OP_REQUIRES(
722             context,
723             before <= resized_shape.dim_size(d) &&
724                 after <= resized_shape.dim_size(d),
725             errors::InvalidArgument("paddings must be no greater "
726                                     "than the dimension size: ",
727                                     before, ", ", after, " greater than ",
728                                     resized_shape.dim_size(d)));
729       } else if (offset_ == 1) {  // REFLECT mode.
730         OP_REQUIRES(
731             context,
732             before < resized_shape.dim_size(d) &&
733                 after < resized_shape.dim_size(d),
734             errors::InvalidArgument("paddings must be less than"
735                                     " the dimension size: ",
736                                     before, ", ", after, " not less than ",
737                                     resized_shape.dim_size(d)));
738       }
739       padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
740     }
741 
742     OP_REQUIRES(
743         context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
744         errors::InvalidArgument(
745             "Fused mirror padding only support spatial padding, not batches: ",
746             paddings.DebugString()));
747     OP_REQUIRES(
748         context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
749         errors::InvalidArgument(
750             "Fused mirror padding only support spatial padding, not channels: ",
751             paddings.DebugString()));
752     const int32_t top_padding = paddings_matrix(1, 0);
753     const int32_t bottom_padding = paddings_matrix(1, 1);
754     const int32_t left_padding = paddings_matrix(2, 0);
755     const int32_t right_padding = paddings_matrix(2, 1);
756 
757     // Input filter is of the following dimensions:
758     // [ filter_rows, filter_cols, in_depth, out_depth]
759     const Tensor& filter = context->input(filter_index);
760 
761     // For 2D convolution, there should be 4 dimensions.
762     OP_REQUIRES(context, padded_shape.dims() == 4,
763                 errors::InvalidArgument("input must be 4-dimensional",
764                                         padded_shape.DebugString()));
765     OP_REQUIRES(context, filter.dims() == 4,
766                 errors::InvalidArgument("filter must be 4-dimensional: ",
767                                         filter.shape().DebugString()));
768 
769     // We only check the first three dims, since the depth is accessed as an
770     // int64 below.
771     for (int i = 0; i < 3; i++) {
772       OP_REQUIRES(
773           context,
774           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
775           errors::InvalidArgument("filter too large"));
776     }
777 
778     // The last dimension for input is in_depth. It must be the same as the
779     // filter's in_depth.
780     const int64_t in_depth = padded_shape.dim_size(3);
781     OP_REQUIRES(context, in_depth == filter.dim_size(2),
782                 errors::InvalidArgument(
783                     "input and filter must have the same depth: ", in_depth,
784                     " vs ", filter.dim_size(2)));
785 
786     // The last dimension for filter is out_depth.
787     const int out_depth = static_cast<int>(filter.dim_size(3));
788 
789     // The second dimension for input is rows/height.
790     // The first dimension for filter is rows/height.
791     const int64_t padded_rows_raw = padded_shape.dim_size(1);
792     OP_REQUIRES(
793         context,
794         FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
795         errors::InvalidArgument("Input rows too large"));
796     const int padded_rows = static_cast<int>(padded_rows_raw);
797     const int filter_rows = static_cast<int>(filter.dim_size(0));
798     const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
799 
800     // The third dimension for input is columns/width.
801     // The second dimension for filter is columns/width.
802     const int64_t padded_cols_raw = padded_shape.dim_size(2);
803     OP_REQUIRES(
804         context,
805         FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
806         errors::InvalidArgument("Input cols too large"));
807     const int padded_cols = static_cast<int>(padded_cols_raw);
808     const int filter_cols = static_cast<int>(filter.dim_size(1));
809     const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
810 
811     // The first dimension for input is batch.
812     const int64_t batch_raw = padded_shape.dim_size(0);
813     OP_REQUIRES(context,
814                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
815                 errors::InvalidArgument("batch is too large"));
816     const int batch = static_cast<int>(batch_raw);
817 
818     // For now we take the stride from the second and third dimensions only (we
819     // do not support striding on the batch or depth dimension).
820     const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
821     const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
822 
823     int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
824     OP_REQUIRES_OK(context,
825                    GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
826                                          padding_, &out_rows, &pad_rows));
827     OP_REQUIRES_OK(context,
828                    GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
829                                          padding_, &out_cols, &pad_cols));
830     TensorShape out_shape =
831         ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
832     OP_REQUIRES(context, (out_shape.num_elements() > 0),
833                 errors::InvalidArgument("Output tensor can't be empty"));
834 
835     // Output tensor is of the following dimensions:
836     // [ in_batch, out_rows, out_cols, out_depth ]
837     Tensor* output = nullptr;
838     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
839 
840     VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
841             << ", padded_cols = " << padded_cols
842             << ", resized_cols = " << resized_cols
843             << ", filter_cols = " << filter_cols
844             << ", padded_rows = " << padded_rows
845             << ", resized_rows = " << resized_rows
846             << ", filter_rows = " << filter_rows
847             << ", stride_rows = " << stride_rows
848             << ", stride_cols = " << stride_cols
849             << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
850 
851     // If there is nothing to compute, return.
852     if (out_shape.num_elements() == 0) {
853       return;
854     }
855     TConvFunctor conv_functor;
856     conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
857                  padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
858                  filter_cols, out_depth, stride_rows, stride_cols, padding_,
859                  output->flat<T>().data(), out_rows, out_cols, st, top_padding,
860                  bottom_padding, left_padding, right_padding, offset_);
861   }
862 
863  private:
864   std::vector<int32> strides_;
865   Padding padding_;
866   bool align_corners_;
867   int offset_;
868 
869   TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
870 };
871 
872 #define REGISTER_FUSED(T)                                                 \
873   REGISTER_KERNEL_BUILDER(                                                \
874       Name("FusedResizeAndPadConv2D")                                     \
875           .Device(DEVICE_CPU)                                             \
876           .TypeConstraint<T>("T"),                                        \
877       FusedResizeConv2DUsingGemmOp<                                       \
878           T,                                                              \
879           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
880                                        BILINEAR>,                         \
881           true>);
882 
883 TF_CALL_half(REGISTER_FUSED);
884 TF_CALL_float(REGISTER_FUSED);
885 TF_CALL_double(REGISTER_FUSED);
886 
887 #define REGISTER_PAD_ONLY_FUSED(T)                                        \
888   REGISTER_KERNEL_BUILDER(                                                \
889       Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
890       FusedResizeConv2DUsingGemmOp<                                       \
891           T,                                                              \
892           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
893                                        NEAREST>,                          \
894           false>);
895 
896 TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
897 TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
898 TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
899 
900 }  // namespace tensorflow
901