• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description: Operations defined for XRT
2
3load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
4load(
5    "//tensorflow:tensorflow.bzl",
6    "tf_custom_op_py_library",
7    "tf_gen_op_wrapper_py",
8)
9load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
10load(
11    "//tensorflow/core/platform:build_config.bzl",
12    "tf_proto_library",
13)
14load(
15    "@local_config_cuda//cuda:build_defs.bzl",
16    "if_cuda",
17)
18
19package(
20    default_visibility = [
21        "//learning/brain:__subpackages__",
22        "//tensorflow/compiler/xrt:__subpackages__",
23    ],
24    licenses = ["notice"],
25)
26
27tf_proto_library(
28    name = "xrt_proto",
29    srcs = ["xrt.proto"],
30    cc_api_version = 2,
31    protodeps = [
32        "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
33        "//tensorflow/compiler/xla:xla_data_proto",
34        "//tensorflow/compiler/xla:xla_proto",
35        "//tensorflow/compiler/xla/service:hlo_proto",
36    ],
37    visibility = ["//visibility:public"],
38)
39
40cc_library(
41    name = "xrt_tpu_utils",
42    srcs = [
43        "xrt_tpu_device.cc",
44    ],
45    hdrs = [
46        "xrt_tpu_device.h",
47    ],
48    visibility = ["//visibility:public"],
49    deps = [
50        "//tensorflow/compiler/jit:xla_device",
51        "//tensorflow/compiler/xla/client:local_client",
52        "//tensorflow/core:framework",
53        "//tensorflow/core:lib",
54        "//tensorflow/core/tpu:tpu_configuration",
55        "//tensorflow/stream_executor/tpu:tpu_node_context",
56    ],
57)
58
59cc_library(
60    name = "xrt_utils",
61    srcs = [
62        "xrt_compilation_cache.cc",
63        "xrt_device.cc",
64        "xrt_memory_manager.cc",
65        "xrt_metrics.cc",
66        "xrt_state.cc",
67        "xrt_util.cc",
68    ],
69    hdrs = [
70        "xrt_compilation_cache.h",
71        "xrt_device.h",
72        "xrt_memory_manager.h",
73        "xrt_metrics.h",
74        "xrt_refptr.h",
75        "xrt_state.h",
76        "xrt_util.h",
77    ],
78    copts = if_cuda(["-DGOOGLE_CUDA=1"]),
79    visibility = ["//visibility:public"],
80    deps = [
81        ":xrt_proto_cc",
82        "//tensorflow/compiler/jit:xla_device",
83        "//tensorflow/compiler/tf2xla:xla_compiler",
84        "//tensorflow/compiler/xla:debug_options_flags",
85        "//tensorflow/compiler/xla:literal",
86        "//tensorflow/compiler/xla:shape_util",
87        "//tensorflow/compiler/xla:status_macros",
88        "//tensorflow/compiler/xla:statusor",
89        "//tensorflow/compiler/xla:types",
90        "//tensorflow/compiler/xla:xla_data_proto_cc",
91        "//tensorflow/compiler/xla:xla_proto_cc",
92        "//tensorflow/compiler/xla/client:local_client",
93        "//tensorflow/compiler/xla/service:backend",
94        "//tensorflow/compiler/xla/service:executable",
95        "//tensorflow/compiler/xla/service:hlo",
96        "//tensorflow/compiler/xla/service:shaped_buffer",
97        "//tensorflow/core:core_cpu_internal",
98        "//tensorflow/core:framework",
99        "//tensorflow/core:lib",
100        "//tensorflow/core:lib_internal",
101        "//tensorflow/core/common_runtime/gpu:gpu_runtime",
102        "//tensorflow/core/platform:regexp",
103        "//tensorflow/core/profiler/lib:traceme",
104        "//tensorflow/stream_executor",
105        "//tensorflow/stream_executor:device_memory_allocator",
106        "//tensorflow/stream_executor:tf_allocator_adapter",
107        "@com_google_absl//absl/container:flat_hash_map",
108        "@com_google_absl//absl/container:node_hash_map",
109        "@com_google_absl//absl/memory",
110        "@com_google_absl//absl/synchronization",
111    ],
112)
113
114tf_gen_op_libs(
115    op_lib_names = [
116        "xrt_compile_ops",
117        "xrt_state_ops",
118        "xrt_execute_op",
119    ],
120    deps = [
121        "//tensorflow/compiler/jit:flags",
122        "//tensorflow/core:lib",
123    ],
124)
125
126tf_gen_op_wrapper_py(
127    name = "xrt_ops_wrapper_py",
128    out = "xrt_ops.py",
129    deps = [
130        ":xrt_compile_ops_op_lib",
131        ":xrt_execute_op_op_lib",
132        ":xrt_state_ops_op_lib",
133    ],
134)
135
136tf_custom_op_py_library(
137    name = "xrt_ops",
138    kernels = ["//tensorflow/compiler/xrt/kernels:xrt_ops"],
139    visibility = ["//visibility:public"],
140    deps = [
141        ":xrt_ops_wrapper_py",
142    ],
143)
144
145cc_library(
146    name = "xrt_server",
147    visibility = ["//visibility:public"],
148    deps = [
149        ":xrt_compile_ops_op_lib",
150        ":xrt_execute_op_op_lib",
151        ":xrt_state_ops_op_lib",
152        "//tensorflow/compiler/xrt/kernels:xrt_ops",
153    ],
154)
155
156# copybara:uncomment_begin(google-only)
157# py_proto_library(
158#     name = "xrt_proto_py_pb2",
159#     api_version = 2,
160#     visibility = ["//visibility:public"],
161#     deps = [":xrt_proto"],
162# )
163# copybara:uncomment_end
164