• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17 
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
30 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
31 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
32 #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/graph_to_functiondef.h"
38 #include "tensorflow/core/framework/node_def_builder.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/graph/graph.h"
41 #include "tensorflow/core/graph/graph_constructor.h"
42 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
43 #include "tensorflow/core/grappler/costs/graph_properties.h"
44 #include "tensorflow/core/grappler/devices.h"
45 #include "tensorflow/core/grappler/grappler_item.h"
46 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
47 #include "tensorflow/core/grappler/utils.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/types.h"
53 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
54 #include "tensorflow/core/protobuf/device_properties.pb.h"  // NOLINT
55 #include "tensorflow/core/protobuf/rewriter_config.pb.h"  // NOLINT
56 #include "tensorflow/core/util/device_name_utils.h"
57 
58 #if GOOGLE_CUDA
59 #if GOOGLE_TENSORRT
60 #include "cuda/include/cuda_runtime_api.h"
61 #include "tensorrt/include/NvInfer.h"
62 namespace tensorflow {
63 namespace tensorrt {
64 namespace convert {
65 using absl::StrAppend;
66 using absl::StrCat;
67 
TrtCandidateSelector(const grappler::GraphProperties & graph_properties,TrtPrecisionMode precision_mode)68 TrtCandidateSelector::TrtCandidateSelector(
69     const grappler::GraphProperties& graph_properties,
70     TrtPrecisionMode precision_mode)
71     : graph_properties_(graph_properties), precision_mode_(precision_mode) {}
72 
IsTensorRTCandidate(const Node * node)73 Status TrtCandidateSelector::IsTensorRTCandidate(const Node* node) {
74   std::vector<const Edge*> input_edges;
75   TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
76   std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
77   input_node_and_ports.reserve(input_edges.size());
78   for (const Edge* input_edge : input_edges) {
79     input_node_and_ports.emplace_back(&input_edge->src()->def(),
80                                       input_edge->src_output());
81   }
82   return validator_.ValidateNode(node->def(), input_node_and_ports,
83                                  precision_mode_, graph_properties_);
84 }
85 
86 namespace {
87 
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)88 Status BuildNodeMap(const Graph& graph,
89                     std::unordered_map<string, Node*>* node_map) {
90   for (auto* node : graph.op_nodes()) {
91     if (!node_map->insert({node->name(), node}).second) {
92       return errors::AlreadyExists("Node name is not unique in graph: " +
93                                    node->name());
94     }
95   }
96   return Status::OK();
97 }
98 
99 }  // namespace
100 
ConvertGraphDefToTensorRT(const GraphDef & graph_def,const std::vector<string> & output_names,size_t max_batch_size,size_t max_workspace_size_bytes,GraphDef * new_graph_def,TrtPrecisionMode precision_mode,int minimum_segment_size,bool is_dyn_op,int max_cached_engines,std::vector<int> cached_engine_batches,bool use_calibration)101 Status ConvertGraphDefToTensorRT(
102     const GraphDef& graph_def, const std::vector<string>& output_names,
103     size_t max_batch_size, size_t max_workspace_size_bytes,
104     GraphDef* new_graph_def, TrtPrecisionMode precision_mode,
105     int minimum_segment_size, bool is_dyn_op, int max_cached_engines,
106     std::vector<int> cached_engine_batches, bool use_calibration) {
107   // Create GrapplerItem.
108   grappler::GrapplerItem item;
109   item.fetch = output_names;
110   item.graph = graph_def;
111 
112 // TODO(aaroey): we should have used single machine cluster like the
113 // following, but the problem is then wrap_conversion will depend on
114 // direct_session and cause double linking problems. To fix this we need to
115 // fix or get rid of the swig dependency. Here we use VirtualCluster
116 // as a work around, and we need to create a session to initialize the
117 // underlying device before calling this method.
118 #if 0
119   // Create single machine cluster. Note that this will create a session and
120   // initialize the gpu devices.
121   const int num_cpu_cores =
122       grappler::GetNumAvailableLogicalCPUCores();
123   const int num_gpus = grappler::GetNumAvailableGPUs();
124   VLOG(2) << "cpu_cores: " << num_cpu_cores;
125   VLOG(2) << "gpus: " << num_gpus;
126   const int timeout_s = 60 * 10;
127   std::unique_ptr<grappler::Cluster> cluster(
128       new grappler::SingleMachine(
129           timeout_s, num_cpu_cores, num_gpus));
130   // These settings are the defaults in tensorflow/python/grappler/cluster.py.
131   cluster->DisableDetailedStats(true);
132   cluster->AllowSoftPlacement(true);
133   cluster->SetNumWarmupSteps(10);
134   TF_RETURN_IF_ERROR(cluster->Provision());
135 #else
136   // Create virtual cluster. Grappler requires a virtual cluster with a proper
137   // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode.
138   // We add numbers from a Pascal card here to have flops>0.
139   DeviceProperties device_properties;
140   device_properties.set_type("GPU");
141   device_properties.mutable_environment()->insert({"architecture", "6"});
142   device_properties.set_num_cores(3584);
143   device_properties.set_frequency(1531);
144   std::unique_ptr<grappler::Cluster> cluster(
145       new grappler::VirtualCluster({{"/GPU:0", device_properties}}));
146 #endif
147 
148   // Create RewriterConfig.
149   ConfigProto config_proto;
150   auto& rw_cfg =
151       *config_proto.mutable_graph_options()->mutable_rewrite_options();
152   // TODO(aaroey): use only const folding and layout for the time being since
153   // new optimizers break the graph for trt.
154   rw_cfg.add_optimizers("constfold");
155   rw_cfg.add_optimizers("layout");
156   auto optimizer = rw_cfg.add_custom_optimizers();
157   optimizer->set_name("TensorRTOptimizer");
158   auto& parameters = *(optimizer->mutable_parameter_map());
159   parameters["minimum_segment_size"].set_i(minimum_segment_size);
160   parameters["max_batch_size"].set_i(max_batch_size);
161   parameters["is_dynamic_op"].set_b(is_dyn_op);
162   parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes);
163   TF_RETURN_IF_ERROR(TrtPrecisionModeToName(
164       precision_mode, parameters["precision_mode"].mutable_s()));
165   parameters["maximum_cached_engines"].set_i(max_cached_engines);
166   if (!cached_engine_batches.empty()) {
167     auto list = parameters["cached_engine_batches"].mutable_list();
168     for (const int batch : cached_engine_batches) {
169       list->add_i(batch);
170     }
171   }
172   parameters["use_calibration"].set_b(use_calibration);
173 
174   // Run optimizer.
175   grappler::MetaOptimizer meta_opt(nullptr, config_proto);
176   TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def));
177 
178   if (VLOG_IS_ON(5)) {
179     std::fstream f;
180     f.open("TRTConversionInput.pb",
181            std::fstream::out | std::fstream::binary | std::fstream::trunc);
182     f << new_graph_def->SerializeAsString();
183     f.close();
184   }
185   return Status::OK();
186 }
187 
188 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::EdgePtrCompare189   bool operator()(const Edge* lhs, const Edge* rhs) const {
190     return lhs->id() < rhs->id();
191   }
192 };
193 
194 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const std::set<const Node * > & segment_nodes,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)195 Status GetEngineInfo(const Graph* g,
196                      const grappler::GraphProperties& graph_properties,
197                      const std::set<const Node*>& segment_nodes,
198                      const std::unordered_map<string, Node*>& node_map,
199                      const std::vector<Node*>& reverse_topo_order,
200                      EngineInfo* info) {
201   std::vector<const Node*> subgraph_nodes;  // Topologically sorted nodes.
202   std::set<const Node*> added_const_nodes;  // Used to prevent double insertion.
203   std::set<string> segment_devices;
204 
205   // Map from src_node_name+port to the unique port numbers of the TRT op, where
206   // the src_node_name is the name of the source node of the input/output
207   // edge, thus there must not be any duplicates since source nodes of
208   // input/output edges must be in different split of the graph.
209   // TODO(aaroey): consider using node id and port instead.
210   // TODO(aaroey): using topo order instead of reverting reverse topo order.
211   std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
212   for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
213        ++it) {
214     const Node* node = *it;
215     if (segment_nodes.count(node) == 0) continue;
216     auto node_device = node->requested_device();
217     if (!node_device.empty()) {
218       // If device is CPU, treat as if no device was assigned. Don't add CPU to
219       // segment_device because that would cause a segfault in
220       // GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes
221       // any already set device is a GPU.
222       DeviceNameUtils::ParsedName parsed_name;
223       DeviceNameUtils::ParseFullName(node_device, &parsed_name);
224       if (parsed_name.type == "CPU") {
225         VLOG(1) << "Node " << node->name() << " was assigned to the CPU. "
226                 << "Attempting to place on GPU.";
227       } else {
228         segment_devices.insert(node_device);
229       }
230     } else {
231       if (node->has_assigned_device_name()) {
232         // It appears that nodes will not have assigned devices at this point in
233         // execution.
234         segment_devices.insert(node->assigned_device_name());
235       } else {
236         VLOG(2) << "Node " << node->name()
237                 << " neither have requested device nor assigned device";
238       }
239     }
240     subgraph_nodes.push_back(node);
241 
242     const int node_id = node->id();
243     const string& node_name = node->name();
244 
245     // Create input connections. Sort edges first to make determnistic since
246     // in_edges is a set of pointers.
247     std::vector<const Edge*> in_edges(node->in_edges().begin(),
248                                       node->in_edges().end());
249     std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
250     for (const auto edge : in_edges) {
251       auto input_node = edge->src();
252       if (input_node->IsSource() || segment_nodes.count(input_node)) {
253         continue;
254       }
255       if (edge->IsControlEdge()) {
256         // Control input.
257         info->connections.emplace_back(input_node->name(), input_node->id(),
258                                        node_name, node_id,
259                                        /*input_edge=*/true);
260       } else if (input_node->type_string() == "Const") {
261         // Add constant data input nodes into the segment graphdef (thus also in
262         // the engine). We don't care if it has other output edges going into
263         // other engines or TF nodes. Since we add it only to the segment
264         // graphdef, not the segment itself, it won't be removed from the graph.
265         // If it doesn't have any edges, TF will prune it out.
266         //
267         // Note that the segmenter already ensure that the constant data input
268         // is valid and suppported by the engine.
269         if (!added_const_nodes.insert(input_node).second) {
270           // Already added before.
271           continue;
272         }
273         VLOG(1) << "Adding const node " << input_node->name();
274         // Since we already add (duplicate) the const input node to the segment
275         // graphdef, it's now not a data dependency any more, but to make the
276         // dependency correct we still add a control dependency.
277         info->connections.emplace_back(input_node->name(), input_node->id(),
278                                        node_name, node_id,
279                                        /*input_edge=*/true);
280       } else {
281         // Non-const data input.
282         int port = Graph::kControlSlot - 1;
283         // Use the source non-segment node name/port as key.
284         const string s = StrCat(input_node->name(), ":", edge->src_output());
285         VLOG(1) << "Input edge = " << s;
286         if (input_to_engine_port.count(s)) {
287           port = input_to_engine_port.at(s);
288         } else {
289           port = input_to_engine_port.size();
290           input_to_engine_port.insert({s, port});
291         }
292         info->connections.emplace_back(
293             input_node->name(), input_node->id(), edge->src_output(), node_name,
294             node_id, edge->dst_input(), /*input_edge=*/true, port);
295       }
296     }
297     // Create output connections. Sort edges first to make determnistic since
298     // out_edges is a set of pointers.
299     std::vector<const Edge*> out_edges(node->out_edges().begin(),
300                                        node->out_edges().end());
301     std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
302     for (const auto edge : out_edges) {
303       auto output_node = edge->dst();
304       if (output_node->IsSink() || segment_nodes.count(output_node)) {
305         continue;
306       }
307       if (edge->IsControlEdge()) {
308         // Control output.
309         info->connections.emplace_back(output_node->name(), output_node->id(),
310                                        node_name, node_id,
311                                        /*input_edge=*/false);
312       } else {
313         // Data output.
314         int port = Graph::kControlSlot - 1;
315         // Use the source segment node name/port as key.
316         const string s = StrCat(node_name, ":", edge->src_output());
317         VLOG(1) << "Output edge = " << s;
318         if (output_to_engine_port.count(s)) {
319           port = output_to_engine_port.at(s);
320         } else {
321           port = output_to_engine_port.size();
322           output_to_engine_port.insert({s, port});
323         }
324         info->connections.emplace_back(
325             output_node->name(), output_node->id(), edge->dst_input(),
326             node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
327       }
328     }
329   }  // For each segment node in topological order.
330 
331   // Construct the const nodes first.
332   subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
333                         added_const_nodes.end());
334   string scope_name;
335   TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
336       g, graph_properties, subgraph_nodes, &info->connections,
337       &info->segment_graph_def, &scope_name));
338   info->engine_name = StrCat(scope_name, info->engine_name);
339   VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
340           << "' to a GraphDef";
341   // TODO(sami): This should not happen once segmenter is updated.
342   if (segment_devices.size() == 1) {
343     info->device = *segment_devices.begin();
344   } else if (segment_devices.size() > 1) {
345     LOG(WARNING) << "Detected multiple(" << segment_devices.size()
346                  << ") devices for the segment. Picking first one to continue "
347                  << "but this shouldn't have happened";
348     info->device = *segment_devices.begin();
349   } else {
350     VLOG(1) << "No device is assigned to the segment. "
351             << "A device will be assigned during graph execution (inference).";
352   }
353   return Status::OK();
354 }
355 
356 // Helper function to update edge connection from the removed node to the
357 // engine node. If an outside node is gone, it must have been absorbed into
358 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)359 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
360                         const size_t my_engine_id,
361                         const std::vector<Node*>& engine_nodes,
362                         const bool is_input_edge, const string& node_name,
363                         Node** node, int* port) {
364   for (size_t t = 0; t < infos.size(); ++t) {
365     if (t == my_engine_id) {
366       continue;
367     }
368     const auto& info = infos.at(t);
369     for (const auto& eng_conn : info.connections) {
370       // If the connection being updated is an input connection, the source of
371       // the connection must be an output connection of another engine. And vise
372       // versa.
373       if (is_input_edge == eng_conn.is_input_edge) continue;
374       if (eng_conn.inside_node_name == node_name &&
375           eng_conn.inside_port == *port) {
376         *node = CHECK_NOTNULL(engine_nodes[t]);
377         QCHECK_EQ(info.engine_name, (**node).name())
378             << "Engine name mismatch: " << info.engine_name << " vs "
379             << (**node).name();
380         *port = eng_conn.port_number;
381         return;
382       }
383     }
384   }
385   LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
386 }
387 
388 // Function to insert a TRT engine node into the graph.
389 // Create engine nodes in the following way:
390 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
391 // 2. When an engine node is created, add it into the graph with necessary
392 //    re-wiring.
393 //    2.1. If the outside connected node is existing, connect the engine
394 //         node to it.
395 //    2.2. If the outside connected node is gone, it must have been absorted
396 //         into another engine node (which was processed before the processing
397 //         one). Connect to the pre-existing engine node instead.
398 // 3. In this way, we ensure the graph is topologically sort-able after each
399 //    invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int max_batch_size,Graph * graph,nvinfer1::IGpuAllocator * alloc,std::vector<Node * > * engine_nodes)400 Status CreateTRTNode(const ConversionParams& params,
401                      const std::vector<EngineInfo>& infos, int pos,
402                      int max_batch_size, Graph* graph,
403                      nvinfer1::IGpuAllocator* alloc,
404                      std::vector<Node*>* engine_nodes) {
405   const auto& info = infos.at(pos);
406   std::vector<TensorShapeProto> output_shape_protos;
407   std::vector<TensorShapeProto> input_shape_protos;
408   std::vector<PartialTensorShape> input_shapes;
409   std::vector<NodeDefBuilder::NodeOut> inputs;
410   std::vector<Node*> input_nodes;
411   std::vector<Node*> control_input_nodes;
412   std::unordered_set<string> control_input_names;
413   std::vector<DataType> out_types;
414 
415   VLOG(1) << "Processing " << info.engine_name;
416   // Collect needed info for creating the engine node in the graph
417   for (const auto& conn : info.connections) {
418     // Control edges
419     if (conn.is_control_edge()) {
420       // Skip control outputs for now. control output info are not needed for
421       // node creation and will be processed later.
422       if (!conn.is_input_edge) continue;
423 
424       // Rewrire control input if it's not found in original graph.
425       Node* input_node = graph->FindNodeId(conn.outside_id);
426       int port = Graph::kControlSlot;
427       if (!input_node) {
428         UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
429                            conn.outside_node_name, &input_node, &port);
430         QCHECK_EQ(Graph::kControlSlot, port);
431       }
432       if (!control_input_names.insert(input_node->name()).second) {
433         continue;
434       }
435       control_input_nodes.push_back(input_node);
436       VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
437               << info.engine_name;
438     } else {
439       // Data edges
440       if (!conn.is_input_edge) {
441         // Set the shapes and data types of output edge.
442         TensorShapeProto out_shape;
443         // shape of the output node inside segment
444         conn.inside_shape.AsProto(&out_shape);
445         if (output_shape_protos.size() <= conn.port_number) {
446           output_shape_protos.resize(conn.port_number + 1);
447           out_types.resize(conn.port_number + 1);
448         }
449         output_shape_protos.at(conn.port_number) = out_shape;
450         out_types.at(conn.port_number) = conn.connection_type;
451       } else {
452         // Set the shapes and data types of input edge.
453         TensorShapeProto in_shape;
454         conn.outside_shape.AsProto(&in_shape);
455         if (input_shape_protos.size() <= conn.port_number) {
456           input_shape_protos.resize(conn.port_number + 1);
457           input_shapes.resize(conn.port_number + 1);
458         }
459         input_shape_protos.at(conn.port_number) = in_shape;
460         input_shapes.at(conn.port_number) = conn.outside_shape;
461         // Shape must be fully defined (excluding batch dimension) for static
462         // mode.
463         if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
464           for (int i = 1; i < conn.outside_shape.dims(); i++) {
465             if (conn.outside_shape.dim_size(i) <= 0) {
466               return errors::Internal(
467                   "Input shapes must be fully defined when in static mode. "
468                   "Please try is_dynamic_op=True (shape was ",
469                   conn.outside_shape.DebugString(), ")");
470             }
471           }
472         }
473 
474         // Rewrire data input if it's not found in original graph.
475         Node* input_node = graph->FindNodeId(conn.outside_id);
476         int port = conn.outside_port;
477         if (!input_node) {
478           UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
479                              conn.outside_node_name, &input_node, &port);
480         }
481         if (std::find_if(
482                 std::begin(inputs), std::end(inputs),
483                 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
484                   return inp.node == input_node->name() && inp.index == port;
485                 }) == std::end(inputs)) {
486           inputs.emplace_back(input_node->name(), port, conn.connection_type);
487           input_nodes.push_back(CHECK_NOTNULL(input_node));
488           VLOG(1) << "Engine Input " << input_node->name() << ":" << port
489                   << " -> " << info.engine_name << ":" << inputs.size() - 1;
490         }
491       }
492     }
493   }
494   // We don't support segments with no inputs. Fall back to native TF here to
495   // avoid crash later. Constant folding should've folded the ops that make up
496   // these segments.
497   if (inputs.empty()) {
498     return errors::Internal(
499         "Segment has no inputs (possible constfold failure)");
500   }
501 
502   const bool calibrate_int8 =
503       (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
504   // Build the engine and get its serialized representation.
505   string segment_string;
506   if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) {
507     // Create static engine for fp32/fp16 mode, and test validity of the engine
508     // for int8 calibration mode. We don't want engine to fail at the
509     // calibration time. So we are constructing a FP32 engine here to check its
510     // validity, and if it is a valid engine then we put the serialized graphdef
511     // to the op. Otherwise we skip node creation for this engine.
512     Logger trt_logger;
513     TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
514     // TODO(sami): What happens if 1st dim is not batch?
515     TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
516         info.segment_graph_def,
517         calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
518         max_batch_size, info.max_workspace_size_bytes, input_shapes,
519         &trt_logger, alloc, /*calibrator=*/nullptr, &engine,
520         info.use_calibration,
521         /*convert_successfully=*/nullptr));
522     TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
523     segment_string = string(static_cast<const char*>(engine_data->data()),
524                             engine_data->size());
525     if (calibrate_int8) {
526       // See above comment about why not putting this inside the 'else' branch.
527       segment_string = info.segment_graph_def.SerializeAsString();
528     }
529   } else {
530     segment_string = info.segment_graph_def.SerializeAsString();
531   }
532 
533   string prec_string;
534   TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
535   NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
536   if (!info.device.empty()) node_builder.Device(info.device);
537   if (VLOG_IS_ON(1)) {
538     string ins = StrCat(info.engine_name, " inputs= ");
539     for (const auto& ii : inputs) {
540       StrAppend(&ins, ii.node, ":", ii.index, " ");
541     }
542     VLOG(1) << ins;
543   }
544   node_builder.Input(inputs);
545   for (const string& c : control_input_names) {
546     node_builder.ControlInput(c);
547   }
548 
549   if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
550       !info.cached_engine_batches.empty()) {
551     LOG(WARNING) << "Cached engine batches are ignored for static engines";
552   }
553   NodeDef trt_node;
554   Status status =
555       node_builder.Attr("input_shapes", input_shape_protos)
556           .Attr("output_shapes", output_shape_protos)
557           .Attr("static_engine",
558                 info.engine_type == EngineInfo::EngineType::TRTStatic)
559           .Attr("segment_funcdef_name",
560                 params.use_function_backup
561                     ? StrCat(info.engine_name, "_native_segment")
562                     : "")
563           .Attr("serialized_segment", segment_string)
564           .Attr("calibration_data", "")
565           .Attr("max_cached_engines_count", info.maximum_cached_engines)
566           .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
567           .Attr("precision_mode", prec_string)
568           .Attr("use_calibration", info.use_calibration)
569           .Attr("OutT", out_types)
570           .Finalize(&trt_node);
571   if (!status.ok()) {
572     LOG(ERROR) << "Node construction failed with" << status;
573     return status;
574   }
575   VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
576 
577   // Up until this point, graph is not modified. If we return !status.ok() from
578   // here, this segment will be skipped
579   // TODO(aaroey): let it return proper error status for the following logic
580   // instead of checking fail.
581   Node* engine_node = graph->AddNode(trt_node, &status);
582   (*engine_nodes)[pos] = engine_node;
583   if (!status.ok()) {
584     LOG(ERROR) << "Adding node failed " << status;
585     return status;
586   }
587   // Add control input and input edges to the engine node.
588   for (const auto in : control_input_nodes) {
589     VLOG(1) << "Connecting control edge from " << in->name() << " to "
590             << engine_node->name();
591     graph->AddControlEdge(in, engine_node);
592   }
593   VLOG(1) << "input_nodes size = " << input_nodes.size();
594   for (int i = 0; i < input_nodes.size(); ++i) {
595     Node* n = CHECK_NOTNULL(input_nodes[i]);
596     const auto& in = inputs[i];
597     VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
598             << " to " << engine_node->name() << ":" << i;
599     graph->AddEdge(n, in.index, engine_node, i);
600   }
601 
602   // Updates the inputs of output edges destination nodes, and point them to the
603   // engine node.
604   for (auto& conn : info.connections) {
605     if (conn.is_input_edge) {
606       continue;
607     }
608     Node* output_node = graph->FindNodeId(conn.outside_id);
609     int port = conn.outside_port;
610     if (!output_node) {
611       UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
612                          conn.outside_node_name, &output_node, &port);
613     }
614     VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
615             << " to " << output_node->name() << ":" << port;
616     if (conn.is_control_edge()) {
617       QCHECK_EQ(Graph::kControlSlot, port);
618       graph->AddControlEdge(engine_node, output_node);
619     } else {
620       auto new_edge =
621           graph->AddEdge(engine_node, conn.port_number, output_node, port);
622       QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
623                        << ":" << conn.port_number << " -> "
624                        << output_node->name() << ":" << conn.outside_port;
625     }
626   }
627   return Status::OK();
628 }
629 
630 // Function to construct a funcdef from the segment and add it to the graph.
RegisterSegmentFunctionToFunctionLibrary(Graph * graph,const GraphDef & segment,const string & engine_name)631 Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
632                                                 const GraphDef& segment,
633                                                 const string& engine_name) {
634   Graph sgraph(graph->flib_def());
635   GraphConstructorOptions gcopts;
636   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph));
637   std::map<string, Node*> io_nodes;
638   int num_inputs = 0;
639   for (auto n : sgraph.op_nodes()) {
640     if (str_util::StartsWith(n->name(), kInputPHName)) {
641       num_inputs++;
642       io_nodes.insert({n->name(), n});
643     } else if (str_util::StartsWith(n->name(), kOutputPHName)) {
644       io_nodes.insert({n->name(), n});
645     }
646   }
647 
648   for (int i = 0; i < num_inputs; ++i) {
649     auto name = StrCat(kInputPHName, i);
650     auto node = io_nodes[name];
651     NodeDef nd;
652     NodeDefBuilder node_builder(StrCat(name, "_Arg"),
653                                 FunctionLibraryDefinition::kArgOp);
654     VLOG(1) << "Adding " << StrCat(name, "_Arg");
655     TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
656                            .Attr("index", i)
657                            .Finalize(&nd));
658     Status s;
659     auto node_arg = sgraph.AddNode(nd, &s);
660     if (!s.ok()) {
661       LOG(ERROR) << "Couldn't add _Arg node for " << name;
662     }
663     for (auto edge : node->out_edges()) {
664       sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
665       VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
666               << " - > " << edge->dst()->name() << ":" << edge->dst_input();
667       if (!s.ok()) {
668         LOG(ERROR) << "Failed to update edge from " << node_arg->name()
669                    << " to " << edge->dst()->name() << ":" << edge->dst_input();
670       }
671     }
672     sgraph.RemoveNode(node);
673   }
674 
675   for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
676     auto name = StrCat(kOutputPHName, i);
677     auto node = io_nodes[name];
678     NodeDef nd;
679     NodeDefBuilder node_builder(StrCat(name, "_Ret"),
680                                 FunctionLibraryDefinition::kRetOp);
681     auto edge = *(node->in_edges().begin());
682     NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(),
683                                  edge->src()->output_type(edge->src_output()));
684     VLOG(1) << " input " << nout.node << ":" << nout.index
685             << " dtype=" << DataTypeString(nout.data_type);
686     // nvcc complains that Input(<brace-enclosed initializer list>) is
687     // ambiguous, so do not use Input({nout}).
688     node_builder.Input(nout);
689     TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
690                            .Attr("index", i)
691                            .Finalize(&nd));
692     if (VLOG_IS_ON(3)) {
693       VLOG(3) << nd.DebugString();
694     }
695     Status s;
696     auto node_ret = sgraph.AddNode(nd, &s);
697     if (!s.ok()) {
698       LOG(ERROR) << "Couldn't add _Ret node for " << name;
699     }
700     VLOG(1) << "Update edge from " << edge->src()->name() << ":"
701             << edge->src_output() << " - > " << node_ret->name() << ":" << 0;
702     sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
703     s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
704     if (!s.ok()) {
705       LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
706                  << edge->src_output() << " - > " << node_ret->name() << ":"
707                  << 0;
708     }
709     sgraph.RemoveNode(node);
710   }
711   FunctionDefLibrary fdeflib;
712   auto native_segment = fdeflib.add_function();
713   TF_RETURN_IF_ERROR(GraphToFunctionDef(
714       sgraph, StrCat(engine_name, "_native_segment"), native_segment));
715   // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
716   // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
717   // would be on host if the op generating the tensor has host memory tag set.
718   (*native_segment
719         ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
720       .set_b(true);
721   if (VLOG_IS_ON(7)) {
722     VLOG(7) << engine_name << " Function_Def ";
723     VLOG(7) << native_segment->DebugString();
724   }
725   VLOG(1) << "Adding funcdef to graphlib";
726   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
727   return Status::OK();
728 }
729 
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)730 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
731                                                  const EngineInfo& engine) {
732   int cuda_device_id = -1;
733   Allocator* dev_allocator = nullptr;
734   if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
735       engine.device.empty()) {
736     // If device is not set, use the first found GPU device for the conversion.
737     for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
738       TfGpuId tf_gpu_id(tf_gpu_id_value);
739       PlatformGpuId platform_gpu_id;
740       Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
741       if (s.ok()) {
742         VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
743                 << platform_gpu_id.value();
744         cuda_device_id = platform_gpu_id.value();
745         GPUOptions gpu_options;
746         // If the TF to Cuda gpu id mapping exist, the device and corresponding
747         // allocator must have been initialized already, so the
748         // GetGPUAllocator() call won't create a new allocator.
749         dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
750             gpu_options, tf_gpu_id, 1);
751         break;
752       }
753       LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
754                  << s;
755     }
756     return std::make_pair(cuda_device_id, dev_allocator);
757   }
758 
759   // Use the device requested by the engine.
760   auto device_set = params.cluster->GetDeviceSet();
761   std::vector<Device*> devices;
762   DeviceNameUtils::ParsedName parsed_name;
763   if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
764       parsed_name.has_id) {
765     device_set->FindMatchingDevices(parsed_name, &devices);
766   }
767   if (!devices.empty()) {
768     if (devices.size() > 1) {
769       string msg = "Found multiple matching devices using name '";
770       StrAppend(&msg, engine.device, "': ");
771       for (auto d : devices) StrAppend(&msg, d->name(), ", ");
772       StrAppend(&msg, ". Will get the allocator from first one.");
773       LOG(WARNING) << msg;
774     }
775     AllocatorAttributes alloc_attr;
776     cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
777     dev_allocator = devices[0]->GetAllocator(alloc_attr);
778     VLOG(1) << "Using allocator " << dev_allocator->Name()
779             << " and cuda_device_id " << cuda_device_id;
780   } else {
781     LOG(WARNING) << "Cluster is set but device '" << engine.device
782                  << "' is not found in the cluster";
783   }
784   return std::make_pair(cuda_device_id, dev_allocator);
785 }
786 
787 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)788 Status ConvertAfterShapes(const ConversionParams& params) {
789   // Sanity checks.
790   if (params.precision_mode == TrtPrecisionMode::INT8) {
791     if (params.use_calibration && !params.use_function_backup) {
792       return errors::InvalidArgument(
793           "Calibration requires enabling fallback to TF function execution.");
794     }
795   } else {
796     if (params.use_calibration) {
797       return errors::InvalidArgument(
798           "Calibration with FP32 or FP16 is not supported.");
799     }
800   }
801 
802   // Convert graphdef to graph.
803   FunctionLibraryDefinition flib(OpRegistry::Global(),
804                                  params.input_graph_def->library());
805   Graph graph(flib);
806   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
807                                             *params.input_graph_def, &graph));
808 
809   // Segment the graph into subgraphs that can be converted to TensorRT
810   segment::SegmentOptions segment_options;
811   // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
812   for (auto node : *(params.output_names)) {
813     segment_options.exclude_node_list.insert(node);
814   }
815   segment_options.minimum_segment_size = params.minimum_segment_size;
816   segment::SegmentNodesVector initial_segments;
817   TrtCandidateSelector candidate_selector(*params.graph_properties,
818                                           params.precision_mode);
819   TF_RETURN_IF_ERROR(segment::SegmentGraph(
820       &graph,
821       std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector,
822                 std::placeholders::_1),
823       // Input validation is already done by TrtCandidateSelector, so we don't
824       // need to check the input edges.
825       [](const Edge* edge) { return true; }, OutputEdgeValidator(),
826       segment_options, &initial_segments));
827   LOG(INFO) << "Number of TensorRT candidate segments: "
828             << initial_segments.size();
829 
830   // Get the EngineInfo for each segment.
831   std::unordered_map<string, Node*> node_map;
832   TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
833   float total_num_nodes_in_segments = 0.;
834   std::vector<EngineInfo> engine_segments;
835   engine_segments.reserve(initial_segments.size());
836   std::vector<Node*> reverse_topo_order;
837   GetPostOrder(graph, &reverse_topo_order);
838   size_t total_engine_bytes_size = 0;
839   std::vector<size_t> engine_bytes_size;
840   segment::SegmentNodesVector converted_segments;
841   converted_segments.reserve(initial_segments.size());
842   for (size_t t = 0; t < initial_segments.size(); t++) {
843     auto& curr_segment = initial_segments.at(t);
844     EngineInfo curr_engine;
845     curr_engine.engine_name = StrCat("TRTEngineOp_", t);
846     Status status =
847         GetEngineInfo(&graph, *params.graph_properties, curr_segment.first,
848                       node_map, reverse_topo_order, &curr_engine);
849     if (!status.ok()) {
850       LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
851                    << status;
852       continue;
853     }
854     curr_engine.precision_mode = params.precision_mode;
855     curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration)
856                                    ? EngineInfo::EngineType::TRTDynamic
857                                    : EngineInfo::EngineType::TRTStatic);
858     curr_engine.use_calibration = params.use_calibration;
859     curr_engine.cached_engine_batches = params.cached_engine_batches;
860     curr_engine.maximum_cached_engines = params.max_cached_engines;
861     if (params.use_function_backup) {
862       status = RegisterSegmentFunctionToFunctionLibrary(
863           &graph, curr_engine.segment_graph_def, curr_engine.engine_name);
864       if (!status.ok()) {
865         LOG(WARNING) << "Failed to register segment graphdef as a function "
866                      << t << ": " << status;
867         continue;
868       }
869     }
870 
871     engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
872     total_engine_bytes_size += engine_bytes_size.back();
873     total_num_nodes_in_segments += curr_segment.first.size();
874     engine_segments.push_back(std::move(curr_engine));
875     converted_segments.push_back(std::move(curr_segment));
876 
877     if (VLOG_IS_ON(8)) {
878       string fname = engine_segments.back().engine_name;
879       StrAppend(&fname, ".pb");
880       std::fstream f;
881       f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
882       f << engine_segments.at(t).segment_graph_def.SerializeAsString();
883       f.close();
884     }
885   }
886 
887   // Create a TRT node for each segment using its EngineInfo.
888   int old_cuda_device = 0;
889   auto err = cudaGetDevice(&old_cuda_device);
890   if (err != cudaSuccess) {
891     LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
892   }
893   VLOG(1) << "Current cuda device is " << old_cuda_device;
894   std::vector<Node*> engine_nodes;
895   engine_nodes.resize(engine_segments.size());
896   for (int i = 0; i < engine_segments.size(); ++i) {
897     auto& engine = engine_segments.at(i);
898     // Partition the workspace size by the average of node ratio and segment
899     // graphdef size
900     engine.max_workspace_size_bytes =
901         params.max_workspace_size_bytes *
902         (engine_bytes_size.at(i) / total_engine_bytes_size +
903          converted_segments.at(i).first.size() / total_num_nodes_in_segments) /
904         2.0;
905     // The allocator is used to build the engine. The build and the built engine
906     // will be destroyed after we get the serialized engine string, so it's fine
907     // to use unique_ptr here.
908     std::unique_ptr<TRTBaseAllocator> alloc;
909     auto device_alloc = GetDeviceAndAllocator(params, engine);
910     int cuda_device_id = 0;
911     if (device_alloc.first >= 0) {
912       cuda_device_id = device_alloc.first;
913       alloc.reset(new TRTDeviceAllocator(device_alloc.second));
914     } else {
915       // Setting allocator as nullptr should get revert to the cudamalloc
916       LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
917     }
918     cudaSetDevice(cuda_device_id);
919     auto status =
920         CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
921                       alloc.get(), &engine_nodes);
922 
923     string msg = StrCat("TensorRT node ", engine.engine_name,
924                         " added for segment ", i, " consisting of ",
925                         converted_segments.at(i).first.size(), " nodes");
926     if (status.ok()) {
927       LOG(INFO) << msg << " succeeded.";
928     } else {
929       // Graph is not modified.
930       LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF...";
931     }
932     if (VLOG_IS_ON(1)) {
933       msg = "Segment consists of nodes: ";
934       for (const Node* node : converted_segments.at(i).first) {
935         StrAppend(&msg, node->name(), ", ");
936       }
937       VLOG(1) << msg;
938     }
939 
940     // If status is ok, we successfully added the node to the graph and can
941     // remove segment ops. Otherwise graph is not modified.
942     if (status.ok()) {
943       for (const Node* node : converted_segments.at(i).first) {
944         graph.RemoveNode(const_cast<Node*>(node));
945       }
946     }
947   }
948   cudaSetDevice(old_cuda_device);
949   graph.ToGraphDef(params.output_graph_def);
950   VLOG(1) << "Returning from conversion";
951   return Status::OK();
952 }
953 
954 }  // namespace convert
955 }  // namespace tensorrt
956 }  // namespace tensorflow
957 
958 #endif  // GOOGLE_TENSORRT
959 #endif  // GOOGLE_CUDA
960