1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18
19 #include "tensorflow/core/kernels/deep_conv2d.h"
20
21 #include <stdlib.h>
22
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/kernels/winograd_transform.h"
25 #include "tensorflow/core/util/work_sharder.h"
26
27 namespace tensorflow {
28
29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
30 // large 'in_depth' and 'out_depth' product. See cost models below for details).
31 //
32 // DeepConv2D is implemented by computing the following equation:
33 //
34 // y = C[Ad * Bg]
35 //
36 // C: output transform matrix
37 // A: input data transform matrix
38 // B: filter transform matrix
39 // d: vectorized data tile
40 // g: vectorized filter tile
41 // y: vectorized output tile
42 //
43 // The transform matrices and input, filter and output tile sizes are all
44 // specified by the DeepConv2DTransform implementation selected at the
45 // start of the DeepConv2D call, based on convolution parameters.
46
47 // Approximate cost models for direct and deep convolutions.
GetDeepConvCost(int input_tile_rows,int input_tile_cols,int out_tile_rows,int out_tile_cols,int in_depth,int out_depth,int out_rows,int out_cols)48 static int64 GetDeepConvCost(int input_tile_rows, int input_tile_cols,
49 int out_tile_rows, int out_tile_cols, int in_depth,
50 int out_depth, int out_rows, int out_cols) {
51 // Input transform cost.
52 const int64 input_tile_spatial_size = input_tile_rows * input_tile_cols;
53 const int64 input_transform_cost =
54 input_tile_spatial_size * input_tile_spatial_size * in_depth;
55
56 // Element-wise products (each product is a MatMul across depth).
57 const int64 product_cost = input_tile_spatial_size * in_depth * out_depth;
58
59 // Output transform cost.
60 const int64 output_tile_spatial_size = out_tile_rows * out_tile_cols;
61 const int64 output_transform_cost =
62 output_tile_spatial_size * input_tile_spatial_size * out_depth;
63
64 // Calculate number of input tiles to process.
65 const int64 row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
66 const int64 col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
67 const int64 num_tiles = row_tiles * col_tiles;
68
69 // Return total cost.
70 return num_tiles *
71 (input_transform_cost + product_cost + output_transform_cost);
72 }
73
GetDirectConvCost(int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)74 static int64 GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
75 int out_depth, int out_rows, int out_cols) {
76 return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
77 }
78
79 // Reads environment variable 'env_var_name'.
80 // Returns 'true' if environment variable is enabled, false otherwise.
ReadBoolFromEnvVar(const char * env_var_name,bool default_val)81 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
82 const char* tf_env_var_val = getenv(env_var_name);
83 if (tf_env_var_val != nullptr) {
84 StringPiece tf_env_var_val_str(tf_env_var_val);
85 if (tf_env_var_val_str == "0") {
86 return false;
87 }
88 return true;
89 }
90 return default_val;
91 }
92
93 // Returns true if convolution can be computed efficiently by DeepConv2D,
94 // returns false otherwise.
95 // TODO(andydavis) Add support for other filter sizes and strides.
96 // TODO(andydavis) Add support for autotuning.
CanUseDeepConv2D(int stride_rows,int stride_cols,int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)97 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
98 int filter_cols, int in_depth, int out_depth,
99 int out_rows, int out_cols) {
100 // Check if convolution parameters are supported.
101 // TODO(andydavis) Add support for multiple filter sizes and strides.
102 if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
103 filter_cols != 3) {
104 return false;
105 }
106
107 // Check if deep convolution is enabled by environment variable.
108 // NOTE: IF this environment variable name changes, update conv_ops_test.py.
109 if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
110 return false;
111 }
112
113 // Check if flop cost of deep convolution is less than direct convolution.
114 WinogradTransform<float> t;
115 const int64 deep_conv_cost = GetDeepConvCost(
116 t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
117 t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
118 const int64 direct_conv_cost = GetDirectConvCost(
119 filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
120
121 VLOG(2) << "CanUseDeepConv2D"
122 << " deep_conv_cost: " << deep_conv_cost
123 << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
124 << (static_cast<float>(deep_conv_cost) /
125 static_cast<float>(direct_conv_cost))
126 << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
127 return deep_conv_cost < direct_conv_cost;
128 }
129
130 typedef Eigen::ThreadPoolDevice CPUDevice;
131
132 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
133 //
134 // filter_in:
135 // [filter_rows, filter_cols, in_depth, out_depth]
136 //
137 // filter_buf:
138 // [base_filter_rows, base_filter_cols, in_depth]
139 //
140 template <typename T>
141 struct CopyFilterDepth {
operator ()tensorflow::CopyFilterDepth142 void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
143 typedef typename Eigen::internal::packet_traits<T>::type Packet;
144 static constexpr int64 kPacketSize = (sizeof(Packet) / sizeof(T));
145
146 const int64 vectorized_size = args.in_depth / kPacketSize;
147 const int64 scalar_size = args.in_depth % kPacketSize;
148 const int64 input_stride = args.out_depth * kPacketSize;
149
150 // Copy vectorized portion of depth dimension.
151 for (int64 d = 0; d < vectorized_size; ++d) {
152 auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
153 args.out_depth);
154 Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
155 }
156 // Copy scalar portion of inner dimension.
157 const int64 in_scalar_base = vectorized_size * input_stride;
158 const int64 buf_scalar_base = vectorized_size * kPacketSize;
159 for (int64 d = 0; d < scalar_size; ++d) {
160 filter_buf[buf_scalar_base + d] =
161 filter_in[in_scalar_base + d * args.out_depth];
162 }
163 }
164 };
165
166 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
167 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
168 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
169 // 'filter_out' at the coordinate stride required by the transformed filter
170 // data layout.
171 //
172 // filter_in:
173 // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
174 // in_depth]
175 //
176 // filter_out:
177 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
178 //
179 // transform_matrix:
180 // [tile_spatial_size, base_filter_spatial_size]
181 //
182 // out_buffer:
183 // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
184
185 template <typename T>
186 struct ComputeFilterRangeTransform {
187 typedef typename Eigen::internal::packet_traits<T>::type Packet;
188 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
189
190 typedef Eigen::Map<
191 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
192 MatrixMap;
193 typedef Eigen::Map<
194 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
195 ConstMatrixMap;
196
operator ()tensorflow::ComputeFilterRangeTransform197 void operator()(const Conv2DArgs& args,
198 const DeepConv2DTransform<T>* transform, const int64 od_start,
199 const int64 num_filters, const int64 shard_rows,
200 const int64 shard_cols, const T* filter_in,
201 const int64 in_stride, const int64 out_stride,
202 const T* transform_matrix, T* out_buffer, T* filter_out) {
203 namespace ei = Eigen::internal;
204
205 const int64 in_depth = args.in_depth;
206 const int64 base_filter_rows = transform->filter_shape().rows;
207 const int64 base_filter_cols = transform->filter_shape().cols;
208 const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
209 const int64 tile_rows = transform->input_shape().rows;
210 const int64 tile_cols = transform->input_shape().cols;
211 const int64 tile_spatial_size = tile_rows * tile_cols;
212
213 // Compute transform of 'num_filters' by 'transform_matrix'.
214 ConstMatrixMap A(transform_matrix, tile_spatial_size,
215 base_filter_spatial_size);
216 ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
217 MatrixMap C(out_buffer, tile_spatial_size, in_stride);
218
219 C.noalias() = A * B;
220
221 // Copy 'out_buffer' to 'filter_out' at required filter output stride.
222 const int64 scalar_size = in_depth % kPacketSize;
223 const int64 vectorized_size = in_depth / kPacketSize;
224
225 const int64 shard_stride = args.in_depth;
226 const int64 out_depth_stride = shard_rows * shard_cols * shard_stride;
227
228 for (int64 od = 0; od < num_filters; ++od) {
229 const int64 out_depth_buf_base = od * out_depth_stride;
230 const int64 out_depth_base = (od_start + od) * out_depth_stride;
231
232 // TODO(andydavis) Shard filters that are multiples of base filter sizes.
233 for (int64 s_r = 0; s_r < shard_rows; ++s_r) {
234 for (int64 s_c = 0; s_c < shard_cols; ++s_c) {
235 const int64 shard_base = shard_stride * (s_r * shard_cols + s_c);
236
237 for (int64 i = 0; i < tile_spatial_size; ++i) {
238 const int64 in_base =
239 i * in_stride + out_depth_buf_base + shard_base;
240 const int64 out_base = i * out_stride + out_depth_base + shard_base;
241 // Copy vectorized portion of 'in_depth'.
242 for (int64 d = 0; d < vectorized_size; ++d) {
243 auto v =
244 ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
245 ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
246 }
247 // Transform scalar portion of 'in_depth'.
248 const int64 scalar_base = vectorized_size * kPacketSize;
249 for (int64 d = 0; d < scalar_size; ++d) {
250 filter_out[out_base + scalar_base + d] =
251 out_buffer[in_base + scalar_base + d];
252 }
253 }
254 }
255 }
256 }
257 }
258 };
259
260 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
261 // For each filter in 'num_filters', copies data for all filter shards from
262 // 'filter_in' into 'filter_buf', adding zero-padding as needed.
263 // Calls ComputeFilterRangeTransform to compute filter transform of data
264 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
265 //
266 // filter_in:
267 // [filter_rows, filter_cols, in_depth, out_depth]
268 //
269 // filter_out:
270 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
271 //
272 // filter_buffer:
273 // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
274 // in_depth]
275 //
276 // transform_matrix:
277 // [tile_spatial_size, base_filter_spatial_size]
278 //
279 // out_buffer:
280 // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
281 //
282
283 template <typename T>
284 struct TransformFilterRange {
operator ()tensorflow::TransformFilterRange285 void operator()(const Conv2DArgs& args,
286 const DeepConv2DTransform<T>* transform, const int64 od_start,
287 const int64 od_limit, const T* filter_in,
288 const T* transform_matrix, T* out_buffer, T* filter_buf,
289 T* filter_out) {
290 const int64 num_filters = od_limit - od_start;
291 const int64 base_filter_rows = transform->filter_shape().rows;
292 const int64 base_filter_cols = transform->filter_shape().cols;
293 const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
294
295 // Compute number of filter shards.
296 const int64 residual_row =
297 std::max(int64{0}, args.filter_rows - base_filter_rows);
298 const int64 shard_rows = 1 + (residual_row + 2 - 1) / 2;
299
300 const int64 residual_col =
301 std::max(int64{0}, args.filter_cols - base_filter_cols);
302 const int64 shard_cols = 1 + (residual_col + 2 - 1) / 2;
303
304 // Compute strides to be used for input and output IO.
305 const int64 shard_stride = args.in_depth;
306 const int64 out_depth_stride = shard_rows * shard_cols * shard_stride;
307 const int64 coord_stride = out_depth_stride * args.out_depth;
308 const int64 filter_buf_stride =
309 num_filters * shard_rows * shard_cols * args.in_depth;
310 const int64 tile_stride_rows = transform->output_shape().rows;
311 const int64 tile_stride_cols = transform->output_shape().cols;
312
313 const int64 filter_buf_size = base_filter_spatial_size * num_filters *
314 shard_rows * shard_cols * args.in_depth;
315 memset(filter_buf, 0, sizeof(T) * filter_buf_size);
316
317 // Copy filter range into 'filter_buf'.
318 for (int64 od = 0; od < num_filters; ++od) {
319 const int64 out_depth_base = od * out_depth_stride;
320
321 // TODO(andydavis) Shard filters that are multiples of base filter sizes.
322 for (int64 s_r = 0; s_r < shard_rows; ++s_r) {
323 const int64 row_offset = s_r == 0 ? 0 : 1;
324
325 for (int64 s_c = 0; s_c < shard_cols; ++s_c) {
326 const int64 col_offset = s_c == 0 ? 0 : 1;
327 const int64 f_r_start = s_r * tile_stride_rows;
328 const int64 f_c_start = s_c * tile_stride_cols;
329
330 const int64 shard_base = shard_stride * (s_r * shard_cols + s_c);
331
332 for (int64 b_r = row_offset; b_r < base_filter_rows; ++b_r) {
333 const int64 f_r = f_r_start + b_r;
334 if (f_r >= args.filter_rows) continue;
335
336 for (int64 b_c = col_offset; b_c < base_filter_cols; ++b_c) {
337 const int64 f_c = f_c_start + b_c;
338 if (f_c >= args.filter_cols) continue;
339
340 const int64 in_index =
341 args.out_depth *
342 (args.in_depth * (f_r * args.filter_cols + f_c)) +
343 (od_start + od);
344
345 const int64 buf_index =
346 filter_buf_stride * (b_r * base_filter_cols + b_c) +
347 out_depth_base + shard_base;
348
349 CopyFilterDepth<T>()(args, filter_in + in_index,
350 filter_buf + buf_index);
351 }
352 }
353 }
354 }
355 }
356
357 // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
358 // Intermediate results are stored in 'out_buffer'.
359 // Final results are stored in 'filter_out'.
360 ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
361 shard_rows, shard_cols, filter_buf,
362 filter_buf_stride, coord_stride,
363 transform_matrix, out_buffer, filter_out);
364 }
365 };
366
367 // Transforms all filters from 'filter_in', storing result in 'filter_out'.
368 //
369 // filter_in:
370 // [filter_rows, filter_cols, in_depth, out_depth]
371 //
372 // filter_out:
373 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
374 //
375 template <typename T>
376 struct TransformFilters {
operator ()tensorflow::TransformFilters377 void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
378 const DeepConv2DTransform<T>* transform,
379 const int64 filter_shards_row, const int64 filter_shards_col,
380 const T* filter_in, T* filter_out) {
381 const int64 in_depth = args.in_depth;
382 const int64 out_depth = args.out_depth;
383
384 const int64 tile_rows = transform->input_shape().rows;
385 const int64 tile_cols = transform->input_shape().cols;
386 const int64 tile_spatial_size = tile_rows * tile_cols;
387
388 const int64 base_filter_rows = transform->filter_shape().rows;
389 const int64 base_filter_cols = transform->filter_shape().cols;
390 const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
391
392 const int64 filter_shards_total = filter_shards_row * filter_shards_col;
393
394 // Calculate filter transform batch based on cache/filter sizes.
395
396 // Cache budget (based on L2 cache size = 256KB).
397 // TODO(andydavis) Read cache size from system.
398 const int64 cache_size = (256LL << 10) / sizeof(T);
399
400 // Fixed cost.
401 const int64 filter_transform_matrix_size =
402 tile_spatial_size * base_filter_spatial_size;
403
404 // Per-filter costs.
405 const int64 filter_total_size =
406 base_filter_spatial_size * in_depth * filter_shards_total;
407
408 const int64 filter_transform_buffer_size =
409 base_filter_spatial_size * filter_shards_total * in_depth;
410
411 const int64 filter_out_buf_size =
412 tile_spatial_size * filter_shards_total * in_depth;
413
414 // Total per-filter costs.
415 const int64 per_filter_cost =
416 filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
417
418 // Remove fixed cost and divide by per-filter cost.
419 const int64 num_filters_cache =
420 std::max(int64{1},
421 (cache_size - filter_transform_matrix_size) / per_filter_cost);
422 const int64 num_filters_transform = std::min(out_depth, num_filters_cache);
423
424 // Allocate buffer for filter transform matrix:
425 // [tile_spatial_size, base_filter_spatial_size]
426 Tensor filter_transform_matrix;
427 OP_REQUIRES_OK(
428 ctx, ctx->allocate_temp(
429 DataTypeToEnum<T>::value,
430 TensorShape({tile_spatial_size, base_filter_spatial_size}),
431 &filter_transform_matrix));
432 T* transform_matrix = filter_transform_matrix.template flat<T>().data();
433 transform->GetFilterTransformMatrix(
434 tile_spatial_size, base_filter_spatial_size, transform_matrix);
435
436 auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
437 &num_filters_transform, &in_depth, &filter_shards_row,
438 &filter_shards_col, &tile_spatial_size, &filter_in,
439 &transform_matrix, &filter_out](int64 start, int64 limit) {
440 // Allocate buffer for pre-processed filter:
441 // [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
442 //
443 Tensor filter_transform_buffer;
444 OP_REQUIRES_OK(ctx,
445 ctx->allocate_temp(
446 DataTypeToEnum<T>::value,
447 TensorShape({base_filter_rows, base_filter_cols,
448 num_filters_transform, filter_shards_row,
449 filter_shards_col, in_depth}),
450 &filter_transform_buffer));
451 T* filter_buf = filter_transform_buffer.template flat<T>().data();
452
453 // Allocate buffer for output filter transform matrix:
454 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
455 Tensor filter_output_buffer;
456 OP_REQUIRES_OK(
457 ctx,
458 ctx->allocate_temp(
459 DataTypeToEnum<T>::value,
460 TensorShape({tile_spatial_size, num_filters_transform,
461 filter_shards_row, filter_shards_col, in_depth}),
462 &filter_output_buffer));
463 T* out_buffer = filter_output_buffer.template flat<T>().data();
464
465 const int64 num_filters = limit - start;
466 const int64 od_unroll = num_filters_transform;
467 const int64 od_unroll_limit = (num_filters / od_unroll) * od_unroll;
468
469 for (int64 od = start; od < od_unroll_limit; od += od_unroll) {
470 TransformFilterRange<T>()(args, transform, od, od + od_unroll,
471 filter_in, transform_matrix, out_buffer,
472 filter_buf, filter_out);
473 }
474
475 if (od_unroll_limit < limit) {
476 TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
477 filter_in, transform_matrix, out_buffer,
478 filter_buf, filter_out);
479 }
480 };
481 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
482
483 const int64 shard_cost = args.filter_rows * args.filter_cols * in_depth *
484 filter_shards_total * tile_spatial_size;
485 // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
486 Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
487 }
488 };
489
490 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
491 // gemm-kernel friendly data layout.
492 //
493 // Data layout for 'lhs_block':
494 // [out_depth, shard_rows, shard_cols, in_depth].
495
496 template <typename T>
497 class GemmFilterPacker {
498 public:
499 typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
500 LhsMapper;
501 typedef Eigen::internal::gebp_traits<T, T> Traits;
502 Eigen::internal::gemm_pack_lhs<
503 T, int64, LhsMapper, Traits::mr, Traits::LhsProgress,
504 typename Traits::LhsPacket4Packing, Eigen::RowMajor>
505 pack_lhs;
506
GemmFilterPacker(const int64 rows,const int64 depth,const T * lhs_input,T * lhs_block)507 GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input,
508 T* lhs_block)
509 : rows_(rows),
510 depth_(depth),
511 lhs_block_(lhs_block),
512 lhs_mapper_(lhs_input, depth_) {}
513
Run()514 void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
515
516 private:
517 const int64 rows_;
518 const int64 depth_;
519 T* lhs_block_;
520 LhsMapper lhs_mapper_;
521 };
522
523 // Packs transformed filter stored in 'filter_transform_data' into
524 // 'packed_filters' to be used by GemmState.
525 template <typename T>
526 struct PackFilters {
operator ()tensorflow::PackFilters527 void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
528 const int64 tile_spatial_size, const int64 filter_shards_row,
529 const int64 filter_shards_col, const T* filter_transform_data,
530 std::vector<Tensor>* packed_filters) {
531 const int64 in_depth = args.in_depth;
532 const int64 out_depth = args.out_depth;
533 const int64 num_filters = filter_shards_row * filter_shards_col * out_depth;
534
535 auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
536 &out_depth, &filter_shards_row, &filter_shards_col,
537 &num_filters](int64 start, int64 limit) {
538 const int64 filter_coord_stride = num_filters * in_depth;
539 for (int64 i = start; i < limit; ++i) {
540 // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
541 OP_REQUIRES_OK(
542 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
543 TensorShape({out_depth, filter_shards_row,
544 filter_shards_col, in_depth}),
545 &(*packed_filters)[i]));
546 T* packed_filter = (*packed_filters)[i].template flat<T>().data();
547 // Pack filters.
548 GemmFilterPacker<T> packer(
549 num_filters, in_depth,
550 filter_transform_data + i * filter_coord_stride, packed_filter);
551 packer.Run();
552 }
553 };
554 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
555 Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
556 num_filters * in_depth, shard);
557 }
558 };
559
560 // Computes the product of filters stored in 'lhs_block' and input tiles
561 // stored in 'rhs_block', storing output in 'out_buffer'.
562 //
563 // Data layout for 'lhs_block':
564 // [out_depth, shard_rows, shard_cols, in_depth].
565 //
566 // Data layout for 'rhs_block':
567 // [num_tiles, in_depth]
568 //
569 // Data layout for 'out_buffer':
570 // [num_tiles, out_depth, shard_rows, shard_cols]
571
572 template <typename T>
573 class GemmState {
574 public:
575 typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::ColMajor>
576 RhsMapper;
577 typedef Eigen::internal::blas_data_mapper<T, int64, Eigen::ColMajor>
578 OutputMapper;
579 typedef Eigen::internal::gebp_traits<T, T> Traits;
580
581 Eigen::internal::gemm_pack_rhs<T, int64, RhsMapper, Traits::nr,
582 Eigen::ColMajor>
583 pack_rhs;
584 Eigen::internal::gebp_kernel<T, T, int64, OutputMapper, Traits::mr,
585 Traits::nr, false, false>
586 gebp;
587
GemmState(const int64 rows,const int64 cols,const int64 depth,const int64 out_buffer_size,const T * lhs_block,const T * rhs_input,T * rhs_block,T * out_buffer)588 GemmState(const int64 rows, const int64 cols, const int64 depth,
589 const int64 out_buffer_size, const T* lhs_block, const T* rhs_input,
590 T* rhs_block, T* out_buffer)
591 : rows_(rows),
592 cols_(cols),
593 depth_(depth),
594 out_buffer_size_(out_buffer_size),
595 lhs_block_(lhs_block),
596 rhs_block_(rhs_block),
597 out_buffer_(out_buffer),
598 rhs_mapper_(rhs_input, depth_),
599 out_mapper_(out_buffer, rows_) {}
600
PackRhs()601 void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
602
Compute()603 void Compute() {
604 memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
605 gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
606 }
607
608 private:
609 const int64 rows_;
610 const int64 cols_;
611 const int64 depth_;
612 const int64 out_buffer_size_;
613 const T* lhs_block_;
614 T* rhs_block_;
615 T* out_buffer_;
616 RhsMapper rhs_mapper_;
617 OutputMapper out_mapper_;
618 };
619
620 // Copies an input tile from 'input' into 'tile_buffer'.
621 //
622 // input:
623 // [in_rows, in_cols, in_depth]
624 //
625 // tile_buffer:
626 // [tile_rows, tile_cols, num_tiles, in_depth]
627
628 template <typename T>
629 struct CopyInputTile {
operator ()tensorflow::CopyInputTile630 void operator()(const Conv2DArgs& args,
631 const DeepConv2DTransform<T>* transform,
632 const int64 num_tiles, const int64 in_r_start,
633 const int64 in_c_start, const T* input, T* tile_buffer) {
634 typedef typename Eigen::internal::packet_traits<T>::type Packet;
635 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
636
637 const int64 tile_rows = transform->input_shape().rows;
638 const int64 tile_cols = transform->input_shape().cols;
639 const int64 coord_stride = num_tiles * args.in_depth;
640
641 // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
642 const int64 input_vectorized_size =
643 (args.in_depth / kPacketSize) * kPacketSize;
644 const int64 input_scalar_size = args.in_depth % kPacketSize;
645
646 for (int64 r = 0; r < tile_rows; ++r) {
647 const int64 in_r = in_r_start + r;
648 if (in_r < 0 || in_r >= args.in_rows) continue;
649
650 for (int64 c = 0; c < tile_cols; ++c) {
651 const int64 in_c = in_c_start + c;
652 if (in_c < 0 || in_c >= args.in_cols) continue;
653
654 auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
655 auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
656 // Copy vectorized portion of depth dimension.
657 for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) {
658 auto v = Eigen::internal::ploadu<Packet>(in + d);
659 Eigen::internal::pstoreu<T>(tile, v);
660 tile += kPacketSize;
661 }
662 // Copy scalar portion of inner dimension.
663 for (int64 d = 0; d < input_scalar_size; ++d) {
664 tile[d] = in[input_vectorized_size + d];
665 }
666 }
667 }
668 }
669 };
670
671 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
672 // final result in 'tile_transform'.
673 // Intermediate results are stored in 'tile_buffer'.
674 //
675 // input:
676 // [in_rows, in_cols, in_depth]
677 // tile_buffer:
678 // [tile_rows, tile_cols, num_tiles, in_depth]
679 // tile_transform_matrix:
680 // [tile_spatial_size, tile_spatial_size]
681 // tile_transform:
682 // [tile_rows, tile_cols, num_tiles, in_depth]
683
684 template <typename T>
685 struct TransformInputTiles {
686 typedef Eigen::Map<
687 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
688 MatrixMap;
689 typedef Eigen::Map<
690 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
691 ConstMatrixMap;
692
operator ()tensorflow::TransformInputTiles693 void operator()(const Conv2DArgs& args,
694 const DeepConv2DTransform<T>* transform,
695 const int64 num_tiles, const int64 in_r_start,
696 const int64 in_c_start, const T* input,
697 const T* transform_matrix, T* tile_buffer,
698 T* tile_transform) {
699 const int64 tile_rows = transform->input_shape().rows;
700 const int64 tile_cols = transform->input_shape().cols;
701 const int64 tile_spatial_size = tile_rows * tile_cols;
702 const int64 tile_stride_cols = transform->output_shape().cols;
703 const int64 coord_stride = num_tiles * args.in_depth;
704 const int64 num_tiles_stride = args.in_depth;
705
706 memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
707 const int64 in_r = in_r_start;
708 for (int64 t = 0; t < num_tiles; ++t) {
709 const int64 num_tiles_base = t * num_tiles_stride;
710 const int64 in_c = in_c_start + t * tile_stride_cols;
711 CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
712 tile_buffer + num_tiles_base);
713 }
714
715 ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
716 ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
717 MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
718
719 C.noalias() = A * B;
720 }
721 };
722
723 // Transforms output tiles from buffer by 'out_transform_matrix', storing
724 // final result in 'output' (intermediate results stored in 'out_buffer').
725 //
726 // out_buffer:
727 // [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
728 //
729 // output transform buffer:
730 // [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
731 //
732 // output:
733 // [out_rows, out_cols, out_depth]
734 //
735
736 template <typename T>
737 struct TransformOutputTile {
738 typedef Eigen::Map<
739 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
740 MatrixMap;
741 typedef Eigen::Map<
742 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
743 ConstMatrixMap;
744
operator ()tensorflow::TransformOutputTile745 void operator()(const Conv2DArgs& args,
746 const DeepConv2DTransform<T>* transform,
747 const int64 num_tiles, const int64 in_r, const int64 in_c,
748 const int64 filter_shards_row, const int64 filter_shards_col,
749 const T* out_transform_matrix, const T* out_buffer,
750 T* out_transform_buffer, T* output) {
751 const int64 tile_rows = transform->input_shape().rows;
752 const int64 tile_cols = transform->input_shape().cols;
753 const int64 tile_spatial_size = tile_rows * tile_cols;
754
755 const int64 out_buf_stride =
756 num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
757
758 const int64 out_tile_rows = transform->output_shape().rows;
759 const int64 out_tile_cols = transform->output_shape().cols;
760 const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
761
762 // Compute output transform.
763 ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
764 tile_spatial_size);
765 ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
766 MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
767
768 C.noalias() = A * B;
769
770 const int64 tile_stride_rows = transform->output_shape().rows;
771 const int64 tile_stride_cols = transform->output_shape().cols;
772
773 const int64 out_depth_stride = filter_shards_row * filter_shards_col;
774 const int64 num_tiles_stride = args.out_depth * out_depth_stride;
775
776 // Copy transformed output from 'out_transform_buffer' to proper index
777 // in 'output'. Note that some outputs at boundaries can be discarded.
778 for (int64 t = 0; t < num_tiles; ++t) {
779 const int64 tile_base = t * num_tiles_stride;
780
781 for (int64 od = 0; od < args.out_depth; ++od) {
782 const int64 out_depth_base = od * out_depth_stride;
783
784 // TODO(andydavis) Update filter sharding scheme in the next CL.
785 for (int64 sr = 0; sr < filter_shards_row; ++sr) {
786 for (int64 sc = 0; sc < filter_shards_col; ++sc) {
787 const int64 shard_base = sr * filter_shards_col + sc;
788 const int64 out_buf_base = tile_base + out_depth_base + shard_base;
789
790 // Calculate output indices and outputs to drop (if needed).
791 const int64 out_r_start =
792 in_r + args.pad_rows - sr * tile_stride_rows;
793 // NOTE: The index 't' for 'num_tiles is used in index calculation
794 // for 'out_c_start' because we 'num_tiles' progresses along the
795 // column dimension.
796 const int64 out_c_start = (in_c + t * tile_stride_cols) +
797 args.pad_cols - sc * tile_stride_cols;
798
799 if (out_r_start < 0 || out_r_start >= args.out_rows ||
800 out_c_start < 0 || out_c_start >= args.out_cols) {
801 continue; // Skip un-needed outputs.
802 }
803
804 // Increment output if not first filter shard.
805 const bool inc_output = (sr == 0 && sc == 0) ? false : true;
806
807 for (int64 ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
808 const int64 out_r = out_r_start + ot_row;
809 if (out_r >= args.out_rows) continue;
810
811 for (int64 ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
812 const int64 out_c = out_c_start + ot_col;
813 if (out_c >= args.out_cols) continue;
814
815 // Calculate out tile indexl
816 const int64 out_buf_index = ot_row * out_tile_cols + ot_col;
817 // Read output value from buffer.
818 const T out_val =
819 out_transform_buffer[out_buf_base +
820 out_buf_index * out_buf_stride];
821 // Calculate output index.
822 const int64 output_index =
823 args.out_depth * (out_r * args.out_cols + out_c) + od;
824 // Update output.
825 if (inc_output) {
826 output[output_index] += out_val;
827 } else {
828 output[output_index] = out_val;
829 }
830 }
831 }
832 }
833 }
834 }
835 }
836 }
837 };
838
839 template <typename T>
840 struct Conv2DState {
Conv2DStatetensorflow::Conv2DState841 Conv2DState(const int64 tile_spatial_size, const int64 filter_shards_row,
842 const int64 filter_shards_col, const T* input,
843 const T* tile_transform_matrix, const T* output_transform_matrix,
844 T* buffer1, T* buffer2, T* packed_tile_buffer,
845 T* gemm_output_buffer)
846 : tile_spatial_size(tile_spatial_size),
847 filter_shards_row(filter_shards_row),
848 filter_shards_col(filter_shards_col),
849 input(input),
850 tile_transform_matrix(tile_transform_matrix),
851 output_transform_matrix(output_transform_matrix),
852 buffer1(buffer1),
853 buffer2(buffer2),
854 packed_tile_buffer(packed_tile_buffer),
855 gemm_output_buffer(gemm_output_buffer) {}
856
857 const int64 tile_spatial_size;
858 const int64 filter_shards_row;
859 const int64 filter_shards_col;
860 const T* input;
861 const T* tile_transform_matrix;
862 const T* output_transform_matrix;
863 T* buffer1;
864 T* buffer2;
865 T* packed_tile_buffer;
866 T* gemm_output_buffer;
867 };
868
869 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
870 // (in_r, in_c), storing the results of the computation in 'output'.
871 // Details:
872 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
873 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
874 // *) Transforms output tiles, and stores result to 'output'.
875
876 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
877 template <typename T>
878 struct ComputeConv2D {
operator ()tensorflow::ComputeConv2D879 void operator()(const Conv2DArgs& args,
880 const DeepConv2DTransform<T>* transform,
881 const Conv2DState<T>& cs, const int64 in_r, const int64 in_c,
882 const int64 num_tiles,
883 const std::vector<Tensor>& packed_filters, const T* input,
884 T* output) {
885 // Transform input tiles.
886 TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
887 cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
888
889 // Compute element-wise product (each a MatMul): input tiles X filters.
890 const int64 in_depth = args.in_depth;
891 const int64 out_depth = args.out_depth;
892 const int64 num_filters =
893 cs.filter_shards_row * cs.filter_shards_col * out_depth;
894 const int64 tile_coord_stride = num_tiles * in_depth;
895 const int64 gemm_out_buf_size = num_tiles * num_filters;
896 const int64 gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
897
898 for (int64 i = 0; i < cs.tile_spatial_size; ++i) {
899 GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
900 packed_filters[i].template flat<T>().data(),
901 cs.buffer2 + i * tile_coord_stride,
902 cs.packed_tile_buffer, cs.gemm_output_buffer);
903 // Pack tile buffer.
904 gemm.PackRhs();
905 // Compute product.
906 gemm.Compute();
907 // Copy to larger output buffer without alignment requirements.
908 memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
909 gemm_out_buf_bytes);
910 }
911
912 // Transform output.
913 TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
914 cs.filter_shards_row, cs.filter_shards_col,
915 cs.output_transform_matrix, cs.buffer1, cs.buffer2,
916 output);
917 }
918 };
919
920 namespace functor {
921
922 // Conv2D operation specialized for deep convolutions (i.e. large
923 // in_depth * out_depth).
924 // Details:
925 // *) Transforms and packs filters from 'filter' in parallel.
926 // *) Computes Conv2D parallelized across 'batch' dimension.
927 // *) Each thread loops over images in its batch shard, copying 'num_tiles'
928 // input tiles into a local buffer, and computing the Conv2D output of
929 // these tiles by all filters.
930
931 // TODO(andydavis) Improve the performance of boundary cases where the input
932 // tile extends past the limit, and wasted outputs are computed. This overhead
933 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
934 // for smaller spatial sizes.
935 // TODO(andydavis) Improve the performance of sharded filters.
936 template <typename T>
937 struct DeepConv2D<CPUDevice, T> {
operator ()tensorflow::functor::DeepConv2D938 void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
939 const T* filter, T* output) {
940 // TODO(andydavis) Add function to select transform based on conv params.
941 std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
942
943 const int64 in_depth = args.in_depth;
944 const int64 out_depth = args.out_depth;
945
946 const int64 tile_rows = transform->input_shape().rows;
947 const int64 tile_cols = transform->input_shape().cols;
948 const int64 tile_spatial_size = tile_rows * tile_cols;
949
950 const int64 out_tile_rows = transform->output_shape().rows;
951 const int64 out_tile_cols = transform->output_shape().cols;
952 const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
953
954 const int64 base_filter_rows = transform->filter_shape().rows;
955
956 const int64 filter_residual_row =
957 std::max(int64{0}, args.filter_rows - base_filter_rows);
958 const int64 filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
959
960 const int64 filter_residual_col =
961 std::max(int64{0}, args.filter_cols - base_filter_rows);
962 const int64 filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
963
964 // Allocate buffer for transformed filters.
965 Tensor filter_transform;
966 OP_REQUIRES_OK(
967 ctx, ctx->allocate_temp(
968 DataTypeToEnum<T>::value,
969 TensorShape({tile_rows, tile_cols, out_depth,
970 filter_shards_row, filter_shards_col, in_depth}),
971 &filter_transform));
972 T* filter_transform_data = filter_transform.template flat<T>().data();
973
974 // Transform filters.
975 TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
976 filter_shards_col, filter, filter_transform_data);
977
978 // Pack filters.
979 std::vector<Tensor> packed_filters(tile_spatial_size);
980 PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
981 filter_shards_col, filter_transform_data, &packed_filters);
982
983 // Allocate buffer for tile transform matrix.
984 Tensor tile_transform_matrix_tensor;
985 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
986 DataTypeToEnum<T>::value,
987 TensorShape({tile_spatial_size, tile_spatial_size}),
988 &tile_transform_matrix_tensor));
989 T* tile_transform_matrix =
990 tile_transform_matrix_tensor.template flat<T>().data();
991 transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
992 tile_transform_matrix);
993
994 // Allocate buffer for output transform matrix.
995 Tensor output_transform_matrix_tensor;
996 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
997 TensorShape({out_tile_spatial_size,
998 tile_spatial_size}),
999 &output_transform_matrix_tensor));
1000 T* output_transform_matrix =
1001 output_transform_matrix_tensor.template flat<T>().data();
1002 transform->GetOutputTransformMatrix(
1003 out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
1004
1005 auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
1006 out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
1007 filter_shards_col, tile_spatial_size, &input,
1008 &tile_transform_matrix, &output_transform_matrix,
1009 &output](int64 batch_start, int64 batch_limit) {
1010 const int64 row_tiles =
1011 (args.out_rows + out_tile_rows - 1) / out_tile_rows +
1012 filter_shards_row - 1;
1013 const int64 col_tiles =
1014 (args.out_cols + out_tile_cols - 1) / out_tile_cols +
1015 filter_shards_col - 1;
1016
1017 // Calculate number of tiles to process together.
1018 const int64 filter_shard_size = filter_shards_row * filter_shards_col;
1019 const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
1020
1021 // Cache budget (based on L2 cache size = 256KB).
1022 // TODO(andydavis) Read cache size from the system.
1023 const int64 cache_size = (256LL << 10) / sizeof(T);
1024
1025 // Fixed costs.
1026 const int64 tile_transform_matrix_size =
1027 tile_spatial_size * tile_spatial_size;
1028 const int64 output_transform_matrix_size =
1029 out_tile_spatial_size * tile_spatial_size;
1030 // Calculate cache reserve size.
1031 const int64 filter_depth_size = in_depth * out_depth * filter_shard_size;
1032 const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
1033 const int64 cache_reserve_size = small_filter ? filter_depth_size : 1024;
1034 // Calculate total fixed cost.
1035 const int64 total_fixed_cost = tile_transform_matrix_size +
1036 output_transform_matrix_size +
1037 cache_reserve_size;
1038
1039 // Per-tile costs.
1040 const int64 buffer1_per_tile_size =
1041 tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
1042 const int64 buffer2_per_tile_size =
1043 std::max(tile_spatial_size * in_depth,
1044 out_tile_spatial_size * out_depth * filter_shard_size);
1045 const int64 packed_tile_per_tile_size = in_depth;
1046 const int64 gemm_out_per_tile_size = out_depth * filter_shard_size;
1047 const int64 total_per_tile_cost =
1048 buffer1_per_tile_size + buffer2_per_tile_size +
1049 packed_tile_per_tile_size + gemm_out_per_tile_size;
1050
1051 const int64 num_tiles_cache = std::max(
1052 int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost);
1053 const int64 num_tiles = std::min(num_tiles_cache, col_tiles);
1054
1055 // Allocate temporary buffer 'buffer1', which is first used for copying
1056 // input tiles, then re-used to buffer gemm output. Calculate the
1057 // required buffer size for 'buffer1', based on max buffer size required
1058 // between copying input tiles and buffering gemm product output.
1059 // buffer1: [max(buf1_tile_size, buf1_out_size)]
1060 const int64 buffer1_tile_size = tile_spatial_size * num_tiles * in_depth;
1061 const int64 buffer1_out_size =
1062 tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1063 const int64 buffer1_size = std::max(buffer1_tile_size, buffer1_out_size);
1064 Tensor buffer1_tensor;
1065 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1066 TensorShape({buffer1_size}),
1067 &buffer1_tensor));
1068 T* buffer1 = buffer1_tensor.template flat<T>().data();
1069
1070 // Allocate temporary buffer 'buffer2', which is first used for
1071 // transformed input tiles, then re-used for transformed output tiles.
1072 // Calculate required buffer size for 'buffer2' as max required buffer
1073 // between input and output transform buffer sizes.
1074 const int64 buffer2_tile_transform_size =
1075 tile_spatial_size * num_tiles * in_depth;
1076 const int64 buffer2_out_transform_size =
1077 out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1078 const int64 buffer2_size =
1079 std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
1080 Tensor buffer2_tensor;
1081 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1082 TensorShape({buffer2_size}),
1083 &buffer2_tensor));
1084 T* buffer2 = buffer2_tensor.template flat<T>().data();
1085
1086 // Allocate temporary buffer to store packed tiles for one coordinate.
1087 // packed tile buffer: [num_tiles, in_depth].
1088 Tensor packed_tile_tensor;
1089 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1090 TensorShape({num_tiles, in_depth}),
1091 &packed_tile_tensor));
1092 T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
1093
1094 // Allocate temporary buffer for gemm output.
1095 // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
1096 Tensor gemm_output_tensor;
1097 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1098 TensorShape({num_tiles, out_depth,
1099 filter_shards_row,
1100 filter_shards_col}),
1101 &gemm_output_tensor));
1102 T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
1103
1104 // Capture state needed for ComputeConv2D inner loop.
1105 Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
1106 filter_shards_col, input, tile_transform_matrix,
1107 output_transform_matrix, buffer1, buffer2,
1108 packed_tile_buffer, gemm_output_buffer);
1109
1110 const int64 row_pad = args.pad_rows;
1111 const int64 col_pad = args.pad_cols;
1112 const int64 unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
1113
1114 const int64 input_image_size = args.in_rows * args.in_cols * in_depth;
1115 const int64 output_image_size = args.out_rows * args.out_cols * out_depth;
1116
1117 const int64 tile_stride_rows = transform->output_shape().rows;
1118 const int64 tile_stride_cols = transform->output_shape().cols;
1119
1120 for (int64 b = batch_start; b < batch_limit; ++b) {
1121 const int64 in_base = b * input_image_size;
1122 const int64 out_base = b * output_image_size;
1123
1124 for (int64 tile_r = 0; tile_r < row_tiles; ++tile_r) {
1125 const int64 in_r = tile_r * tile_stride_rows - row_pad;
1126
1127 // Process unrolled tiles.
1128 for (int64 tile_c = 0; tile_c < unroll_col_limit;
1129 tile_c += num_tiles) {
1130 const int64 in_c = tile_c * tile_stride_cols - col_pad;
1131 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1132 num_tiles, packed_filters, input + in_base,
1133 output + out_base);
1134 }
1135 // Process remaining tiles.
1136 if (unroll_col_limit < col_tiles) {
1137 const int64 rem_tiles = col_tiles - unroll_col_limit;
1138 const int64 in_c = unroll_col_limit * tile_stride_cols - col_pad;
1139 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1140 rem_tiles, packed_filters, input + in_base,
1141 output + out_base);
1142 }
1143 }
1144 }
1145 };
1146 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
1147 const int64 shard_cost = args.out_rows * args.out_cols * args.out_depth *
1148 tile_spatial_size * args.in_depth;
1149 Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
1150 shard_cost, shard);
1151 }
1152 };
1153
1154 } // namespace functor
1155
1156 template struct functor::DeepConv2D<CPUDevice, float>;
1157
1158 } // namespace tensorflow
1159