• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/lib/gtl/flatset.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/util/env_var.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 
28 // Represents the four lists of ops: the allow list, infer list, deny list, and
29 // clear list. These lists determine which ops are converted to fp16/bf16
30 // (referred to as 'f16' for short) and which ops stay as fp32.
31 class AutoMixedPrecisionLists {
32  public:
~AutoMixedPrecisionLists()33   virtual ~AutoMixedPrecisionLists() {}
34 
35   // Returns the set of ops that are considered numerically-safe (for execution
36   // in f16), performance-critical, and can run in f16. These ops are always
37   // converted to f16.
38   virtual gtl::FlatSet<string> AllowList() = 0;
39   // Returns the set of ops that can run in f16 and are considered numerically-
40   // safe (for execution in f16), but which may be made unsafe by an upstream
41   // denylist op.
42   virtual gtl::FlatSet<string> InferList() = 0;
43   // Returns the set of ops that are considered numerically-dangerous (i.e.,
44   // unsafe for execution in f16) and whose effects may also be observed in
45   // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
46   // the Exp).
47   virtual gtl::FlatSet<string> DenyList() = 0;
48   // Returns the set of ops that do not have numerically-significant effects
49   // (i.e., they are always considered safe for execution in f16 precision), and
50   // can run in f16.
51   virtual gtl::FlatSet<string> ClearList() = 0;
52 
53  protected:
54   // Adds or removes ops from list if certain environmental variables are set.
UpdateList(const string & list_name,gtl::FlatSet<string> * list)55   static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
56     CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" ||  // Crash OK.
57           list_name == "DENYLIST" || list_name == "CLEARLIST" ||
58           // TODO(reedwm): for bkwds compat; remove when no longer necessary:
59           list_name == "WHITELIST" || list_name == "GRAYLIST" ||
60           list_name == "BLACKLIST");
61     string add_env_var =
62         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
63     string remove_env_var =
64         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE";
65     string to_add, to_remove;
66     TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add));
67     TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove));
68     for (const auto& x : str_util::Split(to_add, ",")) {
69       list->insert(x);
70     }
71     for (const auto& x : str_util::Split(to_remove, ",")) {
72       list->erase(x);
73     }
74   }
75 
76   // Subclasses should include these on the ClearList.
AddTensorListOps(gtl::FlatSet<string> * list)77   static void AddTensorListOps(gtl::FlatSet<string>* list) {
78     // Note: if a data structure op (such as TensorListPopBack) is added here,
79     // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified
80     // LINT.IfChange
81     constexpr const char* tensor_list_ops[] = {
82         "TensorListConcat",     "TensorListConcatLists",
83         "TensorListConcatV2",   "TensorListGather",
84         "TensorListGetItem",    "TensorListPopBack",
85         "TensorListPushBack",   "TensorListPushBackBatch",
86         "TensorListFromTensor", "TensorListScatter",
87         "TensorListScatterV2",  "TensorListScatterIntoExistingList",
88         "TensorListSetItem",    "TensorListSplit",
89         "TensorListStack"};
90     // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc)
91     for (auto op : tensor_list_ops) {
92       list->insert(op);
93     }
94   }
95 };
96 
97 class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
98  private:
IsPseudoFastMath()99   static bool IsPseudoFastMath() {
100     string optimization_level;
101     TF_CHECK_OK(
102         ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "",
103                              &optimization_level));
104     optimization_level = str_util::Uppercase(optimization_level);
105     return optimization_level == "TENSOR_CORES_ONLY";
106   }
107 
108  public:
AutoMixedPrecisionListsCuda(int cuda_version,int cudnn_version)109   AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
110       : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
111 
AllowList()112   gtl::FlatSet<string> AllowList() override {
113     auto list = gtl::FlatSet<string>{
114         "BlockLSTM",
115         "BlockLSTMV2",
116         "BlockLSTMGrad",
117         "BlockLSTMGradV2",
118         "Conv2D",
119         "Conv2DBackpropFilter",
120         "Conv2DBackpropInput",
121         "CudnnRNN",
122         "CudnnRNNBackprop",
123         "CudnnRNNBackpropV2",
124         "CudnnRNNBackpropV3",
125         "CudnnRNNV2",
126         "CudnnRNNV3",
127         "Einsum",
128         "FusedConv2DBiasActivation",
129         "GRUBlockCell",
130         "GRUBlockCellGrad",
131         "LSTMBlockCell",
132         "LSTMBlockCellGrad",
133         "MatMul",
134     };
135 #if TENSORFLOW_USE_ROCM
136     if (true) {
137 #else
138     if (cuda_version_ >= 9010) {
139       // Fp16 BatchMatMul is slow before CUDA 9.1.
140 #endif
141       list.insert("BatchMatMul");
142       list.insert("BatchMatMulV2");
143     }
144     if (cudnn_version_ >= 7602) {
145       // Fp16 3D conv is slow before CUDNN 7.6.2.
146       list.insert("Conv3D");
147       list.insert("Conv3DBackpropFilter");
148       list.insert("Conv3DBackpropFilterV2");
149       list.insert("Conv3DBackpropInput");
150       list.insert("Conv3DBackpropInputV2");
151     }
152     if (cudnn_version_ >= 8000) {
153       list.insert("DepthwiseConv2dNative");
154       list.insert("DepthwiseConv2dNativeBackpropFilter");
155       list.insert("DepthwiseConv2dNativeBackpropInput");
156     }
157     UpdateList("ALLOWLIST", &list);
158     // For backwards compatibility, keeping the original env variable here.
159     // TODO(reedwm): This should be removed if we don't have active users.
160     UpdateList("WHITELIST", &list);
161 
162     return list;
163   }
164 
165   gtl::FlatSet<string> InferList() override {
166     if (IsPseudoFastMath()) {
167       return gtl::FlatSet<string>{};
168     }
169 
170     auto list = gtl::FlatSet<string>{
171         "Add",
172         "AddN",
173         "AddV2",
174         "AvgPool",
175         "AvgPool3D",
176         "AvgPool3DGrad",
177         "AvgPoolGrad",
178         "BiasAdd",
179         "BiasAddGrad",
180         "BiasAddV1",
181         "Elu",
182         "EluGrad",
183         "Erf",
184         "Erfc",
185         "FloorDiv",
186         "FusedBatchNormV2",
187         "FusedBatchNormGradV2",
188         "FusedBatchNormV3",
189         "FusedBatchNormGradV3",
190         "_FusedBatchNormEx",
191         "Inv",
192         "LeakyRelu",
193         "LeakyReluGrad",
194         "Log",
195         "Log1p",
196         "LogSoftmax",
197         "Mul",
198         "Prod",
199         "RealDiv",
200         "Reciprocal",
201         "Selu",
202         "SeluGrad",
203         "Sigmoid",
204         "SigmoidGrad",
205         "Softmax",
206         "Softplus",
207         "SoftplusGrad",
208         "Softsign",
209         "SoftsignGrad",
210         "Sqrt",
211         "Sub",
212         "Tanh",
213         "TanhGrad",
214     };
215     UpdateList("INFERLIST", &list);
216     // For backwards compatibility, keeping the original env variable here.
217     // TODO(reedwm): This should be removed if we don't have active users.
218     UpdateList("GRAYLIST", &list);
219     return list;
220   }
221 
222   gtl::FlatSet<string> DenyList() override {
223     if (IsPseudoFastMath()) {
224       return gtl::FlatSet<string>{};
225     }
226 
227     auto list = gtl::FlatSet<string>{
228         "Exp",
229         "Expm1",
230         "L2Loss",
231         "Mean",
232         "Pow",
233         "SaveV2",
234         "SoftmaxCrossEntropyWithLogits",
235         "SparseSoftmaxCrossEntropyWithLogits",
236         "Sum",
237     };
238     UpdateList("DENYLIST", &list);
239     // For backwards compatibility, keeping the original env variable here.
240     // TODO(reedwm): This should be removed if we don't have active users.
241     UpdateList("BLACKLIST", &list);
242     return list;
243   }
244 
245   gtl::FlatSet<string> ClearList() override {
246     if (IsPseudoFastMath()) {
247       return gtl::FlatSet<string>{};
248     }
249 
250     auto list = gtl::FlatSet<string>{
251         "Abs",
252         "ArgMax",
253         "ArgMin",
254         "BatchToSpace",
255         "BatchToSpaceND",
256         "BroadcastTo",
257         "Ceil",
258         "CheckNumerics",
259         "ClipByValue",
260         "Concat",
261         "ConcatV2",
262         "DepthToSpace",
263         "DynamicPartition",
264         "DynamicStitch",
265         "Enter",
266         "EnsureShape",
267         "Equal",
268         "Exit",
269         "ExpandDims",
270         "Fill",
271         "Floor",
272         "Gather",
273         "GatherNd",
274         "GatherV2",
275         "Greater",
276         "GreaterEqual",
277         "Identity",
278         "IdentityN",
279         "IsFinite",
280         "IsInf",
281         "IsNan",
282         "Less",
283         "LessEqual",
284         "Max",
285         "MaxPool",
286         "MaxPool3D",
287         "MaxPool3DGrad",
288         "MaxPool3DGradGrad",
289         "MaxPoolGrad",
290         "MaxPoolGradGrad",
291         "MaxPoolGradGradV2",
292         "MaxPoolGradV2",
293         "MaxPoolV2",
294         "Maximum",
295         "Merge",
296         "Min",
297         "Minimum",
298         "MirrorPad",
299         "MirrorPadGrad",
300         "Neg",
301         "NextIteration",
302         "NotEqual",
303         "OneHot",
304         "OnesLike",
305         "Pack",
306         "Pad",
307         "PadV2",
308         "PreventGradient",
309         "Rank",
310         "Relu",
311         "Relu6",
312         "Relu6Grad",
313         "ReluGrad",
314         "Reshape",
315         "ResizeNearestNeighbor",
316         "ResizeNearestNeighborGrad",
317         "Reverse",
318         "ReverseSequence",
319         "ReverseV2",
320         "Round",
321         "Select",
322         "SelectV2",
323         "Shape",
324         "ShapeN",
325         "Sign",
326         "Size",
327         "Slice",
328         "Snapshot",
329         "SpaceToBatch",
330         "SpaceToBatchND",
331         "SpaceToDepth",
332         "Split",
333         "SplitV",
334         "Squeeze",
335         "StopGradient",
336         "StridedSlice",
337         "StridedSliceGrad",
338         "Switch",
339         "Tile",
340         "TopK",
341         "TopKV2",
342         "Transpose",
343         "Unpack",
344         "Where",
345         "ZerosLike",
346     };
347     AddTensorListOps(&list);
348     UpdateList("CLEARLIST", &list);
349     return list;
350   }
351 
352  private:
353   int cuda_version_;
354   int cudnn_version_;
355 };
356 
357 class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
358  public:
AutoMixedPrecisionListsMkl()359   AutoMixedPrecisionListsMkl() {}
360 
361   // Only ops which are supported by MKL in bfloat16 should be added to the
362   // allow list, infer list, or clear list.
AllowList()363   gtl::FlatSet<string> AllowList() override {
364     auto list = gtl::FlatSet<string>{"Conv2D",
365                                      "Conv2DBackpropFilter",
366                                      "Conv2DBackpropInput",
367                                      "Conv3D",
368                                      "Conv3DBackpropFilterV2",
369                                      "Conv3DBackpropInputV2",
370                                      "DepthwiseConv2dNative",
371                                      "DepthwiseConv2dNativeBackpropFilter",
372                                      "DepthwiseConv2dNativeBackpropInput",
373                                      "MatMul",
374                                      "BatchMatMul",
375                                      "BatchMatMulV2"};
376 
377     UpdateList("ALLOWLIST", &list);
378     // For backwards compatibility, keeping the original env variable here.
379     // TODO(reedwm): This should be removed if we don't have active users.
380     UpdateList("WHITELIST", &list);
381     return list;
382   }
383 
InferList()384   gtl::FlatSet<string> InferList() override {
385     auto list = gtl::FlatSet<string>{"Add",
386                                      "AddN",
387                                      "AddV2",
388                                      "AvgPool",
389                                      "AvgPool3D",
390                                      "AvgPool3DGrad",
391                                      "AvgPoolGrad",
392                                      "BiasAdd",
393                                      "BiasAddGrad",
394                                      "BiasAddV1",
395                                      "FusedBatchNormV2",
396                                      "FusedBatchNormGradV2",
397                                      "FusedBatchNormV3",
398                                      "FusedBatchNormGradV3",
399                                      "LeakyRelu",
400                                      "LeakyReluGrad",
401                                      "Mul",
402                                      "Sub",
403                                      "Elu",
404                                      "EluGrad",
405                                      "FloorDiv",
406                                      "_FusedBatchNormEx",
407                                      "Log",
408                                      "Log1p",
409                                      "LogSoftmax",
410                                      "Prod",
411                                      "RealDiv",
412                                      "Reciprocal",
413                                      "Selu",
414                                      "SeluGrad",
415                                      "Sigmoid",
416                                      "SigmoidGrad",
417                                      "Softmax",
418                                      "Softplus",
419                                      "SoftplusGrad",
420                                      "Softsign",
421                                      "SoftsignGrad",
422                                      "Sqrt",
423                                      "Tanh",
424                                      "TanhGrad"};
425     UpdateList("INFERLIST", &list);
426     // For backwards compatibility, keeping the original env variable here.
427     // TODO(reedwm): This should be removed if we don't have active users.
428     UpdateList("GRAYLIST", &list);
429     return list;
430   }
431 
DenyList()432   gtl::FlatSet<string> DenyList() override {
433     auto list = gtl::FlatSet<string>{
434         "Exp",
435         "Expm1",
436         "L2Loss",
437         "Mean",
438         "Pow",
439         "SaveV2",
440         "SoftmaxCrossEntropyWithLogits",
441         "SparseSoftmaxCrossEntropyWithLogits",
442         "Sum",
443     };
444     UpdateList("DENYLIST", &list);
445     // For backwards compatibility, keeping the original env variable here.
446     // TODO(reedwm): This should be removed if we don't have active users.
447     UpdateList("BLACKLIST", &list);
448     return list;
449   }
450 
ClearList()451   gtl::FlatSet<string> ClearList() override {
452     auto list = gtl::FlatSet<string>{
453         "Abs",
454         "ArgMax",
455         "ArgMin",
456         "BatchToSpace",
457         "BatchToSpaceND",
458         "BroadcastTo",
459         "Ceil",
460         "CheckNumerics",
461         "ClipByValue",
462         "Concat",
463         "ConcatV2",
464         "DepthToSpace",
465         "DynamicPartition",
466         "DynamicStitch",
467         "EnsureShape",
468         "Enter",
469         "Equal",
470         "Exit",
471         "ExpandDims",
472         "Fill",
473         "Floor",
474         "Gather",
475         "GatherNd",
476         "GatherV2",
477         "Greater",
478         "GreaterEqual",
479         "Identity",
480         "IsFinite",
481         "IsInf",
482         "IsNan",
483         "Less",
484         "LessEqual",
485         "Max",
486         "Maximum",
487         "MaxPool",
488         "MaxPool3D",
489         "MaxPool3DGrad",
490         "MaxPoolGrad",
491         "MaxPoolGradGrad",
492         "MaxPoolGradGradV2",
493         "MaxPoolGradV2",
494         "MaxPoolV2",
495         "Merge",
496         "Min",
497         "Minimum",
498         "MirrorPad",
499         "MirrorPadGrad",
500         "Neg",
501         "NextIteration",
502         "NotEqual",
503         "OnesLike",
504         "Pack",
505         "Pad",
506         "PadV2",
507         "PreventGradient",
508         "Rank",
509         "Relu",
510         "Relu6",
511         "Relu6Grad",
512         "ReluGrad",
513         "Reshape",
514         "ResizeNearestNeighbor",
515         "ResizeNearestNeighborGrad",
516         "Reverse",
517         "ReverseSequence",
518         "ReverseV2",
519         "Round",
520         "Select",
521         "SelectV2",
522         "Shape",
523         "ShapeN",
524         "Sign",
525         "Slice",
526         "Snapshot",
527         "SpaceToBatch",
528         "SpaceToBatchND",
529         "SpaceToDepth",
530         "Split",
531         "SplitV",
532         "Squeeze",
533         "StopGradient",
534         "StridedSlice",
535         "StridedSliceGrad",
536         "Switch",
537         "Tile",
538         "TopK",
539         "TopKV2",
540         "Transpose",
541         "Where",
542         "Unpack",
543         "ZerosLike",
544     };
545     AddTensorListOps(&list);
546     UpdateList("CLEARLIST", &list);
547     return list;
548   }
549 };
550 
551 }  // end namespace grappler
552 }  // end namespace tensorflow
553 
554 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
555