1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/client.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace xla {
36
Client(ServiceInterface * stub)37 Client::Client(ServiceInterface* stub) : stub_(stub) {}
38
39 Client::~Client() = default;
40
Transfer(const GlobalData & data,const Shape * shape_with_layout)41 StatusOr<Literal> Client::Transfer(const GlobalData& data,
42 const Shape* shape_with_layout) {
43 TransferToClientRequest request;
44 *request.mutable_data() = data.handle();
45 if (shape_with_layout != nullptr) {
46 *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
47 }
48 TransferToClientResponse response;
49
50 VLOG(1) << "making transfer request";
51 VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
52 Status s = stub_->TransferToClient(&request, &response);
53 VLOG(1) << "done with request";
54
55 if (!s.ok()) {
56 return s;
57 }
58 VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";
59
60 if (!response.has_literal()) {
61 return FailedPrecondition(
62 "server provided response without a literal in "
63 "TransferToClient request");
64 }
65 return Literal::CreateFromProto(*response.mutable_literal());
66 }
67
TransferToServer(const LiteralSlice & literal,const DeviceHandle * device_handle)68 StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
69 const LiteralSlice& literal, const DeviceHandle* device_handle) {
70 TransferToServerRequest request;
71 *request.mutable_literal() = literal.ToProto();
72 if (device_handle) {
73 *request.mutable_device_handle() = *device_handle;
74 }
75 TransferToServerResponse response;
76
77 VLOG(1) << "making transfer to server request";
78 VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
79 Status s = stub_->TransferToServer(&request, &response);
80 VLOG(1) << "done with request";
81
82 if (!s.ok()) {
83 return s;
84 }
85 VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";
86
87 if (!response.has_data()) {
88 return FailedPrecondition(
89 "server provided response without a data handle in "
90 "TransferToServer request");
91 }
92
93 return absl::make_unique<GlobalData>(stub_, response.data());
94 }
95
TransferToInfeed(const LiteralSlice & literal,int64 replica_id,const DeviceHandle * device_handle)96 Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
97 const DeviceHandle* device_handle) {
98 TransferToInfeedRequest request;
99 *request.mutable_literal() = literal.ToProto();
100 if (device_handle) {
101 *request.mutable_device_handle() = *device_handle;
102 }
103 request.set_replica_id(replica_id);
104 TransferToInfeedResponse response;
105
106 VLOG(1) << "making transfer to infeed request";
107 VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
108 Status s = stub_->TransferToInfeed(&request, &response);
109 VLOG(1) << "done with request";
110
111 if (!s.ok()) {
112 return s;
113 }
114 VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
115 return Status::OK();
116 }
117
TransferFromOutfeed(const Shape * shape_with_layout,int64 replica_id,const DeviceHandle * device_handle)118 StatusOr<Literal> Client::TransferFromOutfeed(
119 const Shape* shape_with_layout, int64 replica_id,
120 const DeviceHandle* device_handle) {
121 TransferFromOutfeedRequest request;
122 if (device_handle) {
123 *request.mutable_device_handle() = *device_handle;
124 }
125 request.set_replica_id(replica_id);
126 if (shape_with_layout != nullptr) {
127 *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
128 }
129 TransferFromOutfeedResponse response;
130
131 VLOG(1) << "making transfer from outfeed request";
132 VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
133 Status s = stub_->TransferFromOutfeed(&request, &response);
134 VLOG(1) << "done with request";
135
136 if (!s.ok()) {
137 return s;
138 }
139 VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
140
141 if (!response.has_literal()) {
142 return FailedPrecondition(
143 "server provided response without a literal in "
144 "TransferToClient request");
145 }
146
147 return Literal::CreateFromProto(response.literal());
148 }
149
ResetDevice()150 Status Client::ResetDevice() {
151 ResetDeviceRequest request;
152 ResetDeviceResponse response;
153
154 VLOG(1) << "making reset device request";
155 VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
156 Status s = stub_->ResetDevice(&request, &response);
157 VLOG(1) << "done with request";
158
159 if (!s.ok()) {
160 return s;
161 }
162 VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
163 return Status::OK();
164 }
165
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)166 StatusOr<Literal> Client::ExecuteAndTransfer(
167 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
168 const ExecutionOptions* execution_options,
169 ExecutionProfile* execution_profile) {
170 TF_ASSIGN_OR_RETURN(
171 std::unique_ptr<GlobalData> data,
172 Execute(computation, arguments, execution_options, execution_profile));
173
174 absl::optional<Shape> shape_with_output_layout;
175 if (execution_options && execution_options->has_shape_with_output_layout()) {
176 shape_with_output_layout =
177 Shape(execution_options->shape_with_output_layout());
178 }
179 return Transfer(*data, shape_with_output_layout.has_value()
180 ? &(*shape_with_output_layout)
181 : nullptr);
182 }
183
ComputeConstant(const XlaComputation & computation,const Layout * output_layout) const184 StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
185 const Layout* output_layout) const {
186 ComputeConstantGraphRequest request;
187 *request.mutable_computation() = computation.proto();
188 if (output_layout != nullptr) {
189 *request.mutable_output_layout() = output_layout->ToProto();
190 }
191
192 ComputeConstantResponse response;
193
194 VLOG(2) << "making compute-constant-graph request";
195 Status s = stub_->ComputeConstantGraph(&request, &response);
196 VLOG(2) << "done with request";
197
198 if (!s.ok()) {
199 return s;
200 }
201
202 VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
203
204 if (!response.has_literal()) {
205 return InternalError(
206 "no computed literal in the provided response in ComputeConstantGraph "
207 "request");
208 }
209 return Literal::CreateFromProto(response.literal());
210 }
211
LoadSnapshot(const HloSnapshot & module)212 StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
213 TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
214 return XlaComputation(module.hlo().hlo_module());
215 }
216
Compile(const XlaComputation & computation,absl::Span<const Shape> argument_shapes,const ExecutionOptions * execution_options)217 StatusOr<ExecutionHandle> Client::Compile(
218 const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
219 const ExecutionOptions* execution_options) {
220 CompileRequest request;
221 *request.mutable_computation() = computation.proto();
222
223 if (execution_options == nullptr) {
224 *request.mutable_execution_options() = CreateDefaultExecutionOptions();
225 } else {
226 *request.mutable_execution_options() = *execution_options;
227 }
228 if (request.execution_options().device_handles_size() > 1) {
229 return InvalidArgument(
230 "Compiling with multiple device handles is not supported. Use "
231 "'Execute' instead.");
232 }
233
234 // The argument shapes affect how the computation is compiled.
235 for (const auto& arg_shape : argument_shapes) {
236 *request.add_input_shape_with_layout() = arg_shape.ToProto();
237 }
238
239 CompileResponse response;
240 VLOG(1) << "making compile request: " << request.ShortDebugString();
241 Status s = stub_->Compile(&request, &response);
242 VLOG(1) << "done with request";
243
244 if (!s.ok()) {
245 return s;
246 }
247 TF_RET_CHECK(response.has_handle());
248 return response.handle();
249 }
250
Execute(const ExecutionHandle & handle,absl::Span<GlobalData * const> arguments,ExecutionProfile * execution_profile)251 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
252 const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
253 ExecutionProfile* execution_profile) {
254 ExecuteRequest request;
255 *request.mutable_handle() = handle;
256 for (GlobalData* argument : arguments) {
257 CHECK(argument != nullptr) << "Argument pointers must not be null.";
258 *request.add_arguments() = argument->handle();
259 }
260
261 ExecuteResponse response;
262 VLOG(1) << "making execute request: " << request.ShortDebugString();
263 Status s = stub_->Execute(&request, &response);
264 VLOG(1) << "done with request";
265
266 if (!s.ok()) {
267 return s;
268 }
269
270 if (execution_profile != nullptr) {
271 *execution_profile = response.profile();
272 }
273
274 return absl::make_unique<GlobalData>(stub_, response.output());
275 }
276
Execute(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)277 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
278 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
279 const ExecutionOptions* execution_options,
280 ExecutionProfile* execution_profile) {
281 // Create an ExecutionOptions if necessary, or set its DeviceHandles.
282 absl::optional<ExecutionOptions> options_storage;
283 if (!execution_options || execution_options->device_handles().empty()) {
284 if (execution_options) {
285 options_storage.emplace(*execution_options);
286 } else {
287 options_storage.emplace(CreateDefaultExecutionOptions());
288 }
289 execution_options = &*options_storage;
290
291 TF_ASSIGN_OR_RETURN(auto device_handles,
292 GetDeviceHandles(/*device_count=*/1));
293 TF_RET_CHECK(!device_handles.empty());
294 *options_storage->add_device_handles() = std::move(device_handles[0]);
295 }
296
297 std::vector<XlaComputationInstance> computation_instances = {
298 XlaComputationInstance{
299 computation,
300 std::vector<GlobalData*>(arguments.begin(), arguments.end()),
301 *execution_options, execution_profile}};
302
303 // Instead of invoking Compile() and Execute(), invoke
304 // Service::ExecuteParallel() to execute our one computation. Compile()
305 // caches the executable forever, which isn't what we want.
306 VLOG(1) << "Making ExecuteParallel request: "
307 << execution_options->DebugString();
308 TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
309 VLOG(1) << "ExecuteParallel request done.";
310
311 // The result selection is a bit hacky, but better than assuming it is
312 // device 0.
313 //
314 // TODO(b/118493728): Allow Execute to return one result per computation.
315 for (int64 i = 0; i < results.size(); i++) {
316 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
317 if (!ShapeUtil::IsEmptyTuple(shape)) {
318 VLOG(3) << "Fetching result from device " << i << ": "
319 << ShapeUtil::HumanString(shape);
320 return std::move(results[i]);
321 }
322 }
323 TF_RET_CHECK(!results.empty());
324 VLOG(1) << "Defaulting to device 0 result";
325 return std::move(results[0]);
326 }
327
ExecuteParallel(absl::Span<const XlaComputationInstance> computations)328 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
329 absl::Span<const XlaComputationInstance> computations) {
330 ExecuteGraphParallelRequest request;
331
332 for (const XlaComputationInstance& computation : computations) {
333 ExecuteGraphRequest single_request;
334 *single_request.mutable_computation() = computation.computation.proto();
335 for (GlobalData* argument : computation.arguments) {
336 *single_request.add_arguments() = argument->handle();
337 }
338 *single_request.mutable_execution_options() = computation.execution_options;
339 *request.add_requests() = single_request;
340 }
341
342 ExecuteParallelResponse response;
343 VLOG(1) << "making execute-graph-parallel request: "
344 << request.ShortDebugString();
345 Status s = stub_->ExecuteGraphParallel(&request, &response);
346 VLOG(1) << "done with request";
347
348 if (!s.ok()) {
349 return s;
350 }
351
352 std::vector<std::unique_ptr<GlobalData>> outputs;
353 for (size_t i = 0; i < response.responses_size(); ++i) {
354 outputs.push_back(
355 absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
356 if (i < computations.size() &&
357 computations[i].execution_profile != nullptr) {
358 *computations[i].execution_profile = response.responses(i).profile();
359 }
360 }
361
362 return std::move(outputs);
363 }
364
GetDeviceHandles(int64 device_count)365 StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
366 int64 device_count) {
367 if (device_count < 1) {
368 return InvalidArgument("device_count must be greater than 0");
369 }
370 GetDeviceHandlesRequest request;
371 request.set_device_count(device_count);
372
373 GetDeviceHandlesResponse response;
374 VLOG(1) << "making get device request: " << request.ShortDebugString();
375 Status s = stub_->GetDeviceHandles(&request, &response);
376 VLOG(1) << "done with request";
377
378 if (!s.ok()) {
379 return s;
380 }
381
382 std::vector<DeviceHandle> device_handles;
383 for (const DeviceHandle& device_handle : response.device_handles()) {
384 device_handles.push_back(device_handle);
385 }
386
387 return device_handles;
388 }
389
Unregister(const GlobalData & data)390 Status Client::Unregister(const GlobalData& data) {
391 UnregisterRequest request;
392 *request.add_data() = data.handle();
393 UnregisterResponse response;
394
395 VLOG(1) << "making unregister request";
396 Status s = stub_->Unregister(&request, &response);
397 VLOG(1) << "done with request";
398
399 return s;
400 }
401
DeconstructTuple(const GlobalData & data)402 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
403 const GlobalData& data) {
404 DeconstructTupleRequest request;
405 *request.mutable_tuple_handle() = data.handle();
406 DeconstructTupleResponse response;
407
408 VLOG(1) << "making DestructTuple request";
409 Status s = stub_->DeconstructTuple(&request, &response);
410 VLOG(1) << "done with request";
411
412 if (!s.ok()) {
413 return s;
414 }
415
416 std::vector<std::unique_ptr<GlobalData>> handles;
417 for (auto& handle : response.element_handles()) {
418 handles.push_back(absl::make_unique<GlobalData>(stub_, handle));
419 }
420 return std::move(handles);
421 }
422
GetComputationStats(const XlaComputation & computation,const DebugOptions & debug_options) const423 StatusOr<ComputationStats> Client::GetComputationStats(
424 const XlaComputation& computation,
425 const DebugOptions& debug_options) const {
426 ComputationGraphStatsRequest request;
427
428 // TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
429 *request.mutable_computation() = computation.proto();
430 *request.mutable_debug_options() = debug_options;
431 ComputationStatsResponse response;
432
433 VLOG(1) << "making computation graph stats request";
434 Status s = stub_->GetComputationGraphStats(&request, &response);
435 VLOG(1) << "done with request";
436
437 if (!s.ok()) {
438 return s;
439 }
440 CHECK(response.has_stats());
441 return response.stats();
442 }
443
GetComputationShape(const XlaComputation & computation)444 StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
445 const XlaComputation& computation) {
446 TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
447 return absl::make_unique<ProgramShape>(result);
448 }
449
GetShape(const GlobalData & data)450 StatusOr<Shape> Client::GetShape(const GlobalData& data) {
451 GetShapeRequest request;
452 *request.mutable_data() = data.handle();
453 GetShapeResponse response;
454
455 VLOG(1) << "making get shape request";
456 Status s = stub_->GetShape(&request, &response);
457 VLOG(1) << "done with request";
458
459 if (!s.ok()) {
460 return s;
461 }
462
463 return Shape(response.shape());
464 }
465
ExecutionStatsAsString(const XlaComputation & computation,const ExecutionProfile & profile)466 StatusOr<string> Client::ExecutionStatsAsString(
467 const XlaComputation& computation, const ExecutionProfile& profile) {
468 TF_ASSIGN_OR_RETURN(
469 auto computation_stats,
470 GetComputationStats(computation, GetDebugOptionsFromFlags()));
471 int64 total_flops =
472 computation_stats.flop_count() + computation_stats.transcendental_count();
473 if (profile.compute_time_ns() > 0) {
474 int64 nanoseconds = profile.compute_time_ns();
475 int64 cycle_count = profile.compute_cycle_count();
476 double gflops = total_flops / nanoseconds;
477 return absl::StrCat(
478 "[Execution Statistics] flop count: ", computation_stats.flop_count(),
479 ", transcendental count: ", computation_stats.transcendental_count(),
480 ", compute execution time: ", nanoseconds, " nsec",
481 ", compute cycles: ", cycle_count, ", performance: ", gflops,
482 "gflop/s");
483 }
484 return string("[Execution Statistics] not available.");
485 }
486
CreateChannelHandleByType(ChannelHandle::ChannelType type)487 StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
488 ChannelHandle::ChannelType type) {
489 CreateChannelHandleRequest request;
490 request.set_channel_type(type);
491 CreateChannelHandleResponse response;
492
493 VLOG(1) << "making create channel handle request";
494 Status s = stub_->CreateChannelHandle(&request, &response);
495 VLOG(1) << "done with request";
496
497 if (!s.ok()) {
498 return s;
499 }
500
501 return response.channel();
502 }
503
CreateChannelHandle()504 StatusOr<ChannelHandle> Client::CreateChannelHandle() {
505 return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE);
506 }
507
CreateHostToDeviceChannelHandle()508 StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() {
509 return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE);
510 }
511
CreateDeviceToHostChannelHandle()512 StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() {
513 return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST);
514 }
515
516 } // namespace xla
517