• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Make this file empty (or nearly empty) so that it can be compiled even when
17 // libxsmm is not available.
18 
19 #ifndef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
20 void dummy_xsmm_conv2d_ensure_file_is_not_empty();
21 #else
22 
23 #define USE_EIGEN_TENSOR
24 #define EIGEN_USE_THREADS
25 
26 #include "tensorflow/core/kernels/xsmm_conv2d.h"
27 
28 #include <stdlib.h>
29 #include <cstring>
30 
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/blocking_counter.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 
35 #include "include/libxsmm_cpuid.h"
36 #include "include/libxsmm_malloc.h"
37 #include "third_party/libxsmm/src/libxsmm_main.h"  // TODO(bsteiner): API to avoid incl. header from src/
38 
39 namespace tensorflow {
40 
41 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
42 
43 // Returns true if convolution can be computed efficiently by XsmmConv2D,
44 // returns false otherwise.
CanUseXsmmConv2D(const libxsmm_dnn_conv_desc & desc,TensorFormat data_format)45 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
46                       TensorFormat data_format) {
47   int VECTOR_SIZE;
48   int arch = libxsmm_cpuid_x86();
49 
50   if (arch == LIBXSMM_X86_AVX512_CORE) {
51     VECTOR_SIZE = 16;
52   } else if (arch == LIBXSMM_X86_AVX2) {
53     VECTOR_SIZE = 8;
54   } else {
55     VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
56     return false;
57   }
58 
59   if (data_format != FORMAT_NHWC) {
60     VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
61     return false;
62   }
63   if (desc.K % VECTOR_SIZE != 0) {
64     VLOG(1) << "Cannot use XSMM convolutions: output features count not"
65                " divisible by vector size!";
66     return false;
67   }
68   VLOG(2) << "Can use XSMM convolutions.";
69   return true;
70 }
71 
72 typedef Eigen::ThreadPoolDevice CPUDevice;
73 
74 namespace functor {
75 
chk_libxsmm_err(libxsmm_dnn_err_t status,string msg)76 static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
77   if (status != LIBXSMM_DNN_SUCCESS) {
78     VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status);
79   }
80 }
81 
copy_RSCK_to_custom(const float * rsck,float * kcrs,int R,int S,int C,int K,int blocksifm,int blocksofm,int ifmblock,int ofmblock,int start,int end)82 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
83                                         int S, int C, int K, int blocksifm,
84                                         int blocksofm, int ifmblock,
85                                         int ofmblock, int start, int end) {
86   LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
87   LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
88   int r, s, k, c, v1, v2;
89 
90   for (k = start; k < end; k++) {
91     for (c = 0; c < blocksifm; c++) {
92       for (r = 0; r < R; r++) {
93         for (s = 0; s < S; s++) {
94           for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
95             for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
96               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
97                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
98                                  ofmblock) =
99                   LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
100             for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
101               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
102                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
103                                  ofmblock) = 0.0f;
104           }
105           for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
106             for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
107               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
108                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
109                                  ofmblock) = 0.0f;
110           }
111         }
112       }
113     }
114   }
115 }
116 
117 class libxsmm_dnn_conv_desc_wrap {
118  public:
119   const libxsmm_dnn_conv_desc d;
120 
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc & d_)121   libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
operator ==(const libxsmm_dnn_conv_desc_wrap & w) const122   bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
123     return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
124             d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
125             d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
126   }
127 };
128 
129 struct HashFunction {
operator ()tensorflow::functor::HashFunction130   std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
131     return libxsmm_hash(&w.d, sizeof(w.d), 25071975);
132   }
133 };
134 
135 class handles {
136  public:
find(const libxsmm_dnn_conv_desc_wrap & w)137   libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
138     std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
139                        HashFunction>::iterator i = libxsmm_handles.find(w);
140     if (i == libxsmm_handles.end()) {
141       libxsmm_dnn_err_t status;
142       libxsmm_dnn_layer* libxsmm_handle =
143           libxsmm_dnn_create_conv_layer(w.d, &status);
144       chk_libxsmm_err(status, "Create handle");
145       libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
146       return libxsmm_handle;
147     } else {
148       return i->second;
149     }
150   }
~handles()151   ~handles() {
152     std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
153                        HashFunction>::iterator i;
154     for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
155       chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
156                       "Destroy handle");
157   }
158 
159  private:
160   std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
161                      HashFunction>
162       libxsmm_handles;
163 };
164 
165 static handles libxsmm_handles;
166 
167 // #define LIBXSMM_DETAILED_TIMING
168 
169 template <typename InputPtr, typename FilterPtr, typename OutputPtr>
CallLibxsmmConvGeneric(OpKernelContext * ctx,const libxsmm_dnn_conv_desc & desc,libxsmm_dnn_compute_kind kind,InputPtr input,FilterPtr filter,OutputPtr output)170 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
171                                    const libxsmm_dnn_conv_desc& desc,
172                                    libxsmm_dnn_compute_kind kind,
173                                    InputPtr input, FilterPtr filter,
174                                    OutputPtr output) {
175 #if defined(LIBXSMM_DETAILED_TIMING)
176   uint64 l_tick1;
177   uint64 l_tick2;
178   uint64 l_tick3;
179   uint64 l_tick4;
180   uint64 l_tick5;
181   uint64 l_tick6;
182   uint64 l_tick7;
183   uint64 l_tick8;
184   uint64 l_tick9;
185   uint64 l_tick10;
186   l_tick1 = libxsmm_timer_tick();
187 #endif
188   // setup scoped allocator, which adopts the allocator from the context
189   const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
190   libxsmm_dnn_err_t status;
191   libxsmm_dnn_layer* libxsmm_handle;
192   libxsmm_dnn_conv_desc_wrap w(desc);
193   void* scratch;
194 
195   // if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
196   libxsmm_handle = libxsmm_handles.find(w);
197   // else{
198   //  libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
199   //  chk_libxsmm_err(status, "Create handle");
200   //}
201 
202   status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
203   if (status == LIBXSMM_DNN_WARN_FALLBACK) {
204     return false;  // Use non-libxsmm code
205   }
206   chk_libxsmm_err(status, "Check codegen status");
207 
208   libxsmm_dnn_buffer* libxsmm_input;
209   libxsmm_dnn_buffer* libxsmm_output;
210   libxsmm_dnn_filter* libxsmm_filter;
211 
212 #if defined(LIBXSMM_DETAILED_TIMING)
213   l_tick2 = libxsmm_timer_tick();
214 #endif
215 
216   int ifmblock = (libxsmm_handle->ifmblock);
217   int ofmblock = (libxsmm_handle->ofmblock);
218 
219   int blocksifm =
220       desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
221   int blocksofm =
222       desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
223   float* native_filter =
224       (float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
225                                           ifmblock * ofmblock * sizeof(float),
226                                       2097152);
227 
228   const DeviceBase::CpuWorkerThreads* worker_threads =
229       ctx->device()->tensorflow_cpu_worker_threads();
230 
231   int num_threads = worker_threads->num_threads;
232 
233 #if 1
234   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
235       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
236     if (blocksofm > num_threads) {
237       int work = blocksofm;
238       BlockingCounter count(num_threads);
239       for (int i = 0; i < num_threads; ++i) {
240         worker_threads->workers->Schedule([=, &count]() {
241           int start = work / num_threads * i;
242           int end = (start + work / num_threads) > work
243                         ? work
244                         : start + work / num_threads;
245           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
246                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
247                               start, end);
248           count.DecrementCount();
249         });
250       }
251       count.Wait();
252     } else {
253       int work = blocksofm;
254       int num_threads = work;
255 
256       BlockingCounter count(num_threads);
257       for (int i = 0; i < num_threads; ++i) {
258         worker_threads->workers->Schedule([=, &count]() {
259           int start = i;
260           int end = i + 1;
261           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
262                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
263                               start, end);
264           count.DecrementCount();
265         });
266       }
267       count.Wait();
268     }
269   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
270     // Added: for weight update
271     libxsmm_filter =
272         libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
273                                 LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
274     chk_libxsmm_err(status,
275                     "Link filter");  // weight update is in RSCK as
276                                      // filter should be returned in RSCK
277                                      // format
278   }
279 #else
280   memset(native_filter, 0,
281          blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
282              sizeof(float));
283 #endif
284 
285 #if defined(LIBXSMM_DETAILED_TIMING)
286   l_tick3 = libxsmm_timer_tick();
287 #endif
288 
289   libxsmm_input =
290       libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
291                               LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
292   chk_libxsmm_err(status, "Link input buffer");
293   libxsmm_output =
294       libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
295                               LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
296   chk_libxsmm_err(status, "Link output buffer");
297   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
298       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
299     libxsmm_filter = libxsmm_dnn_link_filter(
300         libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
301         LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
302     chk_libxsmm_err(status, "Link filter");
303   }
304   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
305     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
306                                             LIBXSMM_DNN_REGULAR_INPUT),
307                     "Bind input forward");
308     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
309                                             LIBXSMM_DNN_REGULAR_OUTPUT),
310                     "Bind output forward");
311     chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
312                                             LIBXSMM_DNN_REGULAR_FILTER),
313                     "Bind filter forward");
314   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
315     chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
316 
317     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
318                                             LIBXSMM_DNN_GRADIENT_INPUT),
319                     "Bind input backward");
320     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
321                                             LIBXSMM_DNN_GRADIENT_OUTPUT),
322                     "Bind output backward");
323     chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
324                                             LIBXSMM_DNN_REGULAR_FILTER),
325                     "Bind filter backward");
326   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
327     chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
328 
329     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
330                                             LIBXSMM_DNN_REGULAR_INPUT),
331                     "Bind input weight update");
332     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
333                                             LIBXSMM_DNN_GRADIENT_OUTPUT),
334                     "Bind output weight update");
335     chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
336                                             LIBXSMM_DNN_GRADIENT_FILTER),
337                     "Bind filter weight update");
338   } else {
339     /* shouldn't happen */
340   }
341 
342 #if defined(LIBXSMM_DETAILED_TIMING)
343   l_tick4 = libxsmm_timer_tick();
344 #endif
345 
346   /* bind scratch */
347   scratch = (void*)libxsmm_aligned_scratch(
348       libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
349                                    &status),
350       2097152);
351   chk_libxsmm_err(status, "scratch allocation");
352   chk_libxsmm_err(libxsmm_dnn_bind_scratch(
353                       libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
354                   "binding scratch");
355 
356 #if defined(LIBXSMM_DETAILED_TIMING)
357   l_tick5 = libxsmm_timer_tick();
358 #endif
359 
360   if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
361     libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
362   }
363 
364 #if defined(LIBXSMM_DETAILED_TIMING)
365   l_tick6 = libxsmm_timer_tick();
366 #endif
367 
368   BlockingCounter counter(num_threads);
369 
370   for (int i = 0; i < num_threads; ++i) {
371     worker_threads->workers->Schedule([=, &counter]() {
372       chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
373                       "Worker");
374       counter.DecrementCount();
375     });
376   }
377   counter.Wait();
378 
379 #if defined(LIBXSMM_DETAILED_TIMING)
380   l_tick7 = libxsmm_timer_tick();
381 #endif
382 
383   if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
384     libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
385   }
386 
387 #if defined(LIBXSMM_DETAILED_TIMING)
388   l_tick8 = libxsmm_timer_tick();
389 #endif
390 
391   /* clean up */
392   chk_libxsmm_err(
393       libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
394       "release scratch");
395   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
396     chk_libxsmm_err(
397         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
398         "release input");
399     chk_libxsmm_err(
400         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
401         "release output");
402     chk_libxsmm_err(
403         libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
404         "release filter");
405   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
406     chk_libxsmm_err(
407         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
408         "release input");
409     chk_libxsmm_err(
410         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
411         "release output");
412     chk_libxsmm_err(
413         libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
414         "release filter");
415   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
416     chk_libxsmm_err(
417         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
418         "release input");
419     chk_libxsmm_err(
420         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
421         "release output");
422     chk_libxsmm_err(
423         libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
424         "release filter");
425   } else {
426     /* shouldn't happen */
427   }
428   chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
429   chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
430   chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
431 
432 #if defined(LIBXSMM_DETAILED_TIMING)
433   l_tick9 = libxsmm_timer_tick();
434 #endif
435 
436   // if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
437   // chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
438   //               "Destroy handle");
439 
440   libxsmm_free(native_filter);
441   libxsmm_free(scratch);
442 
443 #if defined(LIBXSMM_DETAILED_TIMING)
444   l_tick10 = libxsmm_timer_tick();
445   printf(
446       "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
447       "%f, %f, %f\n",
448       desc.N, desc.C, desc.K, desc.R, desc.S,
449       libxsmm_timer_duration(l_tick1, l_tick2),
450       libxsmm_timer_duration(l_tick2, l_tick3),
451       libxsmm_timer_duration(l_tick3, l_tick4),
452       libxsmm_timer_duration(l_tick4, l_tick5),
453       libxsmm_timer_duration(l_tick5, l_tick6),
454       libxsmm_timer_duration(l_tick6, l_tick7),
455       libxsmm_timer_duration(l_tick7, l_tick8),
456       libxsmm_timer_duration(l_tick8, l_tick9),
457       libxsmm_timer_duration(l_tick9, l_tick10),
458       libxsmm_timer_duration(l_tick1, l_tick10));
459 #endif
460 
461   return true;  // Succeeded
462 }
463 
464 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
465 template <typename T>
466 struct XsmmFwdConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmFwdConv2D467   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
468                   const T* input, const T* filter, T* output) {
469     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
470                                   input, filter, output);
471   }
472 };
473 #endif
474 
475 #ifdef TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
476 template <typename T>
477 struct XsmmBkwInputConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwInputConv2D478   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
479                   T* input, const T* filter, const T* output) {
480     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
481                                   input, filter, output);
482   }
483 };
484 
485 template <typename T>
486 struct XsmmBkwFilterConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwFilterConv2D487   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
488                   const T* input, T* filter, const T* output) {
489     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
490                                   input, filter, output);
491   }
492 };
493 #endif
494 
495 }  // namespace functor
496 
497 template struct functor::XsmmFwdConv2D<CPUDevice, float>;
498 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>;
499 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>;
500 
501 }  // namespace tensorflow
502 
503 #endif  // TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
504