1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/lib/gtl/flatset.h" 22 #include "tensorflow/core/lib/strings/str_util.h" 23 #include "tensorflow/core/util/env_var.h" 24 25 namespace tensorflow { 26 namespace grappler { 27 28 // Represents the four lists of ops: the allow list, infer list, deny list, and 29 // clear list. These lists determine which ops are converted to fp16/bf16 30 // (referred to as 'f16' for short) and which ops stay as fp32. 31 class AutoMixedPrecisionLists { 32 public: ~AutoMixedPrecisionLists()33 virtual ~AutoMixedPrecisionLists() {} 34 35 // Returns the set of ops that are considered numerically-safe (for execution 36 // in f16), performance-critical, and can run in f16. These ops are always 37 // converted to f16. 38 virtual gtl::FlatSet<string> AllowList() = 0; 39 // Returns the set of ops that can run in f16 and are considered numerically- 40 // safe (for execution in f16), but which may be made unsafe by an upstream 41 // denylist op. 42 virtual gtl::FlatSet<string> InferList() = 0; 43 // Returns the set of ops that are considered numerically-dangerous (i.e., 44 // unsafe for execution in f16) and whose effects may also be observed in 45 // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to 46 // the Exp). 47 virtual gtl::FlatSet<string> DenyList() = 0; 48 // Returns the set of ops that do not have numerically-significant effects 49 // (i.e., they are always considered safe for execution in f16 precision), and 50 // can run in f16. 51 virtual gtl::FlatSet<string> ClearList() = 0; 52 53 protected: 54 // Adds or removes ops from list if certain environmental variables are set. UpdateList(const string & list_name,gtl::FlatSet<string> * list)55 static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) { 56 CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" || // Crash OK. 57 list_name == "DENYLIST" || list_name == "CLEARLIST" || 58 // TODO(reedwm): for bkwds compat; remove when no longer necessary: 59 list_name == "WHITELIST" || list_name == "GRAYLIST" || 60 list_name == "BLACKLIST"); 61 string add_env_var = 62 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD"; 63 string remove_env_var = 64 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE"; 65 string to_add, to_remove; 66 TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add)); 67 TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove)); 68 for (const auto& x : str_util::Split(to_add, ",")) { 69 list->insert(x); 70 } 71 for (const auto& x : str_util::Split(to_remove, ",")) { 72 list->erase(x); 73 } 74 } 75 76 // Subclasses should include these on the ClearList. AddTensorListOps(gtl::FlatSet<string> * list)77 static void AddTensorListOps(gtl::FlatSet<string>* list) { 78 // Note: if a data structure op (such as TensorListPopBack) is added here, 79 // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified 80 // LINT.IfChange 81 constexpr const char* tensor_list_ops[] = { 82 "TensorListConcat", "TensorListConcatLists", 83 "TensorListConcatV2", "TensorListGather", 84 "TensorListGetItem", "TensorListPopBack", 85 "TensorListPushBack", "TensorListPushBackBatch", 86 "TensorListFromTensor", "TensorListScatter", 87 "TensorListScatterV2", "TensorListScatterIntoExistingList", 88 "TensorListSetItem", "TensorListSplit", 89 "TensorListStack"}; 90 // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc) 91 for (auto op : tensor_list_ops) { 92 list->insert(op); 93 } 94 } 95 }; 96 97 class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { 98 private: IsPseudoFastMath()99 static bool IsPseudoFastMath() { 100 string optimization_level; 101 TF_CHECK_OK( 102 ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", 103 &optimization_level)); 104 optimization_level = str_util::Uppercase(optimization_level); 105 return optimization_level == "TENSOR_CORES_ONLY"; 106 } 107 108 public: AutoMixedPrecisionListsCuda(int cuda_version,int cudnn_version)109 AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) 110 : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} 111 AllowList()112 gtl::FlatSet<string> AllowList() override { 113 auto list = gtl::FlatSet<string>{ 114 "BlockLSTM", 115 "BlockLSTMV2", 116 "BlockLSTMGrad", 117 "BlockLSTMGradV2", 118 "Conv2D", 119 "Conv2DBackpropFilter", 120 "Conv2DBackpropInput", 121 "CudnnRNN", 122 "CudnnRNNBackprop", 123 "CudnnRNNBackpropV2", 124 "CudnnRNNBackpropV3", 125 "CudnnRNNV2", 126 "CudnnRNNV3", 127 "Einsum", 128 "FusedConv2DBiasActivation", 129 "GRUBlockCell", 130 "GRUBlockCellGrad", 131 "LSTMBlockCell", 132 "LSTMBlockCellGrad", 133 "MatMul", 134 }; 135 #if TENSORFLOW_USE_ROCM 136 if (true) { 137 #else 138 if (cuda_version_ >= 9010) { 139 // Fp16 BatchMatMul is slow before CUDA 9.1. 140 #endif 141 list.insert("BatchMatMul"); 142 list.insert("BatchMatMulV2"); 143 } 144 if (cudnn_version_ >= 7602) { 145 // Fp16 3D conv is slow before CUDNN 7.6.2. 146 list.insert("Conv3D"); 147 list.insert("Conv3DBackpropFilter"); 148 list.insert("Conv3DBackpropFilterV2"); 149 list.insert("Conv3DBackpropInput"); 150 list.insert("Conv3DBackpropInputV2"); 151 } 152 if (cudnn_version_ >= 8000) { 153 list.insert("DepthwiseConv2dNative"); 154 list.insert("DepthwiseConv2dNativeBackpropFilter"); 155 list.insert("DepthwiseConv2dNativeBackpropInput"); 156 } 157 UpdateList("ALLOWLIST", &list); 158 // For backwards compatibility, keeping the original env variable here. 159 // TODO(reedwm): This should be removed if we don't have active users. 160 UpdateList("WHITELIST", &list); 161 162 return list; 163 } 164 165 gtl::FlatSet<string> InferList() override { 166 if (IsPseudoFastMath()) { 167 return gtl::FlatSet<string>{}; 168 } 169 170 auto list = gtl::FlatSet<string>{ 171 "Add", 172 "AddN", 173 "AddV2", 174 "AvgPool", 175 "AvgPool3D", 176 "AvgPool3DGrad", 177 "AvgPoolGrad", 178 "BiasAdd", 179 "BiasAddGrad", 180 "BiasAddV1", 181 "Elu", 182 "EluGrad", 183 "Erf", 184 "Erfc", 185 "FloorDiv", 186 "FusedBatchNormV2", 187 "FusedBatchNormGradV2", 188 "FusedBatchNormV3", 189 "FusedBatchNormGradV3", 190 "_FusedBatchNormEx", 191 "Inv", 192 "LeakyRelu", 193 "LeakyReluGrad", 194 "Log", 195 "Log1p", 196 "LogSoftmax", 197 "Mul", 198 "Prod", 199 "RealDiv", 200 "Reciprocal", 201 "Selu", 202 "SeluGrad", 203 "Sigmoid", 204 "SigmoidGrad", 205 "Softmax", 206 "Softplus", 207 "SoftplusGrad", 208 "Softsign", 209 "SoftsignGrad", 210 "Sqrt", 211 "Sub", 212 "Tanh", 213 "TanhGrad", 214 }; 215 UpdateList("INFERLIST", &list); 216 // For backwards compatibility, keeping the original env variable here. 217 // TODO(reedwm): This should be removed if we don't have active users. 218 UpdateList("GRAYLIST", &list); 219 return list; 220 } 221 222 gtl::FlatSet<string> DenyList() override { 223 if (IsPseudoFastMath()) { 224 return gtl::FlatSet<string>{}; 225 } 226 227 auto list = gtl::FlatSet<string>{ 228 "Exp", 229 "Expm1", 230 "L2Loss", 231 "Mean", 232 "Pow", 233 "SaveV2", 234 "SoftmaxCrossEntropyWithLogits", 235 "SparseSoftmaxCrossEntropyWithLogits", 236 "Sum", 237 }; 238 UpdateList("DENYLIST", &list); 239 // For backwards compatibility, keeping the original env variable here. 240 // TODO(reedwm): This should be removed if we don't have active users. 241 UpdateList("BLACKLIST", &list); 242 return list; 243 } 244 245 gtl::FlatSet<string> ClearList() override { 246 if (IsPseudoFastMath()) { 247 return gtl::FlatSet<string>{}; 248 } 249 250 auto list = gtl::FlatSet<string>{ 251 "Abs", 252 "ArgMax", 253 "ArgMin", 254 "BatchToSpace", 255 "BatchToSpaceND", 256 "BroadcastTo", 257 "Ceil", 258 "CheckNumerics", 259 "ClipByValue", 260 "Concat", 261 "ConcatV2", 262 "DepthToSpace", 263 "DynamicPartition", 264 "DynamicStitch", 265 "Enter", 266 "EnsureShape", 267 "Equal", 268 "Exit", 269 "ExpandDims", 270 "Fill", 271 "Floor", 272 "Gather", 273 "GatherNd", 274 "GatherV2", 275 "Greater", 276 "GreaterEqual", 277 "Identity", 278 "IdentityN", 279 "IsFinite", 280 "IsInf", 281 "IsNan", 282 "Less", 283 "LessEqual", 284 "Max", 285 "MaxPool", 286 "MaxPool3D", 287 "MaxPool3DGrad", 288 "MaxPool3DGradGrad", 289 "MaxPoolGrad", 290 "MaxPoolGradGrad", 291 "MaxPoolGradGradV2", 292 "MaxPoolGradV2", 293 "MaxPoolV2", 294 "Maximum", 295 "Merge", 296 "Min", 297 "Minimum", 298 "MirrorPad", 299 "MirrorPadGrad", 300 "Neg", 301 "NextIteration", 302 "NotEqual", 303 "OneHot", 304 "OnesLike", 305 "Pack", 306 "Pad", 307 "PadV2", 308 "PreventGradient", 309 "Rank", 310 "Relu", 311 "Relu6", 312 "Relu6Grad", 313 "ReluGrad", 314 "Reshape", 315 "ResizeNearestNeighbor", 316 "ResizeNearestNeighborGrad", 317 "Reverse", 318 "ReverseSequence", 319 "ReverseV2", 320 "Round", 321 "Select", 322 "SelectV2", 323 "Shape", 324 "ShapeN", 325 "Sign", 326 "Size", 327 "Slice", 328 "Snapshot", 329 "SpaceToBatch", 330 "SpaceToBatchND", 331 "SpaceToDepth", 332 "Split", 333 "SplitV", 334 "Squeeze", 335 "StopGradient", 336 "StridedSlice", 337 "StridedSliceGrad", 338 "Switch", 339 "Tile", 340 "TopK", 341 "TopKV2", 342 "Transpose", 343 "Unpack", 344 "Where", 345 "ZerosLike", 346 }; 347 AddTensorListOps(&list); 348 UpdateList("CLEARLIST", &list); 349 return list; 350 } 351 352 private: 353 int cuda_version_; 354 int cudnn_version_; 355 }; 356 357 class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { 358 public: AutoMixedPrecisionListsMkl()359 AutoMixedPrecisionListsMkl() {} 360 361 // Only ops which are supported by MKL in bfloat16 should be added to the 362 // allow list, infer list, or clear list. AllowList()363 gtl::FlatSet<string> AllowList() override { 364 auto list = gtl::FlatSet<string>{"Conv2D", 365 "Conv2DBackpropFilter", 366 "Conv2DBackpropInput", 367 "Conv3D", 368 "Conv3DBackpropFilterV2", 369 "Conv3DBackpropInputV2", 370 "DepthwiseConv2dNative", 371 "DepthwiseConv2dNativeBackpropFilter", 372 "DepthwiseConv2dNativeBackpropInput", 373 "MatMul", 374 "BatchMatMul", 375 "BatchMatMulV2"}; 376 377 UpdateList("ALLOWLIST", &list); 378 // For backwards compatibility, keeping the original env variable here. 379 // TODO(reedwm): This should be removed if we don't have active users. 380 UpdateList("WHITELIST", &list); 381 return list; 382 } 383 InferList()384 gtl::FlatSet<string> InferList() override { 385 auto list = gtl::FlatSet<string>{"Add", 386 "AddN", 387 "AddV2", 388 "AvgPool", 389 "AvgPool3D", 390 "AvgPool3DGrad", 391 "AvgPoolGrad", 392 "BiasAdd", 393 "BiasAddGrad", 394 "BiasAddV1", 395 "FusedBatchNormV2", 396 "FusedBatchNormGradV2", 397 "FusedBatchNormV3", 398 "FusedBatchNormGradV3", 399 "LeakyRelu", 400 "LeakyReluGrad", 401 "Mul", 402 "Sub", 403 "Elu", 404 "EluGrad", 405 "FloorDiv", 406 "_FusedBatchNormEx", 407 "Log", 408 "Log1p", 409 "LogSoftmax", 410 "Prod", 411 "RealDiv", 412 "Reciprocal", 413 "Selu", 414 "SeluGrad", 415 "Sigmoid", 416 "SigmoidGrad", 417 "Softmax", 418 "Softplus", 419 "SoftplusGrad", 420 "Softsign", 421 "SoftsignGrad", 422 "Sqrt", 423 "Tanh", 424 "TanhGrad"}; 425 UpdateList("INFERLIST", &list); 426 // For backwards compatibility, keeping the original env variable here. 427 // TODO(reedwm): This should be removed if we don't have active users. 428 UpdateList("GRAYLIST", &list); 429 return list; 430 } 431 DenyList()432 gtl::FlatSet<string> DenyList() override { 433 auto list = gtl::FlatSet<string>{ 434 "Exp", 435 "Expm1", 436 "L2Loss", 437 "Mean", 438 "Pow", 439 "SaveV2", 440 "SoftmaxCrossEntropyWithLogits", 441 "SparseSoftmaxCrossEntropyWithLogits", 442 "Sum", 443 }; 444 UpdateList("DENYLIST", &list); 445 // For backwards compatibility, keeping the original env variable here. 446 // TODO(reedwm): This should be removed if we don't have active users. 447 UpdateList("BLACKLIST", &list); 448 return list; 449 } 450 ClearList()451 gtl::FlatSet<string> ClearList() override { 452 auto list = gtl::FlatSet<string>{ 453 "Abs", 454 "ArgMax", 455 "ArgMin", 456 "BatchToSpace", 457 "BatchToSpaceND", 458 "BroadcastTo", 459 "Ceil", 460 "CheckNumerics", 461 "ClipByValue", 462 "Concat", 463 "ConcatV2", 464 "DepthToSpace", 465 "DynamicPartition", 466 "DynamicStitch", 467 "EnsureShape", 468 "Enter", 469 "Equal", 470 "Exit", 471 "ExpandDims", 472 "Fill", 473 "Floor", 474 "Gather", 475 "GatherNd", 476 "GatherV2", 477 "Greater", 478 "GreaterEqual", 479 "Identity", 480 "IsFinite", 481 "IsInf", 482 "IsNan", 483 "Less", 484 "LessEqual", 485 "Max", 486 "Maximum", 487 "MaxPool", 488 "MaxPool3D", 489 "MaxPool3DGrad", 490 "MaxPoolGrad", 491 "MaxPoolGradGrad", 492 "MaxPoolGradGradV2", 493 "MaxPoolGradV2", 494 "MaxPoolV2", 495 "Merge", 496 "Min", 497 "Minimum", 498 "MirrorPad", 499 "MirrorPadGrad", 500 "Neg", 501 "NextIteration", 502 "NotEqual", 503 "OnesLike", 504 "Pack", 505 "Pad", 506 "PadV2", 507 "PreventGradient", 508 "Rank", 509 "Relu", 510 "Relu6", 511 "Relu6Grad", 512 "ReluGrad", 513 "Reshape", 514 "ResizeNearestNeighbor", 515 "ResizeNearestNeighborGrad", 516 "Reverse", 517 "ReverseSequence", 518 "ReverseV2", 519 "Round", 520 "Select", 521 "SelectV2", 522 "Shape", 523 "ShapeN", 524 "Sign", 525 "Slice", 526 "Snapshot", 527 "SpaceToBatch", 528 "SpaceToBatchND", 529 "SpaceToDepth", 530 "Split", 531 "SplitV", 532 "Squeeze", 533 "StopGradient", 534 "StridedSlice", 535 "StridedSliceGrad", 536 "Switch", 537 "Tile", 538 "TopK", 539 "TopKV2", 540 "Transpose", 541 "Where", 542 "Unpack", 543 "ZerosLike", 544 }; 545 AddTensorListOps(&list); 546 UpdateList("CLEARLIST", &list); 547 return list; 548 } 549 }; 550 551 } // end namespace grappler 552 } // end namespace tensorflow 553 554 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 555