• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/deep_conv2d.h"
20 
21 #include <stdlib.h>
22 
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/kernels/winograd_transform.h"
25 #include "tensorflow/core/util/work_sharder.h"
26 
27 namespace tensorflow {
28 
29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
30 // large 'in_depth' and 'out_depth' product. See cost models below for details).
31 //
32 // DeepConv2D is implemented by computing the following equation:
33 //
34 //   y = C[Ad * Bg]
35 //
36 //   C: output transform matrix
37 //   A: input data transform matrix
38 //   B: filter transform matrix
39 //   d: vectorized data tile
40 //   g: vectorized filter tile
41 //   y: vectorized output tile
42 //
43 // The transform matrices and input, filter and output tile sizes are all
44 // specified by the DeepConv2DTransform implementation selected at the
45 // start of the DeepConv2D call, based on convolution parameters.
46 
47 // Approximate cost models for direct and deep convolutions.
GetDeepConvCost(int input_tile_rows,int input_tile_cols,int out_tile_rows,int out_tile_cols,int in_depth,int out_depth,int out_rows,int out_cols)48 static int64 GetDeepConvCost(int input_tile_rows, int input_tile_cols,
49                              int out_tile_rows, int out_tile_cols, int in_depth,
50                              int out_depth, int out_rows, int out_cols) {
51   // Input transform cost.
52   const int64_t input_tile_spatial_size = input_tile_rows * input_tile_cols;
53   const int64_t input_transform_cost =
54       input_tile_spatial_size * input_tile_spatial_size * in_depth;
55 
56   // Element-wise products (each product is a MatMul across depth).
57   const int64_t product_cost = input_tile_spatial_size * in_depth * out_depth;
58 
59   // Output transform cost.
60   const int64_t output_tile_spatial_size = out_tile_rows * out_tile_cols;
61   const int64_t output_transform_cost =
62       output_tile_spatial_size * input_tile_spatial_size * out_depth;
63 
64   // Calculate number of input tiles to process.
65   const int64_t row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
66   const int64_t col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
67   const int64_t num_tiles = row_tiles * col_tiles;
68 
69   // Return total cost.
70   return num_tiles *
71          (input_transform_cost + product_cost + output_transform_cost);
72 }
73 
GetDirectConvCost(int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)74 static int64 GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
75                                int out_depth, int out_rows, int out_cols) {
76   return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
77 }
78 
79 // Reads environment variable 'env_var_name'.
80 // Returns 'true' if environment variable is enabled, false otherwise.
ReadBoolFromEnvVar(const char * env_var_name,bool default_val)81 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
82   const char* tf_env_var_val = getenv(env_var_name);
83   if (tf_env_var_val != nullptr) {
84     StringPiece tf_env_var_val_str(tf_env_var_val);
85     if (tf_env_var_val_str == "0") {
86       return false;
87     }
88     return true;
89   }
90   return default_val;
91 }
92 
93 // Returns true if convolution can be computed efficiently by DeepConv2D,
94 // returns false otherwise.
95 // TODO(andydavis) Add support for other filter sizes and strides.
96 // TODO(andydavis) Add support for autotuning.
CanUseDeepConv2D(int stride_rows,int stride_cols,int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)97 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
98                       int filter_cols, int in_depth, int out_depth,
99                       int out_rows, int out_cols) {
100   // Check if convolution parameters are supported.
101   // TODO(andydavis) Add support for multiple filter sizes and strides.
102   if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
103       filter_cols != 3) {
104     return false;
105   }
106 
107   // Check if deep convolution is enabled by environment variable.
108   // NOTE: IF this environment variable name changes, update conv_ops_test.py.
109   if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
110     return false;
111   }
112 
113   // Check if flop cost of deep convolution is less than direct convolution.
114   WinogradTransform<float> t;
115   const int64_t deep_conv_cost = GetDeepConvCost(
116       t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
117       t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
118   const int64_t direct_conv_cost = GetDirectConvCost(
119       filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
120 
121   VLOG(2) << "CanUseDeepConv2D"
122           << " deep_conv_cost: " << deep_conv_cost
123           << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
124           << (static_cast<float>(deep_conv_cost) /
125               static_cast<float>(direct_conv_cost))
126           << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
127   return deep_conv_cost < direct_conv_cost;
128 }
129 
130 typedef Eigen::ThreadPoolDevice CPUDevice;
131 
132 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
133 //
134 // filter_in:
135 //   [filter_rows, filter_cols, in_depth, out_depth]
136 //
137 // filter_buf:
138 //   [base_filter_rows, base_filter_cols, in_depth]
139 //
140 template <typename T>
141 struct CopyFilterDepth {
operator ()tensorflow::CopyFilterDepth142   void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
143     typedef typename Eigen::internal::packet_traits<T>::type Packet;
144     static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
145 
146     const int64_t vectorized_size = args.in_depth / kPacketSize;
147     const int64_t scalar_size = args.in_depth % kPacketSize;
148     const int64_t input_stride = args.out_depth * kPacketSize;
149 
150     // Copy vectorized portion of depth dimension.
151     for (int64_t d = 0; d < vectorized_size; ++d) {
152       auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
153                                                    args.out_depth);
154       Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
155     }
156     // Copy scalar portion of inner dimension.
157     const int64_t in_scalar_base = vectorized_size * input_stride;
158     const int64_t buf_scalar_base = vectorized_size * kPacketSize;
159     for (int64_t d = 0; d < scalar_size; ++d) {
160       filter_buf[buf_scalar_base + d] =
161           filter_in[in_scalar_base + d * args.out_depth];
162     }
163   }
164 };
165 
166 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
167 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
168 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
169 // 'filter_out' at the coordinate stride required by the transformed filter
170 // data layout.
171 //
172 // filter_in:
173 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
174 //    in_depth]
175 //
176 // filter_out:
177 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
178 //
179 // transform_matrix:
180 //   [tile_spatial_size, base_filter_spatial_size]
181 //
182 // out_buffer:
183 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
184 
185 template <typename T>
186 struct ComputeFilterRangeTransform {
187   typedef typename Eigen::internal::packet_traits<T>::type Packet;
188   static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
189 
190   typedef Eigen::Map<
191       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
192       MatrixMap;
193   typedef Eigen::Map<
194       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
195       ConstMatrixMap;
196 
operator ()tensorflow::ComputeFilterRangeTransform197   void operator()(const Conv2DArgs& args,
198                   const DeepConv2DTransform<T>* transform,
199                   const int64_t od_start, const int64_t num_filters,
200                   const int64_t shard_rows, const int64_t shard_cols,
201                   const T* filter_in, const int64_t in_stride,
202                   const int64_t out_stride, const T* transform_matrix,
203                   T* out_buffer, T* filter_out) {
204     namespace ei = Eigen::internal;
205 
206     const int64_t in_depth = args.in_depth;
207     const int64_t base_filter_rows = transform->filter_shape().rows;
208     const int64_t base_filter_cols = transform->filter_shape().cols;
209     const int64_t base_filter_spatial_size =
210         base_filter_rows * base_filter_cols;
211     const int64_t tile_rows = transform->input_shape().rows;
212     const int64_t tile_cols = transform->input_shape().cols;
213     const int64_t tile_spatial_size = tile_rows * tile_cols;
214 
215     // Compute transform of 'num_filters' by 'transform_matrix'.
216     ConstMatrixMap A(transform_matrix, tile_spatial_size,
217                      base_filter_spatial_size);
218     ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
219     MatrixMap C(out_buffer, tile_spatial_size, in_stride);
220 
221     C.noalias() = A * B;
222 
223     // Copy 'out_buffer' to 'filter_out' at required filter output stride.
224     const int64_t scalar_size = in_depth % kPacketSize;
225     const int64_t vectorized_size = in_depth / kPacketSize;
226 
227     const int64_t shard_stride = args.in_depth;
228     const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
229 
230     for (int64_t od = 0; od < num_filters; ++od) {
231       const int64_t out_depth_buf_base = od * out_depth_stride;
232       const int64_t out_depth_base = (od_start + od) * out_depth_stride;
233 
234       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
235       for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
236         for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
237           const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
238 
239           for (int64_t i = 0; i < tile_spatial_size; ++i) {
240             const int64_t in_base =
241                 i * in_stride + out_depth_buf_base + shard_base;
242             const int64_t out_base =
243                 i * out_stride + out_depth_base + shard_base;
244             // Copy vectorized portion of 'in_depth'.
245             for (int64_t d = 0; d < vectorized_size; ++d) {
246               auto v =
247                   ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
248               ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
249             }
250             // Transform scalar portion of 'in_depth'.
251             const int64_t scalar_base = vectorized_size * kPacketSize;
252             for (int64_t d = 0; d < scalar_size; ++d) {
253               filter_out[out_base + scalar_base + d] =
254                   out_buffer[in_base + scalar_base + d];
255             }
256           }
257         }
258       }
259     }
260   }
261 };
262 
263 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
264 // For each filter in 'num_filters', copies data for all filter shards from
265 // 'filter_in' into 'filter_buf', adding zero-padding as needed.
266 // Calls ComputeFilterRangeTransform to compute filter transform of data
267 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
268 //
269 // filter_in:
270 //   [filter_rows, filter_cols, in_depth, out_depth]
271 //
272 // filter_out:
273 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
274 //
275 // filter_buffer:
276 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
277 //    in_depth]
278 //
279 // transform_matrix:
280 //   [tile_spatial_size, base_filter_spatial_size]
281 //
282 // out_buffer:
283 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
284 //
285 
286 template <typename T>
287 struct TransformFilterRange {
operator ()tensorflow::TransformFilterRange288   void operator()(const Conv2DArgs& args,
289                   const DeepConv2DTransform<T>* transform,
290                   const int64_t od_start, const int64_t od_limit,
291                   const T* filter_in, const T* transform_matrix, T* out_buffer,
292                   T* filter_buf, T* filter_out) {
293     const int64_t num_filters = od_limit - od_start;
294     const int64_t base_filter_rows = transform->filter_shape().rows;
295     const int64_t base_filter_cols = transform->filter_shape().cols;
296     const int64_t base_filter_spatial_size =
297         base_filter_rows * base_filter_cols;
298 
299     // Compute number of filter shards.
300     const int64_t residual_row =
301         std::max(int64{0}, args.filter_rows - base_filter_rows);
302     const int64_t shard_rows = 1 + (residual_row + 2 - 1) / 2;
303 
304     const int64_t residual_col =
305         std::max(int64{0}, args.filter_cols - base_filter_cols);
306     const int64_t shard_cols = 1 + (residual_col + 2 - 1) / 2;
307 
308     // Compute strides to be used for input and output IO.
309     const int64_t shard_stride = args.in_depth;
310     const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
311     const int64_t coord_stride = out_depth_stride * args.out_depth;
312     const int64_t filter_buf_stride =
313         num_filters * shard_rows * shard_cols * args.in_depth;
314     const int64_t tile_stride_rows = transform->output_shape().rows;
315     const int64_t tile_stride_cols = transform->output_shape().cols;
316 
317     const int64_t filter_buf_size = base_filter_spatial_size * num_filters *
318                                     shard_rows * shard_cols * args.in_depth;
319     memset(filter_buf, 0, sizeof(T) * filter_buf_size);
320 
321     // Copy filter range into 'filter_buf'.
322     for (int64_t od = 0; od < num_filters; ++od) {
323       const int64_t out_depth_base = od * out_depth_stride;
324 
325       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
326       for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
327         const int64_t row_offset = s_r == 0 ? 0 : 1;
328 
329         for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
330           const int64_t col_offset = s_c == 0 ? 0 : 1;
331           const int64_t f_r_start = s_r * tile_stride_rows;
332           const int64_t f_c_start = s_c * tile_stride_cols;
333 
334           const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
335 
336           for (int64_t b_r = row_offset; b_r < base_filter_rows; ++b_r) {
337             const int64_t f_r = f_r_start + b_r;
338             if (f_r >= args.filter_rows) continue;
339 
340             for (int64_t b_c = col_offset; b_c < base_filter_cols; ++b_c) {
341               const int64_t f_c = f_c_start + b_c;
342               if (f_c >= args.filter_cols) continue;
343 
344               const int64_t in_index =
345                   args.out_depth *
346                       (args.in_depth * (f_r * args.filter_cols + f_c)) +
347                   (od_start + od);
348 
349               const int64_t buf_index =
350                   filter_buf_stride * (b_r * base_filter_cols + b_c) +
351                   out_depth_base + shard_base;
352 
353               CopyFilterDepth<T>()(args, filter_in + in_index,
354                                    filter_buf + buf_index);
355             }
356           }
357         }
358       }
359     }
360 
361     // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
362     // Intermediate results are stored in 'out_buffer'.
363     // Final results are stored in 'filter_out'.
364     ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
365                                      shard_rows, shard_cols, filter_buf,
366                                      filter_buf_stride, coord_stride,
367                                      transform_matrix, out_buffer, filter_out);
368   }
369 };
370 
371 // Transforms all filters from 'filter_in', storing result in 'filter_out'.
372 //
373 // filter_in:
374 //   [filter_rows, filter_cols, in_depth, out_depth]
375 //
376 // filter_out:
377 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
378 //
379 template <typename T>
380 struct TransformFilters {
operator ()tensorflow::TransformFilters381   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
382                   const DeepConv2DTransform<T>* transform,
383                   const int64_t filter_shards_row,
384                   const int64_t filter_shards_col, const T* filter_in,
385                   T* filter_out) {
386     const int64_t in_depth = args.in_depth;
387     const int64_t out_depth = args.out_depth;
388 
389     const int64_t tile_rows = transform->input_shape().rows;
390     const int64_t tile_cols = transform->input_shape().cols;
391     const int64_t tile_spatial_size = tile_rows * tile_cols;
392 
393     const int64_t base_filter_rows = transform->filter_shape().rows;
394     const int64_t base_filter_cols = transform->filter_shape().cols;
395     const int64_t base_filter_spatial_size =
396         base_filter_rows * base_filter_cols;
397 
398     const int64_t filter_shards_total = filter_shards_row * filter_shards_col;
399 
400     // Calculate filter transform batch based on cache/filter sizes.
401 
402     // Cache budget (based on L2 cache size = 256KB).
403     // TODO(andydavis) Read cache size from system.
404     const int64_t cache_size = (256LL << 10) / sizeof(T);
405 
406     // Fixed cost.
407     const int64_t filter_transform_matrix_size =
408         tile_spatial_size * base_filter_spatial_size;
409 
410     // Per-filter costs.
411     const int64_t filter_total_size =
412         base_filter_spatial_size * in_depth * filter_shards_total;
413 
414     const int64_t filter_transform_buffer_size =
415         base_filter_spatial_size * filter_shards_total * in_depth;
416 
417     const int64_t filter_out_buf_size =
418         tile_spatial_size * filter_shards_total * in_depth;
419 
420     // Total per-filter costs.
421     const int64_t per_filter_cost =
422         filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
423 
424     // Remove fixed cost and divide by per-filter cost.
425     const int64_t num_filters_cache =
426         std::max(int64{1},
427                  (cache_size - filter_transform_matrix_size) / per_filter_cost);
428     const int64_t num_filters_transform =
429         std::min(out_depth, num_filters_cache);
430 
431     // Allocate buffer for filter transform matrix:
432     //   [tile_spatial_size, base_filter_spatial_size]
433     Tensor filter_transform_matrix;
434     OP_REQUIRES_OK(
435         ctx, ctx->allocate_temp(
436                  DataTypeToEnum<T>::value,
437                  TensorShape({tile_spatial_size, base_filter_spatial_size}),
438                  &filter_transform_matrix));
439     T* transform_matrix = filter_transform_matrix.template flat<T>().data();
440     transform->GetFilterTransformMatrix(
441         tile_spatial_size, base_filter_spatial_size, transform_matrix);
442 
443     auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
444                   &num_filters_transform, &in_depth, &filter_shards_row,
445                   &filter_shards_col, &tile_spatial_size, &filter_in,
446                   &transform_matrix,
447                   &filter_out](int64_t start, int64_t limit) {
448       // Allocate buffer for pre-processed filter:
449       //   [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
450       //
451       Tensor filter_transform_buffer;
452       OP_REQUIRES_OK(ctx,
453                      ctx->allocate_temp(
454                          DataTypeToEnum<T>::value,
455                          TensorShape({base_filter_rows, base_filter_cols,
456                                       num_filters_transform, filter_shards_row,
457                                       filter_shards_col, in_depth}),
458                          &filter_transform_buffer));
459       T* filter_buf = filter_transform_buffer.template flat<T>().data();
460 
461       // Allocate buffer for output filter transform matrix:
462       //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
463       Tensor filter_output_buffer;
464       OP_REQUIRES_OK(
465           ctx,
466           ctx->allocate_temp(
467               DataTypeToEnum<T>::value,
468               TensorShape({tile_spatial_size, num_filters_transform,
469                            filter_shards_row, filter_shards_col, in_depth}),
470               &filter_output_buffer));
471       T* out_buffer = filter_output_buffer.template flat<T>().data();
472 
473       const int64_t num_filters = limit - start;
474       const int64_t od_unroll = num_filters_transform;
475       const int64_t od_unroll_limit = (num_filters / od_unroll) * od_unroll;
476 
477       for (int64_t od = start; od < od_unroll_limit; od += od_unroll) {
478         TransformFilterRange<T>()(args, transform, od, od + od_unroll,
479                                   filter_in, transform_matrix, out_buffer,
480                                   filter_buf, filter_out);
481       }
482 
483       if (od_unroll_limit < limit) {
484         TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
485                                   filter_in, transform_matrix, out_buffer,
486                                   filter_buf, filter_out);
487       }
488     };
489     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
490 
491     const int64_t shard_cost = args.filter_rows * args.filter_cols * in_depth *
492                                filter_shards_total * tile_spatial_size;
493     // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
494     Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
495   }
496 };
497 
498 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
499 // gemm-kernel friendly data layout.
500 //
501 // Data layout for 'lhs_block':
502 //   [out_depth, shard_rows, shard_cols, in_depth].
503 
504 template <typename T>
505 class GemmFilterPacker {
506  public:
507   typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
508       LhsMapper;
509   typedef Eigen::internal::gebp_traits<T, T> Traits;
510   Eigen::internal::gemm_pack_lhs<
511       T, int64, LhsMapper, Traits::mr, Traits::LhsProgress,
512       typename Traits::LhsPacket4Packing, Eigen::RowMajor>
513       pack_lhs;
514 
GemmFilterPacker(const int64_t rows,const int64_t depth,const T * lhs_input,T * lhs_block)515   GemmFilterPacker(const int64_t rows, const int64_t depth, const T* lhs_input,
516                    T* lhs_block)
517       : rows_(rows),
518         depth_(depth),
519         lhs_block_(lhs_block),
520         lhs_mapper_(lhs_input, depth_) {}
521 
Run()522   void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
523 
524  private:
525   const int64 rows_;
526   const int64 depth_;
527   T* lhs_block_;
528   LhsMapper lhs_mapper_;
529 };
530 
531 // Packs transformed filter stored in 'filter_transform_data' into
532 // 'packed_filters' to be used by GemmState.
533 template <typename T>
534 struct PackFilters {
operator ()tensorflow::PackFilters535   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
536                   const int64_t tile_spatial_size,
537                   const int64_t filter_shards_row,
538                   const int64_t filter_shards_col,
539                   const T* filter_transform_data,
540                   std::vector<Tensor>* packed_filters) {
541     const int64_t in_depth = args.in_depth;
542     const int64_t out_depth = args.out_depth;
543     const int64_t num_filters =
544         filter_shards_row * filter_shards_col * out_depth;
545 
546     auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
547                   &out_depth, &filter_shards_row, &filter_shards_col,
548                   &num_filters](int64_t start, int64_t limit) {
549       const int64_t filter_coord_stride = num_filters * in_depth;
550       for (int64_t i = start; i < limit; ++i) {
551         // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
552         OP_REQUIRES_OK(
553             ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
554                                     TensorShape({out_depth, filter_shards_row,
555                                                  filter_shards_col, in_depth}),
556                                     &(*packed_filters)[i]));
557         T* packed_filter = (*packed_filters)[i].template flat<T>().data();
558         // Pack filters.
559         GemmFilterPacker<T> packer(
560             num_filters, in_depth,
561             filter_transform_data + i * filter_coord_stride, packed_filter);
562         packer.Run();
563       }
564     };
565     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
566     Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
567           num_filters * in_depth, shard);
568   }
569 };
570 
571 // Computes the product of filters stored in 'lhs_block' and input tiles
572 // stored in 'rhs_block', storing output in 'out_buffer'.
573 //
574 // Data layout for 'lhs_block':
575 //   [out_depth, shard_rows, shard_cols, in_depth].
576 //
577 // Data layout for 'rhs_block':
578 //   [num_tiles, in_depth]
579 //
580 // Data layout for 'out_buffer':
581 //   [num_tiles, out_depth, shard_rows, shard_cols]
582 
583 template <typename T>
584 class GemmState {
585  public:
586   typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::ColMajor>
587       RhsMapper;
588   typedef Eigen::internal::blas_data_mapper<T, int64, Eigen::ColMajor>
589       OutputMapper;
590   typedef Eigen::internal::gebp_traits<T, T> Traits;
591 
592   Eigen::internal::gemm_pack_rhs<T, int64, RhsMapper, Traits::nr,
593                                  Eigen::ColMajor>
594       pack_rhs;
595   Eigen::internal::gebp_kernel<T, T, int64, OutputMapper, Traits::mr,
596                                Traits::nr, false, false>
597       gebp;
598 
GemmState(const int64_t rows,const int64_t cols,const int64_t depth,const int64_t out_buffer_size,const T * lhs_block,const T * rhs_input,T * rhs_block,T * out_buffer)599   GemmState(const int64_t rows, const int64_t cols, const int64_t depth,
600             const int64_t out_buffer_size, const T* lhs_block,
601             const T* rhs_input, T* rhs_block, T* out_buffer)
602       : rows_(rows),
603         cols_(cols),
604         depth_(depth),
605         out_buffer_size_(out_buffer_size),
606         lhs_block_(lhs_block),
607         rhs_block_(rhs_block),
608         out_buffer_(out_buffer),
609         rhs_mapper_(rhs_input, depth_),
610         out_mapper_(out_buffer, rows_) {}
611 
PackRhs()612   void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
613 
Compute()614   void Compute() {
615     memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
616     gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
617   }
618 
619  private:
620   const int64 rows_;
621   const int64 cols_;
622   const int64 depth_;
623   const int64 out_buffer_size_;
624   const T* lhs_block_;
625   T* rhs_block_;
626   T* out_buffer_;
627   RhsMapper rhs_mapper_;
628   OutputMapper out_mapper_;
629 };
630 
631 // Copies an input tile from 'input' into 'tile_buffer'.
632 //
633 // input:
634 //   [in_rows, in_cols, in_depth]
635 //
636 // tile_buffer:
637 //   [tile_rows, tile_cols, num_tiles, in_depth]
638 
639 template <typename T>
640 struct CopyInputTile {
operator ()tensorflow::CopyInputTile641   void operator()(const Conv2DArgs& args,
642                   const DeepConv2DTransform<T>* transform,
643                   const int64_t num_tiles, const int64_t in_r_start,
644                   const int64_t in_c_start, const T* input, T* tile_buffer) {
645     typedef typename Eigen::internal::packet_traits<T>::type Packet;
646     static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
647 
648     const int64_t tile_rows = transform->input_shape().rows;
649     const int64_t tile_cols = transform->input_shape().cols;
650     const int64_t coord_stride = num_tiles * args.in_depth;
651 
652     // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
653     const int64_t input_vectorized_size =
654         (args.in_depth / kPacketSize) * kPacketSize;
655     const int64_t input_scalar_size = args.in_depth % kPacketSize;
656 
657     for (int64_t r = 0; r < tile_rows; ++r) {
658       const int64_t in_r = in_r_start + r;
659       if (in_r < 0 || in_r >= args.in_rows) continue;
660 
661       for (int64_t c = 0; c < tile_cols; ++c) {
662         const int64_t in_c = in_c_start + c;
663         if (in_c < 0 || in_c >= args.in_cols) continue;
664 
665         auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
666         auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
667         // Copy vectorized portion of depth dimension.
668         for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) {
669           auto v = Eigen::internal::ploadu<Packet>(in + d);
670           Eigen::internal::pstoreu<T>(tile, v);
671           tile += kPacketSize;
672         }
673         // Copy scalar portion of inner dimension.
674         for (int64_t d = 0; d < input_scalar_size; ++d) {
675           tile[d] = in[input_vectorized_size + d];
676         }
677       }
678     }
679   }
680 };
681 
682 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
683 // final result in 'tile_transform'.
684 // Intermediate results are stored in 'tile_buffer'.
685 //
686 // input:
687 //   [in_rows, in_cols, in_depth]
688 // tile_buffer:
689 //   [tile_rows, tile_cols, num_tiles, in_depth]
690 // tile_transform_matrix:
691 //   [tile_spatial_size, tile_spatial_size]
692 // tile_transform:
693 //   [tile_rows, tile_cols, num_tiles, in_depth]
694 
695 template <typename T>
696 struct TransformInputTiles {
697   typedef Eigen::Map<
698       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
699       MatrixMap;
700   typedef Eigen::Map<
701       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
702       ConstMatrixMap;
703 
operator ()tensorflow::TransformInputTiles704   void operator()(const Conv2DArgs& args,
705                   const DeepConv2DTransform<T>* transform,
706                   const int64_t num_tiles, const int64_t in_r_start,
707                   const int64_t in_c_start, const T* input,
708                   const T* transform_matrix, T* tile_buffer,
709                   T* tile_transform) {
710     const int64_t tile_rows = transform->input_shape().rows;
711     const int64_t tile_cols = transform->input_shape().cols;
712     const int64_t tile_spatial_size = tile_rows * tile_cols;
713     const int64_t tile_stride_cols = transform->output_shape().cols;
714     const int64_t coord_stride = num_tiles * args.in_depth;
715     const int64_t num_tiles_stride = args.in_depth;
716 
717     memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
718     const int64_t in_r = in_r_start;
719     for (int64_t t = 0; t < num_tiles; ++t) {
720       const int64_t num_tiles_base = t * num_tiles_stride;
721       const int64_t in_c = in_c_start + t * tile_stride_cols;
722       CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
723                          tile_buffer + num_tiles_base);
724     }
725 
726     ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
727     ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
728     MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
729 
730     C.noalias() = A * B;
731   }
732 };
733 
734 // Transforms output tiles from buffer by 'out_transform_matrix', storing
735 // final result in 'output' (intermediate results stored in 'out_buffer').
736 //
737 // out_buffer:
738 //   [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
739 //
740 // output transform buffer:
741 //  [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
742 //
743 // output:
744 //   [out_rows, out_cols, out_depth]
745 //
746 
747 template <typename T>
748 struct TransformOutputTile {
749   typedef Eigen::Map<
750       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
751       MatrixMap;
752   typedef Eigen::Map<
753       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
754       ConstMatrixMap;
755 
operator ()tensorflow::TransformOutputTile756   void operator()(const Conv2DArgs& args,
757                   const DeepConv2DTransform<T>* transform,
758                   const int64_t num_tiles, const int64_t in_r,
759                   const int64_t in_c, const int64_t filter_shards_row,
760                   const int64_t filter_shards_col,
761                   const T* out_transform_matrix, const T* out_buffer,
762                   T* out_transform_buffer, T* output) {
763     const int64_t tile_rows = transform->input_shape().rows;
764     const int64_t tile_cols = transform->input_shape().cols;
765     const int64_t tile_spatial_size = tile_rows * tile_cols;
766 
767     const int64_t out_buf_stride =
768         num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
769 
770     const int64_t out_tile_rows = transform->output_shape().rows;
771     const int64_t out_tile_cols = transform->output_shape().cols;
772     const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
773 
774     // Compute output transform.
775     ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
776                      tile_spatial_size);
777     ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
778     MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
779 
780     C.noalias() = A * B;
781 
782     const int64_t tile_stride_rows = transform->output_shape().rows;
783     const int64_t tile_stride_cols = transform->output_shape().cols;
784 
785     const int64_t out_depth_stride = filter_shards_row * filter_shards_col;
786     const int64_t num_tiles_stride = args.out_depth * out_depth_stride;
787 
788     // Copy transformed output from 'out_transform_buffer' to proper index
789     // in 'output'. Note that some outputs at boundaries can be discarded.
790     for (int64_t t = 0; t < num_tiles; ++t) {
791       const int64_t tile_base = t * num_tiles_stride;
792 
793       for (int64_t od = 0; od < args.out_depth; ++od) {
794         const int64_t out_depth_base = od * out_depth_stride;
795 
796         // TODO(andydavis) Update filter sharding scheme in the next CL.
797         for (int64_t sr = 0; sr < filter_shards_row; ++sr) {
798           for (int64_t sc = 0; sc < filter_shards_col; ++sc) {
799             const int64_t shard_base = sr * filter_shards_col + sc;
800             const int64_t out_buf_base =
801                 tile_base + out_depth_base + shard_base;
802 
803             // Calculate output indices and outputs to drop (if needed).
804             const int64_t out_r_start =
805                 in_r + args.pad_rows - sr * tile_stride_rows;
806             // NOTE: The index 't' for 'num_tiles is used in index calculation
807             // for 'out_c_start' because we 'num_tiles' progresses along the
808             // column dimension.
809             const int64_t out_c_start = (in_c + t * tile_stride_cols) +
810                                         args.pad_cols - sc * tile_stride_cols;
811 
812             if (out_r_start < 0 || out_r_start >= args.out_rows ||
813                 out_c_start < 0 || out_c_start >= args.out_cols) {
814               continue;  // Skip un-needed outputs.
815             }
816 
817             // Increment output if not first filter shard.
818             const bool inc_output = (sr == 0 && sc == 0) ? false : true;
819 
820             for (int64_t ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
821               const int64_t out_r = out_r_start + ot_row;
822               if (out_r >= args.out_rows) continue;
823 
824               for (int64_t ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
825                 const int64_t out_c = out_c_start + ot_col;
826                 if (out_c >= args.out_cols) continue;
827 
828                 // Calculate out tile indexl
829                 const int64_t out_buf_index = ot_row * out_tile_cols + ot_col;
830                 // Read output value from buffer.
831                 const T out_val =
832                     out_transform_buffer[out_buf_base +
833                                          out_buf_index * out_buf_stride];
834                 // Calculate output index.
835                 const int64_t output_index =
836                     args.out_depth * (out_r * args.out_cols + out_c) + od;
837                 // Update output.
838                 if (inc_output) {
839                   output[output_index] += out_val;
840                 } else {
841                   output[output_index] = out_val;
842                 }
843               }
844             }
845           }
846         }
847       }
848     }
849   }
850 };
851 
852 template <typename T>
853 struct Conv2DState {
Conv2DStatetensorflow::Conv2DState854   Conv2DState(const int64_t tile_spatial_size, const int64_t filter_shards_row,
855               const int64_t filter_shards_col, const T* input,
856               const T* tile_transform_matrix, const T* output_transform_matrix,
857               T* buffer1, T* buffer2, T* packed_tile_buffer,
858               T* gemm_output_buffer)
859       : tile_spatial_size(tile_spatial_size),
860         filter_shards_row(filter_shards_row),
861         filter_shards_col(filter_shards_col),
862         input(input),
863         tile_transform_matrix(tile_transform_matrix),
864         output_transform_matrix(output_transform_matrix),
865         buffer1(buffer1),
866         buffer2(buffer2),
867         packed_tile_buffer(packed_tile_buffer),
868         gemm_output_buffer(gemm_output_buffer) {}
869 
870   const int64 tile_spatial_size;
871   const int64 filter_shards_row;
872   const int64 filter_shards_col;
873   const T* input;
874   const T* tile_transform_matrix;
875   const T* output_transform_matrix;
876   T* buffer1;
877   T* buffer2;
878   T* packed_tile_buffer;
879   T* gemm_output_buffer;
880 };
881 
882 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
883 // (in_r, in_c), storing the results of the computation in 'output'.
884 // Details:
885 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
886 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
887 // *) Transforms output tiles, and stores result to 'output'.
888 
889 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
890 template <typename T>
891 struct ComputeConv2D {
operator ()tensorflow::ComputeConv2D892   void operator()(const Conv2DArgs& args,
893                   const DeepConv2DTransform<T>* transform,
894                   const Conv2DState<T>& cs, const int64_t in_r,
895                   const int64_t in_c, const int64_t num_tiles,
896                   const std::vector<Tensor>& packed_filters, const T* input,
897                   T* output) {
898     // Transform input tiles.
899     TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
900                              cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
901 
902     // Compute element-wise product (each a MatMul): input tiles X filters.
903     const int64_t in_depth = args.in_depth;
904     const int64_t out_depth = args.out_depth;
905     const int64_t num_filters =
906         cs.filter_shards_row * cs.filter_shards_col * out_depth;
907     const int64_t tile_coord_stride = num_tiles * in_depth;
908     const int64_t gemm_out_buf_size = num_tiles * num_filters;
909     const int64_t gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
910 
911     for (int64_t i = 0; i < cs.tile_spatial_size; ++i) {
912       GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
913                         packed_filters[i].template flat<T>().data(),
914                         cs.buffer2 + i * tile_coord_stride,
915                         cs.packed_tile_buffer, cs.gemm_output_buffer);
916       // Pack tile buffer.
917       gemm.PackRhs();
918       // Compute product.
919       gemm.Compute();
920       // Copy to larger output buffer without alignment requirements.
921       memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
922              gemm_out_buf_bytes);
923     }
924 
925     // Transform output.
926     TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
927                              cs.filter_shards_row, cs.filter_shards_col,
928                              cs.output_transform_matrix, cs.buffer1, cs.buffer2,
929                              output);
930   }
931 };
932 
933 namespace functor {
934 
935 // Conv2D operation specialized for deep convolutions (i.e. large
936 // in_depth * out_depth).
937 // Details:
938 // *) Transforms and packs filters from 'filter' in parallel.
939 // *) Computes Conv2D parallelized across 'batch' dimension.
940 //   *) Each thread loops over images in its batch shard, copying 'num_tiles'
941 //      input tiles into a local buffer, and computing the Conv2D output of
942 //      these tiles by all filters.
943 
944 // TODO(andydavis) Improve the performance of boundary cases where the input
945 // tile extends past the limit, and wasted outputs are computed. This overhead
946 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
947 // for smaller spatial sizes.
948 // TODO(andydavis) Improve the performance of sharded filters.
949 template <typename T>
950 struct DeepConv2D<CPUDevice, T> {
operator ()tensorflow::functor::DeepConv2D951   void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
952                   const T* filter, T* output) {
953     // TODO(andydavis) Add function to select transform based on conv params.
954     std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
955 
956     const int64_t in_depth = args.in_depth;
957     const int64_t out_depth = args.out_depth;
958 
959     const int64_t tile_rows = transform->input_shape().rows;
960     const int64_t tile_cols = transform->input_shape().cols;
961     const int64_t tile_spatial_size = tile_rows * tile_cols;
962 
963     const int64_t out_tile_rows = transform->output_shape().rows;
964     const int64_t out_tile_cols = transform->output_shape().cols;
965     const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
966 
967     const int64_t base_filter_rows = transform->filter_shape().rows;
968 
969     const int64_t filter_residual_row =
970         std::max(int64{0}, args.filter_rows - base_filter_rows);
971     const int64_t filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
972 
973     const int64_t filter_residual_col =
974         std::max(int64{0}, args.filter_cols - base_filter_rows);
975     const int64_t filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
976 
977     // Allocate buffer for transformed filters.
978     Tensor filter_transform;
979     OP_REQUIRES_OK(
980         ctx, ctx->allocate_temp(
981                  DataTypeToEnum<T>::value,
982                  TensorShape({tile_rows, tile_cols, out_depth,
983                               filter_shards_row, filter_shards_col, in_depth}),
984                  &filter_transform));
985     T* filter_transform_data = filter_transform.template flat<T>().data();
986 
987     // Transform filters.
988     TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
989                           filter_shards_col, filter, filter_transform_data);
990 
991     // Pack filters.
992     std::vector<Tensor> packed_filters(tile_spatial_size);
993     PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
994                      filter_shards_col, filter_transform_data, &packed_filters);
995 
996     // Allocate buffer for tile transform matrix.
997     Tensor tile_transform_matrix_tensor;
998     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
999                             DataTypeToEnum<T>::value,
1000                             TensorShape({tile_spatial_size, tile_spatial_size}),
1001                             &tile_transform_matrix_tensor));
1002     T* tile_transform_matrix =
1003         tile_transform_matrix_tensor.template flat<T>().data();
1004     transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
1005                                        tile_transform_matrix);
1006 
1007     // Allocate buffer for output transform matrix.
1008     Tensor output_transform_matrix_tensor;
1009     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1010                                            TensorShape({out_tile_spatial_size,
1011                                                         tile_spatial_size}),
1012                                            &output_transform_matrix_tensor));
1013     T* output_transform_matrix =
1014         output_transform_matrix_tensor.template flat<T>().data();
1015     transform->GetOutputTransformMatrix(
1016         out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
1017 
1018     auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
1019                   out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
1020                   filter_shards_col, tile_spatial_size, &input,
1021                   &tile_transform_matrix, &output_transform_matrix,
1022                   &output](int64_t batch_start, int64_t batch_limit) {
1023       const int64_t row_tiles =
1024           (args.out_rows + out_tile_rows - 1) / out_tile_rows +
1025           filter_shards_row - 1;
1026       const int64_t col_tiles =
1027           (args.out_cols + out_tile_cols - 1) / out_tile_cols +
1028           filter_shards_col - 1;
1029 
1030       // Calculate number of tiles to process together.
1031       const int64_t filter_shard_size = filter_shards_row * filter_shards_col;
1032       const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
1033 
1034       // Cache budget (based on L2 cache size = 256KB).
1035       // TODO(andydavis) Read cache size from the system.
1036       const int64_t cache_size = (256LL << 10) / sizeof(T);
1037 
1038       // Fixed costs.
1039       const int64_t tile_transform_matrix_size =
1040           tile_spatial_size * tile_spatial_size;
1041       const int64_t output_transform_matrix_size =
1042           out_tile_spatial_size * tile_spatial_size;
1043       // Calculate cache reserve size.
1044       const int64_t filter_depth_size =
1045           in_depth * out_depth * filter_shard_size;
1046       const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
1047       const int64_t cache_reserve_size =
1048           small_filter ? filter_depth_size : 1024;
1049       // Calculate total fixed cost.
1050       const int64_t total_fixed_cost = tile_transform_matrix_size +
1051                                        output_transform_matrix_size +
1052                                        cache_reserve_size;
1053 
1054       // Per-tile costs.
1055       const int64_t buffer1_per_tile_size =
1056           tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
1057       const int64_t buffer2_per_tile_size =
1058           std::max(tile_spatial_size * in_depth,
1059                    out_tile_spatial_size * out_depth * filter_shard_size);
1060       const int64_t packed_tile_per_tile_size = in_depth;
1061       const int64_t gemm_out_per_tile_size = out_depth * filter_shard_size;
1062       const int64_t total_per_tile_cost =
1063           buffer1_per_tile_size + buffer2_per_tile_size +
1064           packed_tile_per_tile_size + gemm_out_per_tile_size;
1065 
1066       const int64_t num_tiles_cache = std::max(
1067           int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost);
1068       const int64_t num_tiles = std::min(num_tiles_cache, col_tiles);
1069 
1070       // Allocate temporary buffer 'buffer1', which is first used for copying
1071       // input tiles, then re-used to buffer gemm output. Calculate the
1072       // required buffer size for 'buffer1', based on max buffer size required
1073       // between copying input tiles and buffering gemm product output.
1074       //   buffer1: [max(buf1_tile_size, buf1_out_size)]
1075       const int64_t buffer1_tile_size =
1076           tile_spatial_size * num_tiles * in_depth;
1077       const int64_t buffer1_out_size =
1078           tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1079       const int64_t buffer1_size =
1080           std::max(buffer1_tile_size, buffer1_out_size);
1081       Tensor buffer1_tensor;
1082       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1083                                              TensorShape({buffer1_size}),
1084                                              &buffer1_tensor));
1085       T* buffer1 = buffer1_tensor.template flat<T>().data();
1086 
1087       // Allocate temporary buffer 'buffer2', which is first used for
1088       // transformed input tiles, then re-used for transformed output tiles.
1089       // Calculate required buffer size for 'buffer2' as max required buffer
1090       // between input and output transform buffer sizes.
1091       const int64_t buffer2_tile_transform_size =
1092           tile_spatial_size * num_tiles * in_depth;
1093       const int64_t buffer2_out_transform_size =
1094           out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1095       const int64_t buffer2_size =
1096           std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
1097       Tensor buffer2_tensor;
1098       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1099                                              TensorShape({buffer2_size}),
1100                                              &buffer2_tensor));
1101       T* buffer2 = buffer2_tensor.template flat<T>().data();
1102 
1103       // Allocate temporary buffer to store packed tiles for one coordinate.
1104       // packed tile buffer: [num_tiles, in_depth].
1105       Tensor packed_tile_tensor;
1106       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1107                                              TensorShape({num_tiles, in_depth}),
1108                                              &packed_tile_tensor));
1109       T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
1110 
1111       // Allocate temporary buffer for gemm output.
1112       // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
1113       Tensor gemm_output_tensor;
1114       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1115                                              TensorShape({num_tiles, out_depth,
1116                                                           filter_shards_row,
1117                                                           filter_shards_col}),
1118                                              &gemm_output_tensor));
1119       T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
1120 
1121       // Capture state needed for ComputeConv2D inner loop.
1122       Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
1123                                 filter_shards_col, input, tile_transform_matrix,
1124                                 output_transform_matrix, buffer1, buffer2,
1125                                 packed_tile_buffer, gemm_output_buffer);
1126 
1127       const int64_t row_pad = args.pad_rows;
1128       const int64_t col_pad = args.pad_cols;
1129       const int64_t unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
1130 
1131       const int64_t input_image_size = args.in_rows * args.in_cols * in_depth;
1132       const int64_t output_image_size =
1133           args.out_rows * args.out_cols * out_depth;
1134 
1135       const int64_t tile_stride_rows = transform->output_shape().rows;
1136       const int64_t tile_stride_cols = transform->output_shape().cols;
1137 
1138       for (int64_t b = batch_start; b < batch_limit; ++b) {
1139         const int64_t in_base = b * input_image_size;
1140         const int64_t out_base = b * output_image_size;
1141 
1142         for (int64_t tile_r = 0; tile_r < row_tiles; ++tile_r) {
1143           const int64_t in_r = tile_r * tile_stride_rows - row_pad;
1144 
1145           // Process unrolled tiles.
1146           for (int64_t tile_c = 0; tile_c < unroll_col_limit;
1147                tile_c += num_tiles) {
1148             const int64_t in_c = tile_c * tile_stride_cols - col_pad;
1149             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1150                                num_tiles, packed_filters, input + in_base,
1151                                output + out_base);
1152           }
1153           // Process remaining tiles.
1154           if (unroll_col_limit < col_tiles) {
1155             const int64_t rem_tiles = col_tiles - unroll_col_limit;
1156             const int64_t in_c = unroll_col_limit * tile_stride_cols - col_pad;
1157             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1158                                rem_tiles, packed_filters, input + in_base,
1159                                output + out_base);
1160           }
1161         }
1162       }
1163     };
1164     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
1165     const int64_t shard_cost = args.out_rows * args.out_cols * args.out_depth *
1166                                tile_spatial_size * args.in_depth;
1167     Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
1168           shard_cost, shard);
1169   }
1170 };
1171 
1172 }  // namespace functor
1173 
1174 template struct functor::DeepConv2D<CPUDevice, float>;
1175 
1176 }  // namespace tensorflow
1177