• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
3#   and provide TensorRT operators and converter package.
4#   APIs are meant to change over time.
5
6load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
7load(
8    "//tensorflow:tensorflow.bzl",
9    "if_google",
10    "tf_copts",
11    "tf_cuda_library",
12    "tf_custom_op_library_additional_deps",
13    "tf_gen_op_wrapper_py",
14)
15
16# buildifier: disable=same-origin-load
17load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
18
19# buildifier: disable=same-origin-load
20load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
21
22# buildifier: disable=same-origin-load
23load("//tensorflow:tensorflow.bzl", "pybind_extension")
24
25# buildifier: disable=same-origin-load
26load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
27load(
28    "//tensorflow/core/platform:build_config.bzl",
29    "tf_additional_all_protos",
30    "tf_proto_library",
31)
32load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
33
34# Platform specific build config
35load(
36    "//tensorflow/core/platform:build_config_root.bzl",
37    "if_static",
38)
39
40package(
41    default_visibility = ["//visibility:public"],
42    features = [
43        "-layering_check",
44        "-parse_headers",
45    ],
46    licenses = ["notice"],
47)
48
49# Whether or not to use the EfficientNMSPlugin from NVIDIA OSS archive.
50config_setting(
51    name = "use_efficient_nms_plugin",
52    define_values = {"use_efficient_nms_plugin": "1"},
53)
54
55# Config setting to conditionally link GPU targets.
56alias(
57    name = "efficient_nms_plugin_wrapper",
58    actual = if_google(
59        "//third_party/tensorrt/plugin:nvinfer_plugin_nms",
60        "@tensorrt_oss_archive//:nvinfer_plugin_nms",
61    ),
62)
63
64cc_library(
65    name = "efficient_nms_plugin",
66    deps = if_tensorrt([
67        ":efficient_nms_plugin_wrapper",
68    ]),
69)
70
71cc_library(
72    name = "tensorrt_stub",
73    srcs = if_tensorrt([
74        "stub/nvinfer_stub.cc",
75        "stub/nvinfer_plugin_stub.cc",
76    ]),
77    textual_hdrs = glob(["stub/*.inc"]),
78    deps = if_tensorrt([
79        "@local_config_tensorrt//:tensorrt_headers",
80        "//tensorflow/core:lib",
81        "//tensorflow/stream_executor/platform:dso_loader",
82    ]),
83)
84
85alias(
86    name = "tensorrt_lib",
87    actual = select({
88        "@local_config_tensorrt//:use_static_tensorrt": "@local_config_tensorrt//:tensorrt",
89        "//conditions:default": ":tensorrt_stub",
90    }),
91    visibility = ["//visibility:private"],
92)
93
94tf_cuda_cc_test(
95    name = "tensorrt_test_cc",
96    size = "small",
97    srcs = [
98        "tensorrt_test.cc",
99    ],
100    extra_copts = select({
101        ":use_efficient_nms_plugin": ["-DTF_TRT_USE_EFFICIENT_NMS_PLUGIN=1"],
102        "//conditions:default": [],
103    }),
104    tags = [
105        "no_cuda_on_cpu_tap",
106        "no_windows",
107        "nomac",
108    ],
109    deps = [
110        ":trt_logging",
111        ":utils",
112        "//tensorflow/core/common_runtime/gpu:gpu_init",
113        "//tensorflow/core:lib",
114        "//tensorflow/core/platform:stream_executor",
115        "//tensorflow/core:test",
116        "//tensorflow/core:test_main",
117    ] + if_tensorrt([
118        ":tensorrt_lib",
119    ]) + select({
120        ":use_efficient_nms_plugin": [":efficient_nms_plugin"],
121        "//conditions:default": [],
122    }),
123)
124
125cc_library(
126    name = "trt_convert_api",
127    srcs = ["trt_convert_api.cc"],
128    hdrs = [
129        "trt_convert_api.h",
130    ],
131    copts = tf_copts(),
132    deps = [
133        ":trt_parameters",
134        ":trt_resources",
135        "//tensorflow/cc/tools:freeze_saved_model",
136        "//tensorflow/core:direct_session",
137        "//tensorflow/core:framework",
138        "//tensorflow/core/grappler:grappler_item_builder",
139        "//tensorflow/core/grappler/clusters:single_machine",
140        "//tensorflow/core/platform:logging",
141        "@com_google_absl//absl/strings",
142    ] + if_tensorrt([":tensorrt_lib"]),
143)
144
145filegroup(
146    name = "headers",
147    srcs = [
148        "trt_convert_api.h",
149    ],
150)
151
152tf_cuda_cc_test(
153    name = "trt_convert_api_test",
154    size = "small",
155    srcs = ["trt_convert_api_test.cc"],
156    tags = [
157        "no_cuda_on_cpu_tap",
158        "no_windows",
159        "nomac",
160    ],
161    deps = [
162        ":common_utils",
163        ":testutils",
164        ":trt_conversion",
165        ":trt_convert_api",
166        ":trt_logging",
167        ":trt_op_kernels",
168        ":trt_resources",
169        ":utils",
170        "//tensorflow/cc:cc_ops",
171        "//tensorflow/cc:resource_variable_ops",
172        "//tensorflow/cc:scope",
173        "//tensorflow/core:array_ops_op_lib",
174        "//tensorflow/core:core_cpu",
175        "//tensorflow/core:core_cpu_internal",
176        "//tensorflow/core:direct_session",
177        "//tensorflow/core:framework",
178        "//tensorflow/core:function_ops_op_lib",
179        "//tensorflow/core:lib",
180        "//tensorflow/core:lib_internal",
181        "//tensorflow/core:math_ops_op_lib",
182        "//tensorflow/core:no_op_op_lib",
183        "//tensorflow/core:ops",
184        "//tensorflow/core:protos_all_cc",
185        "//tensorflow/core:state_ops_op_lib",
186        "//tensorflow/core:test",
187        "//tensorflow/core:test_main",
188        "//tensorflow/core:testlib",
189        "//tensorflow/core/kernels:array",
190        "//tensorflow/core/kernels:assign_op",
191        "//tensorflow/core/kernels:ops_testutil",
192        "//tensorflow/core/kernels:partitioned_function_ops",
193        "//tensorflow/core/kernels:resource_variable_ops",
194    ],
195)
196
197cc_library(
198    name = "common_utils",
199    srcs = ["common/utils.cc"],
200    hdrs = [
201        "common/datavec.h",
202        "common/utils.h",
203    ],
204    copts = tf_copts(),
205    deps = [
206        "//tensorflow/core:framework",
207        "//tensorflow/core/platform:logging",
208        "//tensorflow/core/profiler/lib:annotated_traceme",
209    ] + if_tensorrt([":tensorrt_lib"]),
210)
211
212cc_library(
213    name = "testutils",
214    testonly = 1,
215    srcs = ["utils/trt_testutils.cc"],
216    hdrs = [
217        "utils/trt_testutils.h",
218    ],
219    copts = tf_copts(),
220    visibility = ["//visibility:private"],
221    deps = [
222        ":trt_conversion",
223        "@com_google_absl//absl/strings",
224        "@com_google_googletest//:gtest",
225        "//tensorflow/core:protos_all_cc",
226        "//tensorflow/cc:cc_ops",
227        "//tensorflow/core/framework:tensor_testutil",
228    ] + if_tensorrt([":tensorrt_lib"]),
229)
230
231tf_cuda_cc_test(
232    name = "testutils_test",
233    size = "small",
234    srcs = ["utils/trt_testutils_test.cc"],
235    tags = [
236        "no_cuda_on_cpu_tap",
237        "no_windows",
238        "nomac",
239    ],
240    deps = [
241        ":testutils",
242        "//tensorflow/core:test_main",
243        "//tensorflow/core:protos_all_cc",
244        "//tensorflow/core/platform:protobuf",
245    ] + if_tensorrt([
246        ":tensorrt_lib",
247    ]),
248)
249
250cc_library(
251    name = "trt_op_kernels",
252    srcs = [
253        "kernels/get_calibration_data_op.cc",
254        "kernels/trt_engine_op.cc",
255    ],
256    copts = tf_copts(),
257    visibility = ["//visibility:public"],
258    deps = [
259        ":trt_allocator",
260        ":trt_conversion",
261        ":trt_engine_utils",
262        ":trt_logging",
263        ":trt_plugins",
264        ":trt_resources",
265        ":utils",
266        ":common_utils",
267        "@com_google_absl//absl/memory",
268        "@com_google_absl//absl/strings",
269        "//tensorflow/core:framework",
270        "//tensorflow/core:gpu_headers_lib",
271        "//tensorflow/core:lib",
272        "//tensorflow/core:lib_internal",
273        "//tensorflow/core:lib_proto_parsing",
274        "//tensorflow/core/platform:stream_executor",
275        "//tensorflow/core:stream_executor_headers_lib",
276        "//tensorflow/core/common_runtime:core_cpu_lib_no_ops",
277        "//tensorflow/core/grappler/costs:graph_properties",
278    ] + if_tensorrt([
279        ":tensorrt_lib",
280        "@local_config_cuda//cuda:cuda_headers",
281    ]) + tf_custom_op_library_additional_deps(),
282    alwayslink = 1,
283)
284
285cc_library(
286    name = "trt_engine_resource_op_kernels",
287    srcs = ["kernels/trt_engine_resource_ops.cc"],
288    copts = tf_copts(),
289    visibility = ["//tensorflow/core:__subpackages__"],
290    deps = [
291        ":trt_allocator",
292        ":trt_engine_instance_proto_cc",
293        ":trt_logging",
294        ":trt_plugins",
295        ":trt_resources",
296        "@com_google_absl//absl/memory",
297        "@com_google_absl//absl/strings",
298        "//tensorflow/core:framework",
299        "//tensorflow/core:gpu_headers_lib",
300        "//tensorflow/core:lib",
301        "//tensorflow/core:lib_internal",
302        "//tensorflow/core:lib_proto_parsing",
303    ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
304    alwayslink = 1,
305)
306
307tf_cuda_cc_test(
308    name = "trt_engine_resource_ops_test",
309    size = "small",
310    srcs = ["kernels/trt_engine_resource_ops_test.cc"],
311    tags = [
312        "no_cuda_on_cpu_tap",
313        "no_windows",
314        "nomac",
315    ],
316    deps = [
317        ":common_utils",
318        ":testutils",
319        ":trt_engine_instance_proto_cc",
320        ":trt_engine_resource_op_kernels",
321        ":trt_engine_resource_ops_op_lib",
322        ":trt_logging",
323        ":trt_resources",
324        ":utils",
325        "//tensorflow/core:core_cpu",
326        "//tensorflow/core:framework",
327        "//tensorflow/core:lib",
328        "//tensorflow/core:lib_internal",
329        "//tensorflow/core:ops",
330        "//tensorflow/core:protos_all_cc",
331        "//tensorflow/core:test",
332        "//tensorflow/core:test_main",
333        "//tensorflow/core/framework:fake_input",
334        "//tensorflow/core/kernels:ops_testutil",
335        "//tensorflow/core/kernels:resource_variable_ops",
336        "@com_google_absl//absl/container:inlined_vector",
337        "@com_google_absl//absl/memory",
338        "@com_google_absl//absl/strings",
339    ],
340)
341
342tf_cuda_cc_test(
343    name = "trt_engine_op_test",
344    size = "small",
345    srcs = ["kernels/trt_engine_op_test.cc"],
346    tags = [
347        "no_cuda_on_cpu_tap",
348        "no_windows",
349        "nomac",
350    ],
351    deps = [
352        ":trt_op_kernels",
353        ":trt_op_libs",
354        ":trt_resources",
355        ":trt_conversion",
356        ":testutils",
357        "@com_google_googletest//:gtest",
358        "@com_google_absl//absl/container:inlined_vector",
359        "@com_google_absl//absl/strings",
360        "@com_google_absl//absl/types:span",
361        "//third_party/eigen3",
362        "//tensorflow/cc:cc_ops",
363        "//tensorflow/cc:function_ops",
364        "//tensorflow/cc:scope",
365        "//tensorflow/core:core_cpu_lib",
366        "//tensorflow/core:framework",
367        "//tensorflow/core:lib",
368        "//tensorflow/core:lib_internal",
369        "//tensorflow/core:protos_all_cc",
370        "//tensorflow/core:test",
371        "//tensorflow/core:test_main",
372        "//tensorflow/core/kernels:ops_testutil",
373        "//tensorflow/core/kernels:function_ops",
374        "//tensorflow/core/kernels:array",
375        "//tensorflow/core/framework:fake_input",
376    ] + if_tensorrt([
377        "@local_config_cuda//cuda:cuda_headers",
378    ]),
379)
380
381tf_gen_op_libs(
382    op_lib_names = [
383        "trt_engine_op",
384        "get_calibration_data_op",
385        "trt_engine_resource_ops",
386    ],
387)
388
389cc_library(
390    name = "trt_op_libs",
391    deps = [
392        ":get_calibration_data_op_op_lib",
393        ":trt_engine_op_op_lib",
394        ":trt_engine_utils",
395    ],
396)
397
398tf_cuda_library(
399    name = "trt_engine_utils",
400    srcs = [
401        "utils/trt_engine_utils.cc",
402        "utils/trt_shape_optimization_profiles.cc",
403    ],
404    hdrs = [
405        "utils/trt_engine_utils.h",
406        "utils/trt_execution_context.h",
407        "utils/trt_shape_optimization_profiles.h",
408    ],
409    deps = [
410        ":common_utils",
411        ":trt_logging",
412        ":utils",
413        ":trt_allocator",
414        ":trt_parameters",
415        "@com_google_absl//absl/strings",
416        "//tensorflow/core:framework",
417        "//tensorflow/core:framework_headers_lib",
418        "//tensorflow/core:lib",
419        "//tensorflow/core/platform:status",
420        "//tensorflow/core/profiler/lib:annotated_traceme",
421        "//tensorflow/core:stream_executor_headers_lib",
422    ] + if_tensorrt([":tensorrt_lib"]),
423)
424
425tf_cuda_library(
426    name = "trt_logging",
427    srcs = ["utils/trt_logger.cc"],
428    hdrs = ["utils/trt_logger.h"],
429    visibility = ["//visibility:public"],
430    deps = [
431        ":common_utils",
432        ":logger_registry",
433        ":utils",
434        "@com_google_absl//absl/strings",
435        "//tensorflow/core:lib_proto_parsing",
436    ] + if_tensorrt([":tensorrt_lib"]),
437)
438
439tf_gen_op_wrapper_py(
440    name = "trt_ops",
441    deps = [
442        ":trt_engine_resource_ops_op_lib",
443        ":trt_op_libs",
444    ],
445)
446
447tf_custom_op_py_library(
448    name = "trt_ops_loader",
449    srcs_version = "PY3",
450    deps = [
451        ":_pywrap_py_utils",
452        ":trt_ops",
453        "//tensorflow/python:errors",
454        "//tensorflow/python:framework_for_generated_wrappers",
455        "//tensorflow/python:platform",
456        "//tensorflow/python:resources",
457    ],
458)
459
460tf_cuda_library(
461    name = "trt_parameters",
462    srcs = ["convert/trt_parameters.cc"],
463    hdrs = [
464        "convert/trt_parameters.h",
465    ],
466    copts = tf_copts(),
467    deps = [
468        ":utils",
469        "@com_google_absl//absl/strings",
470        "//tensorflow/core:lib",
471        "//tensorflow/core:framework",
472    ] + if_tensorrt([":tensorrt_lib"]),
473)
474
475tf_cuda_library(
476    name = "trt_resources",
477    srcs = [
478        "utils/trt_int8_calibrator.cc",
479        "utils/trt_lru_cache.cc",
480    ],
481    hdrs = [
482        "utils/trt_int8_calibrator.h",
483        "utils/trt_lru_cache.h",
484        "utils/trt_shape_optimization_profiles.h",
485        "utils/trt_tensor_proxy.h",
486    ],
487    deps = [
488        ":common_utils",
489        ":trt_allocator",
490        ":trt_engine_utils",
491        ":trt_logging",
492        ":utils",
493        "//tensorflow/core:framework_headers_lib",
494        "//tensorflow/core:framework_lite",
495        "//tensorflow/core/grappler:op_types",
496        "//tensorflow/core:graph",
497        "//tensorflow/core:gpu_runtime",
498        "//tensorflow/core:lib_proto_parsing",
499    ] + if_tensorrt([":tensorrt_lib"]),
500)
501
502tf_cuda_library(
503    name = "trt_allocator",
504    srcs = ["utils/trt_allocator.cc"],
505    hdrs = ["utils/trt_allocator.h"],
506    deps = [
507        "//tensorflow/core:framework_headers_lib",
508        "//tensorflow/core:framework_lite",
509        "//tensorflow/core:lib_proto_parsing",
510    ] + if_tensorrt([":tensorrt_lib"]),
511)
512
513tf_cuda_cc_test(
514    name = "trt_allocator_test",
515    size = "small",
516    srcs = ["utils/trt_allocator_test.cc"],
517    tags = [
518        "no_windows",
519        "nomac",
520    ],
521    deps = [
522        ":trt_allocator",
523        "//tensorflow/core:test",
524        "//tensorflow/core:test_main",
525    ],
526)
527
528tf_cuda_cc_test(
529    name = "trt_lru_cache_test",
530    size = "small",
531    srcs = ["utils/trt_lru_cache_test.cc"],
532    tags = [
533        "no_windows",
534        "nomac",
535    ],
536    deps = [
537        ":trt_resources",
538        "//tensorflow/core:test",
539        "//tensorflow/core:test_main",
540    ],
541)
542
543tf_cuda_cc_test(
544    name = "trt_shape_optimization_profiles_test",
545    size = "small",
546    srcs = ["utils/trt_shape_optimization_profiles_test.cc"],
547    tags = [
548        "no_cuda_on_cpu_tap",
549        "no_windows",
550        "nomac",
551    ],
552    deps = [
553        ":trt_resources",
554        "//tensorflow/core:test",
555        "//tensorflow/core:test_main",
556    ],
557)
558
559tf_cuda_library(
560    name = "logger_registry",
561    srcs = ["convert/logger_registry.cc"],
562    hdrs = [
563        "convert/logger_registry.h",
564    ],
565    copts = tf_copts(),
566    deps = [
567        "@com_google_absl//absl/strings",
568        "//tensorflow/core:lib",
569    ] + if_tensorrt([":tensorrt_lib"]),
570)
571
572tf_cuda_library(
573    name = "trt_weights",
574    srcs = ["convert/weights.cc"],
575    hdrs = [
576        "convert/weights.h",
577    ],
578    copts = tf_copts(),
579    deps = [
580        ":utils",
581        "//tensorflow/core:lib",
582        "//tensorflow/core:framework",
583    ] + if_tensorrt([":tensorrt_lib"]),
584)
585
586tf_cuda_library(
587    name = "op_converter",
588    srcs = [],
589    hdrs = [
590        "convert/op_converter.h",
591    ],
592    deps = [
593        ":trt_parameters",
594        ":trt_weights",
595    ] + if_tensorrt([":tensorrt_lib"]),
596)
597
598# This rule contains static variables for the converter registry. Do not depend
599# on it directly; use :op_converter_registry, and link against
600# libtensorflow_framework.so for the registry symbols. The library
601# libtensorflow_framework.so depends on this target so that users can
602# register custom op converters without the need to incorporate Tensorflow into
603# their build system.
604tf_cuda_library(
605    name = "op_converter_registry_impl",
606    srcs = ["convert/op_converter_registry.cc"],
607    hdrs = [
608        "convert/op_converter_registry.h",
609    ],
610    visibility = ["//tensorflow:__subpackages__"],
611    deps = [
612        ":utils",
613        ":op_converter",
614        "@com_google_absl//absl/strings",
615        "//tensorflow/core:lib",
616    ] + if_tensorrt([":tensorrt_lib"]),
617)
618
619tf_cuda_library(
620    name = "op_converter_registry",
621    hdrs = [
622        "convert/op_converter_registry.h",
623    ],
624    copts = tf_copts(),
625    deps = [
626        ":utils",
627        ":op_converter",
628        "//tensorflow/core:lib",
629    ] + if_static([":op_converter_registry_impl"]),
630)
631
632tf_cuda_cc_test(
633    name = "op_converter_registry_test",
634    size = "small",
635    srcs = ["convert/op_converter_registry_test.cc"],
636    tags = [
637        "no_windows",
638        "nomac",
639    ],
640    deps = [
641        ":op_converter_registry",
642        "//tensorflow/core:test",
643        "//tensorflow/core:test_main",
644    ],
645)
646
647tf_cuda_library(
648    name = "algorithm_selector",
649    srcs = [
650        "convert/algorithm_selector.cc",
651    ],
652    hdrs = [
653        "convert/algorithm_selector.h",
654    ],
655    deps = [":common_utils"] + if_tensorrt([":tensorrt_lib"]),
656)
657
658tf_cuda_cc_test(
659    name = "algorithm_selector_test",
660    srcs = [
661        "convert/algorithm_selector_test.cc",
662    ],
663    deps = [
664        ":algorithm_selector",
665        "//tensorflow/core:test",
666        "//tensorflow/core:test_main",
667    ] + if_tensorrt([":tensorrt_lib"]),
668)
669
670# Library for the node-level conversion portion of TensorRT operation creation
671tf_cuda_library(
672    name = "trt_conversion",
673    srcs = [
674        "convert/convert_graph.cc",
675        "convert/convert_nodes.cc",
676        "convert/ops/binary_ops.cc",
677        "convert/ops/data_format_vec_permute.cc",
678        "convert/ops/einsum.cc",
679        "convert/ops/fill_ops.cc",
680        "convert/ops/like_ops.cc",
681        "convert/ops/log_softmax.cc",
682        "convert/ops/quantization_ops.cc",
683        "convert/ops/slice_ops.cc",
684        "convert/ops/tile.cc",
685        "convert/ops/unary_ops.cc",
686        "convert/ops/variable_ops.cc",
687        "convert/timing_cache.cc",
688        "convert/trt_optimization_pass.cc",
689    ],
690    hdrs = [
691        "convert/convert_graph.h",
692        "convert/convert_nodes.h",
693        "convert/ops/layer_utils.h",
694        "convert/ops/quantization_ops.h",
695        "convert/ops/slice_ops.h",
696        "convert/timing_cache.h",
697        "convert/trt_optimization_pass.h",
698    ],
699    copts = tf_copts() + select({
700        ":use_efficient_nms_plugin": ["-DTF_TRT_USE_EFFICIENT_NMS_PLUGIN=1"],
701        "//conditions:default": [],
702    }),
703    deps = [
704        ":algorithm_selector",
705        ":common_utils",
706        ":logger_registry",
707        ":segment",
708        ":trt_allocator",
709        ":trt_parameters",
710        ":trt_plugins",
711        ":trt_logging",
712        ":trt_resources",
713        ":utils",
714        ":trt_weights",
715        ":op_converter",
716        ":op_converter_registry",
717        "@com_google_absl//absl/memory",
718        "@com_google_absl//absl/strings",
719        "//tensorflow/cc:array_ops",
720        "//tensorflow/core/common_runtime:core_cpu",
721        "//tensorflow/core/grappler/clusters:cluster",
722        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
723        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
724        "//tensorflow/core/grappler:grappler_item",
725        "//tensorflow/core/grappler:op_types",
726        "//tensorflow/core/grappler:utils",
727        "//tensorflow/core/grappler/utils:functions",
728        "//tensorflow/core:framework",
729        "//tensorflow/core:framework_lite",
730        "//tensorflow/core:gpu_runtime",
731        "//tensorflow/core:graph",
732        "//tensorflow/core:lib",
733        "//tensorflow/core:lib_internal",
734        "//tensorflow/core:protos_all_cc",
735        "//tensorflow/core/grappler:devices",
736        "//tensorflow/core/grappler/clusters:virtual_cluster",
737        "//tensorflow/core/grappler/costs:graph_properties",
738        "//tensorflow/core/grappler/optimizers:meta_optimizer",
739        "//tensorflow/core/profiler/lib:annotated_traceme",
740        "//tensorflow/stream_executor/lib",
741        "//tensorflow/tools/graph_transforms:transform_utils",
742    ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps() + select({
743        ":use_efficient_nms_plugin": [":efficient_nms_plugin"],
744        "//conditions:default": [],
745    }),
746    alwayslink = 1,
747)
748
749tf_cuda_cc_test(
750    name = "convert_graph_test",
751    size = "medium",
752    srcs = ["convert/convert_graph_test.cc"],
753    tags = [
754        "no_cuda_on_cpu_tap",
755        "no_windows",
756        "nomac",
757    ],
758    deps = [
759        ":trt_op_kernels",
760        ":trt_op_libs",
761        ":trt_conversion",
762        ":testutils",
763        "@com_google_googletest//:gtest",
764        "@com_google_absl//absl/strings",
765        "//tensorflow/cc:cc_ops",
766        "//tensorflow/cc:ops",
767        "//tensorflow/cc:scope",
768        "//tensorflow/core/grappler:grappler_item",
769        "//tensorflow/core/grappler/clusters:cluster",
770        "//tensorflow/core:core_cpu",
771        "//tensorflow/core:core_cpu_base",
772        "//tensorflow/core:direct_session",
773        "//tensorflow/core:framework",
774        "//tensorflow/core:lib",
775        "//tensorflow/core:protos_all_cc",
776        "//tensorflow/core:test",
777        "//tensorflow/core:test_main",
778    ] + if_tensorrt([":tensorrt_lib"]),
779)
780
781tf_cuda_cc_test(
782    name = "convert_nodes_test",
783    size = "medium",
784    srcs = [
785        "convert/convert_nodes_test.cc",
786        "convert/op_converter_test.cc",
787    ],
788    tags = [
789        "no_cuda_on_cpu_tap",
790        "no_windows",
791        "nomac",
792    ],
793    deps = [
794        ":trt_logging",
795        ":trt_conversion",
796        ":trt_plugins",
797        ":trt_engine_utils",
798        ":utils",
799        ":testutils",
800        "@com_google_googletest//:gtest",
801        "@com_google_absl//absl/strings",
802        "@com_google_absl//absl/types:span",
803        "//tensorflow/cc:cc_ops",
804        "//tensorflow/cc:cc_ops_internal",
805        "//tensorflow/cc:ops",
806        "//tensorflow/cc:scope",
807        "//tensorflow/core/grappler/costs:graph_properties",
808        "//tensorflow/core:core_cpu",
809        "//tensorflow/core:core_cpu_base",
810        "//tensorflow/core:framework",
811        "//tensorflow/core:lib",
812        "//tensorflow/core:protos_all_cc",
813        "//tensorflow/core/framework:tensor_testutil",
814        "//tensorflow/core/kernels:function_ops",
815        "//tensorflow/core/kernels:identity_op",
816        "//tensorflow/core/kernels:resource_variable_ops",
817        "//tensorflow/core:test",
818        "//tensorflow/core:test_main",
819        "//tensorflow/core/platform:status_matchers",
820    ] + if_tensorrt([
821        ":tensorrt_lib",
822        "@local_config_cuda//cuda:cuda_headers",
823    ]),
824)
825
826tf_cuda_cc_test(
827    name = "convert_qdq_test",
828    size = "medium",
829    srcs = [
830        "convert/ops/quantization_ops_test.cc",
831    ],
832    tags = [
833        "no_cuda_on_cpu_tap",
834        "no_windows",
835        "nomac",
836    ],
837    deps = [
838        ":trt_logging",
839        ":trt_conversion",
840        ":trt_convert_api",
841        ":trt_plugins",
842        ":trt_engine_utils",
843        ":trt_op_kernels",
844        ":trt_resources",
845        ":utils",
846        ":testutils",
847        "//tensorflow/compiler/jit:shape_inference",
848        "@com_google_googletest//:gtest",
849        "@com_google_absl//absl/strings",
850        "@com_google_absl//absl/types:span",
851        "//tensorflow/cc:cc_ops",
852        "//tensorflow/cc:cc_ops_internal",
853        "//tensorflow/cc:ops",
854        "//tensorflow/cc:scope",
855        "//tensorflow/core:core_cpu",
856        "//tensorflow/core:core_cpu_base",
857        "//tensorflow/core:framework",
858        "//tensorflow/core:lib",
859        "//tensorflow/core:ops",
860        "//tensorflow/core:protos_all_cc",
861        "//tensorflow/core/framework:tensor_testutil",
862        "//tensorflow/core:test",
863        "//tensorflow/core:test_main",
864        "//tensorflow/core/platform:status_matchers",
865        "//tensorflow/core/kernels:ops_testutil",
866        "//tensorflow/core/kernels:function_ops",
867        "//tensorflow/core/kernels:array",
868        "//tensorflow/core/kernels:nn",
869        "//tensorflow/core/kernels:pooling_ops",
870    ] + if_tensorrt([
871        ":tensorrt_lib",
872        "@local_config_cuda//cuda:cuda_headers",
873    ]),
874)
875
876# Library for the segmenting portion of TensorRT operation creation
877cc_library(
878    name = "union_find",
879    srcs = ["segment/union_find.cc"],
880    hdrs = [
881        "segment/union_find.h",
882    ],
883    copts = tf_copts(),
884    deps = [
885        ":utils",
886        "//tensorflow/core:framework",
887        "//tensorflow/core:lib",
888        "@com_google_absl//absl/strings",
889        "@com_google_absl//absl/strings:str_format",
890        "@com_google_absl//absl/types:optional",
891    ],
892)
893
894cc_library(
895    name = "segment",
896    srcs = ["segment/segment.cc"],
897    hdrs = [
898        "segment/segment.h",
899    ],
900    copts = tf_copts(),
901    deps = [
902        ":common_utils",
903        ":union_find",
904        ":utils",
905        "//tensorflow/core:graph",
906        "//tensorflow/core:lib",
907        "//tensorflow/core:lib_internal",
908        "//tensorflow/core:lib_proto_parsing",
909        "//tensorflow/core:protos_all_cc",
910        "//tensorflow/core/common_runtime:core_cpu",
911        "//tensorflow/core/grappler/costs:graph_properties",
912        "@com_google_absl//absl/container:flat_hash_set",
913        "@com_google_absl//absl/strings",
914        "@com_google_absl//absl/strings:str_format",
915        "@com_google_absl//absl/types:optional",
916        "@com_google_protobuf//:protobuf_headers",
917    ],
918)
919
920tf_cuda_cc_test(
921    name = "segment_test",
922    size = "small",
923    srcs = ["segment/segment_test.cc"],
924    tags = [
925        "no_cuda_on_cpu_tap",
926        "no_windows",
927        "nomac",
928    ],
929    deps = [
930        ":segment",
931        "//tensorflow/cc:cc_ops",
932        "//tensorflow/cc:scope",
933        "//tensorflow/core:core_cpu",
934        "//tensorflow/core:lib",
935        "//tensorflow/core:ops",
936        "//tensorflow/core:protos_all_cc",
937        "//tensorflow/core:test",
938        "//tensorflow/core:test_main",
939    ],
940)
941
942tf_cuda_library(
943    name = "trt_plugins",
944    srcs = ["plugin/trt_plugin.cc"],
945    hdrs = ["plugin/trt_plugin.h"],
946    deps = [
947        "//tensorflow/core:framework_lite",
948        "//tensorflow/core:lib_proto_parsing",
949    ] + if_tensorrt([":tensorrt_lib"]),
950)
951
952cc_library(
953    name = "utils",
954    srcs = [
955        "convert/utils.cc",
956        "utils/trt_experimental_features.cc",
957    ],
958    hdrs = [
959        "common/utils.h",
960        "convert/utils.h",
961        "utils/trt_experimental_features.h",
962        "utils/trt_tensor_proxy.h",
963    ],
964    copts = tf_copts(),
965    deps = [
966        "@com_google_absl//absl/algorithm:container",
967        "@com_google_absl//absl/strings",
968        "//tensorflow/core:framework",
969        "//tensorflow/core:graph",
970        "//tensorflow/core:lib_proto_parsing",
971        "//tensorflow/core:lib",
972    ] + if_tensorrt([":tensorrt_lib"]),
973)
974
975tf_proto_library(
976    name = "trt_engine_instance_proto",
977    srcs = ["utils/trt_engine_instance.proto"],
978    cc_api_version = 2,
979    protodeps = tf_additional_all_protos(),
980)
981
982tf_cuda_library(
983    name = "py_utils",
984    srcs = ["utils/py_utils.cc"],
985    hdrs = ["utils/py_utils.h"],
986    local_defines = select({
987        "@local_config_tensorrt//:use_static_tensorrt": ["TF_USE_TENSORRT_STATIC=1"],
988        "//conditions:default": [],
989    }),
990    deps = if_tensorrt([
991        ":common_utils",
992        ":tensorrt_lib",
993        ":op_converter_registry",
994        "//tensorflow/stream_executor/platform:dso_loader",
995    ]),
996)
997
998pybind_extension(
999    name = "_pywrap_py_utils",
1000    srcs = ["utils/py_utils_wrapper.cc"],
1001    link_in_framework = True,
1002    static_deps = [
1003        # TODO(b/229550590): Uncomment to use cc_shared_library instead of cc_binary.
1004        # "@bazel_tools//:__subpackages__",
1005        # "@boringssl//:__subpackages__",
1006        # "@com_github_googlecloudplatform_tensorflow_gcp_tools//:__subpackages__",
1007        # "@com_google_absl//:__subpackages__",
1008        # "@com_google_googleapis//:__subpackages__",
1009        # "@com_google_protobuf//:__subpackages__",
1010        # "@com_googlesource_code_re2//:__subpackages__",
1011        # "@curl//:__subpackages__",
1012        # "@double_conversion//:__subpackages__",
1013        # "@eigen_archive//:__subpackages__",
1014        # "@farmhash_archive//:__subpackages__",
1015        # "@fft2d//:__subpackages__",
1016        # "@gif//:__subpackages__",
1017        # "@highwayhash//:__subpackages__",
1018        # "@hwloc//:__subpackages__",
1019        # "@jsoncpp_git//:__subpackages__",
1020        # "@libjpeg_turbo//:__subpackages__",
1021        # "@libxsmm_archive//:__subpackages__",
1022        # "@llvm_openmp//:__subpackages__",
1023        # "@llvm-project//:__subpackages__",
1024        # "@llvm_terminfo//:__subpackages__",
1025        # "@llvm_zlib//:__subpackages__",
1026        # "@local_config_cuda//:__subpackages__",
1027        # "@local_config_git//:__subpackages__",
1028        # "@local_config_python//:__subpackages__",
1029        # "@local_config_rocm//:__subpackages__",
1030        # "@local_config_tensorrt//:__subpackages__",
1031        # "@local_execution_config_platform//:__subpackages__",
1032        # "@nsync//:__subpackages__",
1033        # "@platforms//:__subpackages__",
1034        # "@pybind11//:__subpackages__",
1035        # "@snappy//:__subpackages__",
1036        # "//:__subpackages__",
1037        # "@zlib//:__subpackages__",
1038    ],
1039    deps = [
1040        ":common_utils",
1041        ":py_utils",
1042        "//tensorflow/core/platform:env",
1043        "//tensorflow/core/platform:logging",
1044        "//tensorflow/core/platform:status",
1045        "//tensorflow/stream_executor",
1046        "@pybind11",
1047    ],
1048)
1049
1050# copybara:uncomment_begin(google-only)
1051# py_proto_library(
1052#     name = "trt_engine_instance_proto_py_pb2",
1053#     has_services = 0,
1054#     api_version = 2,
1055#     deps = [":trt_engine_instance_proto"],
1056# )
1057# copybara:uncomment_end
1058