1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements convolution operations with image transformations (resize and
17 // mirror padding) baked into the processing, to optimize latency and memory
18 // usage.
19
20 #define EIGEN_USE_THREADS
21
22 #include <string>
23 #include <vector>
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/numeric_op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_slice.h"
33 #include "tensorflow/core/kernels/conv_2d.h"
34 #include "tensorflow/core/kernels/conv_ops.h"
35 #include "tensorflow/core/kernels/gemm_functors.h"
36 #include "tensorflow/core/kernels/image_resizer_state.h"
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/util/mirror_pad_mode.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42
43 namespace tensorflow {
44 namespace {
45
46 // We don't want to allocate a buffer to hold all the patches if the size is
47 // going to be extremely large, so break it into chunks if it's bigger than
48 // a limit. Each chunk will be processed serially, so we can refill the
49 // buffer for the next chunk and reuse it, keeping maximum memory size down.
50 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
51 // other platforms using Eigen, and 1MB for iOS devices, from experimentation.
52 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
53 const size_t kMaxChunkSize = (1 * 1024 * 1024);
54 #else
55 const size_t kMaxChunkSize = (16 * 1024 * 1024);
56 #endif
57 const size_t kResizeCacheSize = (8 * 1024 * 1024);
58
59 // Lookup method used when resizing.
60 enum SamplingMode {
61 BILINEAR = 0,
62 NEAREST = 1,
63 };
64
65 // Simple utility function used by FusedConv to multithread basic workloads. To
66 // use it, pass begin and end values for the full workload and a std::function
67 // that receives a subset of that through the begin and end values for each
68 // worker's task. The division of the full workload into worker tasks is handled
69 // by the multithreading logic. Here's an example of how to use it:
70 // std::vector<float> my_vector(100);
71 // ...
72 // FusedConvParallelFor(context, 0, 100,
73 // [&my_vector](int64 task_begin, int64 task_end) {
74 // for (int64 current = task_begin; current != task_end; ++current) {
75 // my_vector[current] *= 10.0f;
76 // }
77 // });
FusedConvParallelFor(OpKernelContext * context,int64 begin,int64 end,const std::function<void (int64,int64)> & task_function)78 void FusedConvParallelFor(
79 OpKernelContext* context, int64 begin, int64 end,
80 const std::function<void(int64, int64)>& task_function) {
81 // On iOS, the thread management imposes a very big performance penalty, so
82 // just call the function directly with no multithreading.
83 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
84 task_function(begin, end);
85 #else
86 auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
87 thread::ThreadPool* thread_pool = worker_threads.workers;
88 const int64 total_elements = end - begin;
89 // This is a bit of an arbitrary number, but was found to work well for
90 // typical models we've been profiling on various devices.
91 const int64 element_cost = 10000000;
92 thread_pool->ParallelFor(
93 total_elements, element_cost,
94 [begin, task_function](int64 begin_offset, int64 end_offset) {
95 const int64 task_begin = begin + begin_offset;
96 const int64 task_end = begin + end_offset;
97 task_function(task_begin, task_end);
98 });
99 #endif
100 }
101
102 // Holds the state needed for the resizing subtasks.
103 template <class T1>
104 struct ResizeTaskParameters {
ResizeTaskParameterstensorflow::__anond19478670111::ResizeTaskParameters105 ResizeTaskParameters() : st(false, false) {}
106
107 int cache_height;
108 T1* resize_cache;
109 int cache_line_width;
110 int input_width;
111 int input_depth;
112 int top_padding;
113 int pad_offset;
114 int64 resized_height;
115 ImageResizerState st;
116 const T1* input_batch_start;
117 int64 cache_start_x;
118 int64 cache_end_x;
119 int left_padding;
120 int64 resized_width;
121 int64 padded_width;
122 int64 padded_height;
123 };
124
125 template <class T1>
126 struct PerCacheLineParameters {
PerCacheLineParameterstensorflow::__anond19478670111::PerCacheLineParameters127 PerCacheLineParameters() {}
PerCacheLineParameterstensorflow::__anond19478670111::PerCacheLineParameters128 PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
129 : cache_line_start(other.cache_line_start),
130 input_top_row_start(other.input_top_row_start),
131 input_bottom_row_start(other.input_bottom_row_start),
132 y_lerp(other.y_lerp) {}
133
134 T1* cache_line_start;
135 const T1* input_top_row_start;
136 const T1* input_bottom_row_start;
137 T1 y_lerp;
138 };
139
140 // Helper class to simplify bilinear filtering
141 template <class T1>
142 struct SampleRect {
SampleRecttensorflow::__anond19478670111::SampleRect143 EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
144 const T1* in_bottom_left,
145 const T1* in_bottom_right)
146 : top_left(in_top_left),
147 top_right(in_top_right),
148 bottom_left(in_bottom_left),
149 bottom_right(in_bottom_right) {}
150
BilinearSampletensorflow::__anond19478670111::SampleRect151 EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
152 T1 y_lerp) const {
153 const T1 top =
154 top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
155 const T1 bottom = bottom_left[channel] +
156 (bottom_right[channel] - bottom_left[channel]) * x_lerp;
157 return top + (bottom - top) * y_lerp;
158 }
159
160 const T1* top_left;
161 const T1* top_right;
162 const T1* bottom_left;
163 const T1* bottom_right;
164 };
165
166 // Calculates parameters which remain constant through a resize cache row.
167 template <class T1>
CalculatePerCacheLineParameters(int64 cache_height,int64 cache_y,T1 * resize_cache,int64 cache_line_width,int64 input_width,int64 input_depth,int64 top_padding,int64 pad_offset,int64 resized_height,const ImageResizerState & st,const T1 * input_batch_start)168 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
169 int64 cache_height, int64 cache_y, T1* resize_cache, int64 cache_line_width,
170 int64 input_width, int64 input_depth, int64 top_padding, int64 pad_offset,
171 int64 resized_height, const ImageResizerState& st,
172 const T1* input_batch_start) {
173 PerCacheLineParameters<T1> result;
174 // The cache is organized so that the real y values of the resized image map
175 // onto the actual cache values through a modulo scheme. This means that as we
176 // progress downwards through the image, we keep reusing a small cache and so
177 // keep memory usage down.
178 int64 cache_index_y;
179 if (cache_y < 0) {
180 cache_index_y = cache_height + (cache_y % cache_height);
181 } else {
182 cache_index_y = cache_y % cache_height;
183 }
184 result.cache_line_start =
185 resize_cache + (cache_index_y * cache_line_width * input_depth);
186 // This part is implementing the mirror padding that happens before resizing.
187 float in_y = (cache_y - top_padding);
188 if (in_y < 0) {
189 in_y = -(in_y + 1.0f - pad_offset);
190 } else if (in_y >= resized_height) {
191 in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
192 }
193 // Here's where do do the actual resize.
194 in_y *= st.height_scale;
195 const int64 top_y_index = static_cast<int64>(std::floor(in_y));
196 const int64 bottom_y_index =
197 std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
198 // Lerp is used for bilinear filtering when that's needed.
199 result.y_lerp = static_cast<T1>(in_y - top_y_index);
200 // Which rows of the original input image to pull the values from.
201 result.input_top_row_start =
202 input_batch_start + (top_y_index * input_width * input_depth);
203 result.input_bottom_row_start =
204 input_batch_start + (bottom_y_index * input_width * input_depth);
205 return result;
206 }
207
208 template <class T1>
209 struct PerCachePixelParameters {
PerCachePixelParameterstensorflow::__anond19478670111::PerCachePixelParameters210 PerCachePixelParameters() {}
PerCachePixelParameterstensorflow::__anond19478670111::PerCachePixelParameters211 PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
212 : cache_line_pixel(other.cache_line_pixel),
213 left_x_index(other.left_x_index),
214 right_x_index(other.right_x_index),
215 x_lerp(other.x_lerp) {}
216
217 T1* cache_line_pixel;
218 int64 left_x_index;
219 int64 right_x_index;
220 T1 x_lerp;
221 };
222
223 // Pulls out common parameters used for every resized pixel.
224 template <class T1>
225 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
CalculatePerCachePixelParameters(int64 cache_x,int64 cache_start_x,T1 * cache_line_start,int64 input_depth,int64 left_padding,int64 pad_offset,int64 resized_width,const ImageResizerState & st)226 CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
227 T1* cache_line_start, int64 input_depth,
228 int64 left_padding, int64 pad_offset,
229 int64 resized_width,
230 const ImageResizerState& st) {
231 PerCachePixelParameters<T1> result;
232 // Figure out where we're going to store the results of our transform.
233 const int cache_index_x = cache_x - cache_start_x;
234 result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
235 // Implement mirror padding by flipping in_x if it's off the edge.
236 float in_x = (cache_x - left_padding);
237 if (in_x < 0) {
238 in_x = -(in_x + 1.0f - pad_offset);
239 } else if (in_x >= resized_width) {
240 in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
241 }
242 // Resize the x parameters.
243 in_x *= st.width_scale;
244 // Get the x coordinates for the left and right pixels to pull from.
245 result.left_x_index = static_cast<int64>(std::floor(in_x));
246 result.right_x_index =
247 std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
248 // This x_lerp is used to blend pixels in bilinear filtering.
249 result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
250 return result;
251 }
252
253 // Combines bilinear resizing and mirror padding into the im2col transformation
254 // stage of convolution.
255 template <class T1, class T2, class T3, class TGemmFunctor,
256 SamplingMode SampleMode>
257 class FusedResizeAndPadConvFunctor {
258 public:
operator ()(OpKernelContext * context,const Tensor & input,int input_batches,int resized_height,int resized_width,int padded_height,int padded_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width,const ImageResizerState & st,int top_padding,int bottom_padding,int left_padding,int right_padding,int pad_offset)259 void operator()(OpKernelContext* context, const Tensor& input,
260 int input_batches, int resized_height, int resized_width,
261 int padded_height, int padded_width, int input_depth,
262 const T2* filter_data, int filter_height, int filter_width,
263 int filter_count, int stride_rows, int stride_cols,
264 Padding padding, T3* output_data, int output_height,
265 int output_width, const ImageResizerState& st,
266 int top_padding, int bottom_padding, int left_padding,
267 int right_padding, int pad_offset) {
268 if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
269 (input_depth <= 0)) {
270 LOG(WARNING) << "Conv2D was called with bad input dimensions: "
271 << input_batches << ", " << padded_height << ", "
272 << padded_width << ", " << input_depth;
273 return;
274 }
275 if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
276 LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
277 << filter_width << ", " << filter_height << ", "
278 << filter_count;
279 return;
280 }
281 if ((output_width <= 0) || (output_height <= 0)) {
282 LOG(WARNING) << "Conv2D was called with bad output width or height: "
283 << output_width << ", " << output_height;
284 return;
285 }
286 OP_REQUIRES(
287 context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
288 errors::InvalidArgument("Bad sample mode passed in", SampleMode));
289
290 // These calculations define how the patches will be positioned within the
291 // input image. The actual definitions are quite complex, and rely on the
292 // previously-calculated output size.
293 int filter_left_offset;
294 int filter_top_offset;
295 if (padding == VALID) {
296 filter_left_offset =
297 ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
298 2;
299 filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
300 padded_height + 1) /
301 2;
302 } else {
303 filter_left_offset =
304 ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
305 filter_top_offset =
306 ((output_height - 1) * stride_rows + filter_height - padded_height) /
307 2;
308 }
309
310 ResizeTaskParameters<T1> task_params;
311 task_params.input_depth = input_depth;
312 task_params.top_padding = top_padding;
313 task_params.pad_offset = pad_offset;
314 task_params.resized_height = resized_height;
315 task_params.st = st;
316 task_params.left_padding = left_padding;
317 task_params.resized_width = resized_width;
318 task_params.padded_width = padded_width;
319 task_params.padded_height = padded_height;
320
321 // The im2col buffer has # of patches rows, and # of filters cols.
322 // It's laid out like this, in row major order in memory:
323 // < filter value count >
324 // ^ +---------------------+
325 // patch | |
326 // count | |
327 // v +---------------------+
328 // Each patch row contains a filter_width x filter_height patch of the
329 // input, with the depth channel as the most contiguous in memory, followed
330 // by the width, then the height. This is the standard memory order in the
331 // image world if it helps to visualize it.
332 const int filter_value_count = filter_width * filter_height * input_depth;
333
334 OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
335 errors::InvalidArgument("Im2Col patch too large for buffer"));
336 const size_t patches_per_chunk =
337 kMaxChunkSize / (filter_value_count * sizeof(T1));
338 // Because memory allocation is very expensive on mobile platforms, try to
339 // allocate a persistent buffer that will be kept around between calls. We
340 // use TensorFlow's resource management to ensure that the memory will be
341 // released when the session is over.
342 Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
343 std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
344 [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
345 *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
346 return Status::OK();
347 };
348 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
349 "Conv2d", "im2col_buffer",
350 &im2col_buffer_resource, creator));
351
352 // Create a resize cache memory buffer that will hold the rows of
353 // transformed and mirror padded input pixels, ready to be copied
354 // into filter patches by im2col.
355 // It's laid out like this, in row major order in memory:
356 // < cache line width >
357 // ^ +--------------------+
358 // cache | |
359 // height | |
360 // v +--------------------+
361 // Each cache row contains a cache_line_width number of resized pixels,
362 // each with input_depth channels. The cache height is typically less than
363 // the full height the resized image would be, so it's filled up
364 // incrementally as we progress downwards through the input creating im2col
365 // patches.
366 task_params.cache_start_x = -filter_left_offset;
367 task_params.cache_end_x =
368 (((output_width - 1) * stride_cols) - filter_left_offset) +
369 filter_width;
370 task_params.cache_line_width =
371 task_params.cache_end_x - task_params.cache_start_x;
372 task_params.cache_height =
373 kResizeCacheSize / (task_params.cache_line_width * input_depth);
374 const int needed_resize_cache_count =
375 filter_height * task_params.cache_line_width * input_depth;
376 OP_REQUIRES(context,
377 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
378 errors::InvalidArgument("Input too large for resize cache"));
379 Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
380 std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
381 resize_creator =
382 [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
383 *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
384 return Status::OK();
385 };
386 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
387 "Conv2d", "resize_cache",
388 &resize_cache_resource, resize_creator));
389
390 // This means that multiple ops can't be run simultaneously on different
391 // threads, because we have a single shared resource. The platforms this is
392 // aimed at have intra-op parallelism as their focus though, so it shouldn't
393 // be an issue.
394 mutex_lock lock_buffer(im2col_buffer_resource->mu);
395 core::ScopedUnref unref_buffer(im2col_buffer_resource);
396 T1* im2col_buffer = im2col_buffer_resource->data;
397
398 // This buffer is used as a fairly heavy-weight cache for the resized and
399 // mirrored inputs to the im2col operation. The problem is that we want to
400 // keep the memory usage down by not rendering the fully resized and padded
401 // input tensor to the convolution into an entire buffer. The first approach
402 // to avoid this was to fold the bilinear filtering and padding spatial
403 // transformations into the im2col lookup itself. This successfully reduced
404 // memory usage, but because im2col can access an individual pixel for many
405 // different patches, the extra overhead of doing the same bilinear lookups
406 // repeatedly became too expensive.
407 // The resize cache is designed to avoid this problem by keeping a
408 // horizontal slice of the resized and padded input to the im2col
409 // precalculated, so that repeated accesses to the same pixel from different
410 // filter patches can just be copied from this cache. It's organized as a
411 // horizontal slice stretching across the whole virtual image, and as high
412 // as the filter window, so that as the patch processing moves across all
413 // the pixels are present, and before a new row of patches is started any
414 // previously calculated rows that are needed are maintained, with new rows
415 // calculated as required.
416 mutex_lock resize_lock_buffer(resize_cache_resource->mu);
417 core::ScopedUnref unref_resized_cache(resize_cache_resource);
418 task_params.resize_cache = resize_cache_resource->data;
419
420 const T1* input_data = input.flat<T1>().data();
421 const int64 input_height = input.shape().dim_sizes()[1];
422 task_params.input_width = input.shape().dim_sizes()[2];
423
424 int end_cached_lines = std::numeric_limits<int>::min();
425
426 for (int batch = 0; batch < input_batches; ++batch) {
427 task_params.input_batch_start =
428 input_data +
429 (batch * input_height * task_params.input_width * input_depth);
430 const int in_y_end =
431 ((output_height * stride_rows) - filter_top_offset) + filter_height;
432 for (int out_y = 0; out_y < output_height; ++out_y) {
433 const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
434 const int cache_start_y = std::max(in_y_origin, end_cached_lines);
435 const int cache_end_y = std::min(
436 in_y_end, std::max((in_y_origin + task_params.cache_height),
437 end_cached_lines));
438 if (end_cached_lines < (in_y_origin + filter_height)) {
439 // This call breaks up the work required for calculating the mirror
440 // padding and resizing across multiple threads.
441 FusedConvParallelFor(
442 context, cache_start_y, cache_end_y,
443 [task_params](int64 task_cache_start_y, int64 task_cache_end_y) {
444 // This is a long and confusing function, but it's been laid out
445 // this way to help with performance on some intensive models.
446 // What it's doing is populating a cache of the original input
447 // image, after it's been bilinear resized and had its edges
448 // mirrored. This allows the following im2col code to access the
449 // transformed pixels from this cache, without having to
450 // repeatedly apply the expensive bilinear calculations as the
451 // same pixels are accessed by different patches.
452 // This is most effective when the stride is small and the
453 // filter size is large, since that's when pixels are reused
454 // most frequently as patches overlap.
455 for (int cache_y = task_cache_start_y;
456 cache_y < task_cache_end_y; ++cache_y) {
457 // We organize the cache as a series of rows, each containing
458 // all the transformed pixels for a given line in the image.
459 // This cache is big enough to hold at least a filter's height
460 // worth of rows, but typically more, limited by the size of
461 // the cache buffer.
462 // We don't allocate an entire image's worth of rows though,
463 // because we're trying to keep memory usage down, so as we
464 // progress downwards through the im2col we periodically
465 // refresh the cache so that the next lines that are needed
466 // for that operation are always present.
467 // Work out the parameters that remain constant across the
468 // row we're calculating.
469 PerCacheLineParameters<T1> line_params(
470 CalculatePerCacheLineParameters<T1>(
471 task_params.cache_height, cache_y,
472 task_params.resize_cache,
473 task_params.cache_line_width, task_params.input_width,
474 task_params.input_depth, task_params.top_padding,
475 task_params.pad_offset, task_params.resized_height,
476 task_params.st, task_params.input_batch_start));
477 // Iterate through the resize cache row we're filling in.
478 for (int cache_x = task_params.cache_start_x;
479 cache_x < task_params.cache_end_x; ++cache_x) {
480 // Figure out what we need for the cache pixel we're
481 // populating.
482 PerCachePixelParameters<T1> pixel_params(
483 CalculatePerCachePixelParameters<T1>(
484 cache_x, task_params.cache_start_x,
485 line_params.cache_line_start,
486 task_params.input_depth, task_params.left_padding,
487 task_params.pad_offset, task_params.resized_width,
488 task_params.st));
489 // If the access is off the left, right, top, or bottom of
490 // the resized image, the conv padding means we should set
491 // it to zero.
492 if ((cache_x < 0) ||
493 (cache_x >= task_params.padded_width) ||
494 (cache_y < 0) ||
495 (cache_y >= task_params.padded_height)) {
496 std::fill_n(pixel_params.cache_line_pixel,
497 task_params.input_depth, T1(0));
498 } else {
499 // There are two different sampling strategies for
500 // resizing. When using nearest, we can just do a
501 // straight copy of the pixel closest to our sample point,
502 // but bilinear requires a more complex calculation.
503 if (SampleMode == NEAREST) {
504 const T1* input_top_left_pixel =
505 line_params.input_top_row_start +
506 (pixel_params.left_x_index *
507 task_params.input_depth);
508
509 std::copy_n(input_top_left_pixel,
510 task_params.input_depth,
511 pixel_params.cache_line_pixel);
512 } else {
513 const SampleRect<T1> rect(
514 line_params.input_top_row_start +
515 (pixel_params.left_x_index *
516 task_params.input_depth),
517 line_params.input_top_row_start +
518 (pixel_params.right_x_index *
519 task_params.input_depth),
520 line_params.input_bottom_row_start +
521 (pixel_params.left_x_index *
522 task_params.input_depth),
523 line_params.input_bottom_row_start +
524 (pixel_params.right_x_index *
525 task_params.input_depth));
526 for (int in_channel = 0;
527 in_channel < task_params.input_depth;
528 ++in_channel) {
529 pixel_params.cache_line_pixel[in_channel] =
530 rect.BilinearSample(in_channel,
531 pixel_params.x_lerp,
532 line_params.y_lerp);
533 }
534 }
535 }
536 }
537 }
538 });
539 end_cached_lines = cache_end_y;
540 }
541 for (int out_x = 0; out_x < output_width; ++out_x) {
542 const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
543 const int patch_index = (batch * output_width * output_height) +
544 (out_y * output_width) + out_x;
545 const int patch_index_within_chunk = patch_index % patches_per_chunk;
546 T1* im2col_patch_start =
547 im2col_buffer + (patch_index_within_chunk * filter_value_count);
548 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
549 T1* im2col_row_start =
550 im2col_patch_start +
551 (filter_y * filter_width * task_params.input_depth);
552 const int conv_in_y = in_y_origin + filter_y;
553 int cache_index_y;
554 if (conv_in_y < 0) {
555 cache_index_y = task_params.cache_height +
556 (conv_in_y % task_params.cache_height);
557 } else {
558 cache_index_y = conv_in_y % task_params.cache_height;
559 }
560 T1* cache_line_start =
561 task_params.resize_cache +
562 (cache_index_y * task_params.cache_line_width *
563 task_params.input_depth);
564 T1* cache_filter_row_start =
565 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
566 task_params.input_depth);
567 std::copy_n(cache_filter_row_start,
568 (filter_width * task_params.input_depth),
569 im2col_row_start);
570 }
571 const bool is_last_in_chunk =
572 (patch_index_within_chunk == (patches_per_chunk - 1));
573 const bool is_last_overall =
574 ((batch == (input_batches - 1)) &&
575 (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
576 if (is_last_in_chunk || is_last_overall) {
577 // Now we've assembled a set of image patches into a matrix, apply
578 // a GEMM matrix multiply of the patches as rows, times the filter
579 // weights in columns, to get partial results in the output
580 // matrix.
581 const int how_many_patches = patch_index_within_chunk + 1;
582 const int m = how_many_patches;
583 const int n = filter_count;
584 const int k = filter_value_count;
585 const int lda = filter_value_count;
586 const int ldb = filter_count;
587 const int ldc = filter_count;
588 const size_t start_patch_index =
589 patch_index - (how_many_patches - 1);
590 T3* chunk_output_data =
591 output_data + (start_patch_index * filter_count);
592 TGemmFunctor gemm_functor;
593 gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
594 chunk_output_data, ldc);
595 }
596 }
597 }
598 }
599 }
600 };
601
602 } // namespace
603
604 // Implements a version of convolution with bilinear resizing and mirror padding
605 // included.
606 template <class T, class TConvFunctor, bool DoResize>
607 class FusedResizeConv2DUsingGemmOp : public OpKernel {
608 public:
FusedResizeConv2DUsingGemmOp(OpKernelConstruction * context)609 explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
610 : OpKernel(context) {
611 if (DoResize) {
612 OP_REQUIRES_OK(context,
613 context->GetAttr("resize_align_corners", &align_corners_));
614 }
615 MirrorPadMode mode;
616 OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
617
618 switch (mode) {
619 case MirrorPadMode::SYMMETRIC: {
620 offset_ = 0;
621 break;
622 }
623 case MirrorPadMode::REFLECT: {
624 offset_ = 1;
625 break;
626 }
627 default:
628 OP_REQUIRES(context, false,
629 errors::InvalidArgument(
630 "mode must be either REFLECT or SYMMETRIC."));
631 }
632 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
633 OP_REQUIRES(context, strides_.size() == 4,
634 errors::InvalidArgument("Sliding window strides field must "
635 "specify 4 dimensions"));
636 const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
637 const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
638 OP_REQUIRES(
639 context, stride_n == 1 && stride_c == 1,
640 errors::InvalidArgument("Current implementation does not yet support "
641 "strides in the batch and depth dimensions."));
642 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
643 }
644
Compute(OpKernelContext * context)645 void Compute(OpKernelContext* context) override {
646 // Input tensor is of the following dimensions:
647 // [ batch, in_rows, in_cols, in_depth ]
648 const Tensor& input = context->input(0);
649 OP_REQUIRES(context, (input.shape().num_elements() > 0),
650 errors::InvalidArgument("Input tensor can't be empty"));
651
652 ImageResizerState st(false, false);
653 if (DoResize) {
654 st = ImageResizerState(align_corners_, false);
655 st.ValidateAndCalculateOutputSize(context, input);
656 if (!context->status().ok()) return;
657 } else {
658 // Set up the resize parameters to do no scaling at all.
659 st.batch_size = input.dim_size(0);
660 st.out_height = input.dim_size(1);
661 st.out_width = input.dim_size(2);
662 st.in_height = input.dim_size(1);
663 st.in_width = input.dim_size(2);
664 st.channels = input.dim_size(3);
665 st.height_scale = 1.0f;
666 st.width_scale = 1.0f;
667 }
668 TensorShape resized_shape(
669 {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
670 int paddings_index;
671 int filter_index;
672 if (DoResize) {
673 paddings_index = 2;
674 filter_index = 3;
675 } else {
676 paddings_index = 1;
677 filter_index = 2;
678 }
679 const Tensor& paddings = context->input(paddings_index);
680
681 const int dims = resized_shape.dims();
682 OP_REQUIRES(
683 context,
684 TensorShapeUtils::IsMatrix(paddings.shape()) &&
685 paddings.dim_size(1) == 2,
686 errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
687 paddings.shape().DebugString()));
688 const int fixed_dims =
689 (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
690 ? 1
691 : dims;
692 OP_REQUIRES(
693 context, fixed_dims == paddings.dim_size(0),
694 errors::InvalidArgument(
695 "The first dimension of paddings must be the rank of inputs: ",
696 fixed_dims, " ", paddings.shape().DebugString(), " ",
697 resized_shape.DebugString()));
698 OP_REQUIRES(
699 context, dims == paddings.dim_size(0),
700 errors::InvalidArgument(
701 "The first dimension of paddings must be the rank of inputs: ",
702 dims, " ", paddings.shape().DebugString(), " ",
703 resized_shape.DebugString()));
704
705 OP_REQUIRES(
706 context, dims == 4,
707 errors::InvalidArgument(
708 "Fused mirror padding only supports four-dimensional inputs, but ",
709 dims, " requested"));
710
711 // Compute the shape of the output tensor, and allocate it.
712 TensorShape padded_shape;
713 TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
714 for (int d = 0; d < dims; ++d) {
715 const int32 before =
716 paddings_matrix(d, 0); // Pad before existing elements.
717 const int32 after =
718 paddings_matrix(d, 1); // Pad after existing elements.
719 OP_REQUIRES(context, before >= 0 && after >= 0,
720 errors::InvalidArgument(
721 "paddings must be non-negative: ", before, " ", after));
722 if (offset_ == 0) { // SYMMETRIC mode.
723 OP_REQUIRES(
724 context,
725 before <= resized_shape.dim_size(d) &&
726 after <= resized_shape.dim_size(d),
727 errors::InvalidArgument("paddings must be no greater "
728 "than the dimension size: ",
729 before, ", ", after, " greater than ",
730 resized_shape.dim_size(d)));
731 } else if (offset_ == 1) { // REFLECT mode.
732 OP_REQUIRES(
733 context,
734 before < resized_shape.dim_size(d) &&
735 after < resized_shape.dim_size(d),
736 errors::InvalidArgument("paddings must be less than"
737 " the dimension size: ",
738 before, ", ", after, " not less than ",
739 resized_shape.dim_size(d)));
740 }
741 padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
742 }
743
744 OP_REQUIRES(
745 context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
746 errors::InvalidArgument(
747 "Fused mirror padding only support spatial padding, not batches: ",
748 paddings.DebugString()));
749 OP_REQUIRES(
750 context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
751 errors::InvalidArgument(
752 "Fused mirror padding only support spatial padding, not channels: ",
753 paddings.DebugString()));
754 const int32 top_padding = paddings_matrix(1, 0);
755 const int32 bottom_padding = paddings_matrix(1, 1);
756 const int32 left_padding = paddings_matrix(2, 0);
757 const int32 right_padding = paddings_matrix(2, 1);
758
759 // Input filter is of the following dimensions:
760 // [ filter_rows, filter_cols, in_depth, out_depth]
761 const Tensor& filter = context->input(filter_index);
762
763 // For 2D convolution, there should be 4 dimensions.
764 OP_REQUIRES(context, padded_shape.dims() == 4,
765 errors::InvalidArgument("input must be 4-dimensional",
766 padded_shape.DebugString()));
767 OP_REQUIRES(context, filter.dims() == 4,
768 errors::InvalidArgument("filter must be 4-dimensional: ",
769 filter.shape().DebugString()));
770
771 // We only check the first three dims, since the depth is accessed as an
772 // int64 below.
773 for (int i = 0; i < 3; i++) {
774 OP_REQUIRES(
775 context,
776 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
777 errors::InvalidArgument("filter too large"));
778 }
779
780 // The last dimension for input is in_depth. It must be the same as the
781 // filter's in_depth.
782 const int64 in_depth = padded_shape.dim_size(3);
783 OP_REQUIRES(context, in_depth == filter.dim_size(2),
784 errors::InvalidArgument(
785 "input and filter must have the same depth: ", in_depth,
786 " vs ", filter.dim_size(2)));
787
788 // The last dimension for filter is out_depth.
789 const int out_depth = static_cast<int>(filter.dim_size(3));
790
791 // The second dimension for input is rows/height.
792 // The first dimension for filter is rows/height.
793 const int64 padded_rows_raw = padded_shape.dim_size(1);
794 OP_REQUIRES(
795 context,
796 FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
797 errors::InvalidArgument("Input rows too large"));
798 const int padded_rows = static_cast<int>(padded_rows_raw);
799 const int filter_rows = static_cast<int>(filter.dim_size(0));
800 const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
801
802 // The third dimension for input is columns/width.
803 // The second dimension for filter is columns/width.
804 const int64 padded_cols_raw = padded_shape.dim_size(2);
805 OP_REQUIRES(
806 context,
807 FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
808 errors::InvalidArgument("Input cols too large"));
809 const int padded_cols = static_cast<int>(padded_cols_raw);
810 const int filter_cols = static_cast<int>(filter.dim_size(1));
811 const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
812
813 // The first dimension for input is batch.
814 const int64 batch_raw = padded_shape.dim_size(0);
815 OP_REQUIRES(context,
816 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
817 errors::InvalidArgument("batch is too large"));
818 const int batch = static_cast<int>(batch_raw);
819
820 // For now we take the stride from the second and third dimensions only (we
821 // do not support striding on the batch or depth dimension).
822 const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
823 const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
824
825 int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
826 OP_REQUIRES_OK(context,
827 GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
828 padding_, &out_rows, &pad_rows));
829 OP_REQUIRES_OK(context,
830 GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
831 padding_, &out_cols, &pad_cols));
832 TensorShape out_shape =
833 ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
834 OP_REQUIRES(context, (out_shape.num_elements() > 0),
835 errors::InvalidArgument("Output tensor can't be empty"));
836
837 // Output tensor is of the following dimensions:
838 // [ in_batch, out_rows, out_cols, out_depth ]
839 Tensor* output = nullptr;
840 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
841
842 VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
843 << ", padded_cols = " << padded_cols
844 << ", resized_cols = " << resized_cols
845 << ", filter_cols = " << filter_cols
846 << ", padded_rows = " << padded_rows
847 << ", resized_rows = " << resized_rows
848 << ", filter_rows = " << filter_rows
849 << ", stride_rows = " << stride_rows
850 << ", stride_cols = " << stride_cols
851 << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
852
853 // If there is nothing to compute, return.
854 if (out_shape.num_elements() == 0) {
855 return;
856 }
857 TConvFunctor conv_functor;
858 conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
859 padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
860 filter_cols, out_depth, stride_rows, stride_cols, padding_,
861 output->flat<T>().data(), out_rows, out_cols, st, top_padding,
862 bottom_padding, left_padding, right_padding, offset_);
863 }
864
865 private:
866 std::vector<int32> strides_;
867 Padding padding_;
868 bool align_corners_;
869 int offset_;
870
871 TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
872 };
873
874 #define REGISTER_FUSED(T) \
875 REGISTER_KERNEL_BUILDER( \
876 Name("FusedResizeAndPadConv2D") \
877 .Device(DEVICE_CPU) \
878 .TypeConstraint<T>("T"), \
879 FusedResizeConv2DUsingGemmOp< \
880 T, \
881 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
882 BILINEAR>, \
883 true>);
884
885 TF_CALL_half(REGISTER_FUSED);
886 TF_CALL_float(REGISTER_FUSED);
887 TF_CALL_double(REGISTER_FUSED);
888
889 #define REGISTER_PAD_ONLY_FUSED(T) \
890 REGISTER_KERNEL_BUILDER( \
891 Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
892 FusedResizeConv2DUsingGemmOp< \
893 T, \
894 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
895 NEAREST>, \
896 false>);
897
898 TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
899 TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
900 TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
901
902 } // namespace tensorflow
903