1# Description: 2# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow 3# and provide TensorRT operators and converter package. 4# APIs are meant to change over time. 5 6load("//tensorflow:tensorflow.bzl", "cuda_py_test") 7load("//tensorflow:tensorflow.bzl", "cuda_py_tests") 8 9package( 10 default_visibility = ["//visibility:public"], 11 licenses = ["notice"], # Apache 2.0 12) 13 14exports_files(["LICENSE"]) 15 16exports_files(glob([ 17 "test/testdata/*", 18])) 19 20py_library( 21 name = "init_py", 22 srcs = ["__init__.py"], 23 srcs_version = "PY2AND3", 24 deps = [ 25 ":tf_trt_integration_test_base", 26 ":trt_convert_py", 27 ], 28) 29 30py_library( 31 name = "trt_convert_py", 32 srcs = ["trt_convert.py"], 33 srcs_version = "PY2AND3", 34 deps = [ 35 "//tensorflow/compiler/tf2tensorrt:trt_ops_loader", 36 "//tensorflow/compiler/tf2tensorrt:wrap_py_utils", 37 "//tensorflow/python:convert_to_constants", 38 "//tensorflow/python:func_graph", 39 "//tensorflow/python:graph_util", 40 "//tensorflow/python:session", 41 "//tensorflow/python:tf_optimizer", 42 "//tensorflow/python/eager:context", 43 "//tensorflow/python/eager:function", 44 "//tensorflow/python/saved_model:builder", 45 "//tensorflow/python/saved_model:load", 46 "//tensorflow/python/saved_model:loader", 47 "//tensorflow/python/saved_model:save", 48 "//tensorflow/python/saved_model:signature_constants", 49 "//tensorflow/python/saved_model:tag_constants", 50 ], 51) 52 53py_library( 54 name = "trt_convert_windows", 55 srcs = ["trt_convert_windows.py"], 56 srcs_version = "PY2AND3", 57 deps = [ 58 "//tensorflow/python:util", 59 ], 60) 61 62py_library( 63 name = "tf_trt_integration_test_base", 64 srcs = ["test/tf_trt_integration_test_base.py"], 65 deps = [ 66 ":trt_convert_py", 67 "//tensorflow/python:client_testlib", 68 "//tensorflow/python:framework_test_lib", 69 "//tensorflow/python/saved_model:builder", 70 "//tensorflow/python/saved_model:loader", 71 "//tensorflow/python/saved_model:signature_constants", 72 "//tensorflow/python/saved_model:signature_def_utils", 73 "//tensorflow/python/saved_model:tag_constants", 74 "//tensorflow/python/saved_model:utils", 75 "//tensorflow/python/tools:saved_model_utils", 76 ], 77) 78 79cuda_py_test( 80 name = "trt_convert_test", 81 srcs = ["trt_convert_test.py"], 82 data = [ 83 "test/testdata/tftrt_2.0_saved_model/saved_model.pb", 84 "test/testdata/tftrt_2.0_saved_model/variables/variables.data-00000-of-00002", 85 "test/testdata/tftrt_2.0_saved_model/variables/variables.data-00001-of-00002", 86 "test/testdata/tftrt_2.0_saved_model/variables/variables.index", 87 ], 88 python_version = "PY3", 89 tags = [ 90 "no_cuda_on_cpu_tap", 91 "no_pip", 92 "no_rocm", 93 "no_windows", 94 "nomac", 95 ], 96 deps = [ 97 ":trt_convert_py", 98 "//tensorflow/python:client_testlib", 99 "//tensorflow/python:framework_test_lib", 100 "//tensorflow/python:graph_util", 101 "//tensorflow/python/saved_model:builder", 102 "//tensorflow/python/saved_model:loader", 103 "//tensorflow/python/saved_model:signature_constants", 104 "//tensorflow/python/saved_model:signature_def_utils", 105 "//tensorflow/python/saved_model:tag_constants", 106 "//tensorflow/python/saved_model:utils", 107 "//tensorflow/python/tools:freeze_graph_lib", 108 "//tensorflow/python/tools:saved_model_utils", 109 "@absl_py//absl/testing:parameterized", 110 ], 111) 112 113cuda_py_tests( 114 name = "tf_trt_integration_test", 115 srcs = [ 116 "test/base_test.py", 117 "test/batch_matmul_test.py", 118 "test/binary_tensor_weight_broadcast_test.py", 119 "test/combined_nms_test.py", 120 "test/const_broadcast_test.py", 121 "test/conv2d_test.py", 122 "test/dynamic_input_shapes_test.py", 123 "test/identity_output_test.py", 124 "test/int32_test.py", 125 "test/lru_cache_test.py", 126 "test/memory_alignment_test.py", 127 "test/multi_connection_neighbor_engine_test.py", 128 "test/neighboring_engine_test.py", 129 "test/quantization_test.py", 130 "test/rank_two_test.py", 131 "test/reshape_transpose_test.py", 132 "test/topk_test.py", 133 "test/unary_test.py", 134 "test/vgg_block_nchw_test.py", 135 "test/vgg_block_test.py", 136 ], 137 python_version = "PY3", 138 tags = [ 139 "no_cuda_on_cpu_tap", 140 "no_rocm", 141 "no_windows", 142 "nomac", 143 ], 144 deps = [ 145 ":tf_trt_integration_test_base", 146 "//tensorflow/python:client_testlib", 147 "//tensorflow/python:framework_test_lib", 148 ], 149) 150 151cuda_py_tests( 152 name = "concatenation_test", 153 srcs = [ 154 "test/biasadd_matmul_test.py", 155 "test/concatenation_test.py", 156 ], 157 python_version = "PY3", 158 tags = [ 159 "no_rocm", 160 "no_windows", 161 "nomac", 162 "notap", # b/140261407 163 ], 164 deps = [ 165 ":tf_trt_integration_test_base", 166 "//tensorflow/python:client_testlib", 167 "//tensorflow/python:framework_test_lib", 168 ], 169) 170 171cuda_py_test( 172 name = "quantization_mnist_test", 173 srcs = ["test/quantization_mnist_test.py"], 174 data = [ 175 "test/testdata/mnist/checkpoint", 176 "test/testdata/mnist/model.ckpt-46900.data-00000-of-00001", 177 "test/testdata/mnist/model.ckpt-46900.index", 178 ], 179 python_version = "PY3", 180 tags = [ 181 "no_cuda_on_cpu_tap", 182 "no_oss", # TODO(b/125290478): allow running in at least some OSS configurations. 183 "no_pip", 184 "no_rocm", 185 "no_tap", # It is not able to download the mnist data. 186 "no_windows", 187 "nomac", 188 ], 189 deps = [ 190 ":tf_trt_integration_test_base", 191 "//tensorflow/python:client_testlib", 192 "//tensorflow/python:framework_test_lib", 193 "//tensorflow/python/estimator", 194 "//tensorflow/python/keras", 195 ], 196) 197