• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifdef INTEL_MKL
16 #include "mkldnn.hpp"
17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
23 #include "tensorflow/core/kernels/no_op.h"
24 #include "tensorflow/core/util/mkl_util.h"
25 #include "tensorflow/core/util/tensor_format.h"
26 
27 #define GET_FLAG(bn_flag) static_cast<int>(mkldnn::normalization_flags::bn_flag)
28 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
29 
30 using mkldnn::batch_normalization_backward;
31 using mkldnn::batch_normalization_forward;
32 using mkldnn::prop_kind;
33 using mkldnn::stream;
34 
35 using BatchNormFwdPd = mkldnn::batch_normalization_forward::primitive_desc;
36 using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc;
37 
38 namespace tensorflow {
39 using CPUDevice = Eigen::ThreadPoolDevice;
40 
41 using FusedBNActivationMode = functor::FusedBatchNormActivationMode;
42 
43 struct MklBatchNormFwdParams {
44   memory::dims src_dims;
45   int depth;
46   float eps;
47   bool training;
48   TensorFormat data_format;
49   FusedBNActivationMode activation_mode;
50   memory::desc src_md;
51 
MklBatchNormFwdParamstensorflow::MklBatchNormFwdParams52   MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
53                         bool training, TensorFormat data_format,
54                         memory::desc src_md,
55                         FusedBNActivationMode activation_mode)
56       : src_dims(src_dims),
57         depth(depth),
58         eps(eps),
59         training(training),
60         data_format(data_format),
61         activation_mode(activation_mode),
62         src_md(src_md) {}
63 };
64 
65 template <typename T, typename U>
66 class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
67  public:
MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams & fwdParams)68   explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
69       : MklPrimitive(engine(engine::kind::cpu, 0)) {
70     if (context_.bn_fwd == nullptr) Setup(fwdParams);
71   }
72 
~MklFusedBatchNormFwdPrimitive()73   ~MklFusedBatchNormFwdPrimitive() {}
74 
75   // BatchNormalization forward execute
76   //   src_data:     input data buffer of src
77   //   weights_data: input data buffer of weights
78   //   dst_data:     output data buffer of dst
79   //   mean_data:     output data buffer of means
80   //   variance_data: output data buffer of variances
Execute(const T * src_data,const U * weights_data,T * dst_data,U * mean_data,U * variance_data,std::shared_ptr<stream> fwd_stream,U * workspace_data)81   void Execute(const T* src_data, const U* weights_data, T* dst_data,
82                U* mean_data, U* variance_data,
83                std::shared_ptr<stream> fwd_stream, U* workspace_data) {
84 #ifndef ENABLE_ONEDNN_OPENMP
85     // TODO: Create a common function and avoid the duplicate code
86     context_.src_mem->set_data_handle(
87         static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
88     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
89                                       *fwd_stream);
90 
91     if (IS_SET(use_scale_shift))
92       context_.weights_mem->set_data_handle(
93           static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream);
94 
95     if ((context_.pkind == prop_kind::forward_training) ||
96         (IS_SET(use_global_stats))) {
97       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data),
98                                          *fwd_stream);
99       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data),
100                                              *fwd_stream);
101     }
102     if (workspace_data != nullptr) {
103       context_.ws_mem->set_data_handle(workspace_data, *fwd_stream);
104     }
105 #else
106     context_.src_mem->set_data_handle(
107         static_cast<void*>(const_cast<T*>(src_data)));
108     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
109 
110     if (IS_SET(use_scale_shift))
111       context_.weights_mem->set_data_handle(
112           static_cast<void*>(const_cast<U*>(weights_data)));
113 
114     if ((context_.pkind == prop_kind::forward_training) ||
115         (IS_SET(use_global_stats))) {
116       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
117       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
118     }
119     if (workspace_data != nullptr) {
120       context_.ws_mem->set_data_handle(workspace_data);
121     }
122 #endif  // !ENABLE_ONEDNN_OPENMP
123 
124     // Execute batch-normalization forward primitives.
125     execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
126 
127     context_.src_mem->set_data_handle(DummyData);
128     context_.dst_mem->set_data_handle(DummyData);
129 
130     if (IS_SET(use_scale_shift))
131       context_.weights_mem->set_data_handle(DummyData);
132 
133     if ((context_.pkind == prop_kind::forward_training) ||
134         (IS_SET(use_global_stats))) {
135       context_.mean_mem->set_data_handle(DummyData);
136       context_.variance_mem->set_data_handle(DummyData);
137     }
138 
139     if (workspace_data != nullptr) {
140       context_.ws_mem->set_data_handle(DummyData);
141     }
142   }
143 
GetDstPd() const144   memory::desc GetDstPd() const { return context_.dst_mem->get_desc(); }
145 
GetBatchNormFwdPd() const146   std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const {
147     return context_.fwd_pd;
148   }
149 
150  private:
151   // Primitive reuse context for BatchNorm forward op.
152   struct BatchNormFwdContext {
153     // Flags indicating if it is training or inference mode.
154     int64 flags;
155 
156     // Algorithm kind.
157     mkldnn::prop_kind pkind;
158 
159     // Inputs/outputs memory.
160     std::shared_ptr<mkldnn::memory> src_mem;
161     std::shared_ptr<mkldnn::memory> weights_mem;
162     std::shared_ptr<mkldnn::memory> dst_mem;
163     std::shared_ptr<mkldnn::memory> mean_mem;
164     std::shared_ptr<mkldnn::memory> variance_mem;
165     std::shared_ptr<mkldnn::memory> ws_mem;
166 
167     // Forward BatchNorm primitive descriptor.
168     std::shared_ptr<BatchNormFwdPd> fwd_pd;
169 
170     // BatchNorm forward primitive.
171     std::shared_ptr<mkldnn::primitive> bn_fwd;
172     std::vector<mkldnn::primitive> fwd_primitives;
173 
174     std::vector<std::unordered_map<int, memory>> net_args;
175 
BatchNormFwdContexttensorflow::MklFusedBatchNormFwdPrimitive::BatchNormFwdContext176     BatchNormFwdContext()
177         : flags(0),
178           pkind(prop_kind::forward_training),
179           src_mem(nullptr),
180           weights_mem(nullptr),
181           dst_mem(nullptr),
182           mean_mem(nullptr),
183           variance_mem(nullptr),
184           ws_mem(nullptr),
185           bn_fwd(nullptr) {}
186   };
187 
Setup(const MklBatchNormFwdParams & fwdParams)188   void Setup(const MklBatchNormFwdParams& fwdParams) {
189     context_.flags =
190         fwdParams.training
191             ? GET_FLAG(use_scale_shift)
192             : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
193     context_.pkind = fwdParams.training ? prop_kind::forward_training
194                                         : prop_kind::forward_scoring;
195 
196     if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
197       context_.flags |= GET_FLAG(fuse_norm_relu);
198     }
199     // Memory descriptor
200     auto src_md = fwdParams.src_md;
201     // Create forward BatchNorm descriptor and primitive descriptor.
202     auto fwd_desc = batch_normalization_forward::desc(
203         context_.pkind, src_md, fwdParams.eps,
204         static_cast<mkldnn::normalization_flags>(context_.flags));
205 
206     context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_));
207 
208     // Create memory primitive based on dummy data
209     context_.src_mem.reset(
210         new memory(context_.fwd_pd->src_desc(), cpu_engine_, DummyData));
211     context_.dst_mem.reset(
212         new memory(context_.fwd_pd->dst_desc(), cpu_engine_, DummyData));
213 
214     memory::dims s_dims = {2, fwdParams.depth};
215     memory::dims m_dims = {1, fwdParams.depth};
216     if (IS_SET(use_scale_shift)) {
217       context_.weights_mem.reset(
218           new memory({{s_dims}, MklDnnType<U>(), memory::format_tag::nc},
219                      cpu_engine_, DummyData));
220     }
221 
222     if (fwdParams.training || (IS_SET(use_global_stats))) {
223       context_.mean_mem.reset(
224           new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc},
225                      cpu_engine_, DummyData));
226 
227       context_.variance_mem.reset(
228           new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc},
229                      cpu_engine_, DummyData));
230     }
231 
232     if (IS_SET(fuse_norm_relu)) {
233       context_.ws_mem.reset(new memory(context_.fwd_pd->workspace_desc(),
234                                        cpu_engine_, DummyData));
235     }
236 
237     // BatchNorm forward primitive.
238     // TODO(intel-tf): Merge all the #ifdefs and simplify code
239     if (!fwdParams.training && !(IS_SET(use_global_stats))) {
240       if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) {
241         context_.net_args.push_back(
242             {{MKLDNN_ARG_SRC, *context_.src_mem},
243              {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
244              {MKLDNN_ARG_DST, *context_.dst_mem}});
245       } else {
246         context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
247                                      {MKLDNN_ARG_DST, *context_.dst_mem}});
248       }
249       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
250     } else if (IS_SET(use_global_stats)) {
251       if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
252         if (IS_SET(fuse_norm_relu)) {
253           context_.net_args.push_back(
254               {{MKLDNN_ARG_SRC, *context_.src_mem},
255                {MKLDNN_ARG_MEAN, *context_.mean_mem},
256                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
257                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
258                {MKLDNN_ARG_DST, *context_.dst_mem},
259                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
260         } else {
261           context_.net_args.push_back(
262               {{MKLDNN_ARG_SRC, *context_.src_mem},
263                {MKLDNN_ARG_MEAN, *context_.mean_mem},
264                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
265                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
266                {MKLDNN_ARG_DST, *context_.dst_mem}});
267         }
268       } else {
269         if (IS_SET(fuse_norm_relu)) {
270           context_.net_args.push_back(
271               {{MKLDNN_ARG_SRC, *context_.src_mem},
272                {MKLDNN_ARG_MEAN, *context_.mean_mem},
273                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
274                {MKLDNN_ARG_DST, *context_.dst_mem},
275                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
276         } else {
277           context_.net_args.push_back(
278               {{MKLDNN_ARG_SRC, *context_.src_mem},
279                {MKLDNN_ARG_MEAN, *context_.mean_mem},
280                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
281                {MKLDNN_ARG_DST, *context_.dst_mem}});
282         }
283       }
284       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
285     } else {
286       if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) {
287         if (IS_SET(fuse_norm_relu)) {
288           context_.net_args.push_back(
289               {{MKLDNN_ARG_SRC, *context_.src_mem},
290                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
291                {MKLDNN_ARG_DST, *context_.dst_mem},
292                {MKLDNN_ARG_MEAN, *context_.mean_mem},
293                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
294                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
295         } else {
296           context_.net_args.push_back(
297               {{MKLDNN_ARG_SRC, *context_.src_mem},
298                {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
299                {MKLDNN_ARG_DST, *context_.dst_mem},
300                {MKLDNN_ARG_MEAN, *context_.mean_mem},
301                {MKLDNN_ARG_VARIANCE, *context_.variance_mem}});
302         }
303       } else {
304         if (IS_SET(fuse_norm_relu)) {
305           context_.net_args.push_back(
306               {{MKLDNN_ARG_SRC, *context_.src_mem},
307                {MKLDNN_ARG_DST, *context_.dst_mem},
308                {MKLDNN_ARG_MEAN, *context_.mean_mem},
309                {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
310                {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
311         } else {
312           context_.net_args.push_back(
313               {{MKLDNN_ARG_SRC, *context_.src_mem},
314                {MKLDNN_ARG_DST, *context_.dst_mem},
315                {MKLDNN_ARG_MEAN, *context_.mean_mem},
316                {MKLDNN_ARG_VARIANCE, *context_.variance_mem}});
317         }
318       }
319       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
320     }
321 
322     context_.fwd_primitives.push_back(*context_.bn_fwd);
323   }
324 
325   struct BatchNormFwdContext context_;
326 };
327 
328 template <typename T, typename U>
329 class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
330  public:
Get(const MklBatchNormFwdParams & fwdParams)331   static MklFusedBatchNormFwdPrimitive<T, U>* Get(
332       const MklBatchNormFwdParams& fwdParams) {
333     auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>(
334         MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance()
335             .GetBatchNormFwd(fwdParams));
336 
337     if (bn_fwd == nullptr) {
338       bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams);
339       MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd(
340           fwdParams, bn_fwd);
341     }
342     return bn_fwd;
343   }
344 
GetInstance()345   static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() {
346     static MklFusedBatchNormFwdPrimitiveFactory instance_;
347     return instance_;
348   }
349 
350  private:
MklFusedBatchNormFwdPrimitiveFactory()351   MklFusedBatchNormFwdPrimitiveFactory() {}
~MklFusedBatchNormFwdPrimitiveFactory()352   ~MklFusedBatchNormFwdPrimitiveFactory() {}
353 
CreateKey(const MklBatchNormFwdParams & fwdParams)354   static string CreateKey(const MklBatchNormFwdParams& fwdParams) {
355     string prefix = "bn_fwd";
356     FactoryKeyCreator key_creator;
357     key_creator.AddAsKey(prefix);
358     key_creator.AddAsKey(fwdParams.src_dims);
359     key_creator.AddAsKey<int>(fwdParams.depth);
360     key_creator.AddAsKey<float>(fwdParams.eps);
361     key_creator.AddAsKey<bool>(fwdParams.training);
362     key_creator.AddAsKey<TensorFormat>(fwdParams.data_format);
363     key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode);
364     key_creator.AddAsKey(typeid(T).name());
365     key_creator.AddAsKey(typeid(U).name());
366     return key_creator.GetKey();
367   }
368 
GetBatchNormFwd(const MklBatchNormFwdParams & fwdParams)369   MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) {
370     string key = CreateKey(fwdParams);
371     return this->GetOp(key);
372   }
373 
SetBatchNormFwd(const MklBatchNormFwdParams & fwdParams,MklPrimitive * op)374   void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams,
375                        MklPrimitive* op) {
376     string key = CreateKey(fwdParams);
377     this->SetOp(key, op);
378   }
379 };
380 
381 struct MklBatchNormBwdParams {
382   memory::dims src_dims;
383   memory::dims diff_dst_dims;
384   int depth;
385   float eps;
386   bool training;
387   TensorFormat data_format;
388   memory::desc src_md;
389   memory::desc diff_dst_md;
390 
MklBatchNormBwdParamstensorflow::MklBatchNormBwdParams391   MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
392                         int depth, float eps, bool training,
393                         TensorFormat data_format, memory::desc src_md,
394                         memory::desc diff_dst_md)
395       : src_dims(src_dims),
396         diff_dst_dims(diff_dst_dims),
397         depth(depth),
398         eps(eps),
399         training(training),
400         data_format(data_format),
401         src_md(src_md),
402         diff_dst_md(diff_dst_md) {}
403 };
404 
405 template <typename T, typename U>
406 class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
407  public:
MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams & bwdParams)408   explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
409       : MklPrimitive(engine(engine::kind::cpu, 0)) {
410     if (context_.bn_bwd == nullptr) Setup(bwdParams);
411   }
412 
~MklFusedBatchNormBwdPrimitive()413   ~MklFusedBatchNormBwdPrimitive() {}
414 
415   // BatchNormalization backward execute
416   //   src_data:       input data buffer of src
417   //   mean_data:      input data buffer of mean
418   //   variance_data:  input data buffer of variance
419   //   diff_dst_data:  input data buffer of diff_dst
420   //   weights_data:   input data buffer of weights
421   //   diff_src_data:      output data buffer of diff_src
422   //   diff_weights_data:  output data buffer of diff_weights
423   //   res_space_data:     output data buffer or reserved_space_3.
424   //                       TODO: reserved_space_3: temp mem to hold
425   //                          intermediate results is not implemented
426   //                          on CPU as of now.
Execute(const T * src_data,const U * mean_data,const U * variance_data,const T * diff_dst_data,const U * weights_data,T * diff_src_data,U * diff_weights_data,U * res_space_data,std::shared_ptr<stream> bwd_stream)427   void Execute(const T* src_data, const U* mean_data, const U* variance_data,
428                const T* diff_dst_data, const U* weights_data, T* diff_src_data,
429                U* diff_weights_data, U* res_space_data,
430                std::shared_ptr<stream> bwd_stream) {
431 #ifndef ENABLE_ONEDNN_OPENMP
432     // TODO: Create a common function and avoid the duplicate code
433     context_.src_mem->set_data_handle(
434         static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
435     context_.mean_mem->set_data_handle(
436         static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream);
437     context_.variance_mem->set_data_handle(
438         static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream);
439     context_.diff_dst_mem->set_data_handle(
440         static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
441 
442     if (IS_SET(use_scale_shift)) {
443       context_.weights_mem->set_data_handle(
444           static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream);
445       context_.diff_weights_mem->set_data_handle(
446           static_cast<void*>(diff_weights_data), *bwd_stream);
447     }
448 
449     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
450                                            *bwd_stream);
451 #else
452     context_.src_mem->set_data_handle(
453         static_cast<void*>(const_cast<T*>(src_data)));
454     context_.mean_mem->set_data_handle(
455         static_cast<void*>(const_cast<U*>(mean_data)));
456     context_.variance_mem->set_data_handle(
457         static_cast<void*>(const_cast<U*>(variance_data)));
458     context_.diff_dst_mem->set_data_handle(
459         static_cast<void*>(const_cast<T*>(diff_dst_data)));
460 
461     if (IS_SET(use_scale_shift)) {
462       context_.weights_mem->set_data_handle(
463           static_cast<void*>(const_cast<U*>(weights_data)));
464       context_.diff_weights_mem->set_data_handle(
465           static_cast<void*>(diff_weights_data));
466     }
467 
468     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
469 #endif  // !ENABLE_ONEDNN_OPENMP
470     // Execute backward batch-normalization primitives.
471     DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
472     execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
473 
474     // After execution, set data handle back to DummyData.
475     context_.src_mem->set_data_handle(DummyData);
476     context_.mean_mem->set_data_handle(DummyData);
477     context_.variance_mem->set_data_handle(DummyData);
478     context_.diff_dst_mem->set_data_handle(DummyData);
479     if (IS_SET(use_scale_shift)) {
480       context_.weights_mem->set_data_handle(DummyData);
481       context_.diff_weights_mem->set_data_handle(DummyData);
482     }
483     context_.diff_src_mem->set_data_handle(DummyData);
484   }
485 
GetBatchNormBwdPd() const486   std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const {
487     return context_.bwd_pd;
488   }
489 
GetDiffSrcPd()490   memory::desc GetDiffSrcPd() { return context_.diff_src_mem->get_desc(); }
491 
492  private:
493   struct BatchNormBwdContext {
494     // Flags to indicate whether it is training or inference.
495     int64 flags;
496 
497     // Inputs/output memory.
498     std::shared_ptr<mkldnn::memory> src_mem;
499     std::shared_ptr<mkldnn::memory> mean_mem;
500     std::shared_ptr<mkldnn::memory> variance_mem;
501     std::shared_ptr<mkldnn::memory> diff_dst_mem;
502     std::shared_ptr<mkldnn::memory> weights_mem;
503     std::shared_ptr<mkldnn::memory> diff_weights_mem;
504     std::shared_ptr<mkldnn::memory> diff_src_mem;
505 
506     // Backward batch-normalization primitive descriptor.
507     std::shared_ptr<BatchNormBwdPd> bwd_pd;
508 
509     // Backward batch-normalization primitive.
510     std::shared_ptr<mkldnn::primitive> bn_bwd;
511     std::vector<mkldnn::primitive> bwd_primitives;
512 
513     std::vector<std::unordered_map<int, memory>> net_args;
514 
BatchNormBwdContexttensorflow::MklFusedBatchNormBwdPrimitive::BatchNormBwdContext515     BatchNormBwdContext()
516         : src_mem(nullptr),
517           mean_mem(nullptr),
518           variance_mem(nullptr),
519           diff_dst_mem(nullptr),
520           weights_mem(nullptr),
521           diff_weights_mem(nullptr),
522           diff_src_mem(nullptr) {}
523   };
524 
Setup(const MklBatchNormBwdParams & bwdParams)525   void Setup(const MklBatchNormBwdParams& bwdParams) {
526     context_.flags =
527         bwdParams.training
528             ? GET_FLAG(use_scale_shift)
529             : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
530 
531     // Memory descriptors.
532     auto src_md = bwdParams.src_md;
533     auto diff_dst_md = bwdParams.diff_dst_md;
534     auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(),
535                                       memory::format_tag::nc);
536     auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(),
537                                   memory::format_tag::nc);
538     auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(),
539                                      memory::format_tag::nc);
540     auto diff_weights_desc = weights_desc;
541 
542     // Forward batch-normalization descriptor and primitive descriptor.
543     // Adding this back due to type difference with context.flags
544     auto bn_flags = bwdParams.training
545                         ? mkldnn::normalization_flags::use_scale_shift
546                         : (mkldnn::normalization_flags::use_scale_shift |
547                            mkldnn::normalization_flags::use_global_stats);
548     auto fwd_desc = batch_normalization_forward::desc(
549         prop_kind::forward_training, src_md, bwdParams.eps, bn_flags);
550     auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_);
551 
552     // Backward batch-normalization primitive.
553     // For inference, specify use_global_stats
554     //   1. on fwd propagation, use mean and variance provided as inputs.
555     //   2. on bwd propagation, mean and variance are considered as constants.
556     //      Thus, reduce the amount of MKL computation.
557     auto bwd_desc = batch_normalization_backward::desc(
558         prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags);
559     context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd));
560 
561     // Create memory primitives.
562     context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData));
563     context_.diff_dst_mem.reset(
564         new memory(diff_dst_md, cpu_engine_, DummyData));
565     context_.variance_mem.reset(
566         new memory(variance_desc, cpu_engine_, DummyData));
567     context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData));
568     context_.weights_mem.reset(
569         new memory(weights_desc, cpu_engine_, DummyData));
570     context_.diff_weights_mem.reset(
571         new memory(diff_weights_desc, cpu_engine_, DummyData));
572     context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData));
573 
574     context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd));
575     context_.net_args.push_back(
576         {{MKLDNN_ARG_SRC, *context_.src_mem},
577          {MKLDNN_ARG_MEAN, *context_.mean_mem},
578          {MKLDNN_ARG_VARIANCE, *context_.variance_mem},
579          {MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
580          {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
581          {MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem},
582          {MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_weights_mem}});
583     context_.bwd_primitives.push_back(*context_.bn_bwd);
584   }
585 
586   struct BatchNormBwdContext context_;
587 };
588 
589 template <typename T, typename U>
590 class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
591  public:
Get(const MklBatchNormBwdParams & bwdParams)592   static MklFusedBatchNormBwdPrimitive<T, U>* Get(
593       const MklBatchNormBwdParams& bwdParams) {
594     auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>(
595         MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance()
596             .GetBatchNormBwd(bwdParams));
597     if (bn_bwd == nullptr) {
598       bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams);
599       MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd(
600           bwdParams, bn_bwd);
601     }
602     return bn_bwd;
603   }
604 
GetInstance()605   static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() {
606     static MklFusedBatchNormBwdPrimitiveFactory instance_;
607     return instance_;
608   }
609 
610  private:
MklFusedBatchNormBwdPrimitiveFactory()611   MklFusedBatchNormBwdPrimitiveFactory() {}
~MklFusedBatchNormBwdPrimitiveFactory()612   ~MklFusedBatchNormBwdPrimitiveFactory() {}
613 
CreateKey(const MklBatchNormBwdParams & bwdParams)614   static string CreateKey(const MklBatchNormBwdParams& bwdParams) {
615     string prefix = "bn_bwd";
616     FactoryKeyCreator key_creator;
617     key_creator.AddAsKey(prefix);
618     key_creator.AddAsKey(bwdParams.src_dims);
619     key_creator.AddAsKey(bwdParams.diff_dst_dims);
620     key_creator.AddAsKey<int>(bwdParams.depth);
621     key_creator.AddAsKey<float>(bwdParams.eps);
622     key_creator.AddAsKey<bool>(bwdParams.training);
623     key_creator.AddAsKey<TensorFormat>(bwdParams.data_format);
624     key_creator.AddAsKey(typeid(T).name());
625     key_creator.AddAsKey(typeid(U).name());
626     return key_creator.GetKey();
627   }
628 
GetBatchNormBwd(const MklBatchNormBwdParams & bwdParams)629   MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) {
630     string key = CreateKey(bwdParams);
631     return this->GetOp(key);
632   }
633 
SetBatchNormBwd(const MklBatchNormBwdParams & bwdParams,MklPrimitive * op)634   void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams,
635                        MklPrimitive* op) {
636     string key = CreateKey(bwdParams);
637     this->SetOp(key, op);
638   }
639 };
640 
641 //  Adding a third parameter to the template to support FusedBatchNormV3
642 //  with MKL. This is different from default where the classes are
643 //  derived. Moves enabling to compile-time rather than runtime.
644 template <typename Device, typename T, typename U, bool reserved_space,
645           bool is_batch_norm_ex = false, bool native_format = false>
646 class MklFusedBatchNormOp : public OpKernel {
647  public:
MklFusedBatchNormOp(OpKernelConstruction * context)648   explicit MklFusedBatchNormOp(OpKernelConstruction* context)
649       : OpKernel(context) {
650     float epsilon;
651     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
652     epsilon_ = epsilon;
653     float exponential_avg_factor;
654     OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
655                                              &exponential_avg_factor));
656     exponential_avg_factor_ = static_cast<U>(exponential_avg_factor);
657     string tensor_format;
658     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
659     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
660                 errors::InvalidArgument("Invalid data format"));
661     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
662     depth_ = 0;
663     mean_values_ = nullptr;
664     variance_values_ = nullptr;
665 
666     if (!is_batch_norm_ex) {
667       activation_mode_ = FusedBNActivationMode::kIdentity;
668     } else {
669       int num_side_inputs;
670       OP_REQUIRES_OK(context,
671                      context->GetAttr("num_side_inputs", &num_side_inputs));
672       // Currently _MKLFusedBatchNormEx do not support "SideInput"
673       OP_REQUIRES(context, num_side_inputs == 0,
674                   errors::InvalidArgument(
675                       "_MKLFusedBatchNorm do not support side input now."));
676 
677       OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
678       OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu,
679                   errors::InvalidArgument(
680                       "_MKLFusedBatchNorm only support Relu activation"));
681     }
682   }
683 
Compute(OpKernelContext * context)684   void Compute(OpKernelContext* context) override {
685     try {
686       const size_t kSrcIndex = 0;       // index of src input tensor
687       const size_t kScaleIndex = 1;     // index of scale tensor
688       const size_t kShiftIndex = 2;     // index of shift tensor
689       const size_t kMeanIndex = 3;      // index of est_mean tensor
690       const size_t kVarianceIndex = 4;  // index of est_variance tensor
691 
692       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
693       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
694       const Tensor& shift_tensor = MklGetInput(context, kShiftIndex);
695       const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex);
696       const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex);
697 
698       TensorShape tf_shape_src;
699       MklDnnShape dnn_shape_src;
700       GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format);
701 
702       if (dnn_shape_src.IsMklTensor()) {
703         tf_shape_src = dnn_shape_src.GetTfShape();
704         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
705                     errors::InvalidArgument("input must be 4-dimensional",
706                                             src_tensor.shape().DebugString()));
707       } else {
708         tf_shape_src = src_tensor.shape();
709         OP_REQUIRES(context, src_tensor.dims() == 4,
710                     errors::InvalidArgument("input must be 4-dimensional",
711                                             src_tensor.shape().DebugString()));
712       }
713       OP_REQUIRES(context, scale_tensor.dims() == 1,
714                   errors::InvalidArgument("scale must be 1-dimensional",
715                                           scale_tensor.shape().DebugString()));
716       OP_REQUIRES(context, shift_tensor.dims() == 1,
717                   errors::InvalidArgument("offset must be 1-dimensional",
718                                           shift_tensor.shape().DebugString()));
719       OP_REQUIRES(
720           context, est_mean_tensor.dims() == 1,
721           errors::InvalidArgument("estimated_mean must be 1-dimensional",
722                                   est_mean_tensor.shape().DebugString()));
723       OP_REQUIRES(
724           context, est_variance_tensor.dims() == 1,
725           errors::InvalidArgument("estimated_variance must be 1-dimensional",
726                                   est_variance_tensor.shape().DebugString()));
727 
728       // Handle the special case: input with 0 element and 0 batch size.
729       Tensor* dst_tensor = nullptr;
730       TensorShape workspace_tf_shape;
731       if (tf_shape_src.num_elements() == 0) {
732         size_t workspace_bytes = 0;
733         workspace_tf_shape.AddDim(workspace_bytes);
734         HandleEmptyInput(context, tf_shape_src, workspace_tf_shape,
735                          scale_tensor.shape(), &dst_tensor);
736         return;
737       }
738 
739       if (dnn_shape_src.IsMklTensor())
740         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
741       else
742         ExtractParams(context);
743 
744       // Index of output tensor(diff_src).
745       const size_t kDstIndex = 0;
746 
747       // Allocate 5 output TF tensors.
748       Tensor* batch_mean_tensor = nullptr;
749       Tensor* batch_variance_tensor = nullptr;
750       Tensor* saved_mean_tensor = nullptr;
751       Tensor* saved_variance_tensor = nullptr;
752       Tensor* reserved_space_tensor = nullptr;
753 
754       MklDnnData<T> src(&cpu_engine_);
755       MklDnnData<U> weights(&cpu_engine_);
756       MklDnnData<U> wksp(&cpu_engine_);
757 
758       memory::format_tag dnn_fmt;
759       MklTensorFormat mkl_tensor_fmt;
760       if (dnn_shape_src.IsMklTensor()) {
761         if (dnn_shape_src.IsTensorInNCHWFormat()) {
762           dnn_fmt = memory::format_tag::nchw;
763           mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW;
764         } else {
765           dnn_fmt = memory::format_tag::nhwc;
766           mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC;
767         }
768       } else {
769         mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_);
770         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt);
771       }
772 
773       // Set src memory descriptor.
774       memory::dims src_dims =
775           dnn_shape_src.IsMklTensor()
776               ? dnn_shape_src.GetSizesAsMklDnnDims()
777               : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
778 
779       auto src_md = dnn_shape_src.IsMklTensor()
780                         ? dnn_shape_src.GetMklLayout()
781                         : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
782 
783       MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
784                                       tensor_format_, src_md, activation_mode_);
785 
786       // Get forward batch-normalization op from the primitive caching pool.
787       MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
788           MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
789 
790       // Allocate workspace tensor
791       U* ws_data = nullptr;
792       if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
793         memory::desc workspace_md =
794             bn_fwd->GetBatchNormFwdPd()->workspace_desc();
795         size_t workspace_bytes = workspace_md.get_size();
796         workspace_tf_shape.AddDim(workspace_bytes);
797 
798         AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
799                           &batch_mean_tensor, &batch_variance_tensor,
800                           &saved_mean_tensor, &saved_variance_tensor,
801                           &reserved_space_tensor);
802         if (reserved_space) {
803           wksp.SetUsrMem(workspace_md, reserved_space_tensor);
804           ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle());
805         }
806       } else {
807         // There is actually no workspace tensor out, so we make a dummy one.
808         size_t workspace_bytes = 0;
809         workspace_tf_shape.AddDim(workspace_bytes);
810         AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
811                           &batch_mean_tensor, &batch_variance_tensor,
812                           &saved_mean_tensor, &saved_variance_tensor,
813                           &reserved_space_tensor);
814       }
815 
816       if (is_training_)
817         SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
818       else
819         SetMeanVariance(est_mean_tensor, est_variance_tensor);
820 
821       // MKL-DNN packs scale & shift as "weights":
822       // <scale>...<scale><shift>...<shift>
823       weights.AllocateBuffer(2 * depth_ * sizeof(U));
824       U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer());
825       const U* scale_tf = scale_tensor.flat<U>().data();
826       const U* shift_tf = shift_tensor.flat<U>().data();
827 
828       std::memcpy(weights_data, scale_tf, depth_ * sizeof(U));
829       std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U));
830       char* saved_mean_data_tf =
831           reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data());
832       std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_),
833                   depth_ * sizeof(U));
834 
835       char* saved_variance_data_tf =
836           reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data());
837       std::memcpy(saved_variance_data_tf,
838                   reinterpret_cast<char*>(variance_values_),
839                   depth_ * sizeof(U));
840 
841       // Check if reorder is needed for src.
842       const T* src_data = nullptr;
843       std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
844       if (!native_format && src_md != bn_fwd_pd->src_desc()) {
845         src.SetUsrMem(src_md, &src_tensor);
846         src.CheckReorderToOpMem(bn_fwd_pd->src_desc(), cpu_engine_, context);
847         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
848       } else {
849         src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
850       }
851 
852       // Allocate output (dst) tensor
853       MklDnnShape dnn_shape_dst;
854       TensorShape tf_shape_dst;
855       dnn_shape_dst.SetMklTensor(true);
856       auto dst_pd = bn_fwd->GetDstPd();
857       dnn_shape_dst.SetMklLayout(&dst_pd);
858       dnn_shape_dst.SetElemType(MklDnnType<T>());
859       auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension()
860                                                : src_tensor.shape().dims();
861       dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt);
862       tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
863       if (native_format) {
864         tf_shape_dst = dnn_shape_dst.GetTfShape();
865       }
866       AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst,
867                                 dnn_shape_dst, native_format);
868 
869       U* weights_op_data = weights_data;
870       U* mean_op_data = saved_mean_tensor->flat<U>().data();
871       U* variance_op_data = saved_variance_tensor->flat<U>().data();
872       T* dst_data = dst_tensor->flat<T>().data();
873 
874       // Execute
875       std::shared_ptr<stream> fwd_cpu_stream;
876       MklDnnThreadPool eigen_tp(context);
877       fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine()));
878       bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
879                       variance_op_data, fwd_cpu_stream, ws_data);
880       float adjust_factor = 1.0;
881       if (is_training_) {
882         size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
883         size_t adjust_size = (orig_size > 1) ? (orig_size - 1) : 1;
884         adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
885       }
886 
887       auto mean_data = reinterpret_cast<U*>(saved_mean_data_tf);
888       auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf);
889       auto batch_mean_data = batch_mean_tensor->flat<U>().data();
890       auto batch_variance_data = batch_variance_tensor->flat<U>().data();
891       auto est_mean_data = est_mean_tensor.flat<U>().data();
892       auto est_variance_data = est_variance_tensor.flat<U>().data();
893       if (is_training_) {
894         if (exponential_avg_factor_ == U(1.0)) {
895           for (int k = 0; k < depth_; k++) {
896             batch_mean_data[k] = mean_data[k];
897             batch_variance_data[k] =
898                 static_cast<U>(adjust_factor) * variance_data[k];
899           }
900         } else {
901           U one_minus_factor = U(1.0) - exponential_avg_factor_;
902           for (int k = 0; k < depth_; k++) {
903             batch_mean_data[k] = one_minus_factor * est_mean_data[k] +
904                                  exponential_avg_factor_ * mean_data[k];
905             batch_variance_data[k] = one_minus_factor * est_variance_data[k] +
906                                      exponential_avg_factor_ *
907                                          static_cast<U>(adjust_factor) *
908                                          variance_data[k];
909           }
910         }
911       } else {
912         std::memcpy(batch_mean_data, mean_data, depth_ * sizeof(U));
913         std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U));
914       }
915     } catch (mkldnn::error& e) {
916       string error_msg = "Status: " + std::to_string(e.status) +
917                          ", message: " + string(e.message) + ", in file " +
918                          string(__FILE__) + ":" + std::to_string(__LINE__);
919       OP_REQUIRES_OK(
920           context,
921           errors::Aborted("Operation received an exception:", error_msg));
922     }
923   }
924 
925  private:
926   float epsilon_;
927   U exponential_avg_factor_;
928   TensorFormat tensor_format_;
929   bool is_training_;
930   U* mean_values_;
931   U* variance_values_;
932   size_t depth_;  // Batch normalization is performed for per channel.
933   FusedBNActivationMode activation_mode_;
934   engine cpu_engine_ = engine(engine::kind::cpu, 0);
935 
ExtractParams(OpKernelContext * context)936   void ExtractParams(OpKernelContext* context) {
937     const Tensor& input = MklGetInput(context, 0);
938     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
939   }
940 
SetMeanVariance(const Tensor & mean,const Tensor & variance)941   void SetMeanVariance(const Tensor& mean, const Tensor& variance) {
942     mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data()));
943     variance_values_ =
944         reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data()));
945   }
946 
HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape workspace_tf_shape,TensorShape tf_shape_scale,Tensor ** dst_tensor)947   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
948                         TensorShape workspace_tf_shape,
949                         TensorShape tf_shape_scale, Tensor** dst_tensor) {
950     DCHECK(dst_tensor);
951 
952     const size_t kDstIndex = 0;
953     MklDnnShape dnn_shape_dst;
954     dnn_shape_dst.SetMklTensor(false);
955     AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src,
956                               dnn_shape_dst, native_format);
957     DCHECK(*dst_tensor);
958     memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0,
959            (*dst_tensor)->tensor_data().size());
960 
961     Tensor* batch_mean_tensor = nullptr;
962     Tensor* batch_variance_tensor = nullptr;
963     Tensor* saved_mean_tensor = nullptr;
964     Tensor* saved_variance_tensor = nullptr;
965     Tensor* reserved_space_tensor = nullptr;
966     AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape,
967                       &batch_mean_tensor, &batch_variance_tensor,
968                       &saved_mean_tensor, &saved_variance_tensor,
969                       &reserved_space_tensor);
970   }
971 
AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale,TensorShape workspace_tf_shape,Tensor ** batch_mean_tensor,Tensor ** batch_variance_tensor,Tensor ** saved_mean_tensor,Tensor ** saved_variance_tensor,Tensor ** reserved_space_tensor)972   void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
973                          TensorShape workspace_tf_shape,
974                          Tensor** batch_mean_tensor,
975                          Tensor** batch_variance_tensor,
976                          Tensor** saved_mean_tensor,
977                          Tensor** saved_variance_tensor,
978                          Tensor** reserved_space_tensor) {
979     DCHECK(batch_mean_tensor);
980     DCHECK(batch_variance_tensor);
981     DCHECK(saved_mean_tensor);
982     DCHECK(saved_variance_tensor);
983 
984     const size_t kBatchMeanIndex = 1;
985     const size_t kBatchVarianceIndex = 2;
986     const size_t kSavedMeanIndex = 3;
987     const size_t kSavedVarianceIndex = 4;
988     const size_t kReservedSpaceIndex = 5;
989 
990     // Allocate batch mean output tensor.
991     MklDnnShape mkl_shape_batch_mean;
992     mkl_shape_batch_mean.SetMklTensor(false);
993     AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor,
994                               tf_shape_scale, mkl_shape_batch_mean,
995                               native_format);
996     DCHECK(*batch_mean_tensor);
997 
998     // Set NAN mean value in case of empty input tensor
999     int num_elements = tf_shape_scale.num_elements();
1000     auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data();
1001     std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN));
1002 
1003     // Allocate batch variance output tensor.
1004     MklDnnShape mkl_shape_batch_variance;
1005     mkl_shape_batch_variance.SetMklTensor(false);
1006     AllocateOutputSetMklShape(context, kBatchVarianceIndex,
1007                               batch_variance_tensor, tf_shape_scale,
1008                               mkl_shape_batch_variance, native_format);
1009     DCHECK(*batch_variance_tensor);
1010 
1011     // Set NAN variance value in case of empty input tensor
1012     auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data();
1013     std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN));
1014     // Mean and variance (without Bessel's correction) saved for backward
1015     // computation to serve as pre-computed mean and variance.
1016     MklDnnShape mkl_shape_saved_mean;
1017     mkl_shape_saved_mean.SetMklTensor(false);
1018     AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor,
1019                               tf_shape_scale, mkl_shape_saved_mean,
1020                               native_format);
1021     DCHECK(*saved_mean_tensor);
1022 
1023     // Set 0 mean value in case of empty input tensor
1024     auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data();
1025     std::fill_n(saved_mean_data, num_elements, static_cast<U>(0));
1026 
1027     MklDnnShape mkl_shape_saved_variance;
1028     mkl_shape_saved_variance.SetMklTensor(false);
1029     AllocateOutputSetMklShape(context, kSavedVarianceIndex,
1030                               saved_variance_tensor, tf_shape_scale,
1031                               mkl_shape_saved_variance, native_format);
1032     DCHECK(*saved_variance_tensor);
1033 
1034     // Set 0 variance value in case of empty input tensor
1035     auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data();
1036     std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
1037 
1038     // Changes to support reserved_space_3 parameter in FusedBatchNormV3.
1039     if (reserved_space) {
1040       DCHECK(reserved_space_tensor != nullptr);
1041 
1042       MklDnnShape mkl_shape_reserved_space;
1043       mkl_shape_reserved_space.SetMklTensor(false);
1044       AllocateOutputSetMklShape(context, kReservedSpaceIndex,
1045                                 reserved_space_tensor, workspace_tf_shape,
1046                                 mkl_shape_reserved_space, native_format);
1047       DCHECK((*reserved_space_tensor) != nullptr);
1048     }
1049   }
1050 };
1051 
1052 template <typename Device, typename T, typename U, bool reserved_space,
1053           bool native_format = false>
1054 class MklFusedBatchNormGradOp : public OpKernel {
1055  public:
MklFusedBatchNormGradOp(OpKernelConstruction * context)1056   explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
1057       : OpKernel(context) {
1058     float epsilon;
1059     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1060     epsilon_ = epsilon;
1061     string tensor_format;
1062     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1063     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1064                 errors::InvalidArgument("Invalid data format"));
1065     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1066     depth_ = 0;
1067   }
1068 
Compute(OpKernelContext * context)1069   void Compute(OpKernelContext* context) override {
1070     try {
1071       const size_t kDiffDstIndex = 0;        // index of diff_dst tensor
1072       const size_t kSrcIndex = 1;            // index of src input tensor
1073       const size_t kScaleIndex = 2;          // index of scale tensor
1074       const size_t kMeanIndex = 3;           // index of saved_mean tensor
1075       const size_t kVarianceIndex = 4;       // index of saved_variance tensor
1076       const size_t kReservedSpaceIndex = 5;  // index of reserved space 3 tensor
1077 
1078       const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
1079       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
1080       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
1081       const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex);
1082       const Tensor& saved_variance_tensor =
1083           MklGetInput(context, kVarianceIndex);
1084       const Tensor& reserved_space_tensor =
1085           (reserved_space) ? MklGetInput(context, kReservedSpaceIndex)
1086                            : Tensor();
1087 
1088       MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
1089       GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format);
1090       GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format);
1091 
1092       TensorShape tf_shape_src, tf_shape_diff_dst;
1093       if (dnn_shape_diff_dst.IsMklTensor()) {
1094         tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
1095         OP_REQUIRES(
1096             context, dnn_shape_diff_dst.GetDimension() == 4,
1097             errors::InvalidArgument("input must be 4-dimensional",
1098                                     diff_dst_tensor.shape().DebugString()));
1099       } else {
1100         tf_shape_diff_dst = diff_dst_tensor.shape();
1101         OP_REQUIRES(
1102             context, diff_dst_tensor.dims() == 4,
1103             errors::InvalidArgument("input must be 4-dimensional",
1104                                     diff_dst_tensor.shape().DebugString()));
1105       }
1106 
1107       if (dnn_shape_src.IsMklTensor()) {
1108         tf_shape_src = dnn_shape_src.GetTfShape();
1109         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
1110                     errors::InvalidArgument("input must be 4-dimensional",
1111                                             src_tensor.shape().DebugString()));
1112       } else {
1113         tf_shape_src = src_tensor.shape();
1114         OP_REQUIRES(context, src_tensor.dims() == 4,
1115                     errors::InvalidArgument("input must be 4-dimensional",
1116                                             src_tensor.shape().DebugString()));
1117       }
1118 
1119       OP_REQUIRES(context, scale_tensor.dims() == 1,
1120                   errors::InvalidArgument("scale must be 1-dimensional",
1121                                           scale_tensor.shape().DebugString()));
1122       OP_REQUIRES(
1123           context, saved_mean_tensor.dims() == 1,
1124           errors::InvalidArgument("saved mean must be 1-dimensional",
1125                                   saved_mean_tensor.shape().DebugString()));
1126 
1127       OP_REQUIRES(
1128           context, saved_variance_tensor.dims() == 1,
1129           errors::InvalidArgument("saved variance must be 1-dimensional",
1130                                   saved_variance_tensor.shape().DebugString()));
1131 
1132       // Handle the special case: input with 0 element and 0 batch size.
1133       Tensor* diff_src_tensor = nullptr;
1134       if (tf_shape_src.num_elements() == 0 ||
1135           tf_shape_diff_dst.num_elements() == 0) {
1136         HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
1137                          &diff_src_tensor);
1138         return;
1139       }
1140 
1141       if (dnn_shape_src.IsMklTensor()) {
1142         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
1143       } else if (dnn_shape_diff_dst.IsMklTensor()) {
1144         depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C);
1145       } else {
1146         ExtractParams(context);
1147       }
1148 
1149       memory::format_tag dnn_fmt;
1150       MklTensorFormat mkl_tensor_fmt;
1151       if (dnn_shape_src.IsMklTensor()) {
1152         if (dnn_shape_src.IsTensorInNCHWFormat()) {
1153           dnn_fmt = memory::format_tag::nchw;
1154           mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW;
1155         } else {
1156           dnn_fmt = memory::format_tag::nhwc;
1157           mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC;
1158         }
1159       } else {
1160         mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_);
1161         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt);
1162       }
1163 
1164       MklDnnData<T> src(&cpu_engine_);
1165       MklDnnData<T> diff_dst(&cpu_engine_);
1166       MklDnnData<U> weights(&cpu_engine_);
1167       MklDnnData<U> diff_weights(&cpu_engine_);
1168 
1169       memory::dims src_dims =
1170           dnn_shape_src.IsMklTensor()
1171               ? dnn_shape_src.GetSizesAsMklDnnDims()
1172               : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
1173       memory::dims diff_dst_dims =
1174           dnn_shape_diff_dst.IsMklTensor()
1175               ? dnn_shape_diff_dst.GetSizesAsMklDnnDims()
1176               : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
1177                                           tensor_format_);
1178 
1179       // Set src and diff_dst primitive descriptors.
1180       memory::desc src_md =
1181           dnn_shape_src.IsMklTensor()
1182               ? dnn_shape_src.GetMklLayout()
1183               : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
1184       memory::desc diff_dst_md =
1185           dnn_shape_diff_dst.IsMklTensor()
1186               ? dnn_shape_diff_dst.GetMklLayout()
1187               : memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt);
1188 
1189       MklDnnData<T> reorder_src(&cpu_engine_);
1190       MklDnnData<T> reorder_diff_dst(&cpu_engine_);
1191       T* diff_dst_data =
1192           static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
1193       T* src_data =
1194           static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
1195 
1196       if (!native_format) {
1197         // MKL-DNN requires src and diff_dst to be in same memory layout, either
1198         // blocked or native format. If these inputs are in different formats,
1199         // convert the one in native format to blocked format as MKL-DNN gives
1200         // better performance for blocked format.
1201         if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
1202           reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
1203           reorder_diff_dst.CheckReorderToOpMem(src_md, cpu_engine_, context);
1204           diff_dst_md = src_md;
1205           diff_dst_data =
1206               static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle());
1207         } else if (!dnn_shape_src.IsMklTensor() &&
1208                    dnn_shape_diff_dst.IsMklTensor()) {
1209           reorder_src.SetUsrMem(src_md, &src_tensor);
1210           reorder_src.CheckReorderToOpMem(diff_dst_md, cpu_engine_, context);
1211           src_md = diff_dst_md;
1212           src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle());
1213         }
1214       }
1215 
1216       // weights -- MKL DNN packs scales/ shifts as weights in order
1217       // of scale, ..., scale, shift, ...., shift
1218       weights.AllocateBuffer(2 * depth_ * sizeof(U));
1219       U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer());
1220       const U* scale_tf = scale_tensor.flat<U>().data();
1221       for (int k = 0; k < depth_; k++) {
1222         weights_data_tf[k] = scale_tf[k];
1223         weights_data_tf[k + depth_] = static_cast<U>(0);
1224       }
1225 
1226       diff_weights.AllocateBuffer(2 * depth_ * sizeof(U));
1227 
1228       MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
1229                                       is_training_, tensor_format_, src_md,
1230                                       diff_dst_md);
1231       MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd =
1232           MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams);
1233 
1234       // Check if diff_dst input needs to be reordered
1235       std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
1236       if (!native_format && diff_dst_md != bn_bwd_pd->diff_dst_desc()) {
1237         diff_dst.SetUsrMem(diff_dst_md, diff_dst_data);
1238         diff_dst.CheckReorderToOpMem(bn_bwd_pd->diff_dst_desc(), cpu_engine_,
1239                                      context);
1240         diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
1241       }
1242 
1243       if (!native_format && (src_md != bn_bwd_pd->src_desc())) {
1244         src.SetUsrMem(src_md, src_data);
1245         src.CheckReorderToOpMem(bn_bwd_pd->src_desc(), cpu_engine_, context);
1246         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
1247       }
1248 
1249       // Indices of output tensors
1250       const size_t kDiffSrcIndex = 0;
1251 
1252       // Allocate output tensor diff_src, always set as MKL-DNN layout.
1253       MklDnnShape dnn_shape_diff_src;
1254       TensorShape tf_shape_diff_src;
1255       dnn_shape_diff_src.SetMklTensor(true);
1256       auto diff_src_pd = bn_bwd->GetDiffSrcPd();
1257       dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
1258       dnn_shape_diff_src.SetElemType(MklDnnType<T>());
1259       dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt);
1260       dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_);
1261       tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
1262       if (native_format) {
1263         tf_shape_diff_src = dnn_shape_diff_src.GetTfShape();
1264       }
1265       AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
1266                                 tf_shape_diff_src, dnn_shape_diff_src,
1267                                 native_format);
1268 
1269       U* mean_data =
1270           static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data()));
1271       U* variance_data = static_cast<U*>(
1272           const_cast<U*>(saved_variance_tensor.flat<U>().data()));
1273       U* weights_data = weights_data_tf;
1274       T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data());
1275       U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer());
1276 
1277       U* res_space_data =
1278           ((reserved_space) ? static_cast<U*>(const_cast<U*>(
1279                                   reserved_space_tensor.flat<U>().data()))
1280                             : nullptr);
1281 
1282       // Execute
1283       std::shared_ptr<stream> bwd_cpu_stream;
1284       MklDnnThreadPool eigen_tp(context);
1285       bwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_bwd->GetEngine()));
1286       bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
1287                       weights_data, diff_src_data, diff_weights_data,
1288                       res_space_data, bwd_cpu_stream);
1289       // Allocate output TF tensors diff_scale and diff_shift.
1290       Tensor* diff_scale_tensor = nullptr;
1291       Tensor* diff_shift_tensor = nullptr;
1292       AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor,
1293                         &diff_shift_tensor);
1294 
1295       // Copy data for tensors diff_scale and diff_shift.
1296       auto diff_scale_data = diff_scale_tensor->flat<U>().data();
1297       auto diff_shift_data = diff_shift_tensor->flat<U>().data();
1298       std::memcpy(reinterpret_cast<char*>(diff_scale_data),
1299                   reinterpret_cast<char*>(diff_weights_data),
1300                   depth_ * sizeof(U));
1301       std::memcpy(reinterpret_cast<char*>(diff_shift_data),
1302                   reinterpret_cast<char*>(diff_weights_data + depth_),
1303                   depth_ * sizeof(U));
1304     } catch (mkldnn::error& e) {
1305       string error_msg = "Status: " + std::to_string(e.status) +
1306                          ", message: " + string(e.message) + ", in file " +
1307                          string(__FILE__) + ":" + std::to_string(__LINE__);
1308       OP_REQUIRES_OK(
1309           context,
1310           errors::Aborted("Operation received an exception:", error_msg));
1311     }
1312   }
1313 
1314  private:
1315   float epsilon_;
1316   TensorFormat tensor_format_;
1317   size_t depth_;  // Batch normalization is performed for per channel.
1318   bool is_training_;
1319   engine cpu_engine_ = engine(engine::kind::cpu, 0);
1320 
ExtractParams(OpKernelContext * context)1321   void ExtractParams(OpKernelContext* context) {
1322     const Tensor& input = MklGetInput(context, 0);
1323     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
1324   }
1325 
HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape tf_shape_scale_shift,Tensor ** diff_src_tensor)1326   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
1327                         TensorShape tf_shape_scale_shift,
1328                         Tensor** diff_src_tensor) {
1329     const size_t kDiffSrcIndex = 0;
1330 
1331     MklDnnShape dnn_shape_diff_src;
1332     dnn_shape_diff_src.SetMklTensor(false);
1333     AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
1334                               tf_shape_src, dnn_shape_diff_src, native_format);
1335     auto diff_src_data = (*diff_src_tensor)->flat<T>().data();
1336     std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(),
1337                 static_cast<T>(0));
1338 
1339     Tensor* diff_scale_tensor = nullptr;
1340     Tensor* diff_shift_tensor = nullptr;
1341     AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor,
1342                       &diff_shift_tensor);
1343   }
1344 
AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale_shift,Tensor ** diff_scale_tensor,Tensor ** diff_shift_tensor)1345   void AllocateTFOutputs(OpKernelContext* context,
1346                          TensorShape tf_shape_scale_shift,
1347                          Tensor** diff_scale_tensor,
1348                          Tensor** diff_shift_tensor) {
1349     DCHECK(diff_scale_tensor);
1350     DCHECK(diff_shift_tensor);
1351 
1352     const size_t kDiffScaleIndex = 1;
1353     const size_t kDiffShiftIndex = 2;
1354     const size_t kP1Index = 3;
1355     const size_t kP2Index = 4;
1356 
1357     // Separate out scale and shift grad and copy to individual tensors
1358     MklDnnShape mkl_shape_diff_scale;
1359     mkl_shape_diff_scale.SetMklTensor(false);
1360     AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
1361                               tf_shape_scale_shift, mkl_shape_diff_scale,
1362                               native_format);
1363     DCHECK(*diff_scale_tensor);
1364 
1365     auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data();
1366     std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(),
1367                 static_cast<U>(0));
1368 
1369     MklDnnShape mkl_shape_diff_shift;
1370     mkl_shape_diff_shift.SetMklTensor(false);
1371     AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
1372                               tf_shape_scale_shift, mkl_shape_diff_shift,
1373                               native_format);
1374     DCHECK(*diff_shift_tensor);
1375 
1376     auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data();
1377     std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(),
1378                 static_cast<U>(0));
1379 
1380     // Placeholders for estimated_mean and estimated_variance, which are
1381     // used for inference and thus not needed here for gradient computation.
1382     Tensor *p1_tensor = nullptr, *p2_tensor = nullptr;
1383     MklDnnShape mkl_shape_p;
1384     mkl_shape_p.SetMklTensor(false);
1385     AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
1386                               mkl_shape_p, native_format);
1387     std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
1388                 static_cast<U>(0));
1389     AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
1390                               mkl_shape_p, native_format);
1391     std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
1392                 static_cast<U>(0));
1393   }
1394 
GetMeanVarianceDims()1395   memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
1396 };
1397 
1398 #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T)                    \
1399   REGISTER_KERNEL_BUILDER(                                     \
1400       Name("_MklFusedBatchNorm")                               \
1401           .Device(DEVICE_CPU)                                  \
1402           .TypeConstraint<T>("T")                              \
1403           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1404       MklFusedBatchNormOp<CPUDevice, T, T, false, false>);     \
1405   REGISTER_KERNEL_BUILDER(                                     \
1406       Name("_MklNativeFusedBatchNorm")                         \
1407           .Device(DEVICE_CPU)                                  \
1408           .TypeConstraint<T>("T")                              \
1409           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1410       MklFusedBatchNormOp<CPUDevice, T, T, false, false, true>);
1411 
1412 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
1413 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
1414 #undef REGISTER_MKL_FUSED_BATCHNORM_CPU
1415 
1416 #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U)              \
1417   REGISTER_KERNEL_BUILDER(                                     \
1418       Name("_MklFusedBatchNormV2")                             \
1419           .Device(DEVICE_CPU)                                  \
1420           .TypeConstraint<T>("T")                              \
1421           .TypeConstraint<U>("U")                              \
1422           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1423       MklFusedBatchNormOp<CPUDevice, T, U, false, false>);     \
1424   REGISTER_KERNEL_BUILDER(                                     \
1425       Name("_MklNativeFusedBatchNormV2")                       \
1426           .Device(DEVICE_CPU)                                  \
1427           .TypeConstraint<T>("T")                              \
1428           .TypeConstraint<U>("U")                              \
1429           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1430       MklFusedBatchNormOp<CPUDevice, T, U, false, false, true>);
1431 
1432 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
1433 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
1434 #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU
1435 
1436 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T)               \
1437   REGISTER_KERNEL_BUILDER(                                     \
1438       Name("_MklFusedBatchNormGrad")                           \
1439           .Device(DEVICE_CPU)                                  \
1440           .TypeConstraint<T>("T")                              \
1441           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1442       MklFusedBatchNormGradOp<CPUDevice, T, T, false>);        \
1443   REGISTER_KERNEL_BUILDER(                                     \
1444       Name("_MklNativeFusedBatchNormGrad")                     \
1445           .Device(DEVICE_CPU)                                  \
1446           .TypeConstraint<T>("T")                              \
1447           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1448       MklFusedBatchNormGradOp<CPUDevice, T, T, false, true>);
1449 
1450 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU);
1451 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU);
1452 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU
1453 
1454 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U)         \
1455   REGISTER_KERNEL_BUILDER(                                     \
1456       Name("_MklFusedBatchNormGradV2")                         \
1457           .Device(DEVICE_CPU)                                  \
1458           .TypeConstraint<T>("T")                              \
1459           .TypeConstraint<U>("U")                              \
1460           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1461       MklFusedBatchNormGradOp<CPUDevice, T, U, false>);        \
1462   REGISTER_KERNEL_BUILDER(                                     \
1463       Name("_MklNativeFusedBatchNormGradV2")                   \
1464           .Device(DEVICE_CPU)                                  \
1465           .TypeConstraint<T>("T")                              \
1466           .TypeConstraint<U>("U")                              \
1467           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1468       MklFusedBatchNormGradOp<CPUDevice, T, U, false, true>);
1469 
1470 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float);
1471 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float);
1472 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU
1473 
1474 // TODO: FusedBatchNormV3 has an additional output that is used to
1475 //       hold intermediate results. This parameter functionality is
1476 //       not implemented on CPU.
1477 #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U)               \
1478   REGISTER_KERNEL_BUILDER(                                      \
1479       Name("_MklFusedBatchNormV3")                              \
1480           .Device(DEVICE_CPU)                                   \
1481           .TypeConstraint<T>("T")                               \
1482           .TypeConstraint<U>("U")                               \
1483           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),  \
1484       MklFusedBatchNormOp<CPUDevice, T, U, true, false>);       \
1485   REGISTER_KERNEL_BUILDER(                                      \
1486       Name("_MklFusedBatchNormEx")                              \
1487           .Device(DEVICE_CPU)                                   \
1488           .TypeConstraint<T>("T")                               \
1489           .TypeConstraint<U>("U")                               \
1490           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),  \
1491       MklFusedBatchNormOp<CPUDevice, T, U, true, true>);        \
1492   REGISTER_KERNEL_BUILDER(                                      \
1493       Name("_MklNativeFusedBatchNormV3")                        \
1494           .Device(DEVICE_CPU)                                   \
1495           .TypeConstraint<T>("T")                               \
1496           .TypeConstraint<U>("U")                               \
1497           .Label(mkl_op_registry::kMklNameChangeOpLabel),       \
1498       MklFusedBatchNormOp<CPUDevice, T, U, true, false, true>); \
1499   REGISTER_KERNEL_BUILDER(                                      \
1500       Name("_MklNativeFusedBatchNormEx")                        \
1501           .Device(DEVICE_CPU)                                   \
1502           .TypeConstraint<T>("T")                               \
1503           .TypeConstraint<U>("U")                               \
1504           .Label(mkl_op_registry::kMklNameChangeOpLabel),       \
1505       MklFusedBatchNormOp<CPUDevice, T, U, true, true, true>);
1506 
1507 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
1508 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
1509 #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
1510 
1511 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1512                             .Device(DEVICE_CPU)
1513                             .TypeConstraint<float>("T")
1514                             .TypeConstraint<float>("U"),
1515                         NoOp);
1516 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1517                             .Device(DEVICE_CPU)
1518                             .TypeConstraint<bfloat16>("T")
1519                             .TypeConstraint<float>("U"),
1520                         NoOp);
1521 
1522 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U)         \
1523   REGISTER_KERNEL_BUILDER(                                     \
1524       Name("_MklFusedBatchNormGradV3")                         \
1525           .Device(DEVICE_CPU)                                  \
1526           .TypeConstraint<T>("T")                              \
1527           .TypeConstraint<U>("U")                              \
1528           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1529       MklFusedBatchNormGradOp<CPUDevice, T, U, true>);         \
1530   REGISTER_KERNEL_BUILDER(                                     \
1531       Name("_MklNativeFusedBatchNormGradV3")                   \
1532           .Device(DEVICE_CPU)                                  \
1533           .TypeConstraint<T>("T")                              \
1534           .TypeConstraint<U>("U")                              \
1535           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1536       MklFusedBatchNormGradOp<CPUDevice, T, U, true, true>);
1537 
1538 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float);
1539 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float);
1540 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU
1541 
1542 }  // namespace tensorflow
1543 
1544 #undef GET_FLAG
1545 #undef IS_SET
1546 
1547 #endif  // INTEL_MKL
1548