1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
30 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
31 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
32 #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/graph_to_functiondef.h"
38 #include "tensorflow/core/framework/node_def_builder.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/graph/graph.h"
41 #include "tensorflow/core/graph/graph_constructor.h"
42 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
43 #include "tensorflow/core/grappler/costs/graph_properties.h"
44 #include "tensorflow/core/grappler/devices.h"
45 #include "tensorflow/core/grappler/grappler_item.h"
46 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
47 #include "tensorflow/core/grappler/utils.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/types.h"
53 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
54 #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
55 #include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
56 #include "tensorflow/core/util/device_name_utils.h"
57
58 #if GOOGLE_CUDA
59 #if GOOGLE_TENSORRT
60 #include "cuda/include/cuda_runtime_api.h"
61 #include "tensorrt/include/NvInfer.h"
62 namespace tensorflow {
63 namespace tensorrt {
64 namespace convert {
65 using absl::StrAppend;
66 using absl::StrCat;
67
TrtCandidateSelector(const grappler::GraphProperties & graph_properties,TrtPrecisionMode precision_mode)68 TrtCandidateSelector::TrtCandidateSelector(
69 const grappler::GraphProperties& graph_properties,
70 TrtPrecisionMode precision_mode)
71 : graph_properties_(graph_properties), precision_mode_(precision_mode) {}
72
IsTensorRTCandidate(const Node * node)73 Status TrtCandidateSelector::IsTensorRTCandidate(const Node* node) {
74 std::vector<const Edge*> input_edges;
75 TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
76 std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
77 input_node_and_ports.reserve(input_edges.size());
78 for (const Edge* input_edge : input_edges) {
79 input_node_and_ports.emplace_back(&input_edge->src()->def(),
80 input_edge->src_output());
81 }
82 return validator_.ValidateNode(node->def(), input_node_and_ports,
83 precision_mode_, graph_properties_);
84 }
85
86 namespace {
87
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)88 Status BuildNodeMap(const Graph& graph,
89 std::unordered_map<string, Node*>* node_map) {
90 for (auto* node : graph.op_nodes()) {
91 if (!node_map->insert({node->name(), node}).second) {
92 return errors::AlreadyExists("Node name is not unique in graph: " +
93 node->name());
94 }
95 }
96 return Status::OK();
97 }
98
99 } // namespace
100
ConvertGraphDefToTensorRT(const GraphDef & graph_def,const std::vector<string> & output_names,size_t max_batch_size,size_t max_workspace_size_bytes,GraphDef * new_graph_def,TrtPrecisionMode precision_mode,int minimum_segment_size,bool is_dyn_op,int max_cached_engines,std::vector<int> cached_engine_batches,bool use_calibration)101 Status ConvertGraphDefToTensorRT(
102 const GraphDef& graph_def, const std::vector<string>& output_names,
103 size_t max_batch_size, size_t max_workspace_size_bytes,
104 GraphDef* new_graph_def, TrtPrecisionMode precision_mode,
105 int minimum_segment_size, bool is_dyn_op, int max_cached_engines,
106 std::vector<int> cached_engine_batches, bool use_calibration) {
107 // Create GrapplerItem.
108 grappler::GrapplerItem item;
109 item.fetch = output_names;
110 item.graph = graph_def;
111
112 // TODO(aaroey): we should have used single machine cluster like the
113 // following, but the problem is then wrap_conversion will depend on
114 // direct_session and cause double linking problems. To fix this we need to
115 // fix or get rid of the swig dependency. Here we use VirtualCluster
116 // as a work around, and we need to create a session to initialize the
117 // underlying device before calling this method.
118 #if 0
119 // Create single machine cluster. Note that this will create a session and
120 // initialize the gpu devices.
121 const int num_cpu_cores =
122 grappler::GetNumAvailableLogicalCPUCores();
123 const int num_gpus = grappler::GetNumAvailableGPUs();
124 VLOG(2) << "cpu_cores: " << num_cpu_cores;
125 VLOG(2) << "gpus: " << num_gpus;
126 const int timeout_s = 60 * 10;
127 std::unique_ptr<grappler::Cluster> cluster(
128 new grappler::SingleMachine(
129 timeout_s, num_cpu_cores, num_gpus));
130 // These settings are the defaults in tensorflow/python/grappler/cluster.py.
131 cluster->DisableDetailedStats(true);
132 cluster->AllowSoftPlacement(true);
133 cluster->SetNumWarmupSteps(10);
134 TF_RETURN_IF_ERROR(cluster->Provision());
135 #else
136 // Create virtual cluster. Grappler requires a virtual cluster with a proper
137 // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode.
138 // We add numbers from a Pascal card here to have flops>0.
139 DeviceProperties device_properties;
140 device_properties.set_type("GPU");
141 device_properties.mutable_environment()->insert({"architecture", "6"});
142 device_properties.set_num_cores(3584);
143 device_properties.set_frequency(1531);
144 std::unique_ptr<grappler::Cluster> cluster(
145 new grappler::VirtualCluster({{"/GPU:0", device_properties}}));
146 #endif
147
148 // Create RewriterConfig.
149 ConfigProto config_proto;
150 auto& rw_cfg =
151 *config_proto.mutable_graph_options()->mutable_rewrite_options();
152 // TODO(aaroey): use only const folding and layout for the time being since
153 // new optimizers break the graph for trt.
154 rw_cfg.add_optimizers("constfold");
155 rw_cfg.add_optimizers("layout");
156 auto optimizer = rw_cfg.add_custom_optimizers();
157 optimizer->set_name("TensorRTOptimizer");
158 auto& parameters = *(optimizer->mutable_parameter_map());
159 parameters["minimum_segment_size"].set_i(minimum_segment_size);
160 parameters["max_batch_size"].set_i(max_batch_size);
161 parameters["is_dynamic_op"].set_b(is_dyn_op);
162 parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes);
163 TF_RETURN_IF_ERROR(TrtPrecisionModeToName(
164 precision_mode, parameters["precision_mode"].mutable_s()));
165 parameters["maximum_cached_engines"].set_i(max_cached_engines);
166 if (!cached_engine_batches.empty()) {
167 auto list = parameters["cached_engine_batches"].mutable_list();
168 for (const int batch : cached_engine_batches) {
169 list->add_i(batch);
170 }
171 }
172 parameters["use_calibration"].set_b(use_calibration);
173
174 // Run optimizer.
175 grappler::MetaOptimizer meta_opt(nullptr, config_proto);
176 TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def));
177
178 if (VLOG_IS_ON(5)) {
179 std::fstream f;
180 f.open("TRTConversionInput.pb",
181 std::fstream::out | std::fstream::binary | std::fstream::trunc);
182 f << new_graph_def->SerializeAsString();
183 f.close();
184 }
185 return Status::OK();
186 }
187
188 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::EdgePtrCompare189 bool operator()(const Edge* lhs, const Edge* rhs) const {
190 return lhs->id() < rhs->id();
191 }
192 };
193
194 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const std::set<const Node * > & segment_nodes,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)195 Status GetEngineInfo(const Graph* g,
196 const grappler::GraphProperties& graph_properties,
197 const std::set<const Node*>& segment_nodes,
198 const std::unordered_map<string, Node*>& node_map,
199 const std::vector<Node*>& reverse_topo_order,
200 EngineInfo* info) {
201 std::vector<const Node*> subgraph_nodes; // Topologically sorted nodes.
202 std::set<const Node*> added_const_nodes; // Used to prevent double insertion.
203 std::set<string> segment_devices;
204
205 // Map from src_node_name+port to the unique port numbers of the TRT op, where
206 // the src_node_name is the name of the source node of the input/output
207 // edge, thus there must not be any duplicates since source nodes of
208 // input/output edges must be in different split of the graph.
209 // TODO(aaroey): consider using node id and port instead.
210 // TODO(aaroey): using topo order instead of reverting reverse topo order.
211 std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
212 for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
213 ++it) {
214 const Node* node = *it;
215 if (segment_nodes.count(node) == 0) continue;
216 auto node_device = node->requested_device();
217 if (!node_device.empty()) {
218 // If device is CPU, treat as if no device was assigned. Don't add CPU to
219 // segment_device because that would cause a segfault in
220 // GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes
221 // any already set device is a GPU.
222 DeviceNameUtils::ParsedName parsed_name;
223 DeviceNameUtils::ParseFullName(node_device, &parsed_name);
224 if (parsed_name.type == "CPU") {
225 VLOG(1) << "Node " << node->name() << " was assigned to the CPU. "
226 << "Attempting to place on GPU.";
227 } else {
228 segment_devices.insert(node_device);
229 }
230 } else {
231 if (node->has_assigned_device_name()) {
232 // It appears that nodes will not have assigned devices at this point in
233 // execution.
234 segment_devices.insert(node->assigned_device_name());
235 } else {
236 VLOG(2) << "Node " << node->name()
237 << " neither have requested device nor assigned device";
238 }
239 }
240 subgraph_nodes.push_back(node);
241
242 const int node_id = node->id();
243 const string& node_name = node->name();
244
245 // Create input connections. Sort edges first to make determnistic since
246 // in_edges is a set of pointers.
247 std::vector<const Edge*> in_edges(node->in_edges().begin(),
248 node->in_edges().end());
249 std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
250 for (const auto edge : in_edges) {
251 auto input_node = edge->src();
252 if (input_node->IsSource() || segment_nodes.count(input_node)) {
253 continue;
254 }
255 if (edge->IsControlEdge()) {
256 // Control input.
257 info->connections.emplace_back(input_node->name(), input_node->id(),
258 node_name, node_id,
259 /*input_edge=*/true);
260 } else if (input_node->type_string() == "Const") {
261 // Add constant data input nodes into the segment graphdef (thus also in
262 // the engine). We don't care if it has other output edges going into
263 // other engines or TF nodes. Since we add it only to the segment
264 // graphdef, not the segment itself, it won't be removed from the graph.
265 // If it doesn't have any edges, TF will prune it out.
266 //
267 // Note that the segmenter already ensure that the constant data input
268 // is valid and suppported by the engine.
269 if (!added_const_nodes.insert(input_node).second) {
270 // Already added before.
271 continue;
272 }
273 VLOG(1) << "Adding const node " << input_node->name();
274 // Since we already add (duplicate) the const input node to the segment
275 // graphdef, it's now not a data dependency any more, but to make the
276 // dependency correct we still add a control dependency.
277 info->connections.emplace_back(input_node->name(), input_node->id(),
278 node_name, node_id,
279 /*input_edge=*/true);
280 } else {
281 // Non-const data input.
282 int port = Graph::kControlSlot - 1;
283 // Use the source non-segment node name/port as key.
284 const string s = StrCat(input_node->name(), ":", edge->src_output());
285 VLOG(1) << "Input edge = " << s;
286 if (input_to_engine_port.count(s)) {
287 port = input_to_engine_port.at(s);
288 } else {
289 port = input_to_engine_port.size();
290 input_to_engine_port.insert({s, port});
291 }
292 info->connections.emplace_back(
293 input_node->name(), input_node->id(), edge->src_output(), node_name,
294 node_id, edge->dst_input(), /*input_edge=*/true, port);
295 }
296 }
297 // Create output connections. Sort edges first to make determnistic since
298 // out_edges is a set of pointers.
299 std::vector<const Edge*> out_edges(node->out_edges().begin(),
300 node->out_edges().end());
301 std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
302 for (const auto edge : out_edges) {
303 auto output_node = edge->dst();
304 if (output_node->IsSink() || segment_nodes.count(output_node)) {
305 continue;
306 }
307 if (edge->IsControlEdge()) {
308 // Control output.
309 info->connections.emplace_back(output_node->name(), output_node->id(),
310 node_name, node_id,
311 /*input_edge=*/false);
312 } else {
313 // Data output.
314 int port = Graph::kControlSlot - 1;
315 // Use the source segment node name/port as key.
316 const string s = StrCat(node_name, ":", edge->src_output());
317 VLOG(1) << "Output edge = " << s;
318 if (output_to_engine_port.count(s)) {
319 port = output_to_engine_port.at(s);
320 } else {
321 port = output_to_engine_port.size();
322 output_to_engine_port.insert({s, port});
323 }
324 info->connections.emplace_back(
325 output_node->name(), output_node->id(), edge->dst_input(),
326 node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
327 }
328 }
329 } // For each segment node in topological order.
330
331 // Construct the const nodes first.
332 subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
333 added_const_nodes.end());
334 string scope_name;
335 TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
336 g, graph_properties, subgraph_nodes, &info->connections,
337 &info->segment_graph_def, &scope_name));
338 info->engine_name = StrCat(scope_name, info->engine_name);
339 VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
340 << "' to a GraphDef";
341 // TODO(sami): This should not happen once segmenter is updated.
342 if (segment_devices.size() == 1) {
343 info->device = *segment_devices.begin();
344 } else if (segment_devices.size() > 1) {
345 LOG(WARNING) << "Detected multiple(" << segment_devices.size()
346 << ") devices for the segment. Picking first one to continue "
347 << "but this shouldn't have happened";
348 info->device = *segment_devices.begin();
349 } else {
350 VLOG(1) << "No device is assigned to the segment. "
351 << "A device will be assigned during graph execution (inference).";
352 }
353 return Status::OK();
354 }
355
356 // Helper function to update edge connection from the removed node to the
357 // engine node. If an outside node is gone, it must have been absorbed into
358 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)359 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
360 const size_t my_engine_id,
361 const std::vector<Node*>& engine_nodes,
362 const bool is_input_edge, const string& node_name,
363 Node** node, int* port) {
364 for (size_t t = 0; t < infos.size(); ++t) {
365 if (t == my_engine_id) {
366 continue;
367 }
368 const auto& info = infos.at(t);
369 for (const auto& eng_conn : info.connections) {
370 // If the connection being updated is an input connection, the source of
371 // the connection must be an output connection of another engine. And vise
372 // versa.
373 if (is_input_edge == eng_conn.is_input_edge) continue;
374 if (eng_conn.inside_node_name == node_name &&
375 eng_conn.inside_port == *port) {
376 *node = CHECK_NOTNULL(engine_nodes[t]);
377 QCHECK_EQ(info.engine_name, (**node).name())
378 << "Engine name mismatch: " << info.engine_name << " vs "
379 << (**node).name();
380 *port = eng_conn.port_number;
381 return;
382 }
383 }
384 }
385 LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
386 }
387
388 // Function to insert a TRT engine node into the graph.
389 // Create engine nodes in the following way:
390 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
391 // 2. When an engine node is created, add it into the graph with necessary
392 // re-wiring.
393 // 2.1. If the outside connected node is existing, connect the engine
394 // node to it.
395 // 2.2. If the outside connected node is gone, it must have been absorted
396 // into another engine node (which was processed before the processing
397 // one). Connect to the pre-existing engine node instead.
398 // 3. In this way, we ensure the graph is topologically sort-able after each
399 // invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int max_batch_size,Graph * graph,nvinfer1::IGpuAllocator * alloc,std::vector<Node * > * engine_nodes)400 Status CreateTRTNode(const ConversionParams& params,
401 const std::vector<EngineInfo>& infos, int pos,
402 int max_batch_size, Graph* graph,
403 nvinfer1::IGpuAllocator* alloc,
404 std::vector<Node*>* engine_nodes) {
405 const auto& info = infos.at(pos);
406 std::vector<TensorShapeProto> output_shape_protos;
407 std::vector<TensorShapeProto> input_shape_protos;
408 std::vector<PartialTensorShape> input_shapes;
409 std::vector<NodeDefBuilder::NodeOut> inputs;
410 std::vector<Node*> input_nodes;
411 std::vector<Node*> control_input_nodes;
412 std::unordered_set<string> control_input_names;
413 std::vector<DataType> out_types;
414
415 VLOG(1) << "Processing " << info.engine_name;
416 // Collect needed info for creating the engine node in the graph
417 for (const auto& conn : info.connections) {
418 // Control edges
419 if (conn.is_control_edge()) {
420 // Skip control outputs for now. control output info are not needed for
421 // node creation and will be processed later.
422 if (!conn.is_input_edge) continue;
423
424 // Rewrire control input if it's not found in original graph.
425 Node* input_node = graph->FindNodeId(conn.outside_id);
426 int port = Graph::kControlSlot;
427 if (!input_node) {
428 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
429 conn.outside_node_name, &input_node, &port);
430 QCHECK_EQ(Graph::kControlSlot, port);
431 }
432 if (!control_input_names.insert(input_node->name()).second) {
433 continue;
434 }
435 control_input_nodes.push_back(input_node);
436 VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
437 << info.engine_name;
438 } else {
439 // Data edges
440 if (!conn.is_input_edge) {
441 // Set the shapes and data types of output edge.
442 TensorShapeProto out_shape;
443 // shape of the output node inside segment
444 conn.inside_shape.AsProto(&out_shape);
445 if (output_shape_protos.size() <= conn.port_number) {
446 output_shape_protos.resize(conn.port_number + 1);
447 out_types.resize(conn.port_number + 1);
448 }
449 output_shape_protos.at(conn.port_number) = out_shape;
450 out_types.at(conn.port_number) = conn.connection_type;
451 } else {
452 // Set the shapes and data types of input edge.
453 TensorShapeProto in_shape;
454 conn.outside_shape.AsProto(&in_shape);
455 if (input_shape_protos.size() <= conn.port_number) {
456 input_shape_protos.resize(conn.port_number + 1);
457 input_shapes.resize(conn.port_number + 1);
458 }
459 input_shape_protos.at(conn.port_number) = in_shape;
460 input_shapes.at(conn.port_number) = conn.outside_shape;
461 // Shape must be fully defined (excluding batch dimension) for static
462 // mode.
463 if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
464 for (int i = 1; i < conn.outside_shape.dims(); i++) {
465 if (conn.outside_shape.dim_size(i) <= 0) {
466 return errors::Internal(
467 "Input shapes must be fully defined when in static mode. "
468 "Please try is_dynamic_op=True (shape was ",
469 conn.outside_shape.DebugString(), ")");
470 }
471 }
472 }
473
474 // Rewrire data input if it's not found in original graph.
475 Node* input_node = graph->FindNodeId(conn.outside_id);
476 int port = conn.outside_port;
477 if (!input_node) {
478 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
479 conn.outside_node_name, &input_node, &port);
480 }
481 if (std::find_if(
482 std::begin(inputs), std::end(inputs),
483 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
484 return inp.node == input_node->name() && inp.index == port;
485 }) == std::end(inputs)) {
486 inputs.emplace_back(input_node->name(), port, conn.connection_type);
487 input_nodes.push_back(CHECK_NOTNULL(input_node));
488 VLOG(1) << "Engine Input " << input_node->name() << ":" << port
489 << " -> " << info.engine_name << ":" << inputs.size() - 1;
490 }
491 }
492 }
493 }
494 // We don't support segments with no inputs. Fall back to native TF here to
495 // avoid crash later. Constant folding should've folded the ops that make up
496 // these segments.
497 if (inputs.empty()) {
498 return errors::Internal(
499 "Segment has no inputs (possible constfold failure)");
500 }
501
502 const bool calibrate_int8 =
503 (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
504 // Build the engine and get its serialized representation.
505 string segment_string;
506 if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) {
507 // Create static engine for fp32/fp16 mode, and test validity of the engine
508 // for int8 calibration mode. We don't want engine to fail at the
509 // calibration time. So we are constructing a FP32 engine here to check its
510 // validity, and if it is a valid engine then we put the serialized graphdef
511 // to the op. Otherwise we skip node creation for this engine.
512 Logger trt_logger;
513 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
514 // TODO(sami): What happens if 1st dim is not batch?
515 TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
516 info.segment_graph_def,
517 calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
518 max_batch_size, info.max_workspace_size_bytes, input_shapes,
519 &trt_logger, alloc, /*calibrator=*/nullptr, &engine,
520 info.use_calibration,
521 /*convert_successfully=*/nullptr));
522 TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
523 segment_string = string(static_cast<const char*>(engine_data->data()),
524 engine_data->size());
525 if (calibrate_int8) {
526 // See above comment about why not putting this inside the 'else' branch.
527 segment_string = info.segment_graph_def.SerializeAsString();
528 }
529 } else {
530 segment_string = info.segment_graph_def.SerializeAsString();
531 }
532
533 string prec_string;
534 TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
535 NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
536 if (!info.device.empty()) node_builder.Device(info.device);
537 if (VLOG_IS_ON(1)) {
538 string ins = StrCat(info.engine_name, " inputs= ");
539 for (const auto& ii : inputs) {
540 StrAppend(&ins, ii.node, ":", ii.index, " ");
541 }
542 VLOG(1) << ins;
543 }
544 node_builder.Input(inputs);
545 for (const string& c : control_input_names) {
546 node_builder.ControlInput(c);
547 }
548
549 if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
550 !info.cached_engine_batches.empty()) {
551 LOG(WARNING) << "Cached engine batches are ignored for static engines";
552 }
553 NodeDef trt_node;
554 Status status =
555 node_builder.Attr("input_shapes", input_shape_protos)
556 .Attr("output_shapes", output_shape_protos)
557 .Attr("static_engine",
558 info.engine_type == EngineInfo::EngineType::TRTStatic)
559 .Attr("segment_funcdef_name",
560 params.use_function_backup
561 ? StrCat(info.engine_name, "_native_segment")
562 : "")
563 .Attr("serialized_segment", segment_string)
564 .Attr("calibration_data", "")
565 .Attr("max_cached_engines_count", info.maximum_cached_engines)
566 .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
567 .Attr("precision_mode", prec_string)
568 .Attr("use_calibration", info.use_calibration)
569 .Attr("OutT", out_types)
570 .Finalize(&trt_node);
571 if (!status.ok()) {
572 LOG(ERROR) << "Node construction failed with" << status;
573 return status;
574 }
575 VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
576
577 // Up until this point, graph is not modified. If we return !status.ok() from
578 // here, this segment will be skipped
579 // TODO(aaroey): let it return proper error status for the following logic
580 // instead of checking fail.
581 Node* engine_node = graph->AddNode(trt_node, &status);
582 (*engine_nodes)[pos] = engine_node;
583 if (!status.ok()) {
584 LOG(ERROR) << "Adding node failed " << status;
585 return status;
586 }
587 // Add control input and input edges to the engine node.
588 for (const auto in : control_input_nodes) {
589 VLOG(1) << "Connecting control edge from " << in->name() << " to "
590 << engine_node->name();
591 graph->AddControlEdge(in, engine_node);
592 }
593 VLOG(1) << "input_nodes size = " << input_nodes.size();
594 for (int i = 0; i < input_nodes.size(); ++i) {
595 Node* n = CHECK_NOTNULL(input_nodes[i]);
596 const auto& in = inputs[i];
597 VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
598 << " to " << engine_node->name() << ":" << i;
599 graph->AddEdge(n, in.index, engine_node, i);
600 }
601
602 // Updates the inputs of output edges destination nodes, and point them to the
603 // engine node.
604 for (auto& conn : info.connections) {
605 if (conn.is_input_edge) {
606 continue;
607 }
608 Node* output_node = graph->FindNodeId(conn.outside_id);
609 int port = conn.outside_port;
610 if (!output_node) {
611 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
612 conn.outside_node_name, &output_node, &port);
613 }
614 VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
615 << " to " << output_node->name() << ":" << port;
616 if (conn.is_control_edge()) {
617 QCHECK_EQ(Graph::kControlSlot, port);
618 graph->AddControlEdge(engine_node, output_node);
619 } else {
620 auto new_edge =
621 graph->AddEdge(engine_node, conn.port_number, output_node, port);
622 QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
623 << ":" << conn.port_number << " -> "
624 << output_node->name() << ":" << conn.outside_port;
625 }
626 }
627 return Status::OK();
628 }
629
630 // Function to construct a funcdef from the segment and add it to the graph.
RegisterSegmentFunctionToFunctionLibrary(Graph * graph,const GraphDef & segment,const string & engine_name)631 Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
632 const GraphDef& segment,
633 const string& engine_name) {
634 Graph sgraph(graph->flib_def());
635 GraphConstructorOptions gcopts;
636 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph));
637 std::map<string, Node*> io_nodes;
638 int num_inputs = 0;
639 for (auto n : sgraph.op_nodes()) {
640 if (str_util::StartsWith(n->name(), kInputPHName)) {
641 num_inputs++;
642 io_nodes.insert({n->name(), n});
643 } else if (str_util::StartsWith(n->name(), kOutputPHName)) {
644 io_nodes.insert({n->name(), n});
645 }
646 }
647
648 for (int i = 0; i < num_inputs; ++i) {
649 auto name = StrCat(kInputPHName, i);
650 auto node = io_nodes[name];
651 NodeDef nd;
652 NodeDefBuilder node_builder(StrCat(name, "_Arg"),
653 FunctionLibraryDefinition::kArgOp);
654 VLOG(1) << "Adding " << StrCat(name, "_Arg");
655 TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
656 .Attr("index", i)
657 .Finalize(&nd));
658 Status s;
659 auto node_arg = sgraph.AddNode(nd, &s);
660 if (!s.ok()) {
661 LOG(ERROR) << "Couldn't add _Arg node for " << name;
662 }
663 for (auto edge : node->out_edges()) {
664 sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
665 VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
666 << " - > " << edge->dst()->name() << ":" << edge->dst_input();
667 if (!s.ok()) {
668 LOG(ERROR) << "Failed to update edge from " << node_arg->name()
669 << " to " << edge->dst()->name() << ":" << edge->dst_input();
670 }
671 }
672 sgraph.RemoveNode(node);
673 }
674
675 for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
676 auto name = StrCat(kOutputPHName, i);
677 auto node = io_nodes[name];
678 NodeDef nd;
679 NodeDefBuilder node_builder(StrCat(name, "_Ret"),
680 FunctionLibraryDefinition::kRetOp);
681 auto edge = *(node->in_edges().begin());
682 NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(),
683 edge->src()->output_type(edge->src_output()));
684 VLOG(1) << " input " << nout.node << ":" << nout.index
685 << " dtype=" << DataTypeString(nout.data_type);
686 // nvcc complains that Input(<brace-enclosed initializer list>) is
687 // ambiguous, so do not use Input({nout}).
688 node_builder.Input(nout);
689 TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
690 .Attr("index", i)
691 .Finalize(&nd));
692 if (VLOG_IS_ON(3)) {
693 VLOG(3) << nd.DebugString();
694 }
695 Status s;
696 auto node_ret = sgraph.AddNode(nd, &s);
697 if (!s.ok()) {
698 LOG(ERROR) << "Couldn't add _Ret node for " << name;
699 }
700 VLOG(1) << "Update edge from " << edge->src()->name() << ":"
701 << edge->src_output() << " - > " << node_ret->name() << ":" << 0;
702 sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
703 s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
704 if (!s.ok()) {
705 LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
706 << edge->src_output() << " - > " << node_ret->name() << ":"
707 << 0;
708 }
709 sgraph.RemoveNode(node);
710 }
711 FunctionDefLibrary fdeflib;
712 auto native_segment = fdeflib.add_function();
713 TF_RETURN_IF_ERROR(GraphToFunctionDef(
714 sgraph, StrCat(engine_name, "_native_segment"), native_segment));
715 // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
716 // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
717 // would be on host if the op generating the tensor has host memory tag set.
718 (*native_segment
719 ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
720 .set_b(true);
721 if (VLOG_IS_ON(7)) {
722 VLOG(7) << engine_name << " Function_Def ";
723 VLOG(7) << native_segment->DebugString();
724 }
725 VLOG(1) << "Adding funcdef to graphlib";
726 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
727 return Status::OK();
728 }
729
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)730 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
731 const EngineInfo& engine) {
732 int cuda_device_id = -1;
733 Allocator* dev_allocator = nullptr;
734 if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
735 engine.device.empty()) {
736 // If device is not set, use the first found GPU device for the conversion.
737 for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
738 TfGpuId tf_gpu_id(tf_gpu_id_value);
739 PlatformGpuId platform_gpu_id;
740 Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
741 if (s.ok()) {
742 VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
743 << platform_gpu_id.value();
744 cuda_device_id = platform_gpu_id.value();
745 GPUOptions gpu_options;
746 // If the TF to Cuda gpu id mapping exist, the device and corresponding
747 // allocator must have been initialized already, so the
748 // GetGPUAllocator() call won't create a new allocator.
749 dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
750 gpu_options, tf_gpu_id, 1);
751 break;
752 }
753 LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
754 << s;
755 }
756 return std::make_pair(cuda_device_id, dev_allocator);
757 }
758
759 // Use the device requested by the engine.
760 auto device_set = params.cluster->GetDeviceSet();
761 std::vector<Device*> devices;
762 DeviceNameUtils::ParsedName parsed_name;
763 if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
764 parsed_name.has_id) {
765 device_set->FindMatchingDevices(parsed_name, &devices);
766 }
767 if (!devices.empty()) {
768 if (devices.size() > 1) {
769 string msg = "Found multiple matching devices using name '";
770 StrAppend(&msg, engine.device, "': ");
771 for (auto d : devices) StrAppend(&msg, d->name(), ", ");
772 StrAppend(&msg, ". Will get the allocator from first one.");
773 LOG(WARNING) << msg;
774 }
775 AllocatorAttributes alloc_attr;
776 cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
777 dev_allocator = devices[0]->GetAllocator(alloc_attr);
778 VLOG(1) << "Using allocator " << dev_allocator->Name()
779 << " and cuda_device_id " << cuda_device_id;
780 } else {
781 LOG(WARNING) << "Cluster is set but device '" << engine.device
782 << "' is not found in the cluster";
783 }
784 return std::make_pair(cuda_device_id, dev_allocator);
785 }
786
787 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)788 Status ConvertAfterShapes(const ConversionParams& params) {
789 // Sanity checks.
790 if (params.precision_mode == TrtPrecisionMode::INT8) {
791 if (params.use_calibration && !params.use_function_backup) {
792 return errors::InvalidArgument(
793 "Calibration requires enabling fallback to TF function execution.");
794 }
795 } else {
796 if (params.use_calibration) {
797 return errors::InvalidArgument(
798 "Calibration with FP32 or FP16 is not supported.");
799 }
800 }
801
802 // Convert graphdef to graph.
803 FunctionLibraryDefinition flib(OpRegistry::Global(),
804 params.input_graph_def->library());
805 Graph graph(flib);
806 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
807 *params.input_graph_def, &graph));
808
809 // Segment the graph into subgraphs that can be converted to TensorRT
810 segment::SegmentOptions segment_options;
811 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
812 for (auto node : *(params.output_names)) {
813 segment_options.exclude_node_list.insert(node);
814 }
815 segment_options.minimum_segment_size = params.minimum_segment_size;
816 segment::SegmentNodesVector initial_segments;
817 TrtCandidateSelector candidate_selector(*params.graph_properties,
818 params.precision_mode);
819 TF_RETURN_IF_ERROR(segment::SegmentGraph(
820 &graph,
821 std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector,
822 std::placeholders::_1),
823 // Input validation is already done by TrtCandidateSelector, so we don't
824 // need to check the input edges.
825 [](const Edge* edge) { return true; }, OutputEdgeValidator(),
826 segment_options, &initial_segments));
827 LOG(INFO) << "Number of TensorRT candidate segments: "
828 << initial_segments.size();
829
830 // Get the EngineInfo for each segment.
831 std::unordered_map<string, Node*> node_map;
832 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
833 float total_num_nodes_in_segments = 0.;
834 std::vector<EngineInfo> engine_segments;
835 engine_segments.reserve(initial_segments.size());
836 std::vector<Node*> reverse_topo_order;
837 GetPostOrder(graph, &reverse_topo_order);
838 size_t total_engine_bytes_size = 0;
839 std::vector<size_t> engine_bytes_size;
840 segment::SegmentNodesVector converted_segments;
841 converted_segments.reserve(initial_segments.size());
842 for (size_t t = 0; t < initial_segments.size(); t++) {
843 auto& curr_segment = initial_segments.at(t);
844 EngineInfo curr_engine;
845 curr_engine.engine_name = StrCat("TRTEngineOp_", t);
846 Status status =
847 GetEngineInfo(&graph, *params.graph_properties, curr_segment.first,
848 node_map, reverse_topo_order, &curr_engine);
849 if (!status.ok()) {
850 LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
851 << status;
852 continue;
853 }
854 curr_engine.precision_mode = params.precision_mode;
855 curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration)
856 ? EngineInfo::EngineType::TRTDynamic
857 : EngineInfo::EngineType::TRTStatic);
858 curr_engine.use_calibration = params.use_calibration;
859 curr_engine.cached_engine_batches = params.cached_engine_batches;
860 curr_engine.maximum_cached_engines = params.max_cached_engines;
861 if (params.use_function_backup) {
862 status = RegisterSegmentFunctionToFunctionLibrary(
863 &graph, curr_engine.segment_graph_def, curr_engine.engine_name);
864 if (!status.ok()) {
865 LOG(WARNING) << "Failed to register segment graphdef as a function "
866 << t << ": " << status;
867 continue;
868 }
869 }
870
871 engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
872 total_engine_bytes_size += engine_bytes_size.back();
873 total_num_nodes_in_segments += curr_segment.first.size();
874 engine_segments.push_back(std::move(curr_engine));
875 converted_segments.push_back(std::move(curr_segment));
876
877 if (VLOG_IS_ON(8)) {
878 string fname = engine_segments.back().engine_name;
879 StrAppend(&fname, ".pb");
880 std::fstream f;
881 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
882 f << engine_segments.at(t).segment_graph_def.SerializeAsString();
883 f.close();
884 }
885 }
886
887 // Create a TRT node for each segment using its EngineInfo.
888 int old_cuda_device = 0;
889 auto err = cudaGetDevice(&old_cuda_device);
890 if (err != cudaSuccess) {
891 LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
892 }
893 VLOG(1) << "Current cuda device is " << old_cuda_device;
894 std::vector<Node*> engine_nodes;
895 engine_nodes.resize(engine_segments.size());
896 for (int i = 0; i < engine_segments.size(); ++i) {
897 auto& engine = engine_segments.at(i);
898 // Partition the workspace size by the average of node ratio and segment
899 // graphdef size
900 engine.max_workspace_size_bytes =
901 params.max_workspace_size_bytes *
902 (engine_bytes_size.at(i) / total_engine_bytes_size +
903 converted_segments.at(i).first.size() / total_num_nodes_in_segments) /
904 2.0;
905 // The allocator is used to build the engine. The build and the built engine
906 // will be destroyed after we get the serialized engine string, so it's fine
907 // to use unique_ptr here.
908 std::unique_ptr<TRTBaseAllocator> alloc;
909 auto device_alloc = GetDeviceAndAllocator(params, engine);
910 int cuda_device_id = 0;
911 if (device_alloc.first >= 0) {
912 cuda_device_id = device_alloc.first;
913 alloc.reset(new TRTDeviceAllocator(device_alloc.second));
914 } else {
915 // Setting allocator as nullptr should get revert to the cudamalloc
916 LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
917 }
918 cudaSetDevice(cuda_device_id);
919 auto status =
920 CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
921 alloc.get(), &engine_nodes);
922
923 string msg = StrCat("TensorRT node ", engine.engine_name,
924 " added for segment ", i, " consisting of ",
925 converted_segments.at(i).first.size(), " nodes");
926 if (status.ok()) {
927 LOG(INFO) << msg << " succeeded.";
928 } else {
929 // Graph is not modified.
930 LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF...";
931 }
932 if (VLOG_IS_ON(1)) {
933 msg = "Segment consists of nodes: ";
934 for (const Node* node : converted_segments.at(i).first) {
935 StrAppend(&msg, node->name(), ", ");
936 }
937 VLOG(1) << msg;
938 }
939
940 // If status is ok, we successfully added the node to the graph and can
941 // remove segment ops. Otherwise graph is not modified.
942 if (status.ok()) {
943 for (const Node* node : converted_segments.at(i).first) {
944 graph.RemoveNode(const_cast<Node*>(node));
945 }
946 }
947 }
948 cudaSetDevice(old_cuda_device);
949 graph.ToGraphDef(params.output_graph_def);
950 VLOG(1) << "Returning from conversion";
951 return Status::OK();
952 }
953
954 } // namespace convert
955 } // namespace tensorrt
956 } // namespace tensorflow
957
958 #endif // GOOGLE_TENSORRT
959 #endif // GOOGLE_CUDA
960