• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <atomic>
16 #include <functional>
17 
18 #include "absl/base/internal/sysinfo.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_service.grpc.pb.h"
26 #include "tensorflow/core/platform/file_system.h"
27 #include "tensorflow/core/platform/stringpiece.h"
28 #include "tensorflow/core/platform/threadpool.h"
29 
30 /*
31  * The ReplayDriver wraps a concrete TpuDriver implementation and records the
32  * stream of operations to a log file. This log can be later replayed and
33  * analyzed for debugging.
34  */
35 
36 namespace tpu_driver {
37 namespace {
38 
39 static std::atomic<int64_t> id_counter(0);
40 
41 using xla::Status;
42 
43 class RecordingTpuDriver;
44 
45 class RecordingEvent : public Event {
46  public:
RecordingEvent(std::shared_ptr<Event> event)47   explicit RecordingEvent(std::shared_ptr<Event> event)
48       : shared_event_(std::move(event)), id_(id_counter++) {}
49 
RecordingEvent(std::shared_ptr<Event> event,int64_t id)50   explicit RecordingEvent(std::shared_ptr<Event> event, int64_t id)
51       : shared_event_(event), id_(id) {}
52 
~RecordingEvent()53   ~RecordingEvent() override {}
54 
Await()55   xla::Status Await() override { return shared_event_->Await(); }
56 
AwaitWithTimeout(absl::Duration duration)57   absl::optional<xla::Status> AwaitWithTimeout(
58       absl::Duration duration) override {
59     return shared_event_->AwaitWithTimeout(duration);
60   }
61 
AddCallback(std::function<void (xla::Status)> callback)62   void AddCallback(std::function<void(xla::Status)> callback) override {
63     return shared_event_->AddCallback(callback);
64   }
65 
66  private:
67   std::shared_ptr<Event> shared_event_;
68 
69   int64_t id_;
70   friend class RecordingTpuDriver;
71 };
72 
73 class RecordingBufferHandle : public BufferHandle {
74  public:
RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)75   explicit RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)
76       : handle_(std::move(handle)),
77         id_(id_counter++),
78         event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()79   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()80   int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
shape()81   absl::optional<xla::ShapeProto> shape() override { return handle_->shape(); }
82 
83  private:
84   std::unique_ptr<BufferHandle> handle_;
85   int64_t id_;
86   std::shared_ptr<RecordingEvent> event_;
87   friend class RecordingTpuDriver;
88 };
89 
90 class RecordingCompiledProgramHandle : public CompiledProgramHandle {
91  public:
RecordingCompiledProgramHandle(std::unique_ptr<CompiledProgramHandle> handle)92   explicit RecordingCompiledProgramHandle(
93       std::unique_ptr<CompiledProgramHandle> handle)
94       : handle_(std::move(handle)),
95         id_(id_counter++),
96         event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()97   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()98   int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
program_shape(xla::ProgramShapeProto * program_shape)99   xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
100     return handle_->program_shape(program_shape);
101   }
102 
103  private:
104   std::unique_ptr<CompiledProgramHandle> handle_;
105   int64_t id_;
106   std::shared_ptr<RecordingEvent> event_;
107   friend class RecordingTpuDriver;
108 };
109 
110 class RecordingLoadedProgramHandle : public LoadedProgramHandle {
111  public:
RecordingLoadedProgramHandle(std::unique_ptr<LoadedProgramHandle> handle)112   explicit RecordingLoadedProgramHandle(
113       std::unique_ptr<LoadedProgramHandle> handle)
114       : handle_(std::move(handle)),
115         id_(id_counter++),
116         event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()117   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()118   int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
119 
120  private:
121   std::unique_ptr<LoadedProgramHandle> handle_;
122   int64_t id_;
123   std::shared_ptr<RecordingEvent> event_;
124   friend class RecordingTpuDriver;
125 };
126 
127 class RecordingTpuDriver : public TpuDriver {
128  public:
RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,const std::string recording_path,const bool flush)129   explicit RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,
130                               const std::string recording_path,
131                               const bool flush)
132       : driver_(std::move(driver)),
133         recording_path_(recording_path),
134         flush_(flush) {
135     auto file_status = tensorflow::Env::Default()->NewAppendableFile(
136         recording_path_, &log_file_);
137     if (!file_status.ok()) {
138       LOG(FATAL) << "Unable to open " << recording_path_
139                  << " for appending. Error: " << file_status.ToString();
140     }
141   }
~RecordingTpuDriver()142   ~RecordingTpuDriver() override {
143     {
144       log_file_->Flush().IgnoreError();
145       log_file_->Close().IgnoreError();
146       log_file_ = nullptr;
147     }
148   }
149 
QuerySystemInfo(SystemInfo * system_info)150   void QuerySystemInfo(SystemInfo* system_info) override {
151     // TODO(frankchn): Should we even save this event, since it is out-of-band.
152     driver_->QuerySystemInfo(system_info);
153   }
154 
Reset()155   Status Reset() override { return driver_->Reset(); }
156 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)157   std::unique_ptr<BufferHandle> Allocate(
158       int32_t core_id, MemoryRegion region, int64_t num_bytes,
159       absl::Span<Event* const> wait_for) override {
160     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
161 
162     auto thread_id = GetCurrentThreadId();
163     auto handle =
164         driver_->Allocate(core_id, region, num_bytes, unwrapped_wait_for);
165     auto recording_handle =
166         std::make_unique<RecordingBufferHandle>(std::move(handle));
167     auto handle_id = recording_handle->id_;
168 
169     {
170       StreamRequest::Entry r;
171       r.mutable_alloc()->set_core_id(core_id);
172       r.mutable_alloc()->set_region(region);
173       r.mutable_alloc()->set_num_bytes(num_bytes);
174 
175       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
176     }
177 
178     return recording_handle;
179   }
180 
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)181   std::unique_ptr<BufferHandle> Allocate(
182       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
183       absl::Span<Event* const> wait_for) override {
184     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
185 
186     auto thread_id = GetCurrentThreadId();
187     auto handle = driver_->Allocate(core_id, region, shape, unwrapped_wait_for);
188     auto recording_handle =
189         std::make_unique<RecordingBufferHandle>(std::move(handle));
190     auto handle_id = recording_handle->id_;
191 
192     {
193       StreamRequest::Entry r;
194       r.mutable_alloc()->set_core_id(core_id);
195       r.mutable_alloc()->set_region(region);
196       *(r.mutable_alloc()->mutable_shape()) = shape;
197 
198       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
199     }
200 
201     return recording_handle;
202   }
203 
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)204   std::unique_ptr<BufferHandle> AllocateTuple(
205       int32_t core_id, MemoryRegion region,
206       absl::Span<BufferHandle* const> children,
207       absl::Span<Event* const> wait_for) override {
208     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
209 
210     std::vector<BufferHandle*> unwrapped_children;
211     std::vector<int64_t> child_ids;
212     for (auto child : children) {
213       BufferHandle* unwrapped_child =
214           static_cast<const RecordingBufferHandle*>(child)->handle_.get();
215       unwrapped_children.push_back(unwrapped_child);
216       child_ids.push_back(
217           static_cast<const RecordingBufferHandle*>(child)->id_);
218     }
219 
220     auto thread_id = GetCurrentThreadId();
221     auto handle = driver_->AllocateTuple(core_id, region, unwrapped_children,
222                                          unwrapped_wait_for);
223     auto recording_handle =
224         std::make_unique<RecordingBufferHandle>(std::move(handle));
225     auto handle_id = recording_handle->id_;
226 
227     {
228       StreamRequest::Entry r;
229       r.mutable_alloc_tuple()->set_core_id(core_id);
230       r.mutable_alloc_tuple()->set_region(region);
231 
232       for (auto child : child_ids) {
233         r.mutable_alloc_tuple()->add_children(child);
234       }
235 
236       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
237     }
238 
239     return recording_handle;
240   }
241 
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)242   std::shared_ptr<Event> Deallocate(
243       std::unique_ptr<BufferHandle> handle,
244       absl::Span<Event* const> wait_for) override {
245     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
246 
247     auto thread_id = GetCurrentThreadId();
248     auto recording_handle = static_cast<RecordingBufferHandle*>(handle.get());
249     int64_t recording_handle_id = recording_handle->id_;
250     auto event = driver_->Deallocate(std::move(recording_handle->handle_),
251                                      unwrapped_wait_for);
252     auto recording_event = std::make_shared<RecordingEvent>(std::move(event));
253     int64_t event_id = recording_event->id_;
254 
255     {
256       StreamRequest::Entry r;
257       r.mutable_dealloc()->set_handle(recording_handle_id);
258       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
259     }
260 
261     return recording_event;
262   }
263 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)264   std::shared_ptr<Event> TransferToDevice(
265       const void* src, BufferHandle* dst,
266       absl::Span<Event* const> wait_for) override {
267     int64_t num_bytes = dst->size_in_bytes();
268     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
269 
270     auto thread_id = GetCurrentThreadId();
271     auto recording_handle = static_cast<RecordingBufferHandle*>(dst);
272     int64_t recording_handle_id = recording_handle->id_;
273     auto recording_event =
274         std::make_shared<RecordingEvent>(driver_->TransferToDevice(
275             src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
276             unwrapped_wait_for));
277     int64_t event_id = recording_event->id_;
278 
279     {
280       StreamRequest::Entry r;
281       r.mutable_transfer_to()->set_target_handle(recording_handle_id);
282       if (num_bytes > 0) {
283         r.mutable_transfer_to()->mutable_data()->assign(
284             static_cast<const char*>(src), num_bytes);
285       } else {
286         *r.mutable_transfer_to()->mutable_data() = "";
287       }
288       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
289     }
290 
291     return recording_event;
292   }
293 
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)294   std::shared_ptr<Event> TransferFromDevice(
295       const BufferHandle* src, void* dst,
296       absl::Span<Event* const> wait_for) override {
297     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
298 
299     auto thread_id = GetCurrentThreadId();
300     auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
301     auto recording_event =
302         std::make_shared<RecordingEvent>(driver_->TransferFromDevice(
303             static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
304             unwrapped_wait_for));
305     auto event_id = recording_event->id_;
306 
307     {
308       StreamRequest::Entry r;
309       r.mutable_transfer_from()->set_source_handle(src_handle_id);
310       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
311     }
312 
313     return recording_event;
314   }
315 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)316   std::shared_ptr<Event> TransferFromDeviceToDevice(
317       const BufferHandle* src, BufferHandle* dst,
318       absl::Span<Event* const> wait_for) override {
319     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
320 
321     auto thread_id = GetCurrentThreadId();
322     auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
323     auto dst_handle_id = static_cast<const RecordingBufferHandle*>(dst)->id_;
324     auto recording_event =
325         std::make_shared<RecordingEvent>(driver_->TransferFromDeviceToDevice(
326             static_cast<const RecordingBufferHandle*>(src)->handle_.get(),
327             static_cast<const RecordingBufferHandle*>(dst)->handle_.get(),
328             unwrapped_wait_for));
329     auto event_id = recording_event->id_;
330 
331     {
332       StreamRequest::Entry r;
333       r.mutable_transfer_from_to()->set_source_handle(src_handle_id);
334       r.mutable_transfer_from_to()->set_target_handle(dst_handle_id);
335       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
336     }
337 
338     return recording_event;
339   }
340 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)341   std::unique_ptr<CompiledProgramHandle> CompileProgram(
342       const xla::HloProto& source, int32_t num_replicas,
343       absl::Span<Event* const> wait_for) override {
344     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
345 
346     auto thread_id = GetCurrentThreadId();
347     auto recording_handle = std::make_unique<RecordingCompiledProgramHandle>(
348         driver_->CompileProgram(source, num_replicas, unwrapped_wait_for));
349     auto handle_id = recording_handle->id_;
350 
351     {
352       StreamRequest::Entry r;
353       *r.mutable_compile()->mutable_hlo_program() = source;
354       r.mutable_compile()->set_num_replicas(num_replicas);
355       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
356     }
357 
358     return recording_handle;
359   }
360 
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)361   std::unique_ptr<LoadedProgramHandle> LoadProgram(
362       int32_t core_id, const CompiledProgramHandle* handle,
363       absl::Span<Event* const> wait_for) override {
364     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
365 
366     auto thread_id = GetCurrentThreadId();
367     auto compiled_handle_id =
368         static_cast<const RecordingCompiledProgramHandle*>(handle)->id_;
369     auto recording_handle =
370         std::make_unique<RecordingLoadedProgramHandle>(driver_->LoadProgram(
371             core_id,
372             static_cast<const RecordingCompiledProgramHandle*>(handle)
373                 ->handle_.get(),
374             unwrapped_wait_for));
375     auto handle_id = recording_handle->id_;
376     {
377       StreamRequest::Entry r;
378       r.mutable_load()->set_core_id(core_id);
379       r.mutable_load()->set_compiled_program_handle(compiled_handle_id);
380       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
381     }
382 
383     return recording_handle;
384   }
385 
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)386   std::shared_ptr<Event> UnloadProgram(
387       std::unique_ptr<LoadedProgramHandle> handle,
388       absl::Span<Event* const> wait_for) override {
389     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
390 
391     auto thread_id = GetCurrentThreadId();
392     auto loaded_handle_id =
393         static_cast<RecordingLoadedProgramHandle*>(handle.get())->id_;
394     auto recording_event =
395         std::make_shared<RecordingEvent>(driver_->UnloadProgram(
396             std::move(static_cast<RecordingLoadedProgramHandle*>(handle.get())
397                           ->handle_),
398             unwrapped_wait_for));
399     auto event_id = recording_event->id_;
400 
401     {
402       StreamRequest::Entry r;
403       r.mutable_unload()->set_loaded_program_handle(loaded_handle_id);
404       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
405     }
406 
407     return recording_event;
408   }
409 
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)410   std::shared_ptr<Event> ExecuteProgram(
411       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
412       absl::Span<BufferHandle* const> outputs,
413       const xla::DeviceAssignmentProto& device_assignment,
414       absl::Span<Event* const> wait_for) override {
415     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
416 
417     auto thread_id = GetCurrentThreadId();
418     auto program_handle_id =
419         static_cast<RecordingLoadedProgramHandle*>(program)->id_;
420 
421     std::vector<BufferHandle*> unwrapped_inputs;
422     std::vector<int64_t> input_ids;
423     for (auto input : inputs) {
424       BufferHandle* unwrapped_input =
425           static_cast<const RecordingBufferHandle*>(input)->handle_.get();
426       unwrapped_inputs.push_back(unwrapped_input);
427       input_ids.push_back(
428           static_cast<const RecordingBufferHandle*>(input)->id_);
429     }
430 
431     std::vector<BufferHandle*> unwrapped_outputs;
432     std::vector<int64_t> output_ids;
433     for (auto output : outputs) {
434       BufferHandle* unwrapped_output =
435           static_cast<const RecordingBufferHandle*>(output)->handle_.get();
436       unwrapped_outputs.push_back(unwrapped_output);
437       output_ids.push_back(
438           static_cast<const RecordingBufferHandle*>(output)->id_);
439     }
440 
441     auto recording_event =
442         std::make_shared<RecordingEvent>(driver_->ExecuteProgram(
443             static_cast<RecordingLoadedProgramHandle*>(program)->handle_.get(),
444             unwrapped_inputs, unwrapped_outputs, device_assignment,
445             unwrapped_wait_for));
446     auto event_id = recording_event->id_;
447 
448     {
449       StreamRequest::Entry r;
450       r.mutable_execute()->set_loaded_program_handle(program_handle_id);
451       for (auto input_id : input_ids) {
452         r.mutable_execute()->add_input_handle(input_id);
453       }
454       for (auto output_id : output_ids) {
455         r.mutable_execute()->add_output_handle(output_id);
456       }
457       *r.mutable_execute()->mutable_device_assignment() = device_assignment;
458 
459       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
460     }
461 
462     return recording_event;
463   }
464 
GetLinearizer()465   std::unique_ptr<TpuLinearizer> GetLinearizer() override {
466     return driver_->GetLinearizer();
467   }
468 
469  private:
470   std::unique_ptr<TpuDriver> driver_;
471   const std::string recording_path_;
472   const bool flush_;
473 
474   std::unique_ptr<tensorflow::WritableFile> log_file_;
475 
PopulateAndSaveEntry(StreamRequest::Entry * r,absl::Span<Event * const> wait_for,int64_t handle_id,int64_t thread_id)476   void PopulateAndSaveEntry(StreamRequest::Entry* r,
477                             absl::Span<Event* const> wait_for,
478                             int64_t handle_id, int64_t thread_id) {
479     for (auto event : wait_for) {
480       auto recording_event = static_cast<const RecordingEvent*>(event);
481       r->add_wait_for_id(recording_event->id_);
482     }
483     r->set_operation_id(handle_id);
484     r->set_thread_id(thread_id);
485 
486     uint64_t data_size = r->ByteSizeLong();
487     std::vector<char> buffer;
488     buffer.resize(sizeof(data_size) + data_size);
489     memcpy(buffer.data(), &data_size, sizeof(data_size));
490     r->SerializeToArray(buffer.data() + sizeof(data_size), data_size);
491 
492     {
493       if (log_file_ == nullptr) {
494         LOG(WARNING) << "The TPU driver has been shut down before all logging "
495                         "has been written.";
496         return;
497       }
498 
499       tensorflow::StringPiece buffer_sp(buffer.data(), buffer.size());
500       auto data_status = log_file_->Append(buffer_sp);
501       if (!data_status.ok()) {
502         LOG(WARNING) << "Unable to write data to log file. File possibly "
503                         "corrupt. Error: "
504                      << data_status.ToString();
505       }
506 
507       if (flush_) {
508         auto flush_status = log_file_->Flush();
509         if (!flush_status.ok()) {
510           LOG(WARNING) << "Unable to flush data to log file. File possibly "
511                           "corrupt. Error: "
512                        << flush_status.ToString();
513         }
514 
515         auto sync_status = log_file_->Sync();
516         if (!sync_status.ok()) {
517           LOG(WARNING) << "Unable to sync log file. File possibly "
518                           "corrupt. Error: "
519                        << sync_status.ToString();
520         }
521       }
522     }
523   }
524 
UnwrapWaitFor(absl::Span<Event * const> wait_for)525   std::vector<Event*> UnwrapWaitFor(absl::Span<Event* const> wait_for) {
526     std::vector<Event*> unwrapped_events;
527     for (auto event : wait_for) {
528       Event* unwrapped_event =
529           static_cast<RecordingEvent*>(event)->shared_event_.get();
530       unwrapped_events.push_back(unwrapped_event);
531     }
532     return unwrapped_events;
533   }
534 
GetCurrentThreadId()535   int64_t GetCurrentThreadId() { return absl::base_internal::GetTID(); }
536 };
537 
RegisterRecordingTpuDriver(const TpuDriverConfig & config)538 xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
539     const TpuDriverConfig& config) {
540   std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
541 
542   std::string file;
543   std::string worker;
544   bool flush = false;
545 
546   for (const auto& config : configs) {
547     std::vector<std::string> kv =
548         absl::StrSplit(config, absl::MaxSplits('=', 1));
549     if (kv[0] == "file") {
550       file = kv[1];
551     }
552     if (kv[0] == "worker") {
553       worker = kv[1];
554     }
555     if (kv[0] == "flush") {
556       if (kv[1] == "true" || kv[1] == "1") {
557         flush = true;
558       }
559     }
560   }
561 
562   TpuDriverConfig worker_config;
563   worker_config.set_worker(worker);
564 
565   auto driver_status = TpuDriverRegistry::Open(worker_config);
566   if (!driver_status.ok()) return driver_status.status();
567   auto driver = driver_status.ConsumeValueOrDie();
568 
569   return std::unique_ptr<TpuDriver>(
570       new RecordingTpuDriver(std::move(driver), file, flush));
571 }
572 
573 // To record a sequence of operations, set the worker configuration string to
574 // record://|file=<filename>|worker=grpc://1.2.3.4:8470 (for GRPC).
575 REGISTER_TPU_DRIVER("record://", RegisterRecordingTpuDriver);
576 
577 }  // namespace
578 }  // namespace tpu_driver
579