1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifdef INTEL_MKL 16 #include "mkldnn.hpp" 17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/fused_batch_norm_op.h" 23 #include "tensorflow/core/kernels/no_op.h" 24 #include "tensorflow/core/util/mkl_util.h" 25 #include "tensorflow/core/util/tensor_format.h" 26 27 #define GET_FLAG(bn_flag) static_cast<int>(mkldnn::normalization_flags::bn_flag) 28 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) 29 30 using mkldnn::batch_normalization_backward; 31 using mkldnn::batch_normalization_forward; 32 using mkldnn::prop_kind; 33 using mkldnn::stream; 34 35 using BatchNormFwdPd = mkldnn::batch_normalization_forward::primitive_desc; 36 using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc; 37 38 namespace tensorflow { 39 using CPUDevice = Eigen::ThreadPoolDevice; 40 41 using FusedBNActivationMode = functor::FusedBatchNormActivationMode; 42 43 struct MklBatchNormFwdParams { 44 memory::dims src_dims; 45 int depth; 46 float eps; 47 bool training; 48 TensorFormat data_format; 49 FusedBNActivationMode activation_mode; 50 memory::desc src_md; 51 MklBatchNormFwdParamstensorflow::MklBatchNormFwdParams52 MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, 53 bool training, TensorFormat data_format, 54 memory::desc src_md, 55 FusedBNActivationMode activation_mode) 56 : src_dims(src_dims), 57 depth(depth), 58 eps(eps), 59 training(training), 60 data_format(data_format), 61 activation_mode(activation_mode), 62 src_md(src_md) {} 63 }; 64 65 template <typename T, typename U> 66 class MklFusedBatchNormFwdPrimitive : public MklPrimitive { 67 public: MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams & fwdParams)68 explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) 69 : MklPrimitive(engine(engine::kind::cpu, 0)) { 70 if (context_.bn_fwd == nullptr) Setup(fwdParams); 71 } 72 ~MklFusedBatchNormFwdPrimitive()73 ~MklFusedBatchNormFwdPrimitive() {} 74 75 // BatchNormalization forward execute 76 // src_data: input data buffer of src 77 // weights_data: input data buffer of weights 78 // dst_data: output data buffer of dst 79 // mean_data: output data buffer of means 80 // variance_data: output data buffer of variances Execute(const T * src_data,const U * weights_data,T * dst_data,U * mean_data,U * variance_data,std::shared_ptr<stream> fwd_stream,U * workspace_data)81 void Execute(const T* src_data, const U* weights_data, T* dst_data, 82 U* mean_data, U* variance_data, 83 std::shared_ptr<stream> fwd_stream, U* workspace_data) { 84 #ifndef ENABLE_ONEDNN_OPENMP 85 // TODO: Create a common function and avoid the duplicate code 86 context_.src_mem->set_data_handle( 87 static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream); 88 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), 89 *fwd_stream); 90 91 if (IS_SET(use_scale_shift)) 92 context_.weights_mem->set_data_handle( 93 static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream); 94 95 if ((context_.pkind == prop_kind::forward_training) || 96 (IS_SET(use_global_stats))) { 97 context_.mean_mem->set_data_handle(static_cast<void*>(mean_data), 98 *fwd_stream); 99 context_.variance_mem->set_data_handle(static_cast<void*>(variance_data), 100 *fwd_stream); 101 } 102 if (workspace_data != nullptr) { 103 context_.ws_mem->set_data_handle(workspace_data, *fwd_stream); 104 } 105 #else 106 context_.src_mem->set_data_handle( 107 static_cast<void*>(const_cast<T*>(src_data))); 108 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 109 110 if (IS_SET(use_scale_shift)) 111 context_.weights_mem->set_data_handle( 112 static_cast<void*>(const_cast<U*>(weights_data))); 113 114 if ((context_.pkind == prop_kind::forward_training) || 115 (IS_SET(use_global_stats))) { 116 context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); 117 context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); 118 } 119 if (workspace_data != nullptr) { 120 context_.ws_mem->set_data_handle(workspace_data); 121 } 122 #endif // !ENABLE_ONEDNN_OPENMP 123 124 // Execute batch-normalization forward primitives. 125 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); 126 127 context_.src_mem->set_data_handle(DummyData); 128 context_.dst_mem->set_data_handle(DummyData); 129 130 if (IS_SET(use_scale_shift)) 131 context_.weights_mem->set_data_handle(DummyData); 132 133 if ((context_.pkind == prop_kind::forward_training) || 134 (IS_SET(use_global_stats))) { 135 context_.mean_mem->set_data_handle(DummyData); 136 context_.variance_mem->set_data_handle(DummyData); 137 } 138 139 if (workspace_data != nullptr) { 140 context_.ws_mem->set_data_handle(DummyData); 141 } 142 } 143 GetDstPd() const144 memory::desc GetDstPd() const { return context_.dst_mem->get_desc(); } 145 GetBatchNormFwdPd() const146 std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const { 147 return context_.fwd_pd; 148 } 149 150 private: 151 // Primitive reuse context for BatchNorm forward op. 152 struct BatchNormFwdContext { 153 // Flags indicating if it is training or inference mode. 154 int64 flags; 155 156 // Algorithm kind. 157 mkldnn::prop_kind pkind; 158 159 // Inputs/outputs memory. 160 std::shared_ptr<mkldnn::memory> src_mem; 161 std::shared_ptr<mkldnn::memory> weights_mem; 162 std::shared_ptr<mkldnn::memory> dst_mem; 163 std::shared_ptr<mkldnn::memory> mean_mem; 164 std::shared_ptr<mkldnn::memory> variance_mem; 165 std::shared_ptr<mkldnn::memory> ws_mem; 166 167 // Forward BatchNorm primitive descriptor. 168 std::shared_ptr<BatchNormFwdPd> fwd_pd; 169 170 // BatchNorm forward primitive. 171 std::shared_ptr<mkldnn::primitive> bn_fwd; 172 std::vector<mkldnn::primitive> fwd_primitives; 173 174 std::vector<std::unordered_map<int, memory>> net_args; 175 BatchNormFwdContexttensorflow::MklFusedBatchNormFwdPrimitive::BatchNormFwdContext176 BatchNormFwdContext() 177 : flags(0), 178 pkind(prop_kind::forward_training), 179 src_mem(nullptr), 180 weights_mem(nullptr), 181 dst_mem(nullptr), 182 mean_mem(nullptr), 183 variance_mem(nullptr), 184 ws_mem(nullptr), 185 bn_fwd(nullptr) {} 186 }; 187 Setup(const MklBatchNormFwdParams & fwdParams)188 void Setup(const MklBatchNormFwdParams& fwdParams) { 189 context_.flags = 190 fwdParams.training 191 ? GET_FLAG(use_scale_shift) 192 : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); 193 context_.pkind = fwdParams.training ? prop_kind::forward_training 194 : prop_kind::forward_scoring; 195 196 if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { 197 context_.flags |= GET_FLAG(fuse_norm_relu); 198 } 199 // Memory descriptor 200 auto src_md = fwdParams.src_md; 201 // Create forward BatchNorm descriptor and primitive descriptor. 202 auto fwd_desc = batch_normalization_forward::desc( 203 context_.pkind, src_md, fwdParams.eps, 204 static_cast<mkldnn::normalization_flags>(context_.flags)); 205 206 context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_)); 207 208 // Create memory primitive based on dummy data 209 context_.src_mem.reset( 210 new memory(context_.fwd_pd->src_desc(), cpu_engine_, DummyData)); 211 context_.dst_mem.reset( 212 new memory(context_.fwd_pd->dst_desc(), cpu_engine_, DummyData)); 213 214 memory::dims s_dims = {2, fwdParams.depth}; 215 memory::dims m_dims = {1, fwdParams.depth}; 216 if (IS_SET(use_scale_shift)) { 217 context_.weights_mem.reset( 218 new memory({{s_dims}, MklDnnType<U>(), memory::format_tag::nc}, 219 cpu_engine_, DummyData)); 220 } 221 222 if (fwdParams.training || (IS_SET(use_global_stats))) { 223 context_.mean_mem.reset( 224 new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc}, 225 cpu_engine_, DummyData)); 226 227 context_.variance_mem.reset( 228 new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc}, 229 cpu_engine_, DummyData)); 230 } 231 232 if (IS_SET(fuse_norm_relu)) { 233 context_.ws_mem.reset(new memory(context_.fwd_pd->workspace_desc(), 234 cpu_engine_, DummyData)); 235 } 236 237 // BatchNorm forward primitive. 238 // TODO(intel-tf): Merge all the #ifdefs and simplify code 239 if (!fwdParams.training && !(IS_SET(use_global_stats))) { 240 if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) { 241 context_.net_args.push_back( 242 {{MKLDNN_ARG_SRC, *context_.src_mem}, 243 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 244 {MKLDNN_ARG_DST, *context_.dst_mem}}); 245 } else { 246 context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, 247 {MKLDNN_ARG_DST, *context_.dst_mem}}); 248 } 249 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 250 } else if (IS_SET(use_global_stats)) { 251 if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { 252 if (IS_SET(fuse_norm_relu)) { 253 context_.net_args.push_back( 254 {{MKLDNN_ARG_SRC, *context_.src_mem}, 255 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 256 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 257 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 258 {MKLDNN_ARG_DST, *context_.dst_mem}, 259 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 260 } else { 261 context_.net_args.push_back( 262 {{MKLDNN_ARG_SRC, *context_.src_mem}, 263 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 264 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 265 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 266 {MKLDNN_ARG_DST, *context_.dst_mem}}); 267 } 268 } else { 269 if (IS_SET(fuse_norm_relu)) { 270 context_.net_args.push_back( 271 {{MKLDNN_ARG_SRC, *context_.src_mem}, 272 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 273 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 274 {MKLDNN_ARG_DST, *context_.dst_mem}, 275 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 276 } else { 277 context_.net_args.push_back( 278 {{MKLDNN_ARG_SRC, *context_.src_mem}, 279 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 280 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 281 {MKLDNN_ARG_DST, *context_.dst_mem}}); 282 } 283 } 284 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 285 } else { 286 if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { 287 if (IS_SET(fuse_norm_relu)) { 288 context_.net_args.push_back( 289 {{MKLDNN_ARG_SRC, *context_.src_mem}, 290 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 291 {MKLDNN_ARG_DST, *context_.dst_mem}, 292 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 293 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 294 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 295 } else { 296 context_.net_args.push_back( 297 {{MKLDNN_ARG_SRC, *context_.src_mem}, 298 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 299 {MKLDNN_ARG_DST, *context_.dst_mem}, 300 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 301 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}}); 302 } 303 } else { 304 if (IS_SET(fuse_norm_relu)) { 305 context_.net_args.push_back( 306 {{MKLDNN_ARG_SRC, *context_.src_mem}, 307 {MKLDNN_ARG_DST, *context_.dst_mem}, 308 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 309 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 310 {MKLDNN_ARG_WORKSPACE, *context_.ws_mem}}); 311 } else { 312 context_.net_args.push_back( 313 {{MKLDNN_ARG_SRC, *context_.src_mem}, 314 {MKLDNN_ARG_DST, *context_.dst_mem}, 315 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 316 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}}); 317 } 318 } 319 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 320 } 321 322 context_.fwd_primitives.push_back(*context_.bn_fwd); 323 } 324 325 struct BatchNormFwdContext context_; 326 }; 327 328 template <typename T, typename U> 329 class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 330 public: Get(const MklBatchNormFwdParams & fwdParams)331 static MklFusedBatchNormFwdPrimitive<T, U>* Get( 332 const MklBatchNormFwdParams& fwdParams) { 333 auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>( 334 MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance() 335 .GetBatchNormFwd(fwdParams)); 336 337 if (bn_fwd == nullptr) { 338 bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams); 339 MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd( 340 fwdParams, bn_fwd); 341 } 342 return bn_fwd; 343 } 344 GetInstance()345 static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() { 346 static MklFusedBatchNormFwdPrimitiveFactory instance_; 347 return instance_; 348 } 349 350 private: MklFusedBatchNormFwdPrimitiveFactory()351 MklFusedBatchNormFwdPrimitiveFactory() {} ~MklFusedBatchNormFwdPrimitiveFactory()352 ~MklFusedBatchNormFwdPrimitiveFactory() {} 353 CreateKey(const MklBatchNormFwdParams & fwdParams)354 static string CreateKey(const MklBatchNormFwdParams& fwdParams) { 355 string prefix = "bn_fwd"; 356 FactoryKeyCreator key_creator; 357 key_creator.AddAsKey(prefix); 358 key_creator.AddAsKey(fwdParams.src_dims); 359 key_creator.AddAsKey<int>(fwdParams.depth); 360 key_creator.AddAsKey<float>(fwdParams.eps); 361 key_creator.AddAsKey<bool>(fwdParams.training); 362 key_creator.AddAsKey<TensorFormat>(fwdParams.data_format); 363 key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode); 364 key_creator.AddAsKey(typeid(T).name()); 365 key_creator.AddAsKey(typeid(U).name()); 366 return key_creator.GetKey(); 367 } 368 GetBatchNormFwd(const MklBatchNormFwdParams & fwdParams)369 MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { 370 string key = CreateKey(fwdParams); 371 return this->GetOp(key); 372 } 373 SetBatchNormFwd(const MklBatchNormFwdParams & fwdParams,MklPrimitive * op)374 void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, 375 MklPrimitive* op) { 376 string key = CreateKey(fwdParams); 377 this->SetOp(key, op); 378 } 379 }; 380 381 struct MklBatchNormBwdParams { 382 memory::dims src_dims; 383 memory::dims diff_dst_dims; 384 int depth; 385 float eps; 386 bool training; 387 TensorFormat data_format; 388 memory::desc src_md; 389 memory::desc diff_dst_md; 390 MklBatchNormBwdParamstensorflow::MklBatchNormBwdParams391 MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, 392 int depth, float eps, bool training, 393 TensorFormat data_format, memory::desc src_md, 394 memory::desc diff_dst_md) 395 : src_dims(src_dims), 396 diff_dst_dims(diff_dst_dims), 397 depth(depth), 398 eps(eps), 399 training(training), 400 data_format(data_format), 401 src_md(src_md), 402 diff_dst_md(diff_dst_md) {} 403 }; 404 405 template <typename T, typename U> 406 class MklFusedBatchNormBwdPrimitive : public MklPrimitive { 407 public: MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams & bwdParams)408 explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) 409 : MklPrimitive(engine(engine::kind::cpu, 0)) { 410 if (context_.bn_bwd == nullptr) Setup(bwdParams); 411 } 412 ~MklFusedBatchNormBwdPrimitive()413 ~MklFusedBatchNormBwdPrimitive() {} 414 415 // BatchNormalization backward execute 416 // src_data: input data buffer of src 417 // mean_data: input data buffer of mean 418 // variance_data: input data buffer of variance 419 // diff_dst_data: input data buffer of diff_dst 420 // weights_data: input data buffer of weights 421 // diff_src_data: output data buffer of diff_src 422 // diff_weights_data: output data buffer of diff_weights 423 // res_space_data: output data buffer or reserved_space_3. 424 // TODO: reserved_space_3: temp mem to hold 425 // intermediate results is not implemented 426 // on CPU as of now. Execute(const T * src_data,const U * mean_data,const U * variance_data,const T * diff_dst_data,const U * weights_data,T * diff_src_data,U * diff_weights_data,U * res_space_data,std::shared_ptr<stream> bwd_stream)427 void Execute(const T* src_data, const U* mean_data, const U* variance_data, 428 const T* diff_dst_data, const U* weights_data, T* diff_src_data, 429 U* diff_weights_data, U* res_space_data, 430 std::shared_ptr<stream> bwd_stream) { 431 #ifndef ENABLE_ONEDNN_OPENMP 432 // TODO: Create a common function and avoid the duplicate code 433 context_.src_mem->set_data_handle( 434 static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream); 435 context_.mean_mem->set_data_handle( 436 static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream); 437 context_.variance_mem->set_data_handle( 438 static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream); 439 context_.diff_dst_mem->set_data_handle( 440 static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream); 441 442 if (IS_SET(use_scale_shift)) { 443 context_.weights_mem->set_data_handle( 444 static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream); 445 context_.diff_weights_mem->set_data_handle( 446 static_cast<void*>(diff_weights_data), *bwd_stream); 447 } 448 449 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data), 450 *bwd_stream); 451 #else 452 context_.src_mem->set_data_handle( 453 static_cast<void*>(const_cast<T*>(src_data))); 454 context_.mean_mem->set_data_handle( 455 static_cast<void*>(const_cast<U*>(mean_data))); 456 context_.variance_mem->set_data_handle( 457 static_cast<void*>(const_cast<U*>(variance_data))); 458 context_.diff_dst_mem->set_data_handle( 459 static_cast<void*>(const_cast<T*>(diff_dst_data))); 460 461 if (IS_SET(use_scale_shift)) { 462 context_.weights_mem->set_data_handle( 463 static_cast<void*>(const_cast<U*>(weights_data))); 464 context_.diff_weights_mem->set_data_handle( 465 static_cast<void*>(diff_weights_data)); 466 } 467 468 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); 469 #endif // !ENABLE_ONEDNN_OPENMP 470 // Execute backward batch-normalization primitives. 471 DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); 472 execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); 473 474 // After execution, set data handle back to DummyData. 475 context_.src_mem->set_data_handle(DummyData); 476 context_.mean_mem->set_data_handle(DummyData); 477 context_.variance_mem->set_data_handle(DummyData); 478 context_.diff_dst_mem->set_data_handle(DummyData); 479 if (IS_SET(use_scale_shift)) { 480 context_.weights_mem->set_data_handle(DummyData); 481 context_.diff_weights_mem->set_data_handle(DummyData); 482 } 483 context_.diff_src_mem->set_data_handle(DummyData); 484 } 485 GetBatchNormBwdPd() const486 std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const { 487 return context_.bwd_pd; 488 } 489 GetDiffSrcPd()490 memory::desc GetDiffSrcPd() { return context_.diff_src_mem->get_desc(); } 491 492 private: 493 struct BatchNormBwdContext { 494 // Flags to indicate whether it is training or inference. 495 int64 flags; 496 497 // Inputs/output memory. 498 std::shared_ptr<mkldnn::memory> src_mem; 499 std::shared_ptr<mkldnn::memory> mean_mem; 500 std::shared_ptr<mkldnn::memory> variance_mem; 501 std::shared_ptr<mkldnn::memory> diff_dst_mem; 502 std::shared_ptr<mkldnn::memory> weights_mem; 503 std::shared_ptr<mkldnn::memory> diff_weights_mem; 504 std::shared_ptr<mkldnn::memory> diff_src_mem; 505 506 // Backward batch-normalization primitive descriptor. 507 std::shared_ptr<BatchNormBwdPd> bwd_pd; 508 509 // Backward batch-normalization primitive. 510 std::shared_ptr<mkldnn::primitive> bn_bwd; 511 std::vector<mkldnn::primitive> bwd_primitives; 512 513 std::vector<std::unordered_map<int, memory>> net_args; 514 BatchNormBwdContexttensorflow::MklFusedBatchNormBwdPrimitive::BatchNormBwdContext515 BatchNormBwdContext() 516 : src_mem(nullptr), 517 mean_mem(nullptr), 518 variance_mem(nullptr), 519 diff_dst_mem(nullptr), 520 weights_mem(nullptr), 521 diff_weights_mem(nullptr), 522 diff_src_mem(nullptr) {} 523 }; 524 Setup(const MklBatchNormBwdParams & bwdParams)525 void Setup(const MklBatchNormBwdParams& bwdParams) { 526 context_.flags = 527 bwdParams.training 528 ? GET_FLAG(use_scale_shift) 529 : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); 530 531 // Memory descriptors. 532 auto src_md = bwdParams.src_md; 533 auto diff_dst_md = bwdParams.diff_dst_md; 534 auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), 535 memory::format_tag::nc); 536 auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), 537 memory::format_tag::nc); 538 auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(), 539 memory::format_tag::nc); 540 auto diff_weights_desc = weights_desc; 541 542 // Forward batch-normalization descriptor and primitive descriptor. 543 // Adding this back due to type difference with context.flags 544 auto bn_flags = bwdParams.training 545 ? mkldnn::normalization_flags::use_scale_shift 546 : (mkldnn::normalization_flags::use_scale_shift | 547 mkldnn::normalization_flags::use_global_stats); 548 auto fwd_desc = batch_normalization_forward::desc( 549 prop_kind::forward_training, src_md, bwdParams.eps, bn_flags); 550 auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_); 551 552 // Backward batch-normalization primitive. 553 // For inference, specify use_global_stats 554 // 1. on fwd propagation, use mean and variance provided as inputs. 555 // 2. on bwd propagation, mean and variance are considered as constants. 556 // Thus, reduce the amount of MKL computation. 557 auto bwd_desc = batch_normalization_backward::desc( 558 prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags); 559 context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd)); 560 561 // Create memory primitives. 562 context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); 563 context_.diff_dst_mem.reset( 564 new memory(diff_dst_md, cpu_engine_, DummyData)); 565 context_.variance_mem.reset( 566 new memory(variance_desc, cpu_engine_, DummyData)); 567 context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData)); 568 context_.weights_mem.reset( 569 new memory(weights_desc, cpu_engine_, DummyData)); 570 context_.diff_weights_mem.reset( 571 new memory(diff_weights_desc, cpu_engine_, DummyData)); 572 context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); 573 574 context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd)); 575 context_.net_args.push_back( 576 {{MKLDNN_ARG_SRC, *context_.src_mem}, 577 {MKLDNN_ARG_MEAN, *context_.mean_mem}, 578 {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, 579 {MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem}, 580 {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, 581 {MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem}, 582 {MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_weights_mem}}); 583 context_.bwd_primitives.push_back(*context_.bn_bwd); 584 } 585 586 struct BatchNormBwdContext context_; 587 }; 588 589 template <typename T, typename U> 590 class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { 591 public: Get(const MklBatchNormBwdParams & bwdParams)592 static MklFusedBatchNormBwdPrimitive<T, U>* Get( 593 const MklBatchNormBwdParams& bwdParams) { 594 auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>( 595 MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance() 596 .GetBatchNormBwd(bwdParams)); 597 if (bn_bwd == nullptr) { 598 bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams); 599 MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd( 600 bwdParams, bn_bwd); 601 } 602 return bn_bwd; 603 } 604 GetInstance()605 static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { 606 static MklFusedBatchNormBwdPrimitiveFactory instance_; 607 return instance_; 608 } 609 610 private: MklFusedBatchNormBwdPrimitiveFactory()611 MklFusedBatchNormBwdPrimitiveFactory() {} ~MklFusedBatchNormBwdPrimitiveFactory()612 ~MklFusedBatchNormBwdPrimitiveFactory() {} 613 CreateKey(const MklBatchNormBwdParams & bwdParams)614 static string CreateKey(const MklBatchNormBwdParams& bwdParams) { 615 string prefix = "bn_bwd"; 616 FactoryKeyCreator key_creator; 617 key_creator.AddAsKey(prefix); 618 key_creator.AddAsKey(bwdParams.src_dims); 619 key_creator.AddAsKey(bwdParams.diff_dst_dims); 620 key_creator.AddAsKey<int>(bwdParams.depth); 621 key_creator.AddAsKey<float>(bwdParams.eps); 622 key_creator.AddAsKey<bool>(bwdParams.training); 623 key_creator.AddAsKey<TensorFormat>(bwdParams.data_format); 624 key_creator.AddAsKey(typeid(T).name()); 625 key_creator.AddAsKey(typeid(U).name()); 626 return key_creator.GetKey(); 627 } 628 GetBatchNormBwd(const MklBatchNormBwdParams & bwdParams)629 MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { 630 string key = CreateKey(bwdParams); 631 return this->GetOp(key); 632 } 633 SetBatchNormBwd(const MklBatchNormBwdParams & bwdParams,MklPrimitive * op)634 void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, 635 MklPrimitive* op) { 636 string key = CreateKey(bwdParams); 637 this->SetOp(key, op); 638 } 639 }; 640 641 // Adding a third parameter to the template to support FusedBatchNormV3 642 // with MKL. This is different from default where the classes are 643 // derived. Moves enabling to compile-time rather than runtime. 644 template <typename Device, typename T, typename U, bool reserved_space, 645 bool is_batch_norm_ex = false, bool native_format = false> 646 class MklFusedBatchNormOp : public OpKernel { 647 public: MklFusedBatchNormOp(OpKernelConstruction * context)648 explicit MklFusedBatchNormOp(OpKernelConstruction* context) 649 : OpKernel(context) { 650 float epsilon; 651 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 652 epsilon_ = epsilon; 653 float exponential_avg_factor; 654 OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor", 655 &exponential_avg_factor)); 656 exponential_avg_factor_ = static_cast<U>(exponential_avg_factor); 657 string tensor_format; 658 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 659 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 660 errors::InvalidArgument("Invalid data format")); 661 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 662 depth_ = 0; 663 mean_values_ = nullptr; 664 variance_values_ = nullptr; 665 666 if (!is_batch_norm_ex) { 667 activation_mode_ = FusedBNActivationMode::kIdentity; 668 } else { 669 int num_side_inputs; 670 OP_REQUIRES_OK(context, 671 context->GetAttr("num_side_inputs", &num_side_inputs)); 672 // Currently _MKLFusedBatchNormEx do not support "SideInput" 673 OP_REQUIRES(context, num_side_inputs == 0, 674 errors::InvalidArgument( 675 "_MKLFusedBatchNorm do not support side input now.")); 676 677 OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_)); 678 OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu, 679 errors::InvalidArgument( 680 "_MKLFusedBatchNorm only support Relu activation")); 681 } 682 } 683 Compute(OpKernelContext * context)684 void Compute(OpKernelContext* context) override { 685 try { 686 const size_t kSrcIndex = 0; // index of src input tensor 687 const size_t kScaleIndex = 1; // index of scale tensor 688 const size_t kShiftIndex = 2; // index of shift tensor 689 const size_t kMeanIndex = 3; // index of est_mean tensor 690 const size_t kVarianceIndex = 4; // index of est_variance tensor 691 692 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 693 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 694 const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); 695 const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); 696 const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex); 697 698 TensorShape tf_shape_src; 699 MklDnnShape dnn_shape_src; 700 GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); 701 702 if (dnn_shape_src.IsMklTensor()) { 703 tf_shape_src = dnn_shape_src.GetTfShape(); 704 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 705 errors::InvalidArgument("input must be 4-dimensional", 706 src_tensor.shape().DebugString())); 707 } else { 708 tf_shape_src = src_tensor.shape(); 709 OP_REQUIRES(context, src_tensor.dims() == 4, 710 errors::InvalidArgument("input must be 4-dimensional", 711 src_tensor.shape().DebugString())); 712 } 713 OP_REQUIRES(context, scale_tensor.dims() == 1, 714 errors::InvalidArgument("scale must be 1-dimensional", 715 scale_tensor.shape().DebugString())); 716 OP_REQUIRES(context, shift_tensor.dims() == 1, 717 errors::InvalidArgument("offset must be 1-dimensional", 718 shift_tensor.shape().DebugString())); 719 OP_REQUIRES( 720 context, est_mean_tensor.dims() == 1, 721 errors::InvalidArgument("estimated_mean must be 1-dimensional", 722 est_mean_tensor.shape().DebugString())); 723 OP_REQUIRES( 724 context, est_variance_tensor.dims() == 1, 725 errors::InvalidArgument("estimated_variance must be 1-dimensional", 726 est_variance_tensor.shape().DebugString())); 727 728 // Handle the special case: input with 0 element and 0 batch size. 729 Tensor* dst_tensor = nullptr; 730 TensorShape workspace_tf_shape; 731 if (tf_shape_src.num_elements() == 0) { 732 size_t workspace_bytes = 0; 733 workspace_tf_shape.AddDim(workspace_bytes); 734 HandleEmptyInput(context, tf_shape_src, workspace_tf_shape, 735 scale_tensor.shape(), &dst_tensor); 736 return; 737 } 738 739 if (dnn_shape_src.IsMklTensor()) 740 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 741 else 742 ExtractParams(context); 743 744 // Index of output tensor(diff_src). 745 const size_t kDstIndex = 0; 746 747 // Allocate 5 output TF tensors. 748 Tensor* batch_mean_tensor = nullptr; 749 Tensor* batch_variance_tensor = nullptr; 750 Tensor* saved_mean_tensor = nullptr; 751 Tensor* saved_variance_tensor = nullptr; 752 Tensor* reserved_space_tensor = nullptr; 753 754 MklDnnData<T> src(&cpu_engine_); 755 MklDnnData<U> weights(&cpu_engine_); 756 MklDnnData<U> wksp(&cpu_engine_); 757 758 memory::format_tag dnn_fmt; 759 MklTensorFormat mkl_tensor_fmt; 760 if (dnn_shape_src.IsMklTensor()) { 761 if (dnn_shape_src.IsTensorInNCHWFormat()) { 762 dnn_fmt = memory::format_tag::nchw; 763 mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW; 764 } else { 765 dnn_fmt = memory::format_tag::nhwc; 766 mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC; 767 } 768 } else { 769 mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); 770 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); 771 } 772 773 // Set src memory descriptor. 774 memory::dims src_dims = 775 dnn_shape_src.IsMklTensor() 776 ? dnn_shape_src.GetSizesAsMklDnnDims() 777 : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 778 779 auto src_md = dnn_shape_src.IsMklTensor() 780 ? dnn_shape_src.GetMklLayout() 781 : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 782 783 MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, 784 tensor_format_, src_md, activation_mode_); 785 786 // Get forward batch-normalization op from the primitive caching pool. 787 MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = 788 MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); 789 790 // Allocate workspace tensor 791 U* ws_data = nullptr; 792 if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { 793 memory::desc workspace_md = 794 bn_fwd->GetBatchNormFwdPd()->workspace_desc(); 795 size_t workspace_bytes = workspace_md.get_size(); 796 workspace_tf_shape.AddDim(workspace_bytes); 797 798 AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, 799 &batch_mean_tensor, &batch_variance_tensor, 800 &saved_mean_tensor, &saved_variance_tensor, 801 &reserved_space_tensor); 802 if (reserved_space) { 803 wksp.SetUsrMem(workspace_md, reserved_space_tensor); 804 ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle()); 805 } 806 } else { 807 // There is actually no workspace tensor out, so we make a dummy one. 808 size_t workspace_bytes = 0; 809 workspace_tf_shape.AddDim(workspace_bytes); 810 AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, 811 &batch_mean_tensor, &batch_variance_tensor, 812 &saved_mean_tensor, &saved_variance_tensor, 813 &reserved_space_tensor); 814 } 815 816 if (is_training_) 817 SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); 818 else 819 SetMeanVariance(est_mean_tensor, est_variance_tensor); 820 821 // MKL-DNN packs scale & shift as "weights": 822 // <scale>...<scale><shift>...<shift> 823 weights.AllocateBuffer(2 * depth_ * sizeof(U)); 824 U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); 825 const U* scale_tf = scale_tensor.flat<U>().data(); 826 const U* shift_tf = shift_tensor.flat<U>().data(); 827 828 std::memcpy(weights_data, scale_tf, depth_ * sizeof(U)); 829 std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U)); 830 char* saved_mean_data_tf = 831 reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data()); 832 std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), 833 depth_ * sizeof(U)); 834 835 char* saved_variance_data_tf = 836 reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data()); 837 std::memcpy(saved_variance_data_tf, 838 reinterpret_cast<char*>(variance_values_), 839 depth_ * sizeof(U)); 840 841 // Check if reorder is needed for src. 842 const T* src_data = nullptr; 843 std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd(); 844 if (!native_format && src_md != bn_fwd_pd->src_desc()) { 845 src.SetUsrMem(src_md, &src_tensor); 846 src.CheckReorderToOpMem(bn_fwd_pd->src_desc(), cpu_engine_, context); 847 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 848 } else { 849 src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 850 } 851 852 // Allocate output (dst) tensor 853 MklDnnShape dnn_shape_dst; 854 TensorShape tf_shape_dst; 855 dnn_shape_dst.SetMklTensor(true); 856 auto dst_pd = bn_fwd->GetDstPd(); 857 dnn_shape_dst.SetMklLayout(&dst_pd); 858 dnn_shape_dst.SetElemType(MklDnnType<T>()); 859 auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() 860 : src_tensor.shape().dims(); 861 dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt); 862 tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); 863 if (native_format) { 864 tf_shape_dst = dnn_shape_dst.GetTfShape(); 865 } 866 AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, 867 dnn_shape_dst, native_format); 868 869 U* weights_op_data = weights_data; 870 U* mean_op_data = saved_mean_tensor->flat<U>().data(); 871 U* variance_op_data = saved_variance_tensor->flat<U>().data(); 872 T* dst_data = dst_tensor->flat<T>().data(); 873 874 // Execute 875 std::shared_ptr<stream> fwd_cpu_stream; 876 MklDnnThreadPool eigen_tp(context); 877 fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine())); 878 bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, 879 variance_op_data, fwd_cpu_stream, ws_data); 880 float adjust_factor = 1.0; 881 if (is_training_) { 882 size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; 883 size_t adjust_size = (orig_size > 1) ? (orig_size - 1) : 1; 884 adjust_factor = (static_cast<float>(orig_size)) / adjust_size; 885 } 886 887 auto mean_data = reinterpret_cast<U*>(saved_mean_data_tf); 888 auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf); 889 auto batch_mean_data = batch_mean_tensor->flat<U>().data(); 890 auto batch_variance_data = batch_variance_tensor->flat<U>().data(); 891 auto est_mean_data = est_mean_tensor.flat<U>().data(); 892 auto est_variance_data = est_variance_tensor.flat<U>().data(); 893 if (is_training_) { 894 if (exponential_avg_factor_ == U(1.0)) { 895 for (int k = 0; k < depth_; k++) { 896 batch_mean_data[k] = mean_data[k]; 897 batch_variance_data[k] = 898 static_cast<U>(adjust_factor) * variance_data[k]; 899 } 900 } else { 901 U one_minus_factor = U(1.0) - exponential_avg_factor_; 902 for (int k = 0; k < depth_; k++) { 903 batch_mean_data[k] = one_minus_factor * est_mean_data[k] + 904 exponential_avg_factor_ * mean_data[k]; 905 batch_variance_data[k] = one_minus_factor * est_variance_data[k] + 906 exponential_avg_factor_ * 907 static_cast<U>(adjust_factor) * 908 variance_data[k]; 909 } 910 } 911 } else { 912 std::memcpy(batch_mean_data, mean_data, depth_ * sizeof(U)); 913 std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U)); 914 } 915 } catch (mkldnn::error& e) { 916 string error_msg = "Status: " + std::to_string(e.status) + 917 ", message: " + string(e.message) + ", in file " + 918 string(__FILE__) + ":" + std::to_string(__LINE__); 919 OP_REQUIRES_OK( 920 context, 921 errors::Aborted("Operation received an exception:", error_msg)); 922 } 923 } 924 925 private: 926 float epsilon_; 927 U exponential_avg_factor_; 928 TensorFormat tensor_format_; 929 bool is_training_; 930 U* mean_values_; 931 U* variance_values_; 932 size_t depth_; // Batch normalization is performed for per channel. 933 FusedBNActivationMode activation_mode_; 934 engine cpu_engine_ = engine(engine::kind::cpu, 0); 935 ExtractParams(OpKernelContext * context)936 void ExtractParams(OpKernelContext* context) { 937 const Tensor& input = MklGetInput(context, 0); 938 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 939 } 940 SetMeanVariance(const Tensor & mean,const Tensor & variance)941 void SetMeanVariance(const Tensor& mean, const Tensor& variance) { 942 mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data())); 943 variance_values_ = 944 reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data())); 945 } 946 HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape workspace_tf_shape,TensorShape tf_shape_scale,Tensor ** dst_tensor)947 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 948 TensorShape workspace_tf_shape, 949 TensorShape tf_shape_scale, Tensor** dst_tensor) { 950 DCHECK(dst_tensor); 951 952 const size_t kDstIndex = 0; 953 MklDnnShape dnn_shape_dst; 954 dnn_shape_dst.SetMklTensor(false); 955 AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, 956 dnn_shape_dst, native_format); 957 DCHECK(*dst_tensor); 958 memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0, 959 (*dst_tensor)->tensor_data().size()); 960 961 Tensor* batch_mean_tensor = nullptr; 962 Tensor* batch_variance_tensor = nullptr; 963 Tensor* saved_mean_tensor = nullptr; 964 Tensor* saved_variance_tensor = nullptr; 965 Tensor* reserved_space_tensor = nullptr; 966 AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape, 967 &batch_mean_tensor, &batch_variance_tensor, 968 &saved_mean_tensor, &saved_variance_tensor, 969 &reserved_space_tensor); 970 } 971 AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale,TensorShape workspace_tf_shape,Tensor ** batch_mean_tensor,Tensor ** batch_variance_tensor,Tensor ** saved_mean_tensor,Tensor ** saved_variance_tensor,Tensor ** reserved_space_tensor)972 void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, 973 TensorShape workspace_tf_shape, 974 Tensor** batch_mean_tensor, 975 Tensor** batch_variance_tensor, 976 Tensor** saved_mean_tensor, 977 Tensor** saved_variance_tensor, 978 Tensor** reserved_space_tensor) { 979 DCHECK(batch_mean_tensor); 980 DCHECK(batch_variance_tensor); 981 DCHECK(saved_mean_tensor); 982 DCHECK(saved_variance_tensor); 983 984 const size_t kBatchMeanIndex = 1; 985 const size_t kBatchVarianceIndex = 2; 986 const size_t kSavedMeanIndex = 3; 987 const size_t kSavedVarianceIndex = 4; 988 const size_t kReservedSpaceIndex = 5; 989 990 // Allocate batch mean output tensor. 991 MklDnnShape mkl_shape_batch_mean; 992 mkl_shape_batch_mean.SetMklTensor(false); 993 AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, 994 tf_shape_scale, mkl_shape_batch_mean, 995 native_format); 996 DCHECK(*batch_mean_tensor); 997 998 // Set NAN mean value in case of empty input tensor 999 int num_elements = tf_shape_scale.num_elements(); 1000 auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data(); 1001 std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN)); 1002 1003 // Allocate batch variance output tensor. 1004 MklDnnShape mkl_shape_batch_variance; 1005 mkl_shape_batch_variance.SetMklTensor(false); 1006 AllocateOutputSetMklShape(context, kBatchVarianceIndex, 1007 batch_variance_tensor, tf_shape_scale, 1008 mkl_shape_batch_variance, native_format); 1009 DCHECK(*batch_variance_tensor); 1010 1011 // Set NAN variance value in case of empty input tensor 1012 auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data(); 1013 std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN)); 1014 // Mean and variance (without Bessel's correction) saved for backward 1015 // computation to serve as pre-computed mean and variance. 1016 MklDnnShape mkl_shape_saved_mean; 1017 mkl_shape_saved_mean.SetMklTensor(false); 1018 AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, 1019 tf_shape_scale, mkl_shape_saved_mean, 1020 native_format); 1021 DCHECK(*saved_mean_tensor); 1022 1023 // Set 0 mean value in case of empty input tensor 1024 auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data(); 1025 std::fill_n(saved_mean_data, num_elements, static_cast<U>(0)); 1026 1027 MklDnnShape mkl_shape_saved_variance; 1028 mkl_shape_saved_variance.SetMklTensor(false); 1029 AllocateOutputSetMklShape(context, kSavedVarianceIndex, 1030 saved_variance_tensor, tf_shape_scale, 1031 mkl_shape_saved_variance, native_format); 1032 DCHECK(*saved_variance_tensor); 1033 1034 // Set 0 variance value in case of empty input tensor 1035 auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data(); 1036 std::fill_n(saved_variance_data, num_elements, static_cast<U>(0)); 1037 1038 // Changes to support reserved_space_3 parameter in FusedBatchNormV3. 1039 if (reserved_space) { 1040 DCHECK(reserved_space_tensor != nullptr); 1041 1042 MklDnnShape mkl_shape_reserved_space; 1043 mkl_shape_reserved_space.SetMklTensor(false); 1044 AllocateOutputSetMklShape(context, kReservedSpaceIndex, 1045 reserved_space_tensor, workspace_tf_shape, 1046 mkl_shape_reserved_space, native_format); 1047 DCHECK((*reserved_space_tensor) != nullptr); 1048 } 1049 } 1050 }; 1051 1052 template <typename Device, typename T, typename U, bool reserved_space, 1053 bool native_format = false> 1054 class MklFusedBatchNormGradOp : public OpKernel { 1055 public: MklFusedBatchNormGradOp(OpKernelConstruction * context)1056 explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) 1057 : OpKernel(context) { 1058 float epsilon; 1059 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1060 epsilon_ = epsilon; 1061 string tensor_format; 1062 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 1063 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 1064 errors::InvalidArgument("Invalid data format")); 1065 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 1066 depth_ = 0; 1067 } 1068 Compute(OpKernelContext * context)1069 void Compute(OpKernelContext* context) override { 1070 try { 1071 const size_t kDiffDstIndex = 0; // index of diff_dst tensor 1072 const size_t kSrcIndex = 1; // index of src input tensor 1073 const size_t kScaleIndex = 2; // index of scale tensor 1074 const size_t kMeanIndex = 3; // index of saved_mean tensor 1075 const size_t kVarianceIndex = 4; // index of saved_variance tensor 1076 const size_t kReservedSpaceIndex = 5; // index of reserved space 3 tensor 1077 1078 const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); 1079 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 1080 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 1081 const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); 1082 const Tensor& saved_variance_tensor = 1083 MklGetInput(context, kVarianceIndex); 1084 const Tensor& reserved_space_tensor = 1085 (reserved_space) ? MklGetInput(context, kReservedSpaceIndex) 1086 : Tensor(); 1087 1088 MklDnnShape dnn_shape_src, dnn_shape_diff_dst; 1089 GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); 1090 GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format); 1091 1092 TensorShape tf_shape_src, tf_shape_diff_dst; 1093 if (dnn_shape_diff_dst.IsMklTensor()) { 1094 tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); 1095 OP_REQUIRES( 1096 context, dnn_shape_diff_dst.GetDimension() == 4, 1097 errors::InvalidArgument("input must be 4-dimensional", 1098 diff_dst_tensor.shape().DebugString())); 1099 } else { 1100 tf_shape_diff_dst = diff_dst_tensor.shape(); 1101 OP_REQUIRES( 1102 context, diff_dst_tensor.dims() == 4, 1103 errors::InvalidArgument("input must be 4-dimensional", 1104 diff_dst_tensor.shape().DebugString())); 1105 } 1106 1107 if (dnn_shape_src.IsMklTensor()) { 1108 tf_shape_src = dnn_shape_src.GetTfShape(); 1109 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 1110 errors::InvalidArgument("input must be 4-dimensional", 1111 src_tensor.shape().DebugString())); 1112 } else { 1113 tf_shape_src = src_tensor.shape(); 1114 OP_REQUIRES(context, src_tensor.dims() == 4, 1115 errors::InvalidArgument("input must be 4-dimensional", 1116 src_tensor.shape().DebugString())); 1117 } 1118 1119 OP_REQUIRES(context, scale_tensor.dims() == 1, 1120 errors::InvalidArgument("scale must be 1-dimensional", 1121 scale_tensor.shape().DebugString())); 1122 OP_REQUIRES( 1123 context, saved_mean_tensor.dims() == 1, 1124 errors::InvalidArgument("saved mean must be 1-dimensional", 1125 saved_mean_tensor.shape().DebugString())); 1126 1127 OP_REQUIRES( 1128 context, saved_variance_tensor.dims() == 1, 1129 errors::InvalidArgument("saved variance must be 1-dimensional", 1130 saved_variance_tensor.shape().DebugString())); 1131 1132 // Handle the special case: input with 0 element and 0 batch size. 1133 Tensor* diff_src_tensor = nullptr; 1134 if (tf_shape_src.num_elements() == 0 || 1135 tf_shape_diff_dst.num_elements() == 0) { 1136 HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), 1137 &diff_src_tensor); 1138 return; 1139 } 1140 1141 if (dnn_shape_src.IsMklTensor()) { 1142 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 1143 } else if (dnn_shape_diff_dst.IsMklTensor()) { 1144 depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C); 1145 } else { 1146 ExtractParams(context); 1147 } 1148 1149 memory::format_tag dnn_fmt; 1150 MklTensorFormat mkl_tensor_fmt; 1151 if (dnn_shape_src.IsMklTensor()) { 1152 if (dnn_shape_src.IsTensorInNCHWFormat()) { 1153 dnn_fmt = memory::format_tag::nchw; 1154 mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW; 1155 } else { 1156 dnn_fmt = memory::format_tag::nhwc; 1157 mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC; 1158 } 1159 } else { 1160 mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); 1161 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); 1162 } 1163 1164 MklDnnData<T> src(&cpu_engine_); 1165 MklDnnData<T> diff_dst(&cpu_engine_); 1166 MklDnnData<U> weights(&cpu_engine_); 1167 MklDnnData<U> diff_weights(&cpu_engine_); 1168 1169 memory::dims src_dims = 1170 dnn_shape_src.IsMklTensor() 1171 ? dnn_shape_src.GetSizesAsMklDnnDims() 1172 : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 1173 memory::dims diff_dst_dims = 1174 dnn_shape_diff_dst.IsMklTensor() 1175 ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() 1176 : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), 1177 tensor_format_); 1178 1179 // Set src and diff_dst primitive descriptors. 1180 memory::desc src_md = 1181 dnn_shape_src.IsMklTensor() 1182 ? dnn_shape_src.GetMklLayout() 1183 : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 1184 memory::desc diff_dst_md = 1185 dnn_shape_diff_dst.IsMklTensor() 1186 ? dnn_shape_diff_dst.GetMklLayout() 1187 : memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt); 1188 1189 MklDnnData<T> reorder_src(&cpu_engine_); 1190 MklDnnData<T> reorder_diff_dst(&cpu_engine_); 1191 T* diff_dst_data = 1192 static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 1193 T* src_data = 1194 static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 1195 1196 if (!native_format) { 1197 // MKL-DNN requires src and diff_dst to be in same memory layout, either 1198 // blocked or native format. If these inputs are in different formats, 1199 // convert the one in native format to blocked format as MKL-DNN gives 1200 // better performance for blocked format. 1201 if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { 1202 reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 1203 reorder_diff_dst.CheckReorderToOpMem(src_md, cpu_engine_, context); 1204 diff_dst_md = src_md; 1205 diff_dst_data = 1206 static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle()); 1207 } else if (!dnn_shape_src.IsMklTensor() && 1208 dnn_shape_diff_dst.IsMklTensor()) { 1209 reorder_src.SetUsrMem(src_md, &src_tensor); 1210 reorder_src.CheckReorderToOpMem(diff_dst_md, cpu_engine_, context); 1211 src_md = diff_dst_md; 1212 src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle()); 1213 } 1214 } 1215 1216 // weights -- MKL DNN packs scales/ shifts as weights in order 1217 // of scale, ..., scale, shift, ...., shift 1218 weights.AllocateBuffer(2 * depth_ * sizeof(U)); 1219 U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); 1220 const U* scale_tf = scale_tensor.flat<U>().data(); 1221 for (int k = 0; k < depth_; k++) { 1222 weights_data_tf[k] = scale_tf[k]; 1223 weights_data_tf[k + depth_] = static_cast<U>(0); 1224 } 1225 1226 diff_weights.AllocateBuffer(2 * depth_ * sizeof(U)); 1227 1228 MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, 1229 is_training_, tensor_format_, src_md, 1230 diff_dst_md); 1231 MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd = 1232 MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams); 1233 1234 // Check if diff_dst input needs to be reordered 1235 std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd(); 1236 if (!native_format && diff_dst_md != bn_bwd_pd->diff_dst_desc()) { 1237 diff_dst.SetUsrMem(diff_dst_md, diff_dst_data); 1238 diff_dst.CheckReorderToOpMem(bn_bwd_pd->diff_dst_desc(), cpu_engine_, 1239 context); 1240 diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle()); 1241 } 1242 1243 if (!native_format && (src_md != bn_bwd_pd->src_desc())) { 1244 src.SetUsrMem(src_md, src_data); 1245 src.CheckReorderToOpMem(bn_bwd_pd->src_desc(), cpu_engine_, context); 1246 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 1247 } 1248 1249 // Indices of output tensors 1250 const size_t kDiffSrcIndex = 0; 1251 1252 // Allocate output tensor diff_src, always set as MKL-DNN layout. 1253 MklDnnShape dnn_shape_diff_src; 1254 TensorShape tf_shape_diff_src; 1255 dnn_shape_diff_src.SetMklTensor(true); 1256 auto diff_src_pd = bn_bwd->GetDiffSrcPd(); 1257 dnn_shape_diff_src.SetMklLayout(&diff_src_pd); 1258 dnn_shape_diff_src.SetElemType(MklDnnType<T>()); 1259 dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt); 1260 dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); 1261 tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); 1262 if (native_format) { 1263 tf_shape_diff_src = dnn_shape_diff_src.GetTfShape(); 1264 } 1265 AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, 1266 tf_shape_diff_src, dnn_shape_diff_src, 1267 native_format); 1268 1269 U* mean_data = 1270 static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data())); 1271 U* variance_data = static_cast<U*>( 1272 const_cast<U*>(saved_variance_tensor.flat<U>().data())); 1273 U* weights_data = weights_data_tf; 1274 T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data()); 1275 U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer()); 1276 1277 U* res_space_data = 1278 ((reserved_space) ? static_cast<U*>(const_cast<U*>( 1279 reserved_space_tensor.flat<U>().data())) 1280 : nullptr); 1281 1282 // Execute 1283 std::shared_ptr<stream> bwd_cpu_stream; 1284 MklDnnThreadPool eigen_tp(context); 1285 bwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_bwd->GetEngine())); 1286 bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, 1287 weights_data, diff_src_data, diff_weights_data, 1288 res_space_data, bwd_cpu_stream); 1289 // Allocate output TF tensors diff_scale and diff_shift. 1290 Tensor* diff_scale_tensor = nullptr; 1291 Tensor* diff_shift_tensor = nullptr; 1292 AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, 1293 &diff_shift_tensor); 1294 1295 // Copy data for tensors diff_scale and diff_shift. 1296 auto diff_scale_data = diff_scale_tensor->flat<U>().data(); 1297 auto diff_shift_data = diff_shift_tensor->flat<U>().data(); 1298 std::memcpy(reinterpret_cast<char*>(diff_scale_data), 1299 reinterpret_cast<char*>(diff_weights_data), 1300 depth_ * sizeof(U)); 1301 std::memcpy(reinterpret_cast<char*>(diff_shift_data), 1302 reinterpret_cast<char*>(diff_weights_data + depth_), 1303 depth_ * sizeof(U)); 1304 } catch (mkldnn::error& e) { 1305 string error_msg = "Status: " + std::to_string(e.status) + 1306 ", message: " + string(e.message) + ", in file " + 1307 string(__FILE__) + ":" + std::to_string(__LINE__); 1308 OP_REQUIRES_OK( 1309 context, 1310 errors::Aborted("Operation received an exception:", error_msg)); 1311 } 1312 } 1313 1314 private: 1315 float epsilon_; 1316 TensorFormat tensor_format_; 1317 size_t depth_; // Batch normalization is performed for per channel. 1318 bool is_training_; 1319 engine cpu_engine_ = engine(engine::kind::cpu, 0); 1320 ExtractParams(OpKernelContext * context)1321 void ExtractParams(OpKernelContext* context) { 1322 const Tensor& input = MklGetInput(context, 0); 1323 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 1324 } 1325 HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape tf_shape_scale_shift,Tensor ** diff_src_tensor)1326 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 1327 TensorShape tf_shape_scale_shift, 1328 Tensor** diff_src_tensor) { 1329 const size_t kDiffSrcIndex = 0; 1330 1331 MklDnnShape dnn_shape_diff_src; 1332 dnn_shape_diff_src.SetMklTensor(false); 1333 AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, 1334 tf_shape_src, dnn_shape_diff_src, native_format); 1335 auto diff_src_data = (*diff_src_tensor)->flat<T>().data(); 1336 std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 1337 static_cast<T>(0)); 1338 1339 Tensor* diff_scale_tensor = nullptr; 1340 Tensor* diff_shift_tensor = nullptr; 1341 AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor, 1342 &diff_shift_tensor); 1343 } 1344 AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale_shift,Tensor ** diff_scale_tensor,Tensor ** diff_shift_tensor)1345 void AllocateTFOutputs(OpKernelContext* context, 1346 TensorShape tf_shape_scale_shift, 1347 Tensor** diff_scale_tensor, 1348 Tensor** diff_shift_tensor) { 1349 DCHECK(diff_scale_tensor); 1350 DCHECK(diff_shift_tensor); 1351 1352 const size_t kDiffScaleIndex = 1; 1353 const size_t kDiffShiftIndex = 2; 1354 const size_t kP1Index = 3; 1355 const size_t kP2Index = 4; 1356 1357 // Separate out scale and shift grad and copy to individual tensors 1358 MklDnnShape mkl_shape_diff_scale; 1359 mkl_shape_diff_scale.SetMklTensor(false); 1360 AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, 1361 tf_shape_scale_shift, mkl_shape_diff_scale, 1362 native_format); 1363 DCHECK(*diff_scale_tensor); 1364 1365 auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data(); 1366 std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), 1367 static_cast<U>(0)); 1368 1369 MklDnnShape mkl_shape_diff_shift; 1370 mkl_shape_diff_shift.SetMklTensor(false); 1371 AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, 1372 tf_shape_scale_shift, mkl_shape_diff_shift, 1373 native_format); 1374 DCHECK(*diff_shift_tensor); 1375 1376 auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data(); 1377 std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), 1378 static_cast<U>(0)); 1379 1380 // Placeholders for estimated_mean and estimated_variance, which are 1381 // used for inference and thus not needed here for gradient computation. 1382 Tensor *p1_tensor = nullptr, *p2_tensor = nullptr; 1383 MklDnnShape mkl_shape_p; 1384 mkl_shape_p.SetMklTensor(false); 1385 AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), 1386 mkl_shape_p, native_format); 1387 std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(), 1388 static_cast<U>(0)); 1389 AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), 1390 mkl_shape_p, native_format); 1391 std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(), 1392 static_cast<U>(0)); 1393 } 1394 GetMeanVarianceDims()1395 memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } 1396 }; 1397 1398 #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T) \ 1399 REGISTER_KERNEL_BUILDER( \ 1400 Name("_MklFusedBatchNorm") \ 1401 .Device(DEVICE_CPU) \ 1402 .TypeConstraint<T>("T") \ 1403 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1404 MklFusedBatchNormOp<CPUDevice, T, T, false, false>); \ 1405 REGISTER_KERNEL_BUILDER( \ 1406 Name("_MklNativeFusedBatchNorm") \ 1407 .Device(DEVICE_CPU) \ 1408 .TypeConstraint<T>("T") \ 1409 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1410 MklFusedBatchNormOp<CPUDevice, T, T, false, false, true>); 1411 1412 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); 1413 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); 1414 #undef REGISTER_MKL_FUSED_BATCHNORM_CPU 1415 1416 #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U) \ 1417 REGISTER_KERNEL_BUILDER( \ 1418 Name("_MklFusedBatchNormV2") \ 1419 .Device(DEVICE_CPU) \ 1420 .TypeConstraint<T>("T") \ 1421 .TypeConstraint<U>("U") \ 1422 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1423 MklFusedBatchNormOp<CPUDevice, T, U, false, false>); \ 1424 REGISTER_KERNEL_BUILDER( \ 1425 Name("_MklNativeFusedBatchNormV2") \ 1426 .Device(DEVICE_CPU) \ 1427 .TypeConstraint<T>("T") \ 1428 .TypeConstraint<U>("U") \ 1429 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1430 MklFusedBatchNormOp<CPUDevice, T, U, false, false, true>); 1431 1432 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); 1433 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); 1434 #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU 1435 1436 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T) \ 1437 REGISTER_KERNEL_BUILDER( \ 1438 Name("_MklFusedBatchNormGrad") \ 1439 .Device(DEVICE_CPU) \ 1440 .TypeConstraint<T>("T") \ 1441 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1442 MklFusedBatchNormGradOp<CPUDevice, T, T, false>); \ 1443 REGISTER_KERNEL_BUILDER( \ 1444 Name("_MklNativeFusedBatchNormGrad") \ 1445 .Device(DEVICE_CPU) \ 1446 .TypeConstraint<T>("T") \ 1447 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1448 MklFusedBatchNormGradOp<CPUDevice, T, T, false, true>); 1449 1450 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); 1451 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); 1452 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU 1453 1454 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U) \ 1455 REGISTER_KERNEL_BUILDER( \ 1456 Name("_MklFusedBatchNormGradV2") \ 1457 .Device(DEVICE_CPU) \ 1458 .TypeConstraint<T>("T") \ 1459 .TypeConstraint<U>("U") \ 1460 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1461 MklFusedBatchNormGradOp<CPUDevice, T, U, false>); \ 1462 REGISTER_KERNEL_BUILDER( \ 1463 Name("_MklNativeFusedBatchNormGradV2") \ 1464 .Device(DEVICE_CPU) \ 1465 .TypeConstraint<T>("T") \ 1466 .TypeConstraint<U>("U") \ 1467 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1468 MklFusedBatchNormGradOp<CPUDevice, T, U, false, true>); 1469 1470 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float); 1471 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); 1472 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU 1473 1474 // TODO: FusedBatchNormV3 has an additional output that is used to 1475 // hold intermediate results. This parameter functionality is 1476 // not implemented on CPU. 1477 #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U) \ 1478 REGISTER_KERNEL_BUILDER( \ 1479 Name("_MklFusedBatchNormV3") \ 1480 .Device(DEVICE_CPU) \ 1481 .TypeConstraint<T>("T") \ 1482 .TypeConstraint<U>("U") \ 1483 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1484 MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \ 1485 REGISTER_KERNEL_BUILDER( \ 1486 Name("_MklFusedBatchNormEx") \ 1487 .Device(DEVICE_CPU) \ 1488 .TypeConstraint<T>("T") \ 1489 .TypeConstraint<U>("U") \ 1490 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1491 MklFusedBatchNormOp<CPUDevice, T, U, true, true>); \ 1492 REGISTER_KERNEL_BUILDER( \ 1493 Name("_MklNativeFusedBatchNormV3") \ 1494 .Device(DEVICE_CPU) \ 1495 .TypeConstraint<T>("T") \ 1496 .TypeConstraint<U>("U") \ 1497 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1498 MklFusedBatchNormOp<CPUDevice, T, U, true, false, true>); \ 1499 REGISTER_KERNEL_BUILDER( \ 1500 Name("_MklNativeFusedBatchNormEx") \ 1501 .Device(DEVICE_CPU) \ 1502 .TypeConstraint<T>("T") \ 1503 .TypeConstraint<U>("U") \ 1504 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1505 MklFusedBatchNormOp<CPUDevice, T, U, true, true, true>); 1506 1507 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float); 1508 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float); 1509 #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU 1510 1511 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") 1512 .Device(DEVICE_CPU) 1513 .TypeConstraint<float>("T") 1514 .TypeConstraint<float>("U"), 1515 NoOp); 1516 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") 1517 .Device(DEVICE_CPU) 1518 .TypeConstraint<bfloat16>("T") 1519 .TypeConstraint<float>("U"), 1520 NoOp); 1521 1522 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \ 1523 REGISTER_KERNEL_BUILDER( \ 1524 Name("_MklFusedBatchNormGradV3") \ 1525 .Device(DEVICE_CPU) \ 1526 .TypeConstraint<T>("T") \ 1527 .TypeConstraint<U>("U") \ 1528 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1529 MklFusedBatchNormGradOp<CPUDevice, T, U, true>); \ 1530 REGISTER_KERNEL_BUILDER( \ 1531 Name("_MklNativeFusedBatchNormGradV3") \ 1532 .Device(DEVICE_CPU) \ 1533 .TypeConstraint<T>("T") \ 1534 .TypeConstraint<U>("U") \ 1535 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1536 MklFusedBatchNormGradOp<CPUDevice, T, U, true, true>); 1537 1538 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float); 1539 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float); 1540 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU 1541 1542 } // namespace tensorflow 1543 1544 #undef GET_FLAG 1545 #undef IS_SET 1546 1547 #endif // INTEL_MKL 1548