• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Make this file empty (or nearly empty) so that it can be compiled even when
17 // libxsmm is not available.
18 
19 #ifndef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
20 void dummy_xsmm_conv2d_ensure_file_is_not_empty();
21 #else
22 
23 #define USE_EIGEN_TENSOR
24 #define EIGEN_USE_THREADS
25 
26 #include "tensorflow/core/kernels/xsmm_conv2d.h"
27 
28 #include <stdlib.h>
29 #include <cstring>
30 #if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
31 #include <omp.h>
32 #endif
33 
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/lib/core/blocking_counter.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37 
38 #include "include/libxsmm_cpuid.h"
39 #include "include/libxsmm_malloc.h"
40 #include "src/libxsmm_main.h"  // TODO(bsteiner): API to avoid incl. header from src/
41 
42 #define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \
43   if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
44 #define CHECK_LIBXSMM_DNN(STATUS, MESSAGE)                \
45   CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
46       << " failed: " << libxsmm_dnn_get_error(STATUS);
47 
48 namespace tensorflow {
49 
50 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
51 
52 // Returns true if convolution can be computed efficiently by XsmmConv2D,
53 // returns false otherwise.
CanUseXsmmConv2D(const libxsmm_dnn_conv_desc & desc,TensorFormat data_format)54 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
55                       TensorFormat data_format) {
56   int VECTOR_SIZE;
57   int arch = libxsmm_cpuid_x86();
58 
59   if (arch == LIBXSMM_X86_AVX512_CORE) {
60     VECTOR_SIZE = 16;
61   } else if (arch == LIBXSMM_X86_AVX2) {
62     VECTOR_SIZE = 8;
63   } else {
64     VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
65     return false;
66   }
67 
68   if (data_format != FORMAT_NHWC) {
69     VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
70     return false;
71   }
72   if (desc.K % VECTOR_SIZE != 0) {
73     VLOG(1) << "Cannot use XSMM convolutions: output features count not"
74                " divisible by vector size!";
75     return false;
76   }
77   VLOG(2) << "Can use XSMM convolutions.";
78   return true;
79 }
80 
81 typedef Eigen::ThreadPoolDevice CPUDevice;
82 
83 namespace functor {
84 
copy_RSCK_to_custom(const float * rsck,float * kcrs,int R,int S,int C,int K,int blocksifm,int blocksofm,int ifmblock,int ofmblock,int start,int end)85 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
86                                         int S, int C, int K, int blocksifm,
87                                         int blocksofm, int ifmblock,
88                                         int ofmblock, int start, int end) {
89   LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
90   LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
91   int r, s, k, c, v1, v2;
92 
93   for (k = start; k < end; k++) {
94     for (c = 0; c < blocksifm; c++) {
95       for (r = 0; r < R; r++) {
96         for (s = 0; s < S; s++) {
97           for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
98             for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
99               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
100                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
101                                  ofmblock) =
102                   LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
103             for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
104               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
105                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
106                                  ofmblock) = 0.0f;
107           }
108           for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
109             for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
110               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
111                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
112                                  ofmblock) = 0.0f;
113           }
114         }
115       }
116     }
117   }
118 }
119 
120 struct libxsmm_dnn_registry_key {
121   const libxsmm_dnn_conv_desc descriptor;
libxsmm_dnn_registry_keytensorflow::functor::libxsmm_dnn_registry_key122   libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_)
123       : descriptor(desc_) {}
operator ==tensorflow::functor::libxsmm_dnn_registry_key124   bool operator==(const libxsmm_dnn_registry_key& regkey) const {
125     return 0 == memcmp(&descriptor, &regkey.descriptor, sizeof(descriptor));
126   }
127 };
128 
129 struct HashFunction {
operator ()tensorflow::functor::HashFunction130   std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
131     return libxsmm_hash(&regkey.descriptor, sizeof(regkey.descriptor),
132                         25071975);
133   }
134 };
135 
136 struct libxsmm_dnn_registry_value {
137   libxsmm_dnn_tensor_datalayout* layout_input;
138   libxsmm_dnn_tensor_datalayout* layout_filter;
139   libxsmm_dnn_tensor_datalayout* layout_output;
140   libxsmm_dnn_layer* handle;
141 };
142 
143 typedef libxsmm_tf_allocator<libxsmm_scratch_allocator>
144     libxsmm_tf_scratch_allocator;
145 
146 static class libxsmm_dnn_registry_type {
147  private:
148   typedef std::unordered_map<libxsmm_dnn_registry_key,
149                              libxsmm_dnn_registry_value, HashFunction>
150       container_type;
151 
152  public:
libxsmm_dnn_registry_type()153   libxsmm_dnn_registry_type() {
154     libxsmm_init(); /* must be first */
155 #if !defined(LIBXSMM_LOCAL_ALLOC)
156     {
157       libxsmm_malloc_function malloc_fn;
158       libxsmm_free_function free_fn;
159       malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
160       free_fn.function = libxsmm_tf_scratch_allocator::free;
161       libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn);
162     }
163 #endif
164     LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
165     LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
166   }
~libxsmm_dnn_registry_type()167   ~libxsmm_dnn_registry_type() {
168     LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
169     const container_type::const_iterator end = container.end();
170     for (container_type::const_iterator i = container.begin(); i != end; ++i) {
171       CHECK_LIBXSMM_DNN(
172           libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
173           "destroy input layout");
174       CHECK_LIBXSMM_DNN(
175           libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
176           "destroy output layout");
177       CHECK_LIBXSMM_DNN(
178           libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
179           "destroy filter layout");
180       CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle),
181                         "destroy handle");
182     }
183     LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
184     LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
185     LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
186     libxsmm_finalize();
187   }
find(const libxsmm_dnn_registry_key & regkey)188   libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
189     container_type::iterator i;
190     LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
191     i = container.find(regkey);
192     LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
193     if (i == container.end()) {
194       libxsmm_dnn_err_t status;
195       libxsmm_dnn_registry_value regentry;
196 
197       LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
198       i = container.find(regkey);
199       if (i == container.end()) {  // re-check after lock acquisition
200         regentry.handle =
201             libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
202         if (LIBXSMM_DNN_WARN_FALLBACK != status) {
203           CHECK_LIBXSMM_DNN(status, "create handle");
204         } else {  // warning
205           VLOG(1) << libxsmm_dnn_get_error(status);
206         }
207         regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
208             regentry.handle, LIBXSMM_DNN_INPUT, &status);
209         CHECK_LIBXSMM_DNN(status, "create input layout");
210 
211         regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
212             regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
213         CHECK_LIBXSMM_DNN(status, "create output layout");
214 
215         regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
216             regentry.handle, LIBXSMM_DNN_FILTER, &status);
217         CHECK_LIBXSMM_DNN(status, "create filter layout");
218 
219         i = container.insert(std::make_pair(regkey, regentry)).first;
220       }
221       LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
222     }
223     return i->second;
224   }
225 
226  private:
227   container_type container;
228   LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
229   LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
230 } libxsmm_dnn_registry;
231 
232 // #define LIBXSMM_DETAILED_TIMING
233 
234 template <typename InputPtr, typename FilterPtr, typename OutputPtr>
CallLibxsmmConvGeneric(OpKernelContext * ctx,const libxsmm_dnn_conv_desc & desc,libxsmm_dnn_compute_kind kind,InputPtr input,FilterPtr filter,OutputPtr output)235 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
236                                    const libxsmm_dnn_conv_desc& desc,
237                                    libxsmm_dnn_compute_kind kind,
238                                    InputPtr input, FilterPtr filter,
239                                    OutputPtr output) {
240 #if defined(LIBXSMM_DETAILED_TIMING)
241   libxsmm_timer_tickint l_tick1;
242   libxsmm_timer_tickint l_tick2;
243   libxsmm_timer_tickint l_tick3;
244   libxsmm_timer_tickint l_tick4;
245   libxsmm_timer_tickint l_tick5;
246   libxsmm_timer_tickint l_tick6;
247   libxsmm_timer_tickint l_tick7;
248   libxsmm_timer_tickint l_tick8;
249   libxsmm_timer_tickint l_tick9;
250   libxsmm_timer_tickint l_tick10;
251   l_tick1 = libxsmm_timer_tick();
252 #endif
253 #if defined(LIBXSMM_LOCAL_ALLOC)
254   // setup scoped allocator, which adopts the allocator of the current context
255   const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
256 #endif
257   const libxsmm_dnn_registry_key regkey(desc);
258   const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey);
259   libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter;
260   libxsmm_dnn_err_t status;
261 
262   status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
263   if (status == LIBXSMM_DNN_WARN_FALLBACK) {
264     return false;  // Use non-libxsmm code
265   }
266   CHECK_LIBXSMM_DNN(status, "code generation");
267 
268 #if defined(LIBXSMM_DETAILED_TIMING)
269   l_tick2 = libxsmm_timer_tick();
270 #endif
271 
272   const int ifmblock = regentry.handle->ifmblock;
273   const int ofmblock = regentry.handle->ofmblock;
274 
275   const int blocksifm =
276       (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
277   const int blocksofm =
278       (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
279 
280   const size_t filter_size =
281       blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
282   float* const native_filter = (float*)libxsmm_aligned_scratch(
283       filter_size * sizeof(float), 2097152 /*alignment*/);
284 
285   const DeviceBase::CpuWorkerThreads* const worker_threads =
286       ctx->device()->tensorflow_cpu_worker_threads();
287   const int num_threads = worker_threads->num_threads;
288 
289 #if 1
290   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
291       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
292     if (blocksofm > num_threads) {
293       const int work = blocksofm;
294       BlockingCounter count(num_threads);
295       for (int i = 0; i < num_threads; ++i) {
296         worker_threads->workers->Schedule([=, &count]() {
297           const int start = work / num_threads * i;
298           const int end = (start + work / num_threads) > work
299                               ? work
300                               : start + work / num_threads;
301           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
302                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
303                               start, end);
304           count.DecrementCount();
305         });
306       }
307       count.Wait();
308     } else {
309       const int work = blocksofm;
310       const int num_tasks = work;
311 
312       BlockingCounter count(num_tasks);
313       for (int i = 0; i < num_tasks; ++i) {
314         worker_threads->workers->Schedule([=, &count]() {
315           const int start = i;
316           const int end = i + 1;
317           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
318                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
319                               start, end);
320           count.DecrementCount();
321         });
322       }
323       count.Wait();
324     }
325   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
326     // weight update buffer must be in the right format
327     // (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
328     libxsmm_filter =
329         libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
330     CHECK_LIBXSMM_DNN(status, "link filter with layout");
331   }
332 #else
333   memset(native_filter, 0, filter_size * sizeof(float));
334 #endif
335 
336 #if defined(LIBXSMM_DETAILED_TIMING)
337   l_tick3 = libxsmm_timer_tick();
338 #endif
339 
340   // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
341   libxsmm_input =
342       libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
343   CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
344 
345   // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
346   libxsmm_output =
347       libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
348   CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
349 
350   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
351       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
352     // LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
353     libxsmm_filter =
354         libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
355     CHECK_LIBXSMM_DNN(status, "link filter with layout");
356     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
357                                               LIBXSMM_DNN_REGULAR_FILTER),
358                       "bind filter to handle");
359   }
360   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
361     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
362                                               LIBXSMM_DNN_REGULAR_INPUT),
363                       "bind input forward");
364     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
365                                               LIBXSMM_DNN_REGULAR_FILTER),
366                       "bind filter forward");
367     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
368                                               LIBXSMM_DNN_REGULAR_OUTPUT),
369                       "bind output forward");
370   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
371     CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
372     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
373                                               LIBXSMM_DNN_GRADIENT_INPUT),
374                       "bind input backward");
375     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
376                                               LIBXSMM_DNN_REGULAR_FILTER),
377                       "bind filter backward");
378     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
379                                               LIBXSMM_DNN_GRADIENT_OUTPUT),
380                       "bind output backward");
381   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
382     CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter),
383                       "zeroing filter");
384     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
385                                               LIBXSMM_DNN_REGULAR_INPUT),
386                       "bind input weight update");
387     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
388                                               LIBXSMM_DNN_GRADIENT_FILTER),
389                       "bind filter weight update");
390     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
391                                               LIBXSMM_DNN_GRADIENT_OUTPUT),
392                       "bind output weight update");
393   } else {
394     assert(0 /*should not happen*/);
395   }
396 
397 #if defined(LIBXSMM_DETAILED_TIMING)
398   l_tick4 = libxsmm_timer_tick();
399 #endif
400 
401   const size_t scratch_size = libxsmm_dnn_get_scratch_size(
402       regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
403   CHECK_LIBXSMM_DNN(status, "get scratch size");
404   void* const scratch =
405       libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*/);
406   CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
407   CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch(
408                         regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
409                     "binding scratch");
410 
411 #if defined(LIBXSMM_DETAILED_TIMING)
412   l_tick5 = libxsmm_timer_tick();
413 #endif
414 
415   if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
416     libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
417   }
418 
419 #if defined(LIBXSMM_DETAILED_TIMING)
420   l_tick6 = libxsmm_timer_tick();
421 #endif
422 
423 #if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
424   BlockingCounter counter(num_threads);
425 
426   for (int i = 0; i < num_threads; ++i) {
427     worker_threads->workers->Schedule([=, &counter]() {
428       CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
429                         "worker");
430       counter.DecrementCount();
431     });
432   }
433   counter.Wait();
434 #else
435 #pragma omp parallel
436   {
437     CHECK_LIBXSMM_DNN(
438         libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
439         "worker");
440   }
441 #endif
442 
443 #if defined(LIBXSMM_DETAILED_TIMING)
444   l_tick7 = libxsmm_timer_tick();
445 #endif
446 
447   if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
448     libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
449   }
450 
451 #if defined(LIBXSMM_DETAILED_TIMING)
452   l_tick8 = libxsmm_timer_tick();
453 #endif
454 
455   /* clean up */
456   CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle,
457                                                 LIBXSMM_DNN_COMPUTE_KIND_ALL),
458                     "release scratch");
459   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
460     CHECK_LIBXSMM_DNN(
461         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
462         "release input");
463     CHECK_LIBXSMM_DNN(
464         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
465         "release output");
466     CHECK_LIBXSMM_DNN(
467         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
468         "release filter");
469   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
470     CHECK_LIBXSMM_DNN(
471         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
472         "release input");
473     CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
474                                                  LIBXSMM_DNN_GRADIENT_OUTPUT),
475                       "release output");
476     CHECK_LIBXSMM_DNN(
477         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
478         "release filter");
479   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
480     CHECK_LIBXSMM_DNN(
481         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
482         "release input");
483     CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
484                                                  LIBXSMM_DNN_GRADIENT_OUTPUT),
485                       "release output");
486     CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
487                                                  LIBXSMM_DNN_GRADIENT_FILTER),
488                       "release filter");
489   } else {
490     /* shouldn't happen */
491   }
492   CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
493   CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output),
494                     "destroy output");
495   CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter),
496                     "destroy filter");
497 
498 #if defined(LIBXSMM_DETAILED_TIMING)
499   l_tick9 = libxsmm_timer_tick();
500 #endif
501 
502   libxsmm_free(native_filter);
503   libxsmm_free(scratch);
504 
505 #if defined(LIBXSMM_DETAILED_TIMING)
506   l_tick10 = libxsmm_timer_tick();
507   printf(
508       "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
509       "%f, %f, %f\n",
510       desc.N, desc.C, desc.K, desc.R, desc.S,
511       libxsmm_timer_duration(l_tick1, l_tick2),
512       libxsmm_timer_duration(l_tick2, l_tick3),
513       libxsmm_timer_duration(l_tick3, l_tick4),
514       libxsmm_timer_duration(l_tick4, l_tick5),
515       libxsmm_timer_duration(l_tick5, l_tick6),
516       libxsmm_timer_duration(l_tick6, l_tick7),
517       libxsmm_timer_duration(l_tick7, l_tick8),
518       libxsmm_timer_duration(l_tick8, l_tick9),
519       libxsmm_timer_duration(l_tick9, l_tick10),
520       libxsmm_timer_duration(l_tick1, l_tick10));
521 #endif
522 
523   return true;  // Succeeded
524 }
525 
526 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
527 template <typename T>
528 struct XsmmFwdConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmFwdConv2D529   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
530                   const T* input, const T* filter, T* output) {
531     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
532                                   input, filter, output);
533   }
534 };
535 #endif
536 
537 #ifdef TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
538 template <typename T>
539 struct XsmmBkwInputConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwInputConv2D540   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
541                   T* input, const T* filter, const T* output) {
542     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
543                                   input, filter, output);
544   }
545 };
546 
547 template <typename T>
548 struct XsmmBkwFilterConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwFilterConv2D549   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
550                   const T* input, T* filter, const T* output) {
551     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
552                                   input, filter, output);
553   }
554 };
555 #endif
556 
557 }  // namespace functor
558 
559 template struct functor::XsmmFwdConv2D<CPUDevice, float>;
560 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>;
561 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>;
562 
563 }  // namespace tensorflow
564 
565 #endif  // TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
566