1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18
19 #include "tensorflow/core/kernels/deep_conv2d.h"
20
21 #include <stdlib.h>
22
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/kernels/winograd_transform.h"
25 #include "tensorflow/core/util/work_sharder.h"
26
27 namespace tensorflow {
28
29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
30 // large 'in_depth' and 'out_depth' product. See cost models below for details).
31 //
32 // DeepConv2D is implemented by computing the following equation:
33 //
34 // y = C[Ad * Bg]
35 //
36 // C: output transform matrix
37 // A: input data transform matrix
38 // B: filter transform matrix
39 // d: vectorized data tile
40 // g: vectorized filter tile
41 // y: vectorized output tile
42 //
43 // The transform matrices and input, filter and output tile sizes are all
44 // specified by the DeepConv2DTransform implementation selected at the
45 // start of the DeepConv2D call, based on convolution parameters.
46
47 // Approximate cost models for direct and deep convolutions.
GetDeepConvCost(int input_tile_rows,int input_tile_cols,int out_tile_rows,int out_tile_cols,int in_depth,int out_depth,int out_rows,int out_cols)48 static int64 GetDeepConvCost(int input_tile_rows, int input_tile_cols,
49 int out_tile_rows, int out_tile_cols, int in_depth,
50 int out_depth, int out_rows, int out_cols) {
51 // Input transform cost.
52 const int64_t input_tile_spatial_size = input_tile_rows * input_tile_cols;
53 const int64_t input_transform_cost =
54 input_tile_spatial_size * input_tile_spatial_size * in_depth;
55
56 // Element-wise products (each product is a MatMul across depth).
57 const int64_t product_cost = input_tile_spatial_size * in_depth * out_depth;
58
59 // Output transform cost.
60 const int64_t output_tile_spatial_size = out_tile_rows * out_tile_cols;
61 const int64_t output_transform_cost =
62 output_tile_spatial_size * input_tile_spatial_size * out_depth;
63
64 // Calculate number of input tiles to process.
65 const int64_t row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
66 const int64_t col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
67 const int64_t num_tiles = row_tiles * col_tiles;
68
69 // Return total cost.
70 return num_tiles *
71 (input_transform_cost + product_cost + output_transform_cost);
72 }
73
GetDirectConvCost(int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)74 static int64 GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
75 int out_depth, int out_rows, int out_cols) {
76 return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
77 }
78
79 // Reads environment variable 'env_var_name'.
80 // Returns 'true' if environment variable is enabled, false otherwise.
ReadBoolFromEnvVar(const char * env_var_name,bool default_val)81 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
82 const char* tf_env_var_val = getenv(env_var_name);
83 if (tf_env_var_val != nullptr) {
84 StringPiece tf_env_var_val_str(tf_env_var_val);
85 if (tf_env_var_val_str == "0") {
86 return false;
87 }
88 return true;
89 }
90 return default_val;
91 }
92
93 // Returns true if convolution can be computed efficiently by DeepConv2D,
94 // returns false otherwise.
95 // TODO(andydavis) Add support for other filter sizes and strides.
96 // TODO(andydavis) Add support for autotuning.
CanUseDeepConv2D(int stride_rows,int stride_cols,int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)97 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
98 int filter_cols, int in_depth, int out_depth,
99 int out_rows, int out_cols) {
100 // Check if convolution parameters are supported.
101 // TODO(andydavis) Add support for multiple filter sizes and strides.
102 if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
103 filter_cols != 3) {
104 return false;
105 }
106
107 // Check if deep convolution is enabled by environment variable.
108 // NOTE: IF this environment variable name changes, update conv_ops_test.py.
109 if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
110 return false;
111 }
112
113 // Check if flop cost of deep convolution is less than direct convolution.
114 WinogradTransform<float> t;
115 const int64_t deep_conv_cost = GetDeepConvCost(
116 t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
117 t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
118 const int64_t direct_conv_cost = GetDirectConvCost(
119 filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
120
121 VLOG(2) << "CanUseDeepConv2D"
122 << " deep_conv_cost: " << deep_conv_cost
123 << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
124 << (static_cast<float>(deep_conv_cost) /
125 static_cast<float>(direct_conv_cost))
126 << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
127 return deep_conv_cost < direct_conv_cost;
128 }
129
130 typedef Eigen::ThreadPoolDevice CPUDevice;
131
132 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
133 //
134 // filter_in:
135 // [filter_rows, filter_cols, in_depth, out_depth]
136 //
137 // filter_buf:
138 // [base_filter_rows, base_filter_cols, in_depth]
139 //
140 template <typename T>
141 struct CopyFilterDepth {
operator ()tensorflow::CopyFilterDepth142 void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
143 typedef typename Eigen::internal::packet_traits<T>::type Packet;
144 static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
145
146 const int64_t vectorized_size = args.in_depth / kPacketSize;
147 const int64_t scalar_size = args.in_depth % kPacketSize;
148 const int64_t input_stride = args.out_depth * kPacketSize;
149
150 // Copy vectorized portion of depth dimension.
151 for (int64_t d = 0; d < vectorized_size; ++d) {
152 auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
153 args.out_depth);
154 Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
155 }
156 // Copy scalar portion of inner dimension.
157 const int64_t in_scalar_base = vectorized_size * input_stride;
158 const int64_t buf_scalar_base = vectorized_size * kPacketSize;
159 for (int64_t d = 0; d < scalar_size; ++d) {
160 filter_buf[buf_scalar_base + d] =
161 filter_in[in_scalar_base + d * args.out_depth];
162 }
163 }
164 };
165
166 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
167 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
168 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
169 // 'filter_out' at the coordinate stride required by the transformed filter
170 // data layout.
171 //
172 // filter_in:
173 // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
174 // in_depth]
175 //
176 // filter_out:
177 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
178 //
179 // transform_matrix:
180 // [tile_spatial_size, base_filter_spatial_size]
181 //
182 // out_buffer:
183 // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
184
185 template <typename T>
186 struct ComputeFilterRangeTransform {
187 typedef typename Eigen::internal::packet_traits<T>::type Packet;
188 static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
189
190 typedef Eigen::Map<
191 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
192 MatrixMap;
193 typedef Eigen::Map<
194 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
195 ConstMatrixMap;
196
operator ()tensorflow::ComputeFilterRangeTransform197 void operator()(const Conv2DArgs& args,
198 const DeepConv2DTransform<T>* transform,
199 const int64_t od_start, const int64_t num_filters,
200 const int64_t shard_rows, const int64_t shard_cols,
201 const T* filter_in, const int64_t in_stride,
202 const int64_t out_stride, const T* transform_matrix,
203 T* out_buffer, T* filter_out) {
204 namespace ei = Eigen::internal;
205
206 const int64_t in_depth = args.in_depth;
207 const int64_t base_filter_rows = transform->filter_shape().rows;
208 const int64_t base_filter_cols = transform->filter_shape().cols;
209 const int64_t base_filter_spatial_size =
210 base_filter_rows * base_filter_cols;
211 const int64_t tile_rows = transform->input_shape().rows;
212 const int64_t tile_cols = transform->input_shape().cols;
213 const int64_t tile_spatial_size = tile_rows * tile_cols;
214
215 // Compute transform of 'num_filters' by 'transform_matrix'.
216 ConstMatrixMap A(transform_matrix, tile_spatial_size,
217 base_filter_spatial_size);
218 ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
219 MatrixMap C(out_buffer, tile_spatial_size, in_stride);
220
221 C.noalias() = A * B;
222
223 // Copy 'out_buffer' to 'filter_out' at required filter output stride.
224 const int64_t scalar_size = in_depth % kPacketSize;
225 const int64_t vectorized_size = in_depth / kPacketSize;
226
227 const int64_t shard_stride = args.in_depth;
228 const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
229
230 for (int64_t od = 0; od < num_filters; ++od) {
231 const int64_t out_depth_buf_base = od * out_depth_stride;
232 const int64_t out_depth_base = (od_start + od) * out_depth_stride;
233
234 // TODO(andydavis) Shard filters that are multiples of base filter sizes.
235 for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
236 for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
237 const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
238
239 for (int64_t i = 0; i < tile_spatial_size; ++i) {
240 const int64_t in_base =
241 i * in_stride + out_depth_buf_base + shard_base;
242 const int64_t out_base =
243 i * out_stride + out_depth_base + shard_base;
244 // Copy vectorized portion of 'in_depth'.
245 for (int64_t d = 0; d < vectorized_size; ++d) {
246 auto v =
247 ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
248 ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
249 }
250 // Transform scalar portion of 'in_depth'.
251 const int64_t scalar_base = vectorized_size * kPacketSize;
252 for (int64_t d = 0; d < scalar_size; ++d) {
253 filter_out[out_base + scalar_base + d] =
254 out_buffer[in_base + scalar_base + d];
255 }
256 }
257 }
258 }
259 }
260 }
261 };
262
263 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
264 // For each filter in 'num_filters', copies data for all filter shards from
265 // 'filter_in' into 'filter_buf', adding zero-padding as needed.
266 // Calls ComputeFilterRangeTransform to compute filter transform of data
267 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
268 //
269 // filter_in:
270 // [filter_rows, filter_cols, in_depth, out_depth]
271 //
272 // filter_out:
273 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
274 //
275 // filter_buffer:
276 // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
277 // in_depth]
278 //
279 // transform_matrix:
280 // [tile_spatial_size, base_filter_spatial_size]
281 //
282 // out_buffer:
283 // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
284 //
285
286 template <typename T>
287 struct TransformFilterRange {
operator ()tensorflow::TransformFilterRange288 void operator()(const Conv2DArgs& args,
289 const DeepConv2DTransform<T>* transform,
290 const int64_t od_start, const int64_t od_limit,
291 const T* filter_in, const T* transform_matrix, T* out_buffer,
292 T* filter_buf, T* filter_out) {
293 const int64_t num_filters = od_limit - od_start;
294 const int64_t base_filter_rows = transform->filter_shape().rows;
295 const int64_t base_filter_cols = transform->filter_shape().cols;
296 const int64_t base_filter_spatial_size =
297 base_filter_rows * base_filter_cols;
298
299 // Compute number of filter shards.
300 const int64_t residual_row =
301 std::max(int64{0}, args.filter_rows - base_filter_rows);
302 const int64_t shard_rows = 1 + (residual_row + 2 - 1) / 2;
303
304 const int64_t residual_col =
305 std::max(int64{0}, args.filter_cols - base_filter_cols);
306 const int64_t shard_cols = 1 + (residual_col + 2 - 1) / 2;
307
308 // Compute strides to be used for input and output IO.
309 const int64_t shard_stride = args.in_depth;
310 const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
311 const int64_t coord_stride = out_depth_stride * args.out_depth;
312 const int64_t filter_buf_stride =
313 num_filters * shard_rows * shard_cols * args.in_depth;
314 const int64_t tile_stride_rows = transform->output_shape().rows;
315 const int64_t tile_stride_cols = transform->output_shape().cols;
316
317 const int64_t filter_buf_size = base_filter_spatial_size * num_filters *
318 shard_rows * shard_cols * args.in_depth;
319 memset(filter_buf, 0, sizeof(T) * filter_buf_size);
320
321 // Copy filter range into 'filter_buf'.
322 for (int64_t od = 0; od < num_filters; ++od) {
323 const int64_t out_depth_base = od * out_depth_stride;
324
325 // TODO(andydavis) Shard filters that are multiples of base filter sizes.
326 for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
327 const int64_t row_offset = s_r == 0 ? 0 : 1;
328
329 for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
330 const int64_t col_offset = s_c == 0 ? 0 : 1;
331 const int64_t f_r_start = s_r * tile_stride_rows;
332 const int64_t f_c_start = s_c * tile_stride_cols;
333
334 const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
335
336 for (int64_t b_r = row_offset; b_r < base_filter_rows; ++b_r) {
337 const int64_t f_r = f_r_start + b_r;
338 if (f_r >= args.filter_rows) continue;
339
340 for (int64_t b_c = col_offset; b_c < base_filter_cols; ++b_c) {
341 const int64_t f_c = f_c_start + b_c;
342 if (f_c >= args.filter_cols) continue;
343
344 const int64_t in_index =
345 args.out_depth *
346 (args.in_depth * (f_r * args.filter_cols + f_c)) +
347 (od_start + od);
348
349 const int64_t buf_index =
350 filter_buf_stride * (b_r * base_filter_cols + b_c) +
351 out_depth_base + shard_base;
352
353 CopyFilterDepth<T>()(args, filter_in + in_index,
354 filter_buf + buf_index);
355 }
356 }
357 }
358 }
359 }
360
361 // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
362 // Intermediate results are stored in 'out_buffer'.
363 // Final results are stored in 'filter_out'.
364 ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
365 shard_rows, shard_cols, filter_buf,
366 filter_buf_stride, coord_stride,
367 transform_matrix, out_buffer, filter_out);
368 }
369 };
370
371 // Transforms all filters from 'filter_in', storing result in 'filter_out'.
372 //
373 // filter_in:
374 // [filter_rows, filter_cols, in_depth, out_depth]
375 //
376 // filter_out:
377 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
378 //
379 template <typename T>
380 struct TransformFilters {
operator ()tensorflow::TransformFilters381 void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
382 const DeepConv2DTransform<T>* transform,
383 const int64_t filter_shards_row,
384 const int64_t filter_shards_col, const T* filter_in,
385 T* filter_out) {
386 const int64_t in_depth = args.in_depth;
387 const int64_t out_depth = args.out_depth;
388
389 const int64_t tile_rows = transform->input_shape().rows;
390 const int64_t tile_cols = transform->input_shape().cols;
391 const int64_t tile_spatial_size = tile_rows * tile_cols;
392
393 const int64_t base_filter_rows = transform->filter_shape().rows;
394 const int64_t base_filter_cols = transform->filter_shape().cols;
395 const int64_t base_filter_spatial_size =
396 base_filter_rows * base_filter_cols;
397
398 const int64_t filter_shards_total = filter_shards_row * filter_shards_col;
399
400 // Calculate filter transform batch based on cache/filter sizes.
401
402 // Cache budget (based on L2 cache size = 256KB).
403 // TODO(andydavis) Read cache size from system.
404 const int64_t cache_size = (256LL << 10) / sizeof(T);
405
406 // Fixed cost.
407 const int64_t filter_transform_matrix_size =
408 tile_spatial_size * base_filter_spatial_size;
409
410 // Per-filter costs.
411 const int64_t filter_total_size =
412 base_filter_spatial_size * in_depth * filter_shards_total;
413
414 const int64_t filter_transform_buffer_size =
415 base_filter_spatial_size * filter_shards_total * in_depth;
416
417 const int64_t filter_out_buf_size =
418 tile_spatial_size * filter_shards_total * in_depth;
419
420 // Total per-filter costs.
421 const int64_t per_filter_cost =
422 filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
423
424 // Remove fixed cost and divide by per-filter cost.
425 const int64_t num_filters_cache =
426 std::max(int64{1},
427 (cache_size - filter_transform_matrix_size) / per_filter_cost);
428 const int64_t num_filters_transform =
429 std::min(out_depth, num_filters_cache);
430
431 // Allocate buffer for filter transform matrix:
432 // [tile_spatial_size, base_filter_spatial_size]
433 Tensor filter_transform_matrix;
434 OP_REQUIRES_OK(
435 ctx, ctx->allocate_temp(
436 DataTypeToEnum<T>::value,
437 TensorShape({tile_spatial_size, base_filter_spatial_size}),
438 &filter_transform_matrix));
439 T* transform_matrix = filter_transform_matrix.template flat<T>().data();
440 transform->GetFilterTransformMatrix(
441 tile_spatial_size, base_filter_spatial_size, transform_matrix);
442
443 auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
444 &num_filters_transform, &in_depth, &filter_shards_row,
445 &filter_shards_col, &tile_spatial_size, &filter_in,
446 &transform_matrix,
447 &filter_out](int64_t start, int64_t limit) {
448 // Allocate buffer for pre-processed filter:
449 // [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
450 //
451 Tensor filter_transform_buffer;
452 OP_REQUIRES_OK(ctx,
453 ctx->allocate_temp(
454 DataTypeToEnum<T>::value,
455 TensorShape({base_filter_rows, base_filter_cols,
456 num_filters_transform, filter_shards_row,
457 filter_shards_col, in_depth}),
458 &filter_transform_buffer));
459 T* filter_buf = filter_transform_buffer.template flat<T>().data();
460
461 // Allocate buffer for output filter transform matrix:
462 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
463 Tensor filter_output_buffer;
464 OP_REQUIRES_OK(
465 ctx,
466 ctx->allocate_temp(
467 DataTypeToEnum<T>::value,
468 TensorShape({tile_spatial_size, num_filters_transform,
469 filter_shards_row, filter_shards_col, in_depth}),
470 &filter_output_buffer));
471 T* out_buffer = filter_output_buffer.template flat<T>().data();
472
473 const int64_t num_filters = limit - start;
474 const int64_t od_unroll = num_filters_transform;
475 const int64_t od_unroll_limit = (num_filters / od_unroll) * od_unroll;
476
477 for (int64_t od = start; od < od_unroll_limit; od += od_unroll) {
478 TransformFilterRange<T>()(args, transform, od, od + od_unroll,
479 filter_in, transform_matrix, out_buffer,
480 filter_buf, filter_out);
481 }
482
483 if (od_unroll_limit < limit) {
484 TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
485 filter_in, transform_matrix, out_buffer,
486 filter_buf, filter_out);
487 }
488 };
489 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
490
491 const int64_t shard_cost = args.filter_rows * args.filter_cols * in_depth *
492 filter_shards_total * tile_spatial_size;
493 // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
494 Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
495 }
496 };
497
498 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
499 // gemm-kernel friendly data layout.
500 //
501 // Data layout for 'lhs_block':
502 // [out_depth, shard_rows, shard_cols, in_depth].
503
504 template <typename T>
505 class GemmFilterPacker {
506 public:
507 typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
508 LhsMapper;
509 typedef Eigen::internal::gebp_traits<T, T> Traits;
510 Eigen::internal::gemm_pack_lhs<
511 T, int64, LhsMapper, Traits::mr, Traits::LhsProgress,
512 typename Traits::LhsPacket4Packing, Eigen::RowMajor>
513 pack_lhs;
514
GemmFilterPacker(const int64_t rows,const int64_t depth,const T * lhs_input,T * lhs_block)515 GemmFilterPacker(const int64_t rows, const int64_t depth, const T* lhs_input,
516 T* lhs_block)
517 : rows_(rows),
518 depth_(depth),
519 lhs_block_(lhs_block),
520 lhs_mapper_(lhs_input, depth_) {}
521
Run()522 void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
523
524 private:
525 const int64 rows_;
526 const int64 depth_;
527 T* lhs_block_;
528 LhsMapper lhs_mapper_;
529 };
530
531 // Packs transformed filter stored in 'filter_transform_data' into
532 // 'packed_filters' to be used by GemmState.
533 template <typename T>
534 struct PackFilters {
operator ()tensorflow::PackFilters535 void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
536 const int64_t tile_spatial_size,
537 const int64_t filter_shards_row,
538 const int64_t filter_shards_col,
539 const T* filter_transform_data,
540 std::vector<Tensor>* packed_filters) {
541 const int64_t in_depth = args.in_depth;
542 const int64_t out_depth = args.out_depth;
543 const int64_t num_filters =
544 filter_shards_row * filter_shards_col * out_depth;
545
546 auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
547 &out_depth, &filter_shards_row, &filter_shards_col,
548 &num_filters](int64_t start, int64_t limit) {
549 const int64_t filter_coord_stride = num_filters * in_depth;
550 for (int64_t i = start; i < limit; ++i) {
551 // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
552 OP_REQUIRES_OK(
553 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
554 TensorShape({out_depth, filter_shards_row,
555 filter_shards_col, in_depth}),
556 &(*packed_filters)[i]));
557 T* packed_filter = (*packed_filters)[i].template flat<T>().data();
558 // Pack filters.
559 GemmFilterPacker<T> packer(
560 num_filters, in_depth,
561 filter_transform_data + i * filter_coord_stride, packed_filter);
562 packer.Run();
563 }
564 };
565 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
566 Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
567 num_filters * in_depth, shard);
568 }
569 };
570
571 // Computes the product of filters stored in 'lhs_block' and input tiles
572 // stored in 'rhs_block', storing output in 'out_buffer'.
573 //
574 // Data layout for 'lhs_block':
575 // [out_depth, shard_rows, shard_cols, in_depth].
576 //
577 // Data layout for 'rhs_block':
578 // [num_tiles, in_depth]
579 //
580 // Data layout for 'out_buffer':
581 // [num_tiles, out_depth, shard_rows, shard_cols]
582
583 template <typename T>
584 class GemmState {
585 public:
586 typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::ColMajor>
587 RhsMapper;
588 typedef Eigen::internal::blas_data_mapper<T, int64, Eigen::ColMajor>
589 OutputMapper;
590 typedef Eigen::internal::gebp_traits<T, T> Traits;
591
592 Eigen::internal::gemm_pack_rhs<T, int64, RhsMapper, Traits::nr,
593 Eigen::ColMajor>
594 pack_rhs;
595 Eigen::internal::gebp_kernel<T, T, int64, OutputMapper, Traits::mr,
596 Traits::nr, false, false>
597 gebp;
598
GemmState(const int64_t rows,const int64_t cols,const int64_t depth,const int64_t out_buffer_size,const T * lhs_block,const T * rhs_input,T * rhs_block,T * out_buffer)599 GemmState(const int64_t rows, const int64_t cols, const int64_t depth,
600 const int64_t out_buffer_size, const T* lhs_block,
601 const T* rhs_input, T* rhs_block, T* out_buffer)
602 : rows_(rows),
603 cols_(cols),
604 depth_(depth),
605 out_buffer_size_(out_buffer_size),
606 lhs_block_(lhs_block),
607 rhs_block_(rhs_block),
608 out_buffer_(out_buffer),
609 rhs_mapper_(rhs_input, depth_),
610 out_mapper_(out_buffer, rows_) {}
611
PackRhs()612 void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
613
Compute()614 void Compute() {
615 memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
616 gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
617 }
618
619 private:
620 const int64 rows_;
621 const int64 cols_;
622 const int64 depth_;
623 const int64 out_buffer_size_;
624 const T* lhs_block_;
625 T* rhs_block_;
626 T* out_buffer_;
627 RhsMapper rhs_mapper_;
628 OutputMapper out_mapper_;
629 };
630
631 // Copies an input tile from 'input' into 'tile_buffer'.
632 //
633 // input:
634 // [in_rows, in_cols, in_depth]
635 //
636 // tile_buffer:
637 // [tile_rows, tile_cols, num_tiles, in_depth]
638
639 template <typename T>
640 struct CopyInputTile {
operator ()tensorflow::CopyInputTile641 void operator()(const Conv2DArgs& args,
642 const DeepConv2DTransform<T>* transform,
643 const int64_t num_tiles, const int64_t in_r_start,
644 const int64_t in_c_start, const T* input, T* tile_buffer) {
645 typedef typename Eigen::internal::packet_traits<T>::type Packet;
646 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
647
648 const int64_t tile_rows = transform->input_shape().rows;
649 const int64_t tile_cols = transform->input_shape().cols;
650 const int64_t coord_stride = num_tiles * args.in_depth;
651
652 // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
653 const int64_t input_vectorized_size =
654 (args.in_depth / kPacketSize) * kPacketSize;
655 const int64_t input_scalar_size = args.in_depth % kPacketSize;
656
657 for (int64_t r = 0; r < tile_rows; ++r) {
658 const int64_t in_r = in_r_start + r;
659 if (in_r < 0 || in_r >= args.in_rows) continue;
660
661 for (int64_t c = 0; c < tile_cols; ++c) {
662 const int64_t in_c = in_c_start + c;
663 if (in_c < 0 || in_c >= args.in_cols) continue;
664
665 auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
666 auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
667 // Copy vectorized portion of depth dimension.
668 for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) {
669 auto v = Eigen::internal::ploadu<Packet>(in + d);
670 Eigen::internal::pstoreu<T>(tile, v);
671 tile += kPacketSize;
672 }
673 // Copy scalar portion of inner dimension.
674 for (int64_t d = 0; d < input_scalar_size; ++d) {
675 tile[d] = in[input_vectorized_size + d];
676 }
677 }
678 }
679 }
680 };
681
682 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
683 // final result in 'tile_transform'.
684 // Intermediate results are stored in 'tile_buffer'.
685 //
686 // input:
687 // [in_rows, in_cols, in_depth]
688 // tile_buffer:
689 // [tile_rows, tile_cols, num_tiles, in_depth]
690 // tile_transform_matrix:
691 // [tile_spatial_size, tile_spatial_size]
692 // tile_transform:
693 // [tile_rows, tile_cols, num_tiles, in_depth]
694
695 template <typename T>
696 struct TransformInputTiles {
697 typedef Eigen::Map<
698 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
699 MatrixMap;
700 typedef Eigen::Map<
701 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
702 ConstMatrixMap;
703
operator ()tensorflow::TransformInputTiles704 void operator()(const Conv2DArgs& args,
705 const DeepConv2DTransform<T>* transform,
706 const int64_t num_tiles, const int64_t in_r_start,
707 const int64_t in_c_start, const T* input,
708 const T* transform_matrix, T* tile_buffer,
709 T* tile_transform) {
710 const int64_t tile_rows = transform->input_shape().rows;
711 const int64_t tile_cols = transform->input_shape().cols;
712 const int64_t tile_spatial_size = tile_rows * tile_cols;
713 const int64_t tile_stride_cols = transform->output_shape().cols;
714 const int64_t coord_stride = num_tiles * args.in_depth;
715 const int64_t num_tiles_stride = args.in_depth;
716
717 memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
718 const int64_t in_r = in_r_start;
719 for (int64_t t = 0; t < num_tiles; ++t) {
720 const int64_t num_tiles_base = t * num_tiles_stride;
721 const int64_t in_c = in_c_start + t * tile_stride_cols;
722 CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
723 tile_buffer + num_tiles_base);
724 }
725
726 ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
727 ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
728 MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
729
730 C.noalias() = A * B;
731 }
732 };
733
734 // Transforms output tiles from buffer by 'out_transform_matrix', storing
735 // final result in 'output' (intermediate results stored in 'out_buffer').
736 //
737 // out_buffer:
738 // [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
739 //
740 // output transform buffer:
741 // [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
742 //
743 // output:
744 // [out_rows, out_cols, out_depth]
745 //
746
747 template <typename T>
748 struct TransformOutputTile {
749 typedef Eigen::Map<
750 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
751 MatrixMap;
752 typedef Eigen::Map<
753 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
754 ConstMatrixMap;
755
operator ()tensorflow::TransformOutputTile756 void operator()(const Conv2DArgs& args,
757 const DeepConv2DTransform<T>* transform,
758 const int64_t num_tiles, const int64_t in_r,
759 const int64_t in_c, const int64_t filter_shards_row,
760 const int64_t filter_shards_col,
761 const T* out_transform_matrix, const T* out_buffer,
762 T* out_transform_buffer, T* output) {
763 const int64_t tile_rows = transform->input_shape().rows;
764 const int64_t tile_cols = transform->input_shape().cols;
765 const int64_t tile_spatial_size = tile_rows * tile_cols;
766
767 const int64_t out_buf_stride =
768 num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
769
770 const int64_t out_tile_rows = transform->output_shape().rows;
771 const int64_t out_tile_cols = transform->output_shape().cols;
772 const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
773
774 // Compute output transform.
775 ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
776 tile_spatial_size);
777 ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
778 MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
779
780 C.noalias() = A * B;
781
782 const int64_t tile_stride_rows = transform->output_shape().rows;
783 const int64_t tile_stride_cols = transform->output_shape().cols;
784
785 const int64_t out_depth_stride = filter_shards_row * filter_shards_col;
786 const int64_t num_tiles_stride = args.out_depth * out_depth_stride;
787
788 // Copy transformed output from 'out_transform_buffer' to proper index
789 // in 'output'. Note that some outputs at boundaries can be discarded.
790 for (int64_t t = 0; t < num_tiles; ++t) {
791 const int64_t tile_base = t * num_tiles_stride;
792
793 for (int64_t od = 0; od < args.out_depth; ++od) {
794 const int64_t out_depth_base = od * out_depth_stride;
795
796 // TODO(andydavis) Update filter sharding scheme in the next CL.
797 for (int64_t sr = 0; sr < filter_shards_row; ++sr) {
798 for (int64_t sc = 0; sc < filter_shards_col; ++sc) {
799 const int64_t shard_base = sr * filter_shards_col + sc;
800 const int64_t out_buf_base =
801 tile_base + out_depth_base + shard_base;
802
803 // Calculate output indices and outputs to drop (if needed).
804 const int64_t out_r_start =
805 in_r + args.pad_rows - sr * tile_stride_rows;
806 // NOTE: The index 't' for 'num_tiles is used in index calculation
807 // for 'out_c_start' because we 'num_tiles' progresses along the
808 // column dimension.
809 const int64_t out_c_start = (in_c + t * tile_stride_cols) +
810 args.pad_cols - sc * tile_stride_cols;
811
812 if (out_r_start < 0 || out_r_start >= args.out_rows ||
813 out_c_start < 0 || out_c_start >= args.out_cols) {
814 continue; // Skip un-needed outputs.
815 }
816
817 // Increment output if not first filter shard.
818 const bool inc_output = (sr == 0 && sc == 0) ? false : true;
819
820 for (int64_t ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
821 const int64_t out_r = out_r_start + ot_row;
822 if (out_r >= args.out_rows) continue;
823
824 for (int64_t ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
825 const int64_t out_c = out_c_start + ot_col;
826 if (out_c >= args.out_cols) continue;
827
828 // Calculate out tile indexl
829 const int64_t out_buf_index = ot_row * out_tile_cols + ot_col;
830 // Read output value from buffer.
831 const T out_val =
832 out_transform_buffer[out_buf_base +
833 out_buf_index * out_buf_stride];
834 // Calculate output index.
835 const int64_t output_index =
836 args.out_depth * (out_r * args.out_cols + out_c) + od;
837 // Update output.
838 if (inc_output) {
839 output[output_index] += out_val;
840 } else {
841 output[output_index] = out_val;
842 }
843 }
844 }
845 }
846 }
847 }
848 }
849 }
850 };
851
852 template <typename T>
853 struct Conv2DState {
Conv2DStatetensorflow::Conv2DState854 Conv2DState(const int64_t tile_spatial_size, const int64_t filter_shards_row,
855 const int64_t filter_shards_col, const T* input,
856 const T* tile_transform_matrix, const T* output_transform_matrix,
857 T* buffer1, T* buffer2, T* packed_tile_buffer,
858 T* gemm_output_buffer)
859 : tile_spatial_size(tile_spatial_size),
860 filter_shards_row(filter_shards_row),
861 filter_shards_col(filter_shards_col),
862 input(input),
863 tile_transform_matrix(tile_transform_matrix),
864 output_transform_matrix(output_transform_matrix),
865 buffer1(buffer1),
866 buffer2(buffer2),
867 packed_tile_buffer(packed_tile_buffer),
868 gemm_output_buffer(gemm_output_buffer) {}
869
870 const int64 tile_spatial_size;
871 const int64 filter_shards_row;
872 const int64 filter_shards_col;
873 const T* input;
874 const T* tile_transform_matrix;
875 const T* output_transform_matrix;
876 T* buffer1;
877 T* buffer2;
878 T* packed_tile_buffer;
879 T* gemm_output_buffer;
880 };
881
882 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
883 // (in_r, in_c), storing the results of the computation in 'output'.
884 // Details:
885 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
886 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
887 // *) Transforms output tiles, and stores result to 'output'.
888
889 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
890 template <typename T>
891 struct ComputeConv2D {
operator ()tensorflow::ComputeConv2D892 void operator()(const Conv2DArgs& args,
893 const DeepConv2DTransform<T>* transform,
894 const Conv2DState<T>& cs, const int64_t in_r,
895 const int64_t in_c, const int64_t num_tiles,
896 const std::vector<Tensor>& packed_filters, const T* input,
897 T* output) {
898 // Transform input tiles.
899 TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
900 cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
901
902 // Compute element-wise product (each a MatMul): input tiles X filters.
903 const int64_t in_depth = args.in_depth;
904 const int64_t out_depth = args.out_depth;
905 const int64_t num_filters =
906 cs.filter_shards_row * cs.filter_shards_col * out_depth;
907 const int64_t tile_coord_stride = num_tiles * in_depth;
908 const int64_t gemm_out_buf_size = num_tiles * num_filters;
909 const int64_t gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
910
911 for (int64_t i = 0; i < cs.tile_spatial_size; ++i) {
912 GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
913 packed_filters[i].template flat<T>().data(),
914 cs.buffer2 + i * tile_coord_stride,
915 cs.packed_tile_buffer, cs.gemm_output_buffer);
916 // Pack tile buffer.
917 gemm.PackRhs();
918 // Compute product.
919 gemm.Compute();
920 // Copy to larger output buffer without alignment requirements.
921 memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
922 gemm_out_buf_bytes);
923 }
924
925 // Transform output.
926 TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
927 cs.filter_shards_row, cs.filter_shards_col,
928 cs.output_transform_matrix, cs.buffer1, cs.buffer2,
929 output);
930 }
931 };
932
933 namespace functor {
934
935 // Conv2D operation specialized for deep convolutions (i.e. large
936 // in_depth * out_depth).
937 // Details:
938 // *) Transforms and packs filters from 'filter' in parallel.
939 // *) Computes Conv2D parallelized across 'batch' dimension.
940 // *) Each thread loops over images in its batch shard, copying 'num_tiles'
941 // input tiles into a local buffer, and computing the Conv2D output of
942 // these tiles by all filters.
943
944 // TODO(andydavis) Improve the performance of boundary cases where the input
945 // tile extends past the limit, and wasted outputs are computed. This overhead
946 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
947 // for smaller spatial sizes.
948 // TODO(andydavis) Improve the performance of sharded filters.
949 template <typename T>
950 struct DeepConv2D<CPUDevice, T> {
operator ()tensorflow::functor::DeepConv2D951 void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
952 const T* filter, T* output) {
953 // TODO(andydavis) Add function to select transform based on conv params.
954 std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
955
956 const int64_t in_depth = args.in_depth;
957 const int64_t out_depth = args.out_depth;
958
959 const int64_t tile_rows = transform->input_shape().rows;
960 const int64_t tile_cols = transform->input_shape().cols;
961 const int64_t tile_spatial_size = tile_rows * tile_cols;
962
963 const int64_t out_tile_rows = transform->output_shape().rows;
964 const int64_t out_tile_cols = transform->output_shape().cols;
965 const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
966
967 const int64_t base_filter_rows = transform->filter_shape().rows;
968
969 const int64_t filter_residual_row =
970 std::max(int64{0}, args.filter_rows - base_filter_rows);
971 const int64_t filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
972
973 const int64_t filter_residual_col =
974 std::max(int64{0}, args.filter_cols - base_filter_rows);
975 const int64_t filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
976
977 // Allocate buffer for transformed filters.
978 Tensor filter_transform;
979 OP_REQUIRES_OK(
980 ctx, ctx->allocate_temp(
981 DataTypeToEnum<T>::value,
982 TensorShape({tile_rows, tile_cols, out_depth,
983 filter_shards_row, filter_shards_col, in_depth}),
984 &filter_transform));
985 T* filter_transform_data = filter_transform.template flat<T>().data();
986
987 // Transform filters.
988 TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
989 filter_shards_col, filter, filter_transform_data);
990
991 // Pack filters.
992 std::vector<Tensor> packed_filters(tile_spatial_size);
993 PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
994 filter_shards_col, filter_transform_data, &packed_filters);
995
996 // Allocate buffer for tile transform matrix.
997 Tensor tile_transform_matrix_tensor;
998 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
999 DataTypeToEnum<T>::value,
1000 TensorShape({tile_spatial_size, tile_spatial_size}),
1001 &tile_transform_matrix_tensor));
1002 T* tile_transform_matrix =
1003 tile_transform_matrix_tensor.template flat<T>().data();
1004 transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
1005 tile_transform_matrix);
1006
1007 // Allocate buffer for output transform matrix.
1008 Tensor output_transform_matrix_tensor;
1009 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1010 TensorShape({out_tile_spatial_size,
1011 tile_spatial_size}),
1012 &output_transform_matrix_tensor));
1013 T* output_transform_matrix =
1014 output_transform_matrix_tensor.template flat<T>().data();
1015 transform->GetOutputTransformMatrix(
1016 out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
1017
1018 auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
1019 out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
1020 filter_shards_col, tile_spatial_size, &input,
1021 &tile_transform_matrix, &output_transform_matrix,
1022 &output](int64_t batch_start, int64_t batch_limit) {
1023 const int64_t row_tiles =
1024 (args.out_rows + out_tile_rows - 1) / out_tile_rows +
1025 filter_shards_row - 1;
1026 const int64_t col_tiles =
1027 (args.out_cols + out_tile_cols - 1) / out_tile_cols +
1028 filter_shards_col - 1;
1029
1030 // Calculate number of tiles to process together.
1031 const int64_t filter_shard_size = filter_shards_row * filter_shards_col;
1032 const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
1033
1034 // Cache budget (based on L2 cache size = 256KB).
1035 // TODO(andydavis) Read cache size from the system.
1036 const int64_t cache_size = (256LL << 10) / sizeof(T);
1037
1038 // Fixed costs.
1039 const int64_t tile_transform_matrix_size =
1040 tile_spatial_size * tile_spatial_size;
1041 const int64_t output_transform_matrix_size =
1042 out_tile_spatial_size * tile_spatial_size;
1043 // Calculate cache reserve size.
1044 const int64_t filter_depth_size =
1045 in_depth * out_depth * filter_shard_size;
1046 const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
1047 const int64_t cache_reserve_size =
1048 small_filter ? filter_depth_size : 1024;
1049 // Calculate total fixed cost.
1050 const int64_t total_fixed_cost = tile_transform_matrix_size +
1051 output_transform_matrix_size +
1052 cache_reserve_size;
1053
1054 // Per-tile costs.
1055 const int64_t buffer1_per_tile_size =
1056 tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
1057 const int64_t buffer2_per_tile_size =
1058 std::max(tile_spatial_size * in_depth,
1059 out_tile_spatial_size * out_depth * filter_shard_size);
1060 const int64_t packed_tile_per_tile_size = in_depth;
1061 const int64_t gemm_out_per_tile_size = out_depth * filter_shard_size;
1062 const int64_t total_per_tile_cost =
1063 buffer1_per_tile_size + buffer2_per_tile_size +
1064 packed_tile_per_tile_size + gemm_out_per_tile_size;
1065
1066 const int64_t num_tiles_cache = std::max(
1067 int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost);
1068 const int64_t num_tiles = std::min(num_tiles_cache, col_tiles);
1069
1070 // Allocate temporary buffer 'buffer1', which is first used for copying
1071 // input tiles, then re-used to buffer gemm output. Calculate the
1072 // required buffer size for 'buffer1', based on max buffer size required
1073 // between copying input tiles and buffering gemm product output.
1074 // buffer1: [max(buf1_tile_size, buf1_out_size)]
1075 const int64_t buffer1_tile_size =
1076 tile_spatial_size * num_tiles * in_depth;
1077 const int64_t buffer1_out_size =
1078 tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1079 const int64_t buffer1_size =
1080 std::max(buffer1_tile_size, buffer1_out_size);
1081 Tensor buffer1_tensor;
1082 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1083 TensorShape({buffer1_size}),
1084 &buffer1_tensor));
1085 T* buffer1 = buffer1_tensor.template flat<T>().data();
1086
1087 // Allocate temporary buffer 'buffer2', which is first used for
1088 // transformed input tiles, then re-used for transformed output tiles.
1089 // Calculate required buffer size for 'buffer2' as max required buffer
1090 // between input and output transform buffer sizes.
1091 const int64_t buffer2_tile_transform_size =
1092 tile_spatial_size * num_tiles * in_depth;
1093 const int64_t buffer2_out_transform_size =
1094 out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1095 const int64_t buffer2_size =
1096 std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
1097 Tensor buffer2_tensor;
1098 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1099 TensorShape({buffer2_size}),
1100 &buffer2_tensor));
1101 T* buffer2 = buffer2_tensor.template flat<T>().data();
1102
1103 // Allocate temporary buffer to store packed tiles for one coordinate.
1104 // packed tile buffer: [num_tiles, in_depth].
1105 Tensor packed_tile_tensor;
1106 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1107 TensorShape({num_tiles, in_depth}),
1108 &packed_tile_tensor));
1109 T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
1110
1111 // Allocate temporary buffer for gemm output.
1112 // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
1113 Tensor gemm_output_tensor;
1114 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1115 TensorShape({num_tiles, out_depth,
1116 filter_shards_row,
1117 filter_shards_col}),
1118 &gemm_output_tensor));
1119 T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
1120
1121 // Capture state needed for ComputeConv2D inner loop.
1122 Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
1123 filter_shards_col, input, tile_transform_matrix,
1124 output_transform_matrix, buffer1, buffer2,
1125 packed_tile_buffer, gemm_output_buffer);
1126
1127 const int64_t row_pad = args.pad_rows;
1128 const int64_t col_pad = args.pad_cols;
1129 const int64_t unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
1130
1131 const int64_t input_image_size = args.in_rows * args.in_cols * in_depth;
1132 const int64_t output_image_size =
1133 args.out_rows * args.out_cols * out_depth;
1134
1135 const int64_t tile_stride_rows = transform->output_shape().rows;
1136 const int64_t tile_stride_cols = transform->output_shape().cols;
1137
1138 for (int64_t b = batch_start; b < batch_limit; ++b) {
1139 const int64_t in_base = b * input_image_size;
1140 const int64_t out_base = b * output_image_size;
1141
1142 for (int64_t tile_r = 0; tile_r < row_tiles; ++tile_r) {
1143 const int64_t in_r = tile_r * tile_stride_rows - row_pad;
1144
1145 // Process unrolled tiles.
1146 for (int64_t tile_c = 0; tile_c < unroll_col_limit;
1147 tile_c += num_tiles) {
1148 const int64_t in_c = tile_c * tile_stride_cols - col_pad;
1149 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1150 num_tiles, packed_filters, input + in_base,
1151 output + out_base);
1152 }
1153 // Process remaining tiles.
1154 if (unroll_col_limit < col_tiles) {
1155 const int64_t rem_tiles = col_tiles - unroll_col_limit;
1156 const int64_t in_c = unroll_col_limit * tile_stride_cols - col_pad;
1157 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1158 rem_tiles, packed_filters, input + in_base,
1159 output + out_base);
1160 }
1161 }
1162 }
1163 };
1164 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
1165 const int64_t shard_cost = args.out_rows * args.out_cols * args.out_depth *
1166 tile_spatial_size * args.in_depth;
1167 Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
1168 shard_cost, shard);
1169 }
1170 };
1171
1172 } // namespace functor
1173
1174 template struct functor::DeepConv2D<CPUDevice, float>;
1175
1176 } // namespace tensorflow
1177