• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17 
18 #include <fstream>
19 #include <memory>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/grappler/clusters/cluster.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/mutable_graph_view.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/lib/io/path.h"
38 #include "tensorflow/core/lib/strings/numbers.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/env_var.h"
43 
44 namespace tensorflow {
45 namespace grappler {
46 namespace {
47 
48 #if GOOGLE_CUDA
49 const std::pair<int, int> kMinGPUArch = {7, 0};
50 #else
51 const std::pair<int, int> kMinGPUArch = {0, 0};
52 #endif
53 
54 const char kSuffix[] = "AutoMixedPrecision";
55 const char kCastToFp16[] = "CastToFp16";
56 const char kCastToBf16[] = "CastToBf16";
57 const char kCastToFp32[] = "CastToFp32";
58 
59 // Instances of this class represent unique type attribute identifiers within a
60 // node. It handles regular type attributes, list type attributes (where
61 // type_index is set to the index in the type list), and fixed types.
62 struct TypeAttrId {
63   static constexpr int kSingleType = -1;
64 
TypeAttrIdtensorflow::grappler::__anonab6217530111::TypeAttrId65   explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
66       : attr_name(_attr_name),
67         type_index(_type_index),
68         fixed_type(DT_INVALID) {}
69 
TypeAttrIdtensorflow::grappler::__anonab6217530111::TypeAttrId70   explicit TypeAttrId(DataType _fixed_type)
71       : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
72 
operator ==tensorflow::grappler::__anonab6217530111::TypeAttrId73   bool operator==(const TypeAttrId& other) const {
74     return attr_name == other.attr_name && type_index == other.type_index &&
75            fixed_type == other.fixed_type;
76   }
77 
operator <tensorflow::grappler::__anonab6217530111::TypeAttrId78   bool operator<(const TypeAttrId& other) const {
79     return std::make_tuple(attr_name, type_index, fixed_type) <
80            std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
81   }
82 
83   template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)84   friend H AbslHashValue(H h, const TypeAttrId& ta) {
85     return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
86   }
87 
DebugStringtensorflow::grappler::__anonab6217530111::TypeAttrId88   string DebugString() const {
89     if (!attr_name.empty()) {
90       if (type_index == kSingleType) {
91         return attr_name;
92       } else {
93         return strings::StrCat(attr_name, "[", type_index, "]");
94       }
95     } else {
96       return tensorflow::DataTypeString(fixed_type);
97     }
98   }
99 
100   string attr_name;
101   // If attr_name is a list(type), this is the index into the list. Otherwise
102   // this is kSingleType.
103   int type_index;
104   DataType fixed_type;
105 };
106 
107 // Returns the data type of the given type attribute, or DT_INVALID if the type
108 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)109 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
110   if (type_attr.attr_name.empty()) {
111     return type_attr.fixed_type;
112   }
113   if (!node.attr().count(type_attr.attr_name)) {
114     return DT_INVALID;
115   }
116   const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
117   if (type_attr.type_index == TypeAttrId::kSingleType) {
118     return attr_value.type();
119   } else {
120     if (type_attr.type_index < 0 ||
121         type_attr.type_index >= attr_value.list().type_size()) {
122       return DT_INVALID;
123     }
124     return attr_value.list().type(type_attr.type_index);
125   }
126 }
127 
128 // Sets the data type of the given type attribute. Returns false if the type
129 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)130 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
131   if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
132     return false;
133   }
134   AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
135   if (type_attr.type_index == TypeAttrId::kSingleType) {
136     attr_value.set_type(type);
137   } else {
138     if (type_attr.type_index < 0 ||
139         type_attr.type_index >= attr_value.list().type_size()) {
140       return false;
141     }
142     attr_value.mutable_list()->set_type(type_attr.type_index, type);
143   }
144   return true;
145 }
146 
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)147 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
148                                                const OpDef::ArgDef& arg_def) {
149   std::vector<std::pair<int, int>> argdef_inds;
150   if (!arg_def.type_list_attr().empty()) {
151     int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
152     for (int type_idx = 0; type_idx < num_types; ++type_idx) {
153       argdef_inds.push_back({arg_idx, type_idx});
154     }
155   } else {
156     int num_repeat = 1;
157     if (node.attr().count(arg_def.number_attr())) {
158       num_repeat = node.attr().at(arg_def.number_attr()).i();
159     }
160     argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
161   }
162   return argdef_inds;
163 }
164 
165 // Returns a pair (arg_index, type_index) for each input to the node, where
166 // arg_index is the index of the input_arg in op_def and type_index is the index
167 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)168 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
169                                                         const OpDef& op_def) {
170   std::vector<std::pair<int, int>> argdef_inds;
171   argdef_inds.reserve(op_def.input_arg_size());  // Final size may differ.
172   for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
173     const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
174     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
175     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
176                        arg_results.end());
177   }
178   return argdef_inds;
179 }
180 
181 // Returns a pair (arg_index, type_index) for each output to the node, where
182 // arg_index is the index of the output_arg in op_def and type_index is the
183 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)184 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
185                                                          const OpDef& op_def) {
186   std::vector<std::pair<int, int>> argdef_inds;
187   argdef_inds.reserve(op_def.output_arg_size());  // Final size may differ.
188   for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
189     const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
190     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
191     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
192                        arg_results.end());
193   }
194   return argdef_inds;
195 }
196 
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)197 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
198   if (!arg_def.type_list_attr().empty()) {
199     return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
200   } else if (!arg_def.type_attr().empty()) {
201     return TypeAttrId(arg_def.type_attr());
202   } else {
203     return TypeAttrId(arg_def.type());
204   }
205 }
206 
NonControlInputs(const NodeDef & node)207 std::vector<int> NonControlInputs(const NodeDef& node) {
208   std::vector<int> pos;
209   for (int i = 0; i < node.input_size(); i++) {
210     if (!IsControlInput(node.input(i))) {
211       pos.push_back(i);
212     }
213   }
214   return pos;
215 }
216 
217 // A utility class to lookup node type attributes and type attribute <->
218 // input/output port mappings.
219 class NodeTypeAttrMap {
220  public:
NodeTypeAttrMap()221   NodeTypeAttrMap() {}
222 
NodeTypeAttrMap(const GraphDef & graph)223   explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
224 
Init(const GraphDef & graph)225   Status Init(const GraphDef& graph) {
226     if (graph_ != nullptr) {
227       return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
228     }
229     graph_ = &graph;
230     function_library_.reset(
231         new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
232     for (const NodeDef& node : graph.node()) {
233       TF_RETURN_IF_ERROR(AddNode(node));
234     }
235     return Status::OK();
236   }
237 
is_initialized() const238   bool is_initialized() const { return graph_ != nullptr; }
239 
240   // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const241   absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
242     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
243     absl::flat_hash_set<TypeAttrId> type_attrs;
244     const auto iter = type2io_.find(&node);
245     CHECK(iter != type2io_.end());  // Crash Ok
246     for (const auto& key_value : iter->second) {
247       type_attrs.insert(key_value.first);
248     }
249     return type_attrs;
250   }
251 
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const252   const absl::flat_hash_set<int>& GetInputPorts(
253       const NodeDef& node, const TypeAttrId& type_attr) const {
254     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
255     return type2io_.at(&node).at(type_attr).first;
256   }
257 
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const258   const absl::flat_hash_set<int>& GetOutputPorts(
259       const NodeDef& node, const TypeAttrId& type_attr) const {
260     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
261     return type2io_.at(&node).at(type_attr).second;
262   }
263 
GetInputTypeAttr(const NodeDef & node,int port) const264   TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
265     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
266     auto type_vec = io2type_.at(&node).first;
267     CHECK_GE(port, 0);                // Crash Ok
268     CHECK_LT(port, type_vec.size());  // Crash Ok
269     return type_vec[port];
270   }
271 
GetOutputTypeAttr(const NodeDef & node,int port) const272   TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
273     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
274     auto type_vec = io2type_.at(&node).second;
275     CHECK_GE(port, 0);                // Crash Ok
276     CHECK_LT(port, type_vec.size());  // Crash Ok
277     return type_vec[port];
278   }
279 
280  private:
AddNode(const NodeDef & node)281   Status AddNode(const NodeDef& node) {
282     const OpDef* op_def_ptr = nullptr;
283     TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
284     const OpDef& op_def = *op_def_ptr;
285     auto& type2io_entry = type2io_[&node];
286     auto& io2type_entry = io2type_[&node];
287     auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
288     if (NonControlInputs(node).size() != input_arg_inds.size()) {
289       return errors::InvalidArgument(
290           "Expected ", node.op(), " node ", node.name(), " to have ",
291           input_arg_inds.size(), " non-control input(s), but got ",
292           node.input_size());
293     }
294     // Note that the mappings generated here include inputs/outputs with fixed
295     // types. This makes the mappings complete (all inputs and outputs are
296     // included), and allows the graph rewriter to propagate deny paint
297     // from/through ops with fixed types.
298     io2type_entry.first.reserve(input_arg_inds.size());
299     for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
300       const auto& arg_inds = input_arg_inds[i];
301       const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
302       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
303       if (!type_attr.attr_name.empty() &&
304           !node.attr().count(type_attr.attr_name)) {
305         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
306                                        " is not present in node ", node.name());
307       }
308       type2io_entry[type_attr].first.insert(i);
309       io2type_entry.first.push_back(type_attr);
310     }
311 
312     auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
313     io2type_entry.second.reserve(output_arg_inds.size());
314     for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
315       const auto& arg_inds = output_arg_inds[i];
316       const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
317       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
318       if (!type_attr.attr_name.empty() &&
319           !node.attr().count(type_attr.attr_name)) {
320         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
321                                        " is not present in node ", node.name());
322       }
323       type2io_entry[type_attr].second.insert(i);
324       io2type_entry.second.push_back(type_attr);
325     }
326 
327     // Also ensure that type attributes that aren't associated with any inputs
328     // or outputs (e.g., StackV2's elem_type) are added to the map.
329     for (const auto& attr : node.attr()) {
330       const string& attr_name = attr.first;
331       if (!attr_name.empty() && attr_name[0] == '_') continue;
332       const AttrValue& attr_value = attr.second;
333       const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
334       if (!attr_def) {
335         return errors::InvalidArgument("AttrDef not found for attribute ",
336                                        attr_name, " of node ", node.name());
337       }
338       if (attr_def->type() == "type") {
339         type2io_entry[TypeAttrId(attr_name)];
340       } else if (attr_def->type() == "list(type)") {
341         for (int i = 0; i < attr_value.list().type_size(); ++i) {
342           type2io_entry[TypeAttrId(attr_name, i)];
343         }
344       }
345     }
346     return Status::OK();
347   }
348 
349   // WARN: `graph_` must outlive this object (node pointers must remain valid).
350   const GraphDef* graph_ = nullptr;  // do not own
351   std::unique_ptr<FunctionLibraryDefinition> function_library_;
352 
353   typedef absl::flat_hash_set<int> IntSet;
354   // Maps a type attr id -> (input port set, output port set)
355   typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
356   // Maps a node -> type attr mapping
357   absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
358   // Maps a port -> type attr id
359   typedef std::vector<TypeAttrId> TypeAttrIdVec;
360   // Maps a node -> (input port mapping, output port mapping)
361   absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
362       io2type_;
363 };
364 
365 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anonab6217530111::NodeTypeId366   NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
367       : node(_node), type_attr(_type_attr) {}
368 
369   const NodeDef* node;
370   TypeAttrId type_attr;
371 
operator ==tensorflow::grappler::__anonab6217530111::NodeTypeId372   bool operator==(const NodeTypeId& other) const {
373     return node == other.node && type_attr == other.type_attr;
374   }
375 
376   template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)377   friend H AbslHashValue(H h, const NodeTypeId& nt) {
378     return H::combine(std::move(h), nt.node, nt.type_attr);
379   }
380 };
381 
382 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anonab6217530111::NodeTypeIdEdge383   NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
384       : src(_src), dst(_dst) {}
385   NodeTypeId src;
386   NodeTypeId dst;
387 };
388 
389 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
390 // used instead of this modified version.
391 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
392 // the vertices instead of just NodeDef.
393 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
394 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
395 // will be an edge from (A, T) to (B, U).
396 class GraphTypeTopologyView {
397  public:
398   GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)399   explicit GraphTypeTopologyView(bool skip_invalid_edges)
400       : skip_invalid_edges_(skip_invalid_edges) {}
401 
402   // Initialize graph topology view from the graph. It's possible to pass
403   // additional edges that do not exist in a graph, but must be respected when
404   // computing graph topology. Example: Tensorflow runtime allows concurrent
405   // execution of dequeue/enqueue ops from the same queue resource, but we might
406   // want to enforce ordering between them for the purpose of graph analysis.
407   Status InitializeFromGraph(const GraphDef& graph,
408                              const NodeTypeAttrMap& node_type_map);
409 
410   Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges);
411 
is_initialized() const412   bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const413   int num_nodes() const { return num_nodes_; }
graph() const414   const GraphDef* graph() const { return graph_; }
415 
416   // Returns true iff the node exists in the underlying graph.
417   bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
418 
419   // Finds a node by name or returns `nullptr` if it's not in the graph.
420   const NodeTypeId* GetNode(absl::string_view node_name,
421                             const TypeAttrId& type_attr) const;
422   // Returns a node corresponding to the given node index.
423   const NodeTypeId* GetNode(int node_idx) const;
424 
425   // Returns a node index for the given node name, if the name exists in the
426   // underlying graph. Otherwise returns empty optional.
427   const absl::optional<int> GetNodeIndex(absl::string_view node_name,
428                                          const TypeAttrId& type_attr) const;
429   // Returns a node index for the given node, if the node belongs to the
430   // underlying graph. Otherwise returns empty optional.
431   const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
432 
433   // Returns all the node indexes that are in the direct fanin of the given
434   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
435   const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
436   // Returns all the node indexes that are in the direct fanout of the given
437   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
438   const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
439 
440  private:
441   // The key type used to uniquely identify a type attribute on a node.
442   struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
443     typedef std::pair<absl::string_view, TypeAttrId> Base;
444 
445     // Inherit the set of constructors.
446     using Base::pair;
447 
448     template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)449     friend H AbslHashValue(H h, const NodeTypeKey& nt) {
450       return H::combine(std::move(h), nt.first, nt.second);
451     }
452   };
453 
454   // If true, all invalid edges and inputs (srd, dst or input node not found in
455   // a graph) will be skipped, otherwise initialization will fail with error.
456   bool skip_invalid_edges_ = false;
457 
458   // WARN: `graph_` must outlive this object and graph nodes must not be
459   // destructed, because node names captured with absl::string_view.
460   const GraphDef* graph_ = nullptr;  // do not own
461   int num_nodes_ = 0;
462   std::vector<NodeTypeId> node_type_attrs_;
463   absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
464   absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
465 
466   std::vector<absl::InlinedVector<int, 4>> fanins_;
467   std::vector<absl::InlinedVector<int, 2>> fanouts_;
468 
469   // We need a valid reference to return from GetFanin/GetFanout if the
470   // `node_idx` argument is outside of the [0, num_nodes_) range.
471   absl::InlinedVector<int, 4> empty_fanin_;
472   absl::InlinedVector<int, 2> empty_fanout_;
473 };
474 
475 template <typename T>
SortAndRemoveDuplicates(T * v)476 inline void SortAndRemoveDuplicates(T* v) {
477   std::sort(v->begin(), v->end());
478   v->erase(std::unique(v->begin(), v->end()), v->end());
479 }
480 
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)481 Status GraphTypeTopologyView::InitializeFromGraph(
482     const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
483   if (graph_ != nullptr) {
484     return errors::InvalidArgument(
485         "GraphTypeTopologyView is already initialized.");
486   }
487 
488   graph_ = &graph;
489   int num_nodedefs = graph.node_size();
490   node_name_to_index_.rehash(num_nodedefs);
491 
492   // Build maps from name to index.
493   node_type_attrs_.reserve(num_nodedefs);         // Only approximate.
494   node_type_name_to_index_.rehash(num_nodedefs);  // Only approximate.
495   for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
496     const NodeDef& node = graph.node(node_idx);
497     node_name_to_index_.emplace(node.name(), node_idx);
498 
499     for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
500       int node_type_idx = node_type_attrs_.size();
501       node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
502                                        node_type_idx);
503       node_type_attrs_.emplace_back(&node, type_attr);
504     }
505   }
506   num_nodes_ = node_type_attrs_.size();
507   fanins_.resize(num_nodes_);
508   fanouts_.resize(num_nodes_);
509 
510   // Add graph edges to the adjacency lists.
511   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
512     const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
513     auto input_ports =
514         node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
515     fanins_[node_type_idx].reserve(input_ports.size());
516     for (int port : input_ports) {
517       const string& input = node_type.node->input(port);
518       TensorId tensor = ParseTensorName(input);
519       const auto it = node_name_to_index_.find(tensor.node());
520       const bool valid_input = it != node_name_to_index_.end();
521 
522       if (!valid_input) {
523         const string error_message = absl::StrCat(
524             "Non-existent input ", input, " in node ", node_type.node->name());
525         if (skip_invalid_edges_) {
526           VLOG(3) << "Skip error: " << error_message;
527         } else {
528           return errors::InvalidArgument(error_message);
529         }
530       }
531 
532       if (valid_input) {
533         const int input_idx = it->second;
534         const NodeDef& input_node = graph_->node(input_idx);
535         TypeAttrId input_type_attr =
536             node_type_map.GetOutputTypeAttr(input_node, tensor.index());
537         const auto it2 = node_type_name_to_index_.find(
538             NodeTypeKey(input_node.name(), input_type_attr));
539         if (it2 == node_type_name_to_index_.end()) {
540           if (!skip_invalid_edges_) {
541             return errors::InvalidArgument("Did not find type attr ",
542                                            input_type_attr.DebugString(),
543                                            " in node ", input_node.name());
544           }
545           continue;
546         }
547         int input_node_type_idx = it2->second;
548         fanins_[node_type_idx].push_back(input_node_type_idx);
549         fanouts_[input_node_type_idx].push_back(node_type_idx);
550       }
551     }
552 
553     // Dedup the input list while it's still hot in cache.
554     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
555   }
556 
557   // Dedup outputs for all the graph nodes.
558   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
559     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
560   }
561 
562   return Status::OK();
563 }
564 
AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges)565 Status GraphTypeTopologyView::AddEphemeralEdges(
566     absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
567   // Add ephemeral edges to the adjacency lists.
568   for (const NodeTypeIdEdge& edge : ephemeral_edges) {
569     const auto src = node_name_to_index_.find(edge.src.node->name());
570     const bool valid_src = src != node_name_to_index_.end();
571 
572     if (!valid_src) {
573       const string error_message =
574           absl::StrCat("Non-existent src node: ", edge.src.node->name());
575       if (skip_invalid_edges_) {
576         VLOG(0) << "Skip error: " << error_message;
577       } else {
578         return errors::InvalidArgument(error_message);
579       }
580     }
581 
582     const auto dst = node_name_to_index_.find(edge.dst.node->name());
583     const bool valid_dst = dst != node_name_to_index_.end();
584 
585     if (!valid_dst) {
586       const string error_message =
587           absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
588       if (skip_invalid_edges_) {
589         VLOG(0) << "Skip error: " << error_message;
590       } else {
591         return errors::InvalidArgument(error_message);
592       }
593     }
594 
595     if (valid_dst && valid_src) {
596       // TODO(benbarsdell): Check for failure.
597       int src_node_type_idx = node_type_name_to_index_.at(
598           NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
599       int dst_node_type_idx = node_type_name_to_index_.at(
600           NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
601       fanins_[dst_node_type_idx].push_back(src_node_type_idx);
602       fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
603     }
604   }
605 
606   // Dedup inputs and outputs for all the graph nodes.
607   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
608     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
609     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
610   }
611 
612   return Status::OK();
613 }
614 
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const615 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
616                                     const TypeAttrId& type_attr) const {
617   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
618   NodeTypeKey key(node_name, type_attr);
619   const auto it = node_type_name_to_index_.find(key);
620   return it != node_type_name_to_index_.end();
621 }
622 
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const623 const NodeTypeId* GraphTypeTopologyView::GetNode(
624     absl::string_view node_name, const TypeAttrId& type_attr) const {
625   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
626   NodeTypeKey key(node_name, type_attr);
627   const auto it = node_type_name_to_index_.find(key);
628   return it == node_type_name_to_index_.end()
629              ? nullptr
630              : &node_type_attrs_.at(it->second);
631 }
632 
GetNode(int node_idx) const633 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
634   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
635   DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
636   return &node_type_attrs_.at(node_idx);
637 }
638 
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const639 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
640     absl::string_view node_name, const TypeAttrId& type_attr) const {
641   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
642   NodeTypeKey key(node_name, type_attr);
643   const auto it = node_type_name_to_index_.find(key);
644   DCHECK(it != node_type_name_to_index_.end())
645       << "Node doesn't exist in a graph";
646   return it == node_type_name_to_index_.end() ? absl::nullopt
647                                               : absl::make_optional(it->second);
648 }
649 
GetNodeIndex(const NodeTypeId & node) const650 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
651     const NodeTypeId& node) const {
652   return GetNodeIndex(node.node->name(), node.type_attr);
653 }
654 
GetFanin(int node_idx) const655 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
656     int node_idx) const {
657   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
658   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
659   DCHECK(is_valid_node_idx) << "node_idx is out of range";
660   return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
661 }
662 
GetFanout(int node_idx) const663 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
664     int node_idx) const {
665   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
666   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
667   DCHECK(is_valid_node_idx) << "node_idx is out of range";
668   return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
669 }
670 
671 enum class TypeTraversalDirection {
672   kFollowInputs,
673   kFollowOutputs,
674   kFollowInputsAndOutputs,
675 };
676 
677 // Encapsulate DFS callbacks that will be called during the graph traversal.
678 //
679 // If non-empty, the `pre_order` and `post_order` functors will be called on
680 // each reachable node (including the `from` nodes) in pre and post order. If
681 // loops are found, the `on_back_edge` functor will be called on the
682 // corresponding back edges. Moreover, the pre and post order will assume that
683 // these back edges will be cut.
684 struct DfsTypeCallbacks {
685   DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anonab6217530111::DfsTypeCallbacks686   DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
687                    std::function<void(int, int)> back_edge)
688       : pre_order(std::move(pre)),
689         post_order(std::move(post)),
690         on_back_edge(std::move(back_edge)) {}
691 
PreOrdertensorflow::grappler::__anonab6217530111::DfsTypeCallbacks692   static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
693     return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
694   }
695 
PostOrdertensorflow::grappler::__anonab6217530111::DfsTypeCallbacks696   static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
697     return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
698   }
699 
700   std::function<void(int)> pre_order;
701   std::function<void(int)> post_order;
702   std::function<void(int, int)> on_back_edge;
703 };
704 
705 // Encapsulate DFS predicates for traversing the graph.
706 //
707 // The `enter` predicate decides if traversal should enter the node, and the
708 // `advance` predicate decides if the traversal should follow inputs/outputs
709 // from the node.
710 //
711 // If predicates are empty (default initialized), it's assumed that we can enter
712 // into any node and advance from any node respectively.
713 struct DfsTypePredicates {
714   DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anonab6217530111::DfsTypePredicates715   DfsTypePredicates(std::function<bool(int)> enter,
716                     std::function<bool(int)> advance)
717       : enter(std::move(enter)), advance(std::move(advance)) {}
718 
Entertensorflow::grappler::__anonab6217530111::DfsTypePredicates719   static DfsTypePredicates Enter(std::function<bool(int)> enter) {
720     return DfsTypePredicates(std::move(enter), nullptr);
721   }
722 
Advancetensorflow::grappler::__anonab6217530111::DfsTypePredicates723   static DfsTypePredicates Advance(std::function<bool(int)> advance) {
724     return DfsTypePredicates(nullptr, std::move(advance));
725   }
726 
727   std::function<bool(int)> enter;
728   std::function<bool(int)> advance;
729 };
730 
731 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anonab6217530111::DfsStackElem732   DfsStackElem(int node, bool children_visited, int src)
733       : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anonab6217530111::DfsStackElem734   explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
735 
736   // Index of the node in the graph ∊ [0, num_nodes).
737   int node;
738   // `True` if visited all the input/output nodes (pushed all input/output nodes
739   // to the stack).
740   bool children_visited;
741   // Index of the node in the graph, from which we entered the `node`.
742   int src;
743 };
744 
745 enum class NodeState { kNotVisited, kVisiting, kDone };
746 
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)747 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
748                       const absl::Span<const NodeTypeId* const> from,
749                       const TypeTraversalDirection direction,
750                       const DfsTypePredicates& predicates,
751                       const DfsTypeCallbacks& callbacks) {
752   std::vector<DfsStackElem> stack;
753   stack.reserve(from.size());
754 
755   for (const NodeTypeId* node : from) {
756     const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
757     DCHECK(node_idx.has_value())
758         << "Illegal start node: " << node->node->name();
759     if (node_idx.has_value()) {
760       stack.emplace_back(node_idx.value());
761     }
762   }
763 
764   absl::flat_hash_map<int, NodeState> node_state;
765   while (!stack.empty()) {
766     DfsStackElem w = stack.back();
767     stack.pop_back();
768 
769     NodeState& state = node_state[w.node];
770     if (state == NodeState::kDone) continue;
771 
772     // Skip nodes that we should not enter.
773     if (predicates.enter && !predicates.enter(w.node)) {
774       state = NodeState::kDone;
775       continue;
776     }
777 
778     // We've processed all the children of this node.
779     if (w.children_visited) {
780       state = NodeState::kDone;
781       if (callbacks.post_order) {
782         callbacks.post_order(w.node);
783       }
784       continue;
785     }
786 
787     // Loop detected.
788     if (state == NodeState::kVisiting) {
789       if (callbacks.on_back_edge) {
790         callbacks.on_back_edge(w.src, w.node);
791       }
792       continue;
793     }
794 
795     state = NodeState::kVisiting;
796     if (callbacks.pre_order) {
797       callbacks.pre_order(w.node);
798     }
799 
800     // Enqueue the node again with the children_visited flag set to true.
801     stack.emplace_back(w.node, true, w.src);
802 
803     // Check if we can continue traversal from the current node.
804     if (predicates.advance && !predicates.advance(w.node)) {
805       continue;
806     }
807 
808     // Now enqueue the fanin/fanout nodes.
809     if (direction == TypeTraversalDirection::kFollowInputs ||
810         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
811       for (const int fanin : graph_type_view.GetFanin(w.node)) {
812         stack.emplace_back(fanin, false, w.node);
813       }
814     }
815     if (direction == TypeTraversalDirection::kFollowOutputs ||
816         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
817       for (const int fanout : graph_type_view.GetFanout(w.node)) {
818         stack.emplace_back(fanout, false, w.node);
819       }
820     }
821   }
822 }
823 
AllowedDataTypes(const OpDef::AttrDef & attr_def)824 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
825   const auto& allowed_types = attr_def.allowed_values().list().type();
826   if (allowed_types.empty()) {
827     return AllTypes();
828   }
829   uint32 dtype_mask = 0;
830   for (int dtype : allowed_types) {
831     dtype_mask |= 1u << dtype;
832   }
833   return DataTypeSet(dtype_mask);
834 }
835 
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)836 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
837   if (t_attr_id.attr_name.empty()) {
838     return ToSet(t_attr_id.fixed_type);
839   }
840   const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
841   CHECK(attr_def);  // Crash Ok
842   return AllowedDataTypes(*attr_def);
843 }
844 
ValidateLists(const gtl::FlatSet<string> & allow_list,const gtl::FlatSet<string> & deny_list,const gtl::FlatSet<string> & infer_list,const gtl::FlatSet<string> & clear_list)845 Status ValidateLists(const gtl::FlatSet<string>& allow_list,
846                      const gtl::FlatSet<string>& deny_list,
847                      const gtl::FlatSet<string>& infer_list,
848                      const gtl::FlatSet<string>& clear_list) {
849   std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
850                                           clear_list};
851   std::multiset<string> counts;
852   for (const auto& list : lists) {
853     counts.insert(list.begin(), list.end());
854   }
855   bool duplicates = false;
856   for (const auto& s : counts) {
857     if (counts.count(s) > 1) {
858       duplicates = true;
859       LOG(ERROR) << "Op present in multiple lists: " << s;
860     }
861   }
862   if (duplicates) {
863     return errors::InvalidArgument("Op lists have conflicting entries");
864   } else {
865     return Status::OK();
866   }
867 }
868 
HasInputOrOutputRefs(const NodeDef & node)869 bool HasInputOrOutputRefs(const NodeDef& node) {
870   const OpDef* op_def;
871   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
872   if (!status.ok()) {
873     return true;
874   }
875   for (const auto& input : op_def->input_arg()) {
876     if (input.is_ref()) {
877       return true;
878     }
879   }
880   for (const auto& output : op_def->output_arg()) {
881     if (output.is_ref()) {
882       return true;
883     }
884   }
885   return false;
886 }
887 
888 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)889 bool CanForceFP16(const NodeDef& node) {
890   return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
891          !IsStateful(node) && !HasInputOrOutputRefs(node);
892 }
893 
GetCudaVersion(const Cluster & cluster)894 int GetCudaVersion(const Cluster& cluster) {
895   auto devices = cluster.GetDevices();
896   for (const auto& device : devices) {
897     const DeviceProperties& device_properties = device.second;
898     if (device_properties.type() == "GPU") {
899       const auto& device_env = device_properties.environment();
900       auto it = device_env.find("cuda");
901       if (it != device_env.end()) {
902         string cuda_version_str = it->second;
903         return std::stoi(cuda_version_str);
904       }
905     }
906   }
907   return 0;
908 }
909 
GetCudnnVersion(const Cluster & cluster)910 int GetCudnnVersion(const Cluster& cluster) {
911   auto devices = cluster.GetDevices();
912   for (const auto& device : devices) {
913     const DeviceProperties& device_properties = device.second;
914     if (device_properties.type() == "GPU") {
915       const auto& device_env = device_properties.environment();
916       auto it = device_env.find("cudnn");
917       if (it != device_env.end()) {
918         string cudnn_version_str = it->second;
919         return std::stoi(cudnn_version_str);
920       }
921     }
922   }
923   return 0;
924 }
925 
926 class AutoMixedPrecisionImpl {
927  public:
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id,AutoMixedPrecisionMode mode)928   AutoMixedPrecisionImpl(Cluster* cluster,
929                          const std::unordered_set<string>& nodes_to_preserve,
930                          GraphDef* graph, string id,
931                          AutoMixedPrecisionMode mode)
932       : virtual_placer_(cluster->GetDevices()),
933         nodes_to_preserve_(nodes_to_preserve),
934         graph_(graph),
935         function_library_(OpRegistry::Global(), graph->library()),
936         id_(id),
937         graph_view_(graph),
938         cuda_version_(GetCudaVersion(*cluster)),
939         cudnn_version_(GetCudnnVersion(*cluster)),
940         mode_(mode),
941         target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
942                                                             : DT_BFLOAT16) {}
943 
944   Status Optimize();
945 
946  private:
947   typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
948 
get_mixed_precision_lists() const949   std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
950     switch (mode_) {
951       case AutoMixedPrecisionMode::CUDA:
952         return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
953                                                              cudnn_version_);
954       case AutoMixedPrecisionMode::MKL:
955         return std::make_unique<AutoMixedPrecisionListsMkl>();
956     }
957   }
958   Status PrintDebugLogs(bool preop, size_t timestamp);
959   void LogSkippedNode(const NodeDef& node) const;
960   bool MustPreserve(const NodeDef& node) const;
961   bool IsOnDevice(const NodeDef& node, const string& device_type) const;
962   bool IsOnSuitableGPUArch(const NodeDef& node) const;
963   bool ShouldProcess(const NodeDef& node) const;
964   bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
965   bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
966   void ConvertBatchNormOpsToV2();
967   bool SupportsF16(const NodeTypeId& node_type) const;
968   bool SupportsF16DataType(const NodeTypeId& node_type) const;
969   const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
970   bool IsSourceOrSinkOp(const string& op) const;
971   void FindFloat32TensorListOpClustersAndDenylistUnsafe(
972       std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
973       absl::flat_hash_set<int>* deny_set) const;
974   void FindTensorListImplicitFloat32Edges(
975       const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
976       std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
977   void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
978   void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const;
979   void PropagateDenyFwdThroughClearAndInfer(
980       absl::flat_hash_set<int>* deny_set) const;
981   void ForceColorMatchBetweenTensorListOps(
982       const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
983       absl::flat_hash_set<int>* allow_set,
984       absl::flat_hash_set<int>* deny_set) const;
985   void AddClearAndInferToAllowIfBetweenAllow(
986       const absl::flat_hash_set<int>& deny_set,
987       absl::flat_hash_set<int>* allow_set) const;
988   void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
989                                   absl::flat_hash_set<int>* allow_set) const;
990   Status ForceColorMatchOnRecurrentEdges(
991       absl::flat_hash_set<int>* allow_set) const;
992   void MakeCastsAllowIfAllOutputsAllow(
993       absl::flat_hash_set<int>* allow_set) const;
994   NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
995                         const string& device) const;
996   Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
997 
998   VirtualPlacer virtual_placer_;
999   std::unordered_set<string> nodes_to_preserve_;
1000   GraphDef* graph_;
1001   FunctionLibraryDefinition function_library_;
1002   string id_;
1003   MutableGraphView graph_view_;
1004   int cuda_version_;
1005   int cudnn_version_;
1006   NodeTypeAttrMap node_type_map_;
1007   GraphTypeTopologyView graph_type_view_;
1008   bool force_all_fp16_;
1009   AutoMixedPrecisionMode mode_;
1010   gtl::FlatSet<string> f16_allowlist_;
1011   gtl::FlatSet<string> f16_denylist_;
1012   gtl::FlatSet<string> f16_inferlist_;
1013   gtl::FlatSet<string> f16_clearlist_;
1014   absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1015   DataType target_dtype_;  // Either DT_HALF or DT_BFLOAT16
1016 };
1017 
BuildCastNode(const MutableGraphView::OutputPort & src,bool to_f16,const string & device) const1018 NodeDef AutoMixedPrecisionImpl::BuildCastNode(
1019     const MutableGraphView::OutputPort& src, bool to_f16,
1020     const string& device) const {
1021   DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
1022   DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
1023   const char* cast_string =
1024       !to_f16 ? kCastToFp32
1025               : target_dtype_ == DT_HALF ? kCastToFp16 : kCastToBf16;
1026   string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
1027                                 cast_string, "-", kSuffix);
1028   NodeDef node;
1029   node.set_name(name);
1030   node.set_op("Cast");
1031   node.set_device(device);
1032   node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
1033   (*node.mutable_attr())["SrcT"].set_type(src_type);
1034   (*node.mutable_attr())["DstT"].set_type(dst_type);
1035   (*node.mutable_attr())["Truncate"].set_b(false);
1036   return node;
1037 }
1038 
NodeHasF16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1039 bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
1040     const NodeDef& node, TypeAttrId taid) const {
1041   NodeDef node_copy(node);
1042   if (node.device().empty()) {
1043     string device_name = virtual_placer_.get_canonical_device_name(node);
1044     node_copy.set_device(device_name);
1045   }
1046   if (!SetDataType(&node_copy, taid, target_dtype_)) {
1047     return false;
1048   }
1049   return IsKernelRegisteredForNode(node_copy).ok();
1050 }
1051 
PrintDebugLogs(bool preop,size_t timestamp)1052 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1053   string prepend_path;
1054   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1055       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1056   if (prepend_path.empty()) return Status::OK();
1057 
1058   string suffix =
1059       strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1060 
1061   string fname =
1062       io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1063   std::fstream f;
1064   f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1065   f << graph_->SerializeAsString();
1066   f.close();
1067   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1068             << " graph as binary to " << fname;
1069 
1070   fname = io::JoinPath(prepend_path,
1071                        strings::StrCat("graphdef", suffix, ".pb.txt"));
1072   f.open(fname.c_str(), std::fstream::out);
1073   f << graph_->DebugString();
1074   f.close();
1075   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1076             << " graph as text to " << fname;
1077 
1078   if (!preop) {
1079     fname = io::JoinPath(prepend_path,
1080                          strings::StrCat("paintbuckets", suffix, ".txt"));
1081     f.open(fname.c_str(), std::fstream::out);
1082     std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1083         get_mixed_precision_lists();
1084     f << "AllowList:\n";
1085     for (const auto& x : mp_lists->AllowList()) {
1086       f << x << "\n";
1087     }
1088     f << "\nDenyList:\n";
1089     for (const auto& x : mp_lists->DenyList()) {
1090       f << x << "\n";
1091     }
1092     f << "\nInferList:\n";
1093     for (const auto& x : mp_lists->InferList()) {
1094       f << x << "\n";
1095     }
1096     f << "\nClearList:\n";
1097     for (const auto& x : mp_lists->ClearList()) {
1098       f << x << "\n";
1099     }
1100     f.close();
1101     LOG(INFO) << "Saved paint bucket info to " << fname;
1102   }
1103   return Status::OK();
1104 }
1105 
LogSkippedNode(const NodeDef & node) const1106 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
1107   VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1108           << " because it "
1109           << (MustPreserve(node)
1110                   ? "must be preserved"
1111                   : "is not on the GPU, or the GPU arch is not suitable");
1112 }
1113 
MustPreserve(const NodeDef & node) const1114 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1115   return nodes_to_preserve_.count(node.name());
1116 }
1117 
IsOnDevice(const NodeDef & node,const string & device_type) const1118 bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
1119                                         const string& device_type) const {
1120   string device_name;
1121   if (node.device().empty()) {
1122     device_name = virtual_placer_.get_canonical_device_name(node);
1123   } else {
1124     device_name = node.device();
1125   }
1126   string device;
1127   string not_used;
1128   if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
1129       absl::StrContains(absl::AsciiStrToLower(device),
1130                         absl::AsciiStrToLower(device_type))) {
1131     return true;
1132   }
1133   return false;
1134 }
1135 
1136 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)1137 std::pair<int, int> GetDeviceGPUArch(
1138     const DeviceProperties& device_properties) {
1139   if (device_properties.type() != "GPU") return {0, 0};
1140   string arch_str = device_properties.environment().at("architecture");
1141   std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
1142   if (split_arch_str.empty()) {
1143     return {0, 0};
1144   }
1145 
1146   int major, minor;
1147   if (!strings::safe_strto32(split_arch_str[0], &major)) {
1148     return {0, 0};
1149   }
1150 
1151   if (split_arch_str.size() > 1) {
1152     if (strings::safe_strto32(split_arch_str[1], &minor)) {
1153       return {major, minor};
1154     } else {
1155       return {0, 0};
1156     }
1157   } else {
1158     return {major, 0};
1159   }
1160 }
1161 
IsOnSuitableGPUArch(const NodeDef & node) const1162 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1163   return GetDeviceGPUArch(virtual_placer_.get_device(node)) >= kMinGPUArch;
1164 }
1165 
ShouldProcess(const NodeDef & node) const1166 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1167   return should_process_nodes_.count(&node);
1168 }
1169 
IsFloat32(const NodeTypeId & node_type)1170 bool IsFloat32(const NodeTypeId& node_type) {
1171   return GetDataType(*node_type.node, node_type.type_attr) ==
1172          DataType::DT_FLOAT;
1173 }
1174 
IsTensorListOp(const string & op)1175 bool IsTensorListOp(const string& op) {
1176   return op.find("TensorList") != string::npos;
1177 }
1178 
IsTensorListReaderOp(const string & op)1179 bool IsTensorListReaderOp(const string& op) {
1180   static const gtl::FlatSet<string> tensor_list_reader_ops = {
1181       "TensorListConcat",  "TensorListConcatV2", "TensorListGather",
1182       "TensorListGetItem", "TensorListPopBack",  "TensorListStack"};
1183   return tensor_list_reader_ops.count(op);
1184 }
1185 
IsTensorListWriterOp(const string & op)1186 bool IsTensorListWriterOp(const string& op) {
1187   static const gtl::FlatSet<string> tensor_list_writer_ops = {
1188       "TensorListFromTensor",    "TensorListPushBack",
1189       "TensorListPushBackBatch", "TensorListScatter",
1190       "TensorListScatterV2",     "TensorListScatterIntoExistingList",
1191       "TensorListSetItem",       "TensorListSplit"};
1192   return tensor_list_writer_ops.count(op);
1193 }
1194 
SupportsF16(const NodeTypeId & node_type) const1195 bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
1196   const OpDef* op_def;
1197   Status status =
1198       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1199   if (!status.ok()) return false;
1200   return AllowedDataTypes(*op_def, node_type.type_attr)
1201              .Contains(target_dtype_) &&
1202          NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1203 }
1204 
SupportsF16DataType(const NodeTypeId & node_type) const1205 bool AutoMixedPrecisionImpl::SupportsF16DataType(
1206     const NodeTypeId& node_type) const {
1207   const OpDef* op_def;
1208   Status status =
1209       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1210   if (!status.ok()) return false;
1211   return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_);
1212 }
1213 
1214 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1215 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1216 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1217   for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1218     NodeDef* node = graph_->mutable_node(node_idx);
1219     if (!ShouldProcess(*node)) continue;
1220     bool changed = false;
1221     if (node->op() == "FusedBatchNorm") {
1222       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1223               << " to FusedBatchNormV2";
1224       node->set_op("FusedBatchNormV2");
1225       changed = true;
1226     } else if (node->op() == "FusedBatchNormGrad") {
1227       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1228               << " to FusedBatchNormGradV2";
1229       node->set_op("FusedBatchNormGradV2");
1230       changed = true;
1231     }
1232     if (changed) {
1233       (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1234     }
1235   }
1236 }
1237 
1238 // A helper function to decide whether to ignore the effect on performance when
1239 // rewriting the graph. This can be useful for testing the numerical effects of
1240 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1241 bool ShouldIgnorePerformance() {
1242   static bool is_enabled = [] {
1243     bool ret = false;
1244     TF_CHECK_OK(ReadBoolFromEnvVar(
1245         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1246         /*default_val=*/false, &ret));
1247     return ret;
1248   }();
1249   return is_enabled;
1250 }
1251 
Optimize()1252 Status AutoMixedPrecisionImpl::Optimize() {
1253   string optimization_level;
1254   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1255       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1256   optimization_level = absl::AsciiStrToUpper(optimization_level);
1257   force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1258   if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
1259     // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
1260     // bfloat16.
1261     return errors::InvalidArgument(
1262         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
1263         "UNSAFE_FORCE_ALL when MKL is used");
1264   }
1265 
1266   std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1267       get_mixed_precision_lists();
1268   f16_allowlist_ = mp_lists->AllowList();
1269   f16_denylist_ = mp_lists->DenyList();
1270   f16_inferlist_ = mp_lists->InferList();
1271   f16_clearlist_ = mp_lists->ClearList();
1272   TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
1273                                    f16_inferlist_, f16_clearlist_));
1274 
1275   size_t timestamp = Env::Default()->NowMicros() / 1000;
1276   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1277 
1278   VLOG(2) << "Identifying nodes that should be processed";
1279   for (const NodeDef& node : graph_->node()) {
1280     bool should_process;
1281     switch (mode_) {
1282       case AutoMixedPrecisionMode::CUDA:
1283         should_process =
1284             !MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
1285             (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
1286         break;
1287       case AutoMixedPrecisionMode::MKL:
1288         should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
1289         break;
1290     }
1291     if (should_process) {
1292       should_process_nodes_.insert(&node);
1293     } else {
1294       LogSkippedNode(node);
1295     }
1296   }
1297 
1298   VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1299   ConvertBatchNormOpsToV2();
1300 
1301   VLOG(2) << "Building node type map for graph";
1302   TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1303 
1304   VLOG(2) << "Constructing graph type attribute topology view";
1305   TF_RETURN_IF_ERROR(
1306       graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
1307 
1308   absl::flat_hash_set<int> deny_set;
1309 
1310   std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
1311   FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
1312                                                    &deny_set);
1313   std::vector<NodeTypeIdEdge> ephemeral_edges;
1314   for (const auto& cluster : tensor_list_clusters) {
1315     VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
1316     for (const NodeDef* node : cluster) {
1317       VLOG(2) << "  Cluster member: " << node->op() << " node " << node->name();
1318     }
1319     FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
1320   }
1321   TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
1322 
1323   // The goal here is to change performance-critical ops to fp16 or bf16, and to
1324   // do so with the minimal number of casts, subject to the constraint that the
1325   // model's convergence is not affected. This is achieved by first identifying
1326   // which nodes should be changed to f16 and then inserting casts at the
1327   // boundaries between f16/non-f16 nodes.
1328 
1329   // The algorithm for deciding which nodes to change to f16 is as follows:
1330   // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
1331   //    This is done under the assumption that allowlist ops are always
1332   //    numerically-safe in f16 and that they are the most important ops for
1333   //    improving performance.
1334   // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
1335   //    "denylist" ops) or they are on a forward path from a denylist node to
1336   //    a deny/infer node (including the node at the end of the path) through
1337   //    non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
1338   //    This is done to prevent numerically-dangerous ops and their downstream
1339   //    effects from being changed to f16, which would risk breaking the
1340   //    numerical accuracy of the model.
1341   // 3) For all remaining nodes that are not considered dangerous (inferlist
1342   //    and clearlist ops), find those that are between (i.e., both upstream
1343   //    and downstream of) allow nodes, and add them to the allow_set.
1344   //    This is done to avoid unnecessary casts between allowlist ops.
1345   // 4) For all remaining clearlist nodes, add them to the allow_set if they are
1346   //    connected to a node in the allow_set via other clearlist nodes.
1347   //    This is done to increase the number of ops in the allow_set without
1348   //    affecting numerical stability.
1349 
1350   absl::flat_hash_set<int> allow_set;
1351   VLOG(2) << "Beginning pass 1 to add allowlist ops";
1352   AddAllowlistOps(&allow_set);
1353   VLOG(2) << "Finished pass 1";
1354 
1355   if (allow_set.empty()) {
1356     LOG(INFO) << "No allowlist ops found, nothing to do";
1357     return Status::OK();
1358   }
1359 
1360   VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
1361              "through clear/inferlist ops";
1362   PropagateDenyFwdThroughClearAndInfer(&deny_set);
1363   VLOG(2) << "Finished pass 2";
1364 
1365   VLOG(2) << "Forcing color match between data structure ops";
1366   for (const auto& cluster : tensor_list_clusters) {
1367     ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1368   }
1369 
1370   VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
1371              "are between allow ops";
1372   AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
1373   VLOG(2) << "Finished pass 3";
1374 
1375   VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
1376              "clearlist ops";
1377   PropagateAllowThroughClear(deny_set, &allow_set);
1378   VLOG(2) << "Finished pass 4";
1379 
1380   VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed "
1381              "to F16"
1382              "from allow set";
1383   RemoveAllowsetWithFp32(&allow_set);
1384   VLOG(2) << "Finished pass 5";
1385 
1386   VLOG(2) << "Forcing color match between data structure ops";
1387   for (const auto& cluster : tensor_list_clusters) {
1388     ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1389   }
1390 
1391   VLOG(2) << "Forcing color match on loop edges";
1392   TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
1393 
1394   VLOG(2) << "Finding existing casts that can be made allow";
1395   MakeCastsAllowIfAllOutputsAllow(&allow_set);
1396 
1397   VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1398              "ops at paint boundaries";
1399   TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
1400   VLOG(2) << "Finished final pass";
1401 
1402   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1403 
1404   return Status::OK();
1405 }
1406 
1407 // If node is a Tensor List op with a float32 data type attribute then this
1408 // returns a pointer to the NodeTypeId representing that type attribute. In
1409 // all other cases this returns nullptr.
GetTensorListFloat32NodeTypeId(const NodeDef & node) const1410 const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId(
1411     const NodeDef& node) const {
1412   if (!IsTensorListOp(node.op())) return nullptr;
1413   for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) {
1414     const NodeTypeId* node_type =
1415         graph_type_view_.GetNode(node.name(), type_attr);
1416     // This assumes that the float32 data type on a Tensor List op is always a
1417     // non-fixed type attribute containing a single type, and that this type
1418     // attribute represents the dtype of the values in the list.
1419     // TODO(benbarsdell): A new Tensor List op could theoretically break these
1420     // assumptions.
1421     if (node_type && node_type->type_attr.fixed_type == DT_INVALID &&
1422         node_type->type_attr.type_index == TypeAttrId::kSingleType &&
1423         IsFloat32(*node_type)) {
1424       return node_type;
1425     }
1426   }
1427   return nullptr;
1428 }
1429 
IsSourceOrSinkOp(const string & op) const1430 bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
1431   const gtl::FlatSet<string> source_and_sink_ops = {
1432       "_Arg",
1433       "_Retval",
1434       "OptionalFromValue",
1435       "OptionalGetValue",
1436       "PartitionedCall",
1437       "Placeholder",
1438       "StatefulPartitionedCall",
1439   };
1440   return source_and_sink_ops.count(op) || function_library_.Find(op);
1441 }
1442 
1443 // Finds all clusters of float32 Tensor List nodes that are connected via their
1444 // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
1445 // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
1446 // nodes) are added to deny_set. The caller should paint all nodes in a cluster
1447 // the same color, as they may all refer to the same Tensor List.
FindFloat32TensorListOpClustersAndDenylistUnsafe(std::vector<absl::flat_hash_set<const NodeDef * >> * tensor_list_clusters,absl::flat_hash_set<int> * deny_set) const1448 void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
1449     std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
1450     absl::flat_hash_set<int>* deny_set) const {
1451   absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
1452   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1453     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1454     if (!ShouldProcess(*root.node) ||
1455         root.type_attr.fixed_type != DataType::DT_VARIANT ||
1456         !GetTensorListFloat32NodeTypeId(*root.node) ||
1457         tensor_list_prop_set.count(root.node)) {
1458       continue;
1459     }
1460     const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1461     const absl::optional<int> maybe_root_fp32_idx =
1462         graph_type_view_.GetNodeIndex(*root_fp32);
1463     DCHECK(maybe_root_fp32_idx.has_value())
1464         << "Type attribute " << root_fp32->type_attr.DebugString()
1465         << " of node " << root.node->name() << " not found in graph view";
1466     int root_fp32_idx = maybe_root_fp32_idx.value();
1467     // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all
1468     // connected Tensor List nodes.
1469     absl::flat_hash_set<const NodeDef*> cluster({root.node});
1470     DfsTypeTraversal(graph_type_view_, {&root},
1471                      TypeTraversalDirection::kFollowInputsAndOutputs,
1472                      DfsTypePredicates::Enter([&](int idx) -> bool {
1473                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1474                        return !tensor_list_prop_set.count(item.node);
1475                      }),
1476                      DfsTypeCallbacks::PreOrder([&](int idx) {
1477                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1478                        const NodeDef* node = item.node;
1479                        if (GetTensorListFloat32NodeTypeId(*node)) {
1480                          cluster.insert(node);
1481                          if (!ShouldProcess(*node)) {
1482                            // The cluster contains an un-processable node.
1483                            deny_set->insert(root_fp32_idx);
1484                          }
1485                          // TODO(benbarsdell): In a theoretical pathological
1486                          // case of a Tensor List of Tensor List handles, the
1487                          // Tensor List itself would need to be treated as a
1488                          // sink.
1489                        } else if (IsSourceOrSinkOp(node->op())) {
1490                          // The cluster crosses an untraversable boundary.
1491                          deny_set->insert(root_fp32_idx);
1492                        }
1493                      }));
1494     tensor_list_clusters->push_back(cluster);
1495   }
1496 }
1497 
1498 // Finds all writer -> reader pairs in the given set that are connected via
1499 // their handles, and adds corresponding float32 edges to *implicit_fp32_edges.
FindTensorListImplicitFloat32Edges(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,std::vector<NodeTypeIdEdge> * implicit_fp32_edges) const1500 void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
1501     const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1502     std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const {
1503   for (const NodeDef* root_node : tensor_list_nodes) {
1504     if (!IsTensorListReaderOp(root_node->op())) continue;
1505     NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT));
1506     const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1507     CHECK(root_fp32) << "No float32 type attribute found for "  // Crash OK
1508                      << root.node->op() << " node " << root.node->name();
1509     // Search backwards through handle edges (DT_VARIANT) for all writer ops,
1510     // adding direct implicit edges between them and the reader.
1511     DfsTypeTraversal(
1512         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1513         DfsTypePredicates::Enter([&](int idx) -> bool {
1514           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1515           return ShouldProcess(*item.node);
1516         }),
1517         DfsTypeCallbacks::PreOrder([&](int idx) {
1518           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1519           if (IsTensorListWriterOp(item.node->op())) {
1520             const NodeTypeId* item_fp32 =
1521                 GetTensorListFloat32NodeTypeId(*item.node);
1522             CHECK(item_fp32)  // Crash OK
1523                 << "No float32 type attribute found for " << item.node->op()
1524                 << " node " << item.node->name();
1525             VLOG(2) << "Adding ephemeral float32 edge from "
1526                     << item_fp32->node->op() << " node "
1527                     << item_fp32->node->name() << " to "
1528                     << root_fp32->node->op() << " node "
1529                     << root_fp32->node->name();
1530             implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32);
1531           }
1532         }));
1533   }
1534 }
1535 
AddAllowlistOps(absl::flat_hash_set<int> * allow_set) const1536 void AutoMixedPrecisionImpl::AddAllowlistOps(
1537     absl::flat_hash_set<int>* allow_set) const {
1538   // Add allowlisted ops to allow_set.
1539   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1540     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1541     if (!ShouldProcess(*root.node)) continue;
1542     bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
1543     if (f16_allowlist_.count(root.node->op()) || force_allow) {
1544       bool inserted = allow_set->insert(root_idx).second;
1545       if (VLOG_IS_ON(2) && inserted) {
1546         VLOG(2) << "Painting type " << root.type_attr.DebugString()
1547                 << " of node " << root.node->name() << " ALLOW because its op "
1548                 << root.node->op() << " is on the allowlist";
1549       }
1550     }
1551   }
1552 }
1553 
1554 // Adds nodes to deny_set iff they are on the denylist or they are on a
1555 // forward path from a denylist node to a deny/infer node (including the node
1556 // at the end of the path) through clear and infer nodes.
1557 // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
1558 // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
PropagateDenyFwdThroughClearAndInfer(absl::flat_hash_set<int> * deny_set) const1559 void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
1560     absl::flat_hash_set<int>* deny_set) const {
1561   if (force_all_fp16_) return;
1562 
1563   // Find clear nodes that are upstream of deny or infer.
1564   absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
1565   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1566     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1567     if (!(f16_denylist_.count(root.node->op()) ||
1568           f16_inferlist_.count(root.node->op()))) {
1569       continue;
1570     }
1571     DfsTypeTraversal(graph_type_view_, {&root},
1572                      TypeTraversalDirection::kFollowInputs,
1573                      DfsTypePredicates::Enter([&](int idx) -> bool {
1574                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1575                        return idx == root_idx ||
1576                               (!upstream_of_deny_or_infer_set.count(idx) &&
1577                                f16_clearlist_.count(item.node->op()));
1578                      }),
1579                      DfsTypeCallbacks::PreOrder([&](int idx) {
1580                        upstream_of_deny_or_infer_set.insert(idx);
1581                      }));
1582   }
1583 
1584   // Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
1585   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1586     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1587     if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
1588       continue;
1589     }
1590     DfsTypeTraversal(
1591         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1592         DfsTypePredicates::Enter([&](int idx) -> bool {
1593           return idx == root_idx || (!deny_set->count(idx) &&
1594                                      upstream_of_deny_or_infer_set.count(idx));
1595         }),
1596         DfsTypeCallbacks::PreOrder([&](int idx) {
1597           bool inserted = deny_set->insert(idx).second;
1598           if (VLOG_IS_ON(2) && inserted) {
1599             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1600             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1601                     << " of " << item.node->op() << " node "
1602                     << item.node->name() << " DENY";
1603           }
1604         }));
1605   }
1606 }
1607 
AddClearAndInferToAllowIfBetweenAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1608 void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
1609     const absl::flat_hash_set<int>& deny_set,
1610     absl::flat_hash_set<int>* allow_set) const {
1611   // Find clear/inferlist ops that are downstream of allow ops.
1612   absl::flat_hash_set<int> downstream_of_allow_set;
1613   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1614     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1615     if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
1616       continue;
1617     }
1618     DfsTypeTraversal(
1619         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1620         DfsTypePredicates::Enter([&](int idx) -> bool {
1621           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1622           return idx == root_idx ||
1623                  (!downstream_of_allow_set.count(idx) &&
1624                   !f16_allowlist_.count(item.node->op()) &&
1625                   !deny_set.count(idx) && ShouldProcess(*item.node) &&
1626                   // TODO(benbarsdell): Consider allowing propagation through
1627                   // ops that are already float16 in order to reduce the number
1628                   // of casts.
1629                   IsFloat32(item) && SupportsF16(item) &&
1630                   (f16_clearlist_.count(item.node->op()) ||
1631                    f16_inferlist_.count(item.node->op())));
1632         }),
1633         DfsTypeCallbacks::PreOrder(
1634             [&](int idx) { downstream_of_allow_set.insert(idx); }));
1635   }
1636 
1637   // Set nodes that are both downstream and upstream of allow ops to allow.
1638   absl::flat_hash_set<int> upstream_of_allow_set;
1639   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1640     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1641     if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
1642         !f16_allowlist_.count(root.node->op())) {
1643       continue;
1644     }
1645     DfsTypeTraversal(
1646         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1647         DfsTypePredicates::Enter([&](int idx) -> bool {
1648           return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
1649                                      downstream_of_allow_set.count(idx));
1650         }),
1651         DfsTypeCallbacks::PreOrder([&](int idx) {
1652           upstream_of_allow_set.insert(idx);
1653           bool inserted = allow_set->insert(idx).second;
1654           if (VLOG_IS_ON(2) && inserted) {
1655             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1656             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1657                     << " of " << item.node->op() << " node "
1658                     << item.node->name() << " ALLOW";
1659           }
1660         }));
1661   }
1662 }
1663 
PropagateAllowThroughClear(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1664 void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
1665     const absl::flat_hash_set<int>& deny_set,
1666     absl::flat_hash_set<int>* allow_set) const {
1667   // Propagate allow from allow nodes through clearlist ops.
1668   absl::flat_hash_set<int> clear_prop_set;
1669   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1670     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1671     if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1672         !allow_set->count(root_idx)) {
1673       continue;
1674     }
1675     DfsTypeTraversal(
1676         graph_type_view_, {&root},
1677         TypeTraversalDirection::kFollowInputsAndOutputs,
1678         DfsTypePredicates::Enter([&](int idx) -> bool {
1679           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1680           return idx == root_idx ||
1681                  (!allow_set->count(idx) && !deny_set.count(idx) &&
1682                   ShouldProcess(*item.node) && IsFloat32(item) &&
1683                   SupportsF16(item) &&
1684                   (f16_clearlist_.count(item.node->op())) &&
1685                   // We don't propagate (backwards) through nodes that read
1686                   // Variables because it can break the behavior of TensorBoard
1687                   // visualization and/or (in the case of Enter nodes) the model
1688                   // itself. This is only a problem for non-resource variables.
1689                   !NodeImplicitlyReadsNonResourceVariable(*item.node));
1690         }),
1691         DfsTypeCallbacks::PreOrder([&](int idx) {
1692           clear_prop_set.insert(idx);
1693           bool inserted = allow_set->insert(idx).second;
1694           if (VLOG_IS_ON(2) && inserted) {
1695             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1696             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1697                     << " of " << item.node->op() << " node "
1698                     << item.node->name() << " ALLOW";
1699           }
1700         }));
1701   }
1702 }
1703 
1704 // If ops have one or more type_attr, But this type_attr could not be converted
1705 // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only
1706 // support float. So we will remove this node from allow_set.
RemoveAllowsetWithFp32(absl::flat_hash_set<int> * allow_set) const1707 void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32(
1708     absl::flat_hash_set<int>* allow_set) const {
1709   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1710     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1711     if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) &&
1712         !SupportsF16DataType(root)) {
1713       auto erased = allow_set->erase(root_idx);
1714       if (VLOG_IS_ON(2) && erased) {
1715         VLOG(2) << "UnPainting type " << root.type_attr.DebugString()
1716                 << " of node " << root.node->name() << " ALLOW because its op "
1717                 << root.node->op() << " is not support F16 DataType";
1718       }
1719     }
1720   }
1721 }
1722 
1723 // Forces NextIteration nodes and their output Merge node(s) to have the same
1724 // color. Specifically, it removes them all from allow_set if any of the Merge
1725 // nodes is not in allow_set, otherwise it adds the NextIteration node to
1726 // allow_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * allow_set) const1727 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1728     absl::flat_hash_set<int>* allow_set) const {
1729   for (const NodeDef& node : graph_->node()) {
1730     if (node.op() == "NextIteration") {
1731       GraphView::OutputPort output_port(&node, 0);
1732       const auto& fanout = graph_view_.GetFanout(output_port);
1733       std::vector<int> merge_idxs;
1734       merge_idxs.reserve(fanout.size());
1735       bool any_merge_is_not_allow = false;
1736       for (const auto& output : fanout) {
1737         const NodeDef& merge_node = *output.node;
1738         if (merge_node.op() != "Merge") {
1739           return errors::FailedPrecondition(
1740               "Expected Merge node after NextIteration, got ", merge_node.op());
1741         }
1742         const absl::optional<int> maybe_merge_idx =
1743             graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1744         if (!maybe_merge_idx.has_value()) {
1745           return errors::Internal("Type attribute T of Merge node ",
1746                                   merge_node.name(),
1747                                   " not found in graph view");
1748         }
1749         int merge_idx = maybe_merge_idx.value();
1750         merge_idxs.push_back(merge_idx);
1751         any_merge_is_not_allow =
1752             any_merge_is_not_allow || !allow_set->count(merge_idx);
1753       }
1754       const absl::optional<int> maybe_nextiter_idx =
1755           graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1756       if (!maybe_nextiter_idx.has_value()) {
1757         return errors::Internal("Type attribute T of NextIteration node ",
1758                                 node.name(), " not found in graph view");
1759       }
1760       int nextiter_idx = maybe_nextiter_idx.value();
1761       if (any_merge_is_not_allow) {
1762         for (int merge_idx : merge_idxs) {
1763           if (allow_set->erase(merge_idx)) {
1764             VLOG(2) << "Painting type T of Merge node "
1765                     << graph_type_view_.GetNode(merge_idx)->node->name()
1766                     << " DENY to match the color of its sibling Merge nodes "
1767                        "with common NextIteration node "
1768                     << node.name();
1769           }
1770         }
1771         if (allow_set->erase(nextiter_idx)) {
1772           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1773                   << " DENY to match the color of its output Merge node(s)";
1774         }
1775       } else {
1776         if (allow_set->insert(nextiter_idx).second) {
1777           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1778                   << " ALLOW to match the color of its output Merge node(s)";
1779         }
1780       }
1781     }
1782   }
1783   return Status::OK();
1784 }
1785 
1786 // Forces all of the given Tensor List nodes into the same color set.
ForceColorMatchBetweenTensorListOps(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,absl::flat_hash_set<int> * allow_set,absl::flat_hash_set<int> * deny_set) const1787 void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
1788     const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1789     absl::flat_hash_set<int>* allow_set,
1790     absl::flat_hash_set<int>* deny_set) const {
1791   bool any_deny = false;
1792   bool any_allow = false;
1793   std::vector<int> node_type_idxs;
1794   node_type_idxs.reserve(tensor_list_nodes.size());
1795   for (const NodeDef* node : tensor_list_nodes) {
1796     const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node);
1797     const absl::optional<int> maybe_node_type_idx =
1798         graph_type_view_.GetNodeIndex(node_type);
1799     DCHECK(maybe_node_type_idx.has_value())
1800         << "Type attribute " << node_type.type_attr.DebugString() << " of node "
1801         << node->name() << " not found in graph view";
1802     node_type_idxs.push_back(maybe_node_type_idx.value());
1803   }
1804   for (int node_type_idx : node_type_idxs) {
1805     if (deny_set->count(node_type_idx)) {
1806       any_deny = true;
1807       break;
1808     } else if (allow_set->count(node_type_idx)) {
1809       any_allow = true;
1810     }
1811   }
1812   if (!any_deny && !any_allow) return;
1813   for (int node_type_idx : node_type_idxs) {
1814     const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
1815     VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
1816             << node_type.node->op() << " node " << node_type.node->name() << " "
1817             << (any_deny ? "DENY" : "ALLOW")
1818             << " because at least one of its siblings is "
1819             << (any_deny ? "DENY" : "ALLOW");
1820     if (any_deny) {
1821       allow_set->erase(node_type_idx);
1822       deny_set->insert(node_type_idx);
1823     } else {
1824       allow_set->insert(node_type_idx);
1825     }
1826   }
1827 }
1828 
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const1829 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
1830     const NodeDef& node) const {
1831   if (node.op() == "Identity" || node.op() == "Enter") {
1832     GraphView::InputPort node_input(&node, 0);
1833     MutableGraphView::OutputPort prev_output =
1834         graph_view_.GetRegularFanin(node_input);
1835     const NodeDef* input = prev_output.node;
1836     if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
1837                                                input->op() == "VariableV2")) ||
1838                   (node.op() == "Enter" &&
1839                    NodeImplicitlyReadsNonResourceVariable(*input)))) {
1840       return true;
1841     }
1842   }
1843   return false;
1844 }
1845 
1846 // This adds existing Cast nodes to allow_set if all of their outputs are allow,
1847 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsAllowIfAllOutputsAllow(absl::flat_hash_set<int> * allow_set) const1848 void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
1849     absl::flat_hash_set<int>* allow_set) const {
1850   int num_nodes_preop = graph_->node_size();
1851   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1852     NodeDef* node = graph_->mutable_node(node_idx);
1853     NodeTypeId node_type(node, TypeAttrId("DstT"));
1854     if (node->op() != "Cast" || !IsFloat32(node_type)) {
1855       continue;
1856     }
1857     bool all_fanouts_allow = true;
1858     MutableGraphView::OutputPort src(node, 0);
1859     const auto& fanout = graph_view_.GetFanout(src);
1860     for (const MutableGraphView::InputPort& dst : fanout) {
1861       TypeAttrId dst_type_attr =
1862           node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1863       const absl::optional<int> maybe_dst_type_idx =
1864           graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1865       DCHECK(maybe_dst_type_idx.has_value())
1866           << "Type attribute " << dst_type_attr.DebugString() << " of node "
1867           << dst.node->name() << " not found in graph view";
1868       int dst_type_idx = maybe_dst_type_idx.value();
1869       bool dst_is_allow = allow_set->count(dst_type_idx);
1870       if (!dst_is_allow) {
1871         all_fanouts_allow = false;
1872         break;
1873       }
1874     }
1875     if (!fanout.empty() && all_fanouts_allow) {
1876       const absl::optional<int> maybe_node_type_idx =
1877           graph_type_view_.GetNodeIndex(node_type);
1878       DCHECK(maybe_node_type_idx.has_value())
1879           << "Type attribute " << node_type.type_attr.DebugString()
1880           << " of node " << node_type.node->name()
1881           << " not found in graph view";
1882       int node_type_idx = maybe_node_type_idx.value();
1883       allow_set->insert(node_type_idx);
1884     }
1885   }
1886 }
1887 
1888 // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
1889 // inserts Cast nodes at node outputs for all edges that connect
1890 // allow-painted <-> non-allow-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & allow_set)1891 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
1892     const absl::flat_hash_set<int>& allow_set) {
1893   int num_nodes_changed = 0;
1894   int num_nonvar_casts_to_f16 = 0;
1895   int num_nodes_preop = graph_->node_size();
1896   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1897     NodeDef* node = graph_->mutable_node(node_idx);
1898     for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
1899       const absl::optional<int> maybe_node_type_idx =
1900           graph_type_view_.GetNodeIndex(node->name(), type_attr);
1901       if (!maybe_node_type_idx.has_value()) {
1902         return errors::Internal("Type attribute ", type_attr.DebugString(),
1903                                 " of ", node->op(), " node ", node->name(),
1904                                 " not found in graph view");
1905       }
1906       int node_type_idx = maybe_node_type_idx.value();
1907       if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
1908       bool src_is_allow = allow_set.count(node_type_idx);
1909       if (src_is_allow) {
1910         VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
1911                 << node->op() << " node " << node->name() << " to "
1912                 << DataTypeString(target_dtype_);
1913         if (!SetDataType(node, type_attr, target_dtype_)) {
1914           return errors::Internal("Failed to set type attribute");
1915         }
1916         ++num_nodes_changed;
1917       }
1918       for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
1919         MutableGraphView::OutputPort src(node, output_port);
1920         NodeDef* added_cast_node = nullptr;
1921         // Note: This is copied so that edges can be modified inside the loop.
1922         auto fanout = graph_view_.GetFanout(src);
1923         for (const MutableGraphView::InputPort& dst : fanout) {
1924           TypeAttrId dst_type_attr =
1925               node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1926           const absl::optional<int> maybe_dst_type_idx =
1927               graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1928           if (!maybe_dst_type_idx.has_value()) {
1929             return errors::Internal("Type attribute ",
1930                                     dst_type_attr.DebugString(), " of ",
1931                                     dst.node->op(), " node ", dst.node->name(),
1932                                     " not found in graph view");
1933           }
1934           int dst_type_idx = maybe_dst_type_idx.value();
1935           bool dst_is_allow = allow_set.count(dst_type_idx);
1936           if (src_is_allow != dst_is_allow) {
1937             if (!added_cast_node) {
1938               bool to_f16 = dst_is_allow;
1939               VLOG(1) << "Inserting cast to "
1940                       << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
1941                       << " at " << src.node->op() << " " << src.node->name()
1942                       << ":" << src.port_id;
1943               added_cast_node = graph_view_.AddNode(
1944                   BuildCastNode(src, to_f16, src.node->device()));
1945               if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
1946                   !NodeImplicitlyReadsNonResourceVariable(*node)) {
1947                 ++num_nonvar_casts_to_f16;
1948               }
1949             }
1950             TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
1951                 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
1952           }
1953         }
1954       }
1955     }
1956   }
1957   // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
1958   // since many Python users will see this message.
1959   const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
1960   LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
1961             << " nodes to " << type_str << " precision using "
1962             << num_nonvar_casts_to_f16 << " cast(s) to " << type_str
1963             << " (excluding Const and Variable casts)";
1964   return Status::OK();
1965 }
1966 
GetNumGPUs(const Cluster & cluster,const std::pair<int,int> & min_arch={0, 0})1967 int GetNumGPUs(const Cluster& cluster,
1968                const std::pair<int, int>& min_arch = {0, 0}) {
1969   auto devices = cluster.GetDevices();
1970   int num_gpus = 0;
1971   for (const auto& device : devices) {
1972     const DeviceProperties& device_properties = device.second;
1973     std::pair<int, int> arch = GetDeviceGPUArch(device_properties);
1974     if (device_properties.type() == "GPU" && arch >= min_arch) {
1975       num_gpus++;
1976     }
1977   }
1978   return num_gpus;
1979 }
1980 
1981 }  // end namespace
1982 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)1983 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
1984                                     GraphDef* output) {
1985   if (cluster == nullptr) {
1986     return errors::InvalidArgument("cluster == nullptr");
1987   }
1988 
1989 #if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
1990   if (mode_ == AutoMixedPrecisionMode::MKL) {
1991     return errors::Unimplemented(
1992         "The auto_mixed_precision_mkl optimizer cannot be used since "
1993         "this build of TensorFlow is not compiled with MKL support for "
1994         "bfloat16. "
1995         "For information on MKL builds, see: "
1996         "https://software.intel.com/en-us/articles/intel-optimization-for-"
1997         "tensorflow-installation-guide");
1998   }
1999 #endif
2000 
2001   // Start by copying input graph to output.
2002   *output = item.graph;
2003 
2004   int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
2005                                            : GetNumGPUs(*cluster, kMinGPUArch);
2006   if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
2007     // AutoMixedPrecision is currently only tuned for GPU.
2008     LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
2009                  << " graph optimizer";
2010     return Status::OK();
2011   }
2012 
2013   // Optimize the output graph in-place.
2014   AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
2015                                    item.id, mode_);
2016   if (item.id == "tf_graph") {
2017     LOG(INFO) << "Running " << name() << " graph optimizer";
2018   } else {
2019     VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
2020   }
2021   Status status = optimizer.Optimize();
2022   if (!status.ok()) {
2023     // Restore the original graph.
2024     *output = item.graph;
2025     LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
2026   }
2027   return status;
2028 }
2029 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)2030 void AutoMixedPrecision::Feedback(Cluster* cluster, const GrapplerItem& item,
2031                                   const GraphDef& optimize_output,
2032                                   double result) {
2033   // Nothing to do for AutoMixedPrecision.
2034 }
2035 
2036 }  // end namespace grappler
2037 }  // end namespace tensorflow
2038