1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
32 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/graph/algorithm.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
43 #include "tensorflow/core/grappler/costs/graph_properties.h"
44 #include "tensorflow/core/grappler/devices.h"
45 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
46 #include "tensorflow/core/grappler/utils.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/cleanup.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
52 #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
53 #include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/tools/graph_transforms/transform_utils.h"
56
57 #if GOOGLE_CUDA && GOOGLE_TENSORRT
58 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
59 #include "third_party/tensorrt/NvInfer.h"
60 namespace tensorflow {
61 namespace tensorrt {
62 namespace convert {
63
64 using absl::StrAppend;
65 using absl::StrCat;
66 using ::tensorflow::tensorrt::segment::ClusterProperty;
67 using ::tensorflow::tensorrt::segment::NodePtrCompare;
68 using ::tensorflow::tensorrt::segment::Segment;
69
70 namespace {
71
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)72 Status BuildNodeMap(const Graph& graph,
73 std::unordered_map<string, Node*>* node_map) {
74 for (auto* node : graph.op_nodes()) {
75 if (!node_map->insert({node->name(), node}).second) {
76 return errors::AlreadyExists("Node name is not unique in graph: " +
77 node->name());
78 }
79 }
80 return Status::OK();
81 }
82
GetEngineType(const ConversionParams & params)83 EngineInfo::EngineType GetEngineType(const ConversionParams& params) {
84 return (params.is_dyn_op || params.use_calibration)
85 ? EngineInfo::EngineType::TRTDynamic
86 : EngineInfo::EngineType::TRTStatic;
87 }
88
89 // Returns true when use_implicit_batch is false or when we are building dynamic
90 // engine, to allow unknown size for dimensions rather than dimension 0.
AllowDynamicNonBatchDimension(const ConversionParams & params)91 bool AllowDynamicNonBatchDimension(const ConversionParams& params) {
92 return !params.use_implicit_batch ||
93 GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
94 }
95
96 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::__anon11ab94fd0111::EdgePtrCompare97 bool operator()(const Edge* lhs, const Edge* rhs) const {
98 return lhs->id() < rhs->id();
99 }
100 };
101
102 // TODO(laigd): instead of deciding the device here, the converter should accept
103 // a device name as one of the conversion parameter so users can control on
104 // which device they want to run the conversion.
GetFirstValidDeviceId()105 std::pair<TfDeviceId, PlatformDeviceId> GetFirstValidDeviceId() {
106 for (int tf_device_id_value = 0; tf_device_id_value < 100;
107 ++tf_device_id_value) {
108 TfDeviceId tf_device_id(tf_device_id_value);
109 PlatformDeviceId platform_device_id;
110 Status s =
111 GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id);
112 if (s.ok()) {
113 VLOG(1) << "Found TF GPU " << tf_device_id.value() << " at cuda device "
114 << platform_device_id.value();
115 return std::make_pair(tf_device_id, platform_device_id);
116 }
117 }
118 LOG(ERROR) << "Could not find any TF GPUs";
119 return std::make_pair(TfDeviceId(-1), PlatformDeviceId(-1));
120 }
121
122 // Returns false for const nodes (we intend to drop control edges from those).
ShallKeepControlEdgeFrom(const Node * input_node)123 bool ShallKeepControlEdgeFrom(const Node* input_node) {
124 if (!input_node) {
125 LOG(ERROR) << "Node pointer is null, this should not happen";
126 return false;
127 }
128 return input_node->type_string() != "Const";
129 }
130
131 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const Segment & segment,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)132 Status GetEngineInfo(const Graph* g,
133 const grappler::GraphProperties& graph_properties,
134 const Segment& segment,
135 const std::unordered_map<string, Node*>& node_map,
136 const std::vector<Node*>& reverse_topo_order,
137 EngineInfo* info) {
138 std::vector<const Node*> subgraph_nodes; // Topologically sorted nodes.
139 std::set<const Node*> added_const_nodes; // Used to prevent double insertion.
140
141 const ClusterProperty& segment_property = segment.property;
142 const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
143
144 // The device assignment accumulated from the compatible device assignments
145 // for the nodes in the segment.
146 const DeviceNameUtils::ParsedName segment_device =
147 segment_property.DeviceName();
148 info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
149
150 // Map from src_node_name+port to the unique port numbers of the TRT op, where
151 // the src_node_name is the name of the source node of the input/output
152 // edge, thus there must not be any duplicates since source nodes of
153 // input/output edges must be in different split of the graph.
154 // TODO(aaroey): consider using node id and port instead.
155 // TODO(aaroey): using topo order instead of reverting reverse topo order.
156 std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
157 for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
158 ++it) {
159 const Node* node = *it;
160 if (segment_nodes.count(node) == 0) continue;
161 subgraph_nodes.push_back(node);
162
163 const int node_id = node->id();
164 const string& node_name = node->name();
165
166 // Create input connections. Sort edges first to make deterministic since
167 // in_edges is a set of pointers.
168 std::vector<const Edge*> in_edges(node->in_edges().begin(),
169 node->in_edges().end());
170 std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
171 for (const auto edge : in_edges) {
172 auto input_node = edge->src();
173 if (input_node->IsSource() || segment_nodes.count(input_node)) {
174 continue;
175 }
176 if (edge->IsControlEdge()) {
177 if (ShallKeepControlEdgeFrom(input_node)) {
178 // Non-Const control input.
179 info->connections.emplace_back(input_node->name(), input_node->id(),
180 node_name, node_id,
181 /*input_edge=*/true);
182 }
183 } else if (input_node->type_string() == "Const") {
184 // Add constant data input nodes into the segment graphdef (thus also in
185 // the engine). We don't care if it has other output edges going into
186 // other engines or TF nodes. Since we add it only to the segment
187 // graphdef, not the segment itself, it won't be removed from the graph.
188 // If it doesn't have any edges, TF will prune it out.
189 //
190 // Note that the segmenter already ensure that the constant data input
191 // is valid and supported by the engine.
192 if (!added_const_nodes.insert(input_node).second) {
193 // Already added before.
194 continue;
195 }
196 VLOG(1) << "Adding const node " << input_node->name();
197 } else {
198 // Non-const data input.
199 int port = Graph::kControlSlot - 1;
200 // Use the source non-segment node name/port as key.
201 const string s = StrCat(input_node->name(), ":", edge->src_output());
202 VLOG(1) << "Input edge = " << s;
203 if (input_to_engine_port.count(s)) {
204 port = input_to_engine_port.at(s);
205 } else {
206 port = input_to_engine_port.size();
207 input_to_engine_port.insert({s, port});
208 }
209 info->connections.emplace_back(
210 input_node->name(), input_node->id(), edge->src_output(), node_name,
211 node_id, edge->dst_input(), /*input_edge=*/true, port);
212 }
213 }
214 // Create output connections. Sort edges first to make deterministic since
215 // out_edges is a set of pointers.
216 std::vector<const Edge*> out_edges(node->out_edges().begin(),
217 node->out_edges().end());
218 std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
219 for (const auto edge : out_edges) {
220 auto output_node = edge->dst();
221 if (output_node->IsSink() || segment_nodes.count(output_node)) {
222 continue;
223 }
224 if (edge->IsControlEdge()) {
225 // Control output.
226 if (ShallKeepControlEdgeFrom(node)) {
227 info->connections.emplace_back(output_node->name(), output_node->id(),
228 node_name, node_id,
229 /*input_edge=*/false);
230 }
231 } else {
232 // Data output.
233 int port = Graph::kControlSlot - 1;
234 // Use the source segment node name/port as key.
235 const string s = StrCat(node_name, ":", edge->src_output());
236 VLOG(1) << "Output edge = " << s;
237 if (output_to_engine_port.count(s)) {
238 port = output_to_engine_port.at(s);
239 } else {
240 port = output_to_engine_port.size();
241 output_to_engine_port.insert({s, port});
242 }
243 info->connections.emplace_back(
244 output_node->name(), output_node->id(), edge->dst_input(),
245 node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
246 }
247 }
248 } // For each segment node in topological order.
249
250 // Construct the const nodes first.
251 subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
252 added_const_nodes.end());
253 TF_RETURN_IF_ERROR(
254 ConvertSegmentToGraphDef(g, graph_properties, subgraph_nodes, info));
255 VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
256 << "' to a GraphDef " << info->has_int32_input;
257 if (segment_device.has_type) {
258 // If the accumulated device assignment for the segment has a device type,
259 // the segmenter guarantees the device type is GPU. Use the device
260 // assignment in this case.
261 if (segment_device.type != "GPU") {
262 return errors::Internal(
263 "segment device is not GPU: ",
264 DeviceNameUtils::ParsedNameToString(segment_device));
265 }
266 info->device = DeviceNameUtils::ParsedNameToString(segment_device);
267 } else {
268 TfDeviceId tf_device_id;
269 PlatformDeviceId platform_device_id;
270 std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
271 if (tf_device_id.value() >= 0) {
272 DeviceNameUtils::ParsedName parsed_name;
273 parsed_name.type = "GPU";
274 parsed_name.has_type = true;
275 parsed_name.id = tf_device_id.value();
276 parsed_name.has_id = true;
277 info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
278 } else {
279 VLOG(1) << "No device is assigned to the segment. A device will be "
280 "assigned during graph execution (inference).";
281 }
282 }
283 return Status::OK();
284 }
285
286 // Helper function to update edge connection from the removed node to the
287 // engine node. If an outside node is gone, it must have been absorbed into
288 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)289 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
290 const size_t my_engine_id,
291 const std::vector<Node*>& engine_nodes,
292 const bool is_input_edge, const string& node_name,
293 Node** node, int* port) {
294 for (size_t t = 0; t < infos.size(); ++t) {
295 if (t == my_engine_id) {
296 continue;
297 }
298 const auto& info = infos.at(t);
299 for (const auto& eng_conn : info.connections) {
300 // If the connection being updated is an input connection, the source of
301 // the connection must be an output connection of another engine. And vise
302 // versa.
303 if (is_input_edge == eng_conn.is_input_edge) continue;
304 if (eng_conn.inside_node_name == node_name &&
305 eng_conn.inside_port == *port) {
306 *node = CHECK_NOTNULL(engine_nodes[t]);
307 QCHECK_EQ(info.engine_name, (**node).name())
308 << "Engine name mismatch: " << info.engine_name << " vs "
309 << (**node).name();
310 *port = eng_conn.port_number;
311 return;
312 }
313 }
314 }
315 LOG(FATAL) << "Node " << node_name << " not found in any engine.";
316 }
317
318 // Function to insert a TRT engine node into the graph.
319 // Create engine nodes in the following way:
320 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
321 // 2. When an engine node is created, add it into the graph with necessary
322 // re-wiring.
323 // 2.1. If the outside connected node is existing, connect the engine
324 // node to it.
325 // 2.2. If the outside connected node is gone, it must have been absorted
326 // into another engine node (which was processed before the processing
327 // one). Connect to the pre-existing engine node instead.
328 // 3. In this way, we ensure the graph is topologically sort-able after each
329 // invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int default_max_batch_size,Graph * graph,std::vector<Node * > * engine_nodes)330 Status CreateTRTNode(const ConversionParams& params,
331 const std::vector<EngineInfo>& infos, int pos,
332 int default_max_batch_size, Graph* graph,
333 std::vector<Node*>* engine_nodes) {
334 const auto& info = infos.at(pos);
335 std::vector<tensorflow::TensorShapeProto> input_shape_protos;
336 std::vector<PartialTensorShape> input_shapes;
337 std::vector<NodeDefBuilder::NodeOut> inputs;
338 std::vector<Node*> input_nodes;
339 std::vector<Node*> control_input_nodes;
340 std::unordered_set<string> control_input_names;
341 std::vector<DataType> out_types;
342
343 VLOG(1) << "Processing " << info.engine_name;
344 // Collect needed info for creating the engine node in the graph
345 for (const auto& conn : info.connections) {
346 // Control edges
347 if (conn.is_control_edge()) {
348 // Skip control outputs for now. control output info are not needed for
349 // node creation and will be processed later.
350 if (!conn.is_input_edge) continue;
351
352 // Rewrire control input if it's not found in original graph.
353 Node* input_node = graph->FindNodeId(conn.outside_id);
354 int port = Graph::kControlSlot;
355 if (!input_node) {
356 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
357 conn.outside_node_name, &input_node, &port);
358 QCHECK_EQ(Graph::kControlSlot, port);
359 }
360 if (!control_input_names.insert(input_node->name()).second) {
361 continue;
362 }
363 control_input_nodes.push_back(input_node);
364 VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
365 << info.engine_name;
366 } else {
367 // Data edges
368 if (!conn.is_input_edge) {
369 // Set the data types of output edge.
370 if (out_types.size() <= conn.port_number) {
371 out_types.resize(conn.port_number + 1);
372 }
373 out_types.at(conn.port_number) = conn.connection_type;
374 } else {
375 // Set the shapes and data types of input edge.
376 if (input_shapes.size() <= conn.port_number) {
377 input_shape_protos.resize(conn.port_number + 1);
378 input_shapes.resize(conn.port_number + 1);
379 }
380 conn.outside_shape.AsProto(&input_shape_protos.at(conn.port_number));
381 input_shapes.at(conn.port_number) = conn.outside_shape;
382 // Shape must be fully defined (excluding batch dimension) for static
383 // mode.
384 if (params.use_implicit_batch &&
385 info.engine_type == EngineInfo::EngineType::TRTStatic) {
386 for (int i = 1; i < conn.outside_shape.dims(); i++) {
387 if (conn.outside_shape.dim_size(i) <= 0) {
388 return errors::Internal(
389 "Not fully defined input shape when in static mode which "
390 "should have been excluded by the segmenter. ");
391 }
392 }
393 }
394
395 // Rewrire data input if it's not found in original graph.
396 Node* input_node = graph->FindNodeId(conn.outside_id);
397 int port = conn.outside_port;
398 if (!input_node) {
399 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
400 conn.outside_node_name, &input_node, &port);
401 }
402 if (std::find_if(
403 std::begin(inputs), std::end(inputs),
404 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
405 return inp.node == input_node->name() && inp.index == port;
406 }) == std::end(inputs)) {
407 inputs.emplace_back(input_node->name(), port, conn.connection_type);
408 input_nodes.push_back(CHECK_NOTNULL(input_node));
409 VLOG(1) << "Engine Input " << input_node->name() << ":" << port
410 << " -> " << info.engine_name << ":" << inputs.size() - 1;
411 }
412 }
413 }
414 }
415 // We don't support segments with no inputs. Fall back to native TF here to
416 // avoid crash later. Constant folding should've folded the ops that make up
417 // these segments.
418 if (inputs.empty()) {
419 return errors::Internal(
420 "Segment has no inputs (possible constfold failure)");
421 }
422
423 const bool calibrate_int8 =
424 (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
425 // Build the engine and get its serialized representation.
426 string segment_string;
427
428 int max_batch_size = info.max_batch_size.has_value()
429 ? info.max_batch_size.value()
430 : default_max_batch_size;
431
432 if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
433 std::pair<int, Allocator*> device_allocator =
434 GetDeviceAndAllocator(params, info);
435 int cuda_device_id = 0;
436 std::unique_ptr<TRTBaseAllocator> trt_allocator;
437 if (device_allocator.first >= 0) {
438 cuda_device_id = device_allocator.first;
439 trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
440 } else {
441 // The value in trt_allocator is a nullptr and cudamalloc will be used.
442 LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
443 "device 0 and use cudamalloc as an allocator";
444 }
445 cudaSetDevice(cuda_device_id);
446
447 auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
448
449 // Create static engines with precision_mode fp32/fp16.
450 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
451 TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
452 info.segment_graph_def,
453 calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
454 max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
455 trt_allocator.get(), /*calibrator=*/nullptr, &engine,
456 info.use_calibration, params.use_implicit_batch,
457 /*convert_successfully=*/nullptr,
458 /*profile=*/nullptr, info.engine_name));
459 TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
460 segment_string = string(static_cast<const char*>(engine_data->data()),
461 engine_data->size());
462 }
463
464 string prec_string;
465 TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
466 NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
467 if (!info.device.empty()) node_builder.Device(info.device);
468 if (VLOG_IS_ON(1)) {
469 string ins = StrCat(info.engine_name, " inputs= ");
470 for (const auto& ii : inputs) {
471 StrAppend(&ins, ii.node, ":", ii.index, " ");
472 }
473 VLOG(1) << ins;
474 }
475 node_builder.Input(inputs);
476 for (const string& c : control_input_names) {
477 node_builder.ControlInput(c);
478 }
479
480 NodeDef trt_node;
481 NameAttrList function;
482 function.set_name(StrCat(info.engine_name, "_native_segment"));
483 node_builder.Attr("input_shapes", input_shape_protos)
484 .Attr("static_engine",
485 info.engine_type == EngineInfo::EngineType::TRTStatic)
486 .Attr("segment_func", function)
487 .Attr("serialized_segment", segment_string)
488 .Attr("calibration_data", "")
489 .Attr("max_cached_engines_count", info.maximum_cached_engines)
490 .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
491 .Attr("max_batch_size", max_batch_size)
492 .Attr("precision_mode", prec_string)
493 .Attr("use_calibration", info.use_calibration)
494 .Attr("_use_implicit_batch", params.use_implicit_batch)
495 .Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
496 .Attr("OutT", out_types);
497
498 if (!params.use_implicit_batch) {
499 node_builder.Attr("profile_strategy",
500 ProfileStrategyToName(params.profile_strategy));
501 }
502
503 Status status = node_builder.Finalize(&trt_node);
504 if (!status.ok()) {
505 LOG(ERROR) << "Node construction failed with" << status;
506 return status;
507 }
508 VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
509
510 // Up until this point, graph is not modified. If we return !status.ok() from
511 // here, this segment will be skipped
512 // TODO(aaroey): let it return proper error status for the following logic
513 // instead of checking fail.
514 Node* engine_node = graph->AddNode(trt_node, &status);
515 (*engine_nodes)[pos] = engine_node;
516 if (!status.ok()) {
517 LOG(ERROR) << "Adding node failed " << status;
518 return status;
519 }
520 // Add control input and input edges to the engine node.
521 for (const auto in : control_input_nodes) {
522 VLOG(1) << "Connecting control edge from " << in->name() << " to "
523 << engine_node->name();
524 graph->AddControlEdge(in, engine_node);
525 }
526 VLOG(1) << "input_nodes size = " << input_nodes.size();
527 for (int i = 0; i < input_nodes.size(); ++i) {
528 Node* n = CHECK_NOTNULL(input_nodes[i]);
529 const auto& in = inputs[i];
530 VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
531 << " to " << engine_node->name() << ":" << i;
532 graph->AddEdge(n, in.index, engine_node, i);
533 }
534
535 // Updates the inputs of output edges destination nodes, and point them to the
536 // engine node.
537 for (auto& conn : info.connections) {
538 if (conn.is_input_edge) {
539 continue;
540 }
541 Node* output_node = graph->FindNodeId(conn.outside_id);
542 int port = conn.outside_port;
543 if (!output_node) {
544 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
545 conn.outside_node_name, &output_node, &port);
546 }
547 if (conn.is_control_edge()) {
548 VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
549 << output_node->name();
550 QCHECK_EQ(Graph::kControlSlot, port);
551 graph->AddControlEdge(engine_node, output_node);
552 } else {
553 VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
554 << conn.port_number << " to " << output_node->name() << ":"
555 << port;
556 // Use UpdateEdge() to avoid adding the same edge multiple times.
557 TF_CHECK_OK(
558 graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
559 }
560 }
561 return Status::OK();
562 }
563
GetNextGraphSequenceNumber()564 int64 GetNextGraphSequenceNumber() {
565 static std::atomic<int64> graph_sequence_num;
566 return graph_sequence_num++;
567 }
568
569 constexpr char kCastInputTypeAttrName[] = "SrcT";
570
571 // Transforms node = cast(x, fp32) where datatype(x) != fp16 to:
572 // castToFp16 = cast(x, fp16)
573 // node = cast(castToFp16, fp32)
574 //
MaybeRewriteCastToFp32(GraphDef * graph_def,NodeDef * node_def)575 Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) {
576 if (node_def->op() != "Cast") {
577 return Status::OK();
578 }
579
580 DataTypeVector input_types;
581 DataTypeVector output_types;
582 TF_RETURN_IF_ERROR(
583 graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types));
584
585 if (input_types.size() != 1 || output_types.size() != 1) {
586 return errors::Internal("Bad cast operation");
587 }
588
589 if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) {
590 return Status::OK();
591 }
592
593 VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString();
594
595 NodeDef* castToFp16 = graph_def->add_node();
596 for (auto attr_value : node_def->attr()) {
597 (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second;
598 }
599 castToFp16->set_name(node_def->name() + "_split");
600 castToFp16->set_op("Cast");
601 castToFp16->set_device(node_def->device());
602 castToFp16->add_input(node_def->input(0));
603 (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF);
604
605 node_def->set_input(0, castToFp16->name() + ":0");
606 (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF);
607
608 VLOG(2) << castToFp16->DebugString();
609 VLOG(2) << node_def->DebugString();
610
611 return Status::OK();
612 }
613
614 } // namespace
615
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name,bool has_int32_input)616 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
617 Graph* graph, const string& engine_name,
618 bool has_int32_input) {
619 Graph segment_graph(graph->flib_def());
620 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
621 segment_graph_def, &segment_graph));
622 FunctionDefLibrary library;
623 auto segment_func = library.add_function();
624 TF_RETURN_IF_ERROR(GraphToFunctionDef(
625 segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
626 if (has_int32_input) {
627 // Setting this attribute value informs TensorFlow that the int32 _Arg node
628 // inputs are on device memory during native segment execution. We only set
629 // the attribute value when there is any int32 input because the attribute
630 // is not compatible with is_multi_device_function.
631 SetAttrValue(
632 true,
633 &(*segment_func
634 ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]);
635 }
636 if (VLOG_IS_ON(7)) {
637 VLOG(7) << engine_name << " Function_Def ";
638 VLOG(7) << segment_func->DebugString();
639 }
640 VLOG(1) << "Adding funcdef " << segment_func->signature().name()
641 << " to graphlib";
642 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
643 return Status::OK();
644 }
645
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)646 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
647 const EngineInfo& engine) {
648 int cuda_device_id = -1;
649 Allocator* dev_allocator = nullptr;
650 if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
651 engine.device.empty()) {
652 // If device is not set, use the first found GPU device for the conversion.
653 TfDeviceId tf_device_id;
654 PlatformDeviceId platform_device_id;
655 std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
656 cuda_device_id = platform_device_id.value();
657 if (cuda_device_id >= 0) {
658 GPUOptions gpu_options;
659 // If the TF to Cuda gpu id mapping exist, the device and corresponding
660 // allocator must have been initialized already, so the
661 // GetGPUAllocator() call won't create a new allocator.
662 dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
663 gpu_options, tf_device_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
664 }
665 return std::make_pair(cuda_device_id, dev_allocator);
666 }
667
668 // Use the device requested by the engine.
669 auto device_set = params.cluster->GetDeviceSet();
670 std::vector<Device*> devices;
671 DeviceNameUtils::ParsedName parsed_name;
672 if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
673 parsed_name.has_id) {
674 device_set->FindMatchingDevices(parsed_name, &devices);
675 }
676 if (!devices.empty()) {
677 if (devices.size() > 1) {
678 string msg = "Found multiple matching devices using name '";
679 StrAppend(&msg, engine.device, "': ");
680 for (auto d : devices) StrAppend(&msg, d->name(), ", ");
681 StrAppend(&msg, ". Will get the allocator from first one.");
682 LOG_WARNING_WITH_PREFIX << msg;
683 }
684 AllocatorAttributes alloc_attr;
685 cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
686 dev_allocator = devices[0]->GetAllocator(alloc_attr);
687 VLOG(1) << "Using allocator " << dev_allocator->Name()
688 << " and cuda_device_id " << cuda_device_id;
689 } else {
690 LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device
691 << "' is not found in the cluster";
692 }
693 return std::make_pair(cuda_device_id, dev_allocator);
694 }
695
696 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)697 Status ConvertAfterShapes(const ConversionParams& params) {
698 // Sanity checks.
699 if (params.precision_mode != TrtPrecisionMode::INT8 &&
700 params.use_calibration) {
701 return errors::InvalidArgument(
702 "Calibration with FP32 or FP16 is not supported.");
703 }
704
705 // Make a copy of the input_graph_def because grappler doesn't allow changes
706 // to the input_graph_def and GraphProperties only accepts GraphDef, but not
707 // Graph, as inputs.
708 //
709 // If the overhead of copying the input_graph_def becomes a concern, we can
710 // avoid the copy by (1) enhancing the GraphPropertiers representation to
711 // allow adding shape properties for newly created graph nodes and (2) rewrite
712 // the GraphDef transformation to Graph transformation.
713 GraphDef modified_graph_def = params.grappler_item->graph;
714 // When precision_mode is FP16, transform cast(x, fp32) to
715 // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be
716 // included in the TRTEngineOp as an TensorRT Identity layer for performance:
717 // . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16
718 // precision.
719 // . Changing the input to the TRTEngine from fp32 to fp16 may reduce data
720 // moving from the host to the GPU.
721 if (params.precision_mode == TrtPrecisionMode::FP16) {
722 for (int i = 0; i < modified_graph_def.node_size(); i++) {
723 NodeDef* node_def = modified_graph_def.mutable_node(i);
724 TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&modified_graph_def, node_def));
725 }
726 }
727
728 // Construct a GrapplerItem using the modified graph_def and the input
729 // grappler_item.
730 grappler::GrapplerItem grappler_item =
731 params.grappler_item->WithGraph(std::move(modified_graph_def));
732 const GraphDef& graph_def = grappler_item.graph;
733
734 grappler::GraphProperties static_graph_properties(grappler_item);
735 TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
736
737 // Convert graphdef to graph.
738 FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
739 Graph graph(flib);
740 TF_RETURN_IF_ERROR(
741 ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph));
742
743 // Segment the graph into subgraphs that can be converted to TensorRT
744 segment::SegmentOptions segment_options;
745 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
746 for (const auto& node : *(params.output_names)) {
747 segment_options.exclude_node_list.insert(node);
748 }
749 segment_options.minimum_segment_size = params.minimum_segment_size;
750 segment_options.use_implicit_batch = params.use_implicit_batch;
751 if (segment_options.use_implicit_batch)
752 segment_options.maximum_batch_size = params.max_batch_size;
753 segment_options.allow_dynamic_non_batch_dim =
754 AllowDynamicNonBatchDimension(params);
755
756 segment::SegmentVector initial_segments;
757 TrtNodeValidator validator(static_graph_properties, params.precision_mode,
758 params.use_calibration, params.use_implicit_batch);
759 TF_RETURN_IF_ERROR(segment::SegmentGraph(
760 &graph, &static_graph_properties,
761 std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
762 std::placeholders::_1),
763 // Input validation is already done by TrtNodeValidator, so we don't
764 // need to check the input edges.
765 [](const Edge* edge) { return true; }, OutputEdgeValidator(),
766 segment_options, &initial_segments));
767 LOG(INFO) << "Number of TensorRT candidate segments: "
768 << initial_segments.size();
769
770 // Get the EngineInfo for each segment.
771 std::unordered_map<string, Node*> node_map;
772 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
773 std::vector<EngineInfo> engine_segments;
774 engine_segments.reserve(initial_segments.size());
775 std::vector<Node*> reverse_topo_order;
776 GetPostOrder(graph, &reverse_topo_order);
777 segment::SegmentVector converted_segments;
778 converted_segments.reserve(initial_segments.size());
779 string engine_name_prefix =
780 StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
781 for (size_t t = 0; t < initial_segments.size(); t++) {
782 auto& curr_segment = initial_segments.at(t);
783 EngineInfo curr_engine;
784 curr_engine.engine_name = StrCat(engine_name_prefix, t);
785
786 bool int8_no_calib = (params.use_calibration == false &&
787 params.precision_mode == TrtPrecisionMode::INT8);
788 bool has_qdq = false;
789 if (int8_no_calib) {
790 has_qdq = absl::c_any_of(reverse_topo_order, IsQuantizeAndDequantizeOp);
791 }
792
793 Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment,
794 node_map, reverse_topo_order, &curr_engine);
795 if (!status.ok()) {
796 LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t
797 << ": " << status;
798 continue;
799 }
800
801 curr_engine.engine_type = GetEngineType(params);
802 curr_engine.use_calibration = params.use_calibration;
803 // Building cuda engines for INT8 without calibration and without dynamic
804 // range info cause TRT failure. Avoid this situation by setting the
805 // precision to FP16.
806 if (int8_no_calib && !has_qdq) {
807 VLOG(1) << "Set engine precision to FP16 due to missing QDQ OP";
808 curr_engine.precision_mode = TrtPrecisionMode::FP16;
809 } else {
810 curr_engine.precision_mode = params.precision_mode;
811 }
812 curr_engine.maximum_cached_engines = params.max_cached_engines;
813 curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
814 if (!curr_engine.max_batch_size.has_value()) {
815 curr_engine.max_batch_size = params.max_batch_size;
816 }
817
818 status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
819 &graph, curr_engine.engine_name,
820 curr_engine.has_int32_input);
821
822 if (!status.ok()) {
823 LOG_WARNING_WITH_PREFIX
824 << "Failed to register segment graphdef to the library " << t << ": "
825 << status;
826 continue;
827 }
828
829 engine_segments.push_back(std::move(curr_engine));
830 converted_segments.push_back(std::move(curr_segment));
831
832 if (VLOG_IS_ON(8)) {
833 string fname = engine_segments.back().engine_name;
834 StrAppend(&fname, ".pb");
835 std::fstream f;
836 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
837 f << engine_segments.at(t).segment_graph_def.SerializeAsString();
838 f.close();
839 }
840 }
841
842 // Save the cuda device if we may need to switch to another cuda device to
843 // build static engines.
844 absl::optional<int> old_cuda_device = absl::nullopt;
845 if (!params.is_dyn_op) {
846 int cuda_device_id;
847 cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
848 if (cuda_error != cudaSuccess) {
849 LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
850 << cudaGetErrorString(cuda_error);
851 } else {
852 VLOG(1) << "Current cuda device is " << cuda_device_id;
853 old_cuda_device = cuda_device_id;
854 }
855 }
856
857 auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
858 if (old_cuda_device.has_value()) {
859 cudaSetDevice(old_cuda_device.value());
860 }
861 });
862
863 std::vector<Node*> engine_nodes;
864 engine_nodes.resize(engine_segments.size());
865 for (int i = 0; i < engine_segments.size(); ++i) {
866 auto& engine = engine_segments.at(i);
867 // TODO(b/170762693): implement the heuristic to calculate
868 // max_workspace_size_bytes.
869 engine.max_workspace_size_bytes = params.max_workspace_size_bytes;
870 VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
871 << engine.engine_name;
872 auto status = CreateTRTNode(params, engine_segments, i,
873 params.max_batch_size, &graph, &engine_nodes);
874
875 string msg = StrCat("segment ", i, " consisting of ",
876 converted_segments.at(i).nodes.size(), " nodes by ",
877 engine.engine_name);
878 if (status.ok()) {
879 LOG(INFO) << "Replaced " << msg << ".";
880 } else {
881 // Graph is not modified.
882 LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg
883 << " reason: " << status.error_message()
884 << " (keeping original segment).";
885 }
886 if (VLOG_IS_ON(1)) {
887 msg = "Segment consists of nodes: ";
888 for (const Node* node : converted_segments.at(i).nodes) {
889 StrAppend(&msg, node->name(), ", ");
890 }
891 VLOG(1) << msg;
892 }
893
894 // If status is ok, we successfully added the node to the graph and can
895 // remove segment ops. Otherwise graph is not modified.
896 if (status.ok()) {
897 for (const Node* node : converted_segments.at(i).nodes) {
898 graph.RemoveNode(const_cast<Node*>(node));
899 }
900 }
901 }
902 graph.ToGraphDef(params.output_graph_def);
903 VLOG(1) << "Returning from conversion";
904 return Status::OK();
905 }
906
907 } // namespace convert
908 } // namespace tensorrt
909 } // namespace tensorflow
910
911 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
912