1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Classes for compiling XLA computations and managing handles that refer to
17 // them.
18
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/compiler/xrt/xrt.pb.h"
34 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
35 #include "tensorflow/compiler/xrt/xrt_device.h"
36 #include "tensorflow/compiler/xrt/xrt_metrics.h"
37 #include "tensorflow/compiler/xrt/xrt_util.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/resource_mgr.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/lib/core/refcount.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/monitoring/timed.h"
46 #include "tensorflow/core/lib/strings/proto_serialization.h"
47 #include "tensorflow/core/platform/fingerprint.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace tensorflow {
51
52 namespace {
53
GenerateXlaDeviceAssignment(const xrt::DeviceAssignment & xrt_device_assignment,int num_replicas,int num_cores_per_replica,xla::DeviceAssignment * device_assignment)54 Status GenerateXlaDeviceAssignment(
55 const xrt::DeviceAssignment& xrt_device_assignment, int num_replicas,
56 int num_cores_per_replica, xla::DeviceAssignment* device_assignment) {
57 if (num_cores_per_replica !=
58 xrt_device_assignment.computation_devices_size()) {
59 return errors::InvalidArgument(
60 "Device assignment does not have the correct number of "
61 "computation_devices: num_cores_per_replica=",
62 num_cores_per_replica, " computation_devices=",
63 xrt_device_assignment.computation_devices_size());
64 }
65 for (int64_t c = 0; c < xrt_device_assignment.computation_devices_size();
66 ++c) {
67 const auto& computation_devices =
68 xrt_device_assignment.computation_devices(c);
69 if (num_replicas != computation_devices.replica_devices_size()) {
70 return errors::InvalidArgument(
71 "Device assignment does not have the correct number of "
72 "replica_device_ids: num_replicas=",
73 num_replicas,
74 " replica_devices=", computation_devices.replica_devices_size());
75 }
76 for (int64_t r = 0; r < computation_devices.replica_devices_size(); ++r) {
77 const auto& coords = computation_devices.replica_devices(r);
78 if (coords.value_size() != 4) {
79 return errors::InvalidArgument(
80 "Device assignment mesh coordinates must have 4 entries, got ",
81 coords.value_size());
82 }
83 for (int n = 0; n < 3; ++n) {
84 if (coords.value(n) != 0) {
85 return errors::InvalidArgument("Mesh coordinate at index ", n,
86 " must be 0, got ", coords.value(n));
87 }
88 }
89 (*device_assignment)(r, c) = coords.value(3);
90 }
91 }
92 return OkStatus();
93 }
94
95 class XRTCompileOp : public OpKernel {
96 public:
97 explicit XRTCompileOp(OpKernelConstruction* ctx);
98 ~XRTCompileOp() override;
99 XRTCompileOp(const XRTCompileOp&) = delete;
100 XRTCompileOp& operator=(const XRTCompileOp&) = delete;
101
102 void Compute(OpKernelContext* ctx) override;
103
104 private:
105 Status Compile(OpKernelContext* ctx,
106 const xrt::XLAComputation& computation_proto,
107 std::unique_ptr<xla::LocalExecutable>* program);
108 };
109
CompilationCacheKey(const xrt::XLAComputation & computation,string * key)110 Status CompilationCacheKey(const xrt::XLAComputation& computation,
111 string* key) {
112 const size_t size = computation.ByteSizeLong();
113 auto serialized = absl::make_unique<char[]>(size);
114 TF_RET_CHECK(
115 SerializeToBufferDeterministic(computation, serialized.get(), size));
116 uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
117 *key = absl::StrCat(fingerprint);
118 return OkStatus();
119 }
120
XRTCompileOp(OpKernelConstruction * ctx)121 XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
122
Compile(OpKernelContext * ctx,const xrt::XLAComputation & computation_proto,std::unique_ptr<xla::LocalExecutable> * program)123 Status XRTCompileOp::Compile(OpKernelContext* ctx,
124 const xrt::XLAComputation& computation_proto,
125 std::unique_ptr<xla::LocalExecutable>* program) {
126 const xrt::XLAComputationConfig& config = computation_proto.config();
127 // Sanity checks for options not yet supported.
128 int num_cores_per_replica = std::max<int>(config.num_cores_per_replica(), 1);
129 TF_RET_CHECK(num_cores_per_replica == 1);
130 TF_RET_CHECK(config.per_core_program_shape_size() == 0);
131
132 // The default config value is 0; treat it as 1 for convenience.
133 int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
134
135 // We are guaranteed that the underlying device object won't be deleted out
136 // from under us, while the ScopedRef is live.
137 class XRTGenericDeviceAccessor::ScopedRef device_ref;
138 TF_RETURN_IF_ERROR(XRTGenericDeviceAccessor::InitScopedRef(ctx, &device_ref));
139
140 xla::LocalClient* client = device_ref.client();
141
142 // There is officially no way to use XLA in a client/server architecture where
143 // client and server are built from different revisions, because the XLA team
144 // does not want to give any guarantees about the stability of the Hlo
145 // proto. For cloud TPU this is fine because server and client versions can be
146 // assumed to be synced to the same version. For general use the mechanism
147 // here (using a snapshot from XlaComputation) works as well as the "official"
148 // XLA client/server design, which serializes the same proto between client
149 // and server, so in reality is probably fine.
150 TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
151 client->LoadSnapshot(computation_proto.hlo_snapshot()));
152
153 std::vector<xla::Shape> argument_layouts(
154 config.program_shape().parameters_size());
155 std::vector<const xla::Shape*> argument_layout_ptrs(
156 config.program_shape().parameters_size());
157 for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
158 argument_layouts[i] = xla::Shape(config.program_shape().parameters(i));
159 argument_layout_ptrs[i] = &argument_layouts[i];
160 }
161 xla::ExecutableBuildOptions build_options;
162 build_options.set_device_ordinal(device_ref.device_ordinal());
163 build_options.set_num_replicas(num_replicas);
164 build_options.set_result_layout(xla::Shape(config.program_shape().result()));
165 build_options.set_device_allocator(device_ref.allocator());
166 if (config.has_debug_options()) {
167 *build_options.mutable_debug_options() =
168 BuildXlaDebugOptions(config.debug_options());
169 }
170 if (config.has_device_assignment()) {
171 xla::DeviceAssignment device_assignment(num_replicas,
172 num_cores_per_replica);
173 TF_RETURN_IF_ERROR(
174 GenerateXlaDeviceAssignment(config.device_assignment(), num_replicas,
175 num_cores_per_replica, &device_assignment));
176 build_options.set_device_assignment(device_assignment);
177 }
178
179 VLOG(1) << "Building executable";
180 TF_ASSIGN_OR_RETURN(
181 auto executables,
182 client->Compile(computation, argument_layout_ptrs, build_options));
183 TF_RET_CHECK(executables.size() == 1);
184 *program = std::move(executables[0]);
185 return OkStatus();
186 }
187
Compute(OpKernelContext * ctx)188 void XRTCompileOp::Compute(OpKernelContext* ctx) {
189 VLOG(1) << "XRTCompileOp::Compute";
190 auto timed = monitoring::MakeTimed(xrt_metrics::GetCompileCell());
191
192 ResourceMgr* rm;
193 OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
194
195 const Tensor& computation_input = ctx->input(0);
196 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
197 errors::Internal("computation input should be a string scalar"));
198
199 xrt::XLAComputation computation_proto;
200 OP_REQUIRES(ctx,
201 ParseFromTString(computation_input.scalar<tstring>()(),
202 &computation_proto),
203 errors::InvalidArgument(
204 "Unable to parse computation input to XLAComputation"));
205
206 string key;
207 OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
208
209 // Process-wide cache of XLA executables.
210 auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
211 ctx, /*max_number_of_entries=*/0);
212 OP_REQUIRES_OK(ctx, cache_or.status());
213 auto cache = std::move(cache_or).value();
214
215 int64_t uid;
216 OP_REQUIRES_OK(
217 ctx, cache->CompileIfKeyAbsent(
218 key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) {
219 VLOG(1) << "Compiling XLA executable";
220 return Compile(ctx, computation_proto, program);
221 }));
222 std::unique_ptr<XRTCompilationCacheEntryRef> entry;
223 OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry));
224
225 Tensor handle_output(DT_INT64, TensorShape({}));
226 handle_output.scalar<int64_t>()() = uid;
227 ctx->set_output(0, handle_output);
228
229 xla::LocalExecutable* executable = entry->get().get_executable();
230 xla::ProgramShapeProto program_shape = executable->executable()
231 ->module()
232 .config()
233 .entry_computation_layout()
234 .ComputeProgramShape()
235 .ToProto();
236 Tensor program_shape_output(DT_STRING, TensorShape({1}));
237 program_shape_output.vec<tstring>()(0) = program_shape.SerializeAsString();
238 ctx->set_output(1, program_shape_output);
239 }
240
241 XRTCompileOp::~XRTCompileOp() = default;
242
243 class XRTReleaseCompilationRefOp : public OpKernel {
244 public:
245 explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
246 ~XRTReleaseCompilationRefOp() override;
247 XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
248 XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
249 delete;
250
251 void Compute(OpKernelContext* ctx) override;
252 };
253
XRTReleaseCompilationRefOp(OpKernelConstruction * ctx)254 XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
255 OpKernelConstruction* ctx)
256 : OpKernel(ctx) {}
257
258 XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
259
Compute(OpKernelContext * ctx)260 void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
261 VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
262 auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
263
264 // Process-wide cache of XLA executables.
265 auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
266 ctx, /*max_number_of_entries=*/0);
267 OP_REQUIRES_OK(ctx, cache_or.status());
268 auto cache = std::move(cache_or).value();
269
270 const Tensor& keys_tensor = ctx->input(0);
271 auto flat_keys = keys_tensor.flat<int64_t>();
272 for (int64_t i = 0; i < flat_keys.size(); ++i) {
273 int64_t key = flat_keys(i);
274 OP_REQUIRES_OK(ctx, cache->Release(key));
275 VLOG(2) << "Released computation handle " << key;
276 }
277 }
278
279 } // namespace
280
281 REGISTER_KERNEL_BUILDER(Name("XRTCompile")
282 .Device(DEVICE_XLA_CPU)
283 .HostMemory("computation")
284 .HostMemory("handle"),
285 XRTCompileOp);
286 REGISTER_KERNEL_BUILDER(Name("XRTCompile")
287 .Device(DEVICE_XLA_GPU)
288 .HostMemory("computation")
289 .HostMemory("handle"),
290 XRTCompileOp);
291
292 REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
293 .Device(DEVICE_XLA_CPU)
294 .HostMemory("handle"),
295 XRTReleaseCompilationRefOp);
296 REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
297 .Device(DEVICE_XLA_GPU)
298 .HostMemory("handle"),
299 XRTReleaseCompilationRefOp);
300
301 } // namespace tensorflow
302