1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Make this file empty (or nearly empty) so that it can be compiled even when
17 // libxsmm is not available.
18
19 #ifndef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
20 void dummy_xsmm_conv2d_ensure_file_is_not_empty();
21 #else
22
23 #define USE_EIGEN_TENSOR
24 #define EIGEN_USE_THREADS
25
26 #include "tensorflow/core/kernels/xsmm_conv2d.h"
27
28 #include <stdlib.h>
29 #include <cstring>
30 #if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
31 #include <omp.h>
32 #endif
33
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/lib/core/blocking_counter.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37
38 #include "include/libxsmm_cpuid.h"
39 #include "include/libxsmm_malloc.h"
40 #include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
41
42 #define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \
43 if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
44 #define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) \
45 CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
46 << " failed: " << libxsmm_dnn_get_error(STATUS);
47
48 namespace tensorflow {
49
50 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
51
52 // Returns true if convolution can be computed efficiently by XsmmConv2D,
53 // returns false otherwise.
CanUseXsmmConv2D(const libxsmm_dnn_conv_desc & desc,TensorFormat data_format)54 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
55 TensorFormat data_format) {
56 int VECTOR_SIZE;
57 int arch = libxsmm_cpuid_x86();
58
59 if (arch == LIBXSMM_X86_AVX512_CORE) {
60 VECTOR_SIZE = 16;
61 } else if (arch == LIBXSMM_X86_AVX2) {
62 VECTOR_SIZE = 8;
63 } else {
64 VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
65 return false;
66 }
67
68 if (data_format != FORMAT_NHWC) {
69 VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
70 return false;
71 }
72 if (desc.K % VECTOR_SIZE != 0) {
73 VLOG(1) << "Cannot use XSMM convolutions: output features count not"
74 " divisible by vector size!";
75 return false;
76 }
77 VLOG(2) << "Can use XSMM convolutions.";
78 return true;
79 }
80
81 typedef Eigen::ThreadPoolDevice CPUDevice;
82
83 namespace functor {
84
copy_RSCK_to_custom(const float * rsck,float * kcrs,int R,int S,int C,int K,int blocksifm,int blocksofm,int ifmblock,int ofmblock,int start,int end)85 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
86 int S, int C, int K, int blocksifm,
87 int blocksofm, int ifmblock,
88 int ofmblock, int start, int end) {
89 LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
90 LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
91 int r, s, k, c, v1, v2;
92
93 for (k = start; k < end; k++) {
94 for (c = 0; c < blocksifm; c++) {
95 for (r = 0; r < R; r++) {
96 for (s = 0; s < S; s++) {
97 for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
98 for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
99 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
100 v2 - k * ofmblock, blocksifm, R, S, ifmblock,
101 ofmblock) =
102 LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
103 for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
104 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
105 v2 - k * ofmblock, blocksifm, R, S, ifmblock,
106 ofmblock) = 0.0f;
107 }
108 for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
109 for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
110 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
111 v2 - k * ofmblock, blocksifm, R, S, ifmblock,
112 ofmblock) = 0.0f;
113 }
114 }
115 }
116 }
117 }
118 }
119
120 struct libxsmm_dnn_registry_key {
121 const libxsmm_dnn_conv_desc descriptor;
libxsmm_dnn_registry_keytensorflow::functor::libxsmm_dnn_registry_key122 libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_)
123 : descriptor(desc_) {}
operator ==tensorflow::functor::libxsmm_dnn_registry_key124 bool operator==(const libxsmm_dnn_registry_key& regkey) const {
125 return 0 == memcmp(&descriptor, ®key.descriptor, sizeof(descriptor));
126 }
127 };
128
129 struct HashFunction {
operator ()tensorflow::functor::HashFunction130 std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
131 return libxsmm_hash(®key.descriptor, sizeof(regkey.descriptor),
132 25071975);
133 }
134 };
135
136 struct libxsmm_dnn_registry_value {
137 libxsmm_dnn_tensor_datalayout* layout_input;
138 libxsmm_dnn_tensor_datalayout* layout_filter;
139 libxsmm_dnn_tensor_datalayout* layout_output;
140 libxsmm_dnn_layer* handle;
141 };
142
143 typedef libxsmm_tf_allocator<libxsmm_scratch_allocator>
144 libxsmm_tf_scratch_allocator;
145
146 static class libxsmm_dnn_registry_type {
147 private:
148 typedef std::unordered_map<libxsmm_dnn_registry_key,
149 libxsmm_dnn_registry_value, HashFunction>
150 container_type;
151
152 public:
libxsmm_dnn_registry_type()153 libxsmm_dnn_registry_type() {
154 libxsmm_init(); /* must be first */
155 #if !defined(LIBXSMM_LOCAL_ALLOC)
156 {
157 libxsmm_malloc_function malloc_fn;
158 libxsmm_free_function free_fn;
159 malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
160 free_fn.function = libxsmm_tf_scratch_allocator::free;
161 libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn);
162 }
163 #endif
164 LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
165 LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
166 }
~libxsmm_dnn_registry_type()167 ~libxsmm_dnn_registry_type() {
168 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
169 const container_type::const_iterator end = container.end();
170 for (container_type::const_iterator i = container.begin(); i != end; ++i) {
171 CHECK_LIBXSMM_DNN(
172 libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
173 "destroy input layout");
174 CHECK_LIBXSMM_DNN(
175 libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
176 "destroy output layout");
177 CHECK_LIBXSMM_DNN(
178 libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
179 "destroy filter layout");
180 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle),
181 "destroy handle");
182 }
183 LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
184 LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
185 LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
186 libxsmm_finalize();
187 }
find(const libxsmm_dnn_registry_key & regkey)188 libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
189 container_type::iterator i;
190 LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
191 i = container.find(regkey);
192 LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
193 if (i == container.end()) {
194 libxsmm_dnn_err_t status;
195 libxsmm_dnn_registry_value regentry;
196
197 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
198 i = container.find(regkey);
199 if (i == container.end()) { // re-check after lock acquisition
200 regentry.handle =
201 libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
202 if (LIBXSMM_DNN_WARN_FALLBACK != status) {
203 CHECK_LIBXSMM_DNN(status, "create handle");
204 } else { // warning
205 VLOG(1) << libxsmm_dnn_get_error(status);
206 }
207 regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
208 regentry.handle, LIBXSMM_DNN_INPUT, &status);
209 CHECK_LIBXSMM_DNN(status, "create input layout");
210
211 regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
212 regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
213 CHECK_LIBXSMM_DNN(status, "create output layout");
214
215 regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
216 regentry.handle, LIBXSMM_DNN_FILTER, &status);
217 CHECK_LIBXSMM_DNN(status, "create filter layout");
218
219 i = container.insert(std::make_pair(regkey, regentry)).first;
220 }
221 LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
222 }
223 return i->second;
224 }
225
226 private:
227 container_type container;
228 LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
229 LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
230 } libxsmm_dnn_registry;
231
232 // #define LIBXSMM_DETAILED_TIMING
233
234 template <typename InputPtr, typename FilterPtr, typename OutputPtr>
CallLibxsmmConvGeneric(OpKernelContext * ctx,const libxsmm_dnn_conv_desc & desc,libxsmm_dnn_compute_kind kind,InputPtr input,FilterPtr filter,OutputPtr output)235 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
236 const libxsmm_dnn_conv_desc& desc,
237 libxsmm_dnn_compute_kind kind,
238 InputPtr input, FilterPtr filter,
239 OutputPtr output) {
240 #if defined(LIBXSMM_DETAILED_TIMING)
241 libxsmm_timer_tickint l_tick1;
242 libxsmm_timer_tickint l_tick2;
243 libxsmm_timer_tickint l_tick3;
244 libxsmm_timer_tickint l_tick4;
245 libxsmm_timer_tickint l_tick5;
246 libxsmm_timer_tickint l_tick6;
247 libxsmm_timer_tickint l_tick7;
248 libxsmm_timer_tickint l_tick8;
249 libxsmm_timer_tickint l_tick9;
250 libxsmm_timer_tickint l_tick10;
251 l_tick1 = libxsmm_timer_tick();
252 #endif
253 #if defined(LIBXSMM_LOCAL_ALLOC)
254 // setup scoped allocator, which adopts the allocator of the current context
255 const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
256 #endif
257 const libxsmm_dnn_registry_key regkey(desc);
258 const libxsmm_dnn_registry_value regentry = libxsmm_dnn_registry.find(regkey);
259 libxsmm_dnn_tensor *libxsmm_input, *libxsmm_output, *libxsmm_filter;
260 libxsmm_dnn_err_t status;
261
262 status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
263 if (status == LIBXSMM_DNN_WARN_FALLBACK) {
264 return false; // Use non-libxsmm code
265 }
266 CHECK_LIBXSMM_DNN(status, "code generation");
267
268 #if defined(LIBXSMM_DETAILED_TIMING)
269 l_tick2 = libxsmm_timer_tick();
270 #endif
271
272 const int ifmblock = regentry.handle->ifmblock;
273 const int ofmblock = regentry.handle->ofmblock;
274
275 const int blocksifm =
276 (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
277 const int blocksofm =
278 (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
279
280 const size_t filter_size =
281 blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
282 float* const native_filter = (float*)libxsmm_aligned_scratch(
283 filter_size * sizeof(float), 2097152 /*alignment*/);
284
285 const DeviceBase::CpuWorkerThreads* const worker_threads =
286 ctx->device()->tensorflow_cpu_worker_threads();
287 const int num_threads = worker_threads->num_threads;
288
289 #if 1
290 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
291 kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
292 if (blocksofm > num_threads) {
293 const int work = blocksofm;
294 BlockingCounter count(num_threads);
295 for (int i = 0; i < num_threads; ++i) {
296 worker_threads->workers->Schedule([=, &count]() {
297 const int start = work / num_threads * i;
298 const int end = (start + work / num_threads) > work
299 ? work
300 : start + work / num_threads;
301 copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
302 desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
303 start, end);
304 count.DecrementCount();
305 });
306 }
307 count.Wait();
308 } else {
309 const int work = blocksofm;
310 const int num_tasks = work;
311
312 BlockingCounter count(num_tasks);
313 for (int i = 0; i < num_tasks; ++i) {
314 worker_threads->workers->Schedule([=, &count]() {
315 const int start = i;
316 const int end = i + 1;
317 copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
318 desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
319 start, end);
320 count.DecrementCount();
321 });
322 }
323 count.Wait();
324 }
325 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
326 // weight update buffer must be in the right format
327 // (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
328 libxsmm_filter =
329 libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
330 CHECK_LIBXSMM_DNN(status, "link filter with layout");
331 }
332 #else
333 memset(native_filter, 0, filter_size * sizeof(float));
334 #endif
335
336 #if defined(LIBXSMM_DETAILED_TIMING)
337 l_tick3 = libxsmm_timer_tick();
338 #endif
339
340 // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
341 libxsmm_input =
342 libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
343 CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
344
345 // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
346 libxsmm_output =
347 libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
348 CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
349
350 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
351 kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
352 // LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
353 libxsmm_filter =
354 libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
355 CHECK_LIBXSMM_DNN(status, "link filter with layout");
356 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
357 LIBXSMM_DNN_REGULAR_FILTER),
358 "bind filter to handle");
359 }
360 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
361 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
362 LIBXSMM_DNN_REGULAR_INPUT),
363 "bind input forward");
364 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
365 LIBXSMM_DNN_REGULAR_FILTER),
366 "bind filter forward");
367 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
368 LIBXSMM_DNN_REGULAR_OUTPUT),
369 "bind output forward");
370 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
371 CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
372 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
373 LIBXSMM_DNN_GRADIENT_INPUT),
374 "bind input backward");
375 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
376 LIBXSMM_DNN_REGULAR_FILTER),
377 "bind filter backward");
378 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
379 LIBXSMM_DNN_GRADIENT_OUTPUT),
380 "bind output backward");
381 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
382 CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter),
383 "zeroing filter");
384 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
385 LIBXSMM_DNN_REGULAR_INPUT),
386 "bind input weight update");
387 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
388 LIBXSMM_DNN_GRADIENT_FILTER),
389 "bind filter weight update");
390 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
391 LIBXSMM_DNN_GRADIENT_OUTPUT),
392 "bind output weight update");
393 } else {
394 assert(0 /*should not happen*/);
395 }
396
397 #if defined(LIBXSMM_DETAILED_TIMING)
398 l_tick4 = libxsmm_timer_tick();
399 #endif
400
401 const size_t scratch_size = libxsmm_dnn_get_scratch_size(
402 regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
403 CHECK_LIBXSMM_DNN(status, "get scratch size");
404 void* const scratch =
405 libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*/);
406 CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
407 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch(
408 regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
409 "binding scratch");
410
411 #if defined(LIBXSMM_DETAILED_TIMING)
412 l_tick5 = libxsmm_timer_tick();
413 #endif
414
415 if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
416 libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
417 }
418
419 #if defined(LIBXSMM_DETAILED_TIMING)
420 l_tick6 = libxsmm_timer_tick();
421 #endif
422
423 #if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
424 BlockingCounter counter(num_threads);
425
426 for (int i = 0; i < num_threads; ++i) {
427 worker_threads->workers->Schedule([=, &counter]() {
428 CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
429 "worker");
430 counter.DecrementCount();
431 });
432 }
433 counter.Wait();
434 #else
435 #pragma omp parallel
436 {
437 CHECK_LIBXSMM_DNN(
438 libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
439 "worker");
440 }
441 #endif
442
443 #if defined(LIBXSMM_DETAILED_TIMING)
444 l_tick7 = libxsmm_timer_tick();
445 #endif
446
447 if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
448 libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
449 }
450
451 #if defined(LIBXSMM_DETAILED_TIMING)
452 l_tick8 = libxsmm_timer_tick();
453 #endif
454
455 /* clean up */
456 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle,
457 LIBXSMM_DNN_COMPUTE_KIND_ALL),
458 "release scratch");
459 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
460 CHECK_LIBXSMM_DNN(
461 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
462 "release input");
463 CHECK_LIBXSMM_DNN(
464 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
465 "release output");
466 CHECK_LIBXSMM_DNN(
467 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
468 "release filter");
469 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
470 CHECK_LIBXSMM_DNN(
471 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
472 "release input");
473 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
474 LIBXSMM_DNN_GRADIENT_OUTPUT),
475 "release output");
476 CHECK_LIBXSMM_DNN(
477 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
478 "release filter");
479 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
480 CHECK_LIBXSMM_DNN(
481 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
482 "release input");
483 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
484 LIBXSMM_DNN_GRADIENT_OUTPUT),
485 "release output");
486 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
487 LIBXSMM_DNN_GRADIENT_FILTER),
488 "release filter");
489 } else {
490 /* shouldn't happen */
491 }
492 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
493 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output),
494 "destroy output");
495 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter),
496 "destroy filter");
497
498 #if defined(LIBXSMM_DETAILED_TIMING)
499 l_tick9 = libxsmm_timer_tick();
500 #endif
501
502 libxsmm_free(native_filter);
503 libxsmm_free(scratch);
504
505 #if defined(LIBXSMM_DETAILED_TIMING)
506 l_tick10 = libxsmm_timer_tick();
507 printf(
508 "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
509 "%f, %f, %f\n",
510 desc.N, desc.C, desc.K, desc.R, desc.S,
511 libxsmm_timer_duration(l_tick1, l_tick2),
512 libxsmm_timer_duration(l_tick2, l_tick3),
513 libxsmm_timer_duration(l_tick3, l_tick4),
514 libxsmm_timer_duration(l_tick4, l_tick5),
515 libxsmm_timer_duration(l_tick5, l_tick6),
516 libxsmm_timer_duration(l_tick6, l_tick7),
517 libxsmm_timer_duration(l_tick7, l_tick8),
518 libxsmm_timer_duration(l_tick8, l_tick9),
519 libxsmm_timer_duration(l_tick9, l_tick10),
520 libxsmm_timer_duration(l_tick1, l_tick10));
521 #endif
522
523 return true; // Succeeded
524 }
525
526 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
527 template <typename T>
528 struct XsmmFwdConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmFwdConv2D529 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
530 const T* input, const T* filter, T* output) {
531 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
532 input, filter, output);
533 }
534 };
535 #endif
536
537 #ifdef TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
538 template <typename T>
539 struct XsmmBkwInputConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwInputConv2D540 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
541 T* input, const T* filter, const T* output) {
542 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
543 input, filter, output);
544 }
545 };
546
547 template <typename T>
548 struct XsmmBkwFilterConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwFilterConv2D549 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
550 const T* input, T* filter, const T* output) {
551 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
552 input, filter, output);
553 }
554 };
555 #endif
556
557 } // namespace functor
558
559 template struct functor::XsmmFwdConv2D<CPUDevice, float>;
560 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>;
561 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>;
562
563 } // namespace tensorflow
564
565 #endif // TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
566