• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
3#   and provide TensorRT operators and converter package.
4#   APIs are meant to change over time.
5
6load("//tensorflow:tensorflow.bzl", "cuda_py_test")
7load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
8
9package(
10    default_visibility = ["//visibility:public"],
11    licenses = ["notice"],  # Apache 2.0
12)
13
14exports_files(["LICENSE"])
15
16exports_files(glob([
17    "test/testdata/*",
18]))
19
20py_library(
21    name = "init_py",
22    srcs = ["__init__.py"],
23    srcs_version = "PY2AND3",
24    deps = [
25        ":tf_trt_integration_test_base",
26        ":trt_convert_py",
27    ],
28)
29
30py_library(
31    name = "trt_convert_py",
32    srcs = ["trt_convert.py"],
33    srcs_version = "PY2AND3",
34    deps = [
35        "//tensorflow/compiler/tf2tensorrt:trt_ops_loader",
36        "//tensorflow/compiler/tf2tensorrt:wrap_py_utils",
37        "//tensorflow/python:convert_to_constants",
38        "//tensorflow/python:func_graph",
39        "//tensorflow/python:graph_util",
40        "//tensorflow/python:session",
41        "//tensorflow/python:tf_optimizer",
42        "//tensorflow/python/eager:context",
43        "//tensorflow/python/eager:function",
44        "//tensorflow/python/saved_model:builder",
45        "//tensorflow/python/saved_model:load",
46        "//tensorflow/python/saved_model:loader",
47        "//tensorflow/python/saved_model:save",
48        "//tensorflow/python/saved_model:signature_constants",
49        "//tensorflow/python/saved_model:tag_constants",
50    ],
51)
52
53py_library(
54    name = "trt_convert_windows",
55    srcs = ["trt_convert_windows.py"],
56    srcs_version = "PY2AND3",
57    deps = [
58        "//tensorflow/python:util",
59    ],
60)
61
62py_library(
63    name = "tf_trt_integration_test_base",
64    srcs = ["test/tf_trt_integration_test_base.py"],
65    deps = [
66        ":trt_convert_py",
67        "//tensorflow/python:client_testlib",
68        "//tensorflow/python:framework_test_lib",
69        "//tensorflow/python/saved_model:builder",
70        "//tensorflow/python/saved_model:loader",
71        "//tensorflow/python/saved_model:signature_constants",
72        "//tensorflow/python/saved_model:signature_def_utils",
73        "//tensorflow/python/saved_model:tag_constants",
74        "//tensorflow/python/saved_model:utils",
75        "//tensorflow/python/tools:saved_model_utils",
76    ],
77)
78
79cuda_py_test(
80    name = "trt_convert_test",
81    srcs = ["trt_convert_test.py"],
82    data = [
83        "test/testdata/tftrt_2.0_saved_model/saved_model.pb",
84        "test/testdata/tftrt_2.0_saved_model/variables/variables.data-00000-of-00002",
85        "test/testdata/tftrt_2.0_saved_model/variables/variables.data-00001-of-00002",
86        "test/testdata/tftrt_2.0_saved_model/variables/variables.index",
87    ],
88    python_version = "PY3",
89    tags = [
90        "no_cuda_on_cpu_tap",
91        "no_pip",
92        "no_rocm",
93        "no_windows",
94        "nomac",
95    ],
96    deps = [
97        ":trt_convert_py",
98        "//tensorflow/python:client_testlib",
99        "//tensorflow/python:framework_test_lib",
100        "//tensorflow/python:graph_util",
101        "//tensorflow/python/saved_model:builder",
102        "//tensorflow/python/saved_model:loader",
103        "//tensorflow/python/saved_model:signature_constants",
104        "//tensorflow/python/saved_model:signature_def_utils",
105        "//tensorflow/python/saved_model:tag_constants",
106        "//tensorflow/python/saved_model:utils",
107        "//tensorflow/python/tools:freeze_graph_lib",
108        "//tensorflow/python/tools:saved_model_utils",
109        "@absl_py//absl/testing:parameterized",
110    ],
111)
112
113cuda_py_tests(
114    name = "tf_trt_integration_test",
115    srcs = [
116        "test/base_test.py",
117        "test/batch_matmul_test.py",
118        "test/binary_tensor_weight_broadcast_test.py",
119        "test/combined_nms_test.py",
120        "test/const_broadcast_test.py",
121        "test/conv2d_test.py",
122        "test/dynamic_input_shapes_test.py",
123        "test/identity_output_test.py",
124        "test/int32_test.py",
125        "test/lru_cache_test.py",
126        "test/memory_alignment_test.py",
127        "test/multi_connection_neighbor_engine_test.py",
128        "test/neighboring_engine_test.py",
129        "test/quantization_test.py",
130        "test/rank_two_test.py",
131        "test/reshape_transpose_test.py",
132        "test/topk_test.py",
133        "test/unary_test.py",
134        "test/vgg_block_nchw_test.py",
135        "test/vgg_block_test.py",
136    ],
137    python_version = "PY3",
138    tags = [
139        "no_cuda_on_cpu_tap",
140        "no_rocm",
141        "no_windows",
142        "nomac",
143    ],
144    deps = [
145        ":tf_trt_integration_test_base",
146        "//tensorflow/python:client_testlib",
147        "//tensorflow/python:framework_test_lib",
148    ],
149)
150
151cuda_py_tests(
152    name = "concatenation_test",
153    srcs = [
154        "test/biasadd_matmul_test.py",
155        "test/concatenation_test.py",
156    ],
157    python_version = "PY3",
158    tags = [
159        "no_rocm",
160        "no_windows",
161        "nomac",
162        "notap",  # b/140261407
163    ],
164    deps = [
165        ":tf_trt_integration_test_base",
166        "//tensorflow/python:client_testlib",
167        "//tensorflow/python:framework_test_lib",
168    ],
169)
170
171cuda_py_test(
172    name = "quantization_mnist_test",
173    srcs = ["test/quantization_mnist_test.py"],
174    data = [
175        "test/testdata/mnist/checkpoint",
176        "test/testdata/mnist/model.ckpt-46900.data-00000-of-00001",
177        "test/testdata/mnist/model.ckpt-46900.index",
178    ],
179    python_version = "PY3",
180    tags = [
181        "no_cuda_on_cpu_tap",
182        "no_oss",  # TODO(b/125290478): allow running in at least some OSS configurations.
183        "no_pip",
184        "no_rocm",
185        "no_tap",  # It is not able to download the mnist data.
186        "no_windows",
187        "nomac",
188    ],
189    deps = [
190        ":tf_trt_integration_test_base",
191        "//tensorflow/python:client_testlib",
192        "//tensorflow/python:framework_test_lib",
193        "//tensorflow/python/estimator",
194        "//tensorflow/python/keras",
195    ],
196)
197