• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // ==============================================================================
15 
16 #include <dlfcn.h>
17 
18 #include "absl/strings/str_format.h"
19 #include "absl/time/time.h"
20 #include "tensorflow/compiler/xla/python/tpu_driver/client/libtpu.h"
21 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
22 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace tpu_driver {
28 namespace {
29 
30 // Enable the macro by default in the Google internal environment where the
31 // libtpu.so is linked in statically.
32 #ifdef PLATFORM_GOOGLE
33 #define TPU_SHARED_LIBRARY_COMPILE_LINK 1
34 #endif
35 
CreateXlaStatus(::TpuStatus * status)36 xla::Status CreateXlaStatus(::TpuStatus* status) {
37   if (status->code == tensorflow::error::OK) {
38     return xla::Status::OK();
39   } else {
40     return xla::Status(tensorflow::error::Code(status->code),
41                        absl::StrFormat("%s", status->msg));
42   }
43 }
44 
45 constexpr char kDirectProtocol[] = "direct://";
46 
GetTpuAllocationShape(const xla::ShapeProto & shape)47 ::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
48   ::TpuAllocationShape shape_;
49   shape_.size = shape.ByteSizeLong();
50   shape_.bytes = malloc(shape_.size);
51   if (!shape.SerializeToArray(shape_.bytes, shape_.size)) {
52     LOG(ERROR) << "Unable to serialize shape to array.";
53     free(shape_.bytes);
54     shape_.size = 0;
55     shape_.bytes = nullptr;
56   }
57   return shape_;
58 }
59 
60 class DirectTpuDriver;
61 
62 class DirectEvent : public Event {
63  public:
DirectEvent(::TpuDriverFn * driver_fn,::TpuEvent * event)64   explicit DirectEvent(::TpuDriverFn* driver_fn, ::TpuEvent* event)
65       : driver_fn_(driver_fn), event_(event) {}
66 
~DirectEvent()67   ~DirectEvent() override { driver_fn_->TpuDriver_FreeEvent(event_); }
68 
Await()69   xla::Status Await() override {
70     auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
71     auto ret = CreateXlaStatus(tpu_status);
72     driver_fn_->TpuDriver_FreeStatus(tpu_status);
73     return ret;
74   }
75 
AwaitWithTimeout(absl::Duration duration)76   absl::optional<xla::Status> AwaitWithTimeout(
77       absl::Duration duration) override {
78     auto tpu_status_or = driver_fn_->TpuDriver_EventAwait(
79         event_, absl::ToInt64Microseconds(duration));
80     if (tpu_status_or == nullptr) {
81       return absl::nullopt;
82     } else {
83       auto ret = CreateXlaStatus(tpu_status_or);
84       driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
85       return ret;
86     }
87   }
88 
AddCallback(std::function<void (xla::Status)> callback)89   void AddCallback(std::function<void(xla::Status)> callback) override {
90     // We have to create a new copy of the fn on the heap to make it persist.
91     std::function<void(xla::Status)>* callback_addr =
92         new std::function<void(xla::Status)>(callback);
93 
94     // Using the callback_addr instead of capturing because C++11 lambdas with
95     // variable captures cannot be converted to C function pointers.
96     driver_fn_->TpuDriver_EventAddCallback(
97         event_,
98         [](struct TpuStatus* status, void* additional_info) {
99           auto callback_addr =
100               static_cast<std::function<void(xla::Status)>*>(additional_info);
101           auto xla_status = CreateXlaStatus(status);
102           (*callback_addr)(xla_status);
103           delete callback_addr;
104         },
105         callback_addr);
106   }
107 
108  private:
109   ::TpuDriverFn* driver_fn_;
110   ::TpuEvent* event_;
111 
112   friend DirectTpuDriver;
113 };
114 
115 class DirectBufferHandle : public BufferHandle {
116  public:
DirectBufferHandle(::TpuDriverFn * driver_fn,::TpuBufferHandle * handle)117   explicit DirectBufferHandle(::TpuDriverFn* driver_fn,
118                               ::TpuBufferHandle* handle)
119       : handle_(handle), event_(new DirectEvent(driver_fn, handle->event)) {}
120 
OnReady()121   std::shared_ptr<Event> OnReady() override { return event_; }
122 
size_in_bytes()123   int64_t size_in_bytes() override { return handle_->size_in_bytes; }
124 
shape()125   absl::optional<xla::ShapeProto> shape() override {
126     LOG(FATAL) << "Unimplemented.";
127     return absl::nullopt;
128   }
129 
130  private:
131   ::TpuBufferHandle* handle_;
132   std::shared_ptr<DirectEvent> event_;
133 
134   friend DirectTpuDriver;
135 };
136 
137 class DirectCompiledProgramHandle : public CompiledProgramHandle {
138  public:
DirectCompiledProgramHandle(::TpuDriverFn * driver_fn,::TpuCompiledProgramHandle * handle)139   explicit DirectCompiledProgramHandle(::TpuDriverFn* driver_fn,
140                                        ::TpuCompiledProgramHandle* handle)
141       : handle_(handle),
142         driver_fn_(driver_fn),
143         event_(new DirectEvent(driver_fn, handle->event)) {}
144 
~DirectCompiledProgramHandle()145   ~DirectCompiledProgramHandle() override {
146     driver_fn_->TpuDriver_FreeCompiledProgramHandle(handle_);
147   }
148 
OnReady()149   std::shared_ptr<Event> OnReady() override { return event_; }
150 
size_in_bytes()151   int64_t size_in_bytes() override {
152     LOG(FATAL) << "Unimplemented.";
153     return 0;
154   }
155 
program_shape(xla::ProgramShapeProto * program_shape)156   xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
157     struct CompiledProgramShape* shape =
158         driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
159     program_shape->ParseFromArray(shape->bytes, shape->size);
160 
161     auto status = CreateXlaStatus(shape->status);
162     driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
163     return status;
164   }
165 
166  private:
167   ::TpuCompiledProgramHandle* handle_;
168   ::TpuDriverFn* driver_fn_;
169   std::shared_ptr<DirectEvent> event_;
170 
171   friend DirectTpuDriver;
172 };
173 
174 class DirectLoadedProgramHandle : public LoadedProgramHandle {
175  public:
DirectLoadedProgramHandle(::TpuDriverFn * driver_fn,::TpuLoadedProgramHandle * handle)176   explicit DirectLoadedProgramHandle(::TpuDriverFn* driver_fn,
177                                      ::TpuLoadedProgramHandle* handle)
178       : handle_(handle), event_(new DirectEvent(driver_fn, handle->event)) {}
OnReady()179   std::shared_ptr<Event> OnReady() override { return event_; }
180 
size_in_bytes()181   int64_t size_in_bytes() override {
182     LOG(FATAL) << "Unimplemented.";
183     return 0;
184   }
185 
186  private:
187   ::TpuLoadedProgramHandle* handle_;
188   std::shared_ptr<DirectEvent> event_;
189 
190   friend DirectTpuDriver;
191 };
192 
193 class DirectTpuLinearizer : public TpuLinearizer {
194  public:
DirectTpuLinearizer(::TpuDriver * driver,::TpuDriverFn * driver_fn)195   explicit DirectTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
196       : driver_(driver), driver_fn_(driver_fn) {}
197 
ComputeLinearizedBytesFromShape(const xla::ShapeProto & shape)198   int64_t ComputeLinearizedBytesFromShape(
199       const xla::ShapeProto& shape) override {
200     ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
201     uint64_t size =
202         driver_fn_->TpuDriver_ComputeLinearizedBytesFromShape(driver_, shape_);
203     free(shape_.bytes);
204     return size;
205   }
206 
LinearizeShape(void * dst,const void * src,const xla::ShapeProto & shape)207   xla::Status LinearizeShape(void* dst, const void* src,
208                              const xla::ShapeProto& shape) override {
209     ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
210 
211     auto tpu_status =
212         driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
213     auto status = CreateXlaStatus(tpu_status);
214     driver_fn_->TpuDriver_FreeStatus(tpu_status);
215     free(shape_.bytes);
216     return status;
217   }
218 
DelinearizeShape(void * dst,const void * src,const xla::ShapeProto & shape)219   xla::Status DelinearizeShape(void* dst, const void* src,
220                                const xla::ShapeProto& shape) override {
221     ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
222 
223     auto tpu_status =
224         driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
225     auto status = CreateXlaStatus(tpu_status);
226     driver_fn_->TpuDriver_FreeStatus(tpu_status);
227     free(shape_.bytes);
228     return status;
229   }
230 
231  private:
232   ::TpuDriver* driver_;
233   ::TpuDriverFn* driver_fn_;
234 };
235 
236 class DirectTpuDriver : public TpuDriver {
237  public:
DirectTpuDriver(const std::string & so_path)238   explicit DirectTpuDriver(const std::string& so_path) {
239     void* handle;
240     handle = dlopen(so_path.c_str(), RTLD_NOW);
241     if (!handle) {
242       LOG(FATAL) << "Unable to load shared library: " << dlerror();
243     }
244 
245     PrototypeTpuDriver_Initialize* initialize_fn;
246     *reinterpret_cast<void**>(&initialize_fn) =
247         dlsym(handle, "TpuDriver_Initialize");
248     initialize_fn(&driver_fn_, /*initialize=*/true);
249 
250     driver_ = driver_fn_.TpuDriver_Open("local://");
251   }
252 
253 #ifdef TPU_SHARED_LIBRARY_COMPILE_LINK
DirectTpuDriver()254   DirectTpuDriver() {
255     TpuDriver_Initialize(&driver_fn_, /*initialize=*/false);
256     driver_ = driver_fn_.TpuDriver_Open("local://");
257   }
258 #endif
259 
~DirectTpuDriver()260   ~DirectTpuDriver() override { driver_fn_.TpuDriver_Close(driver_); }
261 
QuerySystemInfo(SystemInfo * system_info)262   void QuerySystemInfo(SystemInfo* system_info) override {
263     ::TpuSystemInfo* info = driver_fn_.TpuDriver_QuerySystemInfo(driver_);
264     system_info->ParseFromArray(info->bytes, info->size);
265     driver_fn_.TpuDriver_FreeSystemInfo(info);
266   }
267 
Reset()268   xla::Status Reset() override {
269     auto tpu_status = driver_fn_.TpuDriver_Reset(driver_);
270     auto status = CreateXlaStatus(tpu_status);
271     driver_fn_.TpuDriver_FreeStatus(tpu_status);
272     return status;
273   }
274 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)275   std::unique_ptr<BufferHandle> Allocate(
276       int32_t core_id, MemoryRegion region, int64_t num_bytes,
277       absl::Span<Event* const> wait_for) override {
278     auto tpu_events = MakeEventArray(wait_for);
279     auto bh = absl::make_unique<DirectBufferHandle>(
280         &driver_fn_,
281         driver_fn_.TpuDriver_Allocate(driver_, core_id, region, num_bytes,
282                                       wait_for.size(), tpu_events));
283     delete[] tpu_events;
284     return bh;
285   }
286 
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)287   std::unique_ptr<BufferHandle> Allocate(
288       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
289       absl::Span<Event* const> wait_for) override {
290     auto tpu_events = MakeEventArray(wait_for);
291 
292     ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
293     auto bh = absl::make_unique<DirectBufferHandle>(
294         &driver_fn_,
295         driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_,
296                                            wait_for.size(), tpu_events));
297 
298     free(shape_.bytes);
299     delete[] tpu_events;
300     return bh;
301   }
302 
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)303   std::unique_ptr<BufferHandle> AllocateTuple(
304       int32_t core_id, MemoryRegion region,
305       absl::Span<BufferHandle* const> children,
306       absl::Span<Event* const> wait_for) override {
307     auto tpu_events = MakeEventArray(wait_for);
308 
309     ::TpuBufferHandle** childbuf = new ::TpuBufferHandle*[children.size()];
310     for (int i = 0; i < children.size(); i++) {
311       childbuf[i] =
312           static_cast<DirectBufferHandle* const>(children[i])->handle_;
313     }
314 
315     auto bh = absl::make_unique<DirectBufferHandle>(
316         &driver_fn_, driver_fn_.TpuDriver_AllocateTuple(
317                          driver_, core_id, region, children.size(), childbuf,
318                          wait_for.size(), tpu_events));
319     delete[] tpu_events;
320     delete[] childbuf;
321 
322     return bh;
323   }
324 
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)325   std::shared_ptr<Event> Deallocate(
326       std::unique_ptr<BufferHandle> handle,
327       absl::Span<Event* const> wait_for) override {
328     auto tpu_events = MakeEventArray(wait_for);
329     auto* direct_bh = static_cast<DirectBufferHandle*>(handle.get());
330     auto event = std::make_shared<DirectEvent>(
331         &driver_fn_,
332         driver_fn_.TpuDriver_Deallocate(driver_, direct_bh->handle_,
333                                         wait_for.size(), tpu_events));
334     delete[] tpu_events;
335     return event;
336   }
337 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)338   std::shared_ptr<Event> TransferToDevice(
339       const void* src, BufferHandle* dst,
340       absl::Span<Event* const> wait_for) override {
341     auto tpu_events = MakeEventArray(wait_for);
342     auto event = std::make_shared<DirectEvent>(
343         &driver_fn_,
344         driver_fn_.TpuDriver_TransferToDevice(
345             driver_, src, static_cast<DirectBufferHandle*>(dst)->handle_,
346             wait_for.size(), tpu_events));
347     delete[] tpu_events;
348     return event;
349   }
350 
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)351   std::shared_ptr<Event> TransferFromDevice(
352       const BufferHandle* src, void* dst,
353       absl::Span<Event* const> wait_for) override {
354     auto tpu_events = MakeEventArray(wait_for);
355     auto event = std::make_shared<DirectEvent>(
356         &driver_fn_,
357         driver_fn_.TpuDriver_TransferFromDevice(
358             driver_, static_cast<const DirectBufferHandle*>(src)->handle_, dst,
359             wait_for.size(), tpu_events));
360     delete[] tpu_events;
361     return event;
362   }
363 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)364   std::shared_ptr<Event> TransferFromDeviceToDevice(
365       const BufferHandle* src, BufferHandle* dst,
366       absl::Span<Event* const> wait_for) override {
367     auto tpu_events = MakeEventArray(wait_for);
368     auto event = std::make_shared<DirectEvent>(
369         &driver_fn_,
370         driver_fn_.TpuDriver_TransferFromDeviceToDevice(
371             driver_, static_cast<const DirectBufferHandle*>(src)->handle_,
372             static_cast<DirectBufferHandle*>(dst)->handle_, wait_for.size(),
373             tpu_events));
374     delete[] tpu_events;
375     return event;
376   }
377 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)378   std::unique_ptr<CompiledProgramHandle> CompileProgram(
379       const xla::HloProto& source, int32_t num_replicas,
380       absl::Span<Event* const> wait_for) override {
381     auto tpu_events = MakeEventArray(wait_for);
382 
383     struct HloProto hlo;
384     hlo.size = source.ByteSizeLong();
385     hlo.buffer = malloc(hlo.size);
386     if (!source.SerializeToArray(hlo.buffer, hlo.size)) {
387       LOG(ERROR) << "Unable to serialize HLO to array.";
388       return nullptr;
389     }
390 
391     auto handle = absl::make_unique<DirectCompiledProgramHandle>(
392         &driver_fn_,
393         driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas,
394                                             wait_for.size(), tpu_events));
395 
396     free(hlo.buffer);
397     delete[] tpu_events;
398     return handle;
399   }
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)400   std::unique_ptr<LoadedProgramHandle> LoadProgram(
401       int32_t core_id, const CompiledProgramHandle* handle,
402       absl::Span<Event* const> wait_for) override {
403     auto tpu_events = MakeEventArray(wait_for);
404 
405     auto loaded_handle = absl::make_unique<DirectLoadedProgramHandle>(
406         &driver_fn_,
407         driver_fn_.TpuDriver_LoadProgram(
408             driver_, core_id,
409             static_cast<const DirectCompiledProgramHandle*>(handle)->handle_,
410             wait_for.size(), tpu_events));
411 
412     delete[] tpu_events;
413     return loaded_handle;
414   }
415 
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)416   std::shared_ptr<Event> UnloadProgram(
417       std::unique_ptr<LoadedProgramHandle> handle,
418       absl::Span<Event* const> wait_for) override {
419     auto tpu_events = MakeEventArray(wait_for);
420     auto* direct_lph = static_cast<DirectLoadedProgramHandle*>(handle.get());
421     auto event = std::make_shared<DirectEvent>(
422         &driver_fn_,
423         driver_fn_.TpuDriver_UnloadProgram(driver_, direct_lph->handle_,
424                                            wait_for.size(), tpu_events));
425     delete[] tpu_events;
426     return event;
427   }
428 
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)429   std::shared_ptr<Event> ExecuteProgram(
430       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
431       absl::Span<BufferHandle* const> outputs,
432       const xla::DeviceAssignmentProto& device_assignment,
433       absl::Span<Event* const> wait_for) override {
434     auto tpu_events = MakeEventArray(wait_for);
435 
436     std::vector<::TpuBufferHandle*> inputv;
437     inputv.reserve(inputs.size());
438     for (int i = 0; i < inputs.size(); i++) {
439       inputv.push_back(
440           static_cast<DirectBufferHandle* const>(inputs[i])->handle_);
441     }
442     std::vector<::TpuBufferHandle*> outputv;
443     outputv.reserve(outputs.size());
444     for (int i = 0; i < outputs.size(); i++) {
445       outputv.push_back(
446           static_cast<DirectBufferHandle* const>(outputs[i])->handle_);
447     }
448 
449     struct DeviceAssignment da;
450     da.size = device_assignment.ByteSizeLong();
451     da.bytes = malloc(da.size);
452     device_assignment.SerializeToArray(da.bytes, da.size);
453 
454     auto event = std::make_shared<DirectEvent>(
455         &driver_fn_,
456         driver_fn_.TpuDriver_ExecuteProgram(
457             driver_, static_cast<DirectLoadedProgramHandle*>(program)->handle_,
458             inputs.size(), inputv.data(), outputs.size(), outputv.data(), da,
459             wait_for.size(), tpu_events));
460 
461     free(da.bytes);
462     delete[] tpu_events;
463     return event;
464   }
465 
GetLinearizer()466   std::unique_ptr<TpuLinearizer> GetLinearizer() override {
467     return std::make_unique<DirectTpuLinearizer>(driver_, &driver_fn_);
468   }
469 
470  private:
471   ::TpuDriverFn driver_fn_;
472   ::TpuDriver* driver_;
473 
MakeEventArray(absl::Span<Event * const> wait_for)474   ::TpuEvent** MakeEventArray(absl::Span<Event* const> wait_for) {
475     if (wait_for.empty()) return nullptr;
476     ::TpuEvent** ret = new ::TpuEvent*[wait_for.size()];
477     for (int i = 0; i < wait_for.size(); i++) {
478       ret[i] = static_cast<DirectEvent* const>(wait_for[i])->event_;
479     }
480     return ret;
481   }
482 };
483 
RegisterDirectTpuDriver(const TpuDriverConfig & config)484 xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterDirectTpuDriver(
485     const TpuDriverConfig& config) {
486   std::string shared_lib = config.worker().substr(strlen(kDirectProtocol));
487   if (shared_lib == "internal") {
488 #ifdef TPU_SHARED_LIBRARY_COMPILE_LINK
489     return xla::StatusOr<std::unique_ptr<TpuDriver>>(
490         absl::make_unique<DirectTpuDriver>());
491 #else
492     LOG(FATAL) << "Request to use compile-time linked TPU library, but did not "
493                << "link in appropriate library at compile time.";
494 #endif
495   }
496   return xla::StatusOr<std::unique_ptr<TpuDriver>>(
497       absl::make_unique<DirectTpuDriver>(shared_lib));
498 }
499 
500 REGISTER_TPU_DRIVER(kDirectProtocol, RegisterDirectTpuDriver);
501 
502 }  // namespace
503 }  // namespace tpu_driver
504