1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements convolution operations with image transformations (resize and
17 // mirror padding) baked into the processing, to optimize latency and memory
18 // usage.
19
20 #define EIGEN_USE_THREADS
21
22 #include <string>
23 #include <vector>
24
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_shape_util.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/tensor_slice.h"
34 #include "tensorflow/core/kernels/conv_2d.h"
35 #include "tensorflow/core/kernels/conv_ops.h"
36 #include "tensorflow/core/kernels/gemm_functors.h"
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/util/image_resizer_state.h"
40 #include "tensorflow/core/util/mirror_pad_mode.h"
41 #include "tensorflow/core/util/padding.h"
42 #include "tensorflow/core/util/tensor_format.h"
43
44 namespace tensorflow {
45 namespace {
46
47 // We don't want to allocate a buffer to hold all the patches if the size is
48 // going to be extremely large, so break it into chunks if it's bigger than
49 // a limit. Each chunk will be processed serially, so we can refill the
50 // buffer for the next chunk and reuse it, keeping maximum memory size down.
51 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
52 // other platforms using Eigen, and 1MB for iOS devices, from experimentation.
53 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
54 const size_t kMaxChunkSize = (1 * 1024 * 1024);
55 #else
56 const size_t kMaxChunkSize = (16 * 1024 * 1024);
57 #endif
58 const size_t kResizeCacheSize = (8 * 1024 * 1024);
59
60 // Lookup method used when resizing.
61 enum SamplingMode {
62 BILINEAR = 0,
63 NEAREST = 1,
64 };
65
66 // Simple utility function used by FusedConv to multithread basic workloads. To
67 // use it, pass begin and end values for the full workload and a std::function
68 // that receives a subset of that through the begin and end values for each
69 // worker's task. The division of the full workload into worker tasks is handled
70 // by the multithreading logic. Here's an example of how to use it:
71 // std::vector<float> my_vector(100);
72 // ...
73 // FusedConvParallelFor(context, 0, 100,
74 // [&my_vector](int64 task_begin, int64 task_end) {
75 // for (int64 current = task_begin; current != task_end; ++current) {
76 // my_vector[current] *= 10.0f;
77 // }
78 // });
FusedConvParallelFor(OpKernelContext * context,int64 begin,int64 end,const std::function<void (int64,int64)> & task_function)79 void FusedConvParallelFor(
80 OpKernelContext* context, int64 begin, int64 end,
81 const std::function<void(int64, int64)>& task_function) {
82 // On iOS, the thread management imposes a very big performance penalty, so
83 // just call the function directly with no multithreading.
84 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
85 task_function(begin, end);
86 #else
87 auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
88 thread::ThreadPool* thread_pool = worker_threads.workers;
89 const int64 total_elements = end - begin;
90 // This is a bit of an arbitrary number, but was found to work well for
91 // typical models we've been profiling on various devices.
92 const int64 element_cost = 10000000;
93 thread_pool->ParallelFor(
94 total_elements, element_cost,
95 [begin, task_function](int64 begin_offset, int64 end_offset) {
96 const int64 task_begin = begin + begin_offset;
97 const int64 task_end = begin + end_offset;
98 task_function(task_begin, task_end);
99 });
100 #endif
101 }
102
103 // Holds the state needed for the resizing subtasks.
104 template <class T1>
105 struct ResizeTaskParameters {
ResizeTaskParameterstensorflow::__anond23836c50111::ResizeTaskParameters106 ResizeTaskParameters() : st(false, false) {}
107
108 int cache_height;
109 T1* resize_cache;
110 int cache_line_width;
111 int input_width;
112 int input_depth;
113 int top_padding;
114 int pad_offset;
115 int64 resized_height;
116 ImageResizerState st;
117 const T1* input_batch_start;
118 int64 cache_start_x;
119 int64 cache_end_x;
120 int left_padding;
121 int64 resized_width;
122 int64 padded_width;
123 int64 padded_height;
124 };
125
126 template <class T1>
127 struct PerCacheLineParameters {
PerCacheLineParameterstensorflow::__anond23836c50111::PerCacheLineParameters128 PerCacheLineParameters() {}
PerCacheLineParameterstensorflow::__anond23836c50111::PerCacheLineParameters129 PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
130 : cache_line_start(other.cache_line_start),
131 input_top_row_start(other.input_top_row_start),
132 input_bottom_row_start(other.input_bottom_row_start),
133 y_lerp(other.y_lerp) {}
134
135 T1* cache_line_start;
136 const T1* input_top_row_start;
137 const T1* input_bottom_row_start;
138 T1 y_lerp;
139 };
140
141 // Helper class to simplify bilinear filtering
142 template <class T1>
143 struct SampleRect {
SampleRecttensorflow::__anond23836c50111::SampleRect144 EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
145 const T1* in_bottom_left,
146 const T1* in_bottom_right)
147 : top_left(in_top_left),
148 top_right(in_top_right),
149 bottom_left(in_bottom_left),
150 bottom_right(in_bottom_right) {}
151
BilinearSampletensorflow::__anond23836c50111::SampleRect152 EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
153 T1 y_lerp) const {
154 const T1 top =
155 top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
156 const T1 bottom = bottom_left[channel] +
157 (bottom_right[channel] - bottom_left[channel]) * x_lerp;
158 return top + (bottom - top) * y_lerp;
159 }
160
161 const T1* top_left;
162 const T1* top_right;
163 const T1* bottom_left;
164 const T1* bottom_right;
165 };
166
167 // Calculates parameters which remain constant through a resize cache row.
168 template <class T1>
CalculatePerCacheLineParameters(int64 cache_height,int64 cache_y,T1 * resize_cache,int64 cache_line_width,int64 input_width,int64 input_depth,int64 top_padding,int64 pad_offset,int64 resized_height,const ImageResizerState & st,const T1 * input_batch_start)169 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
170 int64 cache_height, int64 cache_y, T1* resize_cache, int64 cache_line_width,
171 int64 input_width, int64 input_depth, int64 top_padding, int64 pad_offset,
172 int64 resized_height, const ImageResizerState& st,
173 const T1* input_batch_start) {
174 PerCacheLineParameters<T1> result;
175 // The cache is organized so that the real y values of the resized image map
176 // onto the actual cache values through a modulo scheme. This means that as we
177 // progress downwards through the image, we keep reusing a small cache and so
178 // keep memory usage down.
179 int64 cache_index_y;
180 if (cache_y < 0) {
181 cache_index_y = cache_height + (cache_y % cache_height);
182 } else {
183 cache_index_y = cache_y % cache_height;
184 }
185 result.cache_line_start =
186 resize_cache + (cache_index_y * cache_line_width * input_depth);
187 // This part is implementing the mirror padding that happens before resizing.
188 float in_y = (cache_y - top_padding);
189 if (in_y < 0) {
190 in_y = -(in_y + 1.0f - pad_offset);
191 } else if (in_y >= resized_height) {
192 in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
193 }
194 // Here's where to do the actual resize.
195 in_y *= st.height_scale;
196 const int64 top_y_index = static_cast<int64>(std::floor(in_y));
197 const int64 bottom_y_index =
198 std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
199 // Lerp is used for bilinear filtering when that's needed.
200 result.y_lerp = static_cast<T1>(in_y - top_y_index);
201 // Which rows of the original input image to pull the values from.
202 result.input_top_row_start =
203 input_batch_start + (top_y_index * input_width * input_depth);
204 result.input_bottom_row_start =
205 input_batch_start + (bottom_y_index * input_width * input_depth);
206 return result;
207 }
208
209 template <class T1>
210 struct PerCachePixelParameters {
PerCachePixelParameterstensorflow::__anond23836c50111::PerCachePixelParameters211 PerCachePixelParameters() {}
PerCachePixelParameterstensorflow::__anond23836c50111::PerCachePixelParameters212 PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
213 : cache_line_pixel(other.cache_line_pixel),
214 left_x_index(other.left_x_index),
215 right_x_index(other.right_x_index),
216 x_lerp(other.x_lerp) {}
217
218 T1* cache_line_pixel;
219 int64 left_x_index;
220 int64 right_x_index;
221 T1 x_lerp;
222 };
223
224 // Pulls out common parameters used for every resized pixel.
225 template <class T1>
226 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
CalculatePerCachePixelParameters(int64 cache_x,int64 cache_start_x,T1 * cache_line_start,int64 input_depth,int64 left_padding,int64 pad_offset,int64 resized_width,const ImageResizerState & st)227 CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
228 T1* cache_line_start, int64 input_depth,
229 int64 left_padding, int64 pad_offset,
230 int64 resized_width,
231 const ImageResizerState& st) {
232 PerCachePixelParameters<T1> result;
233 // Figure out where we're going to store the results of our transform.
234 const int cache_index_x = cache_x - cache_start_x;
235 result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
236 // Implement mirror padding by flipping in_x if it's off the edge.
237 float in_x = (cache_x - left_padding);
238 if (in_x < 0) {
239 in_x = -(in_x + 1.0f - pad_offset);
240 } else if (in_x >= resized_width) {
241 in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
242 }
243 // Resize the x parameters.
244 in_x *= st.width_scale;
245 // Get the x coordinates for the left and right pixels to pull from.
246 result.left_x_index = static_cast<int64>(std::floor(in_x));
247 result.right_x_index =
248 std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
249 // This x_lerp is used to blend pixels in bilinear filtering.
250 result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
251 return result;
252 }
253
254 // Combines bilinear resizing and mirror padding into the im2col transformation
255 // stage of convolution.
256 template <class T1, class T2, class T3, class TGemmFunctor,
257 SamplingMode SampleMode>
258 class FusedResizeAndPadConvFunctor {
259 public:
operator ()(OpKernelContext * context,const Tensor & input,int input_batches,int resized_height,int resized_width,int padded_height,int padded_width,int input_depth,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int stride_rows,int stride_cols,Padding padding,T3 * output_data,int output_height,int output_width,const ImageResizerState & st,int top_padding,int bottom_padding,int left_padding,int right_padding,int pad_offset)260 void operator()(OpKernelContext* context, const Tensor& input,
261 int input_batches, int resized_height, int resized_width,
262 int padded_height, int padded_width, int input_depth,
263 const T2* filter_data, int filter_height, int filter_width,
264 int filter_count, int stride_rows, int stride_cols,
265 Padding padding, T3* output_data, int output_height,
266 int output_width, const ImageResizerState& st,
267 int top_padding, int bottom_padding, int left_padding,
268 int right_padding, int pad_offset) {
269 if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
270 (input_depth <= 0)) {
271 LOG(WARNING) << "Conv2D was called with bad input dimensions: "
272 << input_batches << ", " << padded_height << ", "
273 << padded_width << ", " << input_depth;
274 return;
275 }
276 if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
277 LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
278 << filter_width << ", " << filter_height << ", "
279 << filter_count;
280 return;
281 }
282 if ((output_width <= 0) || (output_height <= 0)) {
283 LOG(WARNING) << "Conv2D was called with bad output width or height: "
284 << output_width << ", " << output_height;
285 return;
286 }
287 OP_REQUIRES(
288 context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
289 errors::InvalidArgument("Bad sample mode passed in", SampleMode));
290
291 // These calculations define how the patches will be positioned within the
292 // input image. The actual definitions are quite complex, and rely on the
293 // previously-calculated output size.
294 int filter_left_offset;
295 int filter_top_offset;
296 if (padding == VALID) {
297 filter_left_offset =
298 ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
299 2;
300 filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
301 padded_height + 1) /
302 2;
303 } else {
304 filter_left_offset =
305 ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
306 filter_top_offset =
307 ((output_height - 1) * stride_rows + filter_height - padded_height) /
308 2;
309 }
310
311 ResizeTaskParameters<T1> task_params;
312 task_params.input_depth = input_depth;
313 task_params.top_padding = top_padding;
314 task_params.pad_offset = pad_offset;
315 task_params.resized_height = resized_height;
316 task_params.st = st;
317 task_params.left_padding = left_padding;
318 task_params.resized_width = resized_width;
319 task_params.padded_width = padded_width;
320 task_params.padded_height = padded_height;
321
322 // The im2col buffer has # of patches rows, and # of filters cols.
323 // It's laid out like this, in row major order in memory:
324 // < filter value count >
325 // ^ +---------------------+
326 // patch | |
327 // count | |
328 // v +---------------------+
329 // Each patch row contains a filter_width x filter_height patch of the
330 // input, with the depth channel as the most contiguous in memory, followed
331 // by the width, then the height. This is the standard memory order in the
332 // image world if it helps to visualize it.
333 const int filter_value_count = filter_width * filter_height * input_depth;
334
335 OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
336 errors::InvalidArgument("Im2Col patch too large for buffer"));
337 const size_t patches_per_chunk =
338 kMaxChunkSize / (filter_value_count * sizeof(T1));
339 // Because memory allocation is very expensive on mobile platforms, try to
340 // allocate a persistent buffer that will be kept around between calls. We
341 // use TensorFlow's resource management to ensure that the memory will be
342 // released when the session is over.
343 Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
344 std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
345 [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
346 *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
347 return Status::OK();
348 };
349 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
350 "Conv2d", "im2col_buffer",
351 &im2col_buffer_resource, creator));
352
353 // Create a resize cache memory buffer that will hold the rows of
354 // transformed and mirror padded input pixels, ready to be copied
355 // into filter patches by im2col.
356 // It's laid out like this, in row major order in memory:
357 // < cache line width >
358 // ^ +--------------------+
359 // cache | |
360 // height | |
361 // v +--------------------+
362 // Each cache row contains a cache_line_width number of resized pixels,
363 // each with input_depth channels. The cache height is typically less than
364 // the full height the resized image would be, so it's filled up
365 // incrementally as we progress downwards through the input creating im2col
366 // patches.
367 task_params.cache_start_x = -filter_left_offset;
368 task_params.cache_end_x =
369 (((output_width - 1) * stride_cols) - filter_left_offset) +
370 filter_width;
371 task_params.cache_line_width =
372 task_params.cache_end_x - task_params.cache_start_x;
373 task_params.cache_height =
374 kResizeCacheSize / (task_params.cache_line_width * input_depth);
375 const int needed_resize_cache_count =
376 filter_height * task_params.cache_line_width * input_depth;
377 OP_REQUIRES(context,
378 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
379 errors::InvalidArgument("Input too large for resize cache"));
380 Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
381 std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
382 resize_creator =
383 [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
384 *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
385 return Status::OK();
386 };
387 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
388 "Conv2d", "resize_cache",
389 &resize_cache_resource, resize_creator));
390
391 // This means that multiple ops can't be run simultaneously on different
392 // threads, because we have a single shared resource. The platforms this is
393 // aimed at have intra-op parallelism as their focus though, so it shouldn't
394 // be an issue.
395 mutex_lock lock_buffer(im2col_buffer_resource->mu);
396 core::ScopedUnref unref_buffer(im2col_buffer_resource);
397 T1* im2col_buffer = im2col_buffer_resource->data;
398
399 // This buffer is used as a fairly heavy-weight cache for the resized and
400 // mirrored inputs to the im2col operation. The problem is that we want to
401 // keep the memory usage down by not rendering the fully resized and padded
402 // input tensor to the convolution into an entire buffer. The first approach
403 // to avoid this was to fold the bilinear filtering and padding spatial
404 // transformations into the im2col lookup itself. This successfully reduced
405 // memory usage, but because im2col can access an individual pixel for many
406 // different patches, the extra overhead of doing the same bilinear lookups
407 // repeatedly became too expensive.
408 // The resize cache is designed to avoid this problem by keeping a
409 // horizontal slice of the resized and padded input to the im2col
410 // precalculated, so that repeated accesses to the same pixel from different
411 // filter patches can just be copied from this cache. It's organized as a
412 // horizontal slice stretching across the whole virtual image, and as high
413 // as the filter window, so that as the patch processing moves across all
414 // the pixels are present, and before a new row of patches is started any
415 // previously calculated rows that are needed are maintained, with new rows
416 // calculated as required.
417 mutex_lock resize_lock_buffer(resize_cache_resource->mu);
418 core::ScopedUnref unref_resized_cache(resize_cache_resource);
419 task_params.resize_cache = resize_cache_resource->data;
420
421 const T1* input_data = input.flat<T1>().data();
422 const int64 input_height = input.shape().dim_sizes()[1];
423 task_params.input_width = input.shape().dim_sizes()[2];
424
425 int end_cached_lines = std::numeric_limits<int>::min();
426
427 for (int batch = 0; batch < input_batches; ++batch) {
428 task_params.input_batch_start =
429 input_data +
430 (batch * input_height * task_params.input_width * input_depth);
431 const int in_y_end =
432 ((output_height * stride_rows) - filter_top_offset) + filter_height;
433 for (int out_y = 0; out_y < output_height; ++out_y) {
434 const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
435 const int cache_start_y = std::max(in_y_origin, end_cached_lines);
436 const int cache_end_y = std::min(
437 in_y_end, std::max((in_y_origin + task_params.cache_height),
438 end_cached_lines));
439 if (end_cached_lines < (in_y_origin + filter_height)) {
440 // This call breaks up the work required for calculating the mirror
441 // padding and resizing across multiple threads.
442 FusedConvParallelFor(
443 context, cache_start_y, cache_end_y,
444 [task_params](int64 task_cache_start_y, int64 task_cache_end_y) {
445 // This is a long and confusing function, but it's been laid out
446 // this way to help with performance on some intensive models.
447 // What it's doing is populating a cache of the original input
448 // image, after it's been bilinear resized and had its edges
449 // mirrored. This allows the following im2col code to access the
450 // transformed pixels from this cache, without having to
451 // repeatedly apply the expensive bilinear calculations as the
452 // same pixels are accessed by different patches.
453 // This is most effective when the stride is small and the
454 // filter size is large, since that's when pixels are reused
455 // most frequently as patches overlap.
456 for (int cache_y = task_cache_start_y;
457 cache_y < task_cache_end_y; ++cache_y) {
458 // We organize the cache as a series of rows, each containing
459 // all the transformed pixels for a given line in the image.
460 // This cache is big enough to hold at least a filter's height
461 // worth of rows, but typically more, limited by the size of
462 // the cache buffer.
463 // We don't allocate an entire image's worth of rows though,
464 // because we're trying to keep memory usage down, so as we
465 // progress downwards through the im2col we periodically
466 // refresh the cache so that the next lines that are needed
467 // for that operation are always present.
468 // Work out the parameters that remain constant across the
469 // row we're calculating.
470 PerCacheLineParameters<T1> line_params(
471 CalculatePerCacheLineParameters<T1>(
472 task_params.cache_height, cache_y,
473 task_params.resize_cache,
474 task_params.cache_line_width, task_params.input_width,
475 task_params.input_depth, task_params.top_padding,
476 task_params.pad_offset, task_params.resized_height,
477 task_params.st, task_params.input_batch_start));
478 // Iterate through the resize cache row we're filling in.
479 for (int cache_x = task_params.cache_start_x;
480 cache_x < task_params.cache_end_x; ++cache_x) {
481 // Figure out what we need for the cache pixel we're
482 // populating.
483 PerCachePixelParameters<T1> pixel_params(
484 CalculatePerCachePixelParameters<T1>(
485 cache_x, task_params.cache_start_x,
486 line_params.cache_line_start,
487 task_params.input_depth, task_params.left_padding,
488 task_params.pad_offset, task_params.resized_width,
489 task_params.st));
490 // If the access is off the left, right, top, or bottom of
491 // the resized image, the conv padding means we should set
492 // it to zero.
493 if ((cache_x < 0) ||
494 (cache_x >= task_params.padded_width) ||
495 (cache_y < 0) ||
496 (cache_y >= task_params.padded_height)) {
497 std::fill_n(pixel_params.cache_line_pixel,
498 task_params.input_depth, T1(0));
499 } else {
500 // There are two different sampling strategies for
501 // resizing. When using nearest, we can just do a
502 // straight copy of the pixel closest to our sample point,
503 // but bilinear requires a more complex calculation.
504 if (SampleMode == NEAREST) {
505 const T1* input_top_left_pixel =
506 line_params.input_top_row_start +
507 (pixel_params.left_x_index *
508 task_params.input_depth);
509
510 std::copy_n(input_top_left_pixel,
511 task_params.input_depth,
512 pixel_params.cache_line_pixel);
513 } else {
514 const SampleRect<T1> rect(
515 line_params.input_top_row_start +
516 (pixel_params.left_x_index *
517 task_params.input_depth),
518 line_params.input_top_row_start +
519 (pixel_params.right_x_index *
520 task_params.input_depth),
521 line_params.input_bottom_row_start +
522 (pixel_params.left_x_index *
523 task_params.input_depth),
524 line_params.input_bottom_row_start +
525 (pixel_params.right_x_index *
526 task_params.input_depth));
527 for (int in_channel = 0;
528 in_channel < task_params.input_depth;
529 ++in_channel) {
530 pixel_params.cache_line_pixel[in_channel] =
531 rect.BilinearSample(in_channel,
532 pixel_params.x_lerp,
533 line_params.y_lerp);
534 }
535 }
536 }
537 }
538 }
539 });
540 end_cached_lines = cache_end_y;
541 }
542 for (int out_x = 0; out_x < output_width; ++out_x) {
543 const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
544 const int patch_index = (batch * output_width * output_height) +
545 (out_y * output_width) + out_x;
546 const int patch_index_within_chunk = patch_index % patches_per_chunk;
547 T1* im2col_patch_start =
548 im2col_buffer + (patch_index_within_chunk * filter_value_count);
549 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
550 T1* im2col_row_start =
551 im2col_patch_start +
552 (filter_y * filter_width * task_params.input_depth);
553 const int conv_in_y = in_y_origin + filter_y;
554 int cache_index_y;
555 if (conv_in_y < 0) {
556 cache_index_y = task_params.cache_height +
557 (conv_in_y % task_params.cache_height);
558 } else {
559 cache_index_y = conv_in_y % task_params.cache_height;
560 }
561 T1* cache_line_start =
562 task_params.resize_cache +
563 (cache_index_y * task_params.cache_line_width *
564 task_params.input_depth);
565 T1* cache_filter_row_start =
566 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
567 task_params.input_depth);
568 std::copy_n(cache_filter_row_start,
569 (filter_width * task_params.input_depth),
570 im2col_row_start);
571 }
572 const bool is_last_in_chunk =
573 (patch_index_within_chunk == (patches_per_chunk - 1));
574 const bool is_last_overall =
575 ((batch == (input_batches - 1)) &&
576 (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
577 if (is_last_in_chunk || is_last_overall) {
578 // Now we've assembled a set of image patches into a matrix, apply
579 // a GEMM matrix multiply of the patches as rows, times the filter
580 // weights in columns, to get partial results in the output
581 // matrix.
582 const int how_many_patches = patch_index_within_chunk + 1;
583 const int m = how_many_patches;
584 const int n = filter_count;
585 const int k = filter_value_count;
586 const int lda = filter_value_count;
587 const int ldb = filter_count;
588 const int ldc = filter_count;
589 const size_t start_patch_index =
590 patch_index - (how_many_patches - 1);
591 T3* chunk_output_data =
592 output_data + (start_patch_index * filter_count);
593 TGemmFunctor gemm_functor;
594 gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
595 chunk_output_data, ldc);
596 }
597 }
598 }
599 }
600 }
601 };
602
603 } // namespace
604
605 // Implements a version of convolution with bilinear resizing and mirror padding
606 // included.
607 template <class T, class TConvFunctor, bool DoResize>
608 class FusedResizeConv2DUsingGemmOp : public OpKernel {
609 public:
FusedResizeConv2DUsingGemmOp(OpKernelConstruction * context)610 explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
611 : OpKernel(context) {
612 if (DoResize) {
613 OP_REQUIRES_OK(context,
614 context->GetAttr("resize_align_corners", &align_corners_));
615 }
616 MirrorPadMode mode;
617 OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
618
619 switch (mode) {
620 case MirrorPadMode::SYMMETRIC: {
621 offset_ = 0;
622 break;
623 }
624 case MirrorPadMode::REFLECT: {
625 offset_ = 1;
626 break;
627 }
628 default:
629 OP_REQUIRES(context, false,
630 errors::InvalidArgument(
631 "mode must be either REFLECT or SYMMETRIC."));
632 }
633 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
634 OP_REQUIRES(context, strides_.size() == 4,
635 errors::InvalidArgument("Sliding window strides field must "
636 "specify 4 dimensions"));
637 const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
638 const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
639 OP_REQUIRES(
640 context, stride_n == 1 && stride_c == 1,
641 errors::InvalidArgument("Current implementation does not yet support "
642 "strides in the batch and depth dimensions."));
643 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
644 }
645
Compute(OpKernelContext * context)646 void Compute(OpKernelContext* context) override {
647 // Input tensor is of the following dimensions:
648 // [ batch, in_rows, in_cols, in_depth ]
649 const Tensor& input = context->input(0);
650 OP_REQUIRES(context, (input.shape().num_elements() > 0),
651 errors::InvalidArgument("Input tensor can't be empty"));
652
653 ImageResizerState st(false, false);
654 if (DoResize) {
655 st = ImageResizerState(align_corners_, false);
656 st.ValidateAndCalculateOutputSize(context, input);
657 if (!context->status().ok()) return;
658 } else {
659 // Set up the resize parameters to do no scaling at all.
660 st.batch_size = input.dim_size(0);
661 st.out_height = input.dim_size(1);
662 st.out_width = input.dim_size(2);
663 st.in_height = input.dim_size(1);
664 st.in_width = input.dim_size(2);
665 st.channels = input.dim_size(3);
666 st.height_scale = 1.0f;
667 st.width_scale = 1.0f;
668 }
669 TensorShape resized_shape(
670 {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
671 int paddings_index;
672 int filter_index;
673 if (DoResize) {
674 paddings_index = 2;
675 filter_index = 3;
676 } else {
677 paddings_index = 1;
678 filter_index = 2;
679 }
680 const Tensor& paddings = context->input(paddings_index);
681
682 const int dims = resized_shape.dims();
683 OP_REQUIRES(
684 context,
685 TensorShapeUtils::IsMatrix(paddings.shape()) &&
686 paddings.dim_size(1) == 2,
687 errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
688 paddings.shape().DebugString()));
689 OP_REQUIRES(
690 context, dims == paddings.dim_size(0),
691 errors::InvalidArgument(
692 "The first dimension of paddings must be the rank of inputs: ",
693 dims, " ", paddings.shape().DebugString(), " ",
694 resized_shape.DebugString()));
695 OP_REQUIRES(
696 context, dims == paddings.dim_size(0),
697 errors::InvalidArgument(
698 "The first dimension of paddings must be the rank of inputs: ",
699 dims, " ", paddings.shape().DebugString(), " ",
700 resized_shape.DebugString()));
701
702 OP_REQUIRES(
703 context, dims == 4,
704 errors::InvalidArgument(
705 "Fused mirror padding only supports four-dimensional inputs, but ",
706 dims, " requested"));
707
708 // Compute the shape of the output tensor, and allocate it.
709 TensorShape padded_shape;
710 TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
711 for (int d = 0; d < dims; ++d) {
712 const int32 before =
713 paddings_matrix(d, 0); // Pad before existing elements.
714 const int32 after =
715 paddings_matrix(d, 1); // Pad after existing elements.
716 OP_REQUIRES(context, before >= 0 && after >= 0,
717 errors::InvalidArgument(
718 "paddings must be non-negative: ", before, " ", after));
719 if (offset_ == 0) { // SYMMETRIC mode.
720 OP_REQUIRES(
721 context,
722 before <= resized_shape.dim_size(d) &&
723 after <= resized_shape.dim_size(d),
724 errors::InvalidArgument("paddings must be no greater "
725 "than the dimension size: ",
726 before, ", ", after, " greater than ",
727 resized_shape.dim_size(d)));
728 } else if (offset_ == 1) { // REFLECT mode.
729 OP_REQUIRES(
730 context,
731 before < resized_shape.dim_size(d) &&
732 after < resized_shape.dim_size(d),
733 errors::InvalidArgument("paddings must be less than"
734 " the dimension size: ",
735 before, ", ", after, " not less than ",
736 resized_shape.dim_size(d)));
737 }
738 padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
739 }
740
741 OP_REQUIRES(
742 context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
743 errors::InvalidArgument(
744 "Fused mirror padding only support spatial padding, not batches: ",
745 paddings.DebugString()));
746 OP_REQUIRES(
747 context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
748 errors::InvalidArgument(
749 "Fused mirror padding only support spatial padding, not channels: ",
750 paddings.DebugString()));
751 const int32 top_padding = paddings_matrix(1, 0);
752 const int32 bottom_padding = paddings_matrix(1, 1);
753 const int32 left_padding = paddings_matrix(2, 0);
754 const int32 right_padding = paddings_matrix(2, 1);
755
756 // Input filter is of the following dimensions:
757 // [ filter_rows, filter_cols, in_depth, out_depth]
758 const Tensor& filter = context->input(filter_index);
759
760 // For 2D convolution, there should be 4 dimensions.
761 OP_REQUIRES(context, padded_shape.dims() == 4,
762 errors::InvalidArgument("input must be 4-dimensional",
763 padded_shape.DebugString()));
764 OP_REQUIRES(context, filter.dims() == 4,
765 errors::InvalidArgument("filter must be 4-dimensional: ",
766 filter.shape().DebugString()));
767
768 // We only check the first three dims, since the depth is accessed as an
769 // int64 below.
770 for (int i = 0; i < 3; i++) {
771 OP_REQUIRES(
772 context,
773 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
774 errors::InvalidArgument("filter too large"));
775 }
776
777 // The last dimension for input is in_depth. It must be the same as the
778 // filter's in_depth.
779 const int64 in_depth = padded_shape.dim_size(3);
780 OP_REQUIRES(context, in_depth == filter.dim_size(2),
781 errors::InvalidArgument(
782 "input and filter must have the same depth: ", in_depth,
783 " vs ", filter.dim_size(2)));
784
785 // The last dimension for filter is out_depth.
786 const int out_depth = static_cast<int>(filter.dim_size(3));
787
788 // The second dimension for input is rows/height.
789 // The first dimension for filter is rows/height.
790 const int64 padded_rows_raw = padded_shape.dim_size(1);
791 OP_REQUIRES(
792 context,
793 FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
794 errors::InvalidArgument("Input rows too large"));
795 const int padded_rows = static_cast<int>(padded_rows_raw);
796 const int filter_rows = static_cast<int>(filter.dim_size(0));
797 const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
798
799 // The third dimension for input is columns/width.
800 // The second dimension for filter is columns/width.
801 const int64 padded_cols_raw = padded_shape.dim_size(2);
802 OP_REQUIRES(
803 context,
804 FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
805 errors::InvalidArgument("Input cols too large"));
806 const int padded_cols = static_cast<int>(padded_cols_raw);
807 const int filter_cols = static_cast<int>(filter.dim_size(1));
808 const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
809
810 // The first dimension for input is batch.
811 const int64 batch_raw = padded_shape.dim_size(0);
812 OP_REQUIRES(context,
813 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
814 errors::InvalidArgument("batch is too large"));
815 const int batch = static_cast<int>(batch_raw);
816
817 // For now we take the stride from the second and third dimensions only (we
818 // do not support striding on the batch or depth dimension).
819 const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
820 const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
821
822 int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
823 OP_REQUIRES_OK(context,
824 GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
825 padding_, &out_rows, &pad_rows));
826 OP_REQUIRES_OK(context,
827 GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
828 padding_, &out_cols, &pad_cols));
829 TensorShape out_shape =
830 ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
831 OP_REQUIRES(context, (out_shape.num_elements() > 0),
832 errors::InvalidArgument("Output tensor can't be empty"));
833
834 // Output tensor is of the following dimensions:
835 // [ in_batch, out_rows, out_cols, out_depth ]
836 Tensor* output = nullptr;
837 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
838
839 VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
840 << ", padded_cols = " << padded_cols
841 << ", resized_cols = " << resized_cols
842 << ", filter_cols = " << filter_cols
843 << ", padded_rows = " << padded_rows
844 << ", resized_rows = " << resized_rows
845 << ", filter_rows = " << filter_rows
846 << ", stride_rows = " << stride_rows
847 << ", stride_cols = " << stride_cols
848 << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
849
850 // If there is nothing to compute, return.
851 if (out_shape.num_elements() == 0) {
852 return;
853 }
854 TConvFunctor conv_functor;
855 conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
856 padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
857 filter_cols, out_depth, stride_rows, stride_cols, padding_,
858 output->flat<T>().data(), out_rows, out_cols, st, top_padding,
859 bottom_padding, left_padding, right_padding, offset_);
860 }
861
862 private:
863 std::vector<int32> strides_;
864 Padding padding_;
865 bool align_corners_;
866 int offset_;
867
868 TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
869 };
870
871 #define REGISTER_FUSED(T) \
872 REGISTER_KERNEL_BUILDER( \
873 Name("FusedResizeAndPadConv2D") \
874 .Device(DEVICE_CPU) \
875 .TypeConstraint<T>("T"), \
876 FusedResizeConv2DUsingGemmOp< \
877 T, \
878 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
879 BILINEAR>, \
880 true>);
881
882 TF_CALL_half(REGISTER_FUSED);
883 TF_CALL_float(REGISTER_FUSED);
884 TF_CALL_double(REGISTER_FUSED);
885
886 #define REGISTER_PAD_ONLY_FUSED(T) \
887 REGISTER_KERNEL_BUILDER( \
888 Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
889 FusedResizeConv2DUsingGemmOp< \
890 T, \
891 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
892 NEAREST>, \
893 false>);
894
895 TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
896 TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
897 TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
898
899 } // namespace tensorflow
900