1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_tpu_device.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/compiler/jit/xla_device_ops.h"
22 #include "tensorflow/compiler/tf2xla/layout_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/core/common_runtime/copy_tensor.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_factory.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/framework/kernel_def.pb.h"
32 #include "tensorflow/core/framework/tensor_reference.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/tpu/tpu_api.h"
36 #include "tensorflow/core/tpu/tpu_defs.h"
37 #include "tensorflow/core/tpu/tpu_node_device_util.h"
38 #include "tensorflow/core/tpu/virtual_device.h"
39 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
40 #include "tensorflow/stream_executor/tpu/status_helper.h"
41 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
42 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
43 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
44 #include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
45
46 namespace tensorflow {
47 namespace {
48
49 static bool tpu_autoclustering_flag = false;
50 static bool tpu_xla_device_failure_closes_chips_flag = true;
51 static bool tpu_use_substreams_for_cross_tpu_device_transfers_flag = true;
52
53 // Given a tensor of `shape` and `type`, as what shape should it be stored on
54 // the TPU device? This function tranposes or flattens the excessively-padded
55 // tensors to rank 1, but leaves other tensor shapes alone.
TpuShapeRepresentation(const TensorShape & shape,DataType type,bool use_fast_memory,XlaLayoutPreference layout_preference)56 StatusOr<xla::Shape> TpuShapeRepresentation(
57 const TensorShape& shape, DataType type, bool use_fast_memory,
58 XlaLayoutPreference layout_preference) {
59 xla::Shape xla_shape;
60 TF_RETURN_IF_ERROR(
61 tensorflow::TensorShapeToXLAShape(type, shape, &xla_shape));
62 ApiConverter::StackHelper<XLA_Shape> se_shape(xla_shape);
63 ApiConverter::StackHelper<XLA_Shape> tpu_shape;
64 StatusHelper status;
65 tpu::ExecutorApiFn()->XlaShapeToTpuShapeRepresentationFn(
66 &se_shape.value, type, use_fast_memory, &tpu_shape.value,
67 status.c_status);
68 if (!status.status().ok()) {
69 return status.status();
70 }
71 return tpu_shape.AsCpp<xla::Shape>();
72 }
73
74 // Given a tensor, returns the shape of its representation on device,
75 // fully padded. Contents of `shape` are undefined on error.
TpuPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)76 Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
77 const tensorflow::XlaTensor* xla_tensor =
78 tensorflow::XlaTensor::FromTensor(&tensor);
79 if (xla_tensor == nullptr) {
80 return errors::InvalidArgument(
81 "Expected an XlaTensor when computing padded shape");
82 }
83
84 if (!xla_tensor->has_shaped_buffer()) {
85 return errors::InvalidArgument(
86 "XlaTensor is expected to have device memory allocated when "
87 "computing padded shape");
88 }
89
90 const xla::Shape& on_device_shape =
91 xla_tensor->shaped_buffer().on_device_shape();
92
93 StatusHelper status;
94 ApiConverter::StackHelper<XLA_Shape> se_shape(on_device_shape);
95 ApiConverter::StackHelper<XLA_Shape> tpu_shape;
96 tpu::ExecutorApiFn()->XlaShapeToTpuPaddedShapeFn(
97 &se_shape.value, &tpu_shape.value, status.c_status);
98 if (!status.ok()) {
99 return status.status();
100 }
101 *shape = tpu_shape.AsCpp<xla::Shape>();
102 return OkStatus();
103 }
104
105 // Check if TPU has been initialized. TPU initialization is not necessary
106 // for 1x1.
CheckIfTPUInitialized()107 Status CheckIfTPUInitialized() {
108 auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
109 if (!tpu_platform->Initialized()) {
110 return errors::FailedPrecondition(
111 "The TPU system has not been initialized.");
112 }
113 return OkStatus();
114 }
115
116 // Implementation of TPU->TPU device copies that copies over the dedicated TPU
117 // interconnects, which is much faster than PCIe or the host network.
118 // TODO(b/117426293): This implementation is only called for direct interconnect
119 // transfers between TPU devices attached to the same host. Ideally, we would
120 // generalize this support to direct interconnect transfers across hosts, but
121 // currently the CopyTensor infrastructure seems to the network topology is
122 // strictly hierarchical, that is, transfers between devices on different hosts
123 // can only take place using the host network.
TpuDeviceToDeviceCopy(DeviceContext * src_dev_context,DeviceContext * dst_dev_context,Device * src,Device * dst,AllocatorAttributes src_allocator_attrs,AllocatorAttributes dst_allocator_attrs,const Tensor * input,Tensor * output,int dev_to_dev_stream_index,StatusCallback done)124 void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context,
125 DeviceContext* dst_dev_context, Device* src,
126 Device* dst, AllocatorAttributes src_allocator_attrs,
127 AllocatorAttributes dst_allocator_attrs,
128 const Tensor* input, Tensor* output,
129 int dev_to_dev_stream_index, StatusCallback done) {
130 XlaDeviceContext* const src_xla_context =
131 static_cast<XlaDeviceContext*>(src_dev_context);
132 XlaDeviceContext* const dst_xla_context =
133 static_cast<XlaDeviceContext*>(dst_dev_context);
134 static const bool should_use_substream =
135 tpu_use_substreams_for_cross_tpu_device_transfers_flag;
136
137 auto impl = [&]() -> Status {
138 if (src->name() != dst->name()) {
139 Status s = CheckIfTPUInitialized();
140 if (!s.ok()) {
141 done(s);
142 return OkStatus();
143 }
144 }
145 if (input->shape().num_elements() == 0) {
146 // Zero-element tensors have no backing buffers.
147 done(OkStatus());
148 return OkStatus();
149 }
150
151 se::Stream* const src_compute_stream = src_xla_context->stream();
152 TF_RET_CHECK(src_compute_stream != nullptr);
153 TF_RET_CHECK(input->dtype() == output->dtype())
154 << "input type: " << DataTypeString(input->dtype()) << " output type "
155 << DataTypeString(output->dtype());
156 TF_RET_CHECK(input->shape() == output->shape());
157 TF_RET_CHECK(DMAHelper::CanUseDMA(input));
158 auto* const src_compute_stream_impl = static_cast<tpu::TpuStreamInterface*>(
159 src_compute_stream->implementation());
160
161 se::Stream* dst_compute_stream = dst_xla_context->stream();
162 auto* const dst_compute_stream_impl = static_cast<tpu::TpuStreamInterface*>(
163 dst_compute_stream->implementation());
164
165 if (src_compute_stream_impl->IsSameSharedMemoryLocation(
166 dst_compute_stream_impl)) {
167 // Surprisingly, this path does get triggered in practice.
168 *output = *input;
169 done(OkStatus());
170 return OkStatus();
171 }
172
173 // To avoid stream exhaustion, we pick a substream from a pool if enabled.
174 se::Stream* const device_to_device_master_stream =
175 should_use_substream ? dst_xla_context->device_to_device_stream(0)
176 : nullptr;
177 se::Stream* const dst_device_to_device_stream =
178 should_use_substream
179 ? device_to_device_master_stream->GetOrCreateSubStream()
180 : dst_xla_context->GetDeviceToDeviceStream();
181 TF_RET_CHECK(dst_device_to_device_stream != nullptr);
182 auto return_substream = gtl::MakeCleanup(
183 [device_to_device_master_stream, dst_device_to_device_stream] {
184 if (device_to_device_master_stream) {
185 device_to_device_master_stream->ReturnSubStream(
186 dst_device_to_device_stream);
187 }
188 });
189
190 auto* const dst_device_to_device_stream_impl =
191 static_cast<tpu::TpuStreamInterface*>(
192 dst_device_to_device_stream->implementation());
193
194 const int dst_device_ordinal =
195 dst_xla_context->stream()->parent()->device_ordinal();
196
197 XlaTensor* const xla_input = XlaTensor::FromTensor(input);
198 TF_RET_CHECK(xla_input != nullptr && xla_input->has_shaped_buffer());
199 XlaTensor* const xla_output = XlaTensor::FromTensor(output);
200 TF_RET_CHECK(xla_output != nullptr && !xla_output->has_shaped_buffer());
201 TF_RET_CHECK(input->shape() == output->shape());
202
203 const auto& shape_determination_fns =
204 dst_xla_context->shape_determination_fns();
205 XlaLayoutPreference layout_preference =
206 shape_determination_fns.layout_preference_fn(
207 input->shape(), input->dtype(), std::nullopt);
208 TF_ASSIGN_OR_RETURN(xla::Shape shape,
209 shape_determination_fns.shape_representation_fn(
210 input->shape(), input->dtype(),
211 /*use_fast_memory=*/false, layout_preference));
212 TF_RETURN_IF_ERROR(xla_output->AllocateShapedBuffer(
213 input->dtype(), shape, dst_xla_context->client(), dst_device_ordinal));
214
215 VLOG(2) << "TpuDeviceToDeviceCopy: src: "
216 << src_compute_stream->parent()->device_ordinal() << ", "
217 << " dst: " << dst_compute_stream->parent()->device_ordinal()
218 << ", "
219 << " input buffers: " << xla_input->shaped_buffer().ToString()
220 << " output buffers: " << xla_output->shaped_buffer().ToString();
221
222 // Wait for definition event of the source tensor so the input buffers are
223 // available.
224 xla_input->WaitForDefinitionEventOnStream(dst_device_to_device_stream);
225
226 // Wait for the destination tensor buffers to be ready, if they are not
227 // available for an immediate write.
228 if (!dst_xla_context->transfer_manager()->CanShapedBufferBeAccessedNow(
229 dst_compute_stream->parent(), xla_output->shaped_buffer())) {
230 dst_device_to_device_stream->ThenWaitFor(dst_compute_stream);
231 // If the representation is a tuple, we also must wait for the tuple index
232 // buffers to be available on the destination host to device transfer
233 // stream.
234 if (xla_output->shaped_buffer().on_device_shape().IsTuple()) {
235 dst_xla_context->host_to_device_stream()->ThenWaitFor(
236 dst_compute_stream);
237 }
238 }
239
240 for (const auto& leaf : xla_input->shaped_buffer().buffers().leaves()) {
241 const xla::ShapeIndex& index = leaf.first;
242 const se::DeviceMemoryBase& input_buffer = leaf.second;
243 const se::DeviceMemoryBase& output_buffer =
244 xla_output->shaped_buffer().buffer(index);
245 TF_RET_CHECK(input_buffer.size() == output_buffer.size())
246 << "input: " << input_buffer.size()
247 << " output: " << output_buffer.size();
248 TF_RETURN_IF_ERROR(
249 dst_device_to_device_stream_impl->EnqueueOnTpuDeviceSendRecvLocal(
250 input_buffer, output_buffer));
251 }
252
253 // If the on-device shape is a tuple, write new tuple index buffers.
254 if (xla_output->shaped_buffer().on_device_shape().IsTuple()) {
255 TF_RETURN_IF_ERROR(
256 dst_xla_context->transfer_manager()->WriteTupleIndexTablesAsync(
257 dst_xla_context->host_to_device_stream(),
258 xla_output->shaped_buffer()));
259
260 // We need a single definition event for an XlaTensor, so make the
261 // device to device stream wait for the stream that wrote the tuple index
262 // tables on the destination device. Should this prove to be a problem,
263 // we can always extend XlaTensor to take a pair of definition events that
264 // must all be satisfied, or add an Event::Merge() API that allows us to
265 // build an event that is triggered when all of its dependencies are
266 // triggered.
267 dst_device_to_device_stream->ThenWaitFor(
268 dst_xla_context->host_to_device_stream());
269 }
270
271 auto definition_event =
272 std::make_shared<se::Event>(dst_xla_context->stream()->parent());
273 TF_RET_CHECK(definition_event->Init()) << "Event failed to initialize!";
274 dst_device_to_device_stream->ThenRecordEvent(definition_event.get());
275 xla_output->ResetDefinitionEvent(std::move(definition_event),
276 dst_device_to_device_stream);
277
278 // The input must remain alive until the transfer completes, so we keep a
279 // reference. We also wait until the transfer completes before calling
280 // done().
281 // The latter may be too conservative, but given the host is involved in
282 // waiting for the transfer to complete anyway there is probably little
283 // downside. If we were to add the ability for computations to wait directly
284 // on transfers, then we might want to rethink this property.
285 // Also ideally this host callback should be on source stream rather than
286 // destination stream, but when this function returns, the send requests
287 // might not be enqueued to the stream yet, we put it on destination stream.
288 TensorReference input_reference(*input);
289 std::move(return_substream).release();
290 dst_device_to_device_stream->ThenDoHostCallback(
291 [input_reference, done = std::move(done),
292 device_to_device_master_stream, dst_device_to_device_stream] {
293 if (device_to_device_master_stream) {
294 device_to_device_master_stream->ReturnSubStream(
295 dst_device_to_device_stream);
296 }
297 input_reference.Unref();
298 done(OkStatus());
299 });
300
301 return OkStatus();
302 };
303 Status status = impl();
304 if (!status.ok()) {
305 done(status);
306 }
307 }
308
309 class TpuNodeDeviceFactory : public DeviceFactory {
310 public:
311 Status ListPhysicalDevices(std::vector<string>* devices) override;
312 Status CreateDevices(const SessionOptions& options, const string& name_prefix,
313 std::vector<std::unique_ptr<Device>>* devices) override;
314 };
315
ListPhysicalDevices(std::vector<string> * devices)316 Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
317 tpu::TpuPlatformInterface* platform =
318 tpu::TpuPlatformInterface::GetRegisteredPlatform();
319 if (platform == nullptr) {
320 // If we don't have a platform registered, then we have no devices.
321 return OkStatus();
322 }
323
324 int device_count = platform->VisibleDeviceCount();
325
326 for (int i = 0; i < device_count; ++i) {
327 const string device_name = absl::StrCat("/physical_device:TPU:", i);
328 devices->push_back(device_name);
329 }
330
331 return OkStatus();
332 }
333
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)334 Status TpuNodeDeviceFactory::CreateDevices(
335 const SessionOptions& session_options, const string& name_prefix,
336 std::vector<std::unique_ptr<Device>>* devices) {
337 tpu::TpuPlatformInterface* platform =
338 tpu::TpuPlatformInterface::GetRegisteredPlatform();
339 if (platform == nullptr) {
340 // If we don't have a platform registered, then we should not create any.
341 return OkStatus();
342 }
343
344 if (platform != nullptr && platform->ShouldRegisterTpuDeviceToDeviceCopy()) {
345 RegisterTpuDeviceToDeviceCopy();
346 }
347
348 XlaOpRegistry::DeviceRegistration registration;
349 registration.compilation_device_name = DEVICE_TPU_XLA_JIT;
350 registration.autoclustering_policy =
351 tpu_autoclustering_flag
352 ? XlaOpRegistry::AutoclusteringPolicy::kAlways
353 : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
354
355 registration.cluster_resource_variable_ops_unsafely = true;
356 registration.cluster_stack_ops = false;
357 registration.cluster_tensor_array_ops = true;
358 registration.cluster_stateful_rng_ops = true;
359 registration.cluster_control_trigger = true;
360 registration.elide_assert_and_checknumerics = true;
361 registration.cluster_variant_ops = true;
362 registration.cluster_slow_ops = true;
363 registration.cluster_inaccurate_ops = true;
364 XlaOpRegistry::RegisterCompilationDevice(DEVICE_TPU_NODE, registration);
365
366 static XlaDeviceOpRegistrations* registrations =
367 RegisterXlaDeviceKernels(DEVICE_TPU_NODE, DEVICE_TPU_XLA_JIT);
368 (void)registrations;
369
370 int device_count = platform->VisibleDeviceCount();
371 VLOG(1) << "Creating " << device_count << " TPU devices";
372 for (int i = 0; i < device_count; ++i) {
373 TF_RETURN_IF_ERROR(tpu::TpuNodeContext::Initialize(i));
374
375 XlaDevice::Options options;
376 options.platform = platform;
377 options.device_name_prefix = name_prefix;
378 options.device_name = DEVICE_TPU_NODE;
379 options.device_ordinal = i;
380 options.compilation_device_name = DEVICE_TPU_XLA_JIT;
381 options.use_multiple_streams = true;
382 // We set `use_global_compute_stream` to true for TPUs as TPUs can only
383 // have one program running on each core at the same time.
384 options.use_global_compute_stream = true;
385 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{
386 UseNoPreferenceLayoutFn(), &TpuShapeRepresentation};
387 options.shape_determination_fns = {shape_determination_fns};
388 options.padded_shape_fn = &TpuPaddedShapeFn;
389 auto device = std::make_unique<XlaDevice>(session_options, options);
390
391 // The AcceleratorDeviceInfo actually provides information not only for GPU
392 // devices but also for TPU. The name is a legacy from the pre-TPU
393 // dark ages.
394 Status status = device->UseAcceleratorDeviceInfo();
395 if (!status.ok()) {
396 errors::AppendToMessage(&status, "while setting up ", DEVICE_TPU_XLA_JIT,
397 " device number ", i);
398 return status;
399 }
400 device->SetAllowsSyncOnCompletion(false);
401 if (tpu_xla_device_failure_closes_chips_flag) {
402 device->SetHandleDeviceErrorCallback(&tpu::TpuNodeContext::CloseTpuHost);
403 }
404
405 devices->push_back(std::move(device));
406 }
407
408 return OkStatus();
409 }
410
411 class TpuSystemDeviceFactory : public DeviceFactory {
412 public:
413 Status ListPhysicalDevices(std::vector<string>* devices) override;
414 Status CreateDevices(const SessionOptions& options, const string& name_prefix,
415 std::vector<std::unique_ptr<Device>>* devices) override;
416 };
417
ListPhysicalDevices(std::vector<string> * devices)418 Status TpuSystemDeviceFactory::ListPhysicalDevices(
419 std::vector<string>* devices) {
420 int device_count = 0;
421 TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
422 if (device_count == 0) {
423 VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
424 return OkStatus();
425 }
426
427 devices->push_back("/physical_device:TPU_SYSTEM:0");
428
429 return OkStatus();
430 }
431
CreateDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)432 Status TpuSystemDeviceFactory::CreateDevices(
433 const SessionOptions& options, const string& name_prefix,
434 std::vector<std::unique_ptr<Device>>* devices) {
435 int device_count = 0;
436 TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
437 if (device_count == 0) {
438 VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
439 return OkStatus();
440 }
441
442 int64_t memory_limit;
443 TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpuMemoryLimit(&memory_limit));
444
445 // Creates a device that represents a TPU distributed system.
446 const DeviceAttributes attrs = Device::BuildDeviceAttributes(
447 absl::StrCat(name_prefix, "/device:", DEVICE_TPU_SYSTEM, ":", 0),
448 DeviceType(DEVICE_TPU_SYSTEM), Bytes(memory_limit), DeviceLocality(),
449 absl::StrCat("device: ", DEVICE_TPU_SYSTEM, " device"));
450 devices->push_back(std::make_unique<VirtualDevice>(options.env, attrs));
451 VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count
452 << " TPUs";
453
454 return OkStatus();
455 }
456
457 } // namespace
458
RegisterTpuDeviceToDeviceCopy()459 void RegisterTpuDeviceToDeviceCopy() {
460 static auto* const register_tpu_tpu_copy = new CopyTensor::Registration(
461 DEVICE_TPU_NODE, DEVICE_TPU_NODE, TpuDeviceToDeviceCopy);
462 (void)register_tpu_tpu_copy;
463 }
464
RegisterTpuNodeDevice(bool tpu_autoclustering,bool tpu_xla_device_failure_closes_chips,bool tpu_use_substreams_for_cross_tpu_device_transfers)465 void RegisterTpuNodeDevice(
466 bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips,
467 bool tpu_use_substreams_for_cross_tpu_device_transfers) {
468 tpu_autoclustering_flag = tpu_autoclustering;
469 tpu_xla_device_failure_closes_chips_flag =
470 tpu_xla_device_failure_closes_chips;
471 tpu_use_substreams_for_cross_tpu_device_transfers_flag =
472 tpu_use_substreams_for_cross_tpu_device_transfers;
473
474 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
475 REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
476 REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
477 REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
478 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
479 }
480
RegisterTpuSystemDevice()481 void RegisterTpuSystemDevice() {
482 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
483 }
484
485 #if !defined(PLATFORM_GOOGLE)
486
487 // We automatically register this if we are building for open source. For
488 // Google platforms, we initialize these devices in other places.
489
490 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
491 REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
492 REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
493 REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
494 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
495 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
496
497 #endif // PLATFORM_GOOGLE
498
499 } // namespace tensorflow
500