1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <atomic>
16 #include <functional>
17
18 #include "absl/base/internal/sysinfo.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_service.grpc.pb.h"
26 #include "tensorflow/core/platform/file_system.h"
27 #include "tensorflow/core/platform/stringpiece.h"
28 #include "tensorflow/core/platform/threadpool.h"
29
30 /*
31 * The ReplayDriver wraps a concrete TpuDriver implementation and records the
32 * stream of operations to a log file. This log can be later replayed and
33 * analyzed for debugging.
34 */
35
36 namespace tpu_driver {
37 namespace {
38
39 static std::atomic<int64_t> id_counter(0);
40
41 using xla::Status;
42
43 class RecordingTpuDriver;
44
45 class RecordingEvent : public Event {
46 public:
RecordingEvent(std::shared_ptr<Event> event)47 explicit RecordingEvent(std::shared_ptr<Event> event)
48 : shared_event_(std::move(event)), id_(id_counter++) {}
49
RecordingEvent(std::shared_ptr<Event> event,int64_t id)50 explicit RecordingEvent(std::shared_ptr<Event> event, int64_t id)
51 : shared_event_(event), id_(id) {}
52
~RecordingEvent()53 ~RecordingEvent() override {}
54
Await()55 xla::Status Await() override { return shared_event_->Await(); }
56
AwaitWithTimeout(absl::Duration duration)57 absl::optional<xla::Status> AwaitWithTimeout(
58 absl::Duration duration) override {
59 return shared_event_->AwaitWithTimeout(duration);
60 }
61
AddCallback(std::function<void (xla::Status)> callback)62 void AddCallback(std::function<void(xla::Status)> callback) override {
63 return shared_event_->AddCallback(callback);
64 }
65
66 private:
67 std::shared_ptr<Event> shared_event_;
68
69 int64_t id_;
70 friend class RecordingTpuDriver;
71 };
72
73 class RecordingBufferHandle : public BufferHandle {
74 public:
RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)75 explicit RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)
76 : handle_(std::move(handle)),
77 id_(id_counter++),
78 event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()79 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()80 int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
shape()81 absl::optional<xla::ShapeProto> shape() override { return handle_->shape(); }
82
83 private:
84 std::unique_ptr<BufferHandle> handle_;
85 int64_t id_;
86 std::shared_ptr<RecordingEvent> event_;
87 friend class RecordingTpuDriver;
88 };
89
90 class RecordingCompiledProgramHandle : public CompiledProgramHandle {
91 public:
RecordingCompiledProgramHandle(std::unique_ptr<CompiledProgramHandle> handle)92 explicit RecordingCompiledProgramHandle(
93 std::unique_ptr<CompiledProgramHandle> handle)
94 : handle_(std::move(handle)),
95 id_(id_counter++),
96 event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()97 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()98 int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
program_shape(xla::ProgramShapeProto * program_shape)99 xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
100 return handle_->program_shape(program_shape);
101 }
102
103 private:
104 std::unique_ptr<CompiledProgramHandle> handle_;
105 int64_t id_;
106 std::shared_ptr<RecordingEvent> event_;
107 friend class RecordingTpuDriver;
108 };
109
110 class RecordingLoadedProgramHandle : public LoadedProgramHandle {
111 public:
RecordingLoadedProgramHandle(std::unique_ptr<LoadedProgramHandle> handle)112 explicit RecordingLoadedProgramHandle(
113 std::unique_ptr<LoadedProgramHandle> handle)
114 : handle_(std::move(handle)),
115 id_(id_counter++),
116 event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()117 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()118 int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
119
120 private:
121 std::unique_ptr<LoadedProgramHandle> handle_;
122 int64_t id_;
123 std::shared_ptr<RecordingEvent> event_;
124 friend class RecordingTpuDriver;
125 };
126
127 class RecordingTpuDriver : public TpuDriver {
128 public:
RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,const std::string recording_path,const bool flush)129 explicit RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,
130 const std::string recording_path,
131 const bool flush)
132 : driver_(std::move(driver)),
133 recording_path_(recording_path),
134 flush_(flush) {
135 auto file_status = tensorflow::Env::Default()->NewAppendableFile(
136 recording_path_, &log_file_);
137 if (!file_status.ok()) {
138 LOG(FATAL) << "Unable to open " << recording_path_
139 << " for appending. Error: " << file_status.ToString();
140 }
141 }
~RecordingTpuDriver()142 ~RecordingTpuDriver() override {
143 {
144 log_file_->Flush().IgnoreError();
145 log_file_->Close().IgnoreError();
146 log_file_ = nullptr;
147 }
148 }
149
QuerySystemInfo(SystemInfo * system_info)150 void QuerySystemInfo(SystemInfo* system_info) override {
151 // TODO(frankchn): Should we even save this event, since it is out-of-band.
152 driver_->QuerySystemInfo(system_info);
153 }
154
Reset()155 Status Reset() override { return driver_->Reset(); }
156
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)157 std::unique_ptr<BufferHandle> Allocate(
158 int32_t core_id, MemoryRegion region, int64_t num_bytes,
159 absl::Span<Event* const> wait_for) override {
160 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
161
162 auto thread_id = GetCurrentThreadId();
163 auto handle =
164 driver_->Allocate(core_id, region, num_bytes, unwrapped_wait_for);
165 auto recording_handle =
166 std::make_unique<RecordingBufferHandle>(std::move(handle));
167 auto handle_id = recording_handle->id_;
168
169 {
170 StreamRequest::Entry r;
171 r.mutable_alloc()->set_core_id(core_id);
172 r.mutable_alloc()->set_region(region);
173 r.mutable_alloc()->set_num_bytes(num_bytes);
174
175 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
176 }
177
178 return recording_handle;
179 }
180
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)181 std::unique_ptr<BufferHandle> Allocate(
182 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
183 absl::Span<Event* const> wait_for) override {
184 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
185
186 auto thread_id = GetCurrentThreadId();
187 auto handle = driver_->Allocate(core_id, region, shape, unwrapped_wait_for);
188 auto recording_handle =
189 std::make_unique<RecordingBufferHandle>(std::move(handle));
190 auto handle_id = recording_handle->id_;
191
192 {
193 StreamRequest::Entry r;
194 r.mutable_alloc()->set_core_id(core_id);
195 r.mutable_alloc()->set_region(region);
196 *(r.mutable_alloc()->mutable_shape()) = shape;
197
198 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
199 }
200
201 return recording_handle;
202 }
203
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)204 std::unique_ptr<BufferHandle> AllocateTuple(
205 int32_t core_id, MemoryRegion region,
206 absl::Span<BufferHandle* const> children,
207 absl::Span<Event* const> wait_for) override {
208 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
209
210 std::vector<BufferHandle*> unwrapped_children;
211 std::vector<int64_t> child_ids;
212 for (auto child : children) {
213 BufferHandle* unwrapped_child =
214 static_cast<const RecordingBufferHandle*>(child)->handle_.get();
215 unwrapped_children.push_back(unwrapped_child);
216 child_ids.push_back(
217 static_cast<const RecordingBufferHandle*>(child)->id_);
218 }
219
220 auto thread_id = GetCurrentThreadId();
221 auto handle = driver_->AllocateTuple(core_id, region, unwrapped_children,
222 unwrapped_wait_for);
223 auto recording_handle =
224 std::make_unique<RecordingBufferHandle>(std::move(handle));
225 auto handle_id = recording_handle->id_;
226
227 {
228 StreamRequest::Entry r;
229 r.mutable_alloc_tuple()->set_core_id(core_id);
230 r.mutable_alloc_tuple()->set_region(region);
231
232 for (auto child : child_ids) {
233 r.mutable_alloc_tuple()->add_children(child);
234 }
235
236 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
237 }
238
239 return recording_handle;
240 }
241
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)242 std::shared_ptr<Event> Deallocate(
243 std::unique_ptr<BufferHandle> handle,
244 absl::Span<Event* const> wait_for) override {
245 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
246
247 auto thread_id = GetCurrentThreadId();
248 auto recording_handle = static_cast<RecordingBufferHandle*>(handle.get());
249 int64_t recording_handle_id = recording_handle->id_;
250 auto event = driver_->Deallocate(std::move(recording_handle->handle_),
251 unwrapped_wait_for);
252 auto recording_event = std::make_shared<RecordingEvent>(std::move(event));
253 int64_t event_id = recording_event->id_;
254
255 {
256 StreamRequest::Entry r;
257 r.mutable_dealloc()->set_handle(recording_handle_id);
258 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
259 }
260
261 return recording_event;
262 }
263
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)264 std::shared_ptr<Event> TransferToDevice(
265 const void* src, BufferHandle* dst,
266 absl::Span<Event* const> wait_for) override {
267 int64_t num_bytes = dst->size_in_bytes();
268 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
269
270 auto thread_id = GetCurrentThreadId();
271 auto recording_handle = static_cast<RecordingBufferHandle*>(dst);
272 int64_t recording_handle_id = recording_handle->id_;
273 auto recording_event =
274 std::make_shared<RecordingEvent>(driver_->TransferToDevice(
275 src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
276 unwrapped_wait_for));
277 int64_t event_id = recording_event->id_;
278
279 {
280 StreamRequest::Entry r;
281 r.mutable_transfer_to()->set_target_handle(recording_handle_id);
282 if (num_bytes > 0) {
283 r.mutable_transfer_to()->mutable_data()->assign(
284 static_cast<const char*>(src), num_bytes);
285 } else {
286 *r.mutable_transfer_to()->mutable_data() = "";
287 }
288 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
289 }
290
291 return recording_event;
292 }
293
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)294 std::shared_ptr<Event> TransferFromDevice(
295 const BufferHandle* src, void* dst,
296 absl::Span<Event* const> wait_for) override {
297 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
298
299 auto thread_id = GetCurrentThreadId();
300 auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
301 auto recording_event =
302 std::make_shared<RecordingEvent>(driver_->TransferFromDevice(
303 static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
304 unwrapped_wait_for));
305 auto event_id = recording_event->id_;
306
307 {
308 StreamRequest::Entry r;
309 r.mutable_transfer_from()->set_source_handle(src_handle_id);
310 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
311 }
312
313 return recording_event;
314 }
315
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)316 std::shared_ptr<Event> TransferFromDeviceToDevice(
317 const BufferHandle* src, BufferHandle* dst,
318 absl::Span<Event* const> wait_for) override {
319 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
320
321 auto thread_id = GetCurrentThreadId();
322 auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
323 auto dst_handle_id = static_cast<const RecordingBufferHandle*>(dst)->id_;
324 auto recording_event =
325 std::make_shared<RecordingEvent>(driver_->TransferFromDeviceToDevice(
326 static_cast<const RecordingBufferHandle*>(src)->handle_.get(),
327 static_cast<const RecordingBufferHandle*>(dst)->handle_.get(),
328 unwrapped_wait_for));
329 auto event_id = recording_event->id_;
330
331 {
332 StreamRequest::Entry r;
333 r.mutable_transfer_from_to()->set_source_handle(src_handle_id);
334 r.mutable_transfer_from_to()->set_target_handle(dst_handle_id);
335 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
336 }
337
338 return recording_event;
339 }
340
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)341 std::unique_ptr<CompiledProgramHandle> CompileProgram(
342 const xla::HloProto& source, int32_t num_replicas,
343 absl::Span<Event* const> wait_for) override {
344 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
345
346 auto thread_id = GetCurrentThreadId();
347 auto recording_handle = std::make_unique<RecordingCompiledProgramHandle>(
348 driver_->CompileProgram(source, num_replicas, unwrapped_wait_for));
349 auto handle_id = recording_handle->id_;
350
351 {
352 StreamRequest::Entry r;
353 *r.mutable_compile()->mutable_hlo_program() = source;
354 r.mutable_compile()->set_num_replicas(num_replicas);
355 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
356 }
357
358 return recording_handle;
359 }
360
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)361 std::unique_ptr<LoadedProgramHandle> LoadProgram(
362 int32_t core_id, const CompiledProgramHandle* handle,
363 absl::Span<Event* const> wait_for) override {
364 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
365
366 auto thread_id = GetCurrentThreadId();
367 auto compiled_handle_id =
368 static_cast<const RecordingCompiledProgramHandle*>(handle)->id_;
369 auto recording_handle =
370 std::make_unique<RecordingLoadedProgramHandle>(driver_->LoadProgram(
371 core_id,
372 static_cast<const RecordingCompiledProgramHandle*>(handle)
373 ->handle_.get(),
374 unwrapped_wait_for));
375 auto handle_id = recording_handle->id_;
376 {
377 StreamRequest::Entry r;
378 r.mutable_load()->set_core_id(core_id);
379 r.mutable_load()->set_compiled_program_handle(compiled_handle_id);
380 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
381 }
382
383 return recording_handle;
384 }
385
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)386 std::shared_ptr<Event> UnloadProgram(
387 std::unique_ptr<LoadedProgramHandle> handle,
388 absl::Span<Event* const> wait_for) override {
389 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
390
391 auto thread_id = GetCurrentThreadId();
392 auto loaded_handle_id =
393 static_cast<RecordingLoadedProgramHandle*>(handle.get())->id_;
394 auto recording_event =
395 std::make_shared<RecordingEvent>(driver_->UnloadProgram(
396 std::move(static_cast<RecordingLoadedProgramHandle*>(handle.get())
397 ->handle_),
398 unwrapped_wait_for));
399 auto event_id = recording_event->id_;
400
401 {
402 StreamRequest::Entry r;
403 r.mutable_unload()->set_loaded_program_handle(loaded_handle_id);
404 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
405 }
406
407 return recording_event;
408 }
409
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)410 std::shared_ptr<Event> ExecuteProgram(
411 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
412 absl::Span<BufferHandle* const> outputs,
413 const xla::DeviceAssignmentProto& device_assignment,
414 absl::Span<Event* const> wait_for) override {
415 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
416
417 auto thread_id = GetCurrentThreadId();
418 auto program_handle_id =
419 static_cast<RecordingLoadedProgramHandle*>(program)->id_;
420
421 std::vector<BufferHandle*> unwrapped_inputs;
422 std::vector<int64_t> input_ids;
423 for (auto input : inputs) {
424 BufferHandle* unwrapped_input =
425 static_cast<const RecordingBufferHandle*>(input)->handle_.get();
426 unwrapped_inputs.push_back(unwrapped_input);
427 input_ids.push_back(
428 static_cast<const RecordingBufferHandle*>(input)->id_);
429 }
430
431 std::vector<BufferHandle*> unwrapped_outputs;
432 std::vector<int64_t> output_ids;
433 for (auto output : outputs) {
434 BufferHandle* unwrapped_output =
435 static_cast<const RecordingBufferHandle*>(output)->handle_.get();
436 unwrapped_outputs.push_back(unwrapped_output);
437 output_ids.push_back(
438 static_cast<const RecordingBufferHandle*>(output)->id_);
439 }
440
441 auto recording_event =
442 std::make_shared<RecordingEvent>(driver_->ExecuteProgram(
443 static_cast<RecordingLoadedProgramHandle*>(program)->handle_.get(),
444 unwrapped_inputs, unwrapped_outputs, device_assignment,
445 unwrapped_wait_for));
446 auto event_id = recording_event->id_;
447
448 {
449 StreamRequest::Entry r;
450 r.mutable_execute()->set_loaded_program_handle(program_handle_id);
451 for (auto input_id : input_ids) {
452 r.mutable_execute()->add_input_handle(input_id);
453 }
454 for (auto output_id : output_ids) {
455 r.mutable_execute()->add_output_handle(output_id);
456 }
457 *r.mutable_execute()->mutable_device_assignment() = device_assignment;
458
459 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
460 }
461
462 return recording_event;
463 }
464
GetLinearizer()465 std::unique_ptr<TpuLinearizer> GetLinearizer() override {
466 return driver_->GetLinearizer();
467 }
468
469 private:
470 std::unique_ptr<TpuDriver> driver_;
471 const std::string recording_path_;
472 const bool flush_;
473
474 std::unique_ptr<tensorflow::WritableFile> log_file_;
475
PopulateAndSaveEntry(StreamRequest::Entry * r,absl::Span<Event * const> wait_for,int64_t handle_id,int64_t thread_id)476 void PopulateAndSaveEntry(StreamRequest::Entry* r,
477 absl::Span<Event* const> wait_for,
478 int64_t handle_id, int64_t thread_id) {
479 for (auto event : wait_for) {
480 auto recording_event = static_cast<const RecordingEvent*>(event);
481 r->add_wait_for_id(recording_event->id_);
482 }
483 r->set_operation_id(handle_id);
484 r->set_thread_id(thread_id);
485
486 uint64_t data_size = r->ByteSizeLong();
487 std::vector<char> buffer;
488 buffer.resize(sizeof(data_size) + data_size);
489 memcpy(buffer.data(), &data_size, sizeof(data_size));
490 r->SerializeToArray(buffer.data() + sizeof(data_size), data_size);
491
492 {
493 if (log_file_ == nullptr) {
494 LOG(WARNING) << "The TPU driver has been shut down before all logging "
495 "has been written.";
496 return;
497 }
498
499 tensorflow::StringPiece buffer_sp(buffer.data(), buffer.size());
500 auto data_status = log_file_->Append(buffer_sp);
501 if (!data_status.ok()) {
502 LOG(WARNING) << "Unable to write data to log file. File possibly "
503 "corrupt. Error: "
504 << data_status.ToString();
505 }
506
507 if (flush_) {
508 auto flush_status = log_file_->Flush();
509 if (!flush_status.ok()) {
510 LOG(WARNING) << "Unable to flush data to log file. File possibly "
511 "corrupt. Error: "
512 << flush_status.ToString();
513 }
514
515 auto sync_status = log_file_->Sync();
516 if (!sync_status.ok()) {
517 LOG(WARNING) << "Unable to sync log file. File possibly "
518 "corrupt. Error: "
519 << sync_status.ToString();
520 }
521 }
522 }
523 }
524
UnwrapWaitFor(absl::Span<Event * const> wait_for)525 std::vector<Event*> UnwrapWaitFor(absl::Span<Event* const> wait_for) {
526 std::vector<Event*> unwrapped_events;
527 for (auto event : wait_for) {
528 Event* unwrapped_event =
529 static_cast<RecordingEvent*>(event)->shared_event_.get();
530 unwrapped_events.push_back(unwrapped_event);
531 }
532 return unwrapped_events;
533 }
534
GetCurrentThreadId()535 int64_t GetCurrentThreadId() { return absl::base_internal::GetTID(); }
536 };
537
RegisterRecordingTpuDriver(const TpuDriverConfig & config)538 xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
539 const TpuDriverConfig& config) {
540 std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
541
542 std::string file;
543 std::string worker;
544 bool flush = false;
545
546 for (const auto& config : configs) {
547 std::vector<std::string> kv =
548 absl::StrSplit(config, absl::MaxSplits('=', 1));
549 if (kv[0] == "file") {
550 file = kv[1];
551 }
552 if (kv[0] == "worker") {
553 worker = kv[1];
554 }
555 if (kv[0] == "flush") {
556 if (kv[1] == "true" || kv[1] == "1") {
557 flush = true;
558 }
559 }
560 }
561
562 TpuDriverConfig worker_config;
563 worker_config.set_worker(worker);
564
565 auto driver_status = TpuDriverRegistry::Open(worker_config);
566 if (!driver_status.ok()) return driver_status.status();
567 auto driver = driver_status.ConsumeValueOrDie();
568
569 return std::unique_ptr<TpuDriver>(
570 new RecordingTpuDriver(std::move(driver), file, flush));
571 }
572
573 // To record a sequence of operations, set the worker configuration string to
574 // record://|file=<filename>|worker=grpc://1.2.3.4:8470 (for GRPC).
575 REGISTER_TPU_DRIVER("record://", RegisterRecordingTpuDriver);
576
577 } // namespace
578 } // namespace tpu_driver
579