• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17 
18 #include <fstream>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/costs/virtual_placer.h"
29 #include "tensorflow/core/grappler/devices.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
34 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
35 #include "tensorflow/core/grappler/utils.h"
36 #include "tensorflow/core/lib/io/path.h"
37 #include "tensorflow/core/lib/strings/numbers.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/util/env_var.h"
42 
43 namespace tensorflow {
44 namespace grappler {
45 namespace {
46 
47 const std::pair<int, int> kMinGPUArch = {7, 0};
48 
49 const char kSuffix[] = "AutoMixedPrecision";
50 const char kCastToFp16[] = "CastToFp16";
51 const char kCastToFp32[] = "CastToFp32";
52 
53 // Instances of this class represent unique type attribute identifiers within a
54 // node. It handles regular type attributes, list type attributes (where
55 // type_index is set to the index in the type list), and fixed types.
56 struct TypeAttrId {
57   static const int kSingleType = -1;
58 
TypeAttrIdtensorflow::grappler::__anon09ee8e770111::TypeAttrId59   explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
60       : attr_name(_attr_name),
61         type_index(_type_index),
62         fixed_type(DT_INVALID) {}
63 
TypeAttrIdtensorflow::grappler::__anon09ee8e770111::TypeAttrId64   explicit TypeAttrId(DataType _fixed_type)
65       : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
66 
operator ==tensorflow::grappler::__anon09ee8e770111::TypeAttrId67   bool operator==(const TypeAttrId& other) const {
68     return attr_name == other.attr_name && type_index == other.type_index &&
69            fixed_type == other.fixed_type;
70   }
71 
operator <tensorflow::grappler::__anon09ee8e770111::TypeAttrId72   bool operator<(const TypeAttrId& other) const {
73     return std::make_tuple(attr_name, type_index, fixed_type) <
74            std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
75   }
76 
77   template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)78   friend H AbslHashValue(H h, const TypeAttrId& ta) {
79     return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
80   }
81 
DebugStringtensorflow::grappler::__anon09ee8e770111::TypeAttrId82   string DebugString() const {
83     if (!attr_name.empty()) {
84       if (type_index == kSingleType) {
85         return attr_name;
86       } else {
87         return strings::StrCat(attr_name, "[", type_index, "]");
88       }
89     } else {
90       return tensorflow::DataTypeString(fixed_type);
91     }
92   }
93 
94   string attr_name;
95   // If attr_name is a list(type), this is the index into the list. Otherwise
96   // this is kSingleType.
97   int type_index;
98   DataType fixed_type;
99 };
100 
101 // Returns the data type of the given type attribute, or DT_INVALID if the type
102 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)103 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
104   if (type_attr.attr_name.empty()) {
105     return type_attr.fixed_type;
106   }
107   if (!node.attr().count(type_attr.attr_name)) {
108     return DT_INVALID;
109   }
110   const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
111   if (type_attr.type_index == TypeAttrId::kSingleType) {
112     return attr_value.type();
113   } else {
114     if (type_attr.type_index < 0 ||
115         type_attr.type_index >= attr_value.list().type_size()) {
116       return DT_INVALID;
117     }
118     return attr_value.list().type(type_attr.type_index);
119   }
120 }
121 
122 // Sets the data type of the given type attribute. Returns false if the type
123 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)124 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
125   if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
126     return false;
127   }
128   AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
129   if (type_attr.type_index == TypeAttrId::kSingleType) {
130     attr_value.set_type(type);
131   } else {
132     if (type_attr.type_index < 0 ||
133         type_attr.type_index >= attr_value.list().type_size()) {
134       return false;
135     }
136     attr_value.mutable_list()->set_type(type_attr.type_index, type);
137   }
138   return true;
139 }
140 
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)141 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
142                                                const OpDef::ArgDef& arg_def) {
143   std::vector<std::pair<int, int>> argdef_inds;
144   if (!arg_def.type_list_attr().empty()) {
145     int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
146     for (int type_idx = 0; type_idx < num_types; ++type_idx) {
147       argdef_inds.push_back({arg_idx, type_idx});
148     }
149   } else {
150     int num_repeat = 1;
151     if (node.attr().count(arg_def.number_attr())) {
152       num_repeat = node.attr().at(arg_def.number_attr()).i();
153     }
154     argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
155   }
156   return argdef_inds;
157 }
158 
159 // Returns a pair (arg_index, type_index) for each input to the node, where
160 // arg_index is the index of the input_arg in op_def and type_index is the index
161 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)162 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
163                                                         const OpDef& op_def) {
164   std::vector<std::pair<int, int>> argdef_inds;
165   argdef_inds.reserve(op_def.input_arg_size());  // Final size may differ.
166   for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
167     const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
168     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
169     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
170                        arg_results.end());
171   }
172   return argdef_inds;
173 }
174 
175 // Returns a pair (arg_index, type_index) for each output to the node, where
176 // arg_index is the index of the output_arg in op_def and type_index is the
177 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)178 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
179                                                          const OpDef& op_def) {
180   std::vector<std::pair<int, int>> argdef_inds;
181   argdef_inds.reserve(op_def.output_arg_size());  // Final size may differ.
182   for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
183     const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
184     auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
185     argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
186                        arg_results.end());
187   }
188   return argdef_inds;
189 }
190 
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)191 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
192   if (!arg_def.type_list_attr().empty()) {
193     return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
194   } else if (!arg_def.type_attr().empty()) {
195     return TypeAttrId(arg_def.type_attr());
196   } else {
197     return TypeAttrId(arg_def.type());
198   }
199 }
200 
NonControlInputs(const NodeDef & node)201 std::vector<int> NonControlInputs(const NodeDef& node) {
202   std::vector<int> pos;
203   for (int i = 0; i < node.input_size(); i++) {
204     if (!IsControlInput(node.input(i))) {
205       pos.push_back(i);
206     }
207   }
208   return pos;
209 }
210 
211 // A utility class to lookup node type attributes and type attribute <->
212 // input/output port mappings.
213 class NodeTypeAttrMap {
214  public:
NodeTypeAttrMap()215   NodeTypeAttrMap() {}
216 
NodeTypeAttrMap(const GraphDef & graph)217   explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
218 
Init(const GraphDef & graph)219   Status Init(const GraphDef& graph) {
220     if (graph_ != nullptr) {
221       return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
222     }
223     graph_ = &graph;
224     function_library_.reset(
225         new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
226     for (const NodeDef& node : graph.node()) {
227       TF_RETURN_IF_ERROR(AddNode(node));
228     }
229     return Status::OK();
230   }
231 
is_initialized() const232   bool is_initialized() const { return graph_ != nullptr; }
233 
234   // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const235   absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
236     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
237     absl::flat_hash_set<TypeAttrId> type_attrs;
238     const auto iter = type2io_.find(&node);
239     CHECK(iter != type2io_.end());  // Crash Ok
240     for (const auto& key_value : iter->second) {
241       type_attrs.insert(key_value.first);
242     }
243     return type_attrs;
244   }
245 
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const246   const absl::flat_hash_set<int>& GetInputPorts(
247       const NodeDef& node, const TypeAttrId& type_attr) const {
248     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
249     return type2io_.at(&node).at(type_attr).first;
250   }
251 
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const252   const absl::flat_hash_set<int>& GetOutputPorts(
253       const NodeDef& node, const TypeAttrId& type_attr) const {
254     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
255     return type2io_.at(&node).at(type_attr).second;
256   }
257 
GetInputTypeAttr(const NodeDef & node,int port) const258   TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
259     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
260     auto type_vec = io2type_.at(&node).first;
261     CHECK_GE(port, 0);                // Crash Ok
262     CHECK_LT(port, type_vec.size());  // Crash Ok
263     return type_vec[port];
264   }
265 
GetOutputTypeAttr(const NodeDef & node,int port) const266   TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
267     DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
268     auto type_vec = io2type_.at(&node).second;
269     CHECK_GE(port, 0);                // Crash Ok
270     CHECK_LT(port, type_vec.size());  // Crash Ok
271     return type_vec[port];
272   }
273 
274  private:
AddNode(const NodeDef & node)275   Status AddNode(const NodeDef& node) {
276     const OpDef* op_def_ptr = nullptr;
277     TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
278     const OpDef& op_def = *op_def_ptr;
279     auto& type2io_entry = type2io_[&node];
280     auto& io2type_entry = io2type_[&node];
281     auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
282     if (NonControlInputs(node).size() != input_arg_inds.size()) {
283       return errors::InvalidArgument(
284           "Expected ", node.op(), " node ", node.name(), " to have ",
285           input_arg_inds.size(), " non-control input(s), but got ",
286           node.input_size());
287     }
288     // Note that the mappings generated here include inputs/outputs with fixed
289     // types. This makes the mappings complete (all inputs and outputs are
290     // included), and allows the graph rewriter to propagate black paint
291     // from/through ops with fixed types.
292     io2type_entry.first.reserve(input_arg_inds.size());
293     for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
294       const auto& arg_inds = input_arg_inds[i];
295       const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
296       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
297       if (!type_attr.attr_name.empty() &&
298           !node.attr().count(type_attr.attr_name)) {
299         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
300                                        " is not present in node ", node.name());
301       }
302       type2io_entry[type_attr].first.insert(i);
303       io2type_entry.first.push_back(type_attr);
304     }
305 
306     auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
307     io2type_entry.second.reserve(output_arg_inds.size());
308     for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
309       const auto& arg_inds = output_arg_inds[i];
310       const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
311       TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
312       if (!type_attr.attr_name.empty() &&
313           !node.attr().count(type_attr.attr_name)) {
314         return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
315                                        " is not present in node ", node.name());
316       }
317       type2io_entry[type_attr].second.insert(i);
318       io2type_entry.second.push_back(type_attr);
319     }
320 
321     // Also ensure that type attributes that aren't associated with any inputs
322     // or outputs (e.g., StackV2's elem_type) are added to the map.
323     for (const auto& attr : node.attr()) {
324       const string& attr_name = attr.first;
325       if (!attr_name.empty() && attr_name[0] == '_') continue;
326       const AttrValue& attr_value = attr.second;
327       const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
328       if (!attr_def) {
329         return errors::InvalidArgument("AttrDef not found for attribute ",
330                                        attr_name, " of node ", node.name());
331       }
332       if (attr_def->type() == "type") {
333         type2io_entry[TypeAttrId(attr_name)];
334       } else if (attr_def->type() == "list(type)") {
335         for (int i = 0; i < attr_value.list().type_size(); ++i) {
336           type2io_entry[TypeAttrId(attr_name, i)];
337         }
338       }
339     }
340     return Status::OK();
341   }
342 
343   // WARN: `graph_` must outlive this object (node pointers must remain valid).
344   const GraphDef* graph_ = nullptr;  // do not own
345   std::unique_ptr<FunctionLibraryDefinition> function_library_;
346 
347   typedef absl::flat_hash_set<int> IntSet;
348   // Maps a type attr id -> (input port set, output port set)
349   typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
350   // Maps a node -> type attr mapping
351   absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
352   // Maps a port -> type attr id
353   typedef std::vector<TypeAttrId> TypeAttrIdVec;
354   // Maps a node -> (input port mapping, output port mapping)
355   absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
356       io2type_;
357 };
358 
359 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anon09ee8e770111::NodeTypeId360   NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
361       : node(_node), type_attr(_type_attr) {}
362 
363   const NodeDef* node;
364   TypeAttrId type_attr;
365 
operator ==tensorflow::grappler::__anon09ee8e770111::NodeTypeId366   bool operator==(const NodeTypeId& other) const {
367     return node == other.node && type_attr == other.type_attr;
368   }
369 
370   template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)371   friend H AbslHashValue(H h, const NodeTypeId& nt) {
372     return H::combine(std::move(h), nt.node, nt.type_attr);
373   }
374 };
375 
376 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anon09ee8e770111::NodeTypeIdEdge377   NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
378       : src(_src), dst(_dst) {}
379   NodeTypeId src;
380   NodeTypeId dst;
381 };
382 
383 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
384 // used instead of this modified version.
385 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
386 // the vertices instead of just NodeDef.
387 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
388 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
389 // will be an edge from (A, T) to (B, U).
390 class GraphTypeTopologyView {
391  public:
392   GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)393   explicit GraphTypeTopologyView(bool skip_invalid_edges)
394       : skip_invalid_edges_(skip_invalid_edges) {}
395 
396   // Initialize graph topology view from the graph. It's possible to pass
397   // additional edges that do not exist in a graph, but must be respected when
398   // computing graph topology. Example: Tensorflow runtime allows concurrent
399   // execution of dequeue/enqueue ops from the same queue resource, but we might
400   // want to enforce ordering between them for the purpose of graph analysis.
401   Status InitializeFromGraph(const GraphDef& graph,
402                              const NodeTypeAttrMap& node_type_map,
403                              absl::Span<const NodeTypeIdEdge> ephemeral_edges);
404   Status InitializeFromGraph(const GraphDef& graph,
405                              const NodeTypeAttrMap& node_type_map);
406 
is_initialized() const407   bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const408   int num_nodes() const { return num_nodes_; }
graph() const409   const GraphDef* graph() const { return graph_; }
410 
411   // Returns true iff the node exists in the underlying graph.
412   bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
413 
414   // Finds a node by name or returns `nullptr` if it's not in the graph.
415   const NodeTypeId* GetNode(absl::string_view node_name,
416                             const TypeAttrId& type_attr) const;
417   // Returns a node corresponding to the given node index.
418   const NodeTypeId* GetNode(int node_idx) const;
419 
420   // Returns a node index for the given node name, if the name exists in the
421   // underlying graph. Otherwise returns empty optional.
422   const absl::optional<int> GetNodeIndex(absl::string_view node_name,
423                                          const TypeAttrId& type_attr) const;
424   // Returns a node index for the given node, if the node belongs to the
425   // underlying graph. Otherwise returns empty optional.
426   const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
427 
428   // Returns all the node indexes that are in the direct fanin of the given
429   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
430   const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
431   // Returns all the node indexes that are in the direct fanout of the given
432   // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
433   const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
434 
435  private:
436   // The key type used to uniquely identify a type attribute on a node.
437   struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
438     typedef std::pair<absl::string_view, TypeAttrId> Base;
439 
440     // Inherit the set of constructors.
441     using Base::pair;
442 
443     template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)444     friend H AbslHashValue(H h, const NodeTypeKey& nt) {
445       return H::combine(std::move(h), nt.first, nt.second);
446     }
447   };
448 
449   // If true, all invalid edges and inputs (srd, dst or input node not found in
450   // a graph) will be skipped, otherwise initialization will fail with error.
451   bool skip_invalid_edges_ = false;
452 
453   // WARN: `graph_` must outlive this object and graph nodes must not be
454   // destructed, because node names captured with absl::string_view.
455   const GraphDef* graph_ = nullptr;  // do not own
456   int num_nodes_ = 0;
457   std::vector<NodeTypeId> node_type_attrs_;
458   absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
459   absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
460 
461   std::vector<absl::InlinedVector<int, 4>> fanins_;
462   std::vector<absl::InlinedVector<int, 2>> fanouts_;
463 
464   // We need a valid reference to return from GetFanin/GetFanout if the
465   // `node_idx` argument is outside of the [0, num_nodes_) range.
466   absl::InlinedVector<int, 4> empty_fanin_;
467   absl::InlinedVector<int, 2> empty_fanout_;
468 };
469 
470 template <typename T>
SortAndRemoveDuplicates(T * v)471 inline void SortAndRemoveDuplicates(T* v) {
472   std::sort(v->begin(), v->end());
473   v->erase(std::unique(v->begin(), v->end()), v->end());
474 }
475 
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map,absl::Span<const NodeTypeIdEdge> ephemeral_edges)476 Status GraphTypeTopologyView::InitializeFromGraph(
477     const GraphDef& graph, const NodeTypeAttrMap& node_type_map,
478     absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
479   if (graph_ != nullptr) {
480     return errors::InvalidArgument(
481         "GraphTypeTopologyView is already initialized.");
482   }
483 
484   graph_ = &graph;
485   int num_nodedefs = graph.node_size();
486   node_name_to_index_.rehash(num_nodedefs);
487 
488   // Build maps from name to index.
489   node_type_attrs_.reserve(num_nodedefs);         // Only approximate.
490   node_type_name_to_index_.rehash(num_nodedefs);  // Only approximate.
491   for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
492     const NodeDef& node = graph.node(node_idx);
493     node_name_to_index_.emplace(node.name(), node_idx);
494 
495     for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
496       int node_type_idx = node_type_attrs_.size();
497       node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
498                                        node_type_idx);
499       node_type_attrs_.emplace_back(&node, type_attr);
500     }
501   }
502   num_nodes_ = node_type_attrs_.size();
503   fanins_.resize(num_nodes_);
504   fanouts_.resize(num_nodes_);
505 
506   // 1. Add ephemeral edges to the adjacency lists.
507   for (const NodeTypeIdEdge& edge : ephemeral_edges) {
508     const auto src = node_name_to_index_.find(edge.src.node->name());
509     const bool valid_src = src != node_name_to_index_.end();
510 
511     if (!valid_src) {
512       const string error_message =
513           absl::StrCat("Non-existent src node: ", edge.src.node->name());
514       if (skip_invalid_edges_) {
515         VLOG(0) << "Skip error: " << error_message;
516       } else {
517         return errors::InvalidArgument(error_message);
518       }
519     }
520 
521     const auto dst = node_name_to_index_.find(edge.dst.node->name());
522     const bool valid_dst = dst != node_name_to_index_.end();
523 
524     if (!valid_dst) {
525       const string error_message =
526           absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
527       if (skip_invalid_edges_) {
528         VLOG(0) << "Skip error: " << error_message;
529       } else {
530         return errors::InvalidArgument(error_message);
531       }
532     }
533 
534     if (valid_dst && valid_src) {
535       // TODO(benbarsdell): Check for failure.
536       int src_node_type_idx = node_type_name_to_index_.at(
537           NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
538       int dst_node_type_idx = node_type_name_to_index_.at(
539           NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
540       fanins_[dst_node_type_idx].push_back(src_node_type_idx);
541       fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
542     }
543   }
544 
545   // 2. Add graph edges to the adjacency lists.
546   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
547     const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
548     auto input_ports =
549         node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
550     fanins_[node_type_idx].reserve(input_ports.size());
551     for (int port : input_ports) {
552       const string& input = node_type.node->input(port);
553       TensorId tensor = ParseTensorName(input);
554       const auto it = node_name_to_index_.find(tensor.node());
555       const bool valid_input = it != node_name_to_index_.end();
556 
557       if (!valid_input) {
558         const string error_message = absl::StrCat(
559             "Non-existent input ", input, " in node ", node_type.node->name());
560         if (skip_invalid_edges_) {
561           VLOG(3) << "Skip error: " << error_message;
562         } else {
563           return errors::InvalidArgument(error_message);
564         }
565       }
566 
567       if (valid_input) {
568         const int input_idx = it->second;
569         const NodeDef& input_node = graph_->node(input_idx);
570         TypeAttrId input_type_attr =
571             node_type_map.GetOutputTypeAttr(input_node, tensor.index());
572         const auto it2 = node_type_name_to_index_.find(
573             NodeTypeKey(input_node.name(), input_type_attr));
574         if (it2 == node_type_name_to_index_.end()) {
575           if (!skip_invalid_edges_) {
576             return errors::InvalidArgument("Did not find type attr ",
577                                            input_type_attr.DebugString(),
578                                            " in node ", input_node.name());
579           }
580           continue;
581         }
582         int input_node_type_idx = it2->second;
583         fanins_[node_type_idx].push_back(input_node_type_idx);
584         fanouts_[input_node_type_idx].push_back(node_type_idx);
585       }
586     }
587 
588     // Dedup the input list while it's still hot in cache.
589     SortAndRemoveDuplicates(&fanins_[node_type_idx]);
590   }
591 
592   // Dedup outputs for all the graph nodes.
593   for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
594     SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
595   }
596 
597   return Status::OK();
598 }
599 
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)600 Status GraphTypeTopologyView::InitializeFromGraph(
601     const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
602   return InitializeFromGraph(graph, node_type_map,
603                              absl::Span<const NodeTypeIdEdge>());
604 }
605 
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const606 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
607                                     const TypeAttrId& type_attr) const {
608   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
609   NodeTypeKey key(node_name, type_attr);
610   const auto it = node_type_name_to_index_.find(key);
611   return it != node_type_name_to_index_.end();
612 }
613 
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const614 const NodeTypeId* GraphTypeTopologyView::GetNode(
615     absl::string_view node_name, const TypeAttrId& type_attr) const {
616   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
617   NodeTypeKey key(node_name, type_attr);
618   const auto it = node_type_name_to_index_.find(key);
619   return it == node_type_name_to_index_.end()
620              ? nullptr
621              : &node_type_attrs_.at(it->second);
622 }
623 
GetNode(int node_idx) const624 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
625   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
626   DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
627   return &node_type_attrs_.at(node_idx);
628 }
629 
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const630 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
631     absl::string_view node_name, const TypeAttrId& type_attr) const {
632   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
633   NodeTypeKey key(node_name, type_attr);
634   const auto it = node_type_name_to_index_.find(key);
635   DCHECK(it != node_type_name_to_index_.end())
636       << "Node doesn't exist in a graph";
637   return it == node_type_name_to_index_.end() ? absl::nullopt
638                                               : absl::make_optional(it->second);
639 }
640 
GetNodeIndex(const NodeTypeId & node) const641 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
642     const NodeTypeId& node) const {
643   return GetNodeIndex(node.node->name(), node.type_attr);
644 }
645 
GetFanin(int node_idx) const646 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
647     int node_idx) const {
648   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
649   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
650   DCHECK(is_valid_node_idx) << "node_idx is out of range";
651   return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
652 }
653 
GetFanout(int node_idx) const654 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
655     int node_idx) const {
656   DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
657   const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
658   DCHECK(is_valid_node_idx) << "node_idx is out of range";
659   return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
660 }
661 
662 enum class TypeTraversalDirection {
663   kFollowInputs,
664   kFollowOutputs,
665   kFollowInputsAndOutputs,
666 };
667 
668 // Encapsulate DFS callbacks that will be called during the graph traversal.
669 //
670 // If non-empty, the `pre_order` and `post_order` functors will be called on
671 // each reachable node (including the `from` nodes) in pre and post order. If
672 // loops are found, the `on_back_edge` functor will be called on the
673 // corresponding back edges. Moreover, the pre and post order will assume that
674 // these back edges will be cut.
675 struct DfsTypeCallbacks {
676   DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anon09ee8e770111::DfsTypeCallbacks677   DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
678                    std::function<void(int, int)> back_edge)
679       : pre_order(std::move(pre)),
680         post_order(std::move(post)),
681         on_back_edge(std::move(back_edge)) {}
682 
PreOrdertensorflow::grappler::__anon09ee8e770111::DfsTypeCallbacks683   static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
684     return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
685   }
686 
PostOrdertensorflow::grappler::__anon09ee8e770111::DfsTypeCallbacks687   static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
688     return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
689   }
690 
691   std::function<void(int)> pre_order;
692   std::function<void(int)> post_order;
693   std::function<void(int, int)> on_back_edge;
694 };
695 
696 // Encapsulate DFS predicates for traversing the graph.
697 //
698 // The `enter` predicate decides if traversal should enter the node, and the
699 // `advance` predicate decides if the traversal should follow inputs/outputs
700 // from the node.
701 //
702 // If predicates are empty (default initialized), it's assumed that we can enter
703 // into any node and advance from any node respectively.
704 struct DfsTypePredicates {
705   DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anon09ee8e770111::DfsTypePredicates706   DfsTypePredicates(std::function<bool(int)> enter,
707                     std::function<bool(int)> advance)
708       : enter(std::move(enter)), advance(std::move(advance)) {}
709 
Entertensorflow::grappler::__anon09ee8e770111::DfsTypePredicates710   static DfsTypePredicates Enter(std::function<bool(int)> enter) {
711     return DfsTypePredicates(std::move(enter), nullptr);
712   }
713 
Advancetensorflow::grappler::__anon09ee8e770111::DfsTypePredicates714   static DfsTypePredicates Advance(std::function<bool(int)> advance) {
715     return DfsTypePredicates(nullptr, std::move(advance));
716   }
717 
718   std::function<bool(int)> enter;
719   std::function<bool(int)> advance;
720 };
721 
722 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon09ee8e770111::DfsStackElem723   DfsStackElem(int node, bool children_visited, int src)
724       : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon09ee8e770111::DfsStackElem725   explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
726 
727   // Index of the node in the graph ∊ [0, num_nodes).
728   int node;
729   // `True` if visited all the input/output nodes (pushed all input/output nodes
730   // to the stack).
731   bool children_visited;
732   // Index of the node in the graph, from which we entered the `node`.
733   int src;
734 };
735 
736 enum class NodeState { kNotVisited, kVisiting, kDone };
737 
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)738 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
739                       const absl::Span<const NodeTypeId* const> from,
740                       const TypeTraversalDirection direction,
741                       const DfsTypePredicates& predicates,
742                       const DfsTypeCallbacks& callbacks) {
743   std::vector<DfsStackElem> stack;
744   stack.reserve(from.size());
745 
746   for (const NodeTypeId* node : from) {
747     const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
748     DCHECK(node_idx.has_value())
749         << "Illegal start node: " << node->node->name();
750     if (node_idx.has_value()) {
751       stack.emplace_back(node_idx.value());
752     }
753   }
754 
755   absl::flat_hash_map<int, NodeState> node_state;
756   while (!stack.empty()) {
757     DfsStackElem w = stack.back();
758     stack.pop_back();
759 
760     NodeState& state = node_state[w.node];
761     if (state == NodeState::kDone) continue;
762 
763     // Skip nodes that we should not enter.
764     if (predicates.enter && !predicates.enter(w.node)) {
765       state = NodeState::kDone;
766       continue;
767     }
768 
769     // We've processed all the children of this node.
770     if (w.children_visited) {
771       state = NodeState::kDone;
772       if (callbacks.post_order) {
773         callbacks.post_order(w.node);
774       }
775       continue;
776     }
777 
778     // Loop detected.
779     if (state == NodeState::kVisiting) {
780       if (callbacks.on_back_edge) {
781         callbacks.on_back_edge(w.src, w.node);
782       }
783       continue;
784     }
785 
786     state = NodeState::kVisiting;
787     if (callbacks.pre_order) {
788       callbacks.pre_order(w.node);
789     }
790 
791     // Enqueue the node again with the children_visited flag set to true.
792     stack.emplace_back(w.node, true, w.src);
793 
794     // Check if we can continue traversal from the current node.
795     if (predicates.advance && !predicates.advance(w.node)) {
796       continue;
797     }
798 
799     // Now enqueue the fanin/fanout nodes.
800     if (direction == TypeTraversalDirection::kFollowInputs ||
801         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
802       for (const int fanin : graph_type_view.GetFanin(w.node)) {
803         stack.emplace_back(fanin, false, w.node);
804       }
805     }
806     if (direction == TypeTraversalDirection::kFollowOutputs ||
807         direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
808       for (const int fanout : graph_type_view.GetFanout(w.node)) {
809         stack.emplace_back(fanout, false, w.node);
810       }
811     }
812   }
813 }
814 
AllowedDataTypes(const OpDef::AttrDef & attr_def)815 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
816   const auto& allowed_types = attr_def.allowed_values().list().type();
817   if (allowed_types.empty()) {
818     return AllTypes();
819   }
820   uint32 dtype_mask = 0;
821   for (int dtype : allowed_types) {
822     dtype_mask |= 1u << dtype;
823   }
824   return DataTypeSet(dtype_mask);
825 }
826 
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)827 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
828   if (t_attr_id.attr_name.empty()) {
829     return ToSet(t_attr_id.fixed_type);
830   }
831   const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
832   CHECK(attr_def);  // Crash Ok
833   return AllowedDataTypes(*attr_def);
834 }
835 
BuildCastNode(const MutableGraphView::OutputPort & src,bool to_fp16,const string & device)836 NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_fp16,
837                       const string& device) {
838   const char* cast_string = to_fp16 ? kCastToFp16 : kCastToFp32;
839   string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
840                                 cast_string, "-", kSuffix);
841   NodeDef node;
842   node.set_name(name);
843   node.set_op("Cast");
844   node.set_device(device);
845   node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
846   (*node.mutable_attr())["SrcT"].set_type(to_fp16 ? DT_FLOAT : DT_HALF);
847   (*node.mutable_attr())["DstT"].set_type(to_fp16 ? DT_HALF : DT_FLOAT);
848   (*node.mutable_attr())["Truncate"].set_b(false);
849   return node;
850 }
851 
ValidateLists(const gtl::FlatSet<string> & white_list,const gtl::FlatSet<string> & black_list,const gtl::FlatSet<string> & gray_list,const gtl::FlatSet<string> & clear_list)852 Status ValidateLists(const gtl::FlatSet<string>& white_list,
853                      const gtl::FlatSet<string>& black_list,
854                      const gtl::FlatSet<string>& gray_list,
855                      const gtl::FlatSet<string>& clear_list) {
856   std::vector<gtl::FlatSet<string>> lists{white_list, black_list, gray_list,
857                                           clear_list};
858   std::multiset<string> counts;
859   for (auto list : lists) {
860     counts.insert(list.begin(), list.end());
861   }
862   bool duplicates = false;
863   for (auto s : counts) {
864     if (counts.count(s) > 1) {
865       duplicates = true;
866       LOG(ERROR) << "Op present in multiple lists: " << s;
867     }
868   }
869   if (duplicates) {
870     return errors::InvalidArgument("Op lists have conflicting entries");
871   } else {
872     return Status::OK();
873   }
874 }
875 
HasInputOrOutputRefs(const NodeDef & node)876 bool HasInputOrOutputRefs(const NodeDef& node) {
877   const OpDef* op_def;
878   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
879   if (!status.ok()) {
880     return true;
881   }
882   for (const auto& input : op_def->input_arg()) {
883     if (input.is_ref()) {
884       return true;
885     }
886   }
887   for (const auto& output : op_def->output_arg()) {
888     if (output.is_ref()) {
889       return true;
890     }
891   }
892   return false;
893 }
894 
895 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)896 bool CanForceFP16(const NodeDef& node) {
897   return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
898          !IsStateful(node) && !HasInputOrOutputRefs(node);
899 }
900 
GetCudaVersion(const Cluster & cluster)901 int GetCudaVersion(const Cluster& cluster) {
902   auto devices = cluster.GetDevices();
903   for (const auto& device : devices) {
904     const DeviceProperties& device_properties = device.second;
905     if (device_properties.type() == "GPU") {
906       const auto& device_env = device_properties.environment();
907       auto it = device_env.find("cuda");
908       if (it != device_env.end()) {
909         string cuda_version_str = it->second;
910         return std::stoi(cuda_version_str);
911       }
912     }
913   }
914   return 0;
915 }
916 
GetCudnnVersion(const Cluster & cluster)917 int GetCudnnVersion(const Cluster& cluster) {
918   auto devices = cluster.GetDevices();
919   for (const auto& device : devices) {
920     const DeviceProperties& device_properties = device.second;
921     if (device_properties.type() == "GPU") {
922       const auto& device_env = device_properties.environment();
923       auto it = device_env.find("cudnn");
924       if (it != device_env.end()) {
925         string cudnn_version_str = it->second;
926         return std::stoi(cudnn_version_str);
927       }
928     }
929   }
930   return 0;
931 }
932 
933 class AutoMixedPrecisionImpl {
934  public:
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id)935   AutoMixedPrecisionImpl(Cluster* cluster,
936                          const std::unordered_set<string>& nodes_to_preserve,
937                          GraphDef* graph, string id)
938       : virtual_placer_(cluster->GetDevices()),
939         nodes_to_preserve_(nodes_to_preserve),
940         graph_(graph),
941         id_(id),
942         graph_view_(graph),
943         cuda_version_(GetCudaVersion(*cluster)),
944         cudnn_version_(GetCudnnVersion(*cluster)) {}
945 
946   Status Optimize();
947 
948  private:
949   typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
950   // Maps data structure object ops (e.g., StackV2) to the sets of nodes that
951   // write (e.g., StackPushV2) and read (e.g., StackPopV2) from them.
952   typedef absl::flat_hash_map<NodeTypeId,
953                               std::pair<NodeTypeIdSet, NodeTypeIdSet>>
954       DataStructureOpsMap;
955 
956   Status PrintDebugLogs(bool preop, size_t timestamp);
957   void LogSkippedNode(const NodeDef& node) const;
958   bool MustPreserve(const NodeDef& node) const;
959   bool IsOnGPU(const NodeDef& node) const;
960   bool IsOnSuitableGPUArch(const NodeDef& node) const;
961   bool ShouldProcess(const NodeDef& node) const;
962   bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
963   bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
964   void ConvertBatchNormOpsToV2();
965   bool SupportsFloat16(const NodeTypeId& node_type) const;
966   const NodeDef* GetTailOfChain(
967       const NodeDef& node, const absl::flat_hash_set<string>& match_ops) const;
968   Status AddDataStructureOpsToMap(
969       const absl::flat_hash_set<string>& data_structure_ops,
970       TypeAttrId data_structure_type_attr,
971       const absl::flat_hash_map<string, TypeAttrId>& write_ops,
972       const absl::flat_hash_map<string, TypeAttrId>& read_ops,
973       DataStructureOpsMap* object_clients_map) const;
974   void AddWhitelistOps(absl::flat_hash_set<int>* white_set) const;
975   void PropagateBlackFwdThroughClearAndGray(
976       absl::flat_hash_set<int>* black_set) const;
977   void ForceColorMatchBetweenDataStructureOps(
978       const DataStructureOpsMap& object_clients_map,
979       absl::flat_hash_set<int>* white_set,
980       absl::flat_hash_set<int>* black_set) const;
981   void AddClearAndGrayToWhiteIfBetweenWhite(
982       const absl::flat_hash_set<int>& black_set,
983       absl::flat_hash_set<int>* white_set) const;
984   void PropagateWhiteThroughClear(const absl::flat_hash_set<int>& black_set,
985                                   absl::flat_hash_set<int>* white_set) const;
986   Status ForceColorMatchOnRecurrentEdges(
987       absl::flat_hash_set<int>* white_set) const;
988   void MakeCastsWhiteIfAllOutputsWhite(
989       absl::flat_hash_set<int>* white_set) const;
990   Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
991 
992   VirtualPlacer virtual_placer_;
993   std::unordered_set<string> nodes_to_preserve_;
994   GraphDef* graph_;
995   string id_;
996   MutableGraphView graph_view_;
997   int cuda_version_;
998   int cudnn_version_;
999   NodeTypeAttrMap node_type_map_;
1000   GraphTypeTopologyView graph_type_view_;
1001   bool force_all_fp16_;
1002   gtl::FlatSet<string> fp16_whitelist_;
1003   gtl::FlatSet<string> fp16_blacklist_;
1004   gtl::FlatSet<string> fp16_graylist_;
1005   gtl::FlatSet<string> fp16_clearlist_;
1006   absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1007 };
1008 
NodeHasFP16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1009 bool AutoMixedPrecisionImpl::NodeHasFP16KernelForTypeAttr(
1010     const NodeDef& node, TypeAttrId taid) const {
1011   NodeDef node_copy(node);
1012   if (node.device().empty()) {
1013     string device_name = virtual_placer_.get_canonical_device_name(node);
1014     node_copy.set_device(device_name);
1015   }
1016   if (!SetDataType(&node_copy, taid, DataType::DT_HALF)) {
1017     return false;
1018   }
1019   return IsKernelRegisteredForNode(node_copy).ok();
1020 }
1021 
PrintDebugLogs(bool preop,size_t timestamp)1022 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1023   string prepend_path;
1024   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1025       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1026   if (prepend_path.empty()) return Status::OK();
1027 
1028   string suffix =
1029       strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1030 
1031   string fname =
1032       io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1033   std::fstream f;
1034   f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1035   f << graph_->SerializeAsString();
1036   f.close();
1037   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1038             << " graph as binary to " << fname;
1039 
1040   fname = io::JoinPath(prepend_path,
1041                        strings::StrCat("graphdef", suffix, ".pb.txt"));
1042   f.open(fname.c_str(), std::fstream::out);
1043   f << graph_->DebugString();
1044   f.close();
1045   LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1046             << " graph as text to " << fname;
1047 
1048   if (!preop) {
1049     fname = io::JoinPath(prepend_path,
1050                          strings::StrCat("paintbuckets", suffix, ".txt"));
1051     f.open(fname.c_str(), std::fstream::out);
1052     f << "WhiteList:\n";
1053     for (auto x :
1054          AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_)) {
1055       f << x << "\n";
1056     }
1057     f << "\nBlackList:\n";
1058     for (auto x : AutoMixedPrecisionLists::BlackList()) {
1059       f << x << "\n";
1060     }
1061     f << "\nGrayList:\n";
1062     for (auto x : AutoMixedPrecisionLists::GrayList()) {
1063       f << x << "\n";
1064     }
1065     f << "\nClearList:\n";
1066     for (auto x : AutoMixedPrecisionLists::ClearList()) {
1067       f << x << "\n";
1068     }
1069     f.close();
1070     LOG(INFO) << "Saved paint bucket info to " << fname;
1071   }
1072   return Status::OK();
1073 }
1074 
LogSkippedNode(const NodeDef & node) const1075 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
1076   VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1077           << " because it "
1078           << (MustPreserve(node)
1079                   ? "must be preserved"
1080                   : "is not on the GPU, or the GPU arch is not suitable");
1081 }
1082 
MustPreserve(const NodeDef & node) const1083 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1084   return nodes_to_preserve_.count(node.name());
1085 }
1086 
IsOnGPU(const NodeDef & node) const1087 bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
1088   string device_name;
1089   if (node.device().empty()) {
1090     device_name = virtual_placer_.get_canonical_device_name(node);
1091   } else {
1092     device_name = node.device();
1093   }
1094   string device;
1095   string not_used;
1096   if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
1097       absl::StrContains(absl::AsciiStrToLower(device),
1098                         absl::AsciiStrToLower(DEVICE_GPU))) {
1099     return true;
1100   }
1101   return false;
1102 }
1103 
1104 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)1105 std::pair<int, int> GetDeviceGPUArch(
1106     const DeviceProperties& device_properties) {
1107   if (device_properties.type() != "GPU") return {0, 0};
1108   string arch_str = device_properties.environment().at("architecture");
1109   std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
1110   if (split_arch_str.empty()) {
1111     return {0, 0};
1112   }
1113 
1114   int major, minor;
1115   if (!strings::safe_strto32(split_arch_str[0], &major)) {
1116     return {0, 0};
1117   }
1118 
1119   if (split_arch_str.size() > 1) {
1120     if (strings::safe_strto32(split_arch_str[1], &minor)) {
1121       return {major, minor};
1122     } else {
1123       return {0, 0};
1124     }
1125   } else {
1126     return {major, 0};
1127   }
1128 }
1129 
IsOnSuitableGPUArch(const NodeDef & node) const1130 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1131   return GetDeviceGPUArch(virtual_placer_.get_device(node)) >= kMinGPUArch;
1132 }
1133 
ShouldProcess(const NodeDef & node) const1134 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1135   return should_process_nodes_.count(&node);
1136 }
1137 
IsFloat32(const NodeTypeId & node_type)1138 bool IsFloat32(const NodeTypeId& node_type) {
1139   return GetDataType(*node_type.node, node_type.type_attr) ==
1140          DataType::DT_FLOAT;
1141 }
1142 
SupportsFloat16(const NodeTypeId & node_type) const1143 bool AutoMixedPrecisionImpl::SupportsFloat16(
1144     const NodeTypeId& node_type) const {
1145   const OpDef* op_def;
1146   Status status =
1147       OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1148   if (!status.ok()) return false;
1149   return AllowedDataTypes(*op_def, node_type.type_attr)
1150              .Contains(DataType::DT_HALF) &&
1151          NodeHasFP16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1152 }
1153 
1154 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1155 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1156 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1157   for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1158     NodeDef* node = graph_->mutable_node(node_idx);
1159     if (!ShouldProcess(*node)) continue;
1160     bool changed = false;
1161     if (node->op() == "FusedBatchNorm") {
1162       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1163               << " to FusedBatchNormV2";
1164       node->set_op("FusedBatchNormV2");
1165       changed = true;
1166     } else if (node->op() == "FusedBatchNormGrad") {
1167       VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1168               << " to FusedBatchNormGradV2";
1169       node->set_op("FusedBatchNormGradV2");
1170       changed = true;
1171     }
1172     if (changed) {
1173       (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1174     }
1175   }
1176 }
1177 
1178 // A helper function to decide whether to ignore the effect on performance when
1179 // rewriting the graph. This can be useful for testing the numerical effects of
1180 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1181 bool ShouldIgnorePerformance() {
1182   static bool is_enabled = [] {
1183     bool ret = false;
1184     TF_CHECK_OK(ReadBoolFromEnvVar(
1185         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1186         /*default_val=*/false, &ret));
1187     return ret;
1188   }();
1189   return is_enabled;
1190 }
1191 
Optimize()1192 Status AutoMixedPrecisionImpl::Optimize() {
1193   string optimization_level;
1194   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1195       "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1196   optimization_level = absl::AsciiStrToUpper(optimization_level);
1197   force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1198 
1199   fp16_whitelist_ =
1200       AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_);
1201   fp16_blacklist_ = AutoMixedPrecisionLists::BlackList();
1202   fp16_graylist_ = AutoMixedPrecisionLists::GrayList();
1203   fp16_clearlist_ = AutoMixedPrecisionLists::ClearList();
1204   TF_RETURN_IF_ERROR(ValidateLists(fp16_whitelist_, fp16_blacklist_,
1205                                    fp16_graylist_, fp16_clearlist_));
1206 
1207   size_t timestamp = Env::Default()->NowMicros() / 1000;
1208   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1209 
1210   VLOG(2) << "Identifying nodes that should be processed";
1211   for (const NodeDef& node : graph_->node()) {
1212     if (!MustPreserve(node) && IsOnGPU(node) &&
1213         (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node))) {
1214       should_process_nodes_.insert(&node);
1215     } else {
1216       LogSkippedNode(node);
1217     }
1218   }
1219 
1220   VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1221   ConvertBatchNormOpsToV2();
1222 
1223   VLOG(2) << "Building node type map for graph";
1224   TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1225 
1226   // Note: If an op is added to this list that has a data type attribute, it
1227   // should also be added to the AddDataStructureOpsToMap call below (and to the
1228   // clearlist if it involves data flow).
1229   // TODO(benbarsdell): Add support for TensorListPushBackBatch and
1230   // TensorListConcatLists. They require special handling because they connect
1231   // multiple list objects together. Currently if they appear in the graph then
1232   // we have no choice but to disallow changing any tensor list ops, as
1233   // otherwise we risk breaking the graph if some are changed and some are not
1234   // (within a connected cluster of tensor list nodes).
1235   const gtl::FlatSet<string> supported_list_ops = {
1236       "EmptyTensorList",
1237       "TensorListSplit",
1238       "TensorListFromTensor",
1239       "TensorListReserve",
1240       "TensorListScatter",
1241       "TensorListScatterV2",
1242       "TensorListPushBack",
1243       "TensorListSetItem",
1244       "TensorListScatterIntoExistingList",
1245       "TensorListPopBack",
1246       "TensorListStack",
1247       "TensorListConcat",
1248       "TensorListConcatV2",
1249       "TensorListGetItem",
1250       "TensorListGather",
1251       "TensorListLength",
1252       "TensorListElementShape",
1253       "TensorListResize"};
1254 
1255   bool can_change_tensor_list_ops = true;
1256   for (const NodeDef& node : graph_->node()) {
1257     if (absl::StartsWith(node.op(), "TensorList") &&
1258         !supported_list_ops.count(node.op())) {
1259       LOG(WARNING) << "Unsupported " << node.op() << " node found in graph ("
1260                    << node.name()
1261                    << "), tensor list ops will not be converted.";
1262       can_change_tensor_list_ops = false;
1263       break;
1264     }
1265   }
1266 
1267   DataStructureOpsMap object_clients_map;
1268   if (can_change_tensor_list_ops) {
1269     VLOG(2) << "Identifying TensorList* nodes";
1270     TF_RETURN_IF_ERROR(AddDataStructureOpsToMap(
1271         {"EmptyTensorList", "TensorListSplit", "TensorListFromTensor",
1272          "TensorListReserve", "TensorListScatter", "TensorListScatterV2"},
1273         TypeAttrId("element_dtype"),
1274         {{"TensorListPushBack", TypeAttrId("element_dtype")},
1275          {"TensorListSetItem", TypeAttrId("element_dtype")},
1276          {"TensorListScatterIntoExistingList", TypeAttrId("element_dtype")}},
1277         {{"TensorListPopBack", TypeAttrId("element_dtype")},
1278          {"TensorListStack", TypeAttrId("element_dtype")},
1279          {"TensorListConcat", TypeAttrId("element_dtype")},
1280          {"TensorListConcatV2", TypeAttrId("element_dtype")},
1281          {"TensorListGetItem", TypeAttrId("element_dtype")},
1282          {"TensorListGather", TypeAttrId("element_dtype")}},
1283         &object_clients_map));
1284   } else {
1285     for (const string& list_op : supported_list_ops) {
1286       fp16_whitelist_.erase(list_op);
1287       fp16_graylist_.erase(list_op);
1288       fp16_clearlist_.erase(list_op);
1289     }
1290   }
1291 
1292   // Create ephemeral edges between writers and readers of data structure ops.
1293   std::vector<NodeTypeIdEdge> ephemeral_edges;
1294   for (const auto& object_clients : object_clients_map) {
1295     const auto& client_nodes = object_clients.second;
1296     for (const NodeTypeId& write_node_type : client_nodes.first) {
1297       for (const NodeTypeId& read_node_type : client_nodes.second) {
1298         ephemeral_edges.emplace_back(write_node_type, read_node_type);
1299       }
1300     }
1301     const NodeTypeId& object_node_type = object_clients.first;
1302     // These object types also act as writers because they initialize the object
1303     // from an input tensor.
1304     if (object_node_type.node->op() == "TensorListSplit" ||
1305         object_node_type.node->op() == "TensorListFromTensor" ||
1306         object_node_type.node->op() == "TensorListScatter" ||
1307         object_node_type.node->op() == "TensorListScatterV2") {
1308       for (const NodeTypeId& read_node_type : client_nodes.second) {
1309         ephemeral_edges.emplace_back(object_node_type, read_node_type);
1310       }
1311     }
1312   }
1313 
1314   VLOG(2) << "Constructing graph type attribute topology view";
1315   TF_RETURN_IF_ERROR(graph_type_view_.InitializeFromGraph(
1316       *graph_, node_type_map_, ephemeral_edges));
1317 
1318   // The goal here is to change performance-critical ops to fp16, and to do so
1319   // with the minimal number of casts, subject to the constraint that the
1320   // model's convergence is not affected. This is achieved by first identifying
1321   // which nodes should be changed to fp16 and then inserting casts at the
1322   // boundaries between fp16/non-fp16 nodes.
1323 
1324   // The algorithm for deciding which nodes to change to fp16 is as follows:
1325   // 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
1326   //    This is done under the assumption that whitelist ops are always
1327   //    numerically-safe in fp16 and that they are the most important ops for
1328   //    improving performance.
1329   // 2) Add nodes to the black_set iff they are numerically-dangerous (aka
1330   //    "blacklist" ops) or they are on a forward path from a blacklist node to
1331   //    a black/gray node (including the node at the end of the path) through
1332   //    non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
1333   //    This is done to prevent numerically-dangerous ops and their downstream
1334   //    effects from being changed to fp16, which would risk breaking the
1335   //    numerical accuracy of the model.
1336   // 3) For all remaining nodes that are not considered dangerous (greylist
1337   //    and clearlist ops), find those that are between (i.e., both upstream
1338   //    and downstream of) white nodes, and add them to the white_set.
1339   //    This is done to avoid unnecessary casts between whitelist ops.
1340   // 4) For all remaining clearlist nodes, add them to the white_set if they are
1341   //    connected to a node in the white_set via other clearlist nodes.
1342   //    This is done to increase the number of ops in the white_set without
1343   //    affecting numerical stability.
1344 
1345   absl::flat_hash_set<int> white_set;
1346   VLOG(2) << "Beginning pass 1 to add whitelist ops";
1347   AddWhitelistOps(&white_set);
1348   VLOG(2) << "Finished pass 1";
1349 
1350   if (white_set.empty()) {
1351     LOG(INFO) << "No whitelist ops found, nothing to do";
1352     return Status::OK();
1353   }
1354 
1355   absl::flat_hash_set<int> black_set;
1356   VLOG(2) << "Beginning pass 2 to propagate black forwards from blacklist ops "
1357              "through clear/graylist ops";
1358   PropagateBlackFwdThroughClearAndGray(&black_set);
1359   VLOG(2) << "Finished pass 2";
1360 
1361   VLOG(2) << "Forcing color match between data structure ops";
1362   ForceColorMatchBetweenDataStructureOps(object_clients_map, &white_set,
1363                                          &black_set);
1364 
1365   VLOG(2) << "Beginning pass 3 to set clear and gray nodes to white if they "
1366              "are between white ops";
1367   AddClearAndGrayToWhiteIfBetweenWhite(black_set, &white_set);
1368   VLOG(2) << "Finished pass 3";
1369 
1370   VLOG(2) << "Beginning pass 4 to propagate white from white nodes through "
1371              "clearlist ops";
1372   PropagateWhiteThroughClear(black_set, &white_set);
1373   VLOG(2) << "Finished pass 4";
1374 
1375   VLOG(2) << "Forcing color match between data structure ops";
1376   ForceColorMatchBetweenDataStructureOps(object_clients_map, &white_set,
1377                                          &black_set);
1378 
1379   VLOG(2) << "Forcing color match on loop edges";
1380   TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&white_set));
1381 
1382   VLOG(2) << "Finding existing casts that can be made white";
1383   MakeCastsWhiteIfAllOutputsWhite(&white_set);
1384 
1385   VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1386              "ops at paint boundaries";
1387   TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(white_set));
1388   VLOG(2) << "Finished final pass";
1389 
1390   TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1391 
1392   return Status::OK();
1393 }
1394 
1395 // Finds data structure object ops (e.g., StackV2) and the sets of nodes that
1396 // write (e.g., StackPushV2) and read (e.g., StackPopV2) from them.
AddDataStructureOpsToMap(const absl::flat_hash_set<string> & data_structure_ops,TypeAttrId data_structure_type_attr,const absl::flat_hash_map<string,TypeAttrId> & write_ops,const absl::flat_hash_map<string,TypeAttrId> & read_ops,DataStructureOpsMap * object_clients_map) const1397 Status AutoMixedPrecisionImpl::AddDataStructureOpsToMap(
1398     const absl::flat_hash_set<string>& data_structure_ops,
1399     TypeAttrId data_structure_type_attr,
1400     const absl::flat_hash_map<string, TypeAttrId>& write_ops,
1401     const absl::flat_hash_map<string, TypeAttrId>& read_ops,
1402     DataStructureOpsMap* object_clients_map) const {
1403   for (const NodeDef& node : graph_->node()) {
1404     const auto write_iter = write_ops.find(node.op());
1405     const auto read_iter = read_ops.find(node.op());
1406     bool is_writer = write_iter != write_ops.end();
1407     bool is_reader = read_iter != read_ops.end();
1408     if (is_writer || is_reader) {
1409       const NodeDef* object_node = GetTailOfChain(node, data_structure_ops);
1410       if (!object_node) {
1411         return errors::FailedPrecondition(
1412             "No data structure op found upstream of ", node.op(), " node ",
1413             node.name());
1414       }
1415       NodeTypeId object_node_type(object_node, data_structure_type_attr);
1416       TypeAttrId type_attr = is_writer ? write_iter->second : read_iter->second;
1417       NodeTypeId node_type(&node, type_attr);
1418       auto* value = &(*object_clients_map)[object_node_type];
1419       auto* node_set = is_writer ? &value->first : &value->second;
1420       node_set->insert(node_type);
1421     }
1422   }
1423   return Status::OK();
1424 }
1425 
AddWhitelistOps(absl::flat_hash_set<int> * white_set) const1426 void AutoMixedPrecisionImpl::AddWhitelistOps(
1427     absl::flat_hash_set<int>* white_set) const {
1428   // Add whitelisted ops to white_set.
1429   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1430     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1431     if (!ShouldProcess(*root.node)) continue;
1432     bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
1433     if (fp16_whitelist_.count(root.node->op()) || force_white) {
1434       bool inserted = white_set->insert(root_idx).second;
1435       if (VLOG_IS_ON(2) && inserted) {
1436         VLOG(2) << "Painting type " << root.type_attr.DebugString()
1437                 << " of node " << root.node->name() << " WHITE because its op "
1438                 << root.node->op() << " is on the whitelist";
1439       }
1440     }
1441   }
1442 }
1443 
1444 // Adds nodes to black_set iff they are on the blacklist or they are on a
1445 // forward path from a blacklist node to a black/gray node (including the node
1446 // at the end of the path) through clear and gray nodes.
1447 // E.g., black -> gray -> clear -> gray -> clear -> white -> gray
1448 // becomes: black -> black -> black -> black -> clear -> white -> gray.
PropagateBlackFwdThroughClearAndGray(absl::flat_hash_set<int> * black_set) const1449 void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
1450     absl::flat_hash_set<int>* black_set) const {
1451   if (force_all_fp16_) return;
1452 
1453   // Find clear nodes that are upstream of black or gray.
1454   absl::flat_hash_set<int> upstream_of_black_or_gray_set;
1455   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1456     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1457     if (!(fp16_blacklist_.count(root.node->op()) ||
1458           fp16_graylist_.count(root.node->op()))) {
1459       continue;
1460     }
1461     DfsTypeTraversal(graph_type_view_, {&root},
1462                      TypeTraversalDirection::kFollowInputs,
1463                      DfsTypePredicates::Enter([&](int idx) -> bool {
1464                        const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1465                        return idx == root_idx ||
1466                               (!upstream_of_black_or_gray_set.count(idx) &&
1467                                fp16_clearlist_.count(item.node->op()));
1468                      }),
1469                      DfsTypeCallbacks::PreOrder([&](int idx) {
1470                        upstream_of_black_or_gray_set.insert(idx);
1471                      }));
1472   }
1473 
1474   // Propagate black forward through nodes in upstream_of_black_or_gray_set.
1475   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1476     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1477     if (black_set->count(root_idx) || !fp16_blacklist_.count(root.node->op())) {
1478       continue;
1479     }
1480     DfsTypeTraversal(
1481         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1482         DfsTypePredicates::Enter([&](int idx) -> bool {
1483           return idx == root_idx || (!black_set->count(idx) &&
1484                                      upstream_of_black_or_gray_set.count(idx));
1485         }),
1486         DfsTypeCallbacks::PreOrder([&](int idx) {
1487           bool inserted = black_set->insert(idx).second;
1488           if (VLOG_IS_ON(2) && inserted) {
1489             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1490             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1491                     << " of " << item.node->op() << " node "
1492                     << item.node->name() << " BLACK";
1493           }
1494         }));
1495   }
1496 }
1497 
AddClearAndGrayToWhiteIfBetweenWhite(const absl::flat_hash_set<int> & black_set,absl::flat_hash_set<int> * white_set) const1498 void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
1499     const absl::flat_hash_set<int>& black_set,
1500     absl::flat_hash_set<int>* white_set) const {
1501   // Find clear/graylist ops that are downstream of white ops.
1502   absl::flat_hash_set<int> downstream_of_white_set;
1503   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1504     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1505     if (!ShouldProcess(*root.node) || !fp16_whitelist_.count(root.node->op())) {
1506       continue;
1507     }
1508     DfsTypeTraversal(
1509         graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1510         DfsTypePredicates::Enter([&](int idx) -> bool {
1511           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1512           return idx == root_idx ||
1513                  (!downstream_of_white_set.count(idx) &&
1514                   !fp16_whitelist_.count(item.node->op()) &&
1515                   !black_set.count(idx) && ShouldProcess(*item.node) &&
1516                   // TODO(benbarsdell): Consider allowing propagation through
1517                   // ops that are already float16 in order to reduce the number
1518                   // of casts.
1519                   IsFloat32(item) && SupportsFloat16(item) &&
1520                   (fp16_clearlist_.count(item.node->op()) ||
1521                    fp16_graylist_.count(item.node->op())));
1522         }),
1523         DfsTypeCallbacks::PreOrder(
1524             [&](int idx) { downstream_of_white_set.insert(idx); }));
1525   }
1526 
1527   // Set nodes that are both downstream and upstream of white ops to white.
1528   absl::flat_hash_set<int> upstream_of_white_set;
1529   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1530     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1531     if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
1532         !fp16_whitelist_.count(root.node->op())) {
1533       continue;
1534     }
1535     DfsTypeTraversal(
1536         graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1537         DfsTypePredicates::Enter([&](int idx) -> bool {
1538           return idx == root_idx || (!upstream_of_white_set.count(idx) &&
1539                                      downstream_of_white_set.count(idx));
1540         }),
1541         DfsTypeCallbacks::PreOrder([&](int idx) {
1542           upstream_of_white_set.insert(idx);
1543           bool inserted = white_set->insert(idx).second;
1544           if (VLOG_IS_ON(2) && inserted) {
1545             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1546             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1547                     << " of " << item.node->op() << " node "
1548                     << item.node->name() << " WHITE";
1549           }
1550         }));
1551   }
1552 }
1553 
PropagateWhiteThroughClear(const absl::flat_hash_set<int> & black_set,absl::flat_hash_set<int> * white_set) const1554 void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
1555     const absl::flat_hash_set<int>& black_set,
1556     absl::flat_hash_set<int>* white_set) const {
1557   // Propagate white from white nodes through clearlist ops.
1558   absl::flat_hash_set<int> clear_prop_set;
1559   for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1560     const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1561     if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1562         !white_set->count(root_idx)) {
1563       continue;
1564     }
1565     DfsTypeTraversal(
1566         graph_type_view_, {&root},
1567         TypeTraversalDirection::kFollowInputsAndOutputs,
1568         DfsTypePredicates::Enter([&](int idx) -> bool {
1569           const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1570           return idx == root_idx ||
1571                  (!white_set->count(idx) && !black_set.count(idx) &&
1572                   ShouldProcess(*item.node) && IsFloat32(item) &&
1573                   SupportsFloat16(item) &&
1574                   (fp16_clearlist_.count(item.node->op())) &&
1575                   // We don't propagate (backwards) through nodes that read
1576                   // Variables because it can break the behavior of TensorBoard
1577                   // visualization and/or (in the case of Enter nodes) the model
1578                   // itself. This is only a problem for non-resource variables.
1579                   !NodeImplicitlyReadsNonResourceVariable(*item.node));
1580         }),
1581         DfsTypeCallbacks::PreOrder([&](int idx) {
1582           clear_prop_set.insert(idx);
1583           bool inserted = white_set->insert(idx).second;
1584           if (VLOG_IS_ON(2) && inserted) {
1585             const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1586             VLOG(2) << "Painting type " << item.type_attr.DebugString()
1587                     << " of " << item.node->op() << " node "
1588                     << item.node->name() << " WHITE";
1589           }
1590         }));
1591   }
1592 }
1593 
1594 // Forces NextIteration nodes and their output Merge node(s) to have the same
1595 // color. Specifically, it removes them all from white_set if any of the Merge
1596 // nodes is not in white_set, otherwise it adds the NextIteration node to
1597 // white_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * white_set) const1598 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1599     absl::flat_hash_set<int>* white_set) const {
1600   for (const NodeDef& node : graph_->node()) {
1601     if (node.op() == "NextIteration") {
1602       GraphView::OutputPort output_port(&node, 0);
1603       const auto& fanout = graph_view_.GetFanout(output_port);
1604       std::vector<int> merge_idxs;
1605       merge_idxs.reserve(fanout.size());
1606       bool any_merge_is_not_white = false;
1607       for (const auto& output : fanout) {
1608         const NodeDef& merge_node = *output.node;
1609         if (merge_node.op() != "Merge") {
1610           return errors::FailedPrecondition(
1611               "Expected Merge node after NextIteration, got ", merge_node.op());
1612         }
1613         const absl::optional<int> maybe_merge_idx =
1614             graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1615         if (!maybe_merge_idx.has_value()) {
1616           return errors::Internal("Type attribute T of Merge node ",
1617                                   merge_node.name(),
1618                                   " not found in graph view");
1619         }
1620         int merge_idx = maybe_merge_idx.value();
1621         merge_idxs.push_back(merge_idx);
1622         any_merge_is_not_white =
1623             any_merge_is_not_white || !white_set->count(merge_idx);
1624       }
1625       const absl::optional<int> maybe_nextiter_idx =
1626           graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1627       if (!maybe_nextiter_idx.has_value()) {
1628         return errors::Internal("Type attribute T of NextIteration node ",
1629                                 node.name(), " not found in graph view");
1630       }
1631       int nextiter_idx = maybe_nextiter_idx.value();
1632       if (any_merge_is_not_white) {
1633         for (int merge_idx : merge_idxs) {
1634           if (white_set->erase(merge_idx)) {
1635             VLOG(2) << "Painting type T of Merge node "
1636                     << graph_type_view_.GetNode(merge_idx)->node->name()
1637                     << " BLACK to match the color of its sibling Merge nodes "
1638                        "with common NextIteration node "
1639                     << node.name();
1640           }
1641         }
1642         if (white_set->erase(nextiter_idx)) {
1643           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1644                   << " BLACK to match the color of its output Merge node(s)";
1645         }
1646       } else {
1647         if (white_set->insert(nextiter_idx).second) {
1648           VLOG(2) << "Painting type T of NextIteration node " << node.name()
1649                   << " WHITE to match the color of its output Merge node(s)";
1650         }
1651       }
1652     }
1653   }
1654   return Status::OK();
1655 }
1656 
1657 // Returns the last node in the simple chain starting at node and traversing
1658 // backwards through the input(0) edge from each node until one with a matching
1659 // op is found, or nullptr if no matching node is found.
GetTailOfChain(const NodeDef & node,const absl::flat_hash_set<string> & match_ops) const1660 const NodeDef* AutoMixedPrecisionImpl::GetTailOfChain(
1661     const NodeDef& node, const absl::flat_hash_set<string>& match_ops) const {
1662   const NodeDef* node_ptr = &node;
1663   do {
1664     GraphView::InputPort node_input(node_ptr, 0);
1665     MutableGraphView::OutputPort prev_output =
1666         graph_view_.GetRegularFanin(node_input);
1667     node_ptr = prev_output.node;
1668   } while (node_ptr && !match_ops.count(node_ptr->op()));
1669   return node_ptr;
1670 }
1671 
1672 // Ensures that data structure nodes (e.g., StackV2) and all of their associated
1673 // client nodes (e.g., StackPushV2 and StackPopV2) are in the same color set.
ForceColorMatchBetweenDataStructureOps(const DataStructureOpsMap & object_clients_map,absl::flat_hash_set<int> * white_set,absl::flat_hash_set<int> * black_set) const1674 void AutoMixedPrecisionImpl::ForceColorMatchBetweenDataStructureOps(
1675     const DataStructureOpsMap& object_clients_map,
1676     absl::flat_hash_set<int>* white_set,
1677     absl::flat_hash_set<int>* black_set) const {
1678   for (const auto& object_clients : object_clients_map) {
1679     const NodeTypeId& object_node_type = object_clients.first;
1680     const auto& client_nodes = object_clients.second;
1681     NodeTypeIdSet all_client_nodes = client_nodes.first;
1682     all_client_nodes.insert(client_nodes.second.begin(),
1683                             client_nodes.second.end());
1684     // The object node may be considered a client too (e.g.,
1685     // TensorListFromTensor).
1686     all_client_nodes.insert(object_node_type);
1687     bool any_black = false;
1688     bool any_white = false;
1689     for (const NodeTypeId& node_type : all_client_nodes) {
1690       const absl::optional<int> maybe_node_idx =
1691           graph_type_view_.GetNodeIndex(node_type);
1692       DCHECK(maybe_node_idx.has_value())
1693           << "Type attribute " << node_type.type_attr.DebugString()
1694           << " of node " << node_type.node->name()
1695           << " not found in graph view";
1696       int node_idx = maybe_node_idx.value();
1697       if (black_set->count(node_idx)) {
1698         any_black = true;
1699         break;
1700       } else if (white_set->count(node_idx)) {
1701         any_white = true;
1702       }
1703     }
1704     if (any_black || any_white) {
1705       for (const NodeTypeId& node_type : all_client_nodes) {
1706         VLOG(2) << "Painting type " << node_type.type_attr.DebugString()
1707                 << " of " << node_type.node->op() << " node "
1708                 << node_type.node->name() << " "
1709                 << (any_black ? "BLACK" : "WHITE")
1710                 << " because at least one of its siblings is "
1711                 << (any_black ? "BLACK" : "WHITE");
1712         const absl::optional<int> maybe_node_idx =
1713             graph_type_view_.GetNodeIndex(node_type);
1714         DCHECK(maybe_node_idx.has_value())
1715             << "Type attribute " << node_type.type_attr.DebugString()
1716             << " of node " << node_type.node->name()
1717             << " not found in graph view";
1718         int node_idx = maybe_node_idx.value();
1719         if (any_black) {
1720           white_set->erase(node_idx);
1721           black_set->insert(node_idx);
1722         } else {
1723           white_set->insert(node_idx);
1724         }
1725       }
1726     }
1727   }
1728 }
1729 
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const1730 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
1731     const NodeDef& node) const {
1732   if (node.op() == "Identity" || node.op() == "Enter") {
1733     GraphView::InputPort node_input(&node, 0);
1734     MutableGraphView::OutputPort prev_output =
1735         graph_view_.GetRegularFanin(node_input);
1736     const NodeDef* input = prev_output.node;
1737     if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
1738                                                input->op() == "VariableV2")) ||
1739                   (node.op() == "Enter" &&
1740                    NodeImplicitlyReadsNonResourceVariable(*input)))) {
1741       return true;
1742     }
1743   }
1744   return false;
1745 }
1746 
1747 // This adds existing Cast nodes to white_set if all of their outputs are white,
1748 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsWhiteIfAllOutputsWhite(absl::flat_hash_set<int> * white_set) const1749 void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
1750     absl::flat_hash_set<int>* white_set) const {
1751   int num_nodes_preop = graph_->node_size();
1752   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1753     NodeDef* node = graph_->mutable_node(node_idx);
1754     NodeTypeId node_type(node, TypeAttrId("DstT"));
1755     if (node->op() != "Cast" || !IsFloat32(node_type)) {
1756       continue;
1757     }
1758     bool all_fanouts_white = true;
1759     MutableGraphView::OutputPort src(node, 0);
1760     const auto& fanout = graph_view_.GetFanout(src);
1761     for (const MutableGraphView::InputPort& dst : fanout) {
1762       TypeAttrId dst_type_attr =
1763           node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1764       const absl::optional<int> maybe_dst_type_idx =
1765           graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1766       DCHECK(maybe_dst_type_idx.has_value())
1767           << "Type attribute " << dst_type_attr.DebugString() << " of node "
1768           << dst.node->name() << " not found in graph view";
1769       int dst_type_idx = maybe_dst_type_idx.value();
1770       bool dst_is_white = white_set->count(dst_type_idx);
1771       if (!dst_is_white) {
1772         all_fanouts_white = false;
1773         break;
1774       }
1775     }
1776     if (!fanout.empty() && all_fanouts_white) {
1777       const absl::optional<int> maybe_node_type_idx =
1778           graph_type_view_.GetNodeIndex(node_type);
1779       DCHECK(maybe_node_type_idx.has_value())
1780           << "Type attribute " << node_type.type_attr.DebugString()
1781           << " of node " << node_type.node->name()
1782           << " not found in graph view";
1783       int node_type_idx = maybe_node_type_idx.value();
1784       white_set->insert(node_type_idx);
1785     }
1786   }
1787 }
1788 
1789 // Changes all white-painted type attributes to DT_HALF, and inserts Cast nodes
1790 // at node outputs for all edges that connect white-painted <->
1791 // non-white-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & white_set)1792 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
1793     const absl::flat_hash_set<int>& white_set) {
1794   int num_nodes_changed = 0;
1795   int num_nonvar_casts_to_fp16 = 0;
1796   int num_nodes_preop = graph_->node_size();
1797   for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1798     NodeDef* node = graph_->mutable_node(node_idx);
1799     for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
1800       const absl::optional<int> maybe_node_type_idx =
1801           graph_type_view_.GetNodeIndex(node->name(), type_attr);
1802       if (!maybe_node_type_idx.has_value()) {
1803         return errors::Internal("Type attribute ", type_attr.DebugString(),
1804                                 " of ", node->op(), " node ", node->name(),
1805                                 " not found in graph view");
1806       }
1807       int node_type_idx = maybe_node_type_idx.value();
1808       if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
1809       bool src_is_white = white_set.count(node_type_idx);
1810       if (src_is_white) {
1811         VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
1812                 << node->op() << " node " << node->name() << " to DT_HALF";
1813         if (!SetDataType(node, type_attr, DT_HALF)) {
1814           return errors::Internal("Failed to set type attribute");
1815         }
1816         ++num_nodes_changed;
1817       }
1818       for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
1819         MutableGraphView::OutputPort src(node, output_port);
1820         NodeDef* added_cast_node = nullptr;
1821         // Note: This is copied so that edges can be modified inside the loop.
1822         auto fanout = graph_view_.GetFanout(src);
1823         for (const MutableGraphView::InputPort& dst : fanout) {
1824           TypeAttrId dst_type_attr =
1825               node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1826           const absl::optional<int> maybe_dst_type_idx =
1827               graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1828           if (!maybe_dst_type_idx.has_value()) {
1829             return errors::Internal("Type attribute ",
1830                                     dst_type_attr.DebugString(), " of ",
1831                                     dst.node->op(), " node ", dst.node->name(),
1832                                     " not found in graph view");
1833           }
1834           int dst_type_idx = maybe_dst_type_idx.value();
1835           bool dst_is_white = white_set.count(dst_type_idx);
1836           if (src_is_white != dst_is_white) {
1837             if (!added_cast_node) {
1838               bool to_fp16 = dst_is_white;
1839               VLOG(1) << "Inserting cast to "
1840                       << (to_fp16 ? "DT_HALF" : "DT_FLOAT") << " at "
1841                       << src.node->op() << " " << src.node->name() << ":"
1842                       << src.port_id;
1843               added_cast_node = graph_view_.AddNode(
1844                   BuildCastNode(src, to_fp16, src.node->device()));
1845               if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) &&
1846                   !NodeImplicitlyReadsNonResourceVariable(*node)) {
1847                 ++num_nonvar_casts_to_fp16;
1848               }
1849             }
1850             TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
1851                 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
1852           }
1853         }
1854       }
1855     }
1856   }
1857   LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
1858             << " nodes to float16 precision using " << num_nonvar_casts_to_fp16
1859             << " cast(s) to float16 (excluding Const and Variable casts)";
1860   return Status::OK();
1861 }
1862 
GetNumGPUs(const Cluster & cluster,const std::pair<int,int> & min_arch={0, 0})1863 int GetNumGPUs(const Cluster& cluster,
1864                const std::pair<int, int>& min_arch = {0, 0}) {
1865   auto devices = cluster.GetDevices();
1866   int num_gpus = 0;
1867   for (const auto& device : devices) {
1868     const DeviceProperties& device_properties = device.second;
1869     std::pair<int, int> arch = GetDeviceGPUArch(device_properties);
1870     if (device_properties.type() == "GPU" && arch >= min_arch) {
1871       num_gpus++;
1872     }
1873   }
1874   return num_gpus;
1875 }
1876 
1877 }  // end namespace
1878 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)1879 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
1880                                     GraphDef* output) {
1881   if (cluster == nullptr) {
1882     return errors::InvalidArgument("cluster == nullptr");
1883   }
1884 
1885   // Start by copying input graph to output.
1886   *output = item.graph;
1887 
1888   int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
1889                                            : GetNumGPUs(*cluster, kMinGPUArch);
1890   if (num_gpus < 1) {
1891     // AutoMixedPrecision is currently only tuned for GPU.
1892     LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
1893                  << " graph optimizer";
1894     return Status::OK();
1895   }
1896 
1897   // Optimize the output graph in-place.
1898   AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
1899                                    item.id);
1900   if (item.id == "tf_graph") {
1901     LOG(INFO) << "Running " << name() << " graph optimizer";
1902   } else {
1903     VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
1904   }
1905   Status status = optimizer.Optimize();
1906   if (!status.ok()) {
1907     // Restore the original graph.
1908     *output = item.graph;
1909     LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
1910   }
1911   return status;
1912 }
1913 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)1914 void AutoMixedPrecision::Feedback(Cluster* cluster, const GrapplerItem& item,
1915                                   const GraphDef& optimize_output,
1916                                   double result) {
1917   // Nothing to do for AutoMixedPrecision.
1918 }
1919 
1920 }  // end namespace grappler
1921 }  // end namespace tensorflow
1922