1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17
18 #include <fstream>
19 #include <memory>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/grappler/clusters/cluster.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/mutable_graph_view.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/lib/io/path.h"
38 #include "tensorflow/core/lib/strings/numbers.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/env_var.h"
43
44 namespace tensorflow {
45 namespace grappler {
46 namespace {
47
48 #if GOOGLE_CUDA
49 const std::pair<int, int> kMinGPUArch = {7, 0};
50 #else
51 const std::pair<int, int> kMinGPUArch = {0, 0};
52 #endif
53
54 const char kSuffix[] = "AutoMixedPrecision";
55 const char kCastToFp16[] = "CastToFp16";
56 const char kCastToBf16[] = "CastToBf16";
57 const char kCastToFp32[] = "CastToFp32";
58
59 #if GOOGLE_CUDA
60 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)61 std::pair<int, int> GetDeviceGPUArch(
62 const DeviceProperties& device_properties) {
63 if (device_properties.type() != "GPU") return {0, 0};
64 string arch_str = device_properties.environment().at("architecture");
65 std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
66 if (split_arch_str.empty()) {
67 return {0, 0};
68 }
69
70 int major, minor;
71 if (!strings::safe_strto32(split_arch_str[0], &major)) {
72 return {0, 0};
73 }
74
75 if (split_arch_str.size() > 1) {
76 if (strings::safe_strto32(split_arch_str[1], &minor)) {
77 return {major, minor};
78 } else {
79 return {0, 0};
80 }
81 } else {
82 return {major, 0};
83 }
84 }
85 #endif
86
87 // Returns true if FP16Support is valid
88 // For CUDA, We compare the GPUArch with the kMinGPUArch, if GPUArch is >= min,
89 // return true. For AMD the corresponding gfx arch string for the detected AMD
90 // GPU is in the list for FP16 supported compute. Returns false otherwise.
HasFastFP16Support(const DeviceProperties & props)91 bool HasFastFP16Support(const DeviceProperties& props) {
92 #if GOOGLE_CUDA
93 return GetDeviceGPUArch(props) >= kMinGPUArch;
94 #elif TENSORFLOW_USE_ROCM
95 absl::flat_hash_set<std::string> FP16SupportedDevices = {{"gfx906"},
96 {"gfx908"}};
97 std::string gcnArchName = props.environment().at("architecture");
98 std::vector<std::string> gpu_arch = absl::StrSplit(gcnArchName, ":");
99 return !gpu_arch.empty() && FP16SupportedDevices.contains(gpu_arch[0]);
100 #endif
101 return false;
102 }
103
104 // Instances of this class represent unique type attribute identifiers within a
105 // node. It handles regular type attributes, list type attributes (where
106 // type_index is set to the index in the type list), and fixed types.
107 struct TypeAttrId {
108 static constexpr int kSingleType = -1;
109
TypeAttrIdtensorflow::grappler::__anonc899daf80111::TypeAttrId110 explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
111 : attr_name(_attr_name),
112 type_index(_type_index),
113 fixed_type(DT_INVALID) {}
114
TypeAttrIdtensorflow::grappler::__anonc899daf80111::TypeAttrId115 explicit TypeAttrId(DataType _fixed_type)
116 : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
117
operator ==tensorflow::grappler::__anonc899daf80111::TypeAttrId118 bool operator==(const TypeAttrId& other) const {
119 return attr_name == other.attr_name && type_index == other.type_index &&
120 fixed_type == other.fixed_type;
121 }
122
operator <tensorflow::grappler::__anonc899daf80111::TypeAttrId123 bool operator<(const TypeAttrId& other) const {
124 return std::make_tuple(attr_name, type_index, fixed_type) <
125 std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
126 }
127
128 template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)129 friend H AbslHashValue(H h, const TypeAttrId& ta) {
130 return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
131 }
132
DebugStringtensorflow::grappler::__anonc899daf80111::TypeAttrId133 string DebugString() const {
134 if (!attr_name.empty()) {
135 if (type_index == kSingleType) {
136 return attr_name;
137 } else {
138 return strings::StrCat(attr_name, "[", type_index, "]");
139 }
140 } else {
141 return tensorflow::DataTypeString(fixed_type);
142 }
143 }
144
145 string attr_name;
146 // If attr_name is a list(type), this is the index into the list. Otherwise
147 // this is kSingleType.
148 int type_index;
149 DataType fixed_type;
150 };
151
152 // Returns the data type of the given type attribute, or DT_INVALID if the type
153 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)154 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
155 if (type_attr.attr_name.empty()) {
156 return type_attr.fixed_type;
157 }
158 if (!node.attr().count(type_attr.attr_name)) {
159 return DT_INVALID;
160 }
161 const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
162 if (type_attr.type_index == TypeAttrId::kSingleType) {
163 return attr_value.type();
164 } else {
165 if (type_attr.type_index < 0 ||
166 type_attr.type_index >= attr_value.list().type_size()) {
167 return DT_INVALID;
168 }
169 return attr_value.list().type(type_attr.type_index);
170 }
171 }
172
173 // Sets the data type of the given type attribute. Returns false if the type
174 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)175 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
176 if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
177 return false;
178 }
179 AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
180 if (type_attr.type_index == TypeAttrId::kSingleType) {
181 attr_value.set_type(type);
182 } else {
183 if (type_attr.type_index < 0 ||
184 type_attr.type_index >= attr_value.list().type_size()) {
185 return false;
186 }
187 attr_value.mutable_list()->set_type(type_attr.type_index, type);
188 }
189 return true;
190 }
191
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)192 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
193 const OpDef::ArgDef& arg_def) {
194 std::vector<std::pair<int, int>> argdef_inds;
195 if (!arg_def.type_list_attr().empty()) {
196 int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
197 for (int type_idx = 0; type_idx < num_types; ++type_idx) {
198 argdef_inds.push_back({arg_idx, type_idx});
199 }
200 } else {
201 int num_repeat = 1;
202 if (node.attr().count(arg_def.number_attr())) {
203 num_repeat = node.attr().at(arg_def.number_attr()).i();
204 }
205 argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
206 }
207 return argdef_inds;
208 }
209
210 // Returns a pair (arg_index, type_index) for each input to the node, where
211 // arg_index is the index of the input_arg in op_def and type_index is the index
212 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)213 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
214 const OpDef& op_def) {
215 std::vector<std::pair<int, int>> argdef_inds;
216 argdef_inds.reserve(op_def.input_arg_size()); // Final size may differ.
217 for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
218 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
219 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
220 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
221 arg_results.end());
222 }
223 return argdef_inds;
224 }
225
226 // Returns a pair (arg_index, type_index) for each output to the node, where
227 // arg_index is the index of the output_arg in op_def and type_index is the
228 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)229 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
230 const OpDef& op_def) {
231 std::vector<std::pair<int, int>> argdef_inds;
232 argdef_inds.reserve(op_def.output_arg_size()); // Final size may differ.
233 for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
234 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
235 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
236 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
237 arg_results.end());
238 }
239 return argdef_inds;
240 }
241
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)242 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
243 if (!arg_def.type_list_attr().empty()) {
244 return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
245 } else if (!arg_def.type_attr().empty()) {
246 return TypeAttrId(arg_def.type_attr());
247 } else {
248 return TypeAttrId(arg_def.type());
249 }
250 }
251
NonControlInputs(const NodeDef & node)252 std::vector<int> NonControlInputs(const NodeDef& node) {
253 std::vector<int> pos;
254 for (int i = 0; i < node.input_size(); i++) {
255 if (!IsControlInput(node.input(i))) {
256 pos.push_back(i);
257 }
258 }
259 return pos;
260 }
261
262 // A utility class to lookup node type attributes and type attribute <->
263 // input/output port mappings.
264 class NodeTypeAttrMap {
265 public:
NodeTypeAttrMap()266 NodeTypeAttrMap() {}
267
NodeTypeAttrMap(const GraphDef & graph)268 explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
269
Init(const GraphDef & graph)270 Status Init(const GraphDef& graph) {
271 if (graph_ != nullptr) {
272 return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
273 }
274 graph_ = &graph;
275 function_library_.reset(
276 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
277 for (const NodeDef& node : graph.node()) {
278 TF_RETURN_IF_ERROR(AddNode(node));
279 }
280 return Status::OK();
281 }
282
is_initialized() const283 bool is_initialized() const { return graph_ != nullptr; }
284
285 // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const286 absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
287 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
288 absl::flat_hash_set<TypeAttrId> type_attrs;
289 const auto iter = type2io_.find(&node);
290 CHECK(iter != type2io_.end()); // Crash Ok
291 for (const auto& key_value : iter->second) {
292 type_attrs.insert(key_value.first);
293 }
294 return type_attrs;
295 }
296
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const297 const absl::flat_hash_set<int>& GetInputPorts(
298 const NodeDef& node, const TypeAttrId& type_attr) const {
299 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
300 return type2io_.at(&node).at(type_attr).first;
301 }
302
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const303 const absl::flat_hash_set<int>& GetOutputPorts(
304 const NodeDef& node, const TypeAttrId& type_attr) const {
305 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
306 return type2io_.at(&node).at(type_attr).second;
307 }
308
GetInputTypeAttr(const NodeDef & node,int port) const309 TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
310 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
311 auto type_vec = io2type_.at(&node).first;
312 CHECK_GE(port, 0); // Crash Ok
313 CHECK_LT(port, type_vec.size()); // Crash Ok
314 return type_vec[port];
315 }
316
GetOutputTypeAttr(const NodeDef & node,int port) const317 TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
318 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
319 auto type_vec = io2type_.at(&node).second;
320 CHECK_GE(port, 0); // Crash Ok
321 CHECK_LT(port, type_vec.size()); // Crash Ok
322 return type_vec[port];
323 }
324
325 private:
AddNode(const NodeDef & node)326 Status AddNode(const NodeDef& node) {
327 const OpDef* op_def_ptr = nullptr;
328 TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
329 const OpDef& op_def = *op_def_ptr;
330 auto& type2io_entry = type2io_[&node];
331 auto& io2type_entry = io2type_[&node];
332 auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
333 if (NonControlInputs(node).size() != input_arg_inds.size()) {
334 return errors::InvalidArgument(
335 "Expected ", node.op(), " node ", node.name(), " to have ",
336 input_arg_inds.size(), " non-control input(s), but got ",
337 node.input_size());
338 }
339 // Note that the mappings generated here include inputs/outputs with fixed
340 // types. This makes the mappings complete (all inputs and outputs are
341 // included), and allows the graph rewriter to propagate deny paint
342 // from/through ops with fixed types.
343 io2type_entry.first.reserve(input_arg_inds.size());
344 for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
345 const auto& arg_inds = input_arg_inds[i];
346 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
347 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
348 if (!type_attr.attr_name.empty() &&
349 !node.attr().count(type_attr.attr_name)) {
350 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
351 " is not present in node ", node.name());
352 }
353 type2io_entry[type_attr].first.insert(i);
354 io2type_entry.first.push_back(type_attr);
355 }
356
357 auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
358 io2type_entry.second.reserve(output_arg_inds.size());
359 for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
360 const auto& arg_inds = output_arg_inds[i];
361 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
362 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
363 if (!type_attr.attr_name.empty() &&
364 !node.attr().count(type_attr.attr_name)) {
365 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
366 " is not present in node ", node.name());
367 }
368 type2io_entry[type_attr].second.insert(i);
369 io2type_entry.second.push_back(type_attr);
370 }
371
372 // Also ensure that type attributes that aren't associated with any inputs
373 // or outputs (e.g., StackV2's elem_type) are added to the map.
374 for (const auto& attr : node.attr()) {
375 const string& attr_name = attr.first;
376 if (!attr_name.empty() && attr_name[0] == '_') continue;
377 const AttrValue& attr_value = attr.second;
378 const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
379 if (!attr_def) {
380 return errors::InvalidArgument("AttrDef not found for attribute ",
381 attr_name, " of node ", node.name());
382 }
383 if (attr_def->type() == "type") {
384 type2io_entry[TypeAttrId(attr_name)];
385 } else if (attr_def->type() == "list(type)") {
386 for (int i = 0; i < attr_value.list().type_size(); ++i) {
387 type2io_entry[TypeAttrId(attr_name, i)];
388 }
389 }
390 }
391 return Status::OK();
392 }
393
394 // WARN: `graph_` must outlive this object (node pointers must remain valid).
395 const GraphDef* graph_ = nullptr; // do not own
396 std::unique_ptr<FunctionLibraryDefinition> function_library_;
397
398 typedef absl::flat_hash_set<int> IntSet;
399 // Maps a type attr id -> (input port set, output port set)
400 typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
401 // Maps a node -> type attr mapping
402 absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
403 // Maps a port -> type attr id
404 typedef std::vector<TypeAttrId> TypeAttrIdVec;
405 // Maps a node -> (input port mapping, output port mapping)
406 absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
407 io2type_;
408 };
409
410 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anonc899daf80111::NodeTypeId411 NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
412 : node(_node), type_attr(_type_attr) {}
413
414 const NodeDef* node;
415 TypeAttrId type_attr;
416
operator ==tensorflow::grappler::__anonc899daf80111::NodeTypeId417 bool operator==(const NodeTypeId& other) const {
418 return node == other.node && type_attr == other.type_attr;
419 }
420
421 template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)422 friend H AbslHashValue(H h, const NodeTypeId& nt) {
423 return H::combine(std::move(h), nt.node, nt.type_attr);
424 }
425 };
426
427 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anonc899daf80111::NodeTypeIdEdge428 NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
429 : src(_src), dst(_dst) {}
430 NodeTypeId src;
431 NodeTypeId dst;
432 };
433
434 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
435 // used instead of this modified version.
436 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
437 // the vertices instead of just NodeDef.
438 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
439 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
440 // will be an edge from (A, T) to (B, U).
441 class GraphTypeTopologyView {
442 public:
443 GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)444 explicit GraphTypeTopologyView(bool skip_invalid_edges)
445 : skip_invalid_edges_(skip_invalid_edges) {}
446
447 // Initialize graph topology view from the graph. It's possible to pass
448 // additional edges that do not exist in a graph, but must be respected when
449 // computing graph topology. Example: Tensorflow runtime allows concurrent
450 // execution of dequeue/enqueue ops from the same queue resource, but we might
451 // want to enforce ordering between them for the purpose of graph analysis.
452 Status InitializeFromGraph(const GraphDef& graph,
453 const NodeTypeAttrMap& node_type_map);
454
455 Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges);
456
is_initialized() const457 bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const458 int num_nodes() const { return num_nodes_; }
graph() const459 const GraphDef* graph() const { return graph_; }
460
461 // Returns true iff the node exists in the underlying graph.
462 bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
463
464 // Finds a node by name or returns `nullptr` if it's not in the graph.
465 const NodeTypeId* GetNode(absl::string_view node_name,
466 const TypeAttrId& type_attr) const;
467 // Returns a node corresponding to the given node index.
468 const NodeTypeId* GetNode(int node_idx) const;
469
470 // Returns a node index for the given node name, if the name exists in the
471 // underlying graph. Otherwise returns empty optional.
472 const absl::optional<int> GetNodeIndex(absl::string_view node_name,
473 const TypeAttrId& type_attr) const;
474 // Returns a node index for the given node, if the node belongs to the
475 // underlying graph. Otherwise returns empty optional.
476 const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
477
478 // Returns all the node indexes that are in the direct fanin of the given
479 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
480 const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
481 // Returns all the node indexes that are in the direct fanout of the given
482 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
483 const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
484
485 private:
486 // The key type used to uniquely identify a type attribute on a node.
487 struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
488 typedef std::pair<absl::string_view, TypeAttrId> Base;
489
490 // Inherit the set of constructors.
491 using Base::pair;
492
493 template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)494 friend H AbslHashValue(H h, const NodeTypeKey& nt) {
495 return H::combine(std::move(h), nt.first, nt.second);
496 }
497 };
498
499 // If true, all invalid edges and inputs (srd, dst or input node not found in
500 // a graph) will be skipped, otherwise initialization will fail with error.
501 bool skip_invalid_edges_ = false;
502
503 // WARN: `graph_` must outlive this object and graph nodes must not be
504 // destructed, because node names captured with absl::string_view.
505 const GraphDef* graph_ = nullptr; // do not own
506 int num_nodes_ = 0;
507 std::vector<NodeTypeId> node_type_attrs_;
508 absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
509 absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
510
511 std::vector<absl::InlinedVector<int, 4>> fanins_;
512 std::vector<absl::InlinedVector<int, 2>> fanouts_;
513
514 // We need a valid reference to return from GetFanin/GetFanout if the
515 // `node_idx` argument is outside of the [0, num_nodes_) range.
516 absl::InlinedVector<int, 4> empty_fanin_;
517 absl::InlinedVector<int, 2> empty_fanout_;
518 };
519
520 template <typename T>
SortAndRemoveDuplicates(T * v)521 inline void SortAndRemoveDuplicates(T* v) {
522 std::sort(v->begin(), v->end());
523 v->erase(std::unique(v->begin(), v->end()), v->end());
524 }
525
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)526 Status GraphTypeTopologyView::InitializeFromGraph(
527 const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
528 if (graph_ != nullptr) {
529 return errors::InvalidArgument(
530 "GraphTypeTopologyView is already initialized.");
531 }
532
533 graph_ = &graph;
534 int num_nodedefs = graph.node_size();
535 node_name_to_index_.rehash(num_nodedefs);
536
537 // Build maps from name to index.
538 node_type_attrs_.reserve(num_nodedefs); // Only approximate.
539 node_type_name_to_index_.rehash(num_nodedefs); // Only approximate.
540 for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
541 const NodeDef& node = graph.node(node_idx);
542 node_name_to_index_.emplace(node.name(), node_idx);
543
544 for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
545 int node_type_idx = node_type_attrs_.size();
546 node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
547 node_type_idx);
548 node_type_attrs_.emplace_back(&node, type_attr);
549 }
550 }
551 num_nodes_ = node_type_attrs_.size();
552 fanins_.resize(num_nodes_);
553 fanouts_.resize(num_nodes_);
554
555 // Add graph edges to the adjacency lists.
556 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
557 const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
558 auto input_ports =
559 node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
560 fanins_[node_type_idx].reserve(input_ports.size());
561 for (int port : input_ports) {
562 const string& input = node_type.node->input(port);
563 TensorId tensor = ParseTensorName(input);
564 const auto it = node_name_to_index_.find(tensor.node());
565 const bool valid_input = it != node_name_to_index_.end();
566
567 if (!valid_input) {
568 const string error_message = absl::StrCat(
569 "Non-existent input ", input, " in node ", node_type.node->name());
570 if (skip_invalid_edges_) {
571 VLOG(3) << "Skip error: " << error_message;
572 } else {
573 return errors::InvalidArgument(error_message);
574 }
575 }
576
577 if (valid_input) {
578 const int input_idx = it->second;
579 const NodeDef& input_node = graph_->node(input_idx);
580 TypeAttrId input_type_attr =
581 node_type_map.GetOutputTypeAttr(input_node, tensor.index());
582 const auto it2 = node_type_name_to_index_.find(
583 NodeTypeKey(input_node.name(), input_type_attr));
584 if (it2 == node_type_name_to_index_.end()) {
585 if (!skip_invalid_edges_) {
586 return errors::InvalidArgument("Did not find type attr ",
587 input_type_attr.DebugString(),
588 " in node ", input_node.name());
589 }
590 continue;
591 }
592 int input_node_type_idx = it2->second;
593 fanins_[node_type_idx].push_back(input_node_type_idx);
594 fanouts_[input_node_type_idx].push_back(node_type_idx);
595 }
596 }
597
598 // Dedup the input list while it's still hot in cache.
599 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
600 }
601
602 // Dedup outputs for all the graph nodes.
603 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
604 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
605 }
606
607 return Status::OK();
608 }
609
AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges)610 Status GraphTypeTopologyView::AddEphemeralEdges(
611 absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
612 // Add ephemeral edges to the adjacency lists.
613 for (const NodeTypeIdEdge& edge : ephemeral_edges) {
614 const auto src = node_name_to_index_.find(edge.src.node->name());
615 const bool valid_src = src != node_name_to_index_.end();
616
617 if (!valid_src) {
618 const string error_message =
619 absl::StrCat("Non-existent src node: ", edge.src.node->name());
620 if (skip_invalid_edges_) {
621 VLOG(0) << "Skip error: " << error_message;
622 } else {
623 return errors::InvalidArgument(error_message);
624 }
625 }
626
627 const auto dst = node_name_to_index_.find(edge.dst.node->name());
628 const bool valid_dst = dst != node_name_to_index_.end();
629
630 if (!valid_dst) {
631 const string error_message =
632 absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
633 if (skip_invalid_edges_) {
634 VLOG(0) << "Skip error: " << error_message;
635 } else {
636 return errors::InvalidArgument(error_message);
637 }
638 }
639
640 if (valid_dst && valid_src) {
641 // TODO(benbarsdell): Check for failure.
642 int src_node_type_idx = node_type_name_to_index_.at(
643 NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
644 int dst_node_type_idx = node_type_name_to_index_.at(
645 NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
646 fanins_[dst_node_type_idx].push_back(src_node_type_idx);
647 fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
648 }
649 }
650
651 // Dedup inputs and outputs for all the graph nodes.
652 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
653 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
654 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
655 }
656
657 return Status::OK();
658 }
659
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const660 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
661 const TypeAttrId& type_attr) const {
662 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
663 NodeTypeKey key(node_name, type_attr);
664 const auto it = node_type_name_to_index_.find(key);
665 return it != node_type_name_to_index_.end();
666 }
667
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const668 const NodeTypeId* GraphTypeTopologyView::GetNode(
669 absl::string_view node_name, const TypeAttrId& type_attr) const {
670 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
671 NodeTypeKey key(node_name, type_attr);
672 const auto it = node_type_name_to_index_.find(key);
673 return it == node_type_name_to_index_.end()
674 ? nullptr
675 : &node_type_attrs_.at(it->second);
676 }
677
GetNode(int node_idx) const678 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
679 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
680 DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
681 return &node_type_attrs_.at(node_idx);
682 }
683
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const684 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
685 absl::string_view node_name, const TypeAttrId& type_attr) const {
686 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
687 NodeTypeKey key(node_name, type_attr);
688 const auto it = node_type_name_to_index_.find(key);
689 DCHECK(it != node_type_name_to_index_.end())
690 << "Node doesn't exist in a graph";
691 return it == node_type_name_to_index_.end() ? absl::nullopt
692 : absl::make_optional(it->second);
693 }
694
GetNodeIndex(const NodeTypeId & node) const695 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
696 const NodeTypeId& node) const {
697 return GetNodeIndex(node.node->name(), node.type_attr);
698 }
699
GetFanin(int node_idx) const700 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
701 int node_idx) const {
702 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
703 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
704 DCHECK(is_valid_node_idx) << "node_idx is out of range";
705 return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
706 }
707
GetFanout(int node_idx) const708 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
709 int node_idx) const {
710 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
711 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
712 DCHECK(is_valid_node_idx) << "node_idx is out of range";
713 return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
714 }
715
716 enum class TypeTraversalDirection {
717 kFollowInputs,
718 kFollowOutputs,
719 kFollowInputsAndOutputs,
720 };
721
722 // Encapsulate DFS callbacks that will be called during the graph traversal.
723 //
724 // If non-empty, the `pre_order` and `post_order` functors will be called on
725 // each reachable node (including the `from` nodes) in pre and post order. If
726 // loops are found, the `on_back_edge` functor will be called on the
727 // corresponding back edges. Moreover, the pre and post order will assume that
728 // these back edges will be cut.
729 struct DfsTypeCallbacks {
730 DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anonc899daf80111::DfsTypeCallbacks731 DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
732 std::function<void(int, int)> back_edge)
733 : pre_order(std::move(pre)),
734 post_order(std::move(post)),
735 on_back_edge(std::move(back_edge)) {}
736
PreOrdertensorflow::grappler::__anonc899daf80111::DfsTypeCallbacks737 static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
738 return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
739 }
740
PostOrdertensorflow::grappler::__anonc899daf80111::DfsTypeCallbacks741 static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
742 return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
743 }
744
745 std::function<void(int)> pre_order;
746 std::function<void(int)> post_order;
747 std::function<void(int, int)> on_back_edge;
748 };
749
750 // Encapsulate DFS predicates for traversing the graph.
751 //
752 // The `enter` predicate decides if traversal should enter the node, and the
753 // `advance` predicate decides if the traversal should follow inputs/outputs
754 // from the node.
755 //
756 // If predicates are empty (default initialized), it's assumed that we can enter
757 // into any node and advance from any node respectively.
758 struct DfsTypePredicates {
759 DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anonc899daf80111::DfsTypePredicates760 DfsTypePredicates(std::function<bool(int)> enter,
761 std::function<bool(int)> advance)
762 : enter(std::move(enter)), advance(std::move(advance)) {}
763
Entertensorflow::grappler::__anonc899daf80111::DfsTypePredicates764 static DfsTypePredicates Enter(std::function<bool(int)> enter) {
765 return DfsTypePredicates(std::move(enter), nullptr);
766 }
767
Advancetensorflow::grappler::__anonc899daf80111::DfsTypePredicates768 static DfsTypePredicates Advance(std::function<bool(int)> advance) {
769 return DfsTypePredicates(nullptr, std::move(advance));
770 }
771
772 std::function<bool(int)> enter;
773 std::function<bool(int)> advance;
774 };
775
776 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anonc899daf80111::DfsStackElem777 DfsStackElem(int node, bool children_visited, int src)
778 : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anonc899daf80111::DfsStackElem779 explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
780
781 // Index of the node in the graph ∊ [0, num_nodes).
782 int node;
783 // `True` if visited all the input/output nodes (pushed all input/output nodes
784 // to the stack).
785 bool children_visited;
786 // Index of the node in the graph, from which we entered the `node`.
787 int src;
788 };
789
790 enum class NodeState { kNotVisited, kVisiting, kDone };
791
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)792 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
793 const absl::Span<const NodeTypeId* const> from,
794 const TypeTraversalDirection direction,
795 const DfsTypePredicates& predicates,
796 const DfsTypeCallbacks& callbacks) {
797 std::vector<DfsStackElem> stack;
798 stack.reserve(from.size());
799
800 for (const NodeTypeId* node : from) {
801 const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
802 DCHECK(node_idx.has_value())
803 << "Illegal start node: " << node->node->name();
804 if (node_idx.has_value()) {
805 stack.emplace_back(node_idx.value());
806 }
807 }
808
809 absl::flat_hash_map<int, NodeState> node_state;
810 while (!stack.empty()) {
811 DfsStackElem w = stack.back();
812 stack.pop_back();
813
814 NodeState& state = node_state[w.node];
815 if (state == NodeState::kDone) continue;
816
817 // Skip nodes that we should not enter.
818 if (predicates.enter && !predicates.enter(w.node)) {
819 state = NodeState::kDone;
820 continue;
821 }
822
823 // We've processed all the children of this node.
824 if (w.children_visited) {
825 state = NodeState::kDone;
826 if (callbacks.post_order) {
827 callbacks.post_order(w.node);
828 }
829 continue;
830 }
831
832 // Loop detected.
833 if (state == NodeState::kVisiting) {
834 if (callbacks.on_back_edge) {
835 callbacks.on_back_edge(w.src, w.node);
836 }
837 continue;
838 }
839
840 state = NodeState::kVisiting;
841 if (callbacks.pre_order) {
842 callbacks.pre_order(w.node);
843 }
844
845 // Enqueue the node again with the children_visited flag set to true.
846 stack.emplace_back(w.node, true, w.src);
847
848 // Check if we can continue traversal from the current node.
849 if (predicates.advance && !predicates.advance(w.node)) {
850 continue;
851 }
852
853 // Now enqueue the fanin/fanout nodes.
854 if (direction == TypeTraversalDirection::kFollowInputs ||
855 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
856 for (const int fanin : graph_type_view.GetFanin(w.node)) {
857 stack.emplace_back(fanin, false, w.node);
858 }
859 }
860 if (direction == TypeTraversalDirection::kFollowOutputs ||
861 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
862 for (const int fanout : graph_type_view.GetFanout(w.node)) {
863 stack.emplace_back(fanout, false, w.node);
864 }
865 }
866 }
867 }
868
AllowedDataTypes(const OpDef::AttrDef & attr_def)869 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
870 const auto& allowed_types = attr_def.allowed_values().list().type();
871 if (allowed_types.empty()) {
872 return AllTypes();
873 }
874 uint32 dtype_mask = 0;
875 for (int dtype : allowed_types) {
876 dtype_mask |= 1u << dtype;
877 }
878 return DataTypeSet(dtype_mask);
879 }
880
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)881 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
882 if (t_attr_id.attr_name.empty()) {
883 return ToSet(t_attr_id.fixed_type);
884 }
885 const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
886 CHECK(attr_def); // Crash Ok
887 return AllowedDataTypes(*attr_def);
888 }
889
ValidateLists(const gtl::FlatSet<string> & allow_list,const gtl::FlatSet<string> & deny_list,const gtl::FlatSet<string> & infer_list,const gtl::FlatSet<string> & clear_list)890 Status ValidateLists(const gtl::FlatSet<string>& allow_list,
891 const gtl::FlatSet<string>& deny_list,
892 const gtl::FlatSet<string>& infer_list,
893 const gtl::FlatSet<string>& clear_list) {
894 std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
895 clear_list};
896 std::multiset<string> counts;
897 for (const auto& list : lists) {
898 counts.insert(list.begin(), list.end());
899 }
900 bool duplicates = false;
901 for (const auto& s : counts) {
902 if (counts.count(s) > 1) {
903 duplicates = true;
904 LOG(ERROR) << "Op present in multiple lists: " << s;
905 }
906 }
907 if (duplicates) {
908 return errors::InvalidArgument("Op lists have conflicting entries");
909 } else {
910 return Status::OK();
911 }
912 }
913
HasInputOrOutputRefs(const NodeDef & node)914 bool HasInputOrOutputRefs(const NodeDef& node) {
915 const OpDef* op_def;
916 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
917 if (!status.ok()) {
918 return true;
919 }
920 for (const auto& input : op_def->input_arg()) {
921 if (input.is_ref()) {
922 return true;
923 }
924 }
925 for (const auto& output : op_def->output_arg()) {
926 if (output.is_ref()) {
927 return true;
928 }
929 }
930 return false;
931 }
932
933 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)934 bool CanForceFP16(const NodeDef& node) {
935 return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
936 !IsStateful(node) && !HasInputOrOutputRefs(node);
937 }
938
GetCudaVersion(const Cluster & cluster)939 int GetCudaVersion(const Cluster& cluster) {
940 auto devices = cluster.GetDevices();
941 for (const auto& device : devices) {
942 const DeviceProperties& device_properties = device.second;
943 if (device_properties.type() == "GPU") {
944 const auto& device_env = device_properties.environment();
945 auto it = device_env.find("cuda");
946 if (it != device_env.end()) {
947 string cuda_version_str = it->second;
948 return std::stoi(cuda_version_str);
949 }
950 }
951 }
952 return 0;
953 }
954
GetCudnnVersion(const Cluster & cluster)955 int GetCudnnVersion(const Cluster& cluster) {
956 auto devices = cluster.GetDevices();
957 for (const auto& device : devices) {
958 const DeviceProperties& device_properties = device.second;
959 if (device_properties.type() == "GPU") {
960 const auto& device_env = device_properties.environment();
961 auto it = device_env.find("cudnn");
962 if (it != device_env.end()) {
963 string cudnn_version_str = it->second;
964 return std::stoi(cudnn_version_str);
965 }
966 }
967 }
968 return 0;
969 }
970
971 class AutoMixedPrecisionImpl {
972 public:
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id,AutoMixedPrecisionMode mode)973 AutoMixedPrecisionImpl(Cluster* cluster,
974 const std::unordered_set<string>& nodes_to_preserve,
975 GraphDef* graph, string id,
976 AutoMixedPrecisionMode mode)
977 : virtual_placer_(cluster->GetDevices()),
978 nodes_to_preserve_(nodes_to_preserve),
979 graph_(graph),
980 function_library_(OpRegistry::Global(), graph->library()),
981 id_(id),
982 graph_view_(graph),
983 cuda_version_(GetCudaVersion(*cluster)),
984 cudnn_version_(GetCudnnVersion(*cluster)),
985 mode_(mode),
986 target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
987 : DT_BFLOAT16) {}
988
989 Status Optimize();
990
991 private:
992 typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
993
get_mixed_precision_lists() const994 std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
995 switch (mode_) {
996 case AutoMixedPrecisionMode::CUDA:
997 return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
998 cudnn_version_);
999 case AutoMixedPrecisionMode::MKL:
1000 return std::make_unique<AutoMixedPrecisionListsMkl>();
1001 }
1002 }
1003 Status PrintDebugLogs(bool preop, size_t timestamp);
1004 void LogSkippedNode(const NodeDef& node) const;
1005 bool MustPreserve(const NodeDef& node) const;
1006 bool IsOnDevice(const NodeDef& node, const string& device_type) const;
1007 bool IsOnSuitableGPUArch(const NodeDef& node) const;
1008 bool ShouldProcess(const NodeDef& node) const;
1009 bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
1010 bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
1011 void ConvertBatchNormOpsToV2();
1012 bool SupportsF16(const NodeTypeId& node_type) const;
1013 bool SupportsF16DataType(const NodeTypeId& node_type) const;
1014 bool IsQuantized(const NodeTypeId& node_type) const;
1015 const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
1016 bool IsSourceOrSinkOp(const string& op) const;
1017 void FindFloat32TensorListOpClustersAndDenylistUnsafe(
1018 std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
1019 absl::flat_hash_set<int>* deny_set) const;
1020 void FindTensorListImplicitFloat32Edges(
1021 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1022 std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
1023 void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
1024 void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const;
1025 void PropagateDenyFwdThroughClearAndInfer(
1026 absl::flat_hash_set<int>* deny_set) const;
1027 void ForceColorMatchBetweenTensorListOps(
1028 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1029 absl::flat_hash_set<int>* allow_set,
1030 absl::flat_hash_set<int>* deny_set) const;
1031 void AddClearAndInferToAllowIfBetweenAllow(
1032 const absl::flat_hash_set<int>& deny_set,
1033 absl::flat_hash_set<int>* allow_set) const;
1034 void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
1035 absl::flat_hash_set<int>* allow_set) const;
1036 Status ForceColorMatchOnRecurrentEdges(
1037 absl::flat_hash_set<int>* allow_set) const;
1038 void MakeCastsAllowIfAllOutputsAllow(
1039 absl::flat_hash_set<int>* allow_set) const;
1040 NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
1041 const string& device) const;
1042 Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
1043
1044 VirtualPlacer virtual_placer_;
1045 std::unordered_set<string> nodes_to_preserve_;
1046 GraphDef* graph_;
1047 FunctionLibraryDefinition function_library_;
1048 string id_;
1049 MutableGraphView graph_view_;
1050 int cuda_version_;
1051 int cudnn_version_;
1052 NodeTypeAttrMap node_type_map_;
1053 GraphTypeTopologyView graph_type_view_;
1054 bool force_all_fp16_;
1055 AutoMixedPrecisionMode mode_;
1056 gtl::FlatSet<string> f16_allowlist_;
1057 gtl::FlatSet<string> f16_denylist_;
1058 gtl::FlatSet<string> f16_inferlist_;
1059 gtl::FlatSet<string> f16_clearlist_;
1060 absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1061 DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
1062 };
1063
BuildCastNode(const MutableGraphView::OutputPort & src,bool to_f16,const string & device) const1064 NodeDef AutoMixedPrecisionImpl::BuildCastNode(
1065 const MutableGraphView::OutputPort& src, bool to_f16,
1066 const string& device) const {
1067 DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
1068 DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
1069 const char* cast_string = !to_f16 ? kCastToFp32
1070 : target_dtype_ == DT_HALF ? kCastToFp16
1071 : kCastToBf16;
1072 string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
1073 cast_string, "-", kSuffix);
1074 NodeDef node;
1075 node.set_name(name);
1076 node.set_op("Cast");
1077 node.set_device(device);
1078 node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
1079 (*node.mutable_attr())["SrcT"].set_type(src_type);
1080 (*node.mutable_attr())["DstT"].set_type(dst_type);
1081 (*node.mutable_attr())["Truncate"].set_b(false);
1082 return node;
1083 }
1084
NodeHasF16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1085 bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
1086 const NodeDef& node, TypeAttrId taid) const {
1087 NodeDef node_copy(node);
1088 if (node.device().empty()) {
1089 string device_name = virtual_placer_.get_canonical_device_name(node);
1090 node_copy.set_device(device_name);
1091 }
1092 if (!SetDataType(&node_copy, taid, target_dtype_)) {
1093 return false;
1094 }
1095 return IsKernelRegisteredForNode(node_copy).ok();
1096 }
1097
PrintDebugLogs(bool preop,size_t timestamp)1098 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1099 string prepend_path;
1100 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1101 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1102 if (prepend_path.empty()) return Status::OK();
1103
1104 string suffix =
1105 strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1106
1107 string fname =
1108 io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1109 std::fstream f;
1110 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1111 f << graph_->SerializeAsString();
1112 f.close();
1113 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1114 << " graph as binary to " << fname;
1115
1116 fname = io::JoinPath(prepend_path,
1117 strings::StrCat("graphdef", suffix, ".pb.txt"));
1118 f.open(fname.c_str(), std::fstream::out);
1119 f << graph_->DebugString();
1120 f.close();
1121 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1122 << " graph as text to " << fname;
1123
1124 if (!preop) {
1125 fname = io::JoinPath(prepend_path,
1126 strings::StrCat("paintbuckets", suffix, ".txt"));
1127 f.open(fname.c_str(), std::fstream::out);
1128 std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1129 get_mixed_precision_lists();
1130 f << "AllowList:\n";
1131 for (const auto& x : mp_lists->AllowList()) {
1132 f << x << "\n";
1133 }
1134 f << "\nDenyList:\n";
1135 for (const auto& x : mp_lists->DenyList()) {
1136 f << x << "\n";
1137 }
1138 f << "\nInferList:\n";
1139 for (const auto& x : mp_lists->InferList()) {
1140 f << x << "\n";
1141 }
1142 f << "\nClearList:\n";
1143 for (const auto& x : mp_lists->ClearList()) {
1144 f << x << "\n";
1145 }
1146 f.close();
1147 LOG(INFO) << "Saved paint bucket info to " << fname;
1148 }
1149 return Status::OK();
1150 }
1151
LogSkippedNode(const NodeDef & node) const1152 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
1153 VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1154 << " because it "
1155 << (MustPreserve(node)
1156 ? "must be preserved"
1157 : "is not on the GPU, or the GPU arch is not suitable");
1158 }
1159
MustPreserve(const NodeDef & node) const1160 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1161 return nodes_to_preserve_.count(node.name());
1162 }
1163
IsOnDevice(const NodeDef & node,const string & device_type) const1164 bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
1165 const string& device_type) const {
1166 string device_name;
1167 if (node.device().empty()) {
1168 device_name = virtual_placer_.get_canonical_device_name(node);
1169 } else {
1170 device_name = node.device();
1171 }
1172 string device;
1173 string not_used;
1174 if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
1175 absl::StrContains(absl::AsciiStrToLower(device),
1176 absl::AsciiStrToLower(device_type))) {
1177 return true;
1178 }
1179 return false;
1180 }
1181
IsOnSuitableGPUArch(const NodeDef & node) const1182 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1183 return HasFastFP16Support(virtual_placer_.get_device(node));
1184 }
1185
ShouldProcess(const NodeDef & node) const1186 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1187 return should_process_nodes_.count(&node);
1188 }
1189
IsFloat32(const NodeTypeId & node_type)1190 bool IsFloat32(const NodeTypeId& node_type) {
1191 return GetDataType(*node_type.node, node_type.type_attr) ==
1192 DataType::DT_FLOAT;
1193 }
1194
IsTensorListOp(const string & op)1195 bool IsTensorListOp(const string& op) {
1196 return op.find("TensorList") != string::npos;
1197 }
1198
IsTensorListReaderOp(const string & op)1199 bool IsTensorListReaderOp(const string& op) {
1200 static const gtl::FlatSet<string> tensor_list_reader_ops = {
1201 "TensorListConcat", "TensorListConcatV2", "TensorListGather",
1202 "TensorListGetItem", "TensorListPopBack", "TensorListStack"};
1203 return tensor_list_reader_ops.count(op);
1204 }
1205
IsTensorListWriterOp(const string & op)1206 bool IsTensorListWriterOp(const string& op) {
1207 static const gtl::FlatSet<string> tensor_list_writer_ops = {
1208 "TensorListFromTensor", "TensorListPushBack",
1209 "TensorListPushBackBatch", "TensorListScatter",
1210 "TensorListScatterV2", "TensorListScatterIntoExistingList",
1211 "TensorListSetItem", "TensorListSplit"};
1212 return tensor_list_writer_ops.count(op);
1213 }
1214
SupportsF16(const NodeTypeId & node_type) const1215 bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
1216 const OpDef* op_def;
1217 Status status =
1218 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1219 if (!status.ok()) return false;
1220 return AllowedDataTypes(*op_def, node_type.type_attr)
1221 .Contains(target_dtype_) &&
1222 NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1223 }
1224
SupportsF16DataType(const NodeTypeId & node_type) const1225 bool AutoMixedPrecisionImpl::SupportsF16DataType(
1226 const NodeTypeId& node_type) const {
1227 const OpDef* op_def;
1228 Status status =
1229 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1230 if (!status.ok()) return false;
1231 return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_);
1232 }
1233
IsQuantized(const NodeTypeId & node_type) const1234 bool AutoMixedPrecisionImpl::IsQuantized(const NodeTypeId& node_type) const {
1235 for (const TypeAttrId& type_attr :
1236 node_type_map_.GetTypeAttrs(*node_type.node)) {
1237 if (DataTypeIsQuantized(GetDataType(*node_type.node, type_attr))) {
1238 return true;
1239 }
1240 }
1241 return false;
1242 }
1243
1244 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1245 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1246 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1247 for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1248 NodeDef* node = graph_->mutable_node(node_idx);
1249 if (!ShouldProcess(*node)) continue;
1250 bool changed = false;
1251 if (node->op() == "FusedBatchNorm") {
1252 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1253 << " to FusedBatchNormV2";
1254 node->set_op("FusedBatchNormV2");
1255 changed = true;
1256 } else if (node->op() == "FusedBatchNormGrad") {
1257 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1258 << " to FusedBatchNormGradV2";
1259 node->set_op("FusedBatchNormGradV2");
1260 changed = true;
1261 }
1262 if (changed) {
1263 (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1264 }
1265 }
1266 }
1267
1268 // A helper function to decide whether to ignore the effect on performance when
1269 // rewriting the graph. This can be useful for testing the numerical effects of
1270 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1271 bool ShouldIgnorePerformance() {
1272 static bool is_enabled = [] {
1273 bool ret = false;
1274 TF_CHECK_OK(ReadBoolFromEnvVar(
1275 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1276 /*default_val=*/false, &ret));
1277 return ret;
1278 }();
1279 return is_enabled;
1280 }
1281
Optimize()1282 Status AutoMixedPrecisionImpl::Optimize() {
1283 string optimization_level;
1284 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1285 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1286 optimization_level = absl::AsciiStrToUpper(optimization_level);
1287 force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1288 if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
1289 // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
1290 // bfloat16.
1291 return errors::InvalidArgument(
1292 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
1293 "UNSAFE_FORCE_ALL when MKL is used");
1294 }
1295
1296 std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1297 get_mixed_precision_lists();
1298 f16_allowlist_ = mp_lists->AllowList();
1299 f16_denylist_ = mp_lists->DenyList();
1300 f16_inferlist_ = mp_lists->InferList();
1301 f16_clearlist_ = mp_lists->ClearList();
1302 TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
1303 f16_inferlist_, f16_clearlist_));
1304
1305 size_t timestamp = Env::Default()->NowMicros() / 1000;
1306 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1307
1308 VLOG(2) << "Identifying nodes that should be processed";
1309 for (const NodeDef& node : graph_->node()) {
1310 bool should_process;
1311 switch (mode_) {
1312 case AutoMixedPrecisionMode::CUDA:
1313 should_process =
1314 !MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
1315 (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
1316 break;
1317 case AutoMixedPrecisionMode::MKL:
1318 should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
1319 break;
1320 }
1321 if (should_process) {
1322 should_process_nodes_.insert(&node);
1323 } else {
1324 LogSkippedNode(node);
1325 }
1326 }
1327
1328 VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1329 ConvertBatchNormOpsToV2();
1330
1331 VLOG(2) << "Building node type map for graph";
1332 TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1333
1334 VLOG(2) << "Constructing graph type attribute topology view";
1335 TF_RETURN_IF_ERROR(
1336 graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
1337
1338 absl::flat_hash_set<int> deny_set;
1339
1340 std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
1341 FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
1342 &deny_set);
1343 std::vector<NodeTypeIdEdge> ephemeral_edges;
1344 for (const auto& cluster : tensor_list_clusters) {
1345 VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
1346 for (const NodeDef* node : cluster) {
1347 VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
1348 }
1349 FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
1350 }
1351 TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
1352
1353 // The goal here is to change performance-critical ops to fp16 or bf16, and to
1354 // do so with the minimal number of casts, subject to the constraint that the
1355 // model's convergence is not affected. This is achieved by first identifying
1356 // which nodes should be changed to f16 and then inserting casts at the
1357 // boundaries between f16/non-f16 nodes.
1358
1359 // The algorithm for deciding which nodes to change to f16 is as follows:
1360 // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
1361 // This is done under the assumption that allowlist ops are always
1362 // numerically-safe in f16 and that they are the most important ops for
1363 // improving performance.
1364 // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
1365 // "denylist" ops) or they are on a forward path from a denylist node to
1366 // a deny/infer node (including the node at the end of the path) through
1367 // non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
1368 // This is done to prevent numerically-dangerous ops and their downstream
1369 // effects from being changed to f16, which would risk breaking the
1370 // numerical accuracy of the model.
1371 // 3) For all remaining nodes that are not considered dangerous (inferlist
1372 // and clearlist ops), find those that are between (i.e., both upstream
1373 // and downstream of) allow nodes, and add them to the allow_set.
1374 // This is done to avoid unnecessary casts between allowlist ops.
1375 // 4) For all remaining clearlist nodes, add them to the allow_set if they are
1376 // connected to a node in the allow_set via other clearlist nodes.
1377 // This is done to increase the number of ops in the allow_set without
1378 // affecting numerical stability.
1379
1380 absl::flat_hash_set<int> allow_set;
1381 VLOG(2) << "Beginning pass 1 to add allowlist ops";
1382 AddAllowlistOps(&allow_set);
1383 VLOG(2) << "Finished pass 1";
1384
1385 if (allow_set.empty()) {
1386 LOG(INFO) << "No allowlist ops found, nothing to do";
1387 return Status::OK();
1388 }
1389
1390 VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
1391 "through clear/inferlist ops";
1392 PropagateDenyFwdThroughClearAndInfer(&deny_set);
1393 VLOG(2) << "Finished pass 2";
1394
1395 VLOG(2) << "Forcing color match between data structure ops";
1396 for (const auto& cluster : tensor_list_clusters) {
1397 ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1398 }
1399
1400 VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
1401 "are between allow ops";
1402 AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
1403 VLOG(2) << "Finished pass 3";
1404
1405 VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
1406 "clearlist ops";
1407 PropagateAllowThroughClear(deny_set, &allow_set);
1408 VLOG(2) << "Finished pass 4";
1409
1410 VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed "
1411 "to F16"
1412 "from allow set";
1413 RemoveAllowsetWithFp32(&allow_set);
1414 VLOG(2) << "Finished pass 5";
1415
1416 VLOG(2) << "Forcing color match between data structure ops";
1417 for (const auto& cluster : tensor_list_clusters) {
1418 ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1419 }
1420
1421 VLOG(2) << "Forcing color match on loop edges";
1422 TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
1423
1424 VLOG(2) << "Finding existing casts that can be made allow";
1425 MakeCastsAllowIfAllOutputsAllow(&allow_set);
1426
1427 VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1428 "ops at paint boundaries";
1429 TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
1430 VLOG(2) << "Finished final pass";
1431
1432 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1433
1434 return Status::OK();
1435 }
1436
1437 // If node is a Tensor List op with a float32 data type attribute then this
1438 // returns a pointer to the NodeTypeId representing that type attribute. In
1439 // all other cases this returns nullptr.
GetTensorListFloat32NodeTypeId(const NodeDef & node) const1440 const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId(
1441 const NodeDef& node) const {
1442 if (!IsTensorListOp(node.op())) return nullptr;
1443 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) {
1444 const NodeTypeId* node_type =
1445 graph_type_view_.GetNode(node.name(), type_attr);
1446 // This assumes that the float32 data type on a Tensor List op is always a
1447 // non-fixed type attribute containing a single type, and that this type
1448 // attribute represents the dtype of the values in the list.
1449 // TODO(benbarsdell): A new Tensor List op could theoretically break these
1450 // assumptions.
1451 if (node_type && node_type->type_attr.fixed_type == DT_INVALID &&
1452 node_type->type_attr.type_index == TypeAttrId::kSingleType &&
1453 IsFloat32(*node_type)) {
1454 return node_type;
1455 }
1456 }
1457 return nullptr;
1458 }
1459
IsSourceOrSinkOp(const string & op) const1460 bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
1461 const gtl::FlatSet<string> source_and_sink_ops = {
1462 "_Arg",
1463 "_Retval",
1464 "OptionalFromValue",
1465 "OptionalGetValue",
1466 "PartitionedCall",
1467 "Placeholder",
1468 "StatefulPartitionedCall",
1469 };
1470 return source_and_sink_ops.count(op) || function_library_.Find(op);
1471 }
1472
1473 // Finds all clusters of float32 Tensor List nodes that are connected via their
1474 // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
1475 // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
1476 // nodes) are added to deny_set. The caller should paint all nodes in a cluster
1477 // the same color, as they may all refer to the same Tensor List.
FindFloat32TensorListOpClustersAndDenylistUnsafe(std::vector<absl::flat_hash_set<const NodeDef * >> * tensor_list_clusters,absl::flat_hash_set<int> * deny_set) const1478 void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
1479 std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
1480 absl::flat_hash_set<int>* deny_set) const {
1481 absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
1482 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1483 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1484 if (!ShouldProcess(*root.node) ||
1485 root.type_attr.fixed_type != DataType::DT_VARIANT ||
1486 !GetTensorListFloat32NodeTypeId(*root.node) ||
1487 tensor_list_prop_set.count(root.node)) {
1488 continue;
1489 }
1490 const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1491 const absl::optional<int> maybe_root_fp32_idx =
1492 graph_type_view_.GetNodeIndex(*root_fp32);
1493 DCHECK(maybe_root_fp32_idx.has_value())
1494 << "Type attribute " << root_fp32->type_attr.DebugString()
1495 << " of node " << root.node->name() << " not found in graph view";
1496 int root_fp32_idx = maybe_root_fp32_idx.value();
1497 // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all
1498 // connected Tensor List nodes.
1499 absl::flat_hash_set<const NodeDef*> cluster({root.node});
1500 DfsTypeTraversal(graph_type_view_, {&root},
1501 TypeTraversalDirection::kFollowInputsAndOutputs,
1502 DfsTypePredicates::Enter([&](int idx) -> bool {
1503 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1504 return !tensor_list_prop_set.count(item.node);
1505 }),
1506 DfsTypeCallbacks::PreOrder([&](int idx) {
1507 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1508 const NodeDef* node = item.node;
1509 if (GetTensorListFloat32NodeTypeId(*node)) {
1510 cluster.insert(node);
1511 if (!ShouldProcess(*node)) {
1512 // The cluster contains an un-processable node.
1513 deny_set->insert(root_fp32_idx);
1514 }
1515 // TODO(benbarsdell): In a theoretical pathological
1516 // case of a Tensor List of Tensor List handles, the
1517 // Tensor List itself would need to be treated as a
1518 // sink.
1519 } else if (IsSourceOrSinkOp(node->op())) {
1520 // The cluster crosses an untraversable boundary.
1521 deny_set->insert(root_fp32_idx);
1522 }
1523 }));
1524 tensor_list_clusters->push_back(cluster);
1525 }
1526 }
1527
1528 // Finds all writer -> reader pairs in the given set that are connected via
1529 // their handles, and adds corresponding float32 edges to *implicit_fp32_edges.
FindTensorListImplicitFloat32Edges(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,std::vector<NodeTypeIdEdge> * implicit_fp32_edges) const1530 void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
1531 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1532 std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const {
1533 for (const NodeDef* root_node : tensor_list_nodes) {
1534 if (!IsTensorListReaderOp(root_node->op())) continue;
1535 NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT));
1536 const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1537 CHECK(root_fp32) << "No float32 type attribute found for " // Crash OK
1538 << root.node->op() << " node " << root.node->name();
1539 // Search backwards through handle edges (DT_VARIANT) for all writer ops,
1540 // adding direct implicit edges between them and the reader.
1541 DfsTypeTraversal(
1542 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1543 DfsTypePredicates::Enter([&](int idx) -> bool {
1544 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1545 return ShouldProcess(*item.node);
1546 }),
1547 DfsTypeCallbacks::PreOrder([&](int idx) {
1548 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1549 if (IsTensorListWriterOp(item.node->op())) {
1550 const NodeTypeId* item_fp32 =
1551 GetTensorListFloat32NodeTypeId(*item.node);
1552 CHECK(item_fp32) // Crash OK
1553 << "No float32 type attribute found for " << item.node->op()
1554 << " node " << item.node->name();
1555 VLOG(2) << "Adding ephemeral float32 edge from "
1556 << item_fp32->node->op() << " node "
1557 << item_fp32->node->name() << " to "
1558 << root_fp32->node->op() << " node "
1559 << root_fp32->node->name();
1560 implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32);
1561 }
1562 }));
1563 }
1564 }
1565
AddAllowlistOps(absl::flat_hash_set<int> * allow_set) const1566 void AutoMixedPrecisionImpl::AddAllowlistOps(
1567 absl::flat_hash_set<int>* allow_set) const {
1568 // Add allowlisted ops to allow_set.
1569 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1570 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1571 if (!ShouldProcess(*root.node)) continue;
1572 bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
1573 if (f16_allowlist_.count(root.node->op()) || force_allow) {
1574 bool inserted = allow_set->insert(root_idx).second;
1575 if (VLOG_IS_ON(2) && inserted) {
1576 VLOG(2) << "Painting type " << root.type_attr.DebugString()
1577 << " of node " << root.node->name() << " ALLOW because its op "
1578 << root.node->op() << " is on the allowlist";
1579 }
1580 }
1581 }
1582 }
1583
1584 // Adds nodes to deny_set iff they are on the denylist or they are on a
1585 // forward path from a denylist node to a deny/infer node (including the node
1586 // at the end of the path) through clear and infer nodes.
1587 // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
1588 // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
PropagateDenyFwdThroughClearAndInfer(absl::flat_hash_set<int> * deny_set) const1589 void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
1590 absl::flat_hash_set<int>* deny_set) const {
1591 if (force_all_fp16_) return;
1592
1593 // Find clear nodes that are upstream of deny or infer.
1594 absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
1595 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1596 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1597 if (!(f16_denylist_.count(root.node->op()) ||
1598 f16_inferlist_.count(root.node->op()))) {
1599 continue;
1600 }
1601 DfsTypeTraversal(graph_type_view_, {&root},
1602 TypeTraversalDirection::kFollowInputs,
1603 DfsTypePredicates::Enter([&](int idx) -> bool {
1604 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1605 return idx == root_idx ||
1606 (!upstream_of_deny_or_infer_set.count(idx) &&
1607 f16_clearlist_.count(item.node->op()));
1608 }),
1609 DfsTypeCallbacks::PreOrder([&](int idx) {
1610 upstream_of_deny_or_infer_set.insert(idx);
1611 }));
1612 }
1613
1614 // Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
1615 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1616 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1617 if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
1618 continue;
1619 }
1620 DfsTypeTraversal(
1621 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1622 DfsTypePredicates::Enter([&](int idx) -> bool {
1623 return idx == root_idx || (!deny_set->count(idx) &&
1624 upstream_of_deny_or_infer_set.count(idx));
1625 }),
1626 DfsTypeCallbacks::PreOrder([&](int idx) {
1627 bool inserted = deny_set->insert(idx).second;
1628 if (VLOG_IS_ON(2) && inserted) {
1629 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1630 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1631 << " of " << item.node->op() << " node "
1632 << item.node->name() << " DENY";
1633 }
1634 }));
1635 }
1636 }
1637
AddClearAndInferToAllowIfBetweenAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1638 void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
1639 const absl::flat_hash_set<int>& deny_set,
1640 absl::flat_hash_set<int>* allow_set) const {
1641 // Find clear/inferlist ops that are downstream of allow ops.
1642 absl::flat_hash_set<int> downstream_of_allow_set;
1643 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1644 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1645 if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
1646 continue;
1647 }
1648 DfsTypeTraversal(
1649 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1650 DfsTypePredicates::Enter([&](int idx) -> bool {
1651 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1652 return idx == root_idx ||
1653 (!downstream_of_allow_set.count(idx) &&
1654 !f16_allowlist_.count(item.node->op()) &&
1655 !deny_set.count(idx) && ShouldProcess(*item.node) &&
1656 // TODO(benbarsdell): Consider allowing propagation through
1657 // ops that are already float16 in order to reduce the number
1658 // of casts.
1659 IsFloat32(item) && SupportsF16(item) &&
1660 (f16_clearlist_.count(item.node->op()) ||
1661 f16_inferlist_.count(item.node->op())));
1662 }),
1663 DfsTypeCallbacks::PreOrder(
1664 [&](int idx) { downstream_of_allow_set.insert(idx); }));
1665 }
1666
1667 // Set nodes that are both downstream and upstream of allow ops to allow.
1668 absl::flat_hash_set<int> upstream_of_allow_set;
1669 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1670 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1671 if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
1672 !f16_allowlist_.count(root.node->op())) {
1673 continue;
1674 }
1675 DfsTypeTraversal(
1676 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1677 DfsTypePredicates::Enter([&](int idx) -> bool {
1678 return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
1679 downstream_of_allow_set.count(idx));
1680 }),
1681 DfsTypeCallbacks::PreOrder([&](int idx) {
1682 upstream_of_allow_set.insert(idx);
1683 bool inserted = allow_set->insert(idx).second;
1684 if (VLOG_IS_ON(2) && inserted) {
1685 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1686 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1687 << " of " << item.node->op() << " node "
1688 << item.node->name() << " ALLOW";
1689 }
1690 }));
1691 }
1692 }
1693
PropagateAllowThroughClear(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1694 void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
1695 const absl::flat_hash_set<int>& deny_set,
1696 absl::flat_hash_set<int>* allow_set) const {
1697 // Propagate allow from allow nodes through clearlist ops.
1698 absl::flat_hash_set<int> clear_prop_set;
1699 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1700 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1701 if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1702 !allow_set->count(root_idx)) {
1703 continue;
1704 }
1705 DfsTypeTraversal(
1706 graph_type_view_, {&root},
1707 TypeTraversalDirection::kFollowInputsAndOutputs,
1708 DfsTypePredicates::Enter([&](int idx) -> bool {
1709 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1710 return idx == root_idx ||
1711 (!allow_set->count(idx) && !deny_set.count(idx) &&
1712 ShouldProcess(*item.node) && IsFloat32(item) &&
1713 SupportsF16(item) &&
1714 (f16_clearlist_.count(item.node->op())) &&
1715 // We don't propagate (backwards) through nodes that read
1716 // Variables because it can break the behavior of TensorBoard
1717 // visualization and/or (in the case of Enter nodes) the model
1718 // itself. This is only a problem for non-resource variables.
1719 !NodeImplicitlyReadsNonResourceVariable(*item.node));
1720 }),
1721 DfsTypeCallbacks::PreOrder([&](int idx) {
1722 clear_prop_set.insert(idx);
1723 bool inserted = allow_set->insert(idx).second;
1724 if (VLOG_IS_ON(2) && inserted) {
1725 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1726 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1727 << " of " << item.node->op() << " node "
1728 << item.node->name() << " ALLOW";
1729 }
1730 }));
1731 }
1732 }
1733
1734 // If ops have one or more type_attr, But this type_attr could not be converted
1735 // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only
1736 // support float. So we will remove this node from allow_set.
1737 // Also don't convert quantized ops to FP16.
RemoveAllowsetWithFp32(absl::flat_hash_set<int> * allow_set) const1738 void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32(
1739 absl::flat_hash_set<int>* allow_set) const {
1740 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1741 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1742 if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) &&
1743 (!SupportsF16DataType(root) || IsQuantized(root))) {
1744 auto erased = allow_set->erase(root_idx);
1745 if (VLOG_IS_ON(2) && erased) {
1746 VLOG(2) << "UnPainting type " << root.type_attr.DebugString()
1747 << " of node " << root.node->name() << " ALLOW because its op "
1748 << root.node->op() << " is not support F16 DataType";
1749 }
1750 }
1751 }
1752 }
1753
1754 // Forces NextIteration nodes and their output Merge node(s) to have the same
1755 // color. Specifically, it removes them all from allow_set if any of the Merge
1756 // nodes is not in allow_set, otherwise it adds the NextIteration node to
1757 // allow_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * allow_set) const1758 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1759 absl::flat_hash_set<int>* allow_set) const {
1760 for (const NodeDef& node : graph_->node()) {
1761 if (node.op() == "NextIteration") {
1762 GraphView::OutputPort output_port(&node, 0);
1763 const auto& fanout = graph_view_.GetFanout(output_port);
1764 std::vector<int> merge_idxs;
1765 merge_idxs.reserve(fanout.size());
1766 bool any_merge_is_not_allow = false;
1767 for (const auto& output : fanout) {
1768 const NodeDef& merge_node = *output.node;
1769 if (merge_node.op() != "Merge") {
1770 return errors::FailedPrecondition(
1771 "Expected Merge node after NextIteration, got ", merge_node.op());
1772 }
1773 const absl::optional<int> maybe_merge_idx =
1774 graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1775 if (!maybe_merge_idx.has_value()) {
1776 return errors::Internal("Type attribute T of Merge node ",
1777 merge_node.name(),
1778 " not found in graph view");
1779 }
1780 int merge_idx = maybe_merge_idx.value();
1781 merge_idxs.push_back(merge_idx);
1782 any_merge_is_not_allow =
1783 any_merge_is_not_allow || !allow_set->count(merge_idx);
1784 }
1785 const absl::optional<int> maybe_nextiter_idx =
1786 graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1787 if (!maybe_nextiter_idx.has_value()) {
1788 return errors::Internal("Type attribute T of NextIteration node ",
1789 node.name(), " not found in graph view");
1790 }
1791 int nextiter_idx = maybe_nextiter_idx.value();
1792 if (any_merge_is_not_allow) {
1793 for (int merge_idx : merge_idxs) {
1794 if (allow_set->erase(merge_idx)) {
1795 VLOG(2) << "Painting type T of Merge node "
1796 << graph_type_view_.GetNode(merge_idx)->node->name()
1797 << " DENY to match the color of its sibling Merge nodes "
1798 "with common NextIteration node "
1799 << node.name();
1800 }
1801 }
1802 if (allow_set->erase(nextiter_idx)) {
1803 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1804 << " DENY to match the color of its output Merge node(s)";
1805 }
1806 } else {
1807 if (allow_set->insert(nextiter_idx).second) {
1808 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1809 << " ALLOW to match the color of its output Merge node(s)";
1810 }
1811 }
1812 }
1813 }
1814 return Status::OK();
1815 }
1816
1817 // Forces all of the given Tensor List nodes into the same color set.
ForceColorMatchBetweenTensorListOps(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,absl::flat_hash_set<int> * allow_set,absl::flat_hash_set<int> * deny_set) const1818 void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
1819 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1820 absl::flat_hash_set<int>* allow_set,
1821 absl::flat_hash_set<int>* deny_set) const {
1822 bool any_deny = false;
1823 bool any_allow = false;
1824 std::vector<int> node_type_idxs;
1825 node_type_idxs.reserve(tensor_list_nodes.size());
1826 for (const NodeDef* node : tensor_list_nodes) {
1827 const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node);
1828 const absl::optional<int> maybe_node_type_idx =
1829 graph_type_view_.GetNodeIndex(node_type);
1830 DCHECK(maybe_node_type_idx.has_value())
1831 << "Type attribute " << node_type.type_attr.DebugString() << " of node "
1832 << node->name() << " not found in graph view";
1833 node_type_idxs.push_back(maybe_node_type_idx.value());
1834 }
1835 for (int node_type_idx : node_type_idxs) {
1836 if (deny_set->count(node_type_idx)) {
1837 any_deny = true;
1838 break;
1839 } else if (allow_set->count(node_type_idx)) {
1840 any_allow = true;
1841 }
1842 }
1843 if (!any_deny && !any_allow) return;
1844 for (int node_type_idx : node_type_idxs) {
1845 const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
1846 VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
1847 << node_type.node->op() << " node " << node_type.node->name() << " "
1848 << (any_deny ? "DENY" : "ALLOW")
1849 << " because at least one of its siblings is "
1850 << (any_deny ? "DENY" : "ALLOW");
1851 if (any_deny) {
1852 allow_set->erase(node_type_idx);
1853 deny_set->insert(node_type_idx);
1854 } else {
1855 allow_set->insert(node_type_idx);
1856 }
1857 }
1858 }
1859
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const1860 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
1861 const NodeDef& node) const {
1862 if (node.op() == "Identity" || node.op() == "Enter") {
1863 GraphView::InputPort node_input(&node, 0);
1864 MutableGraphView::OutputPort prev_output =
1865 graph_view_.GetRegularFanin(node_input);
1866 const NodeDef* input = prev_output.node;
1867 if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
1868 input->op() == "VariableV2")) ||
1869 (node.op() == "Enter" &&
1870 NodeImplicitlyReadsNonResourceVariable(*input)))) {
1871 return true;
1872 }
1873 }
1874 return false;
1875 }
1876
1877 // This adds existing Cast nodes to allow_set if all of their outputs are allow,
1878 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsAllowIfAllOutputsAllow(absl::flat_hash_set<int> * allow_set) const1879 void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
1880 absl::flat_hash_set<int>* allow_set) const {
1881 int num_nodes_preop = graph_->node_size();
1882 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1883 NodeDef* node = graph_->mutable_node(node_idx);
1884 NodeTypeId node_type(node, TypeAttrId("DstT"));
1885 if (node->op() != "Cast" || !IsFloat32(node_type)) {
1886 continue;
1887 }
1888 bool all_fanouts_allow = true;
1889 MutableGraphView::OutputPort src(node, 0);
1890 const auto& fanout = graph_view_.GetFanout(src);
1891 for (const MutableGraphView::InputPort& dst : fanout) {
1892 TypeAttrId dst_type_attr =
1893 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1894 const absl::optional<int> maybe_dst_type_idx =
1895 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1896 DCHECK(maybe_dst_type_idx.has_value())
1897 << "Type attribute " << dst_type_attr.DebugString() << " of node "
1898 << dst.node->name() << " not found in graph view";
1899 int dst_type_idx = maybe_dst_type_idx.value();
1900 bool dst_is_allow = allow_set->count(dst_type_idx);
1901 if (!dst_is_allow) {
1902 all_fanouts_allow = false;
1903 break;
1904 }
1905 }
1906 if (!fanout.empty() && all_fanouts_allow) {
1907 const absl::optional<int> maybe_node_type_idx =
1908 graph_type_view_.GetNodeIndex(node_type);
1909 DCHECK(maybe_node_type_idx.has_value())
1910 << "Type attribute " << node_type.type_attr.DebugString()
1911 << " of node " << node_type.node->name()
1912 << " not found in graph view";
1913 int node_type_idx = maybe_node_type_idx.value();
1914 allow_set->insert(node_type_idx);
1915 }
1916 }
1917 }
1918
1919 // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
1920 // inserts Cast nodes at node outputs for all edges that connect
1921 // allow-painted <-> non-allow-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & allow_set)1922 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
1923 const absl::flat_hash_set<int>& allow_set) {
1924 int num_nodes_changed = 0;
1925 int num_nonvar_casts_to_f16 = 0;
1926 int num_nodes_preop = graph_->node_size();
1927 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1928 NodeDef* node = graph_->mutable_node(node_idx);
1929 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
1930 const absl::optional<int> maybe_node_type_idx =
1931 graph_type_view_.GetNodeIndex(node->name(), type_attr);
1932 if (!maybe_node_type_idx.has_value()) {
1933 return errors::Internal("Type attribute ", type_attr.DebugString(),
1934 " of ", node->op(), " node ", node->name(),
1935 " not found in graph view");
1936 }
1937 int node_type_idx = maybe_node_type_idx.value();
1938 if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
1939 bool src_is_allow = allow_set.count(node_type_idx);
1940 if (src_is_allow) {
1941 VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
1942 << node->op() << " node " << node->name() << " to "
1943 << DataTypeString(target_dtype_);
1944 if (!SetDataType(node, type_attr, target_dtype_)) {
1945 return errors::Internal("Failed to set type attribute");
1946 }
1947 ++num_nodes_changed;
1948 }
1949 for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
1950 MutableGraphView::OutputPort src(node, output_port);
1951 NodeDef* added_cast_node = nullptr;
1952 // Note: This is copied so that edges can be modified inside the loop.
1953 auto fanout = graph_view_.GetFanout(src);
1954 for (const MutableGraphView::InputPort& dst : fanout) {
1955 TypeAttrId dst_type_attr =
1956 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1957 const absl::optional<int> maybe_dst_type_idx =
1958 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1959 if (!maybe_dst_type_idx.has_value()) {
1960 return errors::Internal("Type attribute ",
1961 dst_type_attr.DebugString(), " of ",
1962 dst.node->op(), " node ", dst.node->name(),
1963 " not found in graph view");
1964 }
1965 int dst_type_idx = maybe_dst_type_idx.value();
1966 bool dst_is_allow = allow_set.count(dst_type_idx);
1967 if (src_is_allow != dst_is_allow) {
1968 if (!added_cast_node) {
1969 bool to_f16 = dst_is_allow;
1970 VLOG(1) << "Inserting cast to "
1971 << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
1972 << " at " << src.node->op() << " " << src.node->name()
1973 << ":" << src.port_id;
1974 added_cast_node = graph_view_.AddNode(
1975 BuildCastNode(src, to_f16, src.node->device()));
1976 if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
1977 !NodeImplicitlyReadsNonResourceVariable(*node)) {
1978 ++num_nonvar_casts_to_f16;
1979 }
1980 }
1981 TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
1982 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
1983 }
1984 }
1985 }
1986 }
1987 }
1988 // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
1989 // since many Python users will see this message.
1990 const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
1991 LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
1992 << " nodes to " << type_str << " precision using "
1993 << num_nonvar_casts_to_f16 << " cast(s) to " << type_str
1994 << " (excluding Const and Variable casts)";
1995 return Status::OK();
1996 }
1997
GetNumGPUs(const Cluster & cluster)1998 int GetNumGPUs(const Cluster& cluster) {
1999 auto devices = cluster.GetDevices();
2000 int num_gpus = 0;
2001 for (const auto& device : devices) {
2002 const DeviceProperties& device_properties = device.second;
2003 if (device_properties.type() == "GPU" &&
2004 (ShouldIgnorePerformance() || HasFastFP16Support(device_properties))) {
2005 num_gpus++;
2006 }
2007 }
2008 return num_gpus;
2009 }
2010
2011 } // end namespace
2012
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)2013 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
2014 GraphDef* output) {
2015 if (cluster == nullptr) {
2016 return errors::InvalidArgument("cluster == nullptr");
2017 }
2018
2019 #if !defined(INTEL_MKL)
2020 if (mode_ == AutoMixedPrecisionMode::MKL) {
2021 return errors::Unimplemented(
2022 "The auto_mixed_precision_mkl optimizer cannot be used since "
2023 "this build of TensorFlow is not compiled with MKL support for "
2024 "bfloat16. "
2025 "For information on MKL builds, see: "
2026 "https://software.intel.com/en-us/articles/intel-optimization-for-"
2027 "tensorflow-installation-guide");
2028 }
2029 #endif // INTEL_MKL
2030
2031 // Start by copying input graph to output.
2032 *output = item.graph;
2033
2034 int num_gpus = GetNumGPUs(*cluster);
2035 if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
2036 // AutoMixedPrecision is currently only tuned for GPU.
2037 LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
2038 << " graph optimizer";
2039 return Status::OK();
2040 }
2041
2042 // Optimize the output graph in-place.
2043 AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
2044 item.id, mode_);
2045 if (item.id == "tf_graph") {
2046 LOG(INFO) << "Running " << name() << " graph optimizer";
2047 } else {
2048 VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
2049 }
2050 Status status = optimizer.Optimize();
2051 if (!status.ok()) {
2052 // Restore the original graph.
2053 *output = item.graph;
2054 LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
2055 }
2056 return status;
2057 }
2058
2059 } // end namespace grappler
2060 } // end namespace tensorflow
2061