• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17 
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
31 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
32 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/node_def_builder.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/graph_constructor.h"
41 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
42 #include "tensorflow/core/grappler/costs/graph_properties.h"
43 #include "tensorflow/core/grappler/devices.h"
44 #include "tensorflow/core/grappler/grappler_item.h"
45 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
46 #include "tensorflow/core/grappler/utils.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
53 #include "tensorflow/core/protobuf/device_properties.pb.h"  // NOLINT
54 #include "tensorflow/core/protobuf/rewriter_config.pb.h"  // NOLINT
55 #include "tensorflow/core/util/device_name_utils.h"
56 
57 #if GOOGLE_CUDA
58 #if GOOGLE_TENSORRT
59 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
60 #include "third_party/tensorrt/NvInfer.h"
61 namespace tensorflow {
62 namespace tensorrt {
63 namespace convert {
64 using absl::StrAppend;
65 using absl::StrCat;
66 
67 namespace {
68 
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)69 Status BuildNodeMap(const Graph& graph,
70                     std::unordered_map<string, Node*>* node_map) {
71   for (auto* node : graph.op_nodes()) {
72     if (!node_map->insert({node->name(), node}).second) {
73       return errors::AlreadyExists("Node name is not unique in graph: " +
74                                    node->name());
75     }
76   }
77   return Status::OK();
78 }
79 
80 }  // namespace
81 
82 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::EdgePtrCompare83   bool operator()(const Edge* lhs, const Edge* rhs) const {
84     return lhs->id() < rhs->id();
85   }
86 };
87 
88 // TODO(laigd): instead of deciding the device here, the converter should accept
89 // a device name as one of the conversion parameter so users can control on
90 // which device they want to run the conversion.
GetFirstValidDeviceId()91 std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
92   for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
93     TfGpuId tf_gpu_id(tf_gpu_id_value);
94     PlatformGpuId platform_gpu_id;
95     Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
96     if (s.ok()) {
97       VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
98               << platform_gpu_id.value();
99       return std::make_pair(tf_gpu_id, platform_gpu_id);
100     }
101   }
102   LOG(ERROR) << "Could not find any TF GPUs";
103   return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
104 }
105 
106 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const std::set<const Node * > & segment_nodes,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)107 Status GetEngineInfo(const Graph* g,
108                      const grappler::GraphProperties& graph_properties,
109                      const std::set<const Node*>& segment_nodes,
110                      const std::unordered_map<string, Node*>& node_map,
111                      const std::vector<Node*>& reverse_topo_order,
112                      EngineInfo* info) {
113   std::vector<const Node*> subgraph_nodes;  // Topologically sorted nodes.
114   std::set<const Node*> added_const_nodes;  // Used to prevent double insertion.
115   std::set<string> segment_devices;
116 
117   // Map from src_node_name+port to the unique port numbers of the TRT op, where
118   // the src_node_name is the name of the source node of the input/output
119   // edge, thus there must not be any duplicates since source nodes of
120   // input/output edges must be in different split of the graph.
121   // TODO(aaroey): consider using node id and port instead.
122   // TODO(aaroey): using topo order instead of reverting reverse topo order.
123   std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
124   for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
125        ++it) {
126     const Node* node = *it;
127     if (segment_nodes.count(node) == 0) continue;
128 
129     std::string device_name;
130     if (!node->requested_device().empty()) {
131       device_name = node->requested_device();
132     } else if (node->has_assigned_device_name()) {
133       // It appears that nodes will not have assigned devices at this point in
134       // execution.
135       device_name = node->assigned_device_name();
136     } else {
137       VLOG(2) << "Node " << node->name()
138               << " neither have requested device nor assigned device";
139     }
140 
141     if (!device_name.empty()) {
142       // If device is set, it means device placement may have been done before,
143       // so we need to assign a device for the TRTEngineOp if the assigned
144       // device is a GPU device.
145       DeviceNameUtils::ParsedName parsed_name;
146       const bool parse_succeeded =
147           DeviceNameUtils::ParseFullName(device_name, &parsed_name);
148       if (!parse_succeeded) {
149         VLOG(1) << "Failed to parse "
150                 << (node->requested_device().empty() ? "assigned" : "requested")
151                 << " device " << device_name << " of node " << node->name();
152       } else if (parsed_name.type != "GPU") {
153         VLOG(1) << "Node " << node->name()
154                 << " was assigned to a non-GPU device " << device_name;
155       } else {
156         segment_devices.insert(device_name);
157       }
158     }
159     subgraph_nodes.push_back(node);
160 
161     const int node_id = node->id();
162     const string& node_name = node->name();
163 
164     // Create input connections. Sort edges first to make deterministic since
165     // in_edges is a set of pointers.
166     std::vector<const Edge*> in_edges(node->in_edges().begin(),
167                                       node->in_edges().end());
168     std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
169     for (const auto edge : in_edges) {
170       auto input_node = edge->src();
171       if (input_node->IsSource() || segment_nodes.count(input_node)) {
172         continue;
173       }
174       if (edge->IsControlEdge()) {
175         if (input_node->type_string() != "Const") {
176           // Non-Const control input.
177           info->connections.emplace_back(input_node->name(), input_node->id(),
178                                          node_name, node_id,
179                                          /*input_edge=*/true);
180         }
181       } else if (input_node->type_string() == "Const") {
182         // Add constant data input nodes into the segment graphdef (thus also in
183         // the engine). We don't care if it has other output edges going into
184         // other engines or TF nodes. Since we add it only to the segment
185         // graphdef, not the segment itself, it won't be removed from the graph.
186         // If it doesn't have any edges, TF will prune it out.
187         //
188         // Note that the segmenter already ensure that the constant data input
189         // is valid and supported by the engine.
190         if (!added_const_nodes.insert(input_node).second) {
191           // Already added before.
192           continue;
193         }
194         VLOG(1) << "Adding const node " << input_node->name();
195       } else {
196         // Non-const data input.
197         int port = Graph::kControlSlot - 1;
198         // Use the source non-segment node name/port as key.
199         const string s = StrCat(input_node->name(), ":", edge->src_output());
200         VLOG(1) << "Input edge = " << s;
201         if (input_to_engine_port.count(s)) {
202           port = input_to_engine_port.at(s);
203         } else {
204           port = input_to_engine_port.size();
205           input_to_engine_port.insert({s, port});
206         }
207         info->connections.emplace_back(
208             input_node->name(), input_node->id(), edge->src_output(), node_name,
209             node_id, edge->dst_input(), /*input_edge=*/true, port);
210       }
211     }
212     // Create output connections. Sort edges first to make deterministic since
213     // out_edges is a set of pointers.
214     std::vector<const Edge*> out_edges(node->out_edges().begin(),
215                                        node->out_edges().end());
216     std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
217     for (const auto edge : out_edges) {
218       auto output_node = edge->dst();
219       if (output_node->IsSink() || segment_nodes.count(output_node)) {
220         continue;
221       }
222       if (edge->IsControlEdge()) {
223         // Control output.
224         info->connections.emplace_back(output_node->name(), output_node->id(),
225                                        node_name, node_id,
226                                        /*input_edge=*/false);
227       } else {
228         // Data output.
229         int port = Graph::kControlSlot - 1;
230         // Use the source segment node name/port as key.
231         const string s = StrCat(node_name, ":", edge->src_output());
232         VLOG(1) << "Output edge = " << s;
233         if (output_to_engine_port.count(s)) {
234           port = output_to_engine_port.at(s);
235         } else {
236           port = output_to_engine_port.size();
237           output_to_engine_port.insert({s, port});
238         }
239         info->connections.emplace_back(
240             output_node->name(), output_node->id(), edge->dst_input(),
241             node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
242       }
243     }
244   }  // For each segment node in topological order.
245 
246   // Construct the const nodes first.
247   subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
248                         added_const_nodes.end());
249   string scope_name;
250   TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
251       g, graph_properties, subgraph_nodes, &info->connections,
252       &info->segment_graph_def, &scope_name));
253   info->engine_name = StrCat(scope_name, info->engine_name);
254   VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
255           << "' to a GraphDef";
256   if (segment_devices.size() == 1) {
257     info->device = *segment_devices.begin();
258   } else if (segment_devices.size() > 1) {
259     LOG(WARNING) << "Detected multiple (" << segment_devices.size()
260                  << ") devices for the segment. Picking first one to continue.";
261     info->device = *segment_devices.begin();
262   } else {
263     TfGpuId tf_gpu_id;
264     PlatformGpuId platform_gpu_id;
265     std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
266     if (tf_gpu_id.value() >= 0) {
267       DeviceNameUtils::ParsedName parsed_name;
268       parsed_name.type = "GPU";
269       parsed_name.has_type = true;
270       parsed_name.id = tf_gpu_id.value();
271       parsed_name.has_id = true;
272       info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
273     } else {
274       VLOG(1) << "No device is assigned to the segment. A device will be "
275                  "assigned during graph execution (inference).";
276     }
277   }
278   return Status::OK();
279 }
280 
281 // Helper function to update edge connection from the removed node to the
282 // engine node. If an outside node is gone, it must have been absorbed into
283 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)284 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
285                         const size_t my_engine_id,
286                         const std::vector<Node*>& engine_nodes,
287                         const bool is_input_edge, const string& node_name,
288                         Node** node, int* port) {
289   for (size_t t = 0; t < infos.size(); ++t) {
290     if (t == my_engine_id) {
291       continue;
292     }
293     const auto& info = infos.at(t);
294     for (const auto& eng_conn : info.connections) {
295       // If the connection being updated is an input connection, the source of
296       // the connection must be an output connection of another engine. And vise
297       // versa.
298       if (is_input_edge == eng_conn.is_input_edge) continue;
299       if (eng_conn.inside_node_name == node_name &&
300           eng_conn.inside_port == *port) {
301         *node = CHECK_NOTNULL(engine_nodes[t]);
302         QCHECK_EQ(info.engine_name, (**node).name())
303             << "Engine name mismatch: " << info.engine_name << " vs "
304             << (**node).name();
305         *port = eng_conn.port_number;
306         return;
307       }
308     }
309   }
310   LOG(FATAL) << "Node " << node_name << " not found in any engine.";
311 }
312 
313 // Function to insert a TRT engine node into the graph.
314 // Create engine nodes in the following way:
315 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
316 // 2. When an engine node is created, add it into the graph with necessary
317 //    re-wiring.
318 //    2.1. If the outside connected node is existing, connect the engine
319 //         node to it.
320 //    2.2. If the outside connected node is gone, it must have been absorted
321 //         into another engine node (which was processed before the processing
322 //         one). Connect to the pre-existing engine node instead.
323 // 3. In this way, we ensure the graph is topologically sort-able after each
324 //    invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int max_batch_size,Graph * graph,nvinfer1::IGpuAllocator * alloc,std::vector<Node * > * engine_nodes)325 Status CreateTRTNode(const ConversionParams& params,
326                      const std::vector<EngineInfo>& infos, int pos,
327                      int max_batch_size, Graph* graph,
328                      nvinfer1::IGpuAllocator* alloc,
329                      std::vector<Node*>* engine_nodes) {
330   const auto& info = infos.at(pos);
331   std::vector<PartialTensorShape> input_shapes;
332   std::vector<NodeDefBuilder::NodeOut> inputs;
333   std::vector<Node*> input_nodes;
334   std::vector<Node*> control_input_nodes;
335   std::unordered_set<string> control_input_names;
336   std::vector<DataType> out_types;
337 
338   VLOG(1) << "Processing " << info.engine_name;
339   // Collect needed info for creating the engine node in the graph
340   for (const auto& conn : info.connections) {
341     // Control edges
342     if (conn.is_control_edge()) {
343       // Skip control outputs for now. control output info are not needed for
344       // node creation and will be processed later.
345       if (!conn.is_input_edge) continue;
346 
347       // Rewrire control input if it's not found in original graph.
348       Node* input_node = graph->FindNodeId(conn.outside_id);
349       int port = Graph::kControlSlot;
350       if (!input_node) {
351         UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
352                            conn.outside_node_name, &input_node, &port);
353         QCHECK_EQ(Graph::kControlSlot, port);
354       }
355       if (!control_input_names.insert(input_node->name()).second) {
356         continue;
357       }
358       control_input_nodes.push_back(input_node);
359       VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
360               << info.engine_name;
361     } else {
362       // Data edges
363       if (!conn.is_input_edge) {
364         // Set the data types of output edge.
365         if (out_types.size() <= conn.port_number) {
366           out_types.resize(conn.port_number + 1);
367         }
368         out_types.at(conn.port_number) = conn.connection_type;
369       } else {
370         // Set the shapes and data types of input edge.
371         if (input_shapes.size() <= conn.port_number) {
372           input_shapes.resize(conn.port_number + 1);
373         }
374         input_shapes.at(conn.port_number) = conn.outside_shape;
375         // Shape must be fully defined (excluding batch dimension) for static
376         // mode.
377         if (params.use_implicit_batch &&
378             info.engine_type == EngineInfo::EngineType::TRTStatic) {
379           for (int i = 1; i < conn.outside_shape.dims(); i++) {
380             if (conn.outside_shape.dim_size(i) <= 0) {
381               return errors::Internal(
382                   "Input shapes must be fully defined when in static mode. "
383                   "Please try is_dynamic_op=True (shape was ",
384                   conn.outside_shape.DebugString(), ")");
385             }
386           }
387         }
388 
389         // Rewrire data input if it's not found in original graph.
390         Node* input_node = graph->FindNodeId(conn.outside_id);
391         int port = conn.outside_port;
392         if (!input_node) {
393           UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
394                              conn.outside_node_name, &input_node, &port);
395         }
396         if (std::find_if(
397                 std::begin(inputs), std::end(inputs),
398                 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
399                   return inp.node == input_node->name() && inp.index == port;
400                 }) == std::end(inputs)) {
401           inputs.emplace_back(input_node->name(), port, conn.connection_type);
402           input_nodes.push_back(CHECK_NOTNULL(input_node));
403           VLOG(1) << "Engine Input " << input_node->name() << ":" << port
404                   << " -> " << info.engine_name << ":" << inputs.size() - 1;
405         }
406       }
407     }
408   }
409   // We don't support segments with no inputs. Fall back to native TF here to
410   // avoid crash later. Constant folding should've folded the ops that make up
411   // these segments.
412   if (inputs.empty()) {
413     return errors::Internal(
414         "Segment has no inputs (possible constfold failure)");
415   }
416 
417   const bool calibrate_int8 =
418       (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
419   // Build the engine and get its serialized representation.
420   string segment_string;
421   if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
422     auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
423     // Create static engine for fp32/fp16 mode.
424     TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
425     // TODO(sami): What happens if 1st dim is not batch?
426     TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
427         info.segment_graph_def,
428         calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
429         max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
430         alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
431         params.use_implicit_batch, /*convert_successfully=*/nullptr));
432     TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
433     segment_string = string(static_cast<const char*>(engine_data->data()),
434                             engine_data->size());
435   }
436 
437   string prec_string;
438   TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
439   NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
440   if (!info.device.empty()) node_builder.Device(info.device);
441   if (VLOG_IS_ON(1)) {
442     string ins = StrCat(info.engine_name, " inputs= ");
443     for (const auto& ii : inputs) {
444       StrAppend(&ins, ii.node, ":", ii.index, " ");
445     }
446     VLOG(1) << ins;
447   }
448   node_builder.Input(inputs);
449   for (const string& c : control_input_names) {
450     node_builder.ControlInput(c);
451   }
452 
453   NodeDef trt_node;
454   NameAttrList function;
455   function.set_name(StrCat(info.engine_name, "_native_segment"));
456   Status status =
457       node_builder
458           .Attr("static_engine",
459                 info.engine_type == EngineInfo::EngineType::TRTStatic)
460           .Attr("segment_func", function)
461           .Attr("serialized_segment", segment_string)
462           .Attr("calibration_data", "")
463           .Attr("max_cached_engines_count", info.maximum_cached_engines)
464           .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
465           .Attr("precision_mode", prec_string)
466           .Attr("use_calibration", info.use_calibration)
467           .Attr("_use_implicit_batch", params.use_implicit_batch)
468           .Attr("OutT", out_types)
469           .Finalize(&trt_node);
470   if (!status.ok()) {
471     LOG(ERROR) << "Node construction failed with" << status;
472     return status;
473   }
474   VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
475 
476   // Up until this point, graph is not modified. If we return !status.ok() from
477   // here, this segment will be skipped
478   // TODO(aaroey): let it return proper error status for the following logic
479   // instead of checking fail.
480   Node* engine_node = graph->AddNode(trt_node, &status);
481   (*engine_nodes)[pos] = engine_node;
482   if (!status.ok()) {
483     LOG(ERROR) << "Adding node failed " << status;
484     return status;
485   }
486   // Add control input and input edges to the engine node.
487   for (const auto in : control_input_nodes) {
488     VLOG(1) << "Connecting control edge from " << in->name() << " to "
489             << engine_node->name();
490     graph->AddControlEdge(in, engine_node);
491   }
492   VLOG(1) << "input_nodes size = " << input_nodes.size();
493   for (int i = 0; i < input_nodes.size(); ++i) {
494     Node* n = CHECK_NOTNULL(input_nodes[i]);
495     const auto& in = inputs[i];
496     VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
497             << " to " << engine_node->name() << ":" << i;
498     graph->AddEdge(n, in.index, engine_node, i);
499   }
500 
501   // Updates the inputs of output edges destination nodes, and point them to the
502   // engine node.
503   for (auto& conn : info.connections) {
504     if (conn.is_input_edge) {
505       continue;
506     }
507     Node* output_node = graph->FindNodeId(conn.outside_id);
508     int port = conn.outside_port;
509     if (!output_node) {
510       UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
511                          conn.outside_node_name, &output_node, &port);
512     }
513     if (conn.is_control_edge()) {
514       VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
515               << output_node->name();
516       QCHECK_EQ(Graph::kControlSlot, port);
517       graph->AddControlEdge(engine_node, output_node);
518     } else {
519       VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
520               << conn.port_number << " to " << output_node->name() << ":"
521               << port;
522       // Use UpdateEdge() to avoid adding the same edge multiple times.
523       TF_CHECK_OK(
524           graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
525     }
526   }
527   return Status::OK();
528 }
529 
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name)530 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
531                                       Graph* graph, const string& engine_name) {
532   Graph segment_graph(graph->flib_def());
533   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
534                                             segment_graph_def, &segment_graph));
535   FunctionDefLibrary library;
536   auto segment_func = library.add_function();
537   TF_RETURN_IF_ERROR(GraphToFunctionDef(
538       segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
539   // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
540   // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
541   // would be on host if the op generating the tensor has host memory tag set.
542   (*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
543       .set_b(true);
544   if (VLOG_IS_ON(7)) {
545     VLOG(7) << engine_name << " Function_Def ";
546     VLOG(7) << segment_func->DebugString();
547   }
548   VLOG(1) << "Adding funcdef " << segment_func->signature().name()
549           << " to graphlib";
550   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
551   return Status::OK();
552 }
553 
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)554 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
555                                                  const EngineInfo& engine) {
556   int cuda_device_id = -1;
557   Allocator* dev_allocator = nullptr;
558   if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
559       engine.device.empty()) {
560     // If device is not set, use the first found GPU device for the conversion.
561     TfGpuId tf_gpu_id;
562     PlatformGpuId platform_gpu_id;
563     std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
564     cuda_device_id = platform_gpu_id.value();
565     if (cuda_device_id >= 0) {
566       GPUOptions gpu_options;
567       // If the TF to Cuda gpu id mapping exist, the device and corresponding
568       // allocator must have been initialized already, so the
569       // GetGPUAllocator() call won't create a new allocator.
570       dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
571           gpu_options, tf_gpu_id, 1);
572     }
573     return std::make_pair(cuda_device_id, dev_allocator);
574   }
575 
576   // Use the device requested by the engine.
577   auto device_set = params.cluster->GetDeviceSet();
578   std::vector<Device*> devices;
579   DeviceNameUtils::ParsedName parsed_name;
580   if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
581       parsed_name.has_id) {
582     device_set->FindMatchingDevices(parsed_name, &devices);
583   }
584   if (!devices.empty()) {
585     if (devices.size() > 1) {
586       string msg = "Found multiple matching devices using name '";
587       StrAppend(&msg, engine.device, "': ");
588       for (auto d : devices) StrAppend(&msg, d->name(), ", ");
589       StrAppend(&msg, ". Will get the allocator from first one.");
590       LOG(WARNING) << msg;
591     }
592     AllocatorAttributes alloc_attr;
593     cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
594     dev_allocator = devices[0]->GetAllocator(alloc_attr);
595     VLOG(1) << "Using allocator " << dev_allocator->Name()
596             << " and cuda_device_id " << cuda_device_id;
597   } else {
598     LOG(WARNING) << "Cluster is set but device '" << engine.device
599                  << "' is not found in the cluster";
600   }
601   return std::make_pair(cuda_device_id, dev_allocator);
602 }
603 
604 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)605 Status ConvertAfterShapes(const ConversionParams& params) {
606   // Sanity checks.
607   if (params.precision_mode != TrtPrecisionMode::INT8 &&
608       params.use_calibration) {
609     return errors::InvalidArgument(
610         "Calibration with FP32 or FP16 is not supported.");
611   }
612 
613   // Convert graphdef to graph.
614   FunctionLibraryDefinition flib(OpRegistry::Global(),
615                                  params.input_graph_def->library());
616   Graph graph(flib);
617   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
618                                             *params.input_graph_def, &graph));
619 
620   // Segment the graph into subgraphs that can be converted to TensorRT
621   segment::SegmentOptions segment_options;
622   // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
623   for (auto node : *(params.output_names)) {
624     segment_options.exclude_node_list.insert(node);
625   }
626   segment_options.minimum_segment_size = params.minimum_segment_size;
627   segment::SegmentNodesVector initial_segments;
628   TrtNodeValidator validator(*params.graph_properties, params.precision_mode,
629                              params.use_calibration, params.use_implicit_batch);
630   TF_RETURN_IF_ERROR(segment::SegmentGraph(
631       &graph,
632       std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
633                 std::placeholders::_1),
634       // Input validation is already done by TrtNodeValidator, so we don't
635       // need to check the input edges.
636       [](const Edge* edge) { return true; }, OutputEdgeValidator(),
637       segment_options, &initial_segments));
638   LOG(INFO) << "Number of TensorRT candidate segments: "
639             << initial_segments.size();
640 
641   // Get the EngineInfo for each segment.
642   std::unordered_map<string, Node*> node_map;
643   TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
644   float total_num_nodes_in_segments = 0.;
645   std::vector<EngineInfo> engine_segments;
646   engine_segments.reserve(initial_segments.size());
647   std::vector<Node*> reverse_topo_order;
648   GetPostOrder(graph, &reverse_topo_order);
649   size_t total_engine_bytes_size = 0;
650   std::vector<size_t> engine_bytes_size;
651   segment::SegmentNodesVector converted_segments;
652   converted_segments.reserve(initial_segments.size());
653   for (size_t t = 0; t < initial_segments.size(); t++) {
654     auto& curr_segment = initial_segments.at(t);
655     EngineInfo curr_engine;
656     curr_engine.engine_name = StrCat("TRTEngineOp_", t);
657     Status status =
658         GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map,
659                       reverse_topo_order, &curr_engine);
660     if (!status.ok()) {
661       LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
662                    << status;
663       continue;
664     }
665     curr_engine.precision_mode = params.precision_mode;
666     curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration)
667                                    ? EngineInfo::EngineType::TRTDynamic
668                                    : EngineInfo::EngineType::TRTStatic);
669     curr_engine.use_calibration = params.use_calibration;
670     curr_engine.maximum_cached_engines = params.max_cached_engines;
671 
672     status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
673                                             &graph, curr_engine.engine_name);
674 
675     if (!status.ok()) {
676       LOG(WARNING) << "Failed to register segment graphdef to the library " << t
677                    << ": " << status;
678       continue;
679     }
680 
681     engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
682     total_engine_bytes_size += engine_bytes_size.back();
683     total_num_nodes_in_segments += curr_segment.size();
684     engine_segments.push_back(std::move(curr_engine));
685     converted_segments.push_back(std::move(curr_segment));
686 
687     if (VLOG_IS_ON(8)) {
688       string fname = engine_segments.back().engine_name;
689       StrAppend(&fname, ".pb");
690       std::fstream f;
691       f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
692       f << engine_segments.at(t).segment_graph_def.SerializeAsString();
693       f.close();
694     }
695   }
696 
697   // Create a TRT node for each segment using its EngineInfo.
698   int old_cuda_device = 0;
699   auto err = cudaGetDevice(&old_cuda_device);
700   if (err != cudaSuccess) {
701     LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
702   }
703   VLOG(1) << "Current cuda device is " << old_cuda_device;
704   std::vector<Node*> engine_nodes;
705   engine_nodes.resize(engine_segments.size());
706   for (int i = 0; i < engine_segments.size(); ++i) {
707     auto& engine = engine_segments.at(i);
708     // Partition the workspace size by the average of node ratio and segment
709     // graphdef size
710     engine.max_workspace_size_bytes =
711         params.max_workspace_size_bytes *
712         (engine_bytes_size.at(i) / total_engine_bytes_size +
713          converted_segments.at(i).size() / total_num_nodes_in_segments) /
714         2.0;
715     VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
716             << engine.engine_name;
717     // The allocator is used to build the engine. The build and the built engine
718     // will be destroyed after we get the serialized engine string, so it's fine
719     // to use unique_ptr here.
720     std::unique_ptr<TRTBaseAllocator> alloc;
721     auto device_alloc = GetDeviceAndAllocator(params, engine);
722     int cuda_device_id = 0;
723     if (device_alloc.first >= 0) {
724       cuda_device_id = device_alloc.first;
725       alloc.reset(new TRTDeviceAllocator(device_alloc.second));
726     } else {
727       // Setting allocator as nullptr should get revert to the cudamalloc
728       LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
729     }
730     cudaSetDevice(cuda_device_id);
731     auto status =
732         CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
733                       alloc.get(), &engine_nodes);
734 
735     string msg = StrCat("segment ", i, " consisting of ",
736                         converted_segments.at(i).size(), " nodes by ",
737                         engine.engine_name);
738     if (status.ok()) {
739       LOG(INFO) << "Replaced " << msg << ".";
740     } else {
741       // Graph is not modified.
742       LOG(WARNING) << "Cannot replace " << msg
743                    << " (keeping original segment).";
744     }
745     if (VLOG_IS_ON(1)) {
746       msg = "Segment consists of nodes: ";
747       for (const Node* node : converted_segments.at(i)) {
748         StrAppend(&msg, node->name(), ", ");
749       }
750       VLOG(1) << msg;
751     }
752 
753     // If status is ok, we successfully added the node to the graph and can
754     // remove segment ops. Otherwise graph is not modified.
755     if (status.ok()) {
756       for (const Node* node : converted_segments.at(i)) {
757         graph.RemoveNode(const_cast<Node*>(node));
758       }
759     }
760   }
761   cudaSetDevice(old_cuda_device);
762   graph.ToGraphDef(params.output_graph_def);
763   VLOG(1) << "Returning from conversion";
764   return Status::OK();
765 }
766 
767 }  // namespace convert
768 }  // namespace tensorrt
769 }  // namespace tensorflow
770 
771 #endif  // GOOGLE_TENSORRT
772 #endif  // GOOGLE_CUDA
773