# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper conversion to trt_graph."""
import collections
import os
import re
from packaging import version
from tensorflow.compiler.tf2tensorrt import _pywrap_py_utils
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import dtypes
def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
"""Modifies rewriter_config to disable all non-TRT optimizations."""
off = rewriter_config_pb2.RewriterConfig.OFF
rewriter_config.arithmetic_optimization = off
rewriter_config.auto_mixed_precision = off
rewriter_config.auto_parallel.enable = False
rewriter_config.constant_folding = off
rewriter_config.debug_stripper = off
rewriter_config.dependency_optimization = off
# This one needs to be ON to allow TF-TRT
rewriter_config.disable_meta_optimizer = False
rewriter_config.disable_model_pruning = True
rewriter_config.function_optimization = off
rewriter_config.implementation_selector = off
rewriter_config.layout_optimizer = off
rewriter_config.loop_optimization = off
rewriter_config.memory_optimization = (
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
rewriter_config.min_graph_nodes = -1
rewriter_config.pin_to_host_optimization = off
rewriter_config.remapping = off
rewriter_config.scoped_allocator_optimization = off
rewriter_config.shape_optimization = off
def version_tuple_to_string(ver_tuple):
assert isinstance(ver_tuple, tuple)
assert len(ver_tuple) == 3
ver_tuple = [str(x) for x in ver_tuple]
return ".".join(ver_tuple)
def _is_tensorrt_version_greater_equal(trt_ver, target_ver):
trt_ver = version.Version(version_tuple_to_string(trt_ver))
target_ver = version.Version(version_tuple_to_string(target_ver))
return trt_ver >= target_ver
def is_linked_tensorrt_version_greater_equal(major, minor=0, patch=0):
ver = _pywrap_py_utils.get_linked_tensorrt_version()
return _is_tensorrt_version_greater_equal(ver, (major, minor, patch))
def is_loaded_tensorrt_version_greater_equal(major, minor=0, patch=0):
ver = _pywrap_py_utils.get_loaded_tensorrt_version()
return _is_tensorrt_version_greater_equal(ver, (major, minor, patch))
def is_experimental_feature_activated(feature_name):
"""Determines if a TF-TRT experimental feature is enabled.
This helper function checks if an experimental feature was enabled using
the environment variable `TF_TRT_EXPERIMENTAL_FEATURES=feature_1,feature_2`.
Args:
feature_name: Name of the feature being tested for activation.
"""
return (feature_name
in os.environ.get("TF_TRT_EXPERIMENTAL_FEATURES",
default="").split(","))
def _convert_dtype_id_to_str(dtype):
"""Helper function to convert a dtype id to a corresponding string name."""
if isinstance(dtype, int):
return dtypes._TYPE_TO_STRING[dtype]
else:
return [dtypes._TYPE_TO_STRING[d] for d in dtype]
def get_node_compute_dtype(node):
"""Returns the compute DType of a GraphDef Node."""
# Note: Order is important, by default TF Node compute dtype is mentioned
# under `T` key, unless these nodes are one of ["TRTEngineOP", "Cast", "Plh"].
for type_key in [
"precision_mode", # TRTEngineOp
"DstT", # Cast Nodes
"dtype", # Placeholder
"T", # Everything Else
]:
try:
precision_val = node.attr[type_key]
if type_key == "precision_mode":
precision_val = precision_val.s.decode("utf-8")
if precision_val == "":
continue
if precision_val == "FP32":
return "float32"
elif precision_val == "FP16":
return "float16"
elif precision_val == "INT8":
return "int8"
else:
return "unknown"
else:
return _convert_dtype_id_to_str(precision_val.type)
except Exception as e:
continue
def get_node_io_shapes(node, key):
"""Returns the input/output shapes of a GraphDef Node."""
out_shape = []
for shape in node.attr[key].list.shape:
out_shape.append([dim.size for dim in shape.dim])
return out_shape
def get_trtengineop_io_dtypes(node, key):
"""Returns the input/output dtypes of a TRTEngineOp."""
return _convert_dtype_id_to_str(node.attr[key].list.type)
def get_trtengineop_io_nodes_count(node, key):
"""Returns the number of input/output nodes of a TRTEngineOp."""
return len(node.attr[key].list.type)
def get_trtengineop_node_op_count(graphdef, node_name):
"""Counts the number of nodes and OP types of a given TRTEngineOp."""
ops_in_engine = collections.defaultdict(int)
for func in graphdef.library.function:
if f"{node_name}_native_segment" == func.signature.name:
node_count = len(func.node_def)
for node in func.node_def:
ops_in_engine[node.op] += 1
break
return node_count, ops_in_engine
class DTypeIndex(dict):
"""Helper class to create an index of dtypes with incremental values."""
def get_dtype_index(self, dtype):
if dtype not in self:
self[dtype] = len(self) + 1
return self[dtype]
def draw_graphdef_as_graphviz(graphdef, dot_output_filename):
"""Exports a GraphDef to GraphViz format.
- Step 1: Drawing Each Node of the compute GraphDef.
- Step 2: Create nodes for each collected dtype in the graph.
- Step 3: Creating invisible links to align properly the legend.
Each node consequently mentions:
- Op Type
- Compute Dtype
- Compute Device
"""
dtype_index = DTypeIndex()
with open(dot_output_filename, "w") as f:
print("digraph tftrt_converted_graph {", file=f)
print(" graph [fontsize=10 fontname=\"Verdana\"];", file=f)
# ColorScheme Documentation: https://graphviz.org/doc/info/colors.html
print(
" node [style=filled height=0.55 colorscheme=set312 shape=box];",
file=f)
# Step 1: Parsing the graph and drawing OPs one by one.
print("\n subgraph tensorflow_graph {", file=f)
print(" node [width=1.35];", file=f)
nodes_with_no_inputs = []
for node in graphdef.node:
output_name = node.name
node_precision = get_node_compute_dtype(node)
color_idx = dtype_index.get_dtype_index(node_precision)
device_key = node.device.split("/")[-1]
if not device_key:
device_key = "device:Unspecified"
if node.op == "TRTEngineOp":
node_count, _ = get_trtengineop_node_op_count(graphdef, output_name)
node_label = f"{output_name} [{node_count}]"
else:
node_label = f"{node.op}"
# Note: double space before
is necessary for formatting.
node_label = f"{node_label}
{device_key}"
print(
f" \"{output_name}\" [label=<{node_label}> "
f"fillcolor={color_idx}];",
file=f)
if len(node.input):
for input_full_name in node.input:
parts = input_full_name.split(":")
input_name = re.sub(r"^\^", "", parts[0])
print(f" \"{input_name}\" -> \"{output_name}\";", file=f)
else:
nodes_with_no_inputs.append(output_name)
print(" }", file=f)
# Step 2: Creating the DType Nodes previously found in Step 1.
print("\n subgraph cluster_legend {", file=f)
print(" label=\"Compute Dtype Legend\";", file=f)
print(" margin=\"30\";", file=f)
print(" node [width=2];", file=f)
for dtype, color_idx in dtype_index.items():
print(
f" {dtype} [fillcolor={color_idx} label=<{dtype}>];",
file=f)
print(" }", file=f)
# Step 3: Alignement of the legend with the graph.
print("\n edge[style=\"invisible\", dir=\"none\"];", file=f)
for dtype in dtype_index.keys():
for node_name in nodes_with_no_inputs:
print(f" \"{dtype}\" -> \"{node_name}\"", file=f)
print("}", file=f)
print("\n===================================================================")
print(f"Graph Visualization Exported to: `{dot_output_filename}`.")
print("We recommend using https://edotor.net/ to visualize the .dot file.")
print("You can also use `graphviz` utility to convert them to PNG format:")
print(" - `sudo apt install -y graphviz`")
print(" - `dot -Tpng .dot -o .png`")
print("===================================================================\n")