1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // ==============================================================================
15
16 #include <dlfcn.h>
17
18 #include "absl/strings/str_format.h"
19 #include "absl/time/time.h"
20 #include "tensorflow/compiler/xla/python/tpu_driver/client/libtpu.h"
21 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
22 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26
27 namespace tpu_driver {
28 namespace {
29
30 // Enable the macro by default in the Google internal environment where the
31 // libtpu.so is linked in statically.
32 #ifdef PLATFORM_GOOGLE
33 #define TPU_SHARED_LIBRARY_COMPILE_LINK 1
34 #endif
35
CreateXlaStatus(::TpuStatus * status)36 xla::Status CreateXlaStatus(::TpuStatus* status) {
37 if (status->code == tensorflow::error::OK) {
38 return xla::Status::OK();
39 } else {
40 return xla::Status(tensorflow::error::Code(status->code),
41 absl::StrFormat("%s", status->msg));
42 }
43 }
44
45 constexpr char kDirectProtocol[] = "direct://";
46
GetTpuAllocationShape(const xla::ShapeProto & shape)47 ::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
48 ::TpuAllocationShape shape_;
49 shape_.size = shape.ByteSizeLong();
50 shape_.bytes = malloc(shape_.size);
51 if (!shape.SerializeToArray(shape_.bytes, shape_.size)) {
52 LOG(ERROR) << "Unable to serialize shape to array.";
53 free(shape_.bytes);
54 shape_.size = 0;
55 shape_.bytes = nullptr;
56 }
57 return shape_;
58 }
59
60 class DirectTpuDriver;
61
62 class DirectEvent : public Event {
63 public:
DirectEvent(::TpuDriverFn * driver_fn,::TpuEvent * event)64 explicit DirectEvent(::TpuDriverFn* driver_fn, ::TpuEvent* event)
65 : driver_fn_(driver_fn), event_(event) {}
66
~DirectEvent()67 ~DirectEvent() override { driver_fn_->TpuDriver_FreeEvent(event_); }
68
Await()69 xla::Status Await() override {
70 auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
71 auto ret = CreateXlaStatus(tpu_status);
72 driver_fn_->TpuDriver_FreeStatus(tpu_status);
73 return ret;
74 }
75
AwaitWithTimeout(absl::Duration duration)76 absl::optional<xla::Status> AwaitWithTimeout(
77 absl::Duration duration) override {
78 auto tpu_status_or = driver_fn_->TpuDriver_EventAwait(
79 event_, absl::ToInt64Microseconds(duration));
80 if (tpu_status_or == nullptr) {
81 return absl::nullopt;
82 } else {
83 auto ret = CreateXlaStatus(tpu_status_or);
84 driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
85 return ret;
86 }
87 }
88
AddCallback(std::function<void (xla::Status)> callback)89 void AddCallback(std::function<void(xla::Status)> callback) override {
90 // We have to create a new copy of the fn on the heap to make it persist.
91 std::function<void(xla::Status)>* callback_addr =
92 new std::function<void(xla::Status)>(callback);
93
94 // Using the callback_addr instead of capturing because C++11 lambdas with
95 // variable captures cannot be converted to C function pointers.
96 driver_fn_->TpuDriver_EventAddCallback(
97 event_,
98 [](struct TpuStatus* status, void* additional_info) {
99 auto callback_addr =
100 static_cast<std::function<void(xla::Status)>*>(additional_info);
101 auto xla_status = CreateXlaStatus(status);
102 (*callback_addr)(xla_status);
103 delete callback_addr;
104 },
105 callback_addr);
106 }
107
108 private:
109 ::TpuDriverFn* driver_fn_;
110 ::TpuEvent* event_;
111
112 friend DirectTpuDriver;
113 };
114
115 class DirectBufferHandle : public BufferHandle {
116 public:
DirectBufferHandle(::TpuDriverFn * driver_fn,::TpuBufferHandle * handle)117 explicit DirectBufferHandle(::TpuDriverFn* driver_fn,
118 ::TpuBufferHandle* handle)
119 : handle_(handle), event_(new DirectEvent(driver_fn, handle->event)) {}
120
OnReady()121 std::shared_ptr<Event> OnReady() override { return event_; }
122
size_in_bytes()123 int64_t size_in_bytes() override { return handle_->size_in_bytes; }
124
shape()125 absl::optional<xla::ShapeProto> shape() override {
126 LOG(FATAL) << "Unimplemented.";
127 return absl::nullopt;
128 }
129
130 private:
131 ::TpuBufferHandle* handle_;
132 std::shared_ptr<DirectEvent> event_;
133
134 friend DirectTpuDriver;
135 };
136
137 class DirectCompiledProgramHandle : public CompiledProgramHandle {
138 public:
DirectCompiledProgramHandle(::TpuDriverFn * driver_fn,::TpuCompiledProgramHandle * handle)139 explicit DirectCompiledProgramHandle(::TpuDriverFn* driver_fn,
140 ::TpuCompiledProgramHandle* handle)
141 : handle_(handle),
142 driver_fn_(driver_fn),
143 event_(new DirectEvent(driver_fn, handle->event)) {}
144
~DirectCompiledProgramHandle()145 ~DirectCompiledProgramHandle() override {
146 driver_fn_->TpuDriver_FreeCompiledProgramHandle(handle_);
147 }
148
OnReady()149 std::shared_ptr<Event> OnReady() override { return event_; }
150
size_in_bytes()151 int64_t size_in_bytes() override {
152 LOG(FATAL) << "Unimplemented.";
153 return 0;
154 }
155
program_shape(xla::ProgramShapeProto * program_shape)156 xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
157 struct CompiledProgramShape* shape =
158 driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
159 program_shape->ParseFromArray(shape->bytes, shape->size);
160
161 auto status = CreateXlaStatus(shape->status);
162 driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
163 return status;
164 }
165
166 private:
167 ::TpuCompiledProgramHandle* handle_;
168 ::TpuDriverFn* driver_fn_;
169 std::shared_ptr<DirectEvent> event_;
170
171 friend DirectTpuDriver;
172 };
173
174 class DirectLoadedProgramHandle : public LoadedProgramHandle {
175 public:
DirectLoadedProgramHandle(::TpuDriverFn * driver_fn,::TpuLoadedProgramHandle * handle)176 explicit DirectLoadedProgramHandle(::TpuDriverFn* driver_fn,
177 ::TpuLoadedProgramHandle* handle)
178 : handle_(handle), event_(new DirectEvent(driver_fn, handle->event)) {}
OnReady()179 std::shared_ptr<Event> OnReady() override { return event_; }
180
size_in_bytes()181 int64_t size_in_bytes() override {
182 LOG(FATAL) << "Unimplemented.";
183 return 0;
184 }
185
186 private:
187 ::TpuLoadedProgramHandle* handle_;
188 std::shared_ptr<DirectEvent> event_;
189
190 friend DirectTpuDriver;
191 };
192
193 class DirectTpuLinearizer : public TpuLinearizer {
194 public:
DirectTpuLinearizer(::TpuDriver * driver,::TpuDriverFn * driver_fn)195 explicit DirectTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
196 : driver_(driver), driver_fn_(driver_fn) {}
197
ComputeLinearizedBytesFromShape(const xla::ShapeProto & shape)198 int64_t ComputeLinearizedBytesFromShape(
199 const xla::ShapeProto& shape) override {
200 ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
201 uint64_t size =
202 driver_fn_->TpuDriver_ComputeLinearizedBytesFromShape(driver_, shape_);
203 free(shape_.bytes);
204 return size;
205 }
206
LinearizeShape(void * dst,const void * src,const xla::ShapeProto & shape)207 xla::Status LinearizeShape(void* dst, const void* src,
208 const xla::ShapeProto& shape) override {
209 ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
210
211 auto tpu_status =
212 driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
213 auto status = CreateXlaStatus(tpu_status);
214 driver_fn_->TpuDriver_FreeStatus(tpu_status);
215 free(shape_.bytes);
216 return status;
217 }
218
DelinearizeShape(void * dst,const void * src,const xla::ShapeProto & shape)219 xla::Status DelinearizeShape(void* dst, const void* src,
220 const xla::ShapeProto& shape) override {
221 ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
222
223 auto tpu_status =
224 driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
225 auto status = CreateXlaStatus(tpu_status);
226 driver_fn_->TpuDriver_FreeStatus(tpu_status);
227 free(shape_.bytes);
228 return status;
229 }
230
231 private:
232 ::TpuDriver* driver_;
233 ::TpuDriverFn* driver_fn_;
234 };
235
236 class DirectTpuDriver : public TpuDriver {
237 public:
DirectTpuDriver(const std::string & so_path)238 explicit DirectTpuDriver(const std::string& so_path) {
239 void* handle;
240 handle = dlopen(so_path.c_str(), RTLD_NOW);
241 if (!handle) {
242 LOG(FATAL) << "Unable to load shared library: " << dlerror();
243 }
244
245 PrototypeTpuDriver_Initialize* initialize_fn;
246 *reinterpret_cast<void**>(&initialize_fn) =
247 dlsym(handle, "TpuDriver_Initialize");
248 initialize_fn(&driver_fn_, /*initialize=*/true);
249
250 driver_ = driver_fn_.TpuDriver_Open("local://");
251 }
252
253 #ifdef TPU_SHARED_LIBRARY_COMPILE_LINK
DirectTpuDriver()254 DirectTpuDriver() {
255 TpuDriver_Initialize(&driver_fn_, /*initialize=*/false);
256 driver_ = driver_fn_.TpuDriver_Open("local://");
257 }
258 #endif
259
~DirectTpuDriver()260 ~DirectTpuDriver() override { driver_fn_.TpuDriver_Close(driver_); }
261
QuerySystemInfo(SystemInfo * system_info)262 void QuerySystemInfo(SystemInfo* system_info) override {
263 ::TpuSystemInfo* info = driver_fn_.TpuDriver_QuerySystemInfo(driver_);
264 system_info->ParseFromArray(info->bytes, info->size);
265 driver_fn_.TpuDriver_FreeSystemInfo(info);
266 }
267
Reset()268 xla::Status Reset() override {
269 auto tpu_status = driver_fn_.TpuDriver_Reset(driver_);
270 auto status = CreateXlaStatus(tpu_status);
271 driver_fn_.TpuDriver_FreeStatus(tpu_status);
272 return status;
273 }
274
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)275 std::unique_ptr<BufferHandle> Allocate(
276 int32_t core_id, MemoryRegion region, int64_t num_bytes,
277 absl::Span<Event* const> wait_for) override {
278 auto tpu_events = MakeEventArray(wait_for);
279 auto bh = absl::make_unique<DirectBufferHandle>(
280 &driver_fn_,
281 driver_fn_.TpuDriver_Allocate(driver_, core_id, region, num_bytes,
282 wait_for.size(), tpu_events));
283 delete[] tpu_events;
284 return bh;
285 }
286
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)287 std::unique_ptr<BufferHandle> Allocate(
288 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
289 absl::Span<Event* const> wait_for) override {
290 auto tpu_events = MakeEventArray(wait_for);
291
292 ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
293 auto bh = absl::make_unique<DirectBufferHandle>(
294 &driver_fn_,
295 driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_,
296 wait_for.size(), tpu_events));
297
298 free(shape_.bytes);
299 delete[] tpu_events;
300 return bh;
301 }
302
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)303 std::unique_ptr<BufferHandle> AllocateTuple(
304 int32_t core_id, MemoryRegion region,
305 absl::Span<BufferHandle* const> children,
306 absl::Span<Event* const> wait_for) override {
307 auto tpu_events = MakeEventArray(wait_for);
308
309 ::TpuBufferHandle** childbuf = new ::TpuBufferHandle*[children.size()];
310 for (int i = 0; i < children.size(); i++) {
311 childbuf[i] =
312 static_cast<DirectBufferHandle* const>(children[i])->handle_;
313 }
314
315 auto bh = absl::make_unique<DirectBufferHandle>(
316 &driver_fn_, driver_fn_.TpuDriver_AllocateTuple(
317 driver_, core_id, region, children.size(), childbuf,
318 wait_for.size(), tpu_events));
319 delete[] tpu_events;
320 delete[] childbuf;
321
322 return bh;
323 }
324
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)325 std::shared_ptr<Event> Deallocate(
326 std::unique_ptr<BufferHandle> handle,
327 absl::Span<Event* const> wait_for) override {
328 auto tpu_events = MakeEventArray(wait_for);
329 auto* direct_bh = static_cast<DirectBufferHandle*>(handle.get());
330 auto event = std::make_shared<DirectEvent>(
331 &driver_fn_,
332 driver_fn_.TpuDriver_Deallocate(driver_, direct_bh->handle_,
333 wait_for.size(), tpu_events));
334 delete[] tpu_events;
335 return event;
336 }
337
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)338 std::shared_ptr<Event> TransferToDevice(
339 const void* src, BufferHandle* dst,
340 absl::Span<Event* const> wait_for) override {
341 auto tpu_events = MakeEventArray(wait_for);
342 auto event = std::make_shared<DirectEvent>(
343 &driver_fn_,
344 driver_fn_.TpuDriver_TransferToDevice(
345 driver_, src, static_cast<DirectBufferHandle*>(dst)->handle_,
346 wait_for.size(), tpu_events));
347 delete[] tpu_events;
348 return event;
349 }
350
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)351 std::shared_ptr<Event> TransferFromDevice(
352 const BufferHandle* src, void* dst,
353 absl::Span<Event* const> wait_for) override {
354 auto tpu_events = MakeEventArray(wait_for);
355 auto event = std::make_shared<DirectEvent>(
356 &driver_fn_,
357 driver_fn_.TpuDriver_TransferFromDevice(
358 driver_, static_cast<const DirectBufferHandle*>(src)->handle_, dst,
359 wait_for.size(), tpu_events));
360 delete[] tpu_events;
361 return event;
362 }
363
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)364 std::shared_ptr<Event> TransferFromDeviceToDevice(
365 const BufferHandle* src, BufferHandle* dst,
366 absl::Span<Event* const> wait_for) override {
367 auto tpu_events = MakeEventArray(wait_for);
368 auto event = std::make_shared<DirectEvent>(
369 &driver_fn_,
370 driver_fn_.TpuDriver_TransferFromDeviceToDevice(
371 driver_, static_cast<const DirectBufferHandle*>(src)->handle_,
372 static_cast<DirectBufferHandle*>(dst)->handle_, wait_for.size(),
373 tpu_events));
374 delete[] tpu_events;
375 return event;
376 }
377
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)378 std::unique_ptr<CompiledProgramHandle> CompileProgram(
379 const xla::HloProto& source, int32_t num_replicas,
380 absl::Span<Event* const> wait_for) override {
381 auto tpu_events = MakeEventArray(wait_for);
382
383 struct HloProto hlo;
384 hlo.size = source.ByteSizeLong();
385 hlo.buffer = malloc(hlo.size);
386 if (!source.SerializeToArray(hlo.buffer, hlo.size)) {
387 LOG(ERROR) << "Unable to serialize HLO to array.";
388 return nullptr;
389 }
390
391 auto handle = absl::make_unique<DirectCompiledProgramHandle>(
392 &driver_fn_,
393 driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas,
394 wait_for.size(), tpu_events));
395
396 free(hlo.buffer);
397 delete[] tpu_events;
398 return handle;
399 }
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)400 std::unique_ptr<LoadedProgramHandle> LoadProgram(
401 int32_t core_id, const CompiledProgramHandle* handle,
402 absl::Span<Event* const> wait_for) override {
403 auto tpu_events = MakeEventArray(wait_for);
404
405 auto loaded_handle = absl::make_unique<DirectLoadedProgramHandle>(
406 &driver_fn_,
407 driver_fn_.TpuDriver_LoadProgram(
408 driver_, core_id,
409 static_cast<const DirectCompiledProgramHandle*>(handle)->handle_,
410 wait_for.size(), tpu_events));
411
412 delete[] tpu_events;
413 return loaded_handle;
414 }
415
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)416 std::shared_ptr<Event> UnloadProgram(
417 std::unique_ptr<LoadedProgramHandle> handle,
418 absl::Span<Event* const> wait_for) override {
419 auto tpu_events = MakeEventArray(wait_for);
420 auto* direct_lph = static_cast<DirectLoadedProgramHandle*>(handle.get());
421 auto event = std::make_shared<DirectEvent>(
422 &driver_fn_,
423 driver_fn_.TpuDriver_UnloadProgram(driver_, direct_lph->handle_,
424 wait_for.size(), tpu_events));
425 delete[] tpu_events;
426 return event;
427 }
428
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)429 std::shared_ptr<Event> ExecuteProgram(
430 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
431 absl::Span<BufferHandle* const> outputs,
432 const xla::DeviceAssignmentProto& device_assignment,
433 absl::Span<Event* const> wait_for) override {
434 auto tpu_events = MakeEventArray(wait_for);
435
436 std::vector<::TpuBufferHandle*> inputv;
437 inputv.reserve(inputs.size());
438 for (int i = 0; i < inputs.size(); i++) {
439 inputv.push_back(
440 static_cast<DirectBufferHandle* const>(inputs[i])->handle_);
441 }
442 std::vector<::TpuBufferHandle*> outputv;
443 outputv.reserve(outputs.size());
444 for (int i = 0; i < outputs.size(); i++) {
445 outputv.push_back(
446 static_cast<DirectBufferHandle* const>(outputs[i])->handle_);
447 }
448
449 struct DeviceAssignment da;
450 da.size = device_assignment.ByteSizeLong();
451 da.bytes = malloc(da.size);
452 device_assignment.SerializeToArray(da.bytes, da.size);
453
454 auto event = std::make_shared<DirectEvent>(
455 &driver_fn_,
456 driver_fn_.TpuDriver_ExecuteProgram(
457 driver_, static_cast<DirectLoadedProgramHandle*>(program)->handle_,
458 inputs.size(), inputv.data(), outputs.size(), outputv.data(), da,
459 wait_for.size(), tpu_events));
460
461 free(da.bytes);
462 delete[] tpu_events;
463 return event;
464 }
465
GetLinearizer()466 std::unique_ptr<TpuLinearizer> GetLinearizer() override {
467 return std::make_unique<DirectTpuLinearizer>(driver_, &driver_fn_);
468 }
469
470 private:
471 ::TpuDriverFn driver_fn_;
472 ::TpuDriver* driver_;
473
MakeEventArray(absl::Span<Event * const> wait_for)474 ::TpuEvent** MakeEventArray(absl::Span<Event* const> wait_for) {
475 if (wait_for.empty()) return nullptr;
476 ::TpuEvent** ret = new ::TpuEvent*[wait_for.size()];
477 for (int i = 0; i < wait_for.size(); i++) {
478 ret[i] = static_cast<DirectEvent* const>(wait_for[i])->event_;
479 }
480 return ret;
481 }
482 };
483
RegisterDirectTpuDriver(const TpuDriverConfig & config)484 xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterDirectTpuDriver(
485 const TpuDriverConfig& config) {
486 std::string shared_lib = config.worker().substr(strlen(kDirectProtocol));
487 if (shared_lib == "internal") {
488 #ifdef TPU_SHARED_LIBRARY_COMPILE_LINK
489 return xla::StatusOr<std::unique_ptr<TpuDriver>>(
490 absl::make_unique<DirectTpuDriver>());
491 #else
492 LOG(FATAL) << "Request to use compile-time linked TPU library, but did not "
493 << "link in appropriate library at compile time.";
494 #endif
495 }
496 return xla::StatusOr<std::unique_ptr<TpuDriver>>(
497 absl::make_unique<DirectTpuDriver>(shared_lib));
498 }
499
500 REGISTER_TPU_DRIVER(kDirectProtocol, RegisterDirectTpuDriver);
501
502 } // namespace
503 } // namespace tpu_driver
504