• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements convolution operations with image transformations (resize and
17 // mirror padding) baked into the processing, to optimize latency and memory
18 // usage.
19 
20 #define EIGEN_USE_THREADS
21 
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_shape_util.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_slice.h"
34 #include "tensorflow/core/kernels/conv_2d.h"
35 #include "tensorflow/core/kernels/conv_ops.h"
36 #include "tensorflow/core/kernels/gemm_functors.h"
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/util/image_resizer_state.h"
40 #include "tensorflow/core/util/mirror_pad_mode.h"
41 #include "tensorflow/core/util/padding.h"
42 #include "tensorflow/core/util/tensor_format.h"
43 
44 namespace tensorflow {
45 namespace {
46 
47 // We don't want to allocate a buffer to hold all the patches if the size is
48 // going to be extremely large, so break it into chunks if it's bigger than
49 // a limit. Each chunk will be processed serially, so we can refill the
50 // buffer for the next chunk and reuse it, keeping maximum memory size down.
51 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
52 // other platforms using Eigen, and 1MB for iOS devices, from experimentation.
53 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
54 const size_t kMaxChunkSize = (1 * 1024 * 1024);
55 #else
56 const size_t kMaxChunkSize = (16 * 1024 * 1024);
57 #endif
58 const size_t kResizeCacheSize = (8 * 1024 * 1024);
59 
60 // Lookup method used when resizing.
61 enum SamplingMode {
62   BILINEAR = 0,
63   NEAREST = 1,
64 };
65 
66 // Simple utility function used by FusedConv to multithread basic workloads. To
67 // use it, pass begin and end values for the full workload and a std::function
68 // that receives a subset of that through the begin and end values for each
69 // worker's task. The division of the full workload into worker tasks is handled
70 // by the multithreading logic. Here's an example of how to use it:
71 // std::vector<float> my_vector(100);
72 // ...
73 // FusedConvParallelFor(context, 0, 100,
74 //   [&my_vector](int64 task_begin, int64 task_end) {
75 //     for (int64 current = task_begin; current != task_end; ++current) {
76 //       my_vector[current] *= 10.0f;
77 //     }
78 // });
FusedConvParallelFor(OpKernelContext * context,int64 begin,int64 end,const std::function<void (int64,int64)> & task_function)79 void FusedConvParallelFor(
80     OpKernelContext* context, int64 begin, int64 end,
81     const std::function<void(int64, int64)>& task_function) {
82 // On iOS, the thread management imposes a very big performance penalty, so
83 // just call the function directly with no multithreading.
84 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
85   task_function(begin, end);
86 #else
87   auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
88   thread::ThreadPool* thread_pool = worker_threads.workers;
89   const int64 total_elements = end - begin;
90   // This is a bit of an arbitrary number, but was found to work well for
91   // typical models we've been profiling on various devices.
92   const int64 element_cost = 10000000;
93   thread_pool->ParallelFor(
94       total_elements, element_cost,
95       [begin, task_function](int64 begin_offset, int64 end_offset) {
96         const int64 task_begin = begin + begin_offset;
97         const int64 task_end = begin + end_offset;
98         task_function(task_begin, task_end);
99       });
100 #endif
101 }
102 
103 // Holds the state needed for the resizing subtasks.
104 template <class T1>
105 struct ResizeTaskParameters {
ResizeTaskParameterstensorflow::__anond23836c50111::ResizeTaskParameters106   ResizeTaskParameters() : st(false, false) {}
107 
108   int cache_height;
109   T1* resize_cache;
110   int cache_line_width;
111   int input_width;
112   int input_depth;
113   int top_padding;
114   int pad_offset;
115   int64 resized_height;
116   ImageResizerState st;
117   const T1* input_batch_start;
118   int64 cache_start_x;
119   int64 cache_end_x;
120   int left_padding;
121   int64 resized_width;
122   int64 padded_width;
123   int64 padded_height;
124 };
125 
126 template <class T1>
127 struct PerCacheLineParameters {
PerCacheLineParameterstensorflow::__anond23836c50111::PerCacheLineParameters128   PerCacheLineParameters() {}
PerCacheLineParameterstensorflow::__anond23836c50111::PerCacheLineParameters129   PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
130       : cache_line_start(other.cache_line_start),
131         input_top_row_start(other.input_top_row_start),
132         input_bottom_row_start(other.input_bottom_row_start),
133         y_lerp(other.y_lerp) {}
134 
135   T1* cache_line_start;
136   const T1* input_top_row_start;
137   const T1* input_bottom_row_start;
138   T1 y_lerp;
139 };
140 
141 // Helper class to simplify bilinear filtering
142 template <class T1>
143 struct SampleRect {
SampleRecttensorflow::__anond23836c50111::SampleRect144   EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
145                                  const T1* in_bottom_left,
146                                  const T1* in_bottom_right)
147       : top_left(in_top_left),
148         top_right(in_top_right),
149         bottom_left(in_bottom_left),
150         bottom_right(in_bottom_right) {}
151 
BilinearSampletensorflow::__anond23836c50111::SampleRect152   EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
153                                         T1 y_lerp) const {
154     const T1 top =
155         top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
156     const T1 bottom = bottom_left[channel] +
157                       (bottom_right[channel] - bottom_left[channel]) * x_lerp;
158     return top + (bottom - top) * y_lerp;
159   }
160 
161   const T1* top_left;
162   const T1* top_right;
163   const T1* bottom_left;
164   const T1* bottom_right;
165 };
166 
167 // Calculates parameters which remain constant through a resize cache row.
168 template <class T1>
CalculatePerCacheLineParameters(int64 cache_height,int64 cache_y,T1 * resize_cache,int64 cache_line_width,int64 input_width,int64 input_depth,int64 top_padding,int64 pad_offset,int64 resized_height,const ImageResizerState & st,const T1 * input_batch_start)169 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
170     int64 cache_height, int64 cache_y, T1* resize_cache, int64 cache_line_width,
171     int64 input_width, int64 input_depth, int64 top_padding, int64 pad_offset,
172     int64 resized_height, const ImageResizerState& st,
173     const T1* input_batch_start) {
174   PerCacheLineParameters<T1> result;
175   // The cache is organized so that the real y values of the resized image map
176   // onto the actual cache values through a modulo scheme. This means that as we
177   // progress downwards through the image, we keep reusing a small cache and so
178   // keep memory usage down.
179   int64 cache_index_y;
180   if (cache_y < 0) {
181     cache_index_y = cache_height + (cache_y % cache_height);
182   } else {
183     cache_index_y = cache_y % cache_height;
184   }
185   result.cache_line_start =
186       resize_cache + (cache_index_y * cache_line_width * input_depth);
187   // This part is implementing the mirror padding that happens before resizing.
188   float in_y = (cache_y - top_padding);
189   if (in_y < 0) {
190     in_y = -(in_y + 1.0f - pad_offset);
191   } else if (in_y >= resized_height) {
192     in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
193   }
194   // Here's where to do the actual resize.
195   in_y *= st.height_scale;
196   const int64 top_y_index = static_cast<int64>(std::floor(in_y));
197   const int64 bottom_y_index =
198       std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
199   // Lerp is used for bilinear filtering when that's needed.
200   result.y_lerp = static_cast<T1>(in_y - top_y_index);
201   // Which rows of the original input image to pull the values from.
202   result.input_top_row_start =
203       input_batch_start + (top_y_index * input_width * input_depth);
204   result.input_bottom_row_start =
205       input_batch_start + (bottom_y_index * input_width * input_depth);
206   return result;
207 }
208 
209 template <class T1>
210 struct PerCachePixelParameters {
PerCachePixelParameterstensorflow::__anond23836c50111::PerCachePixelParameters211   PerCachePixelParameters() {}
PerCachePixelParameterstensorflow::__anond23836c50111::PerCachePixelParameters212   PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
213       : cache_line_pixel(other.cache_line_pixel),
214         left_x_index(other.left_x_index),
215         right_x_index(other.right_x_index),
216         x_lerp(other.x_lerp) {}
217 
218   T1* cache_line_pixel;
219   int64 left_x_index;
220   int64 right_x_index;
221   T1 x_lerp;
222 };
223 
224 // Pulls out common parameters used for every resized pixel.
225 template <class T1>
226 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
CalculatePerCachePixelParameters(int64 cache_x,int64 cache_start_x,T1 * cache_line_start,int64 input_depth,int64 left_padding,int64 pad_offset,int64 resized_width,const ImageResizerState & st)227 CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
228                                  T1* cache_line_start, int64 input_depth,
229                                  int64 left_padding, int64 pad_offset,
230                                  int64 resized_width,
231                                  const ImageResizerState& st) {
232   PerCachePixelParameters<T1> result;
233   // Figure out where we're going to store the results of our transform.
234   const int cache_index_x = cache_x - cache_start_x;
235   result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
236   // Implement mirror padding by flipping in_x if it's off the edge.
237   float in_x = (cache_x - left_padding);
238   if (in_x < 0) {
239     in_x = -(in_x + 1.0f - pad_offset);
240   } else if (in_x >= resized_width) {
241     in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
242   }
243   // Resize the x parameters.
244   in_x *= st.width_scale;
245   // Get the x coordinates for the left and right pixels to pull from.
246   result.left_x_index = static_cast<int64>(std::floor(in_x));
247   result.right_x_index =
248       std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
249   // This x_lerp is used to blend pixels in bilinear filtering.
250   result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
251   return result;
252 }
253 
254 // Combines bilinear resizing and mirror padding into the im2col transformation
255 // stage of convolution.
256 template <class T1, class T2, class T3, class TGemmFunctor,
257           SamplingMode SampleMode>
258 class FusedResizeAndPadConvFunctor {
259  public:
operator ()(OpKernelContext * context,const Tensor & input,int input_batches,int resized_height,int resized_width,int padded_height,int padded_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width,const ImageResizerState & st,int top_padding,int bottom_padding,int left_padding,int right_padding,int pad_offset)260   void operator()(OpKernelContext* context, const Tensor& input,
261                   int input_batches, int resized_height, int resized_width,
262                   int padded_height, int padded_width, int input_depth,
263                   const T2* filter_data, int filter_height, int filter_width,
264                   int filter_count, int stride_rows, int stride_cols,
265                   Padding padding, T3* output_data, int output_height,
266                   int output_width, const ImageResizerState& st,
267                   int top_padding, int bottom_padding, int left_padding,
268                   int right_padding, int pad_offset) {
269     if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
270         (input_depth <= 0)) {
271       LOG(WARNING) << "Conv2D was called with bad input dimensions: "
272                    << input_batches << ", " << padded_height << ", "
273                    << padded_width << ", " << input_depth;
274       return;
275     }
276     if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
277       LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
278                    << filter_width << ", " << filter_height << ", "
279                    << filter_count;
280       return;
281     }
282     if ((output_width <= 0) || (output_height <= 0)) {
283       LOG(WARNING) << "Conv2D was called with bad output width or height: "
284                    << output_width << ", " << output_height;
285       return;
286     }
287     OP_REQUIRES(
288         context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
289         errors::InvalidArgument("Bad sample mode passed in", SampleMode));
290 
291     // These calculations define how the patches will be positioned within the
292     // input image. The actual definitions are quite complex, and rely on the
293     // previously-calculated output size.
294     int filter_left_offset;
295     int filter_top_offset;
296     if (padding == VALID) {
297       filter_left_offset =
298           ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
299           2;
300       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
301                            padded_height + 1) /
302                           2;
303     } else {
304       filter_left_offset =
305           ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
306       filter_top_offset =
307           ((output_height - 1) * stride_rows + filter_height - padded_height) /
308           2;
309     }
310 
311     ResizeTaskParameters<T1> task_params;
312     task_params.input_depth = input_depth;
313     task_params.top_padding = top_padding;
314     task_params.pad_offset = pad_offset;
315     task_params.resized_height = resized_height;
316     task_params.st = st;
317     task_params.left_padding = left_padding;
318     task_params.resized_width = resized_width;
319     task_params.padded_width = padded_width;
320     task_params.padded_height = padded_height;
321 
322     // The im2col buffer has # of patches rows, and # of filters cols.
323     // It's laid out like this, in row major order in memory:
324     //        < filter value count >
325     //   ^   +---------------------+
326     // patch |                     |
327     // count |                     |
328     //   v   +---------------------+
329     // Each patch row contains a filter_width x filter_height patch of the
330     // input, with the depth channel as the most contiguous in memory, followed
331     // by the width, then the height. This is the standard memory order in the
332     // image world if it helps to visualize it.
333     const int filter_value_count = filter_width * filter_height * input_depth;
334 
335     OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
336                 errors::InvalidArgument("Im2Col patch too large for buffer"));
337     const size_t patches_per_chunk =
338         kMaxChunkSize / (filter_value_count * sizeof(T1));
339     // Because memory allocation is very expensive on mobile platforms, try to
340     // allocate a persistent buffer that will be kept around between calls. We
341     // use TensorFlow's resource management to ensure that the memory will be
342     // released when the session is over.
343     Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
344     std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
345         [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
346           *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
347           return Status::OK();
348         };
349     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
350                                 "Conv2d", "im2col_buffer",
351                                 &im2col_buffer_resource, creator));
352 
353     // Create a resize cache memory buffer that will hold the rows of
354     // transformed and mirror padded input pixels, ready to be copied
355     // into filter patches by im2col.
356     // It's laid out like this, in row major order in memory:
357     //         < cache line width >
358     //   ^    +--------------------+
359     // cache  |                    |
360     // height |                    |
361     //   v    +--------------------+
362     // Each cache row contains a cache_line_width number of resized pixels,
363     // each with input_depth channels. The cache height is typically less than
364     // the full height the resized image would be, so it's filled up
365     // incrementally as we progress downwards through the input creating im2col
366     // patches.
367     task_params.cache_start_x = -filter_left_offset;
368     task_params.cache_end_x =
369         (((output_width - 1) * stride_cols) - filter_left_offset) +
370         filter_width;
371     task_params.cache_line_width =
372         task_params.cache_end_x - task_params.cache_start_x;
373     task_params.cache_height =
374         kResizeCacheSize / (task_params.cache_line_width * input_depth);
375     const int needed_resize_cache_count =
376         filter_height * task_params.cache_line_width * input_depth;
377     OP_REQUIRES(context,
378                 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
379                 errors::InvalidArgument("Input too large for resize cache"));
380     Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
381     std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
382         resize_creator =
383             [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
384               *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
385               return Status::OK();
386             };
387     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
388                                 "Conv2d", "resize_cache",
389                                 &resize_cache_resource, resize_creator));
390 
391     // This means that multiple ops can't be run simultaneously on different
392     // threads, because we have a single shared resource. The platforms this is
393     // aimed at have intra-op parallelism as their focus though, so it shouldn't
394     // be an issue.
395     mutex_lock lock_buffer(im2col_buffer_resource->mu);
396     core::ScopedUnref unref_buffer(im2col_buffer_resource);
397     T1* im2col_buffer = im2col_buffer_resource->data;
398 
399     // This buffer is used as a fairly heavy-weight cache for the resized and
400     // mirrored inputs to the im2col operation. The problem is that we want to
401     // keep the memory usage down by not rendering the fully resized and padded
402     // input tensor to the convolution into an entire buffer. The first approach
403     // to avoid this was to fold the bilinear filtering and padding spatial
404     // transformations into the im2col lookup itself. This successfully reduced
405     // memory usage, but because im2col can access an individual pixel for many
406     // different patches, the extra overhead of doing the same bilinear lookups
407     // repeatedly became too expensive.
408     // The resize cache is designed to avoid this problem by keeping a
409     // horizontal slice of the resized and padded input to the im2col
410     // precalculated, so that repeated accesses to the same pixel from different
411     // filter patches can just be copied from this cache. It's organized as a
412     // horizontal slice stretching across the whole virtual image, and as high
413     // as the filter window, so that as the patch processing moves across all
414     // the pixels are present, and before a new row of patches is started any
415     // previously calculated rows that are needed are maintained, with new rows
416     // calculated as required.
417     mutex_lock resize_lock_buffer(resize_cache_resource->mu);
418     core::ScopedUnref unref_resized_cache(resize_cache_resource);
419     task_params.resize_cache = resize_cache_resource->data;
420 
421     const T1* input_data = input.flat<T1>().data();
422     const int64 input_height = input.shape().dim_sizes()[1];
423     task_params.input_width = input.shape().dim_sizes()[2];
424 
425     int end_cached_lines = std::numeric_limits<int>::min();
426 
427     for (int batch = 0; batch < input_batches; ++batch) {
428       task_params.input_batch_start =
429           input_data +
430           (batch * input_height * task_params.input_width * input_depth);
431       const int in_y_end =
432           ((output_height * stride_rows) - filter_top_offset) + filter_height;
433       for (int out_y = 0; out_y < output_height; ++out_y) {
434         const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
435         const int cache_start_y = std::max(in_y_origin, end_cached_lines);
436         const int cache_end_y = std::min(
437             in_y_end, std::max((in_y_origin + task_params.cache_height),
438                                end_cached_lines));
439         if (end_cached_lines < (in_y_origin + filter_height)) {
440           // This call breaks up the work required for calculating the mirror
441           // padding and resizing across multiple threads.
442           FusedConvParallelFor(
443               context, cache_start_y, cache_end_y,
444               [task_params](int64 task_cache_start_y, int64 task_cache_end_y) {
445                 // This is a long and confusing function, but it's been laid out
446                 // this way to help with performance on some intensive models.
447                 // What it's doing is populating a cache of the original input
448                 // image, after it's been bilinear resized and had its edges
449                 // mirrored. This allows the following im2col code to access the
450                 // transformed pixels from this cache, without having to
451                 // repeatedly apply the expensive bilinear calculations as the
452                 // same pixels are accessed by different patches.
453                 // This is most effective when the stride is small and the
454                 // filter size is large, since that's when pixels are reused
455                 // most frequently as patches overlap.
456                 for (int cache_y = task_cache_start_y;
457                      cache_y < task_cache_end_y; ++cache_y) {
458                   // We organize the cache as a series of rows, each containing
459                   // all the transformed pixels for a given line in the image.
460                   // This cache is big enough to hold at least a filter's height
461                   // worth of rows, but typically more, limited by the size of
462                   // the cache buffer.
463                   // We don't allocate an entire image's worth of rows though,
464                   // because we're trying to keep memory usage down, so as we
465                   // progress downwards through the im2col we periodically
466                   // refresh the cache so that the next lines that are needed
467                   // for that operation are always present.
468                   // Work out the parameters that remain constant across the
469                   // row we're calculating.
470                   PerCacheLineParameters<T1> line_params(
471                       CalculatePerCacheLineParameters<T1>(
472                           task_params.cache_height, cache_y,
473                           task_params.resize_cache,
474                           task_params.cache_line_width, task_params.input_width,
475                           task_params.input_depth, task_params.top_padding,
476                           task_params.pad_offset, task_params.resized_height,
477                           task_params.st, task_params.input_batch_start));
478                   // Iterate through the resize cache row we're filling in.
479                   for (int cache_x = task_params.cache_start_x;
480                        cache_x < task_params.cache_end_x; ++cache_x) {
481                     // Figure out what we need for the cache pixel we're
482                     // populating.
483                     PerCachePixelParameters<T1> pixel_params(
484                         CalculatePerCachePixelParameters<T1>(
485                             cache_x, task_params.cache_start_x,
486                             line_params.cache_line_start,
487                             task_params.input_depth, task_params.left_padding,
488                             task_params.pad_offset, task_params.resized_width,
489                             task_params.st));
490                     // If the access is off the left, right, top, or bottom of
491                     // the resized image, the conv padding means we should set
492                     // it to zero.
493                     if ((cache_x < 0) ||
494                         (cache_x >= task_params.padded_width) ||
495                         (cache_y < 0) ||
496                         (cache_y >= task_params.padded_height)) {
497                       std::fill_n(pixel_params.cache_line_pixel,
498                                   task_params.input_depth, T1(0));
499                     } else {
500                       // There are two different sampling strategies for
501                       // resizing. When using nearest, we can just do a
502                       // straight copy of the pixel closest to our sample point,
503                       // but bilinear requires a more complex calculation.
504                       if (SampleMode == NEAREST) {
505                         const T1* input_top_left_pixel =
506                             line_params.input_top_row_start +
507                             (pixel_params.left_x_index *
508                              task_params.input_depth);
509 
510                         std::copy_n(input_top_left_pixel,
511                                     task_params.input_depth,
512                                     pixel_params.cache_line_pixel);
513                       } else {
514                         const SampleRect<T1> rect(
515                             line_params.input_top_row_start +
516                                 (pixel_params.left_x_index *
517                                  task_params.input_depth),
518                             line_params.input_top_row_start +
519                                 (pixel_params.right_x_index *
520                                  task_params.input_depth),
521                             line_params.input_bottom_row_start +
522                                 (pixel_params.left_x_index *
523                                  task_params.input_depth),
524                             line_params.input_bottom_row_start +
525                                 (pixel_params.right_x_index *
526                                  task_params.input_depth));
527                         for (int in_channel = 0;
528                              in_channel < task_params.input_depth;
529                              ++in_channel) {
530                           pixel_params.cache_line_pixel[in_channel] =
531                               rect.BilinearSample(in_channel,
532                                                   pixel_params.x_lerp,
533                                                   line_params.y_lerp);
534                         }
535                       }
536                     }
537                   }
538                 }
539               });
540           end_cached_lines = cache_end_y;
541         }
542         for (int out_x = 0; out_x < output_width; ++out_x) {
543           const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
544           const int patch_index = (batch * output_width * output_height) +
545                                   (out_y * output_width) + out_x;
546           const int patch_index_within_chunk = patch_index % patches_per_chunk;
547           T1* im2col_patch_start =
548               im2col_buffer + (patch_index_within_chunk * filter_value_count);
549           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
550             T1* im2col_row_start =
551                 im2col_patch_start +
552                 (filter_y * filter_width * task_params.input_depth);
553             const int conv_in_y = in_y_origin + filter_y;
554             int cache_index_y;
555             if (conv_in_y < 0) {
556               cache_index_y = task_params.cache_height +
557                               (conv_in_y % task_params.cache_height);
558             } else {
559               cache_index_y = conv_in_y % task_params.cache_height;
560             }
561             T1* cache_line_start =
562                 task_params.resize_cache +
563                 (cache_index_y * task_params.cache_line_width *
564                  task_params.input_depth);
565             T1* cache_filter_row_start =
566                 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
567                                     task_params.input_depth);
568             std::copy_n(cache_filter_row_start,
569                         (filter_width * task_params.input_depth),
570                         im2col_row_start);
571           }
572           const bool is_last_in_chunk =
573               (patch_index_within_chunk == (patches_per_chunk - 1));
574           const bool is_last_overall =
575               ((batch == (input_batches - 1)) &&
576                (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
577           if (is_last_in_chunk || is_last_overall) {
578             // Now we've assembled a set of image patches into a matrix, apply
579             // a GEMM matrix multiply of the patches as rows, times the filter
580             // weights in columns, to get partial results in the output
581             // matrix.
582             const int how_many_patches = patch_index_within_chunk + 1;
583             const int m = how_many_patches;
584             const int n = filter_count;
585             const int k = filter_value_count;
586             const int lda = filter_value_count;
587             const int ldb = filter_count;
588             const int ldc = filter_count;
589             const size_t start_patch_index =
590                 patch_index - (how_many_patches - 1);
591             T3* chunk_output_data =
592                 output_data + (start_patch_index * filter_count);
593             TGemmFunctor gemm_functor;
594             gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
595                          chunk_output_data, ldc);
596           }
597         }
598       }
599     }
600   }
601 };
602 
603 }  // namespace
604 
605 // Implements a version of convolution with bilinear resizing and mirror padding
606 // included.
607 template <class T, class TConvFunctor, bool DoResize>
608 class FusedResizeConv2DUsingGemmOp : public OpKernel {
609  public:
FusedResizeConv2DUsingGemmOp(OpKernelConstruction * context)610   explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
611       : OpKernel(context) {
612     if (DoResize) {
613       OP_REQUIRES_OK(context,
614                      context->GetAttr("resize_align_corners", &align_corners_));
615     }
616     MirrorPadMode mode;
617     OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
618 
619     switch (mode) {
620       case MirrorPadMode::SYMMETRIC: {
621         offset_ = 0;
622         break;
623       }
624       case MirrorPadMode::REFLECT: {
625         offset_ = 1;
626         break;
627       }
628       default:
629         OP_REQUIRES(context, false,
630                     errors::InvalidArgument(
631                         "mode must be either REFLECT or SYMMETRIC."));
632     }
633     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
634     OP_REQUIRES(context, strides_.size() == 4,
635                 errors::InvalidArgument("Sliding window strides field must "
636                                         "specify 4 dimensions"));
637     const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
638     const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
639     OP_REQUIRES(
640         context, stride_n == 1 && stride_c == 1,
641         errors::InvalidArgument("Current implementation does not yet support "
642                                 "strides in the batch and depth dimensions."));
643     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
644   }
645 
Compute(OpKernelContext * context)646   void Compute(OpKernelContext* context) override {
647     // Input tensor is of the following dimensions:
648     // [ batch, in_rows, in_cols, in_depth ]
649     const Tensor& input = context->input(0);
650     OP_REQUIRES(context, (input.shape().num_elements() > 0),
651                 errors::InvalidArgument("Input tensor can't be empty"));
652 
653     ImageResizerState st(false, false);
654     if (DoResize) {
655       st = ImageResizerState(align_corners_, false);
656       st.ValidateAndCalculateOutputSize(context, input);
657       if (!context->status().ok()) return;
658     } else {
659       // Set up the resize parameters to do no scaling at all.
660       st.batch_size = input.dim_size(0);
661       st.out_height = input.dim_size(1);
662       st.out_width = input.dim_size(2);
663       st.in_height = input.dim_size(1);
664       st.in_width = input.dim_size(2);
665       st.channels = input.dim_size(3);
666       st.height_scale = 1.0f;
667       st.width_scale = 1.0f;
668     }
669     TensorShape resized_shape(
670         {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
671     int paddings_index;
672     int filter_index;
673     if (DoResize) {
674       paddings_index = 2;
675       filter_index = 3;
676     } else {
677       paddings_index = 1;
678       filter_index = 2;
679     }
680     const Tensor& paddings = context->input(paddings_index);
681 
682     const int dims = resized_shape.dims();
683     OP_REQUIRES(
684         context,
685         TensorShapeUtils::IsMatrix(paddings.shape()) &&
686             paddings.dim_size(1) == 2,
687         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
688                                 paddings.shape().DebugString()));
689     OP_REQUIRES(
690         context, dims == paddings.dim_size(0),
691         errors::InvalidArgument(
692             "The first dimension of paddings must be the rank of inputs: ",
693             dims, " ", paddings.shape().DebugString(), " ",
694             resized_shape.DebugString()));
695     OP_REQUIRES(
696         context, dims == paddings.dim_size(0),
697         errors::InvalidArgument(
698             "The first dimension of paddings must be the rank of inputs: ",
699             dims, " ", paddings.shape().DebugString(), " ",
700             resized_shape.DebugString()));
701 
702     OP_REQUIRES(
703         context, dims == 4,
704         errors::InvalidArgument(
705             "Fused mirror padding only supports four-dimensional inputs, but ",
706             dims, " requested"));
707 
708     // Compute the shape of the output tensor, and allocate it.
709     TensorShape padded_shape;
710     TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
711     for (int d = 0; d < dims; ++d) {
712       const int32 before =
713           paddings_matrix(d, 0);  // Pad before existing elements.
714       const int32 after =
715           paddings_matrix(d, 1);  // Pad after existing elements.
716       OP_REQUIRES(context, before >= 0 && after >= 0,
717                   errors::InvalidArgument(
718                       "paddings must be non-negative: ", before, " ", after));
719       if (offset_ == 0) {  // SYMMETRIC mode.
720         OP_REQUIRES(
721             context,
722             before <= resized_shape.dim_size(d) &&
723                 after <= resized_shape.dim_size(d),
724             errors::InvalidArgument("paddings must be no greater "
725                                     "than the dimension size: ",
726                                     before, ", ", after, " greater than ",
727                                     resized_shape.dim_size(d)));
728       } else if (offset_ == 1) {  // REFLECT mode.
729         OP_REQUIRES(
730             context,
731             before < resized_shape.dim_size(d) &&
732                 after < resized_shape.dim_size(d),
733             errors::InvalidArgument("paddings must be less than"
734                                     " the dimension size: ",
735                                     before, ", ", after, " not less than ",
736                                     resized_shape.dim_size(d)));
737       }
738       padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
739     }
740 
741     OP_REQUIRES(
742         context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
743         errors::InvalidArgument(
744             "Fused mirror padding only support spatial padding, not batches: ",
745             paddings.DebugString()));
746     OP_REQUIRES(
747         context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
748         errors::InvalidArgument(
749             "Fused mirror padding only support spatial padding, not channels: ",
750             paddings.DebugString()));
751     const int32 top_padding = paddings_matrix(1, 0);
752     const int32 bottom_padding = paddings_matrix(1, 1);
753     const int32 left_padding = paddings_matrix(2, 0);
754     const int32 right_padding = paddings_matrix(2, 1);
755 
756     // Input filter is of the following dimensions:
757     // [ filter_rows, filter_cols, in_depth, out_depth]
758     const Tensor& filter = context->input(filter_index);
759 
760     // For 2D convolution, there should be 4 dimensions.
761     OP_REQUIRES(context, padded_shape.dims() == 4,
762                 errors::InvalidArgument("input must be 4-dimensional",
763                                         padded_shape.DebugString()));
764     OP_REQUIRES(context, filter.dims() == 4,
765                 errors::InvalidArgument("filter must be 4-dimensional: ",
766                                         filter.shape().DebugString()));
767 
768     // We only check the first three dims, since the depth is accessed as an
769     // int64 below.
770     for (int i = 0; i < 3; i++) {
771       OP_REQUIRES(
772           context,
773           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
774           errors::InvalidArgument("filter too large"));
775     }
776 
777     // The last dimension for input is in_depth. It must be the same as the
778     // filter's in_depth.
779     const int64 in_depth = padded_shape.dim_size(3);
780     OP_REQUIRES(context, in_depth == filter.dim_size(2),
781                 errors::InvalidArgument(
782                     "input and filter must have the same depth: ", in_depth,
783                     " vs ", filter.dim_size(2)));
784 
785     // The last dimension for filter is out_depth.
786     const int out_depth = static_cast<int>(filter.dim_size(3));
787 
788     // The second dimension for input is rows/height.
789     // The first dimension for filter is rows/height.
790     const int64 padded_rows_raw = padded_shape.dim_size(1);
791     OP_REQUIRES(
792         context,
793         FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
794         errors::InvalidArgument("Input rows too large"));
795     const int padded_rows = static_cast<int>(padded_rows_raw);
796     const int filter_rows = static_cast<int>(filter.dim_size(0));
797     const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
798 
799     // The third dimension for input is columns/width.
800     // The second dimension for filter is columns/width.
801     const int64 padded_cols_raw = padded_shape.dim_size(2);
802     OP_REQUIRES(
803         context,
804         FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
805         errors::InvalidArgument("Input cols too large"));
806     const int padded_cols = static_cast<int>(padded_cols_raw);
807     const int filter_cols = static_cast<int>(filter.dim_size(1));
808     const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
809 
810     // The first dimension for input is batch.
811     const int64 batch_raw = padded_shape.dim_size(0);
812     OP_REQUIRES(context,
813                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
814                 errors::InvalidArgument("batch is too large"));
815     const int batch = static_cast<int>(batch_raw);
816 
817     // For now we take the stride from the second and third dimensions only (we
818     // do not support striding on the batch or depth dimension).
819     const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
820     const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
821 
822     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
823     OP_REQUIRES_OK(context,
824                    GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
825                                          padding_, &out_rows, &pad_rows));
826     OP_REQUIRES_OK(context,
827                    GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
828                                          padding_, &out_cols, &pad_cols));
829     TensorShape out_shape =
830         ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
831     OP_REQUIRES(context, (out_shape.num_elements() > 0),
832                 errors::InvalidArgument("Output tensor can't be empty"));
833 
834     // Output tensor is of the following dimensions:
835     // [ in_batch, out_rows, out_cols, out_depth ]
836     Tensor* output = nullptr;
837     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
838 
839     VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
840             << ", padded_cols = " << padded_cols
841             << ", resized_cols = " << resized_cols
842             << ", filter_cols = " << filter_cols
843             << ", padded_rows = " << padded_rows
844             << ", resized_rows = " << resized_rows
845             << ", filter_rows = " << filter_rows
846             << ", stride_rows = " << stride_rows
847             << ", stride_cols = " << stride_cols
848             << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
849 
850     // If there is nothing to compute, return.
851     if (out_shape.num_elements() == 0) {
852       return;
853     }
854     TConvFunctor conv_functor;
855     conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
856                  padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
857                  filter_cols, out_depth, stride_rows, stride_cols, padding_,
858                  output->flat<T>().data(), out_rows, out_cols, st, top_padding,
859                  bottom_padding, left_padding, right_padding, offset_);
860   }
861 
862  private:
863   std::vector<int32> strides_;
864   Padding padding_;
865   bool align_corners_;
866   int offset_;
867 
868   TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
869 };
870 
871 #define REGISTER_FUSED(T)                                                 \
872   REGISTER_KERNEL_BUILDER(                                                \
873       Name("FusedResizeAndPadConv2D")                                     \
874           .Device(DEVICE_CPU)                                             \
875           .TypeConstraint<T>("T"),                                        \
876       FusedResizeConv2DUsingGemmOp<                                       \
877           T,                                                              \
878           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
879                                        BILINEAR>,                         \
880           true>);
881 
882 TF_CALL_half(REGISTER_FUSED);
883 TF_CALL_float(REGISTER_FUSED);
884 TF_CALL_double(REGISTER_FUSED);
885 
886 #define REGISTER_PAD_ONLY_FUSED(T)                                        \
887   REGISTER_KERNEL_BUILDER(                                                \
888       Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
889       FusedResizeConv2DUsingGemmOp<                                       \
890           T,                                                              \
891           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
892                                        NEAREST>,                          \
893           false>);
894 
895 TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
896 TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
897 TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
898 
899 }  // namespace tensorflow
900