1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
32 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/graph_to_functiondef.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/graph/algorithm.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
43 #include "tensorflow/core/grappler/costs/graph_properties.h"
44 #include "tensorflow/core/grappler/devices.h"
45 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
46 #include "tensorflow/core/grappler/utils.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/cleanup.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
52 #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
53 #include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/tools/graph_transforms/transform_utils.h"
56
57 #if GOOGLE_CUDA && GOOGLE_TENSORRT
58 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
59 #include "third_party/tensorrt/NvInfer.h"
60 namespace tensorflow {
61 namespace tensorrt {
62 namespace convert {
63
64 using absl::StrAppend;
65 using absl::StrCat;
66 using ::tensorflow::tensorrt::segment::ClusterProperty;
67 using ::tensorflow::tensorrt::segment::NodePtrCompare;
68 using ::tensorflow::tensorrt::segment::Segment;
69
70 namespace {
71
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)72 Status BuildNodeMap(const Graph& graph,
73 std::unordered_map<string, Node*>* node_map) {
74 for (auto* node : graph.op_nodes()) {
75 if (!node_map->insert({node->name(), node}).second) {
76 return errors::AlreadyExists("Node name is not unique in graph: " +
77 node->name());
78 }
79 }
80 return Status::OK();
81 }
82
GetEngineType(const ConversionParams & params)83 EngineInfo::EngineType GetEngineType(const ConversionParams& params) {
84 return (params.is_dyn_op || params.use_calibration)
85 ? EngineInfo::EngineType::TRTDynamic
86 : EngineInfo::EngineType::TRTStatic;
87 }
88
89 // Returns true when use_implicit_batch is false or when we are building dynamic
90 // engine, to allow unknown size for dimensions rather than dimension 0.
AllowDynamicNonBatchDimension(const ConversionParams & params)91 bool AllowDynamicNonBatchDimension(const ConversionParams& params) {
92 return !params.use_implicit_batch ||
93 GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
94 }
95
96 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::__anon288eb7380111::EdgePtrCompare97 bool operator()(const Edge* lhs, const Edge* rhs) const {
98 return lhs->id() < rhs->id();
99 }
100 };
101
102 // TODO(laigd): instead of deciding the device here, the converter should accept
103 // a device name as one of the conversion parameter so users can control on
104 // which device they want to run the conversion.
GetFirstValidDeviceId()105 std::pair<TfGpuId, PlatformGpuId> GetFirstValidDeviceId() {
106 for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
107 TfGpuId tf_gpu_id(tf_gpu_id_value);
108 PlatformGpuId platform_gpu_id;
109 Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
110 if (s.ok()) {
111 VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
112 << platform_gpu_id.value();
113 return std::make_pair(tf_gpu_id, platform_gpu_id);
114 }
115 }
116 LOG(ERROR) << "Could not find any TF GPUs";
117 return std::make_pair(TfGpuId(-1), PlatformGpuId(-1));
118 }
119
120 // Returns false for const nodes (we intend to drop control edges from those).
ShallKeepControlEdgeFrom(const Node * input_node)121 bool ShallKeepControlEdgeFrom(const Node* input_node) {
122 if (!input_node) {
123 LOG(ERROR) << "Node pointer is null, this should not happen";
124 return false;
125 }
126 return input_node->type_string() != "Const";
127 }
128
129 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const Segment & segment,const std::unordered_map<string,Node * > & node_map,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)130 Status GetEngineInfo(const Graph* g,
131 const grappler::GraphProperties& graph_properties,
132 const Segment& segment,
133 const std::unordered_map<string, Node*>& node_map,
134 const std::vector<Node*>& reverse_topo_order,
135 EngineInfo* info) {
136 std::vector<const Node*> subgraph_nodes; // Topologically sorted nodes.
137 std::set<const Node*> added_const_nodes; // Used to prevent double insertion.
138
139 const ClusterProperty& segment_property = segment.property;
140 const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
141
142 // The device assignment accumulated from the compatible device assignments
143 // for the nodes in the segment.
144 const DeviceNameUtils::ParsedName segment_device =
145 segment_property.DeviceName();
146 info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
147
148 // Map from src_node_name+port to the unique port numbers of the TRT op, where
149 // the src_node_name is the name of the source node of the input/output
150 // edge, thus there must not be any duplicates since source nodes of
151 // input/output edges must be in different split of the graph.
152 // TODO(aaroey): consider using node id and port instead.
153 // TODO(aaroey): using topo order instead of reverting reverse topo order.
154 std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
155 for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
156 ++it) {
157 const Node* node = *it;
158 if (segment_nodes.count(node) == 0) continue;
159 subgraph_nodes.push_back(node);
160
161 const int node_id = node->id();
162 const string& node_name = node->name();
163
164 // Create input connections. Sort edges first to make deterministic since
165 // in_edges is a set of pointers.
166 std::vector<const Edge*> in_edges(node->in_edges().begin(),
167 node->in_edges().end());
168 std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
169 for (const auto edge : in_edges) {
170 auto input_node = edge->src();
171 if (input_node->IsSource() || segment_nodes.count(input_node)) {
172 continue;
173 }
174 if (edge->IsControlEdge()) {
175 if (ShallKeepControlEdgeFrom(input_node)) {
176 // Non-Const control input.
177 info->connections.emplace_back(input_node->name(), input_node->id(),
178 node_name, node_id,
179 /*input_edge=*/true);
180 }
181 } else if (input_node->type_string() == "Const") {
182 // Add constant data input nodes into the segment graphdef (thus also in
183 // the engine). We don't care if it has other output edges going into
184 // other engines or TF nodes. Since we add it only to the segment
185 // graphdef, not the segment itself, it won't be removed from the graph.
186 // If it doesn't have any edges, TF will prune it out.
187 //
188 // Note that the segmenter already ensure that the constant data input
189 // is valid and supported by the engine.
190 if (!added_const_nodes.insert(input_node).second) {
191 // Already added before.
192 continue;
193 }
194 VLOG(1) << "Adding const node " << input_node->name();
195 } else {
196 // Non-const data input.
197 int port = Graph::kControlSlot - 1;
198 // Use the source non-segment node name/port as key.
199 const string s = StrCat(input_node->name(), ":", edge->src_output());
200 VLOG(1) << "Input edge = " << s;
201 if (input_to_engine_port.count(s)) {
202 port = input_to_engine_port.at(s);
203 } else {
204 port = input_to_engine_port.size();
205 input_to_engine_port.insert({s, port});
206 }
207 info->connections.emplace_back(
208 input_node->name(), input_node->id(), edge->src_output(), node_name,
209 node_id, edge->dst_input(), /*input_edge=*/true, port);
210 }
211 }
212 // Create output connections. Sort edges first to make deterministic since
213 // out_edges is a set of pointers.
214 std::vector<const Edge*> out_edges(node->out_edges().begin(),
215 node->out_edges().end());
216 std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
217 for (const auto edge : out_edges) {
218 auto output_node = edge->dst();
219 if (output_node->IsSink() || segment_nodes.count(output_node)) {
220 continue;
221 }
222 if (edge->IsControlEdge()) {
223 // Control output.
224 if (ShallKeepControlEdgeFrom(node)) {
225 info->connections.emplace_back(output_node->name(), output_node->id(),
226 node_name, node_id,
227 /*input_edge=*/false);
228 }
229 } else {
230 // Data output.
231 int port = Graph::kControlSlot - 1;
232 // Use the source segment node name/port as key.
233 const string s = StrCat(node_name, ":", edge->src_output());
234 VLOG(1) << "Output edge = " << s;
235 if (output_to_engine_port.count(s)) {
236 port = output_to_engine_port.at(s);
237 } else {
238 port = output_to_engine_port.size();
239 output_to_engine_port.insert({s, port});
240 }
241 info->connections.emplace_back(
242 output_node->name(), output_node->id(), edge->dst_input(),
243 node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
244 }
245 }
246 } // For each segment node in topological order.
247
248 // Construct the const nodes first.
249 subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
250 added_const_nodes.end());
251 string scope_name;
252 TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
253 g, graph_properties, subgraph_nodes, &info->connections,
254 &info->segment_graph_def, &scope_name));
255 info->engine_name = StrCat(scope_name, info->engine_name);
256 VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
257 << "' to a GraphDef";
258 if (segment_device.has_type) {
259 // If the accumulated device assignment for the segment has a device type,
260 // the segmenter guarantees the device type is GPU. Use the device
261 // assignment in this case.
262 if (segment_device.type != "GPU") {
263 return errors::Internal(
264 "segment device is not GPU: ",
265 DeviceNameUtils::ParsedNameToString(segment_device));
266 }
267 info->device = DeviceNameUtils::ParsedNameToString(segment_device);
268 } else {
269 TfGpuId tf_gpu_id;
270 PlatformGpuId platform_gpu_id;
271 std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
272 if (tf_gpu_id.value() >= 0) {
273 DeviceNameUtils::ParsedName parsed_name;
274 parsed_name.type = "GPU";
275 parsed_name.has_type = true;
276 parsed_name.id = tf_gpu_id.value();
277 parsed_name.has_id = true;
278 info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
279 } else {
280 VLOG(1) << "No device is assigned to the segment. A device will be "
281 "assigned during graph execution (inference).";
282 }
283 }
284 return Status::OK();
285 }
286
287 // Helper function to update edge connection from the removed node to the
288 // engine node. If an outside node is gone, it must have been absorbed into
289 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)290 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
291 const size_t my_engine_id,
292 const std::vector<Node*>& engine_nodes,
293 const bool is_input_edge, const string& node_name,
294 Node** node, int* port) {
295 for (size_t t = 0; t < infos.size(); ++t) {
296 if (t == my_engine_id) {
297 continue;
298 }
299 const auto& info = infos.at(t);
300 for (const auto& eng_conn : info.connections) {
301 // If the connection being updated is an input connection, the source of
302 // the connection must be an output connection of another engine. And vise
303 // versa.
304 if (is_input_edge == eng_conn.is_input_edge) continue;
305 if (eng_conn.inside_node_name == node_name &&
306 eng_conn.inside_port == *port) {
307 *node = CHECK_NOTNULL(engine_nodes[t]);
308 QCHECK_EQ(info.engine_name, (**node).name())
309 << "Engine name mismatch: " << info.engine_name << " vs "
310 << (**node).name();
311 *port = eng_conn.port_number;
312 return;
313 }
314 }
315 }
316 LOG(FATAL) << "Node " << node_name << " not found in any engine.";
317 }
318
319 // Function to insert a TRT engine node into the graph.
320 // Create engine nodes in the following way:
321 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
322 // 2. When an engine node is created, add it into the graph with necessary
323 // re-wiring.
324 // 2.1. If the outside connected node is existing, connect the engine
325 // node to it.
326 // 2.2. If the outside connected node is gone, it must have been absorted
327 // into another engine node (which was processed before the processing
328 // one). Connect to the pre-existing engine node instead.
329 // 3. In this way, we ensure the graph is topologically sort-able after each
330 // invocation of CreateTRTNode().
CreateTRTNode(const ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int default_max_batch_size,Graph * graph,std::vector<Node * > * engine_nodes)331 Status CreateTRTNode(const ConversionParams& params,
332 const std::vector<EngineInfo>& infos, int pos,
333 int default_max_batch_size, Graph* graph,
334 std::vector<Node*>* engine_nodes) {
335 const auto& info = infos.at(pos);
336 std::vector<tensorflow::TensorShapeProto> input_shape_protos;
337 std::vector<PartialTensorShape> input_shapes;
338 std::vector<NodeDefBuilder::NodeOut> inputs;
339 std::vector<Node*> input_nodes;
340 std::vector<Node*> control_input_nodes;
341 std::unordered_set<string> control_input_names;
342 std::vector<DataType> out_types;
343
344 VLOG(1) << "Processing " << info.engine_name;
345 // Collect needed info for creating the engine node in the graph
346 for (const auto& conn : info.connections) {
347 // Control edges
348 if (conn.is_control_edge()) {
349 // Skip control outputs for now. control output info are not needed for
350 // node creation and will be processed later.
351 if (!conn.is_input_edge) continue;
352
353 // Rewrire control input if it's not found in original graph.
354 Node* input_node = graph->FindNodeId(conn.outside_id);
355 int port = Graph::kControlSlot;
356 if (!input_node) {
357 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
358 conn.outside_node_name, &input_node, &port);
359 QCHECK_EQ(Graph::kControlSlot, port);
360 }
361 if (!control_input_names.insert(input_node->name()).second) {
362 continue;
363 }
364 control_input_nodes.push_back(input_node);
365 VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
366 << info.engine_name;
367 } else {
368 // Data edges
369 if (!conn.is_input_edge) {
370 // Set the data types of output edge.
371 if (out_types.size() <= conn.port_number) {
372 out_types.resize(conn.port_number + 1);
373 }
374 out_types.at(conn.port_number) = conn.connection_type;
375 } else {
376 // Set the shapes and data types of input edge.
377 if (input_shapes.size() <= conn.port_number) {
378 input_shape_protos.resize(conn.port_number + 1);
379 input_shapes.resize(conn.port_number + 1);
380 }
381 conn.outside_shape.AsProto(&input_shape_protos.at(conn.port_number));
382 input_shapes.at(conn.port_number) = conn.outside_shape;
383 // Shape must be fully defined (excluding batch dimension) for static
384 // mode.
385 if (params.use_implicit_batch &&
386 info.engine_type == EngineInfo::EngineType::TRTStatic) {
387 for (int i = 1; i < conn.outside_shape.dims(); i++) {
388 if (conn.outside_shape.dim_size(i) <= 0) {
389 return errors::Internal(
390 "Not fully defined input shape when in static mode which "
391 "should have been excluded by the segmenter. ");
392 }
393 }
394 }
395
396 // Rewrire data input if it's not found in original graph.
397 Node* input_node = graph->FindNodeId(conn.outside_id);
398 int port = conn.outside_port;
399 if (!input_node) {
400 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
401 conn.outside_node_name, &input_node, &port);
402 }
403 if (std::find_if(
404 std::begin(inputs), std::end(inputs),
405 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
406 return inp.node == input_node->name() && inp.index == port;
407 }) == std::end(inputs)) {
408 inputs.emplace_back(input_node->name(), port, conn.connection_type);
409 input_nodes.push_back(CHECK_NOTNULL(input_node));
410 VLOG(1) << "Engine Input " << input_node->name() << ":" << port
411 << " -> " << info.engine_name << ":" << inputs.size() - 1;
412 }
413 }
414 }
415 }
416 // We don't support segments with no inputs. Fall back to native TF here to
417 // avoid crash later. Constant folding should've folded the ops that make up
418 // these segments.
419 if (inputs.empty()) {
420 return errors::Internal(
421 "Segment has no inputs (possible constfold failure)");
422 }
423
424 const bool calibrate_int8 =
425 (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
426 // Build the engine and get its serialized representation.
427 string segment_string;
428
429 int max_batch_size = info.max_batch_size.has_value()
430 ? info.max_batch_size.value()
431 : default_max_batch_size;
432
433 if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
434 std::pair<int, Allocator*> device_allocator =
435 GetDeviceAndAllocator(params, info);
436 int cuda_device_id = 0;
437 std::unique_ptr<TRTBaseAllocator> trt_allocator;
438 if (device_allocator.first >= 0) {
439 cuda_device_id = device_allocator.first;
440 trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
441 } else {
442 // The value in trt_allocator is a nullptr and cudamalloc will be used.
443 LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
444 "device 0 and use cudamalloc as an allocator";
445 }
446 cudaSetDevice(cuda_device_id);
447
448 auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
449
450 // Create static engines with precision_mode fp32/fp16.
451 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
452 TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
453 info.segment_graph_def,
454 calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
455 max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
456 trt_allocator.get(), /*calibrator=*/nullptr, &engine,
457 info.use_calibration, params.use_implicit_batch,
458 /*convert_successfully=*/nullptr,
459 /*profile=*/nullptr, info.engine_name));
460 TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
461 segment_string = string(static_cast<const char*>(engine_data->data()),
462 engine_data->size());
463 }
464
465 string prec_string;
466 TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
467 NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
468 if (!info.device.empty()) node_builder.Device(info.device);
469 if (VLOG_IS_ON(1)) {
470 string ins = StrCat(info.engine_name, " inputs= ");
471 for (const auto& ii : inputs) {
472 StrAppend(&ins, ii.node, ":", ii.index, " ");
473 }
474 VLOG(1) << ins;
475 }
476 node_builder.Input(inputs);
477 for (const string& c : control_input_names) {
478 node_builder.ControlInput(c);
479 }
480
481 NodeDef trt_node;
482 NameAttrList function;
483 function.set_name(StrCat(info.engine_name, "_native_segment"));
484 Status status =
485 node_builder.Attr("input_shapes", input_shape_protos)
486 .Attr("static_engine",
487 info.engine_type == EngineInfo::EngineType::TRTStatic)
488 .Attr("segment_func", function)
489 .Attr("serialized_segment", segment_string)
490 .Attr("calibration_data", "")
491 .Attr("max_cached_engines_count", info.maximum_cached_engines)
492 .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
493 .Attr("max_batch_size", max_batch_size)
494 .Attr("precision_mode", prec_string)
495 .Attr("use_calibration", info.use_calibration)
496 .Attr("_use_implicit_batch", params.use_implicit_batch)
497 .Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
498 .Attr("OutT", out_types)
499 .Finalize(&trt_node);
500 if (!status.ok()) {
501 LOG(ERROR) << "Node construction failed with" << status;
502 return status;
503 }
504 VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
505
506 // Up until this point, graph is not modified. If we return !status.ok() from
507 // here, this segment will be skipped
508 // TODO(aaroey): let it return proper error status for the following logic
509 // instead of checking fail.
510 Node* engine_node = graph->AddNode(trt_node, &status);
511 (*engine_nodes)[pos] = engine_node;
512 if (!status.ok()) {
513 LOG(ERROR) << "Adding node failed " << status;
514 return status;
515 }
516 // Add control input and input edges to the engine node.
517 for (const auto in : control_input_nodes) {
518 VLOG(1) << "Connecting control edge from " << in->name() << " to "
519 << engine_node->name();
520 graph->AddControlEdge(in, engine_node);
521 }
522 VLOG(1) << "input_nodes size = " << input_nodes.size();
523 for (int i = 0; i < input_nodes.size(); ++i) {
524 Node* n = CHECK_NOTNULL(input_nodes[i]);
525 const auto& in = inputs[i];
526 VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
527 << " to " << engine_node->name() << ":" << i;
528 graph->AddEdge(n, in.index, engine_node, i);
529 }
530
531 // Updates the inputs of output edges destination nodes, and point them to the
532 // engine node.
533 for (auto& conn : info.connections) {
534 if (conn.is_input_edge) {
535 continue;
536 }
537 Node* output_node = graph->FindNodeId(conn.outside_id);
538 int port = conn.outside_port;
539 if (!output_node) {
540 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
541 conn.outside_node_name, &output_node, &port);
542 }
543 if (conn.is_control_edge()) {
544 VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
545 << output_node->name();
546 QCHECK_EQ(Graph::kControlSlot, port);
547 graph->AddControlEdge(engine_node, output_node);
548 } else {
549 VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
550 << conn.port_number << " to " << output_node->name() << ":"
551 << port;
552 // Use UpdateEdge() to avoid adding the same edge multiple times.
553 TF_CHECK_OK(
554 graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
555 }
556 }
557 return Status::OK();
558 }
559
GetNextGraphSequenceNumber()560 int64 GetNextGraphSequenceNumber() {
561 static std::atomic<int64> graph_sequence_num;
562 return graph_sequence_num++;
563 }
564
565 constexpr char kCastInputTypeAttrName[] = "SrcT";
566
567 // Transforms node = cast(x, fp32) where datatype(x) != fp16 to:
568 // castToFp16 = cast(x, fp16)
569 // node = cast(castToFp16, fp32)
570 //
MaybeRewriteCastToFp32(GraphDef * graph_def,NodeDef * node_def)571 Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) {
572 if (node_def->op() != "Cast") {
573 return Status::OK();
574 }
575
576 DataTypeVector input_types;
577 DataTypeVector output_types;
578 TF_RETURN_IF_ERROR(
579 graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types));
580
581 if (input_types.size() != 1 || output_types.size() != 1) {
582 return errors::Internal("Bad cast operation");
583 }
584
585 if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) {
586 return Status::OK();
587 }
588
589 VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString();
590
591 NodeDef* castToFp16 = graph_def->add_node();
592 for (auto attr_value : node_def->attr()) {
593 (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second;
594 }
595 castToFp16->set_name(node_def->name() + "_split");
596 castToFp16->set_op("Cast");
597 castToFp16->set_device(node_def->device());
598 castToFp16->add_input(node_def->input(0));
599 (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF);
600
601 node_def->set_input(0, castToFp16->name() + ":0");
602 (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF);
603
604 VLOG(2) << castToFp16->DebugString();
605 VLOG(2) << node_def->DebugString();
606
607 return Status::OK();
608 }
609
610 } // namespace
611
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name)612 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
613 Graph* graph, const string& engine_name) {
614 Graph segment_graph(graph->flib_def());
615 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
616 segment_graph_def, &segment_graph));
617 FunctionDefLibrary library;
618 auto segment_func = library.add_function();
619 TF_RETURN_IF_ERROR(GraphToFunctionDef(
620 segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
621 if (VLOG_IS_ON(7)) {
622 VLOG(7) << engine_name << " Function_Def ";
623 VLOG(7) << segment_func->DebugString();
624 }
625 VLOG(1) << "Adding funcdef " << segment_func->signature().name()
626 << " to graphlib";
627 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
628 return Status::OK();
629 }
630
GetDeviceAndAllocator(const ConversionParams & params,const EngineInfo & engine)631 std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
632 const EngineInfo& engine) {
633 int cuda_device_id = -1;
634 Allocator* dev_allocator = nullptr;
635 if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
636 engine.device.empty()) {
637 // If device is not set, use the first found GPU device for the conversion.
638 TfGpuId tf_gpu_id;
639 PlatformGpuId platform_gpu_id;
640 std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId();
641 cuda_device_id = platform_gpu_id.value();
642 if (cuda_device_id >= 0) {
643 GPUOptions gpu_options;
644 // If the TF to Cuda gpu id mapping exist, the device and corresponding
645 // allocator must have been initialized already, so the
646 // GetGPUAllocator() call won't create a new allocator.
647 dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
648 gpu_options, tf_gpu_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
649 }
650 return std::make_pair(cuda_device_id, dev_allocator);
651 }
652
653 // Use the device requested by the engine.
654 auto device_set = params.cluster->GetDeviceSet();
655 std::vector<Device*> devices;
656 DeviceNameUtils::ParsedName parsed_name;
657 if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
658 parsed_name.has_id) {
659 device_set->FindMatchingDevices(parsed_name, &devices);
660 }
661 if (!devices.empty()) {
662 if (devices.size() > 1) {
663 string msg = "Found multiple matching devices using name '";
664 StrAppend(&msg, engine.device, "': ");
665 for (auto d : devices) StrAppend(&msg, d->name(), ", ");
666 StrAppend(&msg, ". Will get the allocator from first one.");
667 LOG_WARNING_WITH_PREFIX << msg;
668 }
669 AllocatorAttributes alloc_attr;
670 cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
671 dev_allocator = devices[0]->GetAllocator(alloc_attr);
672 VLOG(1) << "Using allocator " << dev_allocator->Name()
673 << " and cuda_device_id " << cuda_device_id;
674 } else {
675 LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device
676 << "' is not found in the cluster";
677 }
678 return std::make_pair(cuda_device_id, dev_allocator);
679 }
680
681 // Entry function from optimization pass.
ConvertAfterShapes(const ConversionParams & params)682 Status ConvertAfterShapes(const ConversionParams& params) {
683 // Sanity checks.
684 if (params.precision_mode != TrtPrecisionMode::INT8 &&
685 params.use_calibration) {
686 return errors::InvalidArgument(
687 "Calibration with FP32 or FP16 is not supported.");
688 }
689
690 // Make a copy of the input_graph_def because grappler doesn't allow changes
691 // to the input_graph_def and GraphProperties only accepts GraphDef, but not
692 // Graph, as inputs.
693 //
694 // If the overhead of copying the input_graph_def becomes a concern, we can
695 // avoid the copy by (1) enhancing the GraphPropertiers representation to
696 // allow adding shape properties for newly created graph nodes and (2) rewrite
697 // the GraphDef transformation to Graph transformation.
698 GraphDef modified_graph_def = params.grappler_item->graph;
699 // When precision_mode is FP16, transform cast(x, fp32) to
700 // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be
701 // included in the TRTEngineOp as an TensorRT Identity layer for performance:
702 // . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16
703 // precision.
704 // . Changing the input to the TRTEngine from fp32 to fp16 may reduce data
705 // moving from the host to the GPU.
706 if (params.precision_mode == TrtPrecisionMode::FP16) {
707 for (int i = 0; i < modified_graph_def.node_size(); i++) {
708 NodeDef* node_def = modified_graph_def.mutable_node(i);
709 TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&modified_graph_def, node_def));
710 }
711 }
712
713 // Construct a GrapplerItem using the modified graph_def and the input
714 // grappler_item.
715 grappler::GrapplerItem grappler_item =
716 params.grappler_item->WithGraph(std::move(modified_graph_def));
717 const GraphDef& graph_def = grappler_item.graph;
718
719 grappler::GraphProperties static_graph_properties(grappler_item);
720 TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
721
722 // Convert graphdef to graph.
723 FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
724 Graph graph(flib);
725 TF_RETURN_IF_ERROR(
726 ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph));
727
728 // Segment the graph into subgraphs that can be converted to TensorRT
729 segment::SegmentOptions segment_options;
730 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
731 for (const auto& node : *(params.output_names)) {
732 segment_options.exclude_node_list.insert(node);
733 }
734 segment_options.minimum_segment_size = params.minimum_segment_size;
735 segment_options.use_implicit_batch = params.use_implicit_batch;
736 if (segment_options.use_implicit_batch)
737 segment_options.maximum_batch_size = params.max_batch_size;
738 segment_options.allow_dynamic_non_batch_dim =
739 AllowDynamicNonBatchDimension(params);
740
741 segment::SegmentVector initial_segments;
742 TrtNodeValidator validator(static_graph_properties, params.precision_mode,
743 params.use_calibration, params.use_implicit_batch);
744 TF_RETURN_IF_ERROR(segment::SegmentGraph(
745 &graph, &static_graph_properties,
746 std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
747 std::placeholders::_1),
748 // Input validation is already done by TrtNodeValidator, so we don't
749 // need to check the input edges.
750 [](const Edge* edge) { return true; }, OutputEdgeValidator(),
751 segment_options, &initial_segments));
752 LOG(INFO) << "Number of TensorRT candidate segments: "
753 << initial_segments.size();
754
755 // Get the EngineInfo for each segment.
756 std::unordered_map<string, Node*> node_map;
757 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
758 std::vector<EngineInfo> engine_segments;
759 engine_segments.reserve(initial_segments.size());
760 std::vector<Node*> reverse_topo_order;
761 GetPostOrder(graph, &reverse_topo_order);
762 segment::SegmentVector converted_segments;
763 converted_segments.reserve(initial_segments.size());
764 string engine_name_prefix =
765 StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
766 for (size_t t = 0; t < initial_segments.size(); t++) {
767 auto& curr_segment = initial_segments.at(t);
768 EngineInfo curr_engine;
769 curr_engine.engine_name = StrCat(engine_name_prefix, t);
770 Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment,
771 node_map, reverse_topo_order, &curr_engine);
772 if (!status.ok()) {
773 LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t
774 << ": " << status;
775 continue;
776 }
777 curr_engine.precision_mode = params.precision_mode;
778 curr_engine.engine_type = GetEngineType(params);
779 curr_engine.use_calibration = params.use_calibration;
780 curr_engine.maximum_cached_engines = params.max_cached_engines;
781 curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
782 if (!curr_engine.max_batch_size.has_value()) {
783 curr_engine.max_batch_size = params.max_batch_size;
784 }
785
786 status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
787 &graph, curr_engine.engine_name);
788
789 if (!status.ok()) {
790 LOG_WARNING_WITH_PREFIX
791 << "Failed to register segment graphdef to the library " << t << ": "
792 << status;
793 continue;
794 }
795
796 engine_segments.push_back(std::move(curr_engine));
797 converted_segments.push_back(std::move(curr_segment));
798
799 if (VLOG_IS_ON(8)) {
800 string fname = engine_segments.back().engine_name;
801 StrAppend(&fname, ".pb");
802 std::fstream f;
803 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
804 f << engine_segments.at(t).segment_graph_def.SerializeAsString();
805 f.close();
806 }
807 }
808
809 // Save the cuda device if we may need to switch to another cuda device to
810 // build static engines.
811 absl::optional<int> old_cuda_device = absl::nullopt;
812 if (!params.is_dyn_op) {
813 int cuda_device_id;
814 cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
815 if (cuda_error != cudaSuccess) {
816 LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
817 << cudaGetErrorString(cuda_error);
818 } else {
819 VLOG(1) << "Current cuda device is " << cuda_device_id;
820 old_cuda_device = cuda_device_id;
821 }
822 }
823
824 auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
825 if (old_cuda_device.has_value()) {
826 cudaSetDevice(old_cuda_device.value());
827 }
828 });
829
830 std::vector<Node*> engine_nodes;
831 engine_nodes.resize(engine_segments.size());
832 for (int i = 0; i < engine_segments.size(); ++i) {
833 auto& engine = engine_segments.at(i);
834 // TODO(b/170762693): implement the heuristic to calculate
835 // max_workspace_size_bytes.
836 engine.max_workspace_size_bytes = params.max_workspace_size_bytes;
837 VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
838 << engine.engine_name;
839 auto status = CreateTRTNode(params, engine_segments, i,
840 params.max_batch_size, &graph, &engine_nodes);
841
842 string msg = StrCat("segment ", i, " consisting of ",
843 converted_segments.at(i).nodes.size(), " nodes by ",
844 engine.engine_name);
845 if (status.ok()) {
846 LOG(INFO) << "Replaced " << msg << ".";
847 } else {
848 // Graph is not modified.
849 LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg
850 << " reason: " << status.error_message()
851 << " (keeping original segment).";
852 }
853 if (VLOG_IS_ON(1)) {
854 msg = "Segment consists of nodes: ";
855 for (const Node* node : converted_segments.at(i).nodes) {
856 StrAppend(&msg, node->name(), ", ");
857 }
858 VLOG(1) << msg;
859 }
860
861 // If status is ok, we successfully added the node to the graph and can
862 // remove segment ops. Otherwise graph is not modified.
863 if (status.ok()) {
864 for (const Node* node : converted_segments.at(i).nodes) {
865 graph.RemoveNode(const_cast<Node*>(node));
866 }
867 }
868 }
869 graph.ToGraphDef(params.output_graph_def);
870 VLOG(1) << "Returning from conversion";
871 return Status::OK();
872 }
873
874 } // namespace convert
875 } // namespace tensorrt
876 } // namespace tensorflow
877
878 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
879