1# Description: Operations defined for XRT 2 3load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") 4load( 5 "//tensorflow:tensorflow.bzl", 6 "tf_custom_op_py_library", 7 "tf_gen_op_wrapper_py", 8) 9load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") 10load( 11 "//tensorflow/core/platform:build_config.bzl", 12 "tf_proto_library", 13) 14load( 15 "@local_config_cuda//cuda:build_defs.bzl", 16 "if_cuda", 17) 18 19package( 20 default_visibility = [ 21 "//learning/brain:__subpackages__", 22 "//tensorflow/compiler/xrt:__subpackages__", 23 ], 24 licenses = ["notice"], 25) 26 27tf_proto_library( 28 name = "xrt_proto", 29 srcs = ["xrt.proto"], 30 cc_api_version = 2, 31 protodeps = [ 32 "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", 33 "//tensorflow/compiler/xla:xla_data_proto", 34 "//tensorflow/compiler/xla:xla_proto", 35 "//tensorflow/compiler/xla/service:hlo_proto", 36 ], 37 visibility = ["//visibility:public"], 38) 39 40cc_library( 41 name = "xrt_tpu_utils", 42 srcs = [ 43 "xrt_tpu_device.cc", 44 ], 45 hdrs = [ 46 "xrt_tpu_device.h", 47 ], 48 visibility = ["//visibility:public"], 49 deps = [ 50 "//tensorflow/compiler/jit:xla_device", 51 "//tensorflow/compiler/xla/client:local_client", 52 "//tensorflow/core:framework", 53 "//tensorflow/core:lib", 54 "//tensorflow/core/tpu:tpu_configuration", 55 "//tensorflow/stream_executor/tpu:tpu_node_context", 56 ], 57) 58 59cc_library( 60 name = "xrt_utils", 61 srcs = [ 62 "xrt_compilation_cache.cc", 63 "xrt_device.cc", 64 "xrt_memory_manager.cc", 65 "xrt_metrics.cc", 66 "xrt_state.cc", 67 "xrt_util.cc", 68 ], 69 hdrs = [ 70 "xrt_compilation_cache.h", 71 "xrt_device.h", 72 "xrt_memory_manager.h", 73 "xrt_metrics.h", 74 "xrt_refptr.h", 75 "xrt_state.h", 76 "xrt_util.h", 77 ], 78 copts = if_cuda(["-DGOOGLE_CUDA=1"]), 79 visibility = ["//visibility:public"], 80 deps = [ 81 ":xrt_proto_cc", 82 "//tensorflow/compiler/jit:xla_device", 83 "//tensorflow/compiler/tf2xla:xla_compiler", 84 "//tensorflow/compiler/xla:debug_options_flags", 85 "//tensorflow/compiler/xla:literal", 86 "//tensorflow/compiler/xla:shape_util", 87 "//tensorflow/compiler/xla:status_macros", 88 "//tensorflow/compiler/xla:statusor", 89 "//tensorflow/compiler/xla:types", 90 "//tensorflow/compiler/xla:xla_data_proto_cc", 91 "//tensorflow/compiler/xla:xla_proto_cc", 92 "//tensorflow/compiler/xla/client:local_client", 93 "//tensorflow/compiler/xla/service:backend", 94 "//tensorflow/compiler/xla/service:executable", 95 "//tensorflow/compiler/xla/service:hlo", 96 "//tensorflow/compiler/xla/service:shaped_buffer", 97 "//tensorflow/core:core_cpu_internal", 98 "//tensorflow/core:framework", 99 "//tensorflow/core:lib", 100 "//tensorflow/core:lib_internal", 101 "//tensorflow/core/common_runtime/gpu:gpu_runtime", 102 "//tensorflow/core/platform:regexp", 103 "//tensorflow/core/profiler/lib:traceme", 104 "//tensorflow/stream_executor", 105 "//tensorflow/stream_executor:device_memory_allocator", 106 "//tensorflow/stream_executor:tf_allocator_adapter", 107 "@com_google_absl//absl/container:flat_hash_map", 108 "@com_google_absl//absl/container:node_hash_map", 109 "@com_google_absl//absl/memory", 110 "@com_google_absl//absl/synchronization", 111 ], 112) 113 114tf_gen_op_libs( 115 op_lib_names = [ 116 "xrt_compile_ops", 117 "xrt_state_ops", 118 "xrt_execute_op", 119 ], 120 deps = [ 121 "//tensorflow/compiler/jit:flags", 122 "//tensorflow/core:lib", 123 ], 124) 125 126tf_gen_op_wrapper_py( 127 name = "xrt_ops_wrapper_py", 128 out = "xrt_ops.py", 129 deps = [ 130 ":xrt_compile_ops_op_lib", 131 ":xrt_execute_op_op_lib", 132 ":xrt_state_ops_op_lib", 133 ], 134) 135 136tf_custom_op_py_library( 137 name = "xrt_ops", 138 kernels = ["//tensorflow/compiler/xrt/kernels:xrt_ops"], 139 visibility = ["//visibility:public"], 140 deps = [ 141 ":xrt_ops_wrapper_py", 142 ], 143) 144 145cc_library( 146 name = "xrt_server", 147 visibility = ["//visibility:public"], 148 deps = [ 149 ":xrt_compile_ops_op_lib", 150 ":xrt_execute_op_op_lib", 151 ":xrt_state_ops_op_lib", 152 "//tensorflow/compiler/xrt/kernels:xrt_ops", 153 ], 154) 155 156# copybara:uncomment_begin(google-only) 157# py_proto_library( 158# name = "xrt_proto_py_pb2", 159# api_version = 2, 160# visibility = ["//visibility:public"], 161# deps = [":xrt_proto"], 162# ) 163# copybara:uncomment_end 164