• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17 
18 #include <fstream>
19 #include <memory>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/grappler/clusters/cluster.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/mutable_graph_view.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/lib/io/path.h"
38 #include "tensorflow/core/lib/strings/numbers.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/env_var.h"
43 
44 namespace tensorflow {
45 namespace grappler {
46 namespace {
47 
48 #if GOOGLE_CUDA
49 const std::pair<int, int> kMinGPUArch = {7, 0};
50 #else
51 const std::pair<int, int> kMinGPUArch = {0, 0};
52 #endif
53 
54 const char kSuffix[] = "AutoMixedPrecision";
55 const char kCastToFp16[] = "CastToFp16";
56 const char kCastToBf16[] = "CastToBf16";
57 const char kCastToFp32[] = "CastToFp32";
58 
59 #if GOOGLE_CUDA
60 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)61 std::pair<int, int> GetDeviceGPUArch(
62     const DeviceProperties& device_properties) {
63   if (device_properties.type() != "GPU") return {0, 0};
64   string arch_str = device_properties.environment().at("architecture");
65   std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
66   if (split_arch_str.empty()) {
67     return {0, 0};
68   }
69 
70   int major, minor;
71   if (!strings::safe_strto32(split_arch_str[0], &major)) {
72     return {0, 0};
73   }
74 
75   if (split_arch_str.size() > 1) {
76     if (strings::safe_strto32(split_arch_str[1], &minor)) {
77       return {major, minor};
78     } else {
79       return {0, 0};
80     }
81   } else {
82     return {major, 0};
83   }
84 }
85 #endif
86 
87 // Returns true if FP16Support is valid
88 // For CUDA, We compare the GPUArch with the kMinGPUArch, if GPUArch is >= min,
89 // return true. For AMD the corresponding gfx arch string for the detected AMD
90 // GPU is in the list for FP16 supported compute. Returns false otherwise.
HasFastFP16Support(const DeviceProperties & props)91 bool HasFastFP16Support(const DeviceProperties& props) {
92 #if GOOGLE_CUDA
93   return GetDeviceGPUArch(props) >= kMinGPUArch;
94 #elif TENSORFLOW_USE_ROCM
95   absl::flat_hash_set<std::string> FP16SupportedDevices = {{"gfx906"},
96                                                            {"gfx908"}};
97   std::string gcnArchName = props.environment().at("architecture");
98   std::vector<std::string> gpu_arch = absl::StrSplit(gcnArchName, ":");
99   return !gpu_arch.empty() && FP16SupportedDevices.contains(gpu_arch[0]);
100 #endif
101   return false;
102 }
103 
104 // Instances of this class represent unique type attribute identifiers within a
105 // node. It handles regular type attributes, list type attributes (where
106 // type_index is set to the index in the type list), and fixed types.
107 struct TypeAttrId {
108   static constexpr int kSingleType = -1;
109 
TypeAttrIdtensorflow::grappler::__anonc899daf80111::TypeAttrId110   explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
111       : attr_name(_attr_name),
112         type_index(_type_index),
113         fixed_type(DT_INVALID) {}
114 
TypeAttrIdtensorflow::grappler::__anonc899daf80111::TypeAttrId115   explicit TypeAttrId(DataType _fixed_type)
116       : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
117 
operator ==tensorflow::grappler::__anonc899daf80111::TypeAttrId118   bool operator==(const TypeAttrId& other) const {
119     return attr_name == other.attr_name && type_index == other.type_index &&
120            fixed_type == other.fixed_type;
121   }
122 
operator <tensorflow::grappler::__anonc899daf80111::TypeAttrId123   bool operator<(const TypeAttrId& other) const {
124     return std::make_tuple(attr_name, type_index, fixed_type) <
125            std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
126   }
127 
128   template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)129   friend H AbslHashValue(H h, const TypeAttrId& ta) {
130     return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
131   }
132 
DebugStringtensorflow::grappler::__anonc899daf80111::TypeAttrId133   string DebugString() const {
134     if (!attr_name.empty()) {
135       if (type_index == kSingleType) {
136         return attr_name;
137       } else {
138         return strings::StrCat(attr_name, "[", type_index, "]");
139       }
140     } else {
141       return tensorflow::DataTypeString(fixed_type);
142     }
143   }
144 
145   string attr_name;
146   // If attr_name is a list(type), this is the index into the list. Otherwise
147   // this is kSingleType.
148   int type_index;
149   DataType fixed_type;
150 };
151 
152 // Returns the data type of the given type attribute, or DT_INVALID if the type
153 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)154 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
155   if (type_attr.attr_name.empty()) {
156     return type_attr.fixed_type;
157   }
158   if (!node.attr().count(type_attr.attr_name)) {
159     return DT_INVALID;
160   }
161   const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
162   if (type_attr.type_index == TypeAttrId::kSingleType) {
163     return attr_value.type();
164   } else {
165     if (type_attr.type_index < 0 ||
166         type_attr.type_index >= attr_value.list().type_size()) {
167       return DT_INVALID;
168     }
169     return attr_value.list().type(type_attr.type_index);
170   }
171 }
172 
173 // Sets the data type of the given type attribute. Returns false if the type
174 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)175 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
176   if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
177     return false;
178   }
179   AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
180   if (type_attr.type_index == TypeAttrId::kSingleType) {
181     attr_value.set_type(type);
182   } else {
183     if (type_attr.type_index < 0 ||
184         type_attr.type_index >= attr_value.list().type_size()) {
185       return false;
186     }
187     attr_value.mutable_list()->set_type(type_attr.type_index, type);
188   }
189   return true;
190 }
191 
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)192 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
193                                                const OpDef::ArgDef& arg_def) {
194   std::vector<std::pair<int, int>> argdef_inds;
195   if (!arg_def.type_list_attr().empty()) {
196     int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
197     for (int type_idx = 0; type_idx < num_types; ++type_idx) {
198       argdef_inds.push_back({arg_idx, type_idx});
199     }
200   } else {
201     int num_repeat = 1;
202     if (node.attr().count(arg_def.number_attr())) {
203       num_repeat = node.attr().at(arg_def.number_attr()).i();
204     }
205     argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
206   }
207   return argdef_inds;
208 }
209 
210 // Returns a pair (arg_index, type_index) for each input to the node, where
211 // arg_index is the index of the input_arg in op_def and type_index is the index
212 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)213 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
214                                                         const OpDef& op_def) {
215   std::vector<std::pair<int, int>> argdef_inds;
216   argdef_inds.reserve(op_def.input_arg_size());  // Final size may differ.
217   for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
218     const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
219     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
220     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
221                        arg_results.end());
222   }
223   return argdef_inds;
224 }
225 
226 // Returns a pair (arg_index, type_index) for each output to the node, where
227 // arg_index is the index of the output_arg in op_def and type_index is the
228 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)229 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
230                                                          const OpDef& op_def) {
231   std::vector<std::pair<int, int>> argdef_inds;
232   argdef_inds.reserve(op_def.output_arg_size());  // Final size may differ.
233   for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
234     const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
235     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
236     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
237                        arg_results.end());
238   }
239   return argdef_inds;
240 }
241 
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)242 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
243   if (!arg_def.type_list_attr().empty()) {
244     return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
245   } else if (!arg_def.type_attr().empty()) {
246     return TypeAttrId(arg_def.type_attr());
247   } else {
248     return TypeAttrId(arg_def.type());
249   }
250 }
251 
NonControlInputs(const NodeDef & node)252 std::vector<int> NonControlInputs(const NodeDef& node) {
253   std::vector<int> pos;
254   for (int i = 0; i < node.input_size(); i++) {
255     if (!IsControlInput(node.input(i))) {
256       pos.push_back(i);
257     }
258   }
259   return pos;
260 }
261 
262 // A utility class to lookup node type attributes and type attribute <->
263 // input/output port mappings.
264 class NodeTypeAttrMap {
265  public:
NodeTypeAttrMap()266   NodeTypeAttrMap() {}
267 
NodeTypeAttrMap(const GraphDef & graph)268   explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
269 
Init(const GraphDef & graph)270   Status Init(const GraphDef& graph) {
271     if (graph_ != nullptr) {
272       return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
273     }
274     graph_ = &graph;
275     function_library_.reset(
276         new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
277     for (const NodeDef& node : graph.node()) {
278       TF_RETURN_IF_ERROR(AddNode(node));
279     }
280     return Status::OK();
281   }
282 
is_initialized() const283   bool is_initialized() const { return graph_ != nullptr; }
284 
285   // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const286   absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
287     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
288     absl::flat_hash_set<TypeAttrId> type_attrs;
289     const auto iter = type2io_.find(&node);
290     CHECK(iter != type2io_.end());  // Crash Ok
291     for (const auto& key_value : iter->second) {
292       type_attrs.insert(key_value.first);
293     }
294     return type_attrs;
295   }
296 
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const297   const absl::flat_hash_set<int>& GetInputPorts(
298       const NodeDef& node, const TypeAttrId& type_attr) const {
299     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
300     return type2io_.at(&node).at(type_attr).first;
301   }
302 
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const303   const absl::flat_hash_set<int>& GetOutputPorts(
304       const NodeDef& node, const TypeAttrId& type_attr) const {
305     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
306     return type2io_.at(&node).at(type_attr).second;
307   }
308 
GetInputTypeAttr(const NodeDef & node,int port) const309   TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
310     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
311     auto type_vec = io2type_.at(&node).first;
312     CHECK_GE(port, 0);                // Crash Ok
313     CHECK_LT(port, type_vec.size());  // Crash Ok
314     return type_vec[port];
315   }
316 
GetOutputTypeAttr(const NodeDef & node,int port) const317   TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
318     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
319     auto type_vec = io2type_.at(&node).second;
320     CHECK_GE(port, 0);                // Crash Ok
321     CHECK_LT(port, type_vec.size());  // Crash Ok
322     return type_vec[port];
323   }
324 
325  private:
AddNode(const NodeDef & node)326   Status AddNode(const NodeDef& node) {
327     const OpDef* op_def_ptr = nullptr;
328     TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
329     const OpDef& op_def = *op_def_ptr;
330     auto& type2io_entry = type2io_[&node];
331     auto& io2type_entry = io2type_[&node];
332     auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
333     if (NonControlInputs(node).size() != input_arg_inds.size()) {
334       return errors::InvalidArgument(
335           "Expected ", node.op(), " node ", node.name(), " to have ",
336           input_arg_inds.size(), " non-control input(s), but got ",
337           node.input_size());
338     }
339     // Note that the mappings generated here include inputs/outputs with fixed
340     // types. This makes the mappings complete (all inputs and outputs are
341     // included), and allows the graph rewriter to propagate deny paint
342     // from/through ops with fixed types.
343     io2type_entry.first.reserve(input_arg_inds.size());
344     for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
345       const auto& arg_inds = input_arg_inds[i];
346       const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
347       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
348       if (!type_attr.attr_name.empty() &&
349           !node.attr().count(type_attr.attr_name)) {
350         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
351                                        " is not present in node ", node.name());
352       }
353       type2io_entry[type_attr].first.insert(i);
354       io2type_entry.first.push_back(type_attr);
355     }
356 
357     auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
358     io2type_entry.second.reserve(output_arg_inds.size());
359     for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
360       const auto& arg_inds = output_arg_inds[i];
361       const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
362       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
363       if (!type_attr.attr_name.empty() &&
364           !node.attr().count(type_attr.attr_name)) {
365         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
366                                        " is not present in node ", node.name());
367       }
368       type2io_entry[type_attr].second.insert(i);
369       io2type_entry.second.push_back(type_attr);
370     }
371 
372     // Also ensure that type attributes that aren't associated with any inputs
373     // or outputs (e.g., StackV2's elem_type) are added to the map.
374     for (const auto& attr : node.attr()) {
375       const string& attr_name = attr.first;
376       if (!attr_name.empty() && attr_name[0] == '_') continue;
377       const AttrValue& attr_value = attr.second;
378       const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
379       if (!attr_def) {
380         return errors::InvalidArgument("AttrDef not found for attribute ",
381                                        attr_name, " of node ", node.name());
382       }
383       if (attr_def->type() == "type") {
384         type2io_entry[TypeAttrId(attr_name)];
385       } else if (attr_def->type() == "list(type)") {
386         for (int i = 0; i < attr_value.list().type_size(); ++i) {
387           type2io_entry[TypeAttrId(attr_name, i)];
388         }
389       }
390     }
391     return Status::OK();
392   }
393 
394   // WARN: `graph_` must outlive this object (node pointers must remain valid).
395   const GraphDef* graph_ = nullptr;  // do not own
396   std::unique_ptr<FunctionLibraryDefinition> function_library_;
397 
398   typedef absl::flat_hash_set<int> IntSet;
399   // Maps a type attr id -> (input port set, output port set)
400   typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
401   // Maps a node -> type attr mapping
402   absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
403   // Maps a port -> type attr id
404   typedef std::vector<TypeAttrId> TypeAttrIdVec;
405   // Maps a node -> (input port mapping, output port mapping)
406   absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
407       io2type_;
408 };
409 
410 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anonc899daf80111::NodeTypeId411   NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
412       : node(_node), type_attr(_type_attr) {}
413 
414   const NodeDef* node;
415   TypeAttrId type_attr;
416 
operator ==tensorflow::grappler::__anonc899daf80111::NodeTypeId417   bool operator==(const NodeTypeId& other) const {
418     return node == other.node && type_attr == other.type_attr;
419   }
420 
421   template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)422   friend H AbslHashValue(H h, const NodeTypeId& nt) {
423     return H::combine(std::move(h), nt.node, nt.type_attr);
424   }
425 };
426 
427 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anonc899daf80111::NodeTypeIdEdge428   NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
429       : src(_src), dst(_dst) {}
430   NodeTypeId src;
431   NodeTypeId dst;
432 };
433 
434 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
435 // used instead of this modified version.
436 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
437 // the vertices instead of just NodeDef.
438 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
439 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
440 // will be an edge from (A, T) to (B, U).
441 class GraphTypeTopologyView {
442  public:
443   GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)444   explicit GraphTypeTopologyView(bool skip_invalid_edges)
445       : skip_invalid_edges_(skip_invalid_edges) {}
446 
447   // Initialize graph topology view from the graph. It's possible to pass
448   // additional edges that do not exist in a graph, but must be respected when
449   // computing graph topology. Example: Tensorflow runtime allows concurrent
450   // execution of dequeue/enqueue ops from the same queue resource, but we might
451   // want to enforce ordering between them for the purpose of graph analysis.
452   Status InitializeFromGraph(const GraphDef& graph,
453                              const NodeTypeAttrMap& node_type_map);
454 
455   Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges);
456 
is_initialized() const457   bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const458   int num_nodes() const { return num_nodes_; }
graph() const459   const GraphDef* graph() const { return graph_; }
460 
461   // Returns true iff the node exists in the underlying graph.
462   bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
463 
464   // Finds a node by name or returns `nullptr` if it's not in the graph.
465   const NodeTypeId* GetNode(absl::string_view node_name,
466                             const TypeAttrId& type_attr) const;
467   // Returns a node corresponding to the given node index.
468   const NodeTypeId* GetNode(int node_idx) const;
469 
470   // Returns a node index for the given node name, if the name exists in the
471   // underlying graph. Otherwise returns empty optional.
472   const absl::optional<int> GetNodeIndex(absl::string_view node_name,
473                                          const TypeAttrId& type_attr) const;
474   // Returns a node index for the given node, if the node belongs to the
475   // underlying graph. Otherwise returns empty optional.
476   const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
477 
478   // Returns all the node indexes that are in the direct fanin of the given
479   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
480   const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
481   // Returns all the node indexes that are in the direct fanout of the given
482   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
483   const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
484 
485  private:
486   // The key type used to uniquely identify a type attribute on a node.
487   struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
488     typedef std::pair<absl::string_view, TypeAttrId> Base;
489 
490     // Inherit the set of constructors.
491     using Base::pair;
492 
493     template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)494     friend H AbslHashValue(H h, const NodeTypeKey& nt) {
495       return H::combine(std::move(h), nt.first, nt.second);
496     }
497   };
498 
499   // If true, all invalid edges and inputs (srd, dst or input node not found in
500   // a graph) will be skipped, otherwise initialization will fail with error.
501   bool skip_invalid_edges_ = false;
502 
503   // WARN: `graph_` must outlive this object and graph nodes must not be
504   // destructed, because node names captured with absl::string_view.
505   const GraphDef* graph_ = nullptr;  // do not own
506   int num_nodes_ = 0;
507   std::vector<NodeTypeId> node_type_attrs_;
508   absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
509   absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
510 
511   std::vector<absl::InlinedVector<int, 4>> fanins_;
512   std::vector<absl::InlinedVector<int, 2>> fanouts_;
513 
514   // We need a valid reference to return from GetFanin/GetFanout if the
515   // `node_idx` argument is outside of the [0, num_nodes_) range.
516   absl::InlinedVector<int, 4> empty_fanin_;
517   absl::InlinedVector<int, 2> empty_fanout_;
518 };
519 
520 template <typename T>
SortAndRemoveDuplicates(T * v)521 inline void SortAndRemoveDuplicates(T* v) {
522   std::sort(v->begin(), v->end());
523   v->erase(std::unique(v->begin(), v->end()), v->end());
524 }
525 
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)526 Status GraphTypeTopologyView::InitializeFromGraph(
527     const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
528   if (graph_ != nullptr) {
529     return errors::InvalidArgument(
530         "GraphTypeTopologyView is already initialized.");
531   }
532 
533   graph_ = &graph;
534   int num_nodedefs = graph.node_size();
535   node_name_to_index_.rehash(num_nodedefs);
536 
537   // Build maps from name to index.
538   node_type_attrs_.reserve(num_nodedefs);         // Only approximate.
539   node_type_name_to_index_.rehash(num_nodedefs);  // Only approximate.
540   for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
541     const NodeDef& node = graph.node(node_idx);
542     node_name_to_index_.emplace(node.name(), node_idx);
543 
544     for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
545       int node_type_idx = node_type_attrs_.size();
546       node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
547                                        node_type_idx);
548       node_type_attrs_.emplace_back(&node, type_attr);
549     }
550   }
551   num_nodes_ = node_type_attrs_.size();
552   fanins_.resize(num_nodes_);
553   fanouts_.resize(num_nodes_);
554 
555   // Add graph edges to the adjacency lists.
556   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
557     const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
558     auto input_ports =
559         node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
560     fanins_[node_type_idx].reserve(input_ports.size());
561     for (int port : input_ports) {
562       const string& input = node_type.node->input(port);
563       TensorId tensor = ParseTensorName(input);
564       const auto it = node_name_to_index_.find(tensor.node());
565       const bool valid_input = it != node_name_to_index_.end();
566 
567       if (!valid_input) {
568         const string error_message = absl::StrCat(
569             "Non-existent input ", input, " in node ", node_type.node->name());
570         if (skip_invalid_edges_) {
571           VLOG(3) << "Skip error: " << error_message;
572         } else {
573           return errors::InvalidArgument(error_message);
574         }
575       }
576 
577       if (valid_input) {
578         const int input_idx = it->second;
579         const NodeDef& input_node = graph_->node(input_idx);
580         TypeAttrId input_type_attr =
581             node_type_map.GetOutputTypeAttr(input_node, tensor.index());
582         const auto it2 = node_type_name_to_index_.find(
583             NodeTypeKey(input_node.name(), input_type_attr));
584         if (it2 == node_type_name_to_index_.end()) {
585           if (!skip_invalid_edges_) {
586             return errors::InvalidArgument("Did not find type attr ",
587                                            input_type_attr.DebugString(),
588                                            " in node ", input_node.name());
589           }
590           continue;
591         }
592         int input_node_type_idx = it2->second;
593         fanins_[node_type_idx].push_back(input_node_type_idx);
594         fanouts_[input_node_type_idx].push_back(node_type_idx);
595       }
596     }
597 
598     // Dedup the input list while it's still hot in cache.
599     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
600   }
601 
602   // Dedup outputs for all the graph nodes.
603   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
604     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
605   }
606 
607   return Status::OK();
608 }
609 
AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges)610 Status GraphTypeTopologyView::AddEphemeralEdges(
611     absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
612   // Add ephemeral edges to the adjacency lists.
613   for (const NodeTypeIdEdge& edge : ephemeral_edges) {
614     const auto src = node_name_to_index_.find(edge.src.node->name());
615     const bool valid_src = src != node_name_to_index_.end();
616 
617     if (!valid_src) {
618       const string error_message =
619           absl::StrCat("Non-existent src node: ", edge.src.node->name());
620       if (skip_invalid_edges_) {
621         VLOG(0) << "Skip error: " << error_message;
622       } else {
623         return errors::InvalidArgument(error_message);
624       }
625     }
626 
627     const auto dst = node_name_to_index_.find(edge.dst.node->name());
628     const bool valid_dst = dst != node_name_to_index_.end();
629 
630     if (!valid_dst) {
631       const string error_message =
632           absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
633       if (skip_invalid_edges_) {
634         VLOG(0) << "Skip error: " << error_message;
635       } else {
636         return errors::InvalidArgument(error_message);
637       }
638     }
639 
640     if (valid_dst && valid_src) {
641       // TODO(benbarsdell): Check for failure.
642       int src_node_type_idx = node_type_name_to_index_.at(
643           NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
644       int dst_node_type_idx = node_type_name_to_index_.at(
645           NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
646       fanins_[dst_node_type_idx].push_back(src_node_type_idx);
647       fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
648     }
649   }
650 
651   // Dedup inputs and outputs for all the graph nodes.
652   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
653     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
654     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
655   }
656 
657   return Status::OK();
658 }
659 
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const660 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
661                                     const TypeAttrId& type_attr) const {
662   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
663   NodeTypeKey key(node_name, type_attr);
664   const auto it = node_type_name_to_index_.find(key);
665   return it != node_type_name_to_index_.end();
666 }
667 
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const668 const NodeTypeId* GraphTypeTopologyView::GetNode(
669     absl::string_view node_name, const TypeAttrId& type_attr) const {
670   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
671   NodeTypeKey key(node_name, type_attr);
672   const auto it = node_type_name_to_index_.find(key);
673   return it == node_type_name_to_index_.end()
674              ? nullptr
675              : &node_type_attrs_.at(it->second);
676 }
677 
GetNode(int node_idx) const678 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
679   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
680   DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
681   return &node_type_attrs_.at(node_idx);
682 }
683 
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const684 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
685     absl::string_view node_name, const TypeAttrId& type_attr) const {
686   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
687   NodeTypeKey key(node_name, type_attr);
688   const auto it = node_type_name_to_index_.find(key);
689   DCHECK(it != node_type_name_to_index_.end())
690       << "Node doesn't exist in a graph";
691   return it == node_type_name_to_index_.end() ? absl::nullopt
692                                               : absl::make_optional(it->second);
693 }
694 
GetNodeIndex(const NodeTypeId & node) const695 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
696     const NodeTypeId& node) const {
697   return GetNodeIndex(node.node->name(), node.type_attr);
698 }
699 
GetFanin(int node_idx) const700 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
701     int node_idx) const {
702   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
703   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
704   DCHECK(is_valid_node_idx) << "node_idx is out of range";
705   return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
706 }
707 
GetFanout(int node_idx) const708 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
709     int node_idx) const {
710   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
711   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
712   DCHECK(is_valid_node_idx) << "node_idx is out of range";
713   return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
714 }
715 
716 enum class TypeTraversalDirection {
717   kFollowInputs,
718   kFollowOutputs,
719   kFollowInputsAndOutputs,
720 };
721 
722 // Encapsulate DFS callbacks that will be called during the graph traversal.
723 //
724 // If non-empty, the `pre_order` and `post_order` functors will be called on
725 // each reachable node (including the `from` nodes) in pre and post order. If
726 // loops are found, the `on_back_edge` functor will be called on the
727 // corresponding back edges. Moreover, the pre and post order will assume that
728 // these back edges will be cut.
729 struct DfsTypeCallbacks {
730   DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anonc899daf80111::DfsTypeCallbacks731   DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
732                    std::function<void(int, int)> back_edge)
733       : pre_order(std::move(pre)),
734         post_order(std::move(post)),
735         on_back_edge(std::move(back_edge)) {}
736 
PreOrdertensorflow::grappler::__anonc899daf80111::DfsTypeCallbacks737   static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
738     return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
739   }
740 
PostOrdertensorflow::grappler::__anonc899daf80111::DfsTypeCallbacks741   static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
742     return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
743   }
744 
745   std::function<void(int)> pre_order;
746   std::function<void(int)> post_order;
747   std::function<void(int, int)> on_back_edge;
748 };
749 
750 // Encapsulate DFS predicates for traversing the graph.
751 //
752 // The `enter` predicate decides if traversal should enter the node, and the
753 // `advance` predicate decides if the traversal should follow inputs/outputs
754 // from the node.
755 //
756 // If predicates are empty (default initialized), it's assumed that we can enter
757 // into any node and advance from any node respectively.
758 struct DfsTypePredicates {
759   DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anonc899daf80111::DfsTypePredicates760   DfsTypePredicates(std::function<bool(int)> enter,
761                     std::function<bool(int)> advance)
762       : enter(std::move(enter)), advance(std::move(advance)) {}
763 
Entertensorflow::grappler::__anonc899daf80111::DfsTypePredicates764   static DfsTypePredicates Enter(std::function<bool(int)> enter) {
765     return DfsTypePredicates(std::move(enter), nullptr);
766   }
767 
Advancetensorflow::grappler::__anonc899daf80111::DfsTypePredicates768   static DfsTypePredicates Advance(std::function<bool(int)> advance) {
769     return DfsTypePredicates(nullptr, std::move(advance));
770   }
771 
772   std::function<bool(int)> enter;
773   std::function<bool(int)> advance;
774 };
775 
776 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anonc899daf80111::DfsStackElem777   DfsStackElem(int node, bool children_visited, int src)
778       : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anonc899daf80111::DfsStackElem779   explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
780 
781   // Index of the node in the graph ∊ [0, num_nodes).
782   int node;
783   // `True` if visited all the input/output nodes (pushed all input/output nodes
784   // to the stack).
785   bool children_visited;
786   // Index of the node in the graph, from which we entered the `node`.
787   int src;
788 };
789 
790 enum class NodeState { kNotVisited, kVisiting, kDone };
791 
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)792 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
793                       const absl::Span<const NodeTypeId* const> from,
794                       const TypeTraversalDirection direction,
795                       const DfsTypePredicates& predicates,
796                       const DfsTypeCallbacks& callbacks) {
797   std::vector<DfsStackElem> stack;
798   stack.reserve(from.size());
799 
800   for (const NodeTypeId* node : from) {
801     const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
802     DCHECK(node_idx.has_value())
803         << "Illegal start node: " << node->node->name();
804     if (node_idx.has_value()) {
805       stack.emplace_back(node_idx.value());
806     }
807   }
808 
809   absl::flat_hash_map<int, NodeState> node_state;
810   while (!stack.empty()) {
811     DfsStackElem w = stack.back();
812     stack.pop_back();
813 
814     NodeState& state = node_state[w.node];
815     if (state == NodeState::kDone) continue;
816 
817     // Skip nodes that we should not enter.
818     if (predicates.enter && !predicates.enter(w.node)) {
819       state = NodeState::kDone;
820       continue;
821     }
822 
823     // We've processed all the children of this node.
824     if (w.children_visited) {
825       state = NodeState::kDone;
826       if (callbacks.post_order) {
827         callbacks.post_order(w.node);
828       }
829       continue;
830     }
831 
832     // Loop detected.
833     if (state == NodeState::kVisiting) {
834       if (callbacks.on_back_edge) {
835         callbacks.on_back_edge(w.src, w.node);
836       }
837       continue;
838     }
839 
840     state = NodeState::kVisiting;
841     if (callbacks.pre_order) {
842       callbacks.pre_order(w.node);
843     }
844 
845     // Enqueue the node again with the children_visited flag set to true.
846     stack.emplace_back(w.node, true, w.src);
847 
848     // Check if we can continue traversal from the current node.
849     if (predicates.advance && !predicates.advance(w.node)) {
850       continue;
851     }
852 
853     // Now enqueue the fanin/fanout nodes.
854     if (direction == TypeTraversalDirection::kFollowInputs ||
855         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
856       for (const int fanin : graph_type_view.GetFanin(w.node)) {
857         stack.emplace_back(fanin, false, w.node);
858       }
859     }
860     if (direction == TypeTraversalDirection::kFollowOutputs ||
861         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
862       for (const int fanout : graph_type_view.GetFanout(w.node)) {
863         stack.emplace_back(fanout, false, w.node);
864       }
865     }
866   }
867 }
868 
AllowedDataTypes(const OpDef::AttrDef & attr_def)869 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
870   const auto& allowed_types = attr_def.allowed_values().list().type();
871   if (allowed_types.empty()) {
872     return AllTypes();
873   }
874   uint32 dtype_mask = 0;
875   for (int dtype : allowed_types) {
876     dtype_mask |= 1u << dtype;
877   }
878   return DataTypeSet(dtype_mask);
879 }
880 
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)881 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
882   if (t_attr_id.attr_name.empty()) {
883     return ToSet(t_attr_id.fixed_type);
884   }
885   const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
886   CHECK(attr_def);  // Crash Ok
887   return AllowedDataTypes(*attr_def);
888 }
889 
ValidateLists(const gtl::FlatSet<string> & allow_list,const gtl::FlatSet<string> & deny_list,const gtl::FlatSet<string> & infer_list,const gtl::FlatSet<string> & clear_list)890 Status ValidateLists(const gtl::FlatSet<string>& allow_list,
891                      const gtl::FlatSet<string>& deny_list,
892                      const gtl::FlatSet<string>& infer_list,
893                      const gtl::FlatSet<string>& clear_list) {
894   std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
895                                           clear_list};
896   std::multiset<string> counts;
897   for (const auto& list : lists) {
898     counts.insert(list.begin(), list.end());
899   }
900   bool duplicates = false;
901   for (const auto& s : counts) {
902     if (counts.count(s) > 1) {
903       duplicates = true;
904       LOG(ERROR) << "Op present in multiple lists: " << s;
905     }
906   }
907   if (duplicates) {
908     return errors::InvalidArgument("Op lists have conflicting entries");
909   } else {
910     return Status::OK();
911   }
912 }
913 
HasInputOrOutputRefs(const NodeDef & node)914 bool HasInputOrOutputRefs(const NodeDef& node) {
915   const OpDef* op_def;
916   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
917   if (!status.ok()) {
918     return true;
919   }
920   for (const auto& input : op_def->input_arg()) {
921     if (input.is_ref()) {
922       return true;
923     }
924   }
925   for (const auto& output : op_def->output_arg()) {
926     if (output.is_ref()) {
927       return true;
928     }
929   }
930   return false;
931 }
932 
933 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)934 bool CanForceFP16(const NodeDef& node) {
935   return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
936          !IsStateful(node) && !HasInputOrOutputRefs(node);
937 }
938 
GetCudaVersion(const Cluster & cluster)939 int GetCudaVersion(const Cluster& cluster) {
940   auto devices = cluster.GetDevices();
941   for (const auto& device : devices) {
942     const DeviceProperties& device_properties = device.second;
943     if (device_properties.type() == "GPU") {
944       const auto& device_env = device_properties.environment();
945       auto it = device_env.find("cuda");
946       if (it != device_env.end()) {
947         string cuda_version_str = it->second;
948         return std::stoi(cuda_version_str);
949       }
950     }
951   }
952   return 0;
953 }
954 
GetCudnnVersion(const Cluster & cluster)955 int GetCudnnVersion(const Cluster& cluster) {
956   auto devices = cluster.GetDevices();
957   for (const auto& device : devices) {
958     const DeviceProperties& device_properties = device.second;
959     if (device_properties.type() == "GPU") {
960       const auto& device_env = device_properties.environment();
961       auto it = device_env.find("cudnn");
962       if (it != device_env.end()) {
963         string cudnn_version_str = it->second;
964         return std::stoi(cudnn_version_str);
965       }
966     }
967   }
968   return 0;
969 }
970 
971 class AutoMixedPrecisionImpl {
972  public:
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id,AutoMixedPrecisionMode mode)973   AutoMixedPrecisionImpl(Cluster* cluster,
974                          const std::unordered_set<string>& nodes_to_preserve,
975                          GraphDef* graph, string id,
976                          AutoMixedPrecisionMode mode)
977       : virtual_placer_(cluster->GetDevices()),
978         nodes_to_preserve_(nodes_to_preserve),
979         graph_(graph),
980         function_library_(OpRegistry::Global(), graph->library()),
981         id_(id),
982         graph_view_(graph),
983         cuda_version_(GetCudaVersion(*cluster)),
984         cudnn_version_(GetCudnnVersion(*cluster)),
985         mode_(mode),
986         target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
987                                                             : DT_BFLOAT16) {}
988 
989   Status Optimize();
990 
991  private:
992   typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
993 
get_mixed_precision_lists() const994   std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
995     switch (mode_) {
996       case AutoMixedPrecisionMode::CUDA:
997         return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
998                                                              cudnn_version_);
999       case AutoMixedPrecisionMode::MKL:
1000         return std::make_unique<AutoMixedPrecisionListsMkl>();
1001     }
1002   }
1003   Status PrintDebugLogs(bool preop, size_t timestamp);
1004   void LogSkippedNode(const NodeDef& node) const;
1005   bool MustPreserve(const NodeDef& node) const;
1006   bool IsOnDevice(const NodeDef& node, const string& device_type) const;
1007   bool IsOnSuitableGPUArch(const NodeDef& node) const;
1008   bool ShouldProcess(const NodeDef& node) const;
1009   bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
1010   bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
1011   void ConvertBatchNormOpsToV2();
1012   bool SupportsF16(const NodeTypeId& node_type) const;
1013   bool SupportsF16DataType(const NodeTypeId& node_type) const;
1014   bool IsQuantized(const NodeTypeId& node_type) const;
1015   const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
1016   bool IsSourceOrSinkOp(const string& op) const;
1017   void FindFloat32TensorListOpClustersAndDenylistUnsafe(
1018       std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
1019       absl::flat_hash_set<int>* deny_set) const;
1020   void FindTensorListImplicitFloat32Edges(
1021       const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1022       std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
1023   void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
1024   void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const;
1025   void PropagateDenyFwdThroughClearAndInfer(
1026       absl::flat_hash_set<int>* deny_set) const;
1027   void ForceColorMatchBetweenTensorListOps(
1028       const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1029       absl::flat_hash_set<int>* allow_set,
1030       absl::flat_hash_set<int>* deny_set) const;
1031   void AddClearAndInferToAllowIfBetweenAllow(
1032       const absl::flat_hash_set<int>& deny_set,
1033       absl::flat_hash_set<int>* allow_set) const;
1034   void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
1035                                   absl::flat_hash_set<int>* allow_set) const;
1036   Status ForceColorMatchOnRecurrentEdges(
1037       absl::flat_hash_set<int>* allow_set) const;
1038   void MakeCastsAllowIfAllOutputsAllow(
1039       absl::flat_hash_set<int>* allow_set) const;
1040   NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
1041                         const string& device) const;
1042   Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
1043 
1044   VirtualPlacer virtual_placer_;
1045   std::unordered_set<string> nodes_to_preserve_;
1046   GraphDef* graph_;
1047   FunctionLibraryDefinition function_library_;
1048   string id_;
1049   MutableGraphView graph_view_;
1050   int cuda_version_;
1051   int cudnn_version_;
1052   NodeTypeAttrMap node_type_map_;
1053   GraphTypeTopologyView graph_type_view_;
1054   bool force_all_fp16_;
1055   AutoMixedPrecisionMode mode_;
1056   gtl::FlatSet<string> f16_allowlist_;
1057   gtl::FlatSet<string> f16_denylist_;
1058   gtl::FlatSet<string> f16_inferlist_;
1059   gtl::FlatSet<string> f16_clearlist_;
1060   absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1061   DataType target_dtype_;  // Either DT_HALF or DT_BFLOAT16
1062 };
1063 
BuildCastNode(const MutableGraphView::OutputPort & src,bool to_f16,const string & device) const1064 NodeDef AutoMixedPrecisionImpl::BuildCastNode(
1065     const MutableGraphView::OutputPort& src, bool to_f16,
1066     const string& device) const {
1067   DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
1068   DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
1069   const char* cast_string = !to_f16                    ? kCastToFp32
1070                             : target_dtype_ == DT_HALF ? kCastToFp16
1071                                                        : kCastToBf16;
1072   string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
1073                                 cast_string, "-", kSuffix);
1074   NodeDef node;
1075   node.set_name(name);
1076   node.set_op("Cast");
1077   node.set_device(device);
1078   node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
1079   (*node.mutable_attr())["SrcT"].set_type(src_type);
1080   (*node.mutable_attr())["DstT"].set_type(dst_type);
1081   (*node.mutable_attr())["Truncate"].set_b(false);
1082   return node;
1083 }
1084 
NodeHasF16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1085 bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
1086     const NodeDef& node, TypeAttrId taid) const {
1087   NodeDef node_copy(node);
1088   if (node.device().empty()) {
1089     string device_name = virtual_placer_.get_canonical_device_name(node);
1090     node_copy.set_device(device_name);
1091   }
1092   if (!SetDataType(&node_copy, taid, target_dtype_)) {
1093     return false;
1094   }
1095   return IsKernelRegisteredForNode(node_copy).ok();
1096 }
1097 
PrintDebugLogs(bool preop,size_t timestamp)1098 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1099   string prepend_path;
1100   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1101       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1102   if (prepend_path.empty()) return Status::OK();
1103 
1104   string suffix =
1105       strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1106 
1107   string fname =
1108       io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1109   std::fstream f;
1110   f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1111   f << graph_->SerializeAsString();
1112   f.close();
1113   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1114             << " graph as binary to " << fname;
1115 
1116   fname = io::JoinPath(prepend_path,
1117                        strings::StrCat("graphdef", suffix, ".pb.txt"));
1118   f.open(fname.c_str(), std::fstream::out);
1119   f << graph_->DebugString();
1120   f.close();
1121   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1122             << " graph as text to " << fname;
1123 
1124   if (!preop) {
1125     fname = io::JoinPath(prepend_path,
1126                          strings::StrCat("paintbuckets", suffix, ".txt"));
1127     f.open(fname.c_str(), std::fstream::out);
1128     std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1129         get_mixed_precision_lists();
1130     f << "AllowList:\n";
1131     for (const auto& x : mp_lists->AllowList()) {
1132       f << x << "\n";
1133     }
1134     f << "\nDenyList:\n";
1135     for (const auto& x : mp_lists->DenyList()) {
1136       f << x << "\n";
1137     }
1138     f << "\nInferList:\n";
1139     for (const auto& x : mp_lists->InferList()) {
1140       f << x << "\n";
1141     }
1142     f << "\nClearList:\n";
1143     for (const auto& x : mp_lists->ClearList()) {
1144       f << x << "\n";
1145     }
1146     f.close();
1147     LOG(INFO) << "Saved paint bucket info to " << fname;
1148   }
1149   return Status::OK();
1150 }
1151 
LogSkippedNode(const NodeDef & node) const1152 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
1153   VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1154           << " because it "
1155           << (MustPreserve(node)
1156                   ? "must be preserved"
1157                   : "is not on the GPU, or the GPU arch is not suitable");
1158 }
1159 
MustPreserve(const NodeDef & node) const1160 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1161   return nodes_to_preserve_.count(node.name());
1162 }
1163 
IsOnDevice(const NodeDef & node,const string & device_type) const1164 bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
1165                                         const string& device_type) const {
1166   string device_name;
1167   if (node.device().empty()) {
1168     device_name = virtual_placer_.get_canonical_device_name(node);
1169   } else {
1170     device_name = node.device();
1171   }
1172   string device;
1173   string not_used;
1174   if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
1175       absl::StrContains(absl::AsciiStrToLower(device),
1176                         absl::AsciiStrToLower(device_type))) {
1177     return true;
1178   }
1179   return false;
1180 }
1181 
IsOnSuitableGPUArch(const NodeDef & node) const1182 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1183   return HasFastFP16Support(virtual_placer_.get_device(node));
1184 }
1185 
ShouldProcess(const NodeDef & node) const1186 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1187   return should_process_nodes_.count(&node);
1188 }
1189 
IsFloat32(const NodeTypeId & node_type)1190 bool IsFloat32(const NodeTypeId& node_type) {
1191   return GetDataType(*node_type.node, node_type.type_attr) ==
1192          DataType::DT_FLOAT;
1193 }
1194 
IsTensorListOp(const string & op)1195 bool IsTensorListOp(const string& op) {
1196   return op.find("TensorList") != string::npos;
1197 }
1198 
IsTensorListReaderOp(const string & op)1199 bool IsTensorListReaderOp(const string& op) {
1200   static const gtl::FlatSet<string> tensor_list_reader_ops = {
1201       "TensorListConcat",  "TensorListConcatV2", "TensorListGather",
1202       "TensorListGetItem", "TensorListPopBack",  "TensorListStack"};
1203   return tensor_list_reader_ops.count(op);
1204 }
1205 
IsTensorListWriterOp(const string & op)1206 bool IsTensorListWriterOp(const string& op) {
1207   static const gtl::FlatSet<string> tensor_list_writer_ops = {
1208       "TensorListFromTensor",    "TensorListPushBack",
1209       "TensorListPushBackBatch", "TensorListScatter",
1210       "TensorListScatterV2",     "TensorListScatterIntoExistingList",
1211       "TensorListSetItem",       "TensorListSplit"};
1212   return tensor_list_writer_ops.count(op);
1213 }
1214 
SupportsF16(const NodeTypeId & node_type) const1215 bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
1216   const OpDef* op_def;
1217   Status status =
1218       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1219   if (!status.ok()) return false;
1220   return AllowedDataTypes(*op_def, node_type.type_attr)
1221              .Contains(target_dtype_) &&
1222          NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1223 }
1224 
SupportsF16DataType(const NodeTypeId & node_type) const1225 bool AutoMixedPrecisionImpl::SupportsF16DataType(
1226     const NodeTypeId& node_type) const {
1227   const OpDef* op_def;
1228   Status status =
1229       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1230   if (!status.ok()) return false;
1231   return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_);
1232 }
1233 
IsQuantized(const NodeTypeId & node_type) const1234 bool AutoMixedPrecisionImpl::IsQuantized(const NodeTypeId& node_type) const {
1235   for (const TypeAttrId& type_attr :
1236        node_type_map_.GetTypeAttrs(*node_type.node)) {
1237     if (DataTypeIsQuantized(GetDataType(*node_type.node, type_attr))) {
1238       return true;
1239     }
1240   }
1241   return false;
1242 }
1243 
1244 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1245 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1246 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1247   for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1248     NodeDef* node = graph_->mutable_node(node_idx);
1249     if (!ShouldProcess(*node)) continue;
1250     bool changed = false;
1251     if (node->op() == "FusedBatchNorm") {
1252       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1253               << " to FusedBatchNormV2";
1254       node->set_op("FusedBatchNormV2");
1255       changed = true;
1256     } else if (node->op() == "FusedBatchNormGrad") {
1257       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1258               << " to FusedBatchNormGradV2";
1259       node->set_op("FusedBatchNormGradV2");
1260       changed = true;
1261     }
1262     if (changed) {
1263       (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1264     }
1265   }
1266 }
1267 
1268 // A helper function to decide whether to ignore the effect on performance when
1269 // rewriting the graph. This can be useful for testing the numerical effects of
1270 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1271 bool ShouldIgnorePerformance() {
1272   static bool is_enabled = [] {
1273     bool ret = false;
1274     TF_CHECK_OK(ReadBoolFromEnvVar(
1275         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1276         /*default_val=*/false, &ret));
1277     return ret;
1278   }();
1279   return is_enabled;
1280 }
1281 
Optimize()1282 Status AutoMixedPrecisionImpl::Optimize() {
1283   string optimization_level;
1284   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1285       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1286   optimization_level = absl::AsciiStrToUpper(optimization_level);
1287   force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1288   if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
1289     // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
1290     // bfloat16.
1291     return errors::InvalidArgument(
1292         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
1293         "UNSAFE_FORCE_ALL when MKL is used");
1294   }
1295 
1296   std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1297       get_mixed_precision_lists();
1298   f16_allowlist_ = mp_lists->AllowList();
1299   f16_denylist_ = mp_lists->DenyList();
1300   f16_inferlist_ = mp_lists->InferList();
1301   f16_clearlist_ = mp_lists->ClearList();
1302   TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
1303                                    f16_inferlist_, f16_clearlist_));
1304 
1305   size_t timestamp = Env::Default()->NowMicros() / 1000;
1306   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1307 
1308   VLOG(2) << "Identifying nodes that should be processed";
1309   for (const NodeDef& node : graph_->node()) {
1310     bool should_process;
1311     switch (mode_) {
1312       case AutoMixedPrecisionMode::CUDA:
1313         should_process =
1314             !MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
1315             (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
1316         break;
1317       case AutoMixedPrecisionMode::MKL:
1318         should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
1319         break;
1320     }
1321     if (should_process) {
1322       should_process_nodes_.insert(&node);
1323     } else {
1324       LogSkippedNode(node);
1325     }
1326   }
1327 
1328   VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1329   ConvertBatchNormOpsToV2();
1330 
1331   VLOG(2) << "Building node type map for graph";
1332   TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1333 
1334   VLOG(2) << "Constructing graph type attribute topology view";
1335   TF_RETURN_IF_ERROR(
1336       graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
1337 
1338   absl::flat_hash_set<int> deny_set;
1339 
1340   std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
1341   FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
1342                                                    &deny_set);
1343   std::vector<NodeTypeIdEdge> ephemeral_edges;
1344   for (const auto& cluster : tensor_list_clusters) {
1345     VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
1346     for (const NodeDef* node : cluster) {
1347       VLOG(2) << "  Cluster member: " << node->op() << " node " << node->name();
1348     }
1349     FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
1350   }
1351   TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
1352 
1353   // The goal here is to change performance-critical ops to fp16 or bf16, and to
1354   // do so with the minimal number of casts, subject to the constraint that the
1355   // model's convergence is not affected. This is achieved by first identifying
1356   // which nodes should be changed to f16 and then inserting casts at the
1357   // boundaries between f16/non-f16 nodes.
1358 
1359   // The algorithm for deciding which nodes to change to f16 is as follows:
1360   // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
1361   //    This is done under the assumption that allowlist ops are always
1362   //    numerically-safe in f16 and that they are the most important ops for
1363   //    improving performance.
1364   // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
1365   //    "denylist" ops) or they are on a forward path from a denylist node to
1366   //    a deny/infer node (including the node at the end of the path) through
1367   //    non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
1368   //    This is done to prevent numerically-dangerous ops and their downstream
1369   //    effects from being changed to f16, which would risk breaking the
1370   //    numerical accuracy of the model.
1371   // 3) For all remaining nodes that are not considered dangerous (inferlist
1372   //    and clearlist ops), find those that are between (i.e., both upstream
1373   //    and downstream of) allow nodes, and add them to the allow_set.
1374   //    This is done to avoid unnecessary casts between allowlist ops.
1375   // 4) For all remaining clearlist nodes, add them to the allow_set if they are
1376   //    connected to a node in the allow_set via other clearlist nodes.
1377   //    This is done to increase the number of ops in the allow_set without
1378   //    affecting numerical stability.
1379 
1380   absl::flat_hash_set<int> allow_set;
1381   VLOG(2) << "Beginning pass 1 to add allowlist ops";
1382   AddAllowlistOps(&allow_set);
1383   VLOG(2) << "Finished pass 1";
1384 
1385   if (allow_set.empty()) {
1386     LOG(INFO) << "No allowlist ops found, nothing to do";
1387     return Status::OK();
1388   }
1389 
1390   VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
1391              "through clear/inferlist ops";
1392   PropagateDenyFwdThroughClearAndInfer(&deny_set);
1393   VLOG(2) << "Finished pass 2";
1394 
1395   VLOG(2) << "Forcing color match between data structure ops";
1396   for (const auto& cluster : tensor_list_clusters) {
1397     ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1398   }
1399 
1400   VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
1401              "are between allow ops";
1402   AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
1403   VLOG(2) << "Finished pass 3";
1404 
1405   VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
1406              "clearlist ops";
1407   PropagateAllowThroughClear(deny_set, &allow_set);
1408   VLOG(2) << "Finished pass 4";
1409 
1410   VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed "
1411              "to F16"
1412              "from allow set";
1413   RemoveAllowsetWithFp32(&allow_set);
1414   VLOG(2) << "Finished pass 5";
1415 
1416   VLOG(2) << "Forcing color match between data structure ops";
1417   for (const auto& cluster : tensor_list_clusters) {
1418     ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1419   }
1420 
1421   VLOG(2) << "Forcing color match on loop edges";
1422   TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
1423 
1424   VLOG(2) << "Finding existing casts that can be made allow";
1425   MakeCastsAllowIfAllOutputsAllow(&allow_set);
1426 
1427   VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1428              "ops at paint boundaries";
1429   TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
1430   VLOG(2) << "Finished final pass";
1431 
1432   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1433 
1434   return Status::OK();
1435 }
1436 
1437 // If node is a Tensor List op with a float32 data type attribute then this
1438 // returns a pointer to the NodeTypeId representing that type attribute. In
1439 // all other cases this returns nullptr.
GetTensorListFloat32NodeTypeId(const NodeDef & node) const1440 const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId(
1441     const NodeDef& node) const {
1442   if (!IsTensorListOp(node.op())) return nullptr;
1443   for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) {
1444     const NodeTypeId* node_type =
1445         graph_type_view_.GetNode(node.name(), type_attr);
1446     // This assumes that the float32 data type on a Tensor List op is always a
1447     // non-fixed type attribute containing a single type, and that this type
1448     // attribute represents the dtype of the values in the list.
1449     // TODO(benbarsdell): A new Tensor List op could theoretically break these
1450     // assumptions.
1451     if (node_type && node_type->type_attr.fixed_type == DT_INVALID &&
1452         node_type->type_attr.type_index == TypeAttrId::kSingleType &&
1453         IsFloat32(*node_type)) {
1454       return node_type;
1455     }
1456   }
1457   return nullptr;
1458 }
1459 
IsSourceOrSinkOp(const string & op) const1460 bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
1461   const gtl::FlatSet<string> source_and_sink_ops = {
1462       "_Arg",
1463       "_Retval",
1464       "OptionalFromValue",
1465       "OptionalGetValue",
1466       "PartitionedCall",
1467       "Placeholder",
1468       "StatefulPartitionedCall",
1469   };
1470   return source_and_sink_ops.count(op) || function_library_.Find(op);
1471 }
1472 
1473 // Finds all clusters of float32 Tensor List nodes that are connected via their
1474 // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
1475 // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
1476 // nodes) are added to deny_set. The caller should paint all nodes in a cluster
1477 // the same color, as they may all refer to the same Tensor List.
FindFloat32TensorListOpClustersAndDenylistUnsafe(std::vector<absl::flat_hash_set<const NodeDef * >> * tensor_list_clusters,absl::flat_hash_set<int> * deny_set) const1478 void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
1479     std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
1480     absl::flat_hash_set<int>* deny_set) const {
1481   absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
1482   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1483     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1484     if (!ShouldProcess(*root.node) ||
1485         root.type_attr.fixed_type != DataType::DT_VARIANT ||
1486         !GetTensorListFloat32NodeTypeId(*root.node) ||
1487         tensor_list_prop_set.count(root.node)) {
1488       continue;
1489     }
1490     const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1491     const absl::optional<int> maybe_root_fp32_idx =
1492         graph_type_view_.GetNodeIndex(*root_fp32);
1493     DCHECK(maybe_root_fp32_idx.has_value())
1494         << "Type attribute " << root_fp32->type_attr.DebugString()
1495         << " of node " << root.node->name() << " not found in graph view";
1496     int root_fp32_idx = maybe_root_fp32_idx.value();
1497     // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all
1498     // connected Tensor List nodes.
1499     absl::flat_hash_set<const NodeDef*> cluster({root.node});
1500     DfsTypeTraversal(graph_type_view_, {&root},
1501                      TypeTraversalDirection::kFollowInputsAndOutputs,
1502                      DfsTypePredicates::Enter([&](int idx) -> bool {
1503                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1504                        return !tensor_list_prop_set.count(item.node);
1505                      }),
1506                      DfsTypeCallbacks::PreOrder([&](int idx) {
1507                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1508                        const NodeDef* node = item.node;
1509                        if (GetTensorListFloat32NodeTypeId(*node)) {
1510                          cluster.insert(node);
1511                          if (!ShouldProcess(*node)) {
1512                            // The cluster contains an un-processable node.
1513                            deny_set->insert(root_fp32_idx);
1514                          }
1515                          // TODO(benbarsdell): In a theoretical pathological
1516                          // case of a Tensor List of Tensor List handles, the
1517                          // Tensor List itself would need to be treated as a
1518                          // sink.
1519                        } else if (IsSourceOrSinkOp(node->op())) {
1520                          // The cluster crosses an untraversable boundary.
1521                          deny_set->insert(root_fp32_idx);
1522                        }
1523                      }));
1524     tensor_list_clusters->push_back(cluster);
1525   }
1526 }
1527 
1528 // Finds all writer -> reader pairs in the given set that are connected via
1529 // their handles, and adds corresponding float32 edges to *implicit_fp32_edges.
FindTensorListImplicitFloat32Edges(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,std::vector<NodeTypeIdEdge> * implicit_fp32_edges) const1530 void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
1531     const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1532     std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const {
1533   for (const NodeDef* root_node : tensor_list_nodes) {
1534     if (!IsTensorListReaderOp(root_node->op())) continue;
1535     NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT));
1536     const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1537     CHECK(root_fp32) << "No float32 type attribute found for "  // Crash OK
1538                      << root.node->op() << " node " << root.node->name();
1539     // Search backwards through handle edges (DT_VARIANT) for all writer ops,
1540     // adding direct implicit edges between them and the reader.
1541     DfsTypeTraversal(
1542         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1543         DfsTypePredicates::Enter([&](int idx) -> bool {
1544           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1545           return ShouldProcess(*item.node);
1546         }),
1547         DfsTypeCallbacks::PreOrder([&](int idx) {
1548           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1549           if (IsTensorListWriterOp(item.node->op())) {
1550             const NodeTypeId* item_fp32 =
1551                 GetTensorListFloat32NodeTypeId(*item.node);
1552             CHECK(item_fp32)  // Crash OK
1553                 << "No float32 type attribute found for " << item.node->op()
1554                 << " node " << item.node->name();
1555             VLOG(2) << "Adding ephemeral float32 edge from "
1556                     << item_fp32->node->op() << " node "
1557                     << item_fp32->node->name() << " to "
1558                     << root_fp32->node->op() << " node "
1559                     << root_fp32->node->name();
1560             implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32);
1561           }
1562         }));
1563   }
1564 }
1565 
AddAllowlistOps(absl::flat_hash_set<int> * allow_set) const1566 void AutoMixedPrecisionImpl::AddAllowlistOps(
1567     absl::flat_hash_set<int>* allow_set) const {
1568   // Add allowlisted ops to allow_set.
1569   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1570     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1571     if (!ShouldProcess(*root.node)) continue;
1572     bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
1573     if (f16_allowlist_.count(root.node->op()) || force_allow) {
1574       bool inserted = allow_set->insert(root_idx).second;
1575       if (VLOG_IS_ON(2) && inserted) {
1576         VLOG(2) << "Painting type " << root.type_attr.DebugString()
1577                 << " of node " << root.node->name() << " ALLOW because its op "
1578                 << root.node->op() << " is on the allowlist";
1579       }
1580     }
1581   }
1582 }
1583 
1584 // Adds nodes to deny_set iff they are on the denylist or they are on a
1585 // forward path from a denylist node to a deny/infer node (including the node
1586 // at the end of the path) through clear and infer nodes.
1587 // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
1588 // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
PropagateDenyFwdThroughClearAndInfer(absl::flat_hash_set<int> * deny_set) const1589 void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
1590     absl::flat_hash_set<int>* deny_set) const {
1591   if (force_all_fp16_) return;
1592 
1593   // Find clear nodes that are upstream of deny or infer.
1594   absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
1595   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1596     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1597     if (!(f16_denylist_.count(root.node->op()) ||
1598           f16_inferlist_.count(root.node->op()))) {
1599       continue;
1600     }
1601     DfsTypeTraversal(graph_type_view_, {&root},
1602                      TypeTraversalDirection::kFollowInputs,
1603                      DfsTypePredicates::Enter([&](int idx) -> bool {
1604                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1605                        return idx == root_idx ||
1606                               (!upstream_of_deny_or_infer_set.count(idx) &&
1607                                f16_clearlist_.count(item.node->op()));
1608                      }),
1609                      DfsTypeCallbacks::PreOrder([&](int idx) {
1610                        upstream_of_deny_or_infer_set.insert(idx);
1611                      }));
1612   }
1613 
1614   // Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
1615   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1616     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1617     if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
1618       continue;
1619     }
1620     DfsTypeTraversal(
1621         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1622         DfsTypePredicates::Enter([&](int idx) -> bool {
1623           return idx == root_idx || (!deny_set->count(idx) &&
1624                                      upstream_of_deny_or_infer_set.count(idx));
1625         }),
1626         DfsTypeCallbacks::PreOrder([&](int idx) {
1627           bool inserted = deny_set->insert(idx).second;
1628           if (VLOG_IS_ON(2) && inserted) {
1629             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1630             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1631                     << " of " << item.node->op() << " node "
1632                     << item.node->name() << " DENY";
1633           }
1634         }));
1635   }
1636 }
1637 
AddClearAndInferToAllowIfBetweenAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1638 void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
1639     const absl::flat_hash_set<int>& deny_set,
1640     absl::flat_hash_set<int>* allow_set) const {
1641   // Find clear/inferlist ops that are downstream of allow ops.
1642   absl::flat_hash_set<int> downstream_of_allow_set;
1643   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1644     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1645     if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
1646       continue;
1647     }
1648     DfsTypeTraversal(
1649         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1650         DfsTypePredicates::Enter([&](int idx) -> bool {
1651           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1652           return idx == root_idx ||
1653                  (!downstream_of_allow_set.count(idx) &&
1654                   !f16_allowlist_.count(item.node->op()) &&
1655                   !deny_set.count(idx) && ShouldProcess(*item.node) &&
1656                   // TODO(benbarsdell): Consider allowing propagation through
1657                   // ops that are already float16 in order to reduce the number
1658                   // of casts.
1659                   IsFloat32(item) && SupportsF16(item) &&
1660                   (f16_clearlist_.count(item.node->op()) ||
1661                    f16_inferlist_.count(item.node->op())));
1662         }),
1663         DfsTypeCallbacks::PreOrder(
1664             [&](int idx) { downstream_of_allow_set.insert(idx); }));
1665   }
1666 
1667   // Set nodes that are both downstream and upstream of allow ops to allow.
1668   absl::flat_hash_set<int> upstream_of_allow_set;
1669   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1670     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1671     if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
1672         !f16_allowlist_.count(root.node->op())) {
1673       continue;
1674     }
1675     DfsTypeTraversal(
1676         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1677         DfsTypePredicates::Enter([&](int idx) -> bool {
1678           return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
1679                                      downstream_of_allow_set.count(idx));
1680         }),
1681         DfsTypeCallbacks::PreOrder([&](int idx) {
1682           upstream_of_allow_set.insert(idx);
1683           bool inserted = allow_set->insert(idx).second;
1684           if (VLOG_IS_ON(2) && inserted) {
1685             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1686             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1687                     << " of " << item.node->op() << " node "
1688                     << item.node->name() << " ALLOW";
1689           }
1690         }));
1691   }
1692 }
1693 
PropagateAllowThroughClear(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1694 void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
1695     const absl::flat_hash_set<int>& deny_set,
1696     absl::flat_hash_set<int>* allow_set) const {
1697   // Propagate allow from allow nodes through clearlist ops.
1698   absl::flat_hash_set<int> clear_prop_set;
1699   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1700     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1701     if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1702         !allow_set->count(root_idx)) {
1703       continue;
1704     }
1705     DfsTypeTraversal(
1706         graph_type_view_, {&root},
1707         TypeTraversalDirection::kFollowInputsAndOutputs,
1708         DfsTypePredicates::Enter([&](int idx) -> bool {
1709           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1710           return idx == root_idx ||
1711                  (!allow_set->count(idx) && !deny_set.count(idx) &&
1712                   ShouldProcess(*item.node) && IsFloat32(item) &&
1713                   SupportsF16(item) &&
1714                   (f16_clearlist_.count(item.node->op())) &&
1715                   // We don't propagate (backwards) through nodes that read
1716                   // Variables because it can break the behavior of TensorBoard
1717                   // visualization and/or (in the case of Enter nodes) the model
1718                   // itself. This is only a problem for non-resource variables.
1719                   !NodeImplicitlyReadsNonResourceVariable(*item.node));
1720         }),
1721         DfsTypeCallbacks::PreOrder([&](int idx) {
1722           clear_prop_set.insert(idx);
1723           bool inserted = allow_set->insert(idx).second;
1724           if (VLOG_IS_ON(2) && inserted) {
1725             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1726             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1727                     << " of " << item.node->op() << " node "
1728                     << item.node->name() << " ALLOW";
1729           }
1730         }));
1731   }
1732 }
1733 
1734 // If ops have one or more type_attr, But this type_attr could not be converted
1735 // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only
1736 // support float. So we will remove this node from allow_set.
1737 // Also don't convert quantized ops to FP16.
RemoveAllowsetWithFp32(absl::flat_hash_set<int> * allow_set) const1738 void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32(
1739     absl::flat_hash_set<int>* allow_set) const {
1740   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1741     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1742     if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) &&
1743         (!SupportsF16DataType(root) || IsQuantized(root))) {
1744       auto erased = allow_set->erase(root_idx);
1745       if (VLOG_IS_ON(2) && erased) {
1746         VLOG(2) << "UnPainting type " << root.type_attr.DebugString()
1747                 << " of node " << root.node->name() << " ALLOW because its op "
1748                 << root.node->op() << " is not support F16 DataType";
1749       }
1750     }
1751   }
1752 }
1753 
1754 // Forces NextIteration nodes and their output Merge node(s) to have the same
1755 // color. Specifically, it removes them all from allow_set if any of the Merge
1756 // nodes is not in allow_set, otherwise it adds the NextIteration node to
1757 // allow_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * allow_set) const1758 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1759     absl::flat_hash_set<int>* allow_set) const {
1760   for (const NodeDef& node : graph_->node()) {
1761     if (node.op() == "NextIteration") {
1762       GraphView::OutputPort output_port(&node, 0);
1763       const auto& fanout = graph_view_.GetFanout(output_port);
1764       std::vector<int> merge_idxs;
1765       merge_idxs.reserve(fanout.size());
1766       bool any_merge_is_not_allow = false;
1767       for (const auto& output : fanout) {
1768         const NodeDef& merge_node = *output.node;
1769         if (merge_node.op() != "Merge") {
1770           return errors::FailedPrecondition(
1771               "Expected Merge node after NextIteration, got ", merge_node.op());
1772         }
1773         const absl::optional<int> maybe_merge_idx =
1774             graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1775         if (!maybe_merge_idx.has_value()) {
1776           return errors::Internal("Type attribute T of Merge node ",
1777                                   merge_node.name(),
1778                                   " not found in graph view");
1779         }
1780         int merge_idx = maybe_merge_idx.value();
1781         merge_idxs.push_back(merge_idx);
1782         any_merge_is_not_allow =
1783             any_merge_is_not_allow || !allow_set->count(merge_idx);
1784       }
1785       const absl::optional<int> maybe_nextiter_idx =
1786           graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1787       if (!maybe_nextiter_idx.has_value()) {
1788         return errors::Internal("Type attribute T of NextIteration node ",
1789                                 node.name(), " not found in graph view");
1790       }
1791       int nextiter_idx = maybe_nextiter_idx.value();
1792       if (any_merge_is_not_allow) {
1793         for (int merge_idx : merge_idxs) {
1794           if (allow_set->erase(merge_idx)) {
1795             VLOG(2) << "Painting type T of Merge node "
1796                     << graph_type_view_.GetNode(merge_idx)->node->name()
1797                     << " DENY to match the color of its sibling Merge nodes "
1798                        "with common NextIteration node "
1799                     << node.name();
1800           }
1801         }
1802         if (allow_set->erase(nextiter_idx)) {
1803           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1804                   << " DENY to match the color of its output Merge node(s)";
1805         }
1806       } else {
1807         if (allow_set->insert(nextiter_idx).second) {
1808           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1809                   << " ALLOW to match the color of its output Merge node(s)";
1810         }
1811       }
1812     }
1813   }
1814   return Status::OK();
1815 }
1816 
1817 // Forces all of the given Tensor List nodes into the same color set.
ForceColorMatchBetweenTensorListOps(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,absl::flat_hash_set<int> * allow_set,absl::flat_hash_set<int> * deny_set) const1818 void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
1819     const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1820     absl::flat_hash_set<int>* allow_set,
1821     absl::flat_hash_set<int>* deny_set) const {
1822   bool any_deny = false;
1823   bool any_allow = false;
1824   std::vector<int> node_type_idxs;
1825   node_type_idxs.reserve(tensor_list_nodes.size());
1826   for (const NodeDef* node : tensor_list_nodes) {
1827     const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node);
1828     const absl::optional<int> maybe_node_type_idx =
1829         graph_type_view_.GetNodeIndex(node_type);
1830     DCHECK(maybe_node_type_idx.has_value())
1831         << "Type attribute " << node_type.type_attr.DebugString() << " of node "
1832         << node->name() << " not found in graph view";
1833     node_type_idxs.push_back(maybe_node_type_idx.value());
1834   }
1835   for (int node_type_idx : node_type_idxs) {
1836     if (deny_set->count(node_type_idx)) {
1837       any_deny = true;
1838       break;
1839     } else if (allow_set->count(node_type_idx)) {
1840       any_allow = true;
1841     }
1842   }
1843   if (!any_deny && !any_allow) return;
1844   for (int node_type_idx : node_type_idxs) {
1845     const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
1846     VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
1847             << node_type.node->op() << " node " << node_type.node->name() << " "
1848             << (any_deny ? "DENY" : "ALLOW")
1849             << " because at least one of its siblings is "
1850             << (any_deny ? "DENY" : "ALLOW");
1851     if (any_deny) {
1852       allow_set->erase(node_type_idx);
1853       deny_set->insert(node_type_idx);
1854     } else {
1855       allow_set->insert(node_type_idx);
1856     }
1857   }
1858 }
1859 
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const1860 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
1861     const NodeDef& node) const {
1862   if (node.op() == "Identity" || node.op() == "Enter") {
1863     GraphView::InputPort node_input(&node, 0);
1864     MutableGraphView::OutputPort prev_output =
1865         graph_view_.GetRegularFanin(node_input);
1866     const NodeDef* input = prev_output.node;
1867     if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
1868                                                input->op() == "VariableV2")) ||
1869                   (node.op() == "Enter" &&
1870                    NodeImplicitlyReadsNonResourceVariable(*input)))) {
1871       return true;
1872     }
1873   }
1874   return false;
1875 }
1876 
1877 // This adds existing Cast nodes to allow_set if all of their outputs are allow,
1878 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsAllowIfAllOutputsAllow(absl::flat_hash_set<int> * allow_set) const1879 void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
1880     absl::flat_hash_set<int>* allow_set) const {
1881   int num_nodes_preop = graph_->node_size();
1882   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1883     NodeDef* node = graph_->mutable_node(node_idx);
1884     NodeTypeId node_type(node, TypeAttrId("DstT"));
1885     if (node->op() != "Cast" || !IsFloat32(node_type)) {
1886       continue;
1887     }
1888     bool all_fanouts_allow = true;
1889     MutableGraphView::OutputPort src(node, 0);
1890     const auto& fanout = graph_view_.GetFanout(src);
1891     for (const MutableGraphView::InputPort& dst : fanout) {
1892       TypeAttrId dst_type_attr =
1893           node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1894       const absl::optional<int> maybe_dst_type_idx =
1895           graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1896       DCHECK(maybe_dst_type_idx.has_value())
1897           << "Type attribute " << dst_type_attr.DebugString() << " of node "
1898           << dst.node->name() << " not found in graph view";
1899       int dst_type_idx = maybe_dst_type_idx.value();
1900       bool dst_is_allow = allow_set->count(dst_type_idx);
1901       if (!dst_is_allow) {
1902         all_fanouts_allow = false;
1903         break;
1904       }
1905     }
1906     if (!fanout.empty() && all_fanouts_allow) {
1907       const absl::optional<int> maybe_node_type_idx =
1908           graph_type_view_.GetNodeIndex(node_type);
1909       DCHECK(maybe_node_type_idx.has_value())
1910           << "Type attribute " << node_type.type_attr.DebugString()
1911           << " of node " << node_type.node->name()
1912           << " not found in graph view";
1913       int node_type_idx = maybe_node_type_idx.value();
1914       allow_set->insert(node_type_idx);
1915     }
1916   }
1917 }
1918 
1919 // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
1920 // inserts Cast nodes at node outputs for all edges that connect
1921 // allow-painted <-> non-allow-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & allow_set)1922 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
1923     const absl::flat_hash_set<int>& allow_set) {
1924   int num_nodes_changed = 0;
1925   int num_nonvar_casts_to_f16 = 0;
1926   int num_nodes_preop = graph_->node_size();
1927   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1928     NodeDef* node = graph_->mutable_node(node_idx);
1929     for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
1930       const absl::optional<int> maybe_node_type_idx =
1931           graph_type_view_.GetNodeIndex(node->name(), type_attr);
1932       if (!maybe_node_type_idx.has_value()) {
1933         return errors::Internal("Type attribute ", type_attr.DebugString(),
1934                                 " of ", node->op(), " node ", node->name(),
1935                                 " not found in graph view");
1936       }
1937       int node_type_idx = maybe_node_type_idx.value();
1938       if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
1939       bool src_is_allow = allow_set.count(node_type_idx);
1940       if (src_is_allow) {
1941         VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
1942                 << node->op() << " node " << node->name() << " to "
1943                 << DataTypeString(target_dtype_);
1944         if (!SetDataType(node, type_attr, target_dtype_)) {
1945           return errors::Internal("Failed to set type attribute");
1946         }
1947         ++num_nodes_changed;
1948       }
1949       for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
1950         MutableGraphView::OutputPort src(node, output_port);
1951         NodeDef* added_cast_node = nullptr;
1952         // Note: This is copied so that edges can be modified inside the loop.
1953         auto fanout = graph_view_.GetFanout(src);
1954         for (const MutableGraphView::InputPort& dst : fanout) {
1955           TypeAttrId dst_type_attr =
1956               node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1957           const absl::optional<int> maybe_dst_type_idx =
1958               graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1959           if (!maybe_dst_type_idx.has_value()) {
1960             return errors::Internal("Type attribute ",
1961                                     dst_type_attr.DebugString(), " of ",
1962                                     dst.node->op(), " node ", dst.node->name(),
1963                                     " not found in graph view");
1964           }
1965           int dst_type_idx = maybe_dst_type_idx.value();
1966           bool dst_is_allow = allow_set.count(dst_type_idx);
1967           if (src_is_allow != dst_is_allow) {
1968             if (!added_cast_node) {
1969               bool to_f16 = dst_is_allow;
1970               VLOG(1) << "Inserting cast to "
1971                       << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
1972                       << " at " << src.node->op() << " " << src.node->name()
1973                       << ":" << src.port_id;
1974               added_cast_node = graph_view_.AddNode(
1975                   BuildCastNode(src, to_f16, src.node->device()));
1976               if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
1977                   !NodeImplicitlyReadsNonResourceVariable(*node)) {
1978                 ++num_nonvar_casts_to_f16;
1979               }
1980             }
1981             TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
1982                 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
1983           }
1984         }
1985       }
1986     }
1987   }
1988   // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
1989   // since many Python users will see this message.
1990   const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
1991   LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
1992             << " nodes to " << type_str << " precision using "
1993             << num_nonvar_casts_to_f16 << " cast(s) to " << type_str
1994             << " (excluding Const and Variable casts)";
1995   return Status::OK();
1996 }
1997 
GetNumGPUs(const Cluster & cluster)1998 int GetNumGPUs(const Cluster& cluster) {
1999   auto devices = cluster.GetDevices();
2000   int num_gpus = 0;
2001   for (const auto& device : devices) {
2002     const DeviceProperties& device_properties = device.second;
2003     if (device_properties.type() == "GPU" &&
2004         (ShouldIgnorePerformance() || HasFastFP16Support(device_properties))) {
2005       num_gpus++;
2006     }
2007   }
2008   return num_gpus;
2009 }
2010 
2011 }  // end namespace
2012 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)2013 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
2014                                     GraphDef* output) {
2015   if (cluster == nullptr) {
2016     return errors::InvalidArgument("cluster == nullptr");
2017   }
2018 
2019 #if !defined(INTEL_MKL)
2020   if (mode_ == AutoMixedPrecisionMode::MKL) {
2021     return errors::Unimplemented(
2022         "The auto_mixed_precision_mkl optimizer cannot be used since "
2023         "this build of TensorFlow is not compiled with MKL support for "
2024         "bfloat16. "
2025         "For information on MKL builds, see: "
2026         "https://software.intel.com/en-us/articles/intel-optimization-for-"
2027         "tensorflow-installation-guide");
2028   }
2029 #endif  // INTEL_MKL
2030 
2031   // Start by copying input graph to output.
2032   *output = item.graph;
2033 
2034   int num_gpus = GetNumGPUs(*cluster);
2035   if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
2036     // AutoMixedPrecision is currently only tuned for GPU.
2037     LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
2038                  << " graph optimizer";
2039     return Status::OK();
2040   }
2041 
2042   // Optimize the output graph in-place.
2043   AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
2044                                    item.id, mode_);
2045   if (item.id == "tf_graph") {
2046     LOG(INFO) << "Running " << name() << " graph optimizer";
2047   } else {
2048     VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
2049   }
2050   Status status = optimizer.Optimize();
2051   if (!status.ok()) {
2052     // Restore the original graph.
2053     *output = item.graph;
2054     LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
2055   }
2056   return status;
2057 }
2058 
2059 }  // end namespace grappler
2060 }  // end namespace tensorflow
2061