• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17 
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
32 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/graph/algorithm.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
43 #include "tensorflow/core/grappler/costs/graph_properties.h"
44 #include "tensorflow/core/grappler/devices.h"
45 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
46 #include "tensorflow/core/grappler/utils.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/cleanup.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
52 #include "tensorflow/core/protobuf/device_properties.pb.h"  // NOLINT
53 #include "tensorflow/core/protobuf/rewriter_config.pb.h"  // NOLINT
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/tools/graph_transforms/transform_utils.h"
56 
57 #if GOOGLE_CUDA && GOOGLE_TENSORRT
58 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
59 #include "third_party/tensorrt/NvInfer.h"
60 namespace tensorflow {
61 namespace tensorrt {
62 namespace convert {
63 
64 using absl::StrAppend;
65 using absl::StrCat;
66 using ::tensorflow::tensorrt::segment::ClusterProperty;
67 using ::tensorflow::tensorrt::segment::NodePtrCompare;
68 using ::tensorflow::tensorrt::segment::Segment;
69 
70 namespace {
71 
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)72 Status BuildNodeMap(const Graph& graph,
73                     std::unordered_map<string, Node*>* node_map) {
74   for (auto* node : graph.op_nodes()) {
75     if (!node_map->insert({node->name(), node}).second) {
76       return errors::AlreadyExists("Node name is not unique in graph: " +
77                                    node->name());
78     }
79   }
80   return Status::OK();
81 }
82 
GetEngineType(const ConversionParams & params)83 EngineInfo::EngineType GetEngineType(const ConversionParams& params) {
84   return (params.is_dyn_op || params.use_calibration)
85              ? EngineInfo::EngineType::TRTDynamic
86              : EngineInfo::EngineType::TRTStatic;
87 }
88 
89 // Returns true when use_implicit_batch is false or when we are building dynamic
90 // engine, to allow unknown size for dimensions rather than dimension 0.
AllowDynamicNonBatchDimension(const ConversionParams & params)91 bool AllowDynamicNonBatchDimension(const ConversionParams& params) {
92   return !params.use_implicit_batch ||
93          GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
94 }
95 
96 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::__anon11ab94fd0111::EdgePtrCompare97   bool operator()(const Edge* lhs, const Edge* rhs) const {
98     return lhs->id() < rhs->id();
99   }
100 };
101 
102 // TODO(laigd): instead of deciding the device here, the converter should accept
103 // a device name as one of the conversion parameter so users can control on
104 // which device they want to run the conversion.
GetFirstValidDeviceId()105 std::pair<TfDeviceId, PlatformDeviceId> GetFirstValidDeviceId() {
106   for (int tf_device_id_value = 0; tf_device_id_value < 100;
107        ++tf_device_id_value) {
108     TfDeviceId tf_device_id(tf_device_id_value);
109     PlatformDeviceId platform_device_id;
110     Status s =
111         GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id);
112     if (s.ok()) {
113       VLOG(1) << "Found TF GPU " << tf_device_id.value() << " at cuda device "
114               << platform_device_id.value();
115       return std::make_pair(tf_device_id, platform_device_id);
116     }
117   }
118   LOG(ERROR) << "Could not find any TF GPUs";
119   return std::make_pair(TfDeviceId(-1), PlatformDeviceId(-1));
120 }
121 
122 // Returns false for const nodes (we intend to drop control edges from those).
ShallKeepControlEdgeFrom(const Node * input_node)123 bool ShallKeepControlEdgeFrom(const Node* input_node) {
124   if (!input_node) {
125     LOG(ERROR) << "Node pointer is null, this should not happen";
126     return false;
127   }
128   return input_node->type_string() != "Const";
129 }
130 
131 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const Segment & segment,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)132 Status GetEngineInfo(const Graph* g,
133                      const grappler::GraphProperties& graph_properties,
134                      const Segment& segment,
135                      const std::unordered_map<string, Node*>& node_map,
136                      const std::vector<Node*>& reverse_topo_order,
137                      EngineInfo* info) {
138   std::vector<const Node*> subgraph_nodes;  // Topologically sorted nodes.
139   std::set<const Node*> added_const_nodes;  // Used to prevent double insertion.
140 
141   const ClusterProperty& segment_property = segment.property;
142   const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
143 
144   // The device assignment accumulated from the compatible device assignments
145   // for the nodes in the segment.
146   const DeviceNameUtils::ParsedName segment_device =
147       segment_property.DeviceName();
148   info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
149 
150   // Map from src_node_name+port to the unique port numbers of the TRT op, where
151   // the src_node_name is the name of the source node of the input/output
152   // edge, thus there must not be any duplicates since source nodes of
153   // input/output edges must be in different split of the graph.
154   // TODO(aaroey): consider using node id and port instead.
155   // TODO(aaroey): using topo order instead of reverting reverse topo order.
156   std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
157   for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
158        ++it) {
159     const Node* node = *it;
160     if (segment_nodes.count(node) == 0) continue;
161     subgraph_nodes.push_back(node);
162 
163     const int node_id = node->id();
164     const string& node_name = node->name();
165 
166     // Create input connections. Sort edges first to make deterministic since
167     // in_edges is a set of pointers.
168     std::vector<const Edge*> in_edges(node->in_edges().begin(),
169                                       node->in_edges().end());
170     std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
171     for (const auto edge : in_edges) {
172       auto input_node = edge->src();
173       if (input_node->IsSource() || segment_nodes.count(input_node)) {
174         continue;
175       }
176       if (edge->IsControlEdge()) {
177         if (ShallKeepControlEdgeFrom(input_node)) {
178           // Non-Const control input.
179           info->connections.emplace_back(input_node->name(), input_node->id(),
180                                          node_name, node_id,
181                                          /*input_edge=*/true);
182         }
183       } else if (input_node->type_string() == "Const") {
184         // Add constant data input nodes into the segment graphdef (thus also in
185         // the engine). We don't care if it has other output edges going into
186         // other engines or TF nodes. Since we add it only to the segment
187         // graphdef, not the segment itself, it won't be removed from the graph.
188         // If it doesn't have any edges, TF will prune it out.
189         //
190         // Note that the segmenter already ensure that the constant data input
191         // is valid and supported by the engine.
192         if (!added_const_nodes.insert(input_node).second) {
193           // Already added before.
194           continue;
195         }
196         VLOG(1) << "Adding const node " << input_node->name();
197       } else {
198         // Non-const data input.
199         int port = Graph::kControlSlot - 1;
200         // Use the source non-segment node name/port as key.
201         const string s = StrCat(input_node->name(), ":", edge->src_output());
202         VLOG(1) << "Input edge = " << s;
203         if (input_to_engine_port.count(s)) {
204           port = input_to_engine_port.at(s);
205         } else {
206           port = input_to_engine_port.size();
207           input_to_engine_port.insert({s, port});
208         }
209         info->connections.emplace_back(
210             input_node->name(), input_node->id(), edge->src_output(), node_name,
211             node_id, edge->dst_input(), /*input_edge=*/true, port);
212       }
213     }
214     // Create output connections. Sort edges first to make deterministic since
215     // out_edges is a set of pointers.
216     std::vector<const Edge*> out_edges(node->out_edges().begin(),
217                                        node->out_edges().end());
218     std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
219     for (const auto edge : out_edges) {
220       auto output_node = edge->dst();
221       if (output_node->IsSink() || segment_nodes.count(output_node)) {
222         continue;
223       }
224       if (edge->IsControlEdge()) {
225         // Control output.
226         if (ShallKeepControlEdgeFrom(node)) {
227           info->connections.emplace_back(output_node->name(), output_node->id(),
228                                          node_name, node_id,
229                                          /*input_edge=*/false);
230         }
231       } else {
232         // Data output.
233         int port = Graph::kControlSlot - 1;
234         // Use the source segment node name/port as key.
235         const string s = StrCat(node_name, ":", edge->src_output());
236         VLOG(1) << "Output edge = " << s;
237         if (output_to_engine_port.count(s)) {
238           port = output_to_engine_port.at(s);
239         } else {
240           port = output_to_engine_port.size();
241           output_to_engine_port.insert({s, port});
242         }
243         info->connections.emplace_back(
244             output_node->name(), output_node->id(), edge->dst_input(),
245             node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
246       }
247     }
248   }  // For each segment node in topological order.
249 
250   // Construct the const nodes first.
251   subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
252                         added_const_nodes.end());
253   TF_RETURN_IF_ERROR(
254       ConvertSegmentToGraphDef(g, graph_properties, subgraph_nodes, info));
255   VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
256           << "' to a GraphDef " << info->has_int32_input;
257   if (segment_device.has_type) {
258     // If the accumulated device assignment for the segment has a device type,
259     // the segmenter guarantees the device type is GPU. Use the device
260     // assignment in this case.
261     if (segment_device.type != "GPU") {
262       return errors::Internal(
263           "segment device is not GPU: ",
264           DeviceNameUtils::ParsedNameToString(segment_device));
265     }
266     info->device = DeviceNameUtils::ParsedNameToString(segment_device);
267   } else {
268     TfDeviceId tf_device_id;
269     PlatformDeviceId platform_device_id;
270     std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
271     if (tf_device_id.value() >= 0) {
272       DeviceNameUtils::ParsedName parsed_name;
273       parsed_name.type = "GPU";
274       parsed_name.has_type = true;
275       parsed_name.id = tf_device_id.value();
276       parsed_name.has_id = true;
277       info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
278     } else {
279       VLOG(1) << "No device is assigned to the segment. A device will be "
280                  "assigned during graph execution (inference).";
281     }
282   }
283   return Status::OK();
284 }
285 
286 // Helper function to update edge connection from the removed node to the
287 // engine node. If an outside node is gone, it must have been absorbed into
288 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)289 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
290                         const size_t my_engine_id,
291                         const std::vector<Node*>& engine_nodes,
292                         const bool is_input_edge, const string& node_name,
293                         Node** node, int* port) {
294   for (size_t t = 0; t < infos.size(); ++t) {
295     if (t == my_engine_id) {
296       continue;
297     }
298     const auto& info = infos.at(t);
299     for (const auto& eng_conn : info.connections) {
300       // If the connection being updated is an input connection, the source of
301       // the connection must be an output connection of another engine. And vise
302       // versa.
303       if (is_input_edge == eng_conn.is_input_edge) continue;
304       if (eng_conn.inside_node_name == node_name &&
305           eng_conn.inside_port == *port) {
306         *node = CHECK_NOTNULL(engine_nodes[t]);
307         QCHECK_EQ(info.engine_name, (**node).name())
308             << "Engine name mismatch: " << info.engine_name << " vs "
309             << (**node).name();
310         *port = eng_conn.port_number;
311         return;
312       }
313     }
314   }
315   LOG(FATAL) << "Node " << node_name << " not found in any engine.";
316 }
317 
318 // Function to insert a TRT engine node into the graph.
319 // Create engine nodes in the following way:
320 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
321 // 2. When an engine node is created, add it into the graph with necessary
322 //    re-wiring.
323 //    2.1. If the outside connected node is existing, connect the engine
324 //         node to it.
325 //    2.2. If the outside connected node is gone, it must have been absorted
326 //         into another engine node (which was processed before the processing
327 //         one). Connect to the pre-existing engine node instead.
328 // 3. In this way, we ensure the graph is topologically sort-able after each
329 //    invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int default_max_batch_size,Graph * graph,std::vector<Node * > * engine_nodes)330 Status CreateTRTNode(const ConversionParams& params,
331                      const std::vector<EngineInfo>& infos, int pos,
332                      int default_max_batch_size, Graph* graph,
333                      std::vector<Node*>* engine_nodes) {
334   const auto& info = infos.at(pos);
335   std::vector<tensorflow::TensorShapeProto> input_shape_protos;
336   std::vector<PartialTensorShape> input_shapes;
337   std::vector<NodeDefBuilder::NodeOut> inputs;
338   std::vector<Node*> input_nodes;
339   std::vector<Node*> control_input_nodes;
340   std::unordered_set<string> control_input_names;
341   std::vector<DataType> out_types;
342 
343   VLOG(1) << "Processing " << info.engine_name;
344   // Collect needed info for creating the engine node in the graph
345   for (const auto& conn : info.connections) {
346     // Control edges
347     if (conn.is_control_edge()) {
348       // Skip control outputs for now. control output info are not needed for
349       // node creation and will be processed later.
350       if (!conn.is_input_edge) continue;
351 
352       // Rewrire control input if it's not found in original graph.
353       Node* input_node = graph->FindNodeId(conn.outside_id);
354       int port = Graph::kControlSlot;
355       if (!input_node) {
356         UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
357                            conn.outside_node_name, &input_node, &port);
358         QCHECK_EQ(Graph::kControlSlot, port);
359       }
360       if (!control_input_names.insert(input_node->name()).second) {
361         continue;
362       }
363       control_input_nodes.push_back(input_node);
364       VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
365               << info.engine_name;
366     } else {
367       // Data edges
368       if (!conn.is_input_edge) {
369         // Set the data types of output edge.
370         if (out_types.size() <= conn.port_number) {
371           out_types.resize(conn.port_number + 1);
372         }
373         out_types.at(conn.port_number) = conn.connection_type;
374       } else {
375         // Set the shapes and data types of input edge.
376         if (input_shapes.size() <= conn.port_number) {
377           input_shape_protos.resize(conn.port_number + 1);
378           input_shapes.resize(conn.port_number + 1);
379         }
380         conn.outside_shape.AsProto(&input_shape_protos.at(conn.port_number));
381         input_shapes.at(conn.port_number) = conn.outside_shape;
382         // Shape must be fully defined (excluding batch dimension) for static
383         // mode.
384         if (params.use_implicit_batch &&
385             info.engine_type == EngineInfo::EngineType::TRTStatic) {
386           for (int i = 1; i < conn.outside_shape.dims(); i++) {
387             if (conn.outside_shape.dim_size(i) <= 0) {
388               return errors::Internal(
389                   "Not fully defined input shape when in static mode which "
390                   "should have been excluded by the segmenter. ");
391             }
392           }
393         }
394 
395         // Rewrire data input if it's not found in original graph.
396         Node* input_node = graph->FindNodeId(conn.outside_id);
397         int port = conn.outside_port;
398         if (!input_node) {
399           UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
400                              conn.outside_node_name, &input_node, &port);
401         }
402         if (std::find_if(
403                 std::begin(inputs), std::end(inputs),
404                 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
405                   return inp.node == input_node->name() && inp.index == port;
406                 }) == std::end(inputs)) {
407           inputs.emplace_back(input_node->name(), port, conn.connection_type);
408           input_nodes.push_back(CHECK_NOTNULL(input_node));
409           VLOG(1) << "Engine Input " << input_node->name() << ":" << port
410                   << " -> " << info.engine_name << ":" << inputs.size() - 1;
411         }
412       }
413     }
414   }
415   // We don't support segments with no inputs. Fall back to native TF here to
416   // avoid crash later. Constant folding should've folded the ops that make up
417   // these segments.
418   if (inputs.empty()) {
419     return errors::Internal(
420         "Segment has no inputs (possible constfold failure)");
421   }
422 
423   const bool calibrate_int8 =
424       (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
425   // Build the engine and get its serialized representation.
426   string segment_string;
427 
428   int max_batch_size = info.max_batch_size.has_value()
429                            ? info.max_batch_size.value()
430                            : default_max_batch_size;
431 
432   if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
433     std::pair<int, Allocator*> device_allocator =
434         GetDeviceAndAllocator(params, info);
435     int cuda_device_id = 0;
436     std::unique_ptr<TRTBaseAllocator> trt_allocator;
437     if (device_allocator.first >= 0) {
438       cuda_device_id = device_allocator.first;
439       trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
440     } else {
441       // The value in trt_allocator is a nullptr and cudamalloc will be used.
442       LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
443                                  "device 0 and use cudamalloc as an allocator";
444     }
445     cudaSetDevice(cuda_device_id);
446 
447     auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
448 
449     // Create static engines with precision_mode fp32/fp16.
450     TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
451     TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
452         info.segment_graph_def,
453         calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
454         max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
455         trt_allocator.get(), /*calibrator=*/nullptr, &engine,
456         info.use_calibration, params.use_implicit_batch,
457         /*convert_successfully=*/nullptr,
458         /*profile=*/nullptr, info.engine_name));
459     TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
460     segment_string = string(static_cast<const char*>(engine_data->data()),
461                             engine_data->size());
462   }
463 
464   string prec_string;
465   TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
466   NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
467   if (!info.device.empty()) node_builder.Device(info.device);
468   if (VLOG_IS_ON(1)) {
469     string ins = StrCat(info.engine_name, " inputs= ");
470     for (const auto& ii : inputs) {
471       StrAppend(&ins, ii.node, ":", ii.index, " ");
472     }
473     VLOG(1) << ins;
474   }
475   node_builder.Input(inputs);
476   for (const string& c : control_input_names) {
477     node_builder.ControlInput(c);
478   }
479 
480   NodeDef trt_node;
481   NameAttrList function;
482   function.set_name(StrCat(info.engine_name, "_native_segment"));
483   node_builder.Attr("input_shapes", input_shape_protos)
484       .Attr("static_engine",
485             info.engine_type == EngineInfo::EngineType::TRTStatic)
486       .Attr("segment_func", function)
487       .Attr("serialized_segment", segment_string)
488       .Attr("calibration_data", "")
489       .Attr("max_cached_engines_count", info.maximum_cached_engines)
490       .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
491       .Attr("max_batch_size", max_batch_size)
492       .Attr("precision_mode", prec_string)
493       .Attr("use_calibration", info.use_calibration)
494       .Attr("_use_implicit_batch", params.use_implicit_batch)
495       .Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
496       .Attr("OutT", out_types);
497 
498   if (!params.use_implicit_batch) {
499     node_builder.Attr("profile_strategy",
500                       ProfileStrategyToName(params.profile_strategy));
501   }
502 
503   Status status = node_builder.Finalize(&trt_node);
504   if (!status.ok()) {
505     LOG(ERROR) << "Node construction failed with" << status;
506     return status;
507   }
508   VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
509 
510   // Up until this point, graph is not modified. If we return !status.ok() from
511   // here, this segment will be skipped
512   // TODO(aaroey): let it return proper error status for the following logic
513   // instead of checking fail.
514   Node* engine_node = graph->AddNode(trt_node, &status);
515   (*engine_nodes)[pos] = engine_node;
516   if (!status.ok()) {
517     LOG(ERROR) << "Adding node failed " << status;
518     return status;
519   }
520   // Add control input and input edges to the engine node.
521   for (const auto in : control_input_nodes) {
522     VLOG(1) << "Connecting control edge from " << in->name() << " to "
523             << engine_node->name();
524     graph->AddControlEdge(in, engine_node);
525   }
526   VLOG(1) << "input_nodes size = " << input_nodes.size();
527   for (int i = 0; i < input_nodes.size(); ++i) {
528     Node* n = CHECK_NOTNULL(input_nodes[i]);
529     const auto& in = inputs[i];
530     VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
531             << " to " << engine_node->name() << ":" << i;
532     graph->AddEdge(n, in.index, engine_node, i);
533   }
534 
535   // Updates the inputs of output edges destination nodes, and point them to the
536   // engine node.
537   for (auto& conn : info.connections) {
538     if (conn.is_input_edge) {
539       continue;
540     }
541     Node* output_node = graph->FindNodeId(conn.outside_id);
542     int port = conn.outside_port;
543     if (!output_node) {
544       UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
545                          conn.outside_node_name, &output_node, &port);
546     }
547     if (conn.is_control_edge()) {
548       VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
549               << output_node->name();
550       QCHECK_EQ(Graph::kControlSlot, port);
551       graph->AddControlEdge(engine_node, output_node);
552     } else {
553       VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
554               << conn.port_number << " to " << output_node->name() << ":"
555               << port;
556       // Use UpdateEdge() to avoid adding the same edge multiple times.
557       TF_CHECK_OK(
558           graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
559     }
560   }
561   return Status::OK();
562 }
563 
GetNextGraphSequenceNumber()564 int64 GetNextGraphSequenceNumber() {
565   static std::atomic<int64> graph_sequence_num;
566   return graph_sequence_num++;
567 }
568 
569 constexpr char kCastInputTypeAttrName[] = "SrcT";
570 
571 // Transforms node = cast(x, fp32) where datatype(x) != fp16 to:
572 //   castToFp16 = cast(x, fp16)
573 //   node = cast(castToFp16, fp32)
574 //
MaybeRewriteCastToFp32(GraphDef * graph_def,NodeDef * node_def)575 Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) {
576   if (node_def->op() != "Cast") {
577     return Status::OK();
578   }
579 
580   DataTypeVector input_types;
581   DataTypeVector output_types;
582   TF_RETURN_IF_ERROR(
583       graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types));
584 
585   if (input_types.size() != 1 || output_types.size() != 1) {
586     return errors::Internal("Bad cast operation");
587   }
588 
589   if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) {
590     return Status::OK();
591   }
592 
593   VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString();
594 
595   NodeDef* castToFp16 = graph_def->add_node();
596   for (auto attr_value : node_def->attr()) {
597     (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second;
598   }
599   castToFp16->set_name(node_def->name() + "_split");
600   castToFp16->set_op("Cast");
601   castToFp16->set_device(node_def->device());
602   castToFp16->add_input(node_def->input(0));
603   (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF);
604 
605   node_def->set_input(0, castToFp16->name() + ":0");
606   (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF);
607 
608   VLOG(2) << castToFp16->DebugString();
609   VLOG(2) << node_def->DebugString();
610 
611   return Status::OK();
612 }
613 
614 }  // namespace
615 
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name,bool has_int32_input)616 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
617                                       Graph* graph, const string& engine_name,
618                                       bool has_int32_input) {
619   Graph segment_graph(graph->flib_def());
620   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
621                                             segment_graph_def, &segment_graph));
622   FunctionDefLibrary library;
623   auto segment_func = library.add_function();
624   TF_RETURN_IF_ERROR(GraphToFunctionDef(
625       segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
626   if (has_int32_input) {
627     // Setting this attribute value informs TensorFlow that the int32 _Arg node
628     // inputs are on device memory during native segment execution. We only set
629     // the attribute value when there is any int32 input because the attribute
630     // is not compatible with is_multi_device_function.
631     SetAttrValue(
632         true,
633         &(*segment_func
634                ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]);
635   }
636   if (VLOG_IS_ON(7)) {
637     VLOG(7) << engine_name << " Function_Def ";
638     VLOG(7) << segment_func->DebugString();
639   }
640   VLOG(1) << "Adding funcdef " << segment_func->signature().name()
641           << " to graphlib";
642   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
643   return Status::OK();
644 }
645 
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)646 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
647                                                  const EngineInfo& engine) {
648   int cuda_device_id = -1;
649   Allocator* dev_allocator = nullptr;
650   if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
651       engine.device.empty()) {
652     // If device is not set, use the first found GPU device for the conversion.
653     TfDeviceId tf_device_id;
654     PlatformDeviceId platform_device_id;
655     std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
656     cuda_device_id = platform_device_id.value();
657     if (cuda_device_id >= 0) {
658       GPUOptions gpu_options;
659       // If the TF to Cuda gpu id mapping exist, the device and corresponding
660       // allocator must have been initialized already, so the
661       // GetGPUAllocator() call won't create a new allocator.
662       dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
663           gpu_options, tf_device_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
664     }
665     return std::make_pair(cuda_device_id, dev_allocator);
666   }
667 
668   // Use the device requested by the engine.
669   auto device_set = params.cluster->GetDeviceSet();
670   std::vector<Device*> devices;
671   DeviceNameUtils::ParsedName parsed_name;
672   if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
673       parsed_name.has_id) {
674     device_set->FindMatchingDevices(parsed_name, &devices);
675   }
676   if (!devices.empty()) {
677     if (devices.size() > 1) {
678       string msg = "Found multiple matching devices using name '";
679       StrAppend(&msg, engine.device, "': ");
680       for (auto d : devices) StrAppend(&msg, d->name(), ", ");
681       StrAppend(&msg, ". Will get the allocator from first one.");
682       LOG_WARNING_WITH_PREFIX << msg;
683     }
684     AllocatorAttributes alloc_attr;
685     cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
686     dev_allocator = devices[0]->GetAllocator(alloc_attr);
687     VLOG(1) << "Using allocator " << dev_allocator->Name()
688             << " and cuda_device_id " << cuda_device_id;
689   } else {
690     LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device
691                             << "' is not found in the cluster";
692   }
693   return std::make_pair(cuda_device_id, dev_allocator);
694 }
695 
696 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)697 Status ConvertAfterShapes(const ConversionParams& params) {
698   // Sanity checks.
699   if (params.precision_mode != TrtPrecisionMode::INT8 &&
700       params.use_calibration) {
701     return errors::InvalidArgument(
702         "Calibration with FP32 or FP16 is not supported.");
703   }
704 
705   // Make a copy of the input_graph_def because grappler doesn't allow changes
706   // to the input_graph_def and GraphProperties only accepts GraphDef, but not
707   // Graph, as inputs.
708   //
709   // If the overhead of copying the input_graph_def becomes a concern, we can
710   // avoid the copy by (1) enhancing the GraphPropertiers representation to
711   // allow adding shape properties for newly created graph nodes and (2) rewrite
712   // the GraphDef transformation to Graph transformation.
713   GraphDef modified_graph_def = params.grappler_item->graph;
714   // When precision_mode is FP16, transform cast(x, fp32) to
715   // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be
716   // included in the TRTEngineOp as an TensorRT Identity layer for performance:
717   //  . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16
718   //    precision.
719   //  . Changing the input to the TRTEngine from fp32 to fp16 may reduce data
720   //    moving from the host to the GPU.
721   if (params.precision_mode == TrtPrecisionMode::FP16) {
722     for (int i = 0; i < modified_graph_def.node_size(); i++) {
723       NodeDef* node_def = modified_graph_def.mutable_node(i);
724       TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&modified_graph_def, node_def));
725     }
726   }
727 
728   // Construct a GrapplerItem using the modified graph_def and the input
729   // grappler_item.
730   grappler::GrapplerItem grappler_item =
731       params.grappler_item->WithGraph(std::move(modified_graph_def));
732   const GraphDef& graph_def = grappler_item.graph;
733 
734   grappler::GraphProperties static_graph_properties(grappler_item);
735   TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
736 
737   // Convert graphdef to graph.
738   FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
739   Graph graph(flib);
740   TF_RETURN_IF_ERROR(
741       ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph));
742 
743   // Segment the graph into subgraphs that can be converted to TensorRT
744   segment::SegmentOptions segment_options;
745   // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
746   for (const auto& node : *(params.output_names)) {
747     segment_options.exclude_node_list.insert(node);
748   }
749   segment_options.minimum_segment_size = params.minimum_segment_size;
750   segment_options.use_implicit_batch = params.use_implicit_batch;
751   if (segment_options.use_implicit_batch)
752     segment_options.maximum_batch_size = params.max_batch_size;
753   segment_options.allow_dynamic_non_batch_dim =
754       AllowDynamicNonBatchDimension(params);
755 
756   segment::SegmentVector initial_segments;
757   TrtNodeValidator validator(static_graph_properties, params.precision_mode,
758                              params.use_calibration, params.use_implicit_batch);
759   TF_RETURN_IF_ERROR(segment::SegmentGraph(
760       &graph, &static_graph_properties,
761       std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
762                 std::placeholders::_1),
763       // Input validation is already done by TrtNodeValidator, so we don't
764       // need to check the input edges.
765       [](const Edge* edge) { return true; }, OutputEdgeValidator(),
766       segment_options, &initial_segments));
767   LOG(INFO) << "Number of TensorRT candidate segments: "
768             << initial_segments.size();
769 
770   // Get the EngineInfo for each segment.
771   std::unordered_map<string, Node*> node_map;
772   TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
773   std::vector<EngineInfo> engine_segments;
774   engine_segments.reserve(initial_segments.size());
775   std::vector<Node*> reverse_topo_order;
776   GetPostOrder(graph, &reverse_topo_order);
777   segment::SegmentVector converted_segments;
778   converted_segments.reserve(initial_segments.size());
779   string engine_name_prefix =
780       StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
781   for (size_t t = 0; t < initial_segments.size(); t++) {
782     auto& curr_segment = initial_segments.at(t);
783     EngineInfo curr_engine;
784     curr_engine.engine_name = StrCat(engine_name_prefix, t);
785 
786     bool int8_no_calib = (params.use_calibration == false &&
787                           params.precision_mode == TrtPrecisionMode::INT8);
788     bool has_qdq = false;
789     if (int8_no_calib) {
790       has_qdq = absl::c_any_of(reverse_topo_order, IsQuantizeAndDequantizeOp);
791     }
792 
793     Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment,
794                                   node_map, reverse_topo_order, &curr_engine);
795     if (!status.ok()) {
796       LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t
797                               << ": " << status;
798       continue;
799     }
800 
801     curr_engine.engine_type = GetEngineType(params);
802     curr_engine.use_calibration = params.use_calibration;
803     // Building cuda engines for INT8 without calibration and without dynamic
804     // range info cause TRT failure. Avoid this situation by setting the
805     // precision to FP16.
806     if (int8_no_calib && !has_qdq) {
807       VLOG(1) << "Set engine precision to FP16 due to missing QDQ OP";
808       curr_engine.precision_mode = TrtPrecisionMode::FP16;
809     } else {
810       curr_engine.precision_mode = params.precision_mode;
811     }
812     curr_engine.maximum_cached_engines = params.max_cached_engines;
813     curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
814     if (!curr_engine.max_batch_size.has_value()) {
815       curr_engine.max_batch_size = params.max_batch_size;
816     }
817 
818     status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
819                                             &graph, curr_engine.engine_name,
820                                             curr_engine.has_int32_input);
821 
822     if (!status.ok()) {
823       LOG_WARNING_WITH_PREFIX
824           << "Failed to register segment graphdef to the library " << t << ": "
825           << status;
826       continue;
827     }
828 
829     engine_segments.push_back(std::move(curr_engine));
830     converted_segments.push_back(std::move(curr_segment));
831 
832     if (VLOG_IS_ON(8)) {
833       string fname = engine_segments.back().engine_name;
834       StrAppend(&fname, ".pb");
835       std::fstream f;
836       f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
837       f << engine_segments.at(t).segment_graph_def.SerializeAsString();
838       f.close();
839     }
840   }
841 
842   // Save the cuda device if we may need to switch to another cuda device to
843   // build static engines.
844   absl::optional<int> old_cuda_device = absl::nullopt;
845   if (!params.is_dyn_op) {
846     int cuda_device_id;
847     cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
848     if (cuda_error != cudaSuccess) {
849       LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
850                               << cudaGetErrorString(cuda_error);
851     } else {
852       VLOG(1) << "Current cuda device is " << cuda_device_id;
853       old_cuda_device = cuda_device_id;
854     }
855   }
856 
857   auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
858     if (old_cuda_device.has_value()) {
859       cudaSetDevice(old_cuda_device.value());
860     }
861   });
862 
863   std::vector<Node*> engine_nodes;
864   engine_nodes.resize(engine_segments.size());
865   for (int i = 0; i < engine_segments.size(); ++i) {
866     auto& engine = engine_segments.at(i);
867     // TODO(b/170762693): implement the heuristic to calculate
868     // max_workspace_size_bytes.
869     engine.max_workspace_size_bytes = params.max_workspace_size_bytes;
870     VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
871             << engine.engine_name;
872     auto status = CreateTRTNode(params, engine_segments, i,
873                                 params.max_batch_size, &graph, &engine_nodes);
874 
875     string msg = StrCat("segment ", i, " consisting of ",
876                         converted_segments.at(i).nodes.size(), " nodes by ",
877                         engine.engine_name);
878     if (status.ok()) {
879       LOG(INFO) << "Replaced " << msg << ".";
880     } else {
881       // Graph is not modified.
882       LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg
883                               << " reason: " << status.error_message()
884                               << " (keeping original segment).";
885     }
886     if (VLOG_IS_ON(1)) {
887       msg = "Segment consists of nodes: ";
888       for (const Node* node : converted_segments.at(i).nodes) {
889         StrAppend(&msg, node->name(), ", ");
890       }
891       VLOG(1) << msg;
892     }
893 
894     // If status is ok, we successfully added the node to the graph and can
895     // remove segment ops. Otherwise graph is not modified.
896     if (status.ok()) {
897       for (const Node* node : converted_segments.at(i).nodes) {
898         graph.RemoveNode(const_cast<Node*>(node));
899       }
900     }
901   }
902   graph.ToGraphDef(params.output_graph_def);
903   VLOG(1) << "Returning from conversion";
904   return Status::OK();
905 }
906 
907 }  // namespace convert
908 }  // namespace tensorrt
909 }  // namespace tensorflow
910 
911 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
912