• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17 
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
32 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/graph/algorithm.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
43 #include "tensorflow/core/grappler/costs/graph_properties.h"
44 #include "tensorflow/core/grappler/devices.h"
45 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
46 #include "tensorflow/core/grappler/utils.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/cleanup.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
52 #include "tensorflow/core/protobuf/device_properties.pb.h"  // NOLINT
53 #include "tensorflow/core/protobuf/rewriter_config.pb.h"  // NOLINT
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/tools/graph_transforms/transform_utils.h"
56 
57 #if GOOGLE_CUDA && GOOGLE_TENSORRT
58 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
59 #include "third_party/tensorrt/NvInfer.h"
60 namespace tensorflow {
61 namespace tensorrt {
62 namespace convert {
63 
64 using absl::StrAppend;
65 using absl::StrCat;
66 using ::tensorflow::tensorrt::segment::ClusterProperty;
67 using ::tensorflow::tensorrt::segment::NodePtrCompare;
68 using ::tensorflow::tensorrt::segment::Segment;
69 
70 namespace {
71 
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)72 Status BuildNodeMap(const Graph& graph,
73                     std::unordered_map<string, Node*>* node_map) {
74   for (auto* node : graph.op_nodes()) {
75     if (!node_map->insert({node->name(), node}).second) {
76       return errors::AlreadyExists("Node name is not unique in graph: " +
77                                    node->name());
78     }
79   }
80   return Status::OK();
81 }
82 
GetEngineType(const ConversionParams & params)83 EngineInfo::EngineType GetEngineType(const ConversionParams& params) {
84   return (params.is_dyn_op || params.use_calibration)
85              ? EngineInfo::EngineType::TRTDynamic
86              : EngineInfo::EngineType::TRTStatic;
87 }
88 
89 // Returns true when use_implicit_batch is false or when we are building dynamic
90 // engine, to allow unknown size for dimensions rather than dimension 0.
AllowDynamicNonBatchDimension(const ConversionParams & params)91 bool AllowDynamicNonBatchDimension(const ConversionParams& params) {
92   return !params.use_implicit_batch ||
93          GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
94 }
95 
96 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::__anon288eb7380111::EdgePtrCompare97   bool operator()(const Edge* lhs, const Edge* rhs) const {
98     return lhs->id() < rhs->id();
99   }
100 };
101 
102 // TODO(laigd): instead of deciding the device here, the converter should accept
103 // a device name as one of the conversion parameter so users can control on
104 // which device they want to run the conversion.
GetFirstValidDeviceId()105 std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
106   for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
107     TfGpuId tf_gpu_id(tf_gpu_id_value);
108     PlatformGpuId platform_gpu_id;
109     Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
110     if (s.ok()) {
111       VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
112               << platform_gpu_id.value();
113       return std::make_pair(tf_gpu_id, platform_gpu_id);
114     }
115   }
116   LOG(ERROR) << "Could not find any TF GPUs";
117   return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
118 }
119 
120 // Returns false for const nodes (we intend to drop control edges from those).
ShallKeepControlEdgeFrom(const Node * input_node)121 bool ShallKeepControlEdgeFrom(const Node* input_node) {
122   if (!input_node) {
123     LOG(ERROR) << "Node pointer is null, this should not happen";
124     return false;
125   }
126   return input_node->type_string() != "Const";
127 }
128 
129 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const Segment & segment,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)130 Status GetEngineInfo(const Graph* g,
131                      const grappler::GraphProperties& graph_properties,
132                      const Segment& segment,
133                      const std::unordered_map<string, Node*>& node_map,
134                      const std::vector<Node*>& reverse_topo_order,
135                      EngineInfo* info) {
136   std::vector<const Node*> subgraph_nodes;  // Topologically sorted nodes.
137   std::set<const Node*> added_const_nodes;  // Used to prevent double insertion.
138 
139   const ClusterProperty& segment_property = segment.property;
140   const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
141 
142   // The device assignment accumulated from the compatible device assignments
143   // for the nodes in the segment.
144   const DeviceNameUtils::ParsedName segment_device =
145       segment_property.DeviceName();
146   info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
147 
148   // Map from src_node_name+port to the unique port numbers of the TRT op, where
149   // the src_node_name is the name of the source node of the input/output
150   // edge, thus there must not be any duplicates since source nodes of
151   // input/output edges must be in different split of the graph.
152   // TODO(aaroey): consider using node id and port instead.
153   // TODO(aaroey): using topo order instead of reverting reverse topo order.
154   std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
155   for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
156        ++it) {
157     const Node* node = *it;
158     if (segment_nodes.count(node) == 0) continue;
159     subgraph_nodes.push_back(node);
160 
161     const int node_id = node->id();
162     const string& node_name = node->name();
163 
164     // Create input connections. Sort edges first to make deterministic since
165     // in_edges is a set of pointers.
166     std::vector<const Edge*> in_edges(node->in_edges().begin(),
167                                       node->in_edges().end());
168     std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
169     for (const auto edge : in_edges) {
170       auto input_node = edge->src();
171       if (input_node->IsSource() || segment_nodes.count(input_node)) {
172         continue;
173       }
174       if (edge->IsControlEdge()) {
175         if (ShallKeepControlEdgeFrom(input_node)) {
176           // Non-Const control input.
177           info->connections.emplace_back(input_node->name(), input_node->id(),
178                                          node_name, node_id,
179                                          /*input_edge=*/true);
180         }
181       } else if (input_node->type_string() == "Const") {
182         // Add constant data input nodes into the segment graphdef (thus also in
183         // the engine). We don't care if it has other output edges going into
184         // other engines or TF nodes. Since we add it only to the segment
185         // graphdef, not the segment itself, it won't be removed from the graph.
186         // If it doesn't have any edges, TF will prune it out.
187         //
188         // Note that the segmenter already ensure that the constant data input
189         // is valid and supported by the engine.
190         if (!added_const_nodes.insert(input_node).second) {
191           // Already added before.
192           continue;
193         }
194         VLOG(1) << "Adding const node " << input_node->name();
195       } else {
196         // Non-const data input.
197         int port = Graph::kControlSlot - 1;
198         // Use the source non-segment node name/port as key.
199         const string s = StrCat(input_node->name(), ":", edge->src_output());
200         VLOG(1) << "Input edge = " << s;
201         if (input_to_engine_port.count(s)) {
202           port = input_to_engine_port.at(s);
203         } else {
204           port = input_to_engine_port.size();
205           input_to_engine_port.insert({s, port});
206         }
207         info->connections.emplace_back(
208             input_node->name(), input_node->id(), edge->src_output(), node_name,
209             node_id, edge->dst_input(), /*input_edge=*/true, port);
210       }
211     }
212     // Create output connections. Sort edges first to make deterministic since
213     // out_edges is a set of pointers.
214     std::vector<const Edge*> out_edges(node->out_edges().begin(),
215                                        node->out_edges().end());
216     std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
217     for (const auto edge : out_edges) {
218       auto output_node = edge->dst();
219       if (output_node->IsSink() || segment_nodes.count(output_node)) {
220         continue;
221       }
222       if (edge->IsControlEdge()) {
223         // Control output.
224         if (ShallKeepControlEdgeFrom(node)) {
225           info->connections.emplace_back(output_node->name(), output_node->id(),
226                                          node_name, node_id,
227                                          /*input_edge=*/false);
228         }
229       } else {
230         // Data output.
231         int port = Graph::kControlSlot - 1;
232         // Use the source segment node name/port as key.
233         const string s = StrCat(node_name, ":", edge->src_output());
234         VLOG(1) << "Output edge = " << s;
235         if (output_to_engine_port.count(s)) {
236           port = output_to_engine_port.at(s);
237         } else {
238           port = output_to_engine_port.size();
239           output_to_engine_port.insert({s, port});
240         }
241         info->connections.emplace_back(
242             output_node->name(), output_node->id(), edge->dst_input(),
243             node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
244       }
245     }
246   }  // For each segment node in topological order.
247 
248   // Construct the const nodes first.
249   subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
250                         added_const_nodes.end());
251   string scope_name;
252   TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
253       g, graph_properties, subgraph_nodes, &info->connections,
254       &info->segment_graph_def, &scope_name));
255   info->engine_name = StrCat(scope_name, info->engine_name);
256   VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
257           << "' to a GraphDef";
258   if (segment_device.has_type) {
259     // If the accumulated device assignment for the segment has a device type,
260     // the segmenter guarantees the device type is GPU. Use the device
261     // assignment in this case.
262     if (segment_device.type != "GPU") {
263       return errors::Internal(
264           "segment device is not GPU: ",
265           DeviceNameUtils::ParsedNameToString(segment_device));
266     }
267     info->device = DeviceNameUtils::ParsedNameToString(segment_device);
268   } else {
269     TfGpuId tf_gpu_id;
270     PlatformGpuId platform_gpu_id;
271     std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
272     if (tf_gpu_id.value() >= 0) {
273       DeviceNameUtils::ParsedName parsed_name;
274       parsed_name.type = "GPU";
275       parsed_name.has_type = true;
276       parsed_name.id = tf_gpu_id.value();
277       parsed_name.has_id = true;
278       info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
279     } else {
280       VLOG(1) << "No device is assigned to the segment. A device will be "
281                  "assigned during graph execution (inference).";
282     }
283   }
284   return Status::OK();
285 }
286 
287 // Helper function to update edge connection from the removed node to the
288 // engine node. If an outside node is gone, it must have been absorbed into
289 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)290 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
291                         const size_t my_engine_id,
292                         const std::vector<Node*>& engine_nodes,
293                         const bool is_input_edge, const string& node_name,
294                         Node** node, int* port) {
295   for (size_t t = 0; t < infos.size(); ++t) {
296     if (t == my_engine_id) {
297       continue;
298     }
299     const auto& info = infos.at(t);
300     for (const auto& eng_conn : info.connections) {
301       // If the connection being updated is an input connection, the source of
302       // the connection must be an output connection of another engine. And vise
303       // versa.
304       if (is_input_edge == eng_conn.is_input_edge) continue;
305       if (eng_conn.inside_node_name == node_name &&
306           eng_conn.inside_port == *port) {
307         *node = CHECK_NOTNULL(engine_nodes[t]);
308         QCHECK_EQ(info.engine_name, (**node).name())
309             << "Engine name mismatch: " << info.engine_name << " vs "
310             << (**node).name();
311         *port = eng_conn.port_number;
312         return;
313       }
314     }
315   }
316   LOG(FATAL) << "Node " << node_name << " not found in any engine.";
317 }
318 
319 // Function to insert a TRT engine node into the graph.
320 // Create engine nodes in the following way:
321 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
322 // 2. When an engine node is created, add it into the graph with necessary
323 //    re-wiring.
324 //    2.1. If the outside connected node is existing, connect the engine
325 //         node to it.
326 //    2.2. If the outside connected node is gone, it must have been absorted
327 //         into another engine node (which was processed before the processing
328 //         one). Connect to the pre-existing engine node instead.
329 // 3. In this way, we ensure the graph is topologically sort-able after each
330 //    invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int default_max_batch_size,Graph * graph,std::vector<Node * > * engine_nodes)331 Status CreateTRTNode(const ConversionParams& params,
332                      const std::vector<EngineInfo>& infos, int pos,
333                      int default_max_batch_size, Graph* graph,
334                      std::vector<Node*>* engine_nodes) {
335   const auto& info = infos.at(pos);
336   std::vector<tensorflow::TensorShapeProto> input_shape_protos;
337   std::vector<PartialTensorShape> input_shapes;
338   std::vector<NodeDefBuilder::NodeOut> inputs;
339   std::vector<Node*> input_nodes;
340   std::vector<Node*> control_input_nodes;
341   std::unordered_set<string> control_input_names;
342   std::vector<DataType> out_types;
343 
344   VLOG(1) << "Processing " << info.engine_name;
345   // Collect needed info for creating the engine node in the graph
346   for (const auto& conn : info.connections) {
347     // Control edges
348     if (conn.is_control_edge()) {
349       // Skip control outputs for now. control output info are not needed for
350       // node creation and will be processed later.
351       if (!conn.is_input_edge) continue;
352 
353       // Rewrire control input if it's not found in original graph.
354       Node* input_node = graph->FindNodeId(conn.outside_id);
355       int port = Graph::kControlSlot;
356       if (!input_node) {
357         UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
358                            conn.outside_node_name, &input_node, &port);
359         QCHECK_EQ(Graph::kControlSlot, port);
360       }
361       if (!control_input_names.insert(input_node->name()).second) {
362         continue;
363       }
364       control_input_nodes.push_back(input_node);
365       VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
366               << info.engine_name;
367     } else {
368       // Data edges
369       if (!conn.is_input_edge) {
370         // Set the data types of output edge.
371         if (out_types.size() <= conn.port_number) {
372           out_types.resize(conn.port_number + 1);
373         }
374         out_types.at(conn.port_number) = conn.connection_type;
375       } else {
376         // Set the shapes and data types of input edge.
377         if (input_shapes.size() <= conn.port_number) {
378           input_shape_protos.resize(conn.port_number + 1);
379           input_shapes.resize(conn.port_number + 1);
380         }
381         conn.outside_shape.AsProto(&input_shape_protos.at(conn.port_number));
382         input_shapes.at(conn.port_number) = conn.outside_shape;
383         // Shape must be fully defined (excluding batch dimension) for static
384         // mode.
385         if (params.use_implicit_batch &&
386             info.engine_type == EngineInfo::EngineType::TRTStatic) {
387           for (int i = 1; i < conn.outside_shape.dims(); i++) {
388             if (conn.outside_shape.dim_size(i) <= 0) {
389               return errors::Internal(
390                   "Not fully defined input shape when in static mode which "
391                   "should have been excluded by the segmenter. ");
392             }
393           }
394         }
395 
396         // Rewrire data input if it's not found in original graph.
397         Node* input_node = graph->FindNodeId(conn.outside_id);
398         int port = conn.outside_port;
399         if (!input_node) {
400           UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
401                              conn.outside_node_name, &input_node, &port);
402         }
403         if (std::find_if(
404                 std::begin(inputs), std::end(inputs),
405                 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
406                   return inp.node == input_node->name() && inp.index == port;
407                 }) == std::end(inputs)) {
408           inputs.emplace_back(input_node->name(), port, conn.connection_type);
409           input_nodes.push_back(CHECK_NOTNULL(input_node));
410           VLOG(1) << "Engine Input " << input_node->name() << ":" << port
411                   << " -> " << info.engine_name << ":" << inputs.size() - 1;
412         }
413       }
414     }
415   }
416   // We don't support segments with no inputs. Fall back to native TF here to
417   // avoid crash later. Constant folding should've folded the ops that make up
418   // these segments.
419   if (inputs.empty()) {
420     return errors::Internal(
421         "Segment has no inputs (possible constfold failure)");
422   }
423 
424   const bool calibrate_int8 =
425       (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
426   // Build the engine and get its serialized representation.
427   string segment_string;
428 
429   int max_batch_size = info.max_batch_size.has_value()
430                            ? info.max_batch_size.value()
431                            : default_max_batch_size;
432 
433   if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
434     std::pair<int, Allocator*> device_allocator =
435         GetDeviceAndAllocator(params, info);
436     int cuda_device_id = 0;
437     std::unique_ptr<TRTBaseAllocator> trt_allocator;
438     if (device_allocator.first >= 0) {
439       cuda_device_id = device_allocator.first;
440       trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
441     } else {
442       // The value in trt_allocator is a nullptr and cudamalloc will be used.
443       LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
444                                  "device 0 and use cudamalloc as an allocator";
445     }
446     cudaSetDevice(cuda_device_id);
447 
448     auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
449 
450     // Create static engines with precision_mode fp32/fp16.
451     TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
452     TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
453         info.segment_graph_def,
454         calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
455         max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
456         trt_allocator.get(), /*calibrator=*/nullptr, &engine,
457         info.use_calibration, params.use_implicit_batch,
458         /*convert_successfully=*/nullptr,
459         /*profile=*/nullptr, info.engine_name));
460     TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
461     segment_string = string(static_cast<const char*>(engine_data->data()),
462                             engine_data->size());
463   }
464 
465   string prec_string;
466   TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
467   NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
468   if (!info.device.empty()) node_builder.Device(info.device);
469   if (VLOG_IS_ON(1)) {
470     string ins = StrCat(info.engine_name, " inputs= ");
471     for (const auto& ii : inputs) {
472       StrAppend(&ins, ii.node, ":", ii.index, " ");
473     }
474     VLOG(1) << ins;
475   }
476   node_builder.Input(inputs);
477   for (const string& c : control_input_names) {
478     node_builder.ControlInput(c);
479   }
480 
481   NodeDef trt_node;
482   NameAttrList function;
483   function.set_name(StrCat(info.engine_name, "_native_segment"));
484   Status status =
485       node_builder.Attr("input_shapes", input_shape_protos)
486           .Attr("static_engine",
487                 info.engine_type == EngineInfo::EngineType::TRTStatic)
488           .Attr("segment_func", function)
489           .Attr("serialized_segment", segment_string)
490           .Attr("calibration_data", "")
491           .Attr("max_cached_engines_count", info.maximum_cached_engines)
492           .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
493           .Attr("max_batch_size", max_batch_size)
494           .Attr("precision_mode", prec_string)
495           .Attr("use_calibration", info.use_calibration)
496           .Attr("_use_implicit_batch", params.use_implicit_batch)
497           .Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
498           .Attr("OutT", out_types)
499           .Finalize(&trt_node);
500   if (!status.ok()) {
501     LOG(ERROR) << "Node construction failed with" << status;
502     return status;
503   }
504   VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
505 
506   // Up until this point, graph is not modified. If we return !status.ok() from
507   // here, this segment will be skipped
508   // TODO(aaroey): let it return proper error status for the following logic
509   // instead of checking fail.
510   Node* engine_node = graph->AddNode(trt_node, &status);
511   (*engine_nodes)[pos] = engine_node;
512   if (!status.ok()) {
513     LOG(ERROR) << "Adding node failed " << status;
514     return status;
515   }
516   // Add control input and input edges to the engine node.
517   for (const auto in : control_input_nodes) {
518     VLOG(1) << "Connecting control edge from " << in->name() << " to "
519             << engine_node->name();
520     graph->AddControlEdge(in, engine_node);
521   }
522   VLOG(1) << "input_nodes size = " << input_nodes.size();
523   for (int i = 0; i < input_nodes.size(); ++i) {
524     Node* n = CHECK_NOTNULL(input_nodes[i]);
525     const auto& in = inputs[i];
526     VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
527             << " to " << engine_node->name() << ":" << i;
528     graph->AddEdge(n, in.index, engine_node, i);
529   }
530 
531   // Updates the inputs of output edges destination nodes, and point them to the
532   // engine node.
533   for (auto& conn : info.connections) {
534     if (conn.is_input_edge) {
535       continue;
536     }
537     Node* output_node = graph->FindNodeId(conn.outside_id);
538     int port = conn.outside_port;
539     if (!output_node) {
540       UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
541                          conn.outside_node_name, &output_node, &port);
542     }
543     if (conn.is_control_edge()) {
544       VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
545               << output_node->name();
546       QCHECK_EQ(Graph::kControlSlot, port);
547       graph->AddControlEdge(engine_node, output_node);
548     } else {
549       VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
550               << conn.port_number << " to " << output_node->name() << ":"
551               << port;
552       // Use UpdateEdge() to avoid adding the same edge multiple times.
553       TF_CHECK_OK(
554           graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
555     }
556   }
557   return Status::OK();
558 }
559 
GetNextGraphSequenceNumber()560 int64 GetNextGraphSequenceNumber() {
561   static std::atomic<int64> graph_sequence_num;
562   return graph_sequence_num++;
563 }
564 
565 constexpr char kCastInputTypeAttrName[] = "SrcT";
566 
567 // Transforms node = cast(x, fp32) where datatype(x) != fp16 to:
568 //   castToFp16 = cast(x, fp16)
569 //   node = cast(castToFp16, fp32)
570 //
MaybeRewriteCastToFp32(GraphDef * graph_def,NodeDef * node_def)571 Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) {
572   if (node_def->op() != "Cast") {
573     return Status::OK();
574   }
575 
576   DataTypeVector input_types;
577   DataTypeVector output_types;
578   TF_RETURN_IF_ERROR(
579       graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types));
580 
581   if (input_types.size() != 1 || output_types.size() != 1) {
582     return errors::Internal("Bad cast operation");
583   }
584 
585   if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) {
586     return Status::OK();
587   }
588 
589   VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString();
590 
591   NodeDef* castToFp16 = graph_def->add_node();
592   for (auto attr_value : node_def->attr()) {
593     (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second;
594   }
595   castToFp16->set_name(node_def->name() + "_split");
596   castToFp16->set_op("Cast");
597   castToFp16->set_device(node_def->device());
598   castToFp16->add_input(node_def->input(0));
599   (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF);
600 
601   node_def->set_input(0, castToFp16->name() + ":0");
602   (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF);
603 
604   VLOG(2) << castToFp16->DebugString();
605   VLOG(2) << node_def->DebugString();
606 
607   return Status::OK();
608 }
609 
610 }  // namespace
611 
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name)612 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
613                                       Graph* graph, const string& engine_name) {
614   Graph segment_graph(graph->flib_def());
615   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
616                                             segment_graph_def, &segment_graph));
617   FunctionDefLibrary library;
618   auto segment_func = library.add_function();
619   TF_RETURN_IF_ERROR(GraphToFunctionDef(
620       segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
621   if (VLOG_IS_ON(7)) {
622     VLOG(7) << engine_name << " Function_Def ";
623     VLOG(7) << segment_func->DebugString();
624   }
625   VLOG(1) << "Adding funcdef " << segment_func->signature().name()
626           << " to graphlib";
627   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
628   return Status::OK();
629 }
630 
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)631 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
632                                                  const EngineInfo& engine) {
633   int cuda_device_id = -1;
634   Allocator* dev_allocator = nullptr;
635   if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
636       engine.device.empty()) {
637     // If device is not set, use the first found GPU device for the conversion.
638     TfGpuId tf_gpu_id;
639     PlatformGpuId platform_gpu_id;
640     std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
641     cuda_device_id = platform_gpu_id.value();
642     if (cuda_device_id >= 0) {
643       GPUOptions gpu_options;
644       // If the TF to Cuda gpu id mapping exist, the device and corresponding
645       // allocator must have been initialized already, so the
646       // GetGPUAllocator() call won't create a new allocator.
647       dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
648           gpu_options, tf_gpu_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
649     }
650     return std::make_pair(cuda_device_id, dev_allocator);
651   }
652 
653   // Use the device requested by the engine.
654   auto device_set = params.cluster->GetDeviceSet();
655   std::vector<Device*> devices;
656   DeviceNameUtils::ParsedName parsed_name;
657   if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
658       parsed_name.has_id) {
659     device_set->FindMatchingDevices(parsed_name, &devices);
660   }
661   if (!devices.empty()) {
662     if (devices.size() > 1) {
663       string msg = "Found multiple matching devices using name '";
664       StrAppend(&msg, engine.device, "': ");
665       for (auto d : devices) StrAppend(&msg, d->name(), ", ");
666       StrAppend(&msg, ". Will get the allocator from first one.");
667       LOG_WARNING_WITH_PREFIX << msg;
668     }
669     AllocatorAttributes alloc_attr;
670     cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
671     dev_allocator = devices[0]->GetAllocator(alloc_attr);
672     VLOG(1) << "Using allocator " << dev_allocator->Name()
673             << " and cuda_device_id " << cuda_device_id;
674   } else {
675     LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device
676                             << "' is not found in the cluster";
677   }
678   return std::make_pair(cuda_device_id, dev_allocator);
679 }
680 
681 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)682 Status ConvertAfterShapes(const ConversionParams& params) {
683   // Sanity checks.
684   if (params.precision_mode != TrtPrecisionMode::INT8 &&
685       params.use_calibration) {
686     return errors::InvalidArgument(
687         "Calibration with FP32 or FP16 is not supported.");
688   }
689 
690   // Make a copy of the input_graph_def because grappler doesn't allow changes
691   // to the input_graph_def and GraphProperties only accepts GraphDef, but not
692   // Graph, as inputs.
693   //
694   // If the overhead of copying the input_graph_def becomes a concern, we can
695   // avoid the copy by (1) enhancing the GraphPropertiers representation to
696   // allow adding shape properties for newly created graph nodes and (2) rewrite
697   // the GraphDef transformation to Graph transformation.
698   GraphDef modified_graph_def = params.grappler_item->graph;
699   // When precision_mode is FP16, transform cast(x, fp32) to
700   // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be
701   // included in the TRTEngineOp as an TensorRT Identity layer for performance:
702   //  . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16
703   //    precision.
704   //  . Changing the input to the TRTEngine from fp32 to fp16 may reduce data
705   //    moving from the host to the GPU.
706   if (params.precision_mode == TrtPrecisionMode::FP16) {
707     for (int i = 0; i < modified_graph_def.node_size(); i++) {
708       NodeDef* node_def = modified_graph_def.mutable_node(i);
709       TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&modified_graph_def, node_def));
710     }
711   }
712 
713   // Construct a GrapplerItem using the modified graph_def and the input
714   // grappler_item.
715   grappler::GrapplerItem grappler_item =
716       params.grappler_item->WithGraph(std::move(modified_graph_def));
717   const GraphDef& graph_def = grappler_item.graph;
718 
719   grappler::GraphProperties static_graph_properties(grappler_item);
720   TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
721 
722   // Convert graphdef to graph.
723   FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
724   Graph graph(flib);
725   TF_RETURN_IF_ERROR(
726       ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph));
727 
728   // Segment the graph into subgraphs that can be converted to TensorRT
729   segment::SegmentOptions segment_options;
730   // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
731   for (const auto& node : *(params.output_names)) {
732     segment_options.exclude_node_list.insert(node);
733   }
734   segment_options.minimum_segment_size = params.minimum_segment_size;
735   segment_options.use_implicit_batch = params.use_implicit_batch;
736   if (segment_options.use_implicit_batch)
737     segment_options.maximum_batch_size = params.max_batch_size;
738   segment_options.allow_dynamic_non_batch_dim =
739       AllowDynamicNonBatchDimension(params);
740 
741   segment::SegmentVector initial_segments;
742   TrtNodeValidator validator(static_graph_properties, params.precision_mode,
743                              params.use_calibration, params.use_implicit_batch);
744   TF_RETURN_IF_ERROR(segment::SegmentGraph(
745       &graph, &static_graph_properties,
746       std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
747                 std::placeholders::_1),
748       // Input validation is already done by TrtNodeValidator, so we don't
749       // need to check the input edges.
750       [](const Edge* edge) { return true; }, OutputEdgeValidator(),
751       segment_options, &initial_segments));
752   LOG(INFO) << "Number of TensorRT candidate segments: "
753             << initial_segments.size();
754 
755   // Get the EngineInfo for each segment.
756   std::unordered_map<string, Node*> node_map;
757   TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
758   std::vector<EngineInfo> engine_segments;
759   engine_segments.reserve(initial_segments.size());
760   std::vector<Node*> reverse_topo_order;
761   GetPostOrder(graph, &reverse_topo_order);
762   segment::SegmentVector converted_segments;
763   converted_segments.reserve(initial_segments.size());
764   string engine_name_prefix =
765       StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
766   for (size_t t = 0; t < initial_segments.size(); t++) {
767     auto& curr_segment = initial_segments.at(t);
768     EngineInfo curr_engine;
769     curr_engine.engine_name = StrCat(engine_name_prefix, t);
770     Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment,
771                                   node_map, reverse_topo_order, &curr_engine);
772     if (!status.ok()) {
773       LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t
774                               << ": " << status;
775       continue;
776     }
777     curr_engine.precision_mode = params.precision_mode;
778     curr_engine.engine_type = GetEngineType(params);
779     curr_engine.use_calibration = params.use_calibration;
780     curr_engine.maximum_cached_engines = params.max_cached_engines;
781     curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
782     if (!curr_engine.max_batch_size.has_value()) {
783       curr_engine.max_batch_size = params.max_batch_size;
784     }
785 
786     status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
787                                             &graph, curr_engine.engine_name);
788 
789     if (!status.ok()) {
790       LOG_WARNING_WITH_PREFIX
791           << "Failed to register segment graphdef to the library " << t << ": "
792           << status;
793       continue;
794     }
795 
796     engine_segments.push_back(std::move(curr_engine));
797     converted_segments.push_back(std::move(curr_segment));
798 
799     if (VLOG_IS_ON(8)) {
800       string fname = engine_segments.back().engine_name;
801       StrAppend(&fname, ".pb");
802       std::fstream f;
803       f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
804       f << engine_segments.at(t).segment_graph_def.SerializeAsString();
805       f.close();
806     }
807   }
808 
809   // Save the cuda device if we may need to switch to another cuda device to
810   // build static engines.
811   absl::optional<int> old_cuda_device = absl::nullopt;
812   if (!params.is_dyn_op) {
813     int cuda_device_id;
814     cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
815     if (cuda_error != cudaSuccess) {
816       LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
817                               << cudaGetErrorString(cuda_error);
818     } else {
819       VLOG(1) << "Current cuda device is " << cuda_device_id;
820       old_cuda_device = cuda_device_id;
821     }
822   }
823 
824   auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
825     if (old_cuda_device.has_value()) {
826       cudaSetDevice(old_cuda_device.value());
827     }
828   });
829 
830   std::vector<Node*> engine_nodes;
831   engine_nodes.resize(engine_segments.size());
832   for (int i = 0; i < engine_segments.size(); ++i) {
833     auto& engine = engine_segments.at(i);
834     // TODO(b/170762693): implement the heuristic to calculate
835     // max_workspace_size_bytes.
836     engine.max_workspace_size_bytes = params.max_workspace_size_bytes;
837     VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
838             << engine.engine_name;
839     auto status = CreateTRTNode(params, engine_segments, i,
840                                 params.max_batch_size, &graph, &engine_nodes);
841 
842     string msg = StrCat("segment ", i, " consisting of ",
843                         converted_segments.at(i).nodes.size(), " nodes by ",
844                         engine.engine_name);
845     if (status.ok()) {
846       LOG(INFO) << "Replaced " << msg << ".";
847     } else {
848       // Graph is not modified.
849       LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg
850                               << " reason: " << status.error_message()
851                               << " (keeping original segment).";
852     }
853     if (VLOG_IS_ON(1)) {
854       msg = "Segment consists of nodes: ";
855       for (const Node* node : converted_segments.at(i).nodes) {
856         StrAppend(&msg, node->name(), ", ");
857       }
858       VLOG(1) << msg;
859     }
860 
861     // If status is ok, we successfully added the node to the graph and can
862     // remove segment ops. Otherwise graph is not modified.
863     if (status.ok()) {
864       for (const Node* node : converted_segments.at(i).nodes) {
865         graph.RemoveNode(const_cast<Node*>(node));
866       }
867     }
868   }
869   graph.ToGraphDef(params.output_graph_def);
870   VLOG(1) << "Returning from conversion";
871   return Status::OK();
872 }
873 
874 }  // namespace convert
875 }  // namespace tensorrt
876 }  // namespace tensorflow
877 
878 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
879