1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Exposes the Python wrapper conversion to trt_graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six as _six 22from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.core.protobuf import meta_graph_pb2 25from tensorflow.core.protobuf import rewriter_config_pb2 26from tensorflow.python.client import session 27from tensorflow.python.eager import context 28from tensorflow.python.eager import function 29from tensorflow.python.framework import convert_to_constants 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import func_graph 32from tensorflow.python.framework import graph_util 33from tensorflow.python.framework import importer 34from tensorflow.python.framework import ops 35from tensorflow.python.grappler import tf_optimizer 36from tensorflow.python.ops import array_ops 37from tensorflow.python.platform import tf_logging 38from tensorflow.python.saved_model import builder 39from tensorflow.python.saved_model import load 40from tensorflow.python.saved_model import loader 41from tensorflow.python.saved_model import save 42from tensorflow.python.saved_model import signature_constants 43from tensorflow.python.saved_model import tag_constants 44from tensorflow.python.training import saver 45 46 47def _to_bytes(s): 48 """Encode s if it is a sequence of chars.""" 49 if isinstance(s, _six.text_type): 50 return s.encode("utf-8", errors="surrogateescape") 51 return s 52 53 54def _to_string(s): 55 """Decode s if it is a sequence of bytes.""" 56 if isinstance(s, _six.binary_type): 57 return s.decode("utf-8") 58 return s 59 60 61class GraphConverter(object): 62 """Base class for offline converters to optimize SavedModels/GraphDefs. 63 64 A `GraphConverter` object encapsulates the environment to convert (optimize) a 65 TensorFlow SavedModel or GraphDef. 66 67 To create a custom GraphConverter: 68 69 ```python 70 class MyGraphConverter(GraphConverter): 71 ... 72 73 def get_rewriter_config(self, rewriter_config_template=None): 74 my_rewriter_config = ... 75 return my_rewriter_config 76 ``` 77 78 Then to run the conversion without quantization calibration: 79 80 ```python 81 my_converter = MyGraphConverter(input_saved_model_dir="my_dir") 82 converted_graph_def = my_converter.convert() 83 my_converter.save(output_saved_model_dir) # Optional 84 ``` 85 86 To run the conversion with quantization calibration: 87 88 ```python 89 my_converter = MyGraphConverter(input_saved_model_dir="my_dir") 90 my_converter.convert() 91 92 # Run calibration 10 times. 93 converted_graph_def = my_converter.calibrate( 94 fetch_names=['output:0'], 95 num_runs=10, 96 feed_dict_fn=lambda: {'input:0': my_next_data()}) 97 98 my_converter.save(output_saved_model_dir) # Optional 99 ``` 100 """ 101 102 # TODO(laigd): clean up the parameters. 103 def __init__(self, 104 input_saved_model_dir=None, 105 input_saved_model_tags=None, 106 input_saved_model_signature_key=None, 107 input_graph_def=None, 108 nodes_blacklist=None, 109 session_config=None): 110 """Initialize the converter. 111 112 Args: 113 input_saved_model_dir: the directory to load the SavedModel which contains 114 the input graph to transforms. Used only when input_graph_def is None. 115 input_saved_model_tags: list of tags to load the SavedModel. 116 input_saved_model_signature_key: the key of the signature to optimize the 117 graph for. 118 input_graph_def: a GraphDef object containing a model to be transformed. 119 If set to None, the graph will be read from the SavedModel loaded from 120 input_saved_model_dir. 121 nodes_blacklist: list of node names to prevent the converter from 122 touching. Only used when input_graph_def is not None. 123 session_config: the ConfigProto used to create a Session. It's also used 124 as a template to create a RewriterConfig for conversion. If not 125 specified, a default ConfigProto will be used. 126 127 Raises: 128 ValueError: if the combination of the parameters is invalid. 129 """ 130 if context.executing_eagerly(): 131 if input_graph_def or not input_saved_model_dir: 132 raise ValueError( 133 "TF 2.0 only supports conversion of SavedModel, please specify " 134 "input_saved_model_dir as input.") 135 else: 136 if input_graph_def and input_saved_model_dir: 137 raise ValueError( 138 "Can only specify one of input_graph_def and input_saved_model_dir") 139 if not input_graph_def and not input_saved_model_dir: 140 raise ValueError("Must specify one of input_graph_def and " 141 "input_saved_model_dir") 142 143 self._input_graph_def = input_graph_def 144 self._nodes_blacklist = nodes_blacklist 145 146 self._input_saved_model_dir = input_saved_model_dir 147 self._converted = False 148 self._grappler_meta_graph_def = None 149 150 self._input_saved_model_tags = ( 151 input_saved_model_tags or [tag_constants.SERVING]) 152 self._input_saved_model_signature_key = ( 153 input_saved_model_signature_key or 154 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 155 self._session_config = session_config or config_pb2.ConfigProto() 156 157 # For calibration usage. 158 self._calibration_graph = None 159 self._calibration_sess = None 160 self._calibration_data_collected = False 161 162 def get_rewriter_config(self, rewriter_config_template=None): 163 """Returns a RewriterConfig proto for TRT transformation. 164 165 Args: 166 rewriter_config_template: a template RewriterConfig proto used to create a 167 RewriterConfig for the conversion. The implementation should not modify 168 the template. If None, it will use a default one. 169 170 Returns: 171 A RewriterConfig proto which will be used to run the conversion using 172 Grappler. 173 """ 174 raise NotImplementedError("get_rewriter_config") 175 176 def _run_conversion(self): 177 """Run Grappler's OptimizeGraph() tool to convert the graph.""" 178 # Create custom ConfigProto for Grappler. 179 grappler_session_config = config_pb2.ConfigProto() 180 grappler_session_config.CopyFrom(self._session_config) 181 rewriter_config = None 182 if (grappler_session_config.HasField("graph_options") and 183 grappler_session_config.graph_options.HasField("rewrite_options")): 184 rewriter_config = grappler_session_config.graph_options.rewrite_options 185 custom_rewriter_config = self.get_rewriter_config(rewriter_config) 186 grappler_session_config.graph_options.rewrite_options.CopyFrom( 187 custom_rewriter_config) 188 189 # Run Grappler. 190 self._converted_graph_def = tf_optimizer.OptimizeGraph( 191 grappler_session_config, 192 self._grappler_meta_graph_def, 193 graph_id=b"tf_graph") 194 self._converted = True 195 196 def _add_nodes_blacklist(self): 197 if self._nodes_blacklist: 198 collection_def = self._grappler_meta_graph_def.collection_def["train_op"] 199 blacklist = collection_def.node_list.value 200 for i in self._nodes_blacklist: 201 if isinstance(i, ops.Tensor): 202 blacklist.append(_to_bytes(i.name)) 203 else: 204 blacklist.append(_to_bytes(i)) 205 206 def _convert_graph_def(self): 207 """Convert the input GraphDef.""" 208 graph = ops.Graph() 209 with graph.as_default(): 210 importer.import_graph_def(self._input_graph_def, name="") 211 self._grappler_meta_graph_def = saver.export_meta_graph( 212 graph_def=graph.as_graph_def(add_shapes=True), graph=graph) 213 self._add_nodes_blacklist() 214 215 self._run_conversion() 216 217 def _convert_saved_model(self): 218 """Convert the input SavedModel.""" 219 graph = ops.Graph() 220 with session.Session(graph=graph, config=self._session_config) as sess: 221 input_meta_graph_def = loader.load(sess, self._input_saved_model_tags, 222 self._input_saved_model_dir) 223 input_signature_def = input_meta_graph_def.signature_def[ 224 self._input_saved_model_signature_key] 225 226 def _gather_names(tensor_info): 227 """Get the node names from a TensorInfo.""" 228 return set([tensor_info[key].name.split(":")[0] for key in tensor_info]) 229 230 # Get input and outputs from all SignatureDef. 231 output_node_names = _gather_names(input_signature_def.inputs).union( 232 _gather_names(input_signature_def.outputs)) 233 234 # Freeze the variables in the SavedModel graph and copy the frozen 235 # graph over. 236 frozen_graph_def = graph_util.convert_variables_to_constants( 237 sess, sess.graph.as_graph_def(add_shapes=True), 238 list(output_node_names)) 239 self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() 240 self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) 241 242 # Copy the collections that are not variables. 243 for key in input_meta_graph_def.collection_def: 244 # TODO(laigd): currently we use the collection key to filter out 245 # collections that depend on variable ops, but this may miss some 246 # other user-defined collections. A better way would be to use 247 # CollectionDef::NodeList for the filtering. 248 if key not in [ 249 "variables", "local_variables", "model_variables", 250 "trainable_variables", "train_op", "table_initializer" 251 ]: 252 self._grappler_meta_graph_def.collection_def[key].CopyFrom( 253 input_meta_graph_def.collection_def[key]) 254 255 self._add_nodes_blacklist() 256 257 # Copy other information. 258 self._grappler_meta_graph_def.meta_info_def.CopyFrom( 259 input_meta_graph_def.meta_info_def) 260 self._grappler_meta_graph_def.signature_def[ 261 self._input_saved_model_signature_key].CopyFrom(input_signature_def) 262 # TODO(laigd): maybe add back AssetFileDef. 263 264 self._run_conversion() 265 266 # TODO(laigd): provide a utility function to optimize a ConcreteFunction and 267 # use it here (b/124792963). 268 def _convert_saved_model_v2(self): 269 """Convert the input SavedModel in 2.0 format.""" 270 self._saved_model = load.load(self._input_saved_model_dir, 271 self._input_saved_model_tags) 272 func = self._saved_model.signatures[self._input_saved_model_signature_key] 273 frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) 274 self._grappler_meta_graph_def = saver.export_meta_graph( 275 graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) 276 277 # Add a collection 'train_op' so that Grappler knows the outputs. 278 fetch_collection = meta_graph_pb2.CollectionDef() 279 for array in func.inputs + func.outputs: 280 fetch_collection.node_list.value.append(array.name) 281 self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom( 282 fetch_collection) 283 284 # Run TRT optimizer in Grappler to convert the graph. 285 self._run_conversion() 286 287 def _get_tensor(graph, tensors): 288 new_tensors = [] 289 for tensor in tensors: 290 new_tensor = graph.get_tensor_by_name(tensor.name) 291 new_tensor.set_shape(tensor.shape) 292 new_tensors.append(new_tensor) 293 return new_tensors 294 295 # TODO(laigd): do we need to use different name e.g. "trt_func_graph"? 296 converted_graph = func_graph.FuncGraph(func.graph.name) 297 with converted_graph.as_default(): 298 importer.import_graph_def(self._converted_graph_def, name="") 299 300 converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs) 301 converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs) 302 converted_graph.structured_outputs = func.graph.structured_outputs 303 converted_graph.structured_input_signature = ( 304 func.graph.structured_input_signature) 305 306 # pylint: disable=protected-access 307 # TODO(laigd): should we set up the signature as well? 308 self._converted_func = function.ConcreteFunction( 309 converted_graph, attrs=None, signature=None) 310 self._converted_func.add_to_graph() 311 self._converted_func._arg_keywords = func._arg_keywords 312 self._converted_func._num_positional_args = func._num_positional_args 313 self._converted_func._captured_inputs = func._captured_inputs 314 self._converted_func.graph.variables = func.graph.variables 315 # pylint: enable=protected-access 316 317 def convert(self): 318 """Run the conversion. 319 320 Returns: 321 The converted GraphDef for TF 1.x, or the converted ConcreteFunction in TF 322 2.0+. 323 """ 324 assert not self._converted 325 326 if context.executing_eagerly(): 327 self._convert_saved_model_v2() 328 return self._converted_func 329 else: 330 if self._input_graph_def: 331 self._convert_graph_def() 332 else: 333 self._convert_saved_model() 334 return self._converted_graph_def 335 336 def calibrate(self, 337 fetch_names, 338 num_runs, 339 feed_dict_fn=None, 340 input_map_fn=None): 341 """Run the calibration and return the calibrated GraphDef. 342 343 Args: 344 fetch_names: a list of output tensor name to fetch during calibration. 345 num_runs: number of runs of the graph during calibration. 346 feed_dict_fn: a function that returns a dictionary mapping input names (as 347 strings) in the GraphDef to be calibrated to values (e.g. Python list, 348 numpy arrays, etc). One and only one of `feed_dict_fn` and 349 `input_map_fn` should be specified. 350 input_map_fn: a function that returns a dictionary mapping input names (as 351 strings) in the GraphDef to be calibrated to Tensor objects. The values 352 of the named input tensors in the GraphDef to be calibrated will be 353 re-mapped to the respective `Tensor` values during calibration. One and 354 only one of `feed_dict_fn` and `input_map_fn` should be specified. 355 356 Raises: 357 ValueError: if the input combination is invalid. 358 RuntimeError: if this method is called in eager mode. 359 360 Returns: 361 The GraphDef after the calibration. 362 """ 363 assert self._converted 364 assert not self._calibration_sess 365 366 if context.executing_eagerly(): 367 raise RuntimeError("Calibration for TF 2.0 is not supported yet.") 368 369 if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and 370 not input_map_fn): 371 raise ValueError( 372 "Should specify one and only one of feed_dict_fn and input_map_fn.") 373 374 self._calibration_graph = ops.Graph() 375 with self._calibration_graph.as_default(): 376 fetches = importer.import_graph_def( 377 self._converted_graph_def, 378 input_map=input_map_fn() if input_map_fn else None, 379 return_elements=fetch_names, 380 name="") 381 self._calibration_sess = session.Session( 382 graph=self._calibration_graph, config=self._session_config) 383 384 for _ in range(num_runs): 385 self._calibration_sess.run( 386 fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None) 387 388 self.finalize_calibration() 389 return self._converted_graph_def 390 391 def finalize_calibration(self): 392 """Clean up calibration resources and finalize the calibration. 393 394 Implementations need to close self._calibration_sess before returning. 395 """ 396 raise NotImplementedError("finalize_calibration") 397 398 def save(self, output_saved_model_dir): 399 """Save the converted graph as a SavedModel. 400 401 Args: 402 output_saved_model_dir: construct a SavedModel using the converted 403 GraphDef and save it to the specified directory. This option only works 404 when the input graph is loaded from a SavedModel, i.e. when 405 input_saved_model_dir is specified and input_graph_def is None in 406 __init__(). 407 408 Raises: 409 ValueError: if the input to the converter is a GraphDef instead of a 410 SavedModel. 411 """ 412 assert self._converted 413 414 if context.executing_eagerly(): 415 # Rewrite the signature map using the optimized ConcreteFunction. 416 signatures = { 417 key: value for key, value in self._saved_model.signatures.items() 418 } 419 signatures[self._input_saved_model_signature_key] = self._converted_func 420 save.save(self._saved_model, output_saved_model_dir, signatures) 421 else: 422 if self._input_graph_def: 423 raise ValueError( 424 "Not able to save to a SavedModel since input is a GraphDef") 425 426 # Write the transformed graphdef as SavedModel. 427 saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) 428 with ops.Graph().as_default(): 429 importer.import_graph_def(self._converted_graph_def, name="") 430 # We don't use any specific converter here. 431 with session.Session(config=self._session_config) as sess: 432 saved_model_builder.add_meta_graph_and_variables( 433 sess, 434 self._input_saved_model_tags, 435 signature_def_map=self._grappler_meta_graph_def.signature_def) 436 # Ignore other meta graphs from the input SavedModel. 437 saved_model_builder.save() 438 439 440class TrtPrecisionMode(object): 441 FP32 = "FP32" 442 FP16 = "FP16" 443 INT8 = "INT8" 444 445 @staticmethod 446 def supported_precision_modes(): 447 return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] 448 449 450# Use a large enough number as the default max_workspace_size for TRT engines, 451# so it can produce reasonable performance results with the default. 452DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 453 454 455class TrtGraphConverter(GraphConverter): 456 """A GraphConverter for TRT transformation.""" 457 458 _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF_TRT_Calibration" 459 460 @classmethod 461 def get_tensorrt_rewriter_config( 462 cls, 463 rewriter_config_template=None, 464 max_batch_size=1, 465 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 466 precision_mode=TrtPrecisionMode.FP32, 467 minimum_segment_size=3, 468 is_dynamic_op=False, 469 maximum_cached_engines=1, 470 cached_engine_batches=None, 471 use_calibration=True, 472 use_function_backup=True): 473 """Returns a RewriterConfig proto for TRT transformation. 474 475 Args: 476 rewriter_config_template: a template RewriterConfig proto used to create a 477 TRT-enabled RewriterConfig. If None, it will use a default one. 478 max_batch_size: max size for the input batch 479 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 480 engine can use at execution time. This corresponds to the 481 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 482 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 483 minimum_segment_size: the minimum number of nodes required for a subgraph 484 to be replaced by TRTEngineOp. 485 is_dynamic_op: whether to generate dynamic TRT ops which will build the 486 TRT network and engine at run time. 487 maximum_cached_engines: max number of cached TRT engines in dynamic TRT 488 ops. If the number of cached engines is already at max but none of them 489 can serve the input, the TRTEngineOp will fall back to run the TF 490 function based on which the TRTEngineOp is created. 491 cached_engine_batches: a list of batch sizes used to create cached 492 engines, only used when is_dynamic_op is True. The length of the list 493 should be <= maximum_cached_engines, and the dynamic TRT op will use 494 this list to determine the batch sizes of the cached engines, instead of 495 making the decision on the fly. This is useful when we know the most 496 common batch size(s) the application is going to generate. 497 use_calibration: this argument is ignored if precision_mode is not INT8. 498 If set to True, a calibration graph will be created to calibrate the 499 missing ranges. The calibration graph must be converted to an inference 500 graph by running calibration with calibrate(). If set to False, 501 quantization nodes will be expected for every tensor in the graph 502 (exlcuding those which will be fused). If a range is missing, an error 503 will occur. Please note that accuracy may be negatively affected if 504 there is a mismatch between which tensors TRT quantizes and which 505 tensors were trained with fake quantization. 506 use_function_backup: if set to True, it will create a FunctionDef for each 507 subgraph that is converted to TRT op, and if TRT ops fail to execute at 508 runtime, it'll invoke that function as a fallback. 509 510 Returns: 511 A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. 512 513 Raises: 514 TypeError: if any of the parameters are of unexpected type. 515 ValueError: if any of the parameters are of unexpected value. 516 """ 517 # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain 518 # even if it cannot find TensorRT library. 519 trt_ops.load_trt_ops() 520 # pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable 521 # Import a random symbol to trigger loading of TRT library. 522 from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version 523 # pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable 524 525 if rewriter_config_template is not None and not isinstance( 526 rewriter_config_template, rewriter_config_pb2.RewriterConfig): 527 raise TypeError( 528 "rewriter_config_template should be a RewriterConfig proto.") 529 530 rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() 531 if rewriter_config_template is None: 532 # Layout optimizer may add Const nodes followed by Reshape nodes, thus we 533 # need to run constant folding again. 534 rewriter_config_with_trt.optimizers.extend( 535 ["constfold", "layout", "constfold"]) 536 rewriter_config_with_trt.meta_optimizer_iterations = ( 537 rewriter_config_pb2.RewriterConfig.ONE) 538 else: 539 rewriter_config_with_trt.CopyFrom(rewriter_config_template) 540 541 optimizer = rewriter_config_with_trt.custom_optimizers.add() 542 optimizer.name = "TensorRTOptimizer" 543 optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size 544 optimizer.parameter_map["max_batch_size"].i = max_batch_size 545 optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op 546 optimizer.parameter_map[ 547 "max_workspace_size_bytes"].i = max_workspace_size_bytes 548 optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) 549 optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines 550 if cached_engine_batches: 551 optimizer.parameter_map["cached_engine_batches"].list.i.extend( 552 cached_engine_batches) 553 optimizer.parameter_map["use_calibration"].b = use_calibration 554 optimizer.parameter_map["use_function_backup"].b = use_function_backup 555 return rewriter_config_with_trt 556 557 def __init__(self, 558 input_saved_model_dir=None, 559 input_saved_model_tags=None, 560 input_saved_model_signature_key=None, 561 input_graph_def=None, 562 nodes_blacklist=None, 563 session_config=None, 564 max_batch_size=1, 565 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 566 precision_mode=TrtPrecisionMode.FP32, 567 minimum_segment_size=3, 568 is_dynamic_op=False, 569 maximum_cached_engines=1, 570 cached_engine_batches=None, 571 use_calibration=True, 572 use_function_backup=True): 573 """Initialize the converter. 574 575 Args: 576 input_saved_model_dir: the directory to load the SavedModel which contains 577 the input graph to transforms. Used only when input_graph_def is None. 578 input_saved_model_tags: list of tags to load the SavedModel. 579 input_saved_model_signature_key: the key of the signature to optimize the 580 graph for. 581 input_graph_def: a GraphDef object containing a model to be transformed. 582 If set to None, the graph will be read from the SavedModel loaded from 583 input_saved_model_dir. 584 nodes_blacklist: list of node names to prevent the converter from 585 touching. Only used when input_graph_def is not None. 586 session_config: the ConfigProto used to create a Session. It's also used 587 as a template to create a TRT-enabled ConfigProto for conversion. If not 588 specified, a default ConfigProto will be used. 589 max_batch_size: max size for the input batch. 590 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 591 engine can use at execution time. This corresponds to the 592 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 593 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 594 minimum_segment_size: the minimum number of nodes required for a subgraph 595 to be replaced by TRTEngineOp. 596 is_dynamic_op: whether to generate dynamic TRT ops which will build the 597 TRT network and engine at run time. 598 maximum_cached_engines: max number of cached TRT engines in dynamic TRT 599 ops. If the number of cached engines is already at max but none of them 600 can serve the input, the TRTEngineOp will fall back to run the TF 601 function based on which the TRTEngineOp is created. 602 cached_engine_batches: a list of batch sizes used to create cached 603 engines, only used when is_dynamic_op is True. The length of the list 604 should be <= maximum_cached_engines, and the dynamic TRT op will use 605 this list to determine the batch sizes of the cached engines, instead of 606 making the decision on the fly. This is useful when we know the most 607 common batch size(s) the application is going to generate. 608 use_calibration: this argument is ignored if precision_mode is not INT8. 609 If set to True, a calibration graph will be created to calibrate the 610 missing ranges. The calibration graph must be converted to an inference 611 graph by running calibration with calibrate(). If set to False, 612 quantization nodes will be expected for every tensor in the graph 613 (exlcuding those which will be fused). If a range is missing, an error 614 will occur. Please note that accuracy may be negatively affected if 615 there is a mismatch between which tensors TRT quantizes and which 616 tensors were trained with fake quantization. 617 use_function_backup: if set to True, it will create a FunctionDef for each 618 subgraph that is converted to TRT op, and if TRT ops fail to execute at 619 runtime, it'll invoke that function as a fallback. 620 621 Raises: 622 ValueError: if the combination of the parameters is invalid. 623 RuntimeError: if the TensorRT library version is incompatible. 624 """ 625 super(TrtGraphConverter, self).__init__( 626 input_saved_model_dir=input_saved_model_dir, 627 input_saved_model_tags=input_saved_model_tags, 628 input_saved_model_signature_key=input_saved_model_signature_key, 629 input_graph_def=input_graph_def, 630 nodes_blacklist=nodes_blacklist, 631 session_config=session_config) 632 633 # TODO(laigd): move all the validations below to 634 # get_tensorrt_rewriter_config(). 635 636 # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain 637 # even if it cannot find TensorRT library. 638 trt_ops.load_trt_ops() 639 # pylint: disable=g-import-not-at-top,line-too-long 640 from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version 641 from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version 642 # pylint: enable=g-import-not-at-top,line-too-long 643 644 # Check compatibility of TensorRT version. 645 compiled_version = get_linked_tensorrt_version() 646 loaded_version = get_loaded_tensorrt_version() 647 tf_logging.info("Linked TensorRT version: %s" % str(compiled_version)) 648 tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version)) 649 version_mismatch = False 650 if loaded_version[0] < compiled_version[0]: 651 tf_logging.error( 652 "TensorRT version mismatch. Tensorflow was compiled against " + 653 "TensorRT %s but library loaded from environment is TensorRT %s" % 654 (".".join([str(x) for x in compiled_version]), 655 ".".join([str(x) for x in loaded_version])) + 656 ". Please make sure that correct version of TensorRT " + 657 "is available in the system and added to ldconfig or LD_LIBRARY_PATH") 658 raise RuntimeError("Incompatible TensorRT library version") 659 for i in zip(loaded_version, compiled_version): 660 if i[0] != i[1]: 661 tf_logging.warn("TensorRT mismatch. Compiled against version " + 662 "%s, but loaded %s. Things may not work" % 663 (".".join([str(x) for x in compiled_version]), 664 ".".join([str(x) for x in loaded_version]))) 665 version_mismatch = True 666 break 667 if not version_mismatch: 668 tf_logging.info("Running against TensorRT version %s" % 669 ".".join([str(x) for x in loaded_version])) 670 671 # Check input arguments. 672 supported_precision_modes = TrtPrecisionMode.supported_precision_modes() 673 if precision_mode not in supported_precision_modes: 674 raise ValueError(("precision mode '{}' is not supported." 675 "It should be one of {}").format( 676 precision_mode, supported_precision_modes)) 677 678 if cached_engine_batches: 679 if not isinstance(cached_engine_batches, list): 680 raise TypeError("cached_engine_batches should be a list.") 681 if len(cached_engine_batches) > maximum_cached_engines: 682 raise ValueError("cached_engine_batches should not contain more than " 683 "maximum_cached_engines items.") 684 685 self._need_calibration = ( 686 precision_mode == TrtPrecisionMode.INT8 and use_calibration) 687 self._use_function_backup = use_function_backup 688 689 # TODO(laigd): consider provide a mechanism to remove the fallback path 690 # after calibration is done. 691 if self._need_calibration and not use_function_backup: 692 raise ValueError( 693 "Calibration requires enabling fallback to TF function execution.") 694 695 # TODO(laigd): 696 # - Get rid of is_dynamic_op option, it should always be True, and it should 697 # accept N shapes as input. 698 # - Verify in int8 mode that maximum_cached_engines and 699 # cached_engine_batches are set appropriately. 700 # - If it fails to build the int8 engine it should return error. 701 self._max_batch_size = max_batch_size 702 self._max_workspace_size_bytes = max_workspace_size_bytes 703 self._precision_mode = precision_mode 704 self._minimum_segment_size = minimum_segment_size 705 self._is_dynamic_op = is_dynamic_op 706 self._maximum_cached_engines = maximum_cached_engines 707 self._cached_engine_batches = cached_engine_batches 708 709 def get_rewriter_config(self, rewriter_config_template=None): 710 return TrtGraphConverter.get_tensorrt_rewriter_config( 711 rewriter_config_template, 712 max_batch_size=self._max_batch_size, 713 max_workspace_size_bytes=self._max_workspace_size_bytes, 714 precision_mode=self._precision_mode, 715 minimum_segment_size=self._minimum_segment_size, 716 is_dynamic_op=self._is_dynamic_op, 717 maximum_cached_engines=self._maximum_cached_engines, 718 cached_engine_batches=self._cached_engine_batches, 719 use_calibration=self._need_calibration, 720 use_function_backup=self._use_function_backup) 721 722 def finalize_calibration(self): 723 assert self._need_calibration 724 assert self._converted 725 assert not self._calibration_data_collected 726 727 # Lazily load the op, since it's not available in cpu-only builds. Importing 728 # this at top will cause tests that imports TF-TRT fail when they're built 729 # and run without CUDA/GPU. 730 # pylint: disable=g-import-not-at-top,line-too-long 731 from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op 732 # pylint: enable=g-import-not-at-top,line-too-long 733 734 # TODO(laigd): a better way would be to use self._calibration_sess to list 735 # all the devices, add one get_serialized_resource_op for each device, and 736 # fetch each such op for every resource until its found. This can work 737 # even when the device of the TRTEngineOp is empty or not fully specified. 738 739 # Maps device name to the corresponding get_serialized_resource_op. 740 device_to_get_resource_op_map = {} 741 742 with self._calibration_graph.as_default(): 743 container_input = array_ops.placeholder(dtypes.string) 744 resource_name_input = array_ops.placeholder(dtypes.string) 745 746 for node in self._converted_graph_def.node: 747 if node.op == "TRTEngineOp": 748 # Adds the get_serialized_resource_op for the device if not done 749 # before. We only add one such op for each device. 750 # TODO(laigd): What if the device is empty????? 751 if node.device not in device_to_get_resource_op_map: 752 with self._calibration_graph.device(node.device): 753 serialized_resources_output = ( 754 get_serialized_resource_op(container_input, 755 resource_name_input)) 756 device_to_get_resource_op_map[node.device] = ( 757 serialized_resources_output) 758 759 # Get the calibration resource. 760 calibration_result = self._calibration_sess.run( 761 device_to_get_resource_op_map[node.device], 762 feed_dict={ 763 container_input: 764 TrtGraphConverter 765 ._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME, 766 resource_name_input: 767 node.name 768 }) 769 node.attr["calibration_data"].s = calibration_result 770 771 self._calibration_data_collected = True 772 self._calibration_sess.close() 773 774 def save(self, output_saved_model_dir): 775 """Save the converted graph as a SavedModel.""" 776 if self._need_calibration: 777 assert self._calibration_data_collected 778 super(TrtGraphConverter, self).save(output_saved_model_dir) 779 780 781def create_inference_graph( 782 input_graph_def, 783 outputs, 784 max_batch_size=1, 785 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 786 precision_mode=TrtPrecisionMode.FP32, 787 minimum_segment_size=3, 788 is_dynamic_op=False, 789 maximum_cached_engines=1, 790 cached_engine_batches=None, 791 input_saved_model_dir=None, 792 input_saved_model_tags=None, 793 input_saved_model_signature_key=None, 794 output_saved_model_dir=None, 795 session_config=None): 796 """Python wrapper for the TRT transformation. 797 798 Args: 799 input_graph_def: a GraphDef object containing a model to be transformed. If 800 set to None, the graph will be read from the SavedModel loaded from 801 input_saved_model_dir. 802 outputs: list of tensors or node names for the model outputs. Only used when 803 input_graph_def is not None. 804 max_batch_size: max size for the input batch. 805 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 806 engine can use at execution time. This corresponds to the 'workspaceSize' 807 parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 808 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 809 minimum_segment_size: the minimum number of nodes required for a subgraph to 810 be replaced by TRTEngineOp. 811 is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT 812 network and engine at run time. 813 maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. 814 If the number of cached engines is already at max but none of them can 815 serve the input, the TRTEngineOp will fall back to run the TF function 816 based on which the TRTEngineOp is created. 817 cached_engine_batches: a list of batch sizes used to create cached engines, 818 only used when is_dynamic_op is True. The length of the list should be <= 819 maximum_cached_engines, and the dynamic TRT op will use this list to 820 determine the batch sizes of the cached engines, instead of making the 821 decision on the fly. This is useful when we know the most common batch 822 size(s) the application is going to generate. 823 input_saved_model_dir: the directory to load the SavedModel which contains 824 the input graph to transforms. Used only when input_graph_def is None. 825 input_saved_model_tags: list of tags to load the SavedModel. 826 input_saved_model_signature_key: the key of the signature to optimize the 827 graph for. 828 output_saved_model_dir: if not None, construct a SavedModel using the 829 returned GraphDef and save it to the specified directory. This option only 830 works when the input graph is loaded from a SavedModel, i.e. when 831 input_saved_model_dir is specified and input_graph_def is None. 832 session_config: the ConfigProto used to create a Session. It's also used as 833 a template to create a TRT-enabled ConfigProto for conversion. If not 834 specified, a default ConfigProto will be used. 835 836 Returns: 837 A GraphDef transformed from input_graph_def (or the SavedModel graph def 838 loaded from input_saved_model_dir, if input_graph_def is not present), where 839 all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF 840 function is added for each of the subgraphs. 841 842 If is_dynamic_op is True, each TRTEngineOp will contain a serialized 843 subgraph GraphDef, which will be converted to a TRT engine at execution time 844 and the TRT engine will be cached for future usage. A new TRT engine will be 845 created each time when none of the cached engines match the input shapes. If 846 it fails to execute the TRT engine or the number of cached engines reaches 847 maximum_cached_engines, the op will fall back to call the corresponding TF 848 function. 849 850 If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT 851 engine created from the corresponding subgraph. No more engines will be 852 created on the fly, and the op will fall back to call the corresponding TF 853 function when it fails to execute the engine. 854 855 Raises: 856 ValueError: if the combination of the parameters is invalid. 857 """ 858 trt_converter = TrtGraphConverter( 859 input_saved_model_dir=input_saved_model_dir, 860 input_saved_model_tags=input_saved_model_tags, 861 input_saved_model_signature_key=input_saved_model_signature_key, 862 input_graph_def=input_graph_def, 863 nodes_blacklist=outputs, 864 session_config=session_config, 865 max_batch_size=max_batch_size, 866 max_workspace_size_bytes=max_workspace_size_bytes, 867 precision_mode=precision_mode, 868 minimum_segment_size=minimum_segment_size, 869 is_dynamic_op=is_dynamic_op, 870 maximum_cached_engines=maximum_cached_engines, 871 cached_engine_batches=cached_engine_batches, 872 use_calibration=False) 873 converted_graph_def = trt_converter.convert() 874 if output_saved_model_dir: 875 trt_converter.save(output_saved_model_dir) 876 return converted_graph_def 877