1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17
18 #include <fstream>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/costs/virtual_placer.h"
29 #include "tensorflow/core/grappler/devices.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
34 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
35 #include "tensorflow/core/grappler/utils.h"
36 #include "tensorflow/core/lib/io/path.h"
37 #include "tensorflow/core/lib/strings/numbers.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/util/env_var.h"
42
43 namespace tensorflow {
44 namespace grappler {
45 namespace {
46
47 const std::pair<int, int> kMinGPUArch = {7, 0};
48
49 const char kSuffix[] = "AutoMixedPrecision";
50 const char kCastToFp16[] = "CastToFp16";
51 const char kCastToFp32[] = "CastToFp32";
52
53 // Instances of this class represent unique type attribute identifiers within a
54 // node. It handles regular type attributes, list type attributes (where
55 // type_index is set to the index in the type list), and fixed types.
56 struct TypeAttrId {
57 static const int kSingleType = -1;
58
TypeAttrIdtensorflow::grappler::__anon09ee8e770111::TypeAttrId59 explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
60 : attr_name(_attr_name),
61 type_index(_type_index),
62 fixed_type(DT_INVALID) {}
63
TypeAttrIdtensorflow::grappler::__anon09ee8e770111::TypeAttrId64 explicit TypeAttrId(DataType _fixed_type)
65 : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
66
operator ==tensorflow::grappler::__anon09ee8e770111::TypeAttrId67 bool operator==(const TypeAttrId& other) const {
68 return attr_name == other.attr_name && type_index == other.type_index &&
69 fixed_type == other.fixed_type;
70 }
71
operator <tensorflow::grappler::__anon09ee8e770111::TypeAttrId72 bool operator<(const TypeAttrId& other) const {
73 return std::make_tuple(attr_name, type_index, fixed_type) <
74 std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
75 }
76
77 template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)78 friend H AbslHashValue(H h, const TypeAttrId& ta) {
79 return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
80 }
81
DebugStringtensorflow::grappler::__anon09ee8e770111::TypeAttrId82 string DebugString() const {
83 if (!attr_name.empty()) {
84 if (type_index == kSingleType) {
85 return attr_name;
86 } else {
87 return strings::StrCat(attr_name, "[", type_index, "]");
88 }
89 } else {
90 return tensorflow::DataTypeString(fixed_type);
91 }
92 }
93
94 string attr_name;
95 // If attr_name is a list(type), this is the index into the list. Otherwise
96 // this is kSingleType.
97 int type_index;
98 DataType fixed_type;
99 };
100
101 // Returns the data type of the given type attribute, or DT_INVALID if the type
102 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)103 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
104 if (type_attr.attr_name.empty()) {
105 return type_attr.fixed_type;
106 }
107 if (!node.attr().count(type_attr.attr_name)) {
108 return DT_INVALID;
109 }
110 const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
111 if (type_attr.type_index == TypeAttrId::kSingleType) {
112 return attr_value.type();
113 } else {
114 if (type_attr.type_index < 0 ||
115 type_attr.type_index >= attr_value.list().type_size()) {
116 return DT_INVALID;
117 }
118 return attr_value.list().type(type_attr.type_index);
119 }
120 }
121
122 // Sets the data type of the given type attribute. Returns false if the type
123 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)124 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
125 if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
126 return false;
127 }
128 AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
129 if (type_attr.type_index == TypeAttrId::kSingleType) {
130 attr_value.set_type(type);
131 } else {
132 if (type_attr.type_index < 0 ||
133 type_attr.type_index >= attr_value.list().type_size()) {
134 return false;
135 }
136 attr_value.mutable_list()->set_type(type_attr.type_index, type);
137 }
138 return true;
139 }
140
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)141 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
142 const OpDef::ArgDef& arg_def) {
143 std::vector<std::pair<int, int>> argdef_inds;
144 if (!arg_def.type_list_attr().empty()) {
145 int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
146 for (int type_idx = 0; type_idx < num_types; ++type_idx) {
147 argdef_inds.push_back({arg_idx, type_idx});
148 }
149 } else {
150 int num_repeat = 1;
151 if (node.attr().count(arg_def.number_attr())) {
152 num_repeat = node.attr().at(arg_def.number_attr()).i();
153 }
154 argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
155 }
156 return argdef_inds;
157 }
158
159 // Returns a pair (arg_index, type_index) for each input to the node, where
160 // arg_index is the index of the input_arg in op_def and type_index is the index
161 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)162 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
163 const OpDef& op_def) {
164 std::vector<std::pair<int, int>> argdef_inds;
165 argdef_inds.reserve(op_def.input_arg_size()); // Final size may differ.
166 for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
167 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
168 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
169 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
170 arg_results.end());
171 }
172 return argdef_inds;
173 }
174
175 // Returns a pair (arg_index, type_index) for each output to the node, where
176 // arg_index is the index of the output_arg in op_def and type_index is the
177 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)178 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
179 const OpDef& op_def) {
180 std::vector<std::pair<int, int>> argdef_inds;
181 argdef_inds.reserve(op_def.output_arg_size()); // Final size may differ.
182 for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
183 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
184 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
185 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
186 arg_results.end());
187 }
188 return argdef_inds;
189 }
190
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)191 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
192 if (!arg_def.type_list_attr().empty()) {
193 return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
194 } else if (!arg_def.type_attr().empty()) {
195 return TypeAttrId(arg_def.type_attr());
196 } else {
197 return TypeAttrId(arg_def.type());
198 }
199 }
200
NonControlInputs(const NodeDef & node)201 std::vector<int> NonControlInputs(const NodeDef& node) {
202 std::vector<int> pos;
203 for (int i = 0; i < node.input_size(); i++) {
204 if (!IsControlInput(node.input(i))) {
205 pos.push_back(i);
206 }
207 }
208 return pos;
209 }
210
211 // A utility class to lookup node type attributes and type attribute <->
212 // input/output port mappings.
213 class NodeTypeAttrMap {
214 public:
NodeTypeAttrMap()215 NodeTypeAttrMap() {}
216
NodeTypeAttrMap(const GraphDef & graph)217 explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
218
Init(const GraphDef & graph)219 Status Init(const GraphDef& graph) {
220 if (graph_ != nullptr) {
221 return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
222 }
223 graph_ = &graph;
224 function_library_.reset(
225 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
226 for (const NodeDef& node : graph.node()) {
227 TF_RETURN_IF_ERROR(AddNode(node));
228 }
229 return Status::OK();
230 }
231
is_initialized() const232 bool is_initialized() const { return graph_ != nullptr; }
233
234 // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const235 absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
236 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
237 absl::flat_hash_set<TypeAttrId> type_attrs;
238 const auto iter = type2io_.find(&node);
239 CHECK(iter != type2io_.end()); // Crash Ok
240 for (const auto& key_value : iter->second) {
241 type_attrs.insert(key_value.first);
242 }
243 return type_attrs;
244 }
245
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const246 const absl::flat_hash_set<int>& GetInputPorts(
247 const NodeDef& node, const TypeAttrId& type_attr) const {
248 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
249 return type2io_.at(&node).at(type_attr).first;
250 }
251
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const252 const absl::flat_hash_set<int>& GetOutputPorts(
253 const NodeDef& node, const TypeAttrId& type_attr) const {
254 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
255 return type2io_.at(&node).at(type_attr).second;
256 }
257
GetInputTypeAttr(const NodeDef & node,int port) const258 TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
259 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
260 auto type_vec = io2type_.at(&node).first;
261 CHECK_GE(port, 0); // Crash Ok
262 CHECK_LT(port, type_vec.size()); // Crash Ok
263 return type_vec[port];
264 }
265
GetOutputTypeAttr(const NodeDef & node,int port) const266 TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
267 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
268 auto type_vec = io2type_.at(&node).second;
269 CHECK_GE(port, 0); // Crash Ok
270 CHECK_LT(port, type_vec.size()); // Crash Ok
271 return type_vec[port];
272 }
273
274 private:
AddNode(const NodeDef & node)275 Status AddNode(const NodeDef& node) {
276 const OpDef* op_def_ptr = nullptr;
277 TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
278 const OpDef& op_def = *op_def_ptr;
279 auto& type2io_entry = type2io_[&node];
280 auto& io2type_entry = io2type_[&node];
281 auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
282 if (NonControlInputs(node).size() != input_arg_inds.size()) {
283 return errors::InvalidArgument(
284 "Expected ", node.op(), " node ", node.name(), " to have ",
285 input_arg_inds.size(), " non-control input(s), but got ",
286 node.input_size());
287 }
288 // Note that the mappings generated here include inputs/outputs with fixed
289 // types. This makes the mappings complete (all inputs and outputs are
290 // included), and allows the graph rewriter to propagate black paint
291 // from/through ops with fixed types.
292 io2type_entry.first.reserve(input_arg_inds.size());
293 for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
294 const auto& arg_inds = input_arg_inds[i];
295 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
296 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
297 if (!type_attr.attr_name.empty() &&
298 !node.attr().count(type_attr.attr_name)) {
299 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
300 " is not present in node ", node.name());
301 }
302 type2io_entry[type_attr].first.insert(i);
303 io2type_entry.first.push_back(type_attr);
304 }
305
306 auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
307 io2type_entry.second.reserve(output_arg_inds.size());
308 for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
309 const auto& arg_inds = output_arg_inds[i];
310 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
311 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
312 if (!type_attr.attr_name.empty() &&
313 !node.attr().count(type_attr.attr_name)) {
314 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
315 " is not present in node ", node.name());
316 }
317 type2io_entry[type_attr].second.insert(i);
318 io2type_entry.second.push_back(type_attr);
319 }
320
321 // Also ensure that type attributes that aren't associated with any inputs
322 // or outputs (e.g., StackV2's elem_type) are added to the map.
323 for (const auto& attr : node.attr()) {
324 const string& attr_name = attr.first;
325 if (!attr_name.empty() && attr_name[0] == '_') continue;
326 const AttrValue& attr_value = attr.second;
327 const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
328 if (!attr_def) {
329 return errors::InvalidArgument("AttrDef not found for attribute ",
330 attr_name, " of node ", node.name());
331 }
332 if (attr_def->type() == "type") {
333 type2io_entry[TypeAttrId(attr_name)];
334 } else if (attr_def->type() == "list(type)") {
335 for (int i = 0; i < attr_value.list().type_size(); ++i) {
336 type2io_entry[TypeAttrId(attr_name, i)];
337 }
338 }
339 }
340 return Status::OK();
341 }
342
343 // WARN: `graph_` must outlive this object (node pointers must remain valid).
344 const GraphDef* graph_ = nullptr; // do not own
345 std::unique_ptr<FunctionLibraryDefinition> function_library_;
346
347 typedef absl::flat_hash_set<int> IntSet;
348 // Maps a type attr id -> (input port set, output port set)
349 typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
350 // Maps a node -> type attr mapping
351 absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
352 // Maps a port -> type attr id
353 typedef std::vector<TypeAttrId> TypeAttrIdVec;
354 // Maps a node -> (input port mapping, output port mapping)
355 absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
356 io2type_;
357 };
358
359 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anon09ee8e770111::NodeTypeId360 NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
361 : node(_node), type_attr(_type_attr) {}
362
363 const NodeDef* node;
364 TypeAttrId type_attr;
365
operator ==tensorflow::grappler::__anon09ee8e770111::NodeTypeId366 bool operator==(const NodeTypeId& other) const {
367 return node == other.node && type_attr == other.type_attr;
368 }
369
370 template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)371 friend H AbslHashValue(H h, const NodeTypeId& nt) {
372 return H::combine(std::move(h), nt.node, nt.type_attr);
373 }
374 };
375
376 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anon09ee8e770111::NodeTypeIdEdge377 NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
378 : src(_src), dst(_dst) {}
379 NodeTypeId src;
380 NodeTypeId dst;
381 };
382
383 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
384 // used instead of this modified version.
385 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
386 // the vertices instead of just NodeDef.
387 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
388 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
389 // will be an edge from (A, T) to (B, U).
390 class GraphTypeTopologyView {
391 public:
392 GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)393 explicit GraphTypeTopologyView(bool skip_invalid_edges)
394 : skip_invalid_edges_(skip_invalid_edges) {}
395
396 // Initialize graph topology view from the graph. It's possible to pass
397 // additional edges that do not exist in a graph, but must be respected when
398 // computing graph topology. Example: Tensorflow runtime allows concurrent
399 // execution of dequeue/enqueue ops from the same queue resource, but we might
400 // want to enforce ordering between them for the purpose of graph analysis.
401 Status InitializeFromGraph(const GraphDef& graph,
402 const NodeTypeAttrMap& node_type_map,
403 absl::Span<const NodeTypeIdEdge> ephemeral_edges);
404 Status InitializeFromGraph(const GraphDef& graph,
405 const NodeTypeAttrMap& node_type_map);
406
is_initialized() const407 bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const408 int num_nodes() const { return num_nodes_; }
graph() const409 const GraphDef* graph() const { return graph_; }
410
411 // Returns true iff the node exists in the underlying graph.
412 bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
413
414 // Finds a node by name or returns `nullptr` if it's not in the graph.
415 const NodeTypeId* GetNode(absl::string_view node_name,
416 const TypeAttrId& type_attr) const;
417 // Returns a node corresponding to the given node index.
418 const NodeTypeId* GetNode(int node_idx) const;
419
420 // Returns a node index for the given node name, if the name exists in the
421 // underlying graph. Otherwise returns empty optional.
422 const absl::optional<int> GetNodeIndex(absl::string_view node_name,
423 const TypeAttrId& type_attr) const;
424 // Returns a node index for the given node, if the node belongs to the
425 // underlying graph. Otherwise returns empty optional.
426 const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
427
428 // Returns all the node indexes that are in the direct fanin of the given
429 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
430 const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
431 // Returns all the node indexes that are in the direct fanout of the given
432 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
433 const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
434
435 private:
436 // The key type used to uniquely identify a type attribute on a node.
437 struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
438 typedef std::pair<absl::string_view, TypeAttrId> Base;
439
440 // Inherit the set of constructors.
441 using Base::pair;
442
443 template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)444 friend H AbslHashValue(H h, const NodeTypeKey& nt) {
445 return H::combine(std::move(h), nt.first, nt.second);
446 }
447 };
448
449 // If true, all invalid edges and inputs (srd, dst or input node not found in
450 // a graph) will be skipped, otherwise initialization will fail with error.
451 bool skip_invalid_edges_ = false;
452
453 // WARN: `graph_` must outlive this object and graph nodes must not be
454 // destructed, because node names captured with absl::string_view.
455 const GraphDef* graph_ = nullptr; // do not own
456 int num_nodes_ = 0;
457 std::vector<NodeTypeId> node_type_attrs_;
458 absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
459 absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
460
461 std::vector<absl::InlinedVector<int, 4>> fanins_;
462 std::vector<absl::InlinedVector<int, 2>> fanouts_;
463
464 // We need a valid reference to return from GetFanin/GetFanout if the
465 // `node_idx` argument is outside of the [0, num_nodes_) range.
466 absl::InlinedVector<int, 4> empty_fanin_;
467 absl::InlinedVector<int, 2> empty_fanout_;
468 };
469
470 template <typename T>
SortAndRemoveDuplicates(T * v)471 inline void SortAndRemoveDuplicates(T* v) {
472 std::sort(v->begin(), v->end());
473 v->erase(std::unique(v->begin(), v->end()), v->end());
474 }
475
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map,absl::Span<const NodeTypeIdEdge> ephemeral_edges)476 Status GraphTypeTopologyView::InitializeFromGraph(
477 const GraphDef& graph, const NodeTypeAttrMap& node_type_map,
478 absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
479 if (graph_ != nullptr) {
480 return errors::InvalidArgument(
481 "GraphTypeTopologyView is already initialized.");
482 }
483
484 graph_ = &graph;
485 int num_nodedefs = graph.node_size();
486 node_name_to_index_.rehash(num_nodedefs);
487
488 // Build maps from name to index.
489 node_type_attrs_.reserve(num_nodedefs); // Only approximate.
490 node_type_name_to_index_.rehash(num_nodedefs); // Only approximate.
491 for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
492 const NodeDef& node = graph.node(node_idx);
493 node_name_to_index_.emplace(node.name(), node_idx);
494
495 for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
496 int node_type_idx = node_type_attrs_.size();
497 node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
498 node_type_idx);
499 node_type_attrs_.emplace_back(&node, type_attr);
500 }
501 }
502 num_nodes_ = node_type_attrs_.size();
503 fanins_.resize(num_nodes_);
504 fanouts_.resize(num_nodes_);
505
506 // 1. Add ephemeral edges to the adjacency lists.
507 for (const NodeTypeIdEdge& edge : ephemeral_edges) {
508 const auto src = node_name_to_index_.find(edge.src.node->name());
509 const bool valid_src = src != node_name_to_index_.end();
510
511 if (!valid_src) {
512 const string error_message =
513 absl::StrCat("Non-existent src node: ", edge.src.node->name());
514 if (skip_invalid_edges_) {
515 VLOG(0) << "Skip error: " << error_message;
516 } else {
517 return errors::InvalidArgument(error_message);
518 }
519 }
520
521 const auto dst = node_name_to_index_.find(edge.dst.node->name());
522 const bool valid_dst = dst != node_name_to_index_.end();
523
524 if (!valid_dst) {
525 const string error_message =
526 absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
527 if (skip_invalid_edges_) {
528 VLOG(0) << "Skip error: " << error_message;
529 } else {
530 return errors::InvalidArgument(error_message);
531 }
532 }
533
534 if (valid_dst && valid_src) {
535 // TODO(benbarsdell): Check for failure.
536 int src_node_type_idx = node_type_name_to_index_.at(
537 NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
538 int dst_node_type_idx = node_type_name_to_index_.at(
539 NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
540 fanins_[dst_node_type_idx].push_back(src_node_type_idx);
541 fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
542 }
543 }
544
545 // 2. Add graph edges to the adjacency lists.
546 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
547 const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
548 auto input_ports =
549 node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
550 fanins_[node_type_idx].reserve(input_ports.size());
551 for (int port : input_ports) {
552 const string& input = node_type.node->input(port);
553 TensorId tensor = ParseTensorName(input);
554 const auto it = node_name_to_index_.find(tensor.node());
555 const bool valid_input = it != node_name_to_index_.end();
556
557 if (!valid_input) {
558 const string error_message = absl::StrCat(
559 "Non-existent input ", input, " in node ", node_type.node->name());
560 if (skip_invalid_edges_) {
561 VLOG(3) << "Skip error: " << error_message;
562 } else {
563 return errors::InvalidArgument(error_message);
564 }
565 }
566
567 if (valid_input) {
568 const int input_idx = it->second;
569 const NodeDef& input_node = graph_->node(input_idx);
570 TypeAttrId input_type_attr =
571 node_type_map.GetOutputTypeAttr(input_node, tensor.index());
572 const auto it2 = node_type_name_to_index_.find(
573 NodeTypeKey(input_node.name(), input_type_attr));
574 if (it2 == node_type_name_to_index_.end()) {
575 if (!skip_invalid_edges_) {
576 return errors::InvalidArgument("Did not find type attr ",
577 input_type_attr.DebugString(),
578 " in node ", input_node.name());
579 }
580 continue;
581 }
582 int input_node_type_idx = it2->second;
583 fanins_[node_type_idx].push_back(input_node_type_idx);
584 fanouts_[input_node_type_idx].push_back(node_type_idx);
585 }
586 }
587
588 // Dedup the input list while it's still hot in cache.
589 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
590 }
591
592 // Dedup outputs for all the graph nodes.
593 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
594 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
595 }
596
597 return Status::OK();
598 }
599
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)600 Status GraphTypeTopologyView::InitializeFromGraph(
601 const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
602 return InitializeFromGraph(graph, node_type_map,
603 absl::Span<const NodeTypeIdEdge>());
604 }
605
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const606 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
607 const TypeAttrId& type_attr) const {
608 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
609 NodeTypeKey key(node_name, type_attr);
610 const auto it = node_type_name_to_index_.find(key);
611 return it != node_type_name_to_index_.end();
612 }
613
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const614 const NodeTypeId* GraphTypeTopologyView::GetNode(
615 absl::string_view node_name, const TypeAttrId& type_attr) const {
616 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
617 NodeTypeKey key(node_name, type_attr);
618 const auto it = node_type_name_to_index_.find(key);
619 return it == node_type_name_to_index_.end()
620 ? nullptr
621 : &node_type_attrs_.at(it->second);
622 }
623
GetNode(int node_idx) const624 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
625 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
626 DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
627 return &node_type_attrs_.at(node_idx);
628 }
629
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const630 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
631 absl::string_view node_name, const TypeAttrId& type_attr) const {
632 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
633 NodeTypeKey key(node_name, type_attr);
634 const auto it = node_type_name_to_index_.find(key);
635 DCHECK(it != node_type_name_to_index_.end())
636 << "Node doesn't exist in a graph";
637 return it == node_type_name_to_index_.end() ? absl::nullopt
638 : absl::make_optional(it->second);
639 }
640
GetNodeIndex(const NodeTypeId & node) const641 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
642 const NodeTypeId& node) const {
643 return GetNodeIndex(node.node->name(), node.type_attr);
644 }
645
GetFanin(int node_idx) const646 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
647 int node_idx) const {
648 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
649 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
650 DCHECK(is_valid_node_idx) << "node_idx is out of range";
651 return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
652 }
653
GetFanout(int node_idx) const654 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
655 int node_idx) const {
656 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
657 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
658 DCHECK(is_valid_node_idx) << "node_idx is out of range";
659 return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
660 }
661
662 enum class TypeTraversalDirection {
663 kFollowInputs,
664 kFollowOutputs,
665 kFollowInputsAndOutputs,
666 };
667
668 // Encapsulate DFS callbacks that will be called during the graph traversal.
669 //
670 // If non-empty, the `pre_order` and `post_order` functors will be called on
671 // each reachable node (including the `from` nodes) in pre and post order. If
672 // loops are found, the `on_back_edge` functor will be called on the
673 // corresponding back edges. Moreover, the pre and post order will assume that
674 // these back edges will be cut.
675 struct DfsTypeCallbacks {
676 DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anon09ee8e770111::DfsTypeCallbacks677 DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
678 std::function<void(int, int)> back_edge)
679 : pre_order(std::move(pre)),
680 post_order(std::move(post)),
681 on_back_edge(std::move(back_edge)) {}
682
PreOrdertensorflow::grappler::__anon09ee8e770111::DfsTypeCallbacks683 static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
684 return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
685 }
686
PostOrdertensorflow::grappler::__anon09ee8e770111::DfsTypeCallbacks687 static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
688 return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
689 }
690
691 std::function<void(int)> pre_order;
692 std::function<void(int)> post_order;
693 std::function<void(int, int)> on_back_edge;
694 };
695
696 // Encapsulate DFS predicates for traversing the graph.
697 //
698 // The `enter` predicate decides if traversal should enter the node, and the
699 // `advance` predicate decides if the traversal should follow inputs/outputs
700 // from the node.
701 //
702 // If predicates are empty (default initialized), it's assumed that we can enter
703 // into any node and advance from any node respectively.
704 struct DfsTypePredicates {
705 DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anon09ee8e770111::DfsTypePredicates706 DfsTypePredicates(std::function<bool(int)> enter,
707 std::function<bool(int)> advance)
708 : enter(std::move(enter)), advance(std::move(advance)) {}
709
Entertensorflow::grappler::__anon09ee8e770111::DfsTypePredicates710 static DfsTypePredicates Enter(std::function<bool(int)> enter) {
711 return DfsTypePredicates(std::move(enter), nullptr);
712 }
713
Advancetensorflow::grappler::__anon09ee8e770111::DfsTypePredicates714 static DfsTypePredicates Advance(std::function<bool(int)> advance) {
715 return DfsTypePredicates(nullptr, std::move(advance));
716 }
717
718 std::function<bool(int)> enter;
719 std::function<bool(int)> advance;
720 };
721
722 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon09ee8e770111::DfsStackElem723 DfsStackElem(int node, bool children_visited, int src)
724 : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon09ee8e770111::DfsStackElem725 explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
726
727 // Index of the node in the graph ∊ [0, num_nodes).
728 int node;
729 // `True` if visited all the input/output nodes (pushed all input/output nodes
730 // to the stack).
731 bool children_visited;
732 // Index of the node in the graph, from which we entered the `node`.
733 int src;
734 };
735
736 enum class NodeState { kNotVisited, kVisiting, kDone };
737
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)738 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
739 const absl::Span<const NodeTypeId* const> from,
740 const TypeTraversalDirection direction,
741 const DfsTypePredicates& predicates,
742 const DfsTypeCallbacks& callbacks) {
743 std::vector<DfsStackElem> stack;
744 stack.reserve(from.size());
745
746 for (const NodeTypeId* node : from) {
747 const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
748 DCHECK(node_idx.has_value())
749 << "Illegal start node: " << node->node->name();
750 if (node_idx.has_value()) {
751 stack.emplace_back(node_idx.value());
752 }
753 }
754
755 absl::flat_hash_map<int, NodeState> node_state;
756 while (!stack.empty()) {
757 DfsStackElem w = stack.back();
758 stack.pop_back();
759
760 NodeState& state = node_state[w.node];
761 if (state == NodeState::kDone) continue;
762
763 // Skip nodes that we should not enter.
764 if (predicates.enter && !predicates.enter(w.node)) {
765 state = NodeState::kDone;
766 continue;
767 }
768
769 // We've processed all the children of this node.
770 if (w.children_visited) {
771 state = NodeState::kDone;
772 if (callbacks.post_order) {
773 callbacks.post_order(w.node);
774 }
775 continue;
776 }
777
778 // Loop detected.
779 if (state == NodeState::kVisiting) {
780 if (callbacks.on_back_edge) {
781 callbacks.on_back_edge(w.src, w.node);
782 }
783 continue;
784 }
785
786 state = NodeState::kVisiting;
787 if (callbacks.pre_order) {
788 callbacks.pre_order(w.node);
789 }
790
791 // Enqueue the node again with the children_visited flag set to true.
792 stack.emplace_back(w.node, true, w.src);
793
794 // Check if we can continue traversal from the current node.
795 if (predicates.advance && !predicates.advance(w.node)) {
796 continue;
797 }
798
799 // Now enqueue the fanin/fanout nodes.
800 if (direction == TypeTraversalDirection::kFollowInputs ||
801 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
802 for (const int fanin : graph_type_view.GetFanin(w.node)) {
803 stack.emplace_back(fanin, false, w.node);
804 }
805 }
806 if (direction == TypeTraversalDirection::kFollowOutputs ||
807 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
808 for (const int fanout : graph_type_view.GetFanout(w.node)) {
809 stack.emplace_back(fanout, false, w.node);
810 }
811 }
812 }
813 }
814
AllowedDataTypes(const OpDef::AttrDef & attr_def)815 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
816 const auto& allowed_types = attr_def.allowed_values().list().type();
817 if (allowed_types.empty()) {
818 return AllTypes();
819 }
820 uint32 dtype_mask = 0;
821 for (int dtype : allowed_types) {
822 dtype_mask |= 1u << dtype;
823 }
824 return DataTypeSet(dtype_mask);
825 }
826
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)827 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
828 if (t_attr_id.attr_name.empty()) {
829 return ToSet(t_attr_id.fixed_type);
830 }
831 const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
832 CHECK(attr_def); // Crash Ok
833 return AllowedDataTypes(*attr_def);
834 }
835
BuildCastNode(const MutableGraphView::OutputPort & src,bool to_fp16,const string & device)836 NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_fp16,
837 const string& device) {
838 const char* cast_string = to_fp16 ? kCastToFp16 : kCastToFp32;
839 string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
840 cast_string, "-", kSuffix);
841 NodeDef node;
842 node.set_name(name);
843 node.set_op("Cast");
844 node.set_device(device);
845 node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
846 (*node.mutable_attr())["SrcT"].set_type(to_fp16 ? DT_FLOAT : DT_HALF);
847 (*node.mutable_attr())["DstT"].set_type(to_fp16 ? DT_HALF : DT_FLOAT);
848 (*node.mutable_attr())["Truncate"].set_b(false);
849 return node;
850 }
851
ValidateLists(const gtl::FlatSet<string> & white_list,const gtl::FlatSet<string> & black_list,const gtl::FlatSet<string> & gray_list,const gtl::FlatSet<string> & clear_list)852 Status ValidateLists(const gtl::FlatSet<string>& white_list,
853 const gtl::FlatSet<string>& black_list,
854 const gtl::FlatSet<string>& gray_list,
855 const gtl::FlatSet<string>& clear_list) {
856 std::vector<gtl::FlatSet<string>> lists{white_list, black_list, gray_list,
857 clear_list};
858 std::multiset<string> counts;
859 for (auto list : lists) {
860 counts.insert(list.begin(), list.end());
861 }
862 bool duplicates = false;
863 for (auto s : counts) {
864 if (counts.count(s) > 1) {
865 duplicates = true;
866 LOG(ERROR) << "Op present in multiple lists: " << s;
867 }
868 }
869 if (duplicates) {
870 return errors::InvalidArgument("Op lists have conflicting entries");
871 } else {
872 return Status::OK();
873 }
874 }
875
HasInputOrOutputRefs(const NodeDef & node)876 bool HasInputOrOutputRefs(const NodeDef& node) {
877 const OpDef* op_def;
878 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
879 if (!status.ok()) {
880 return true;
881 }
882 for (const auto& input : op_def->input_arg()) {
883 if (input.is_ref()) {
884 return true;
885 }
886 }
887 for (const auto& output : op_def->output_arg()) {
888 if (output.is_ref()) {
889 return true;
890 }
891 }
892 return false;
893 }
894
895 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)896 bool CanForceFP16(const NodeDef& node) {
897 return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
898 !IsStateful(node) && !HasInputOrOutputRefs(node);
899 }
900
GetCudaVersion(const Cluster & cluster)901 int GetCudaVersion(const Cluster& cluster) {
902 auto devices = cluster.GetDevices();
903 for (const auto& device : devices) {
904 const DeviceProperties& device_properties = device.second;
905 if (device_properties.type() == "GPU") {
906 const auto& device_env = device_properties.environment();
907 auto it = device_env.find("cuda");
908 if (it != device_env.end()) {
909 string cuda_version_str = it->second;
910 return std::stoi(cuda_version_str);
911 }
912 }
913 }
914 return 0;
915 }
916
GetCudnnVersion(const Cluster & cluster)917 int GetCudnnVersion(const Cluster& cluster) {
918 auto devices = cluster.GetDevices();
919 for (const auto& device : devices) {
920 const DeviceProperties& device_properties = device.second;
921 if (device_properties.type() == "GPU") {
922 const auto& device_env = device_properties.environment();
923 auto it = device_env.find("cudnn");
924 if (it != device_env.end()) {
925 string cudnn_version_str = it->second;
926 return std::stoi(cudnn_version_str);
927 }
928 }
929 }
930 return 0;
931 }
932
933 class AutoMixedPrecisionImpl {
934 public:
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id)935 AutoMixedPrecisionImpl(Cluster* cluster,
936 const std::unordered_set<string>& nodes_to_preserve,
937 GraphDef* graph, string id)
938 : virtual_placer_(cluster->GetDevices()),
939 nodes_to_preserve_(nodes_to_preserve),
940 graph_(graph),
941 id_(id),
942 graph_view_(graph),
943 cuda_version_(GetCudaVersion(*cluster)),
944 cudnn_version_(GetCudnnVersion(*cluster)) {}
945
946 Status Optimize();
947
948 private:
949 typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
950 // Maps data structure object ops (e.g., StackV2) to the sets of nodes that
951 // write (e.g., StackPushV2) and read (e.g., StackPopV2) from them.
952 typedef absl::flat_hash_map<NodeTypeId,
953 std::pair<NodeTypeIdSet, NodeTypeIdSet>>
954 DataStructureOpsMap;
955
956 Status PrintDebugLogs(bool preop, size_t timestamp);
957 void LogSkippedNode(const NodeDef& node) const;
958 bool MustPreserve(const NodeDef& node) const;
959 bool IsOnGPU(const NodeDef& node) const;
960 bool IsOnSuitableGPUArch(const NodeDef& node) const;
961 bool ShouldProcess(const NodeDef& node) const;
962 bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
963 bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
964 void ConvertBatchNormOpsToV2();
965 bool SupportsFloat16(const NodeTypeId& node_type) const;
966 const NodeDef* GetTailOfChain(
967 const NodeDef& node, const absl::flat_hash_set<string>& match_ops) const;
968 Status AddDataStructureOpsToMap(
969 const absl::flat_hash_set<string>& data_structure_ops,
970 TypeAttrId data_structure_type_attr,
971 const absl::flat_hash_map<string, TypeAttrId>& write_ops,
972 const absl::flat_hash_map<string, TypeAttrId>& read_ops,
973 DataStructureOpsMap* object_clients_map) const;
974 void AddWhitelistOps(absl::flat_hash_set<int>* white_set) const;
975 void PropagateBlackFwdThroughClearAndGray(
976 absl::flat_hash_set<int>* black_set) const;
977 void ForceColorMatchBetweenDataStructureOps(
978 const DataStructureOpsMap& object_clients_map,
979 absl::flat_hash_set<int>* white_set,
980 absl::flat_hash_set<int>* black_set) const;
981 void AddClearAndGrayToWhiteIfBetweenWhite(
982 const absl::flat_hash_set<int>& black_set,
983 absl::flat_hash_set<int>* white_set) const;
984 void PropagateWhiteThroughClear(const absl::flat_hash_set<int>& black_set,
985 absl::flat_hash_set<int>* white_set) const;
986 Status ForceColorMatchOnRecurrentEdges(
987 absl::flat_hash_set<int>* white_set) const;
988 void MakeCastsWhiteIfAllOutputsWhite(
989 absl::flat_hash_set<int>* white_set) const;
990 Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
991
992 VirtualPlacer virtual_placer_;
993 std::unordered_set<string> nodes_to_preserve_;
994 GraphDef* graph_;
995 string id_;
996 MutableGraphView graph_view_;
997 int cuda_version_;
998 int cudnn_version_;
999 NodeTypeAttrMap node_type_map_;
1000 GraphTypeTopologyView graph_type_view_;
1001 bool force_all_fp16_;
1002 gtl::FlatSet<string> fp16_whitelist_;
1003 gtl::FlatSet<string> fp16_blacklist_;
1004 gtl::FlatSet<string> fp16_graylist_;
1005 gtl::FlatSet<string> fp16_clearlist_;
1006 absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1007 };
1008
NodeHasFP16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1009 bool AutoMixedPrecisionImpl::NodeHasFP16KernelForTypeAttr(
1010 const NodeDef& node, TypeAttrId taid) const {
1011 NodeDef node_copy(node);
1012 if (node.device().empty()) {
1013 string device_name = virtual_placer_.get_canonical_device_name(node);
1014 node_copy.set_device(device_name);
1015 }
1016 if (!SetDataType(&node_copy, taid, DataType::DT_HALF)) {
1017 return false;
1018 }
1019 return IsKernelRegisteredForNode(node_copy).ok();
1020 }
1021
PrintDebugLogs(bool preop,size_t timestamp)1022 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1023 string prepend_path;
1024 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1025 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1026 if (prepend_path.empty()) return Status::OK();
1027
1028 string suffix =
1029 strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1030
1031 string fname =
1032 io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1033 std::fstream f;
1034 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1035 f << graph_->SerializeAsString();
1036 f.close();
1037 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1038 << " graph as binary to " << fname;
1039
1040 fname = io::JoinPath(prepend_path,
1041 strings::StrCat("graphdef", suffix, ".pb.txt"));
1042 f.open(fname.c_str(), std::fstream::out);
1043 f << graph_->DebugString();
1044 f.close();
1045 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1046 << " graph as text to " << fname;
1047
1048 if (!preop) {
1049 fname = io::JoinPath(prepend_path,
1050 strings::StrCat("paintbuckets", suffix, ".txt"));
1051 f.open(fname.c_str(), std::fstream::out);
1052 f << "WhiteList:\n";
1053 for (auto x :
1054 AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_)) {
1055 f << x << "\n";
1056 }
1057 f << "\nBlackList:\n";
1058 for (auto x : AutoMixedPrecisionLists::BlackList()) {
1059 f << x << "\n";
1060 }
1061 f << "\nGrayList:\n";
1062 for (auto x : AutoMixedPrecisionLists::GrayList()) {
1063 f << x << "\n";
1064 }
1065 f << "\nClearList:\n";
1066 for (auto x : AutoMixedPrecisionLists::ClearList()) {
1067 f << x << "\n";
1068 }
1069 f.close();
1070 LOG(INFO) << "Saved paint bucket info to " << fname;
1071 }
1072 return Status::OK();
1073 }
1074
LogSkippedNode(const NodeDef & node) const1075 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
1076 VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1077 << " because it "
1078 << (MustPreserve(node)
1079 ? "must be preserved"
1080 : "is not on the GPU, or the GPU arch is not suitable");
1081 }
1082
MustPreserve(const NodeDef & node) const1083 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1084 return nodes_to_preserve_.count(node.name());
1085 }
1086
IsOnGPU(const NodeDef & node) const1087 bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
1088 string device_name;
1089 if (node.device().empty()) {
1090 device_name = virtual_placer_.get_canonical_device_name(node);
1091 } else {
1092 device_name = node.device();
1093 }
1094 string device;
1095 string not_used;
1096 if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
1097 absl::StrContains(absl::AsciiStrToLower(device),
1098 absl::AsciiStrToLower(DEVICE_GPU))) {
1099 return true;
1100 }
1101 return false;
1102 }
1103
1104 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)1105 std::pair<int, int> GetDeviceGPUArch(
1106 const DeviceProperties& device_properties) {
1107 if (device_properties.type() != "GPU") return {0, 0};
1108 string arch_str = device_properties.environment().at("architecture");
1109 std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
1110 if (split_arch_str.empty()) {
1111 return {0, 0};
1112 }
1113
1114 int major, minor;
1115 if (!strings::safe_strto32(split_arch_str[0], &major)) {
1116 return {0, 0};
1117 }
1118
1119 if (split_arch_str.size() > 1) {
1120 if (strings::safe_strto32(split_arch_str[1], &minor)) {
1121 return {major, minor};
1122 } else {
1123 return {0, 0};
1124 }
1125 } else {
1126 return {major, 0};
1127 }
1128 }
1129
IsOnSuitableGPUArch(const NodeDef & node) const1130 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1131 return GetDeviceGPUArch(virtual_placer_.get_device(node)) >= kMinGPUArch;
1132 }
1133
ShouldProcess(const NodeDef & node) const1134 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1135 return should_process_nodes_.count(&node);
1136 }
1137
IsFloat32(const NodeTypeId & node_type)1138 bool IsFloat32(const NodeTypeId& node_type) {
1139 return GetDataType(*node_type.node, node_type.type_attr) ==
1140 DataType::DT_FLOAT;
1141 }
1142
SupportsFloat16(const NodeTypeId & node_type) const1143 bool AutoMixedPrecisionImpl::SupportsFloat16(
1144 const NodeTypeId& node_type) const {
1145 const OpDef* op_def;
1146 Status status =
1147 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1148 if (!status.ok()) return false;
1149 return AllowedDataTypes(*op_def, node_type.type_attr)
1150 .Contains(DataType::DT_HALF) &&
1151 NodeHasFP16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1152 }
1153
1154 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1155 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1156 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1157 for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1158 NodeDef* node = graph_->mutable_node(node_idx);
1159 if (!ShouldProcess(*node)) continue;
1160 bool changed = false;
1161 if (node->op() == "FusedBatchNorm") {
1162 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1163 << " to FusedBatchNormV2";
1164 node->set_op("FusedBatchNormV2");
1165 changed = true;
1166 } else if (node->op() == "FusedBatchNormGrad") {
1167 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1168 << " to FusedBatchNormGradV2";
1169 node->set_op("FusedBatchNormGradV2");
1170 changed = true;
1171 }
1172 if (changed) {
1173 (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1174 }
1175 }
1176 }
1177
1178 // A helper function to decide whether to ignore the effect on performance when
1179 // rewriting the graph. This can be useful for testing the numerical effects of
1180 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1181 bool ShouldIgnorePerformance() {
1182 static bool is_enabled = [] {
1183 bool ret = false;
1184 TF_CHECK_OK(ReadBoolFromEnvVar(
1185 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1186 /*default_val=*/false, &ret));
1187 return ret;
1188 }();
1189 return is_enabled;
1190 }
1191
Optimize()1192 Status AutoMixedPrecisionImpl::Optimize() {
1193 string optimization_level;
1194 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1195 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1196 optimization_level = absl::AsciiStrToUpper(optimization_level);
1197 force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1198
1199 fp16_whitelist_ =
1200 AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_);
1201 fp16_blacklist_ = AutoMixedPrecisionLists::BlackList();
1202 fp16_graylist_ = AutoMixedPrecisionLists::GrayList();
1203 fp16_clearlist_ = AutoMixedPrecisionLists::ClearList();
1204 TF_RETURN_IF_ERROR(ValidateLists(fp16_whitelist_, fp16_blacklist_,
1205 fp16_graylist_, fp16_clearlist_));
1206
1207 size_t timestamp = Env::Default()->NowMicros() / 1000;
1208 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1209
1210 VLOG(2) << "Identifying nodes that should be processed";
1211 for (const NodeDef& node : graph_->node()) {
1212 if (!MustPreserve(node) && IsOnGPU(node) &&
1213 (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node))) {
1214 should_process_nodes_.insert(&node);
1215 } else {
1216 LogSkippedNode(node);
1217 }
1218 }
1219
1220 VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1221 ConvertBatchNormOpsToV2();
1222
1223 VLOG(2) << "Building node type map for graph";
1224 TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1225
1226 // Note: If an op is added to this list that has a data type attribute, it
1227 // should also be added to the AddDataStructureOpsToMap call below (and to the
1228 // clearlist if it involves data flow).
1229 // TODO(benbarsdell): Add support for TensorListPushBackBatch and
1230 // TensorListConcatLists. They require special handling because they connect
1231 // multiple list objects together. Currently if they appear in the graph then
1232 // we have no choice but to disallow changing any tensor list ops, as
1233 // otherwise we risk breaking the graph if some are changed and some are not
1234 // (within a connected cluster of tensor list nodes).
1235 const gtl::FlatSet<string> supported_list_ops = {
1236 "EmptyTensorList",
1237 "TensorListSplit",
1238 "TensorListFromTensor",
1239 "TensorListReserve",
1240 "TensorListScatter",
1241 "TensorListScatterV2",
1242 "TensorListPushBack",
1243 "TensorListSetItem",
1244 "TensorListScatterIntoExistingList",
1245 "TensorListPopBack",
1246 "TensorListStack",
1247 "TensorListConcat",
1248 "TensorListConcatV2",
1249 "TensorListGetItem",
1250 "TensorListGather",
1251 "TensorListLength",
1252 "TensorListElementShape",
1253 "TensorListResize"};
1254
1255 bool can_change_tensor_list_ops = true;
1256 for (const NodeDef& node : graph_->node()) {
1257 if (absl::StartsWith(node.op(), "TensorList") &&
1258 !supported_list_ops.count(node.op())) {
1259 LOG(WARNING) << "Unsupported " << node.op() << " node found in graph ("
1260 << node.name()
1261 << "), tensor list ops will not be converted.";
1262 can_change_tensor_list_ops = false;
1263 break;
1264 }
1265 }
1266
1267 DataStructureOpsMap object_clients_map;
1268 if (can_change_tensor_list_ops) {
1269 VLOG(2) << "Identifying TensorList* nodes";
1270 TF_RETURN_IF_ERROR(AddDataStructureOpsToMap(
1271 {"EmptyTensorList", "TensorListSplit", "TensorListFromTensor",
1272 "TensorListReserve", "TensorListScatter", "TensorListScatterV2"},
1273 TypeAttrId("element_dtype"),
1274 {{"TensorListPushBack", TypeAttrId("element_dtype")},
1275 {"TensorListSetItem", TypeAttrId("element_dtype")},
1276 {"TensorListScatterIntoExistingList", TypeAttrId("element_dtype")}},
1277 {{"TensorListPopBack", TypeAttrId("element_dtype")},
1278 {"TensorListStack", TypeAttrId("element_dtype")},
1279 {"TensorListConcat", TypeAttrId("element_dtype")},
1280 {"TensorListConcatV2", TypeAttrId("element_dtype")},
1281 {"TensorListGetItem", TypeAttrId("element_dtype")},
1282 {"TensorListGather", TypeAttrId("element_dtype")}},
1283 &object_clients_map));
1284 } else {
1285 for (const string& list_op : supported_list_ops) {
1286 fp16_whitelist_.erase(list_op);
1287 fp16_graylist_.erase(list_op);
1288 fp16_clearlist_.erase(list_op);
1289 }
1290 }
1291
1292 // Create ephemeral edges between writers and readers of data structure ops.
1293 std::vector<NodeTypeIdEdge> ephemeral_edges;
1294 for (const auto& object_clients : object_clients_map) {
1295 const auto& client_nodes = object_clients.second;
1296 for (const NodeTypeId& write_node_type : client_nodes.first) {
1297 for (const NodeTypeId& read_node_type : client_nodes.second) {
1298 ephemeral_edges.emplace_back(write_node_type, read_node_type);
1299 }
1300 }
1301 const NodeTypeId& object_node_type = object_clients.first;
1302 // These object types also act as writers because they initialize the object
1303 // from an input tensor.
1304 if (object_node_type.node->op() == "TensorListSplit" ||
1305 object_node_type.node->op() == "TensorListFromTensor" ||
1306 object_node_type.node->op() == "TensorListScatter" ||
1307 object_node_type.node->op() == "TensorListScatterV2") {
1308 for (const NodeTypeId& read_node_type : client_nodes.second) {
1309 ephemeral_edges.emplace_back(object_node_type, read_node_type);
1310 }
1311 }
1312 }
1313
1314 VLOG(2) << "Constructing graph type attribute topology view";
1315 TF_RETURN_IF_ERROR(graph_type_view_.InitializeFromGraph(
1316 *graph_, node_type_map_, ephemeral_edges));
1317
1318 // The goal here is to change performance-critical ops to fp16, and to do so
1319 // with the minimal number of casts, subject to the constraint that the
1320 // model's convergence is not affected. This is achieved by first identifying
1321 // which nodes should be changed to fp16 and then inserting casts at the
1322 // boundaries between fp16/non-fp16 nodes.
1323
1324 // The algorithm for deciding which nodes to change to fp16 is as follows:
1325 // 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
1326 // This is done under the assumption that whitelist ops are always
1327 // numerically-safe in fp16 and that they are the most important ops for
1328 // improving performance.
1329 // 2) Add nodes to the black_set iff they are numerically-dangerous (aka
1330 // "blacklist" ops) or they are on a forward path from a blacklist node to
1331 // a black/gray node (including the node at the end of the path) through
1332 // non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
1333 // This is done to prevent numerically-dangerous ops and their downstream
1334 // effects from being changed to fp16, which would risk breaking the
1335 // numerical accuracy of the model.
1336 // 3) For all remaining nodes that are not considered dangerous (greylist
1337 // and clearlist ops), find those that are between (i.e., both upstream
1338 // and downstream of) white nodes, and add them to the white_set.
1339 // This is done to avoid unnecessary casts between whitelist ops.
1340 // 4) For all remaining clearlist nodes, add them to the white_set if they are
1341 // connected to a node in the white_set via other clearlist nodes.
1342 // This is done to increase the number of ops in the white_set without
1343 // affecting numerical stability.
1344
1345 absl::flat_hash_set<int> white_set;
1346 VLOG(2) << "Beginning pass 1 to add whitelist ops";
1347 AddWhitelistOps(&white_set);
1348 VLOG(2) << "Finished pass 1";
1349
1350 if (white_set.empty()) {
1351 LOG(INFO) << "No whitelist ops found, nothing to do";
1352 return Status::OK();
1353 }
1354
1355 absl::flat_hash_set<int> black_set;
1356 VLOG(2) << "Beginning pass 2 to propagate black forwards from blacklist ops "
1357 "through clear/graylist ops";
1358 PropagateBlackFwdThroughClearAndGray(&black_set);
1359 VLOG(2) << "Finished pass 2";
1360
1361 VLOG(2) << "Forcing color match between data structure ops";
1362 ForceColorMatchBetweenDataStructureOps(object_clients_map, &white_set,
1363 &black_set);
1364
1365 VLOG(2) << "Beginning pass 3 to set clear and gray nodes to white if they "
1366 "are between white ops";
1367 AddClearAndGrayToWhiteIfBetweenWhite(black_set, &white_set);
1368 VLOG(2) << "Finished pass 3";
1369
1370 VLOG(2) << "Beginning pass 4 to propagate white from white nodes through "
1371 "clearlist ops";
1372 PropagateWhiteThroughClear(black_set, &white_set);
1373 VLOG(2) << "Finished pass 4";
1374
1375 VLOG(2) << "Forcing color match between data structure ops";
1376 ForceColorMatchBetweenDataStructureOps(object_clients_map, &white_set,
1377 &black_set);
1378
1379 VLOG(2) << "Forcing color match on loop edges";
1380 TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&white_set));
1381
1382 VLOG(2) << "Finding existing casts that can be made white";
1383 MakeCastsWhiteIfAllOutputsWhite(&white_set);
1384
1385 VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1386 "ops at paint boundaries";
1387 TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(white_set));
1388 VLOG(2) << "Finished final pass";
1389
1390 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1391
1392 return Status::OK();
1393 }
1394
1395 // Finds data structure object ops (e.g., StackV2) and the sets of nodes that
1396 // write (e.g., StackPushV2) and read (e.g., StackPopV2) from them.
AddDataStructureOpsToMap(const absl::flat_hash_set<string> & data_structure_ops,TypeAttrId data_structure_type_attr,const absl::flat_hash_map<string,TypeAttrId> & write_ops,const absl::flat_hash_map<string,TypeAttrId> & read_ops,DataStructureOpsMap * object_clients_map) const1397 Status AutoMixedPrecisionImpl::AddDataStructureOpsToMap(
1398 const absl::flat_hash_set<string>& data_structure_ops,
1399 TypeAttrId data_structure_type_attr,
1400 const absl::flat_hash_map<string, TypeAttrId>& write_ops,
1401 const absl::flat_hash_map<string, TypeAttrId>& read_ops,
1402 DataStructureOpsMap* object_clients_map) const {
1403 for (const NodeDef& node : graph_->node()) {
1404 const auto write_iter = write_ops.find(node.op());
1405 const auto read_iter = read_ops.find(node.op());
1406 bool is_writer = write_iter != write_ops.end();
1407 bool is_reader = read_iter != read_ops.end();
1408 if (is_writer || is_reader) {
1409 const NodeDef* object_node = GetTailOfChain(node, data_structure_ops);
1410 if (!object_node) {
1411 return errors::FailedPrecondition(
1412 "No data structure op found upstream of ", node.op(), " node ",
1413 node.name());
1414 }
1415 NodeTypeId object_node_type(object_node, data_structure_type_attr);
1416 TypeAttrId type_attr = is_writer ? write_iter->second : read_iter->second;
1417 NodeTypeId node_type(&node, type_attr);
1418 auto* value = &(*object_clients_map)[object_node_type];
1419 auto* node_set = is_writer ? &value->first : &value->second;
1420 node_set->insert(node_type);
1421 }
1422 }
1423 return Status::OK();
1424 }
1425
AddWhitelistOps(absl::flat_hash_set<int> * white_set) const1426 void AutoMixedPrecisionImpl::AddWhitelistOps(
1427 absl::flat_hash_set<int>* white_set) const {
1428 // Add whitelisted ops to white_set.
1429 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1430 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1431 if (!ShouldProcess(*root.node)) continue;
1432 bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
1433 if (fp16_whitelist_.count(root.node->op()) || force_white) {
1434 bool inserted = white_set->insert(root_idx).second;
1435 if (VLOG_IS_ON(2) && inserted) {
1436 VLOG(2) << "Painting type " << root.type_attr.DebugString()
1437 << " of node " << root.node->name() << " WHITE because its op "
1438 << root.node->op() << " is on the whitelist";
1439 }
1440 }
1441 }
1442 }
1443
1444 // Adds nodes to black_set iff they are on the blacklist or they are on a
1445 // forward path from a blacklist node to a black/gray node (including the node
1446 // at the end of the path) through clear and gray nodes.
1447 // E.g., black -> gray -> clear -> gray -> clear -> white -> gray
1448 // becomes: black -> black -> black -> black -> clear -> white -> gray.
PropagateBlackFwdThroughClearAndGray(absl::flat_hash_set<int> * black_set) const1449 void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
1450 absl::flat_hash_set<int>* black_set) const {
1451 if (force_all_fp16_) return;
1452
1453 // Find clear nodes that are upstream of black or gray.
1454 absl::flat_hash_set<int> upstream_of_black_or_gray_set;
1455 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1456 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1457 if (!(fp16_blacklist_.count(root.node->op()) ||
1458 fp16_graylist_.count(root.node->op()))) {
1459 continue;
1460 }
1461 DfsTypeTraversal(graph_type_view_, {&root},
1462 TypeTraversalDirection::kFollowInputs,
1463 DfsTypePredicates::Enter([&](int idx) -> bool {
1464 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1465 return idx == root_idx ||
1466 (!upstream_of_black_or_gray_set.count(idx) &&
1467 fp16_clearlist_.count(item.node->op()));
1468 }),
1469 DfsTypeCallbacks::PreOrder([&](int idx) {
1470 upstream_of_black_or_gray_set.insert(idx);
1471 }));
1472 }
1473
1474 // Propagate black forward through nodes in upstream_of_black_or_gray_set.
1475 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1476 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1477 if (black_set->count(root_idx) || !fp16_blacklist_.count(root.node->op())) {
1478 continue;
1479 }
1480 DfsTypeTraversal(
1481 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1482 DfsTypePredicates::Enter([&](int idx) -> bool {
1483 return idx == root_idx || (!black_set->count(idx) &&
1484 upstream_of_black_or_gray_set.count(idx));
1485 }),
1486 DfsTypeCallbacks::PreOrder([&](int idx) {
1487 bool inserted = black_set->insert(idx).second;
1488 if (VLOG_IS_ON(2) && inserted) {
1489 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1490 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1491 << " of " << item.node->op() << " node "
1492 << item.node->name() << " BLACK";
1493 }
1494 }));
1495 }
1496 }
1497
AddClearAndGrayToWhiteIfBetweenWhite(const absl::flat_hash_set<int> & black_set,absl::flat_hash_set<int> * white_set) const1498 void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
1499 const absl::flat_hash_set<int>& black_set,
1500 absl::flat_hash_set<int>* white_set) const {
1501 // Find clear/graylist ops that are downstream of white ops.
1502 absl::flat_hash_set<int> downstream_of_white_set;
1503 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1504 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1505 if (!ShouldProcess(*root.node) || !fp16_whitelist_.count(root.node->op())) {
1506 continue;
1507 }
1508 DfsTypeTraversal(
1509 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1510 DfsTypePredicates::Enter([&](int idx) -> bool {
1511 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1512 return idx == root_idx ||
1513 (!downstream_of_white_set.count(idx) &&
1514 !fp16_whitelist_.count(item.node->op()) &&
1515 !black_set.count(idx) && ShouldProcess(*item.node) &&
1516 // TODO(benbarsdell): Consider allowing propagation through
1517 // ops that are already float16 in order to reduce the number
1518 // of casts.
1519 IsFloat32(item) && SupportsFloat16(item) &&
1520 (fp16_clearlist_.count(item.node->op()) ||
1521 fp16_graylist_.count(item.node->op())));
1522 }),
1523 DfsTypeCallbacks::PreOrder(
1524 [&](int idx) { downstream_of_white_set.insert(idx); }));
1525 }
1526
1527 // Set nodes that are both downstream and upstream of white ops to white.
1528 absl::flat_hash_set<int> upstream_of_white_set;
1529 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1530 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1531 if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
1532 !fp16_whitelist_.count(root.node->op())) {
1533 continue;
1534 }
1535 DfsTypeTraversal(
1536 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1537 DfsTypePredicates::Enter([&](int idx) -> bool {
1538 return idx == root_idx || (!upstream_of_white_set.count(idx) &&
1539 downstream_of_white_set.count(idx));
1540 }),
1541 DfsTypeCallbacks::PreOrder([&](int idx) {
1542 upstream_of_white_set.insert(idx);
1543 bool inserted = white_set->insert(idx).second;
1544 if (VLOG_IS_ON(2) && inserted) {
1545 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1546 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1547 << " of " << item.node->op() << " node "
1548 << item.node->name() << " WHITE";
1549 }
1550 }));
1551 }
1552 }
1553
PropagateWhiteThroughClear(const absl::flat_hash_set<int> & black_set,absl::flat_hash_set<int> * white_set) const1554 void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
1555 const absl::flat_hash_set<int>& black_set,
1556 absl::flat_hash_set<int>* white_set) const {
1557 // Propagate white from white nodes through clearlist ops.
1558 absl::flat_hash_set<int> clear_prop_set;
1559 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1560 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1561 if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1562 !white_set->count(root_idx)) {
1563 continue;
1564 }
1565 DfsTypeTraversal(
1566 graph_type_view_, {&root},
1567 TypeTraversalDirection::kFollowInputsAndOutputs,
1568 DfsTypePredicates::Enter([&](int idx) -> bool {
1569 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1570 return idx == root_idx ||
1571 (!white_set->count(idx) && !black_set.count(idx) &&
1572 ShouldProcess(*item.node) && IsFloat32(item) &&
1573 SupportsFloat16(item) &&
1574 (fp16_clearlist_.count(item.node->op())) &&
1575 // We don't propagate (backwards) through nodes that read
1576 // Variables because it can break the behavior of TensorBoard
1577 // visualization and/or (in the case of Enter nodes) the model
1578 // itself. This is only a problem for non-resource variables.
1579 !NodeImplicitlyReadsNonResourceVariable(*item.node));
1580 }),
1581 DfsTypeCallbacks::PreOrder([&](int idx) {
1582 clear_prop_set.insert(idx);
1583 bool inserted = white_set->insert(idx).second;
1584 if (VLOG_IS_ON(2) && inserted) {
1585 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1586 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1587 << " of " << item.node->op() << " node "
1588 << item.node->name() << " WHITE";
1589 }
1590 }));
1591 }
1592 }
1593
1594 // Forces NextIteration nodes and their output Merge node(s) to have the same
1595 // color. Specifically, it removes them all from white_set if any of the Merge
1596 // nodes is not in white_set, otherwise it adds the NextIteration node to
1597 // white_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * white_set) const1598 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1599 absl::flat_hash_set<int>* white_set) const {
1600 for (const NodeDef& node : graph_->node()) {
1601 if (node.op() == "NextIteration") {
1602 GraphView::OutputPort output_port(&node, 0);
1603 const auto& fanout = graph_view_.GetFanout(output_port);
1604 std::vector<int> merge_idxs;
1605 merge_idxs.reserve(fanout.size());
1606 bool any_merge_is_not_white = false;
1607 for (const auto& output : fanout) {
1608 const NodeDef& merge_node = *output.node;
1609 if (merge_node.op() != "Merge") {
1610 return errors::FailedPrecondition(
1611 "Expected Merge node after NextIteration, got ", merge_node.op());
1612 }
1613 const absl::optional<int> maybe_merge_idx =
1614 graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1615 if (!maybe_merge_idx.has_value()) {
1616 return errors::Internal("Type attribute T of Merge node ",
1617 merge_node.name(),
1618 " not found in graph view");
1619 }
1620 int merge_idx = maybe_merge_idx.value();
1621 merge_idxs.push_back(merge_idx);
1622 any_merge_is_not_white =
1623 any_merge_is_not_white || !white_set->count(merge_idx);
1624 }
1625 const absl::optional<int> maybe_nextiter_idx =
1626 graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1627 if (!maybe_nextiter_idx.has_value()) {
1628 return errors::Internal("Type attribute T of NextIteration node ",
1629 node.name(), " not found in graph view");
1630 }
1631 int nextiter_idx = maybe_nextiter_idx.value();
1632 if (any_merge_is_not_white) {
1633 for (int merge_idx : merge_idxs) {
1634 if (white_set->erase(merge_idx)) {
1635 VLOG(2) << "Painting type T of Merge node "
1636 << graph_type_view_.GetNode(merge_idx)->node->name()
1637 << " BLACK to match the color of its sibling Merge nodes "
1638 "with common NextIteration node "
1639 << node.name();
1640 }
1641 }
1642 if (white_set->erase(nextiter_idx)) {
1643 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1644 << " BLACK to match the color of its output Merge node(s)";
1645 }
1646 } else {
1647 if (white_set->insert(nextiter_idx).second) {
1648 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1649 << " WHITE to match the color of its output Merge node(s)";
1650 }
1651 }
1652 }
1653 }
1654 return Status::OK();
1655 }
1656
1657 // Returns the last node in the simple chain starting at node and traversing
1658 // backwards through the input(0) edge from each node until one with a matching
1659 // op is found, or nullptr if no matching node is found.
GetTailOfChain(const NodeDef & node,const absl::flat_hash_set<string> & match_ops) const1660 const NodeDef* AutoMixedPrecisionImpl::GetTailOfChain(
1661 const NodeDef& node, const absl::flat_hash_set<string>& match_ops) const {
1662 const NodeDef* node_ptr = &node;
1663 do {
1664 GraphView::InputPort node_input(node_ptr, 0);
1665 MutableGraphView::OutputPort prev_output =
1666 graph_view_.GetRegularFanin(node_input);
1667 node_ptr = prev_output.node;
1668 } while (node_ptr && !match_ops.count(node_ptr->op()));
1669 return node_ptr;
1670 }
1671
1672 // Ensures that data structure nodes (e.g., StackV2) and all of their associated
1673 // client nodes (e.g., StackPushV2 and StackPopV2) are in the same color set.
ForceColorMatchBetweenDataStructureOps(const DataStructureOpsMap & object_clients_map,absl::flat_hash_set<int> * white_set,absl::flat_hash_set<int> * black_set) const1674 void AutoMixedPrecisionImpl::ForceColorMatchBetweenDataStructureOps(
1675 const DataStructureOpsMap& object_clients_map,
1676 absl::flat_hash_set<int>* white_set,
1677 absl::flat_hash_set<int>* black_set) const {
1678 for (const auto& object_clients : object_clients_map) {
1679 const NodeTypeId& object_node_type = object_clients.first;
1680 const auto& client_nodes = object_clients.second;
1681 NodeTypeIdSet all_client_nodes = client_nodes.first;
1682 all_client_nodes.insert(client_nodes.second.begin(),
1683 client_nodes.second.end());
1684 // The object node may be considered a client too (e.g.,
1685 // TensorListFromTensor).
1686 all_client_nodes.insert(object_node_type);
1687 bool any_black = false;
1688 bool any_white = false;
1689 for (const NodeTypeId& node_type : all_client_nodes) {
1690 const absl::optional<int> maybe_node_idx =
1691 graph_type_view_.GetNodeIndex(node_type);
1692 DCHECK(maybe_node_idx.has_value())
1693 << "Type attribute " << node_type.type_attr.DebugString()
1694 << " of node " << node_type.node->name()
1695 << " not found in graph view";
1696 int node_idx = maybe_node_idx.value();
1697 if (black_set->count(node_idx)) {
1698 any_black = true;
1699 break;
1700 } else if (white_set->count(node_idx)) {
1701 any_white = true;
1702 }
1703 }
1704 if (any_black || any_white) {
1705 for (const NodeTypeId& node_type : all_client_nodes) {
1706 VLOG(2) << "Painting type " << node_type.type_attr.DebugString()
1707 << " of " << node_type.node->op() << " node "
1708 << node_type.node->name() << " "
1709 << (any_black ? "BLACK" : "WHITE")
1710 << " because at least one of its siblings is "
1711 << (any_black ? "BLACK" : "WHITE");
1712 const absl::optional<int> maybe_node_idx =
1713 graph_type_view_.GetNodeIndex(node_type);
1714 DCHECK(maybe_node_idx.has_value())
1715 << "Type attribute " << node_type.type_attr.DebugString()
1716 << " of node " << node_type.node->name()
1717 << " not found in graph view";
1718 int node_idx = maybe_node_idx.value();
1719 if (any_black) {
1720 white_set->erase(node_idx);
1721 black_set->insert(node_idx);
1722 } else {
1723 white_set->insert(node_idx);
1724 }
1725 }
1726 }
1727 }
1728 }
1729
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const1730 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
1731 const NodeDef& node) const {
1732 if (node.op() == "Identity" || node.op() == "Enter") {
1733 GraphView::InputPort node_input(&node, 0);
1734 MutableGraphView::OutputPort prev_output =
1735 graph_view_.GetRegularFanin(node_input);
1736 const NodeDef* input = prev_output.node;
1737 if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
1738 input->op() == "VariableV2")) ||
1739 (node.op() == "Enter" &&
1740 NodeImplicitlyReadsNonResourceVariable(*input)))) {
1741 return true;
1742 }
1743 }
1744 return false;
1745 }
1746
1747 // This adds existing Cast nodes to white_set if all of their outputs are white,
1748 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsWhiteIfAllOutputsWhite(absl::flat_hash_set<int> * white_set) const1749 void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
1750 absl::flat_hash_set<int>* white_set) const {
1751 int num_nodes_preop = graph_->node_size();
1752 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1753 NodeDef* node = graph_->mutable_node(node_idx);
1754 NodeTypeId node_type(node, TypeAttrId("DstT"));
1755 if (node->op() != "Cast" || !IsFloat32(node_type)) {
1756 continue;
1757 }
1758 bool all_fanouts_white = true;
1759 MutableGraphView::OutputPort src(node, 0);
1760 const auto& fanout = graph_view_.GetFanout(src);
1761 for (const MutableGraphView::InputPort& dst : fanout) {
1762 TypeAttrId dst_type_attr =
1763 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1764 const absl::optional<int> maybe_dst_type_idx =
1765 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1766 DCHECK(maybe_dst_type_idx.has_value())
1767 << "Type attribute " << dst_type_attr.DebugString() << " of node "
1768 << dst.node->name() << " not found in graph view";
1769 int dst_type_idx = maybe_dst_type_idx.value();
1770 bool dst_is_white = white_set->count(dst_type_idx);
1771 if (!dst_is_white) {
1772 all_fanouts_white = false;
1773 break;
1774 }
1775 }
1776 if (!fanout.empty() && all_fanouts_white) {
1777 const absl::optional<int> maybe_node_type_idx =
1778 graph_type_view_.GetNodeIndex(node_type);
1779 DCHECK(maybe_node_type_idx.has_value())
1780 << "Type attribute " << node_type.type_attr.DebugString()
1781 << " of node " << node_type.node->name()
1782 << " not found in graph view";
1783 int node_type_idx = maybe_node_type_idx.value();
1784 white_set->insert(node_type_idx);
1785 }
1786 }
1787 }
1788
1789 // Changes all white-painted type attributes to DT_HALF, and inserts Cast nodes
1790 // at node outputs for all edges that connect white-painted <->
1791 // non-white-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & white_set)1792 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
1793 const absl::flat_hash_set<int>& white_set) {
1794 int num_nodes_changed = 0;
1795 int num_nonvar_casts_to_fp16 = 0;
1796 int num_nodes_preop = graph_->node_size();
1797 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1798 NodeDef* node = graph_->mutable_node(node_idx);
1799 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
1800 const absl::optional<int> maybe_node_type_idx =
1801 graph_type_view_.GetNodeIndex(node->name(), type_attr);
1802 if (!maybe_node_type_idx.has_value()) {
1803 return errors::Internal("Type attribute ", type_attr.DebugString(),
1804 " of ", node->op(), " node ", node->name(),
1805 " not found in graph view");
1806 }
1807 int node_type_idx = maybe_node_type_idx.value();
1808 if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
1809 bool src_is_white = white_set.count(node_type_idx);
1810 if (src_is_white) {
1811 VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
1812 << node->op() << " node " << node->name() << " to DT_HALF";
1813 if (!SetDataType(node, type_attr, DT_HALF)) {
1814 return errors::Internal("Failed to set type attribute");
1815 }
1816 ++num_nodes_changed;
1817 }
1818 for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
1819 MutableGraphView::OutputPort src(node, output_port);
1820 NodeDef* added_cast_node = nullptr;
1821 // Note: This is copied so that edges can be modified inside the loop.
1822 auto fanout = graph_view_.GetFanout(src);
1823 for (const MutableGraphView::InputPort& dst : fanout) {
1824 TypeAttrId dst_type_attr =
1825 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1826 const absl::optional<int> maybe_dst_type_idx =
1827 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1828 if (!maybe_dst_type_idx.has_value()) {
1829 return errors::Internal("Type attribute ",
1830 dst_type_attr.DebugString(), " of ",
1831 dst.node->op(), " node ", dst.node->name(),
1832 " not found in graph view");
1833 }
1834 int dst_type_idx = maybe_dst_type_idx.value();
1835 bool dst_is_white = white_set.count(dst_type_idx);
1836 if (src_is_white != dst_is_white) {
1837 if (!added_cast_node) {
1838 bool to_fp16 = dst_is_white;
1839 VLOG(1) << "Inserting cast to "
1840 << (to_fp16 ? "DT_HALF" : "DT_FLOAT") << " at "
1841 << src.node->op() << " " << src.node->name() << ":"
1842 << src.port_id;
1843 added_cast_node = graph_view_.AddNode(
1844 BuildCastNode(src, to_fp16, src.node->device()));
1845 if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) &&
1846 !NodeImplicitlyReadsNonResourceVariable(*node)) {
1847 ++num_nonvar_casts_to_fp16;
1848 }
1849 }
1850 TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
1851 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
1852 }
1853 }
1854 }
1855 }
1856 }
1857 LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
1858 << " nodes to float16 precision using " << num_nonvar_casts_to_fp16
1859 << " cast(s) to float16 (excluding Const and Variable casts)";
1860 return Status::OK();
1861 }
1862
GetNumGPUs(const Cluster & cluster,const std::pair<int,int> & min_arch={0, 0})1863 int GetNumGPUs(const Cluster& cluster,
1864 const std::pair<int, int>& min_arch = {0, 0}) {
1865 auto devices = cluster.GetDevices();
1866 int num_gpus = 0;
1867 for (const auto& device : devices) {
1868 const DeviceProperties& device_properties = device.second;
1869 std::pair<int, int> arch = GetDeviceGPUArch(device_properties);
1870 if (device_properties.type() == "GPU" && arch >= min_arch) {
1871 num_gpus++;
1872 }
1873 }
1874 return num_gpus;
1875 }
1876
1877 } // end namespace
1878
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)1879 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
1880 GraphDef* output) {
1881 if (cluster == nullptr) {
1882 return errors::InvalidArgument("cluster == nullptr");
1883 }
1884
1885 // Start by copying input graph to output.
1886 *output = item.graph;
1887
1888 int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
1889 : GetNumGPUs(*cluster, kMinGPUArch);
1890 if (num_gpus < 1) {
1891 // AutoMixedPrecision is currently only tuned for GPU.
1892 LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
1893 << " graph optimizer";
1894 return Status::OK();
1895 }
1896
1897 // Optimize the output graph in-place.
1898 AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
1899 item.id);
1900 if (item.id == "tf_graph") {
1901 LOG(INFO) << "Running " << name() << " graph optimizer";
1902 } else {
1903 VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
1904 }
1905 Status status = optimizer.Optimize();
1906 if (!status.ok()) {
1907 // Restore the original graph.
1908 *output = item.graph;
1909 LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
1910 }
1911 return status;
1912 }
1913
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)1914 void AutoMixedPrecision::Feedback(Cluster* cluster, const GrapplerItem& item,
1915 const GraphDef& optimize_output,
1916 double result) {
1917 // Nothing to do for AutoMixedPrecision.
1918 }
1919
1920 } // end namespace grappler
1921 } // end namespace tensorflow
1922