• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_tpu_device.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/compiler/jit/xla_device_ops.h"
22 #include "tensorflow/compiler/tf2xla/layout_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/core/common_runtime/copy_tensor.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_factory.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/framework/kernel_def.pb.h"
32 #include "tensorflow/core/framework/tensor_reference.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/tpu/tpu_api.h"
36 #include "tensorflow/core/tpu/tpu_defs.h"
37 #include "tensorflow/core/tpu/tpu_node_device_util.h"
38 #include "tensorflow/core/tpu/virtual_device.h"
39 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
40 #include "tensorflow/stream_executor/tpu/status_helper.h"
41 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
42 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
43 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
44 #include "tensorflow/stream_executor/tpu/tpu_stream_interface.h"
45 
46 namespace tensorflow {
47 namespace {
48 
49 static bool tpu_autoclustering_flag = false;
50 static bool tpu_xla_device_failure_closes_chips_flag = true;
51 static bool tpu_use_substreams_for_cross_tpu_device_transfers_flag = true;
52 
53 // Given a tensor of `shape` and `type`, as what shape should it be stored on
54 // the TPU device? This function tranposes or flattens the excessively-padded
55 // tensors to rank 1, but leaves other tensor shapes alone.
TpuShapeRepresentation(const TensorShape & shape,DataType type,bool use_fast_memory,XlaLayoutPreference layout_preference)56 StatusOr<xla::Shape> TpuShapeRepresentation(
57     const TensorShape& shape, DataType type, bool use_fast_memory,
58     XlaLayoutPreference layout_preference) {
59   xla::Shape xla_shape;
60   TF_RETURN_IF_ERROR(
61       tensorflow::TensorShapeToXLAShape(type, shape, &xla_shape));
62   ApiConverter::StackHelper<XLA_Shape> se_shape(xla_shape);
63   ApiConverter::StackHelper<XLA_Shape> tpu_shape;
64   StatusHelper status;
65   tpu::ExecutorApiFn()->XlaShapeToTpuShapeRepresentationFn(
66       &se_shape.value, type, use_fast_memory, &tpu_shape.value,
67       status.c_status);
68   if (!status.status().ok()) {
69     return status.status();
70   }
71   return tpu_shape.AsCpp<xla::Shape>();
72 }
73 
74 // Given a tensor, returns the shape of its representation on device,
75 // fully padded. Contents of `shape` are undefined on error.
TpuPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)76 Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
77   const tensorflow::XlaTensor* xla_tensor =
78       tensorflow::XlaTensor::FromTensor(&tensor);
79   if (xla_tensor == nullptr) {
80     return errors::InvalidArgument(
81         "Expected an XlaTensor when computing padded shape");
82   }
83 
84   if (!xla_tensor->has_shaped_buffer()) {
85     return errors::InvalidArgument(
86         "XlaTensor is expected to have device memory allocated when "
87         "computing padded shape");
88   }
89 
90   const xla::Shape& on_device_shape =
91       xla_tensor->shaped_buffer().on_device_shape();
92 
93   StatusHelper status;
94   ApiConverter::StackHelper<XLA_Shape> se_shape(on_device_shape);
95   ApiConverter::StackHelper<XLA_Shape> tpu_shape;
96   tpu::ExecutorApiFn()->XlaShapeToTpuPaddedShapeFn(
97       &se_shape.value, &tpu_shape.value, status.c_status);
98   if (!status.ok()) {
99     return status.status();
100   }
101   *shape = tpu_shape.AsCpp<xla::Shape>();
102   return OkStatus();
103 }
104 
105 // Check if TPU has been initialized. TPU initialization is not necessary
106 // for 1x1.
CheckIfTPUInitialized()107 Status CheckIfTPUInitialized() {
108   auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
109   if (!tpu_platform->Initialized()) {
110     return errors::FailedPrecondition(
111         "The TPU system has not been initialized.");
112   }
113   return OkStatus();
114 }
115 
116 // Implementation of TPU->TPU device copies that copies over the dedicated TPU
117 // interconnects, which is much faster than PCIe or the host network.
118 // TODO(b/117426293): This implementation is only called for direct interconnect
119 // transfers between TPU devices attached to the same host. Ideally, we would
120 // generalize this support to direct interconnect transfers across hosts, but
121 // currently the CopyTensor infrastructure seems to the network topology is
122 // strictly hierarchical, that is, transfers between devices on different hosts
123 // can only take place using the host network.
TpuDeviceToDeviceCopy(DeviceContext * src_dev_context,DeviceContext * dst_dev_context,Device * src,Device * dst,AllocatorAttributes src_allocator_attrs,AllocatorAttributes dst_allocator_attrs,const Tensor * input,Tensor * output,int dev_to_dev_stream_index,StatusCallback done)124 void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context,
125                            DeviceContext* dst_dev_context, Device* src,
126                            Device* dst, AllocatorAttributes src_allocator_attrs,
127                            AllocatorAttributes dst_allocator_attrs,
128                            const Tensor* input, Tensor* output,
129                            int dev_to_dev_stream_index, StatusCallback done) {
130   XlaDeviceContext* const src_xla_context =
131       static_cast<XlaDeviceContext*>(src_dev_context);
132   XlaDeviceContext* const dst_xla_context =
133       static_cast<XlaDeviceContext*>(dst_dev_context);
134   static const bool should_use_substream =
135       tpu_use_substreams_for_cross_tpu_device_transfers_flag;
136 
137   auto impl = [&]() -> Status {
138     if (src->name() != dst->name()) {
139       Status s = CheckIfTPUInitialized();
140       if (!s.ok()) {
141         done(s);
142         return OkStatus();
143       }
144     }
145     if (input->shape().num_elements() == 0) {
146       // Zero-element tensors have no backing buffers.
147       done(OkStatus());
148       return OkStatus();
149     }
150 
151     se::Stream* const src_compute_stream = src_xla_context->stream();
152     TF_RET_CHECK(src_compute_stream != nullptr);
153     TF_RET_CHECK(input->dtype() == output->dtype())
154         << "input type: " << DataTypeString(input->dtype()) << " output type "
155         << DataTypeString(output->dtype());
156     TF_RET_CHECK(input->shape() == output->shape());
157     TF_RET_CHECK(DMAHelper::CanUseDMA(input));
158     auto* const src_compute_stream_impl = static_cast<tpu::TpuStreamInterface*>(
159         src_compute_stream->implementation());
160 
161     se::Stream* dst_compute_stream = dst_xla_context->stream();
162     auto* const dst_compute_stream_impl = static_cast<tpu::TpuStreamInterface*>(
163         dst_compute_stream->implementation());
164 
165     if (src_compute_stream_impl->IsSameSharedMemoryLocation(
166             dst_compute_stream_impl)) {
167       // Surprisingly, this path does get triggered in practice.
168       *output = *input;
169       done(OkStatus());
170       return OkStatus();
171     }
172 
173     // To avoid stream exhaustion, we pick a substream from a pool if enabled.
174     se::Stream* const device_to_device_master_stream =
175         should_use_substream ? dst_xla_context->device_to_device_stream(0)
176                              : nullptr;
177     se::Stream* const dst_device_to_device_stream =
178         should_use_substream
179             ? device_to_device_master_stream->GetOrCreateSubStream()
180             : dst_xla_context->GetDeviceToDeviceStream();
181     TF_RET_CHECK(dst_device_to_device_stream != nullptr);
182     auto return_substream = gtl::MakeCleanup(
183         [device_to_device_master_stream, dst_device_to_device_stream] {
184           if (device_to_device_master_stream) {
185             device_to_device_master_stream->ReturnSubStream(
186                 dst_device_to_device_stream);
187           }
188         });
189 
190     auto* const dst_device_to_device_stream_impl =
191         static_cast<tpu::TpuStreamInterface*>(
192             dst_device_to_device_stream->implementation());
193 
194     const int dst_device_ordinal =
195         dst_xla_context->stream()->parent()->device_ordinal();
196 
197     XlaTensor* const xla_input = XlaTensor::FromTensor(input);
198     TF_RET_CHECK(xla_input != nullptr && xla_input->has_shaped_buffer());
199     XlaTensor* const xla_output = XlaTensor::FromTensor(output);
200     TF_RET_CHECK(xla_output != nullptr && !xla_output->has_shaped_buffer());
201     TF_RET_CHECK(input->shape() == output->shape());
202 
203     const auto& shape_determination_fns =
204         dst_xla_context->shape_determination_fns();
205     XlaLayoutPreference layout_preference =
206         shape_determination_fns.layout_preference_fn(
207             input->shape(), input->dtype(), std::nullopt);
208     TF_ASSIGN_OR_RETURN(xla::Shape shape,
209                         shape_determination_fns.shape_representation_fn(
210                             input->shape(), input->dtype(),
211                             /*use_fast_memory=*/false, layout_preference));
212     TF_RETURN_IF_ERROR(xla_output->AllocateShapedBuffer(
213         input->dtype(), shape, dst_xla_context->client(), dst_device_ordinal));
214 
215     VLOG(2) << "TpuDeviceToDeviceCopy: src: "
216             << src_compute_stream->parent()->device_ordinal() << ", "
217             << " dst: " << dst_compute_stream->parent()->device_ordinal()
218             << ", "
219             << " input buffers: " << xla_input->shaped_buffer().ToString()
220             << " output buffers: " << xla_output->shaped_buffer().ToString();
221 
222     // Wait for definition event of the source tensor so the input buffers are
223     // available.
224     xla_input->WaitForDefinitionEventOnStream(dst_device_to_device_stream);
225 
226     // Wait for the destination tensor buffers to be ready, if they are not
227     // available for an immediate write.
228     if (!dst_xla_context->transfer_manager()->CanShapedBufferBeAccessedNow(
229             dst_compute_stream->parent(), xla_output->shaped_buffer())) {
230       dst_device_to_device_stream->ThenWaitFor(dst_compute_stream);
231       // If the representation is a tuple, we also must wait for the tuple index
232       // buffers to be available on the destination host to device transfer
233       // stream.
234       if (xla_output->shaped_buffer().on_device_shape().IsTuple()) {
235         dst_xla_context->host_to_device_stream()->ThenWaitFor(
236             dst_compute_stream);
237       }
238     }
239 
240     for (const auto& leaf : xla_input->shaped_buffer().buffers().leaves()) {
241       const xla::ShapeIndex& index = leaf.first;
242       const se::DeviceMemoryBase& input_buffer = leaf.second;
243       const se::DeviceMemoryBase& output_buffer =
244           xla_output->shaped_buffer().buffer(index);
245       TF_RET_CHECK(input_buffer.size() == output_buffer.size())
246           << "input: " << input_buffer.size()
247           << " output: " << output_buffer.size();
248       TF_RETURN_IF_ERROR(
249           dst_device_to_device_stream_impl->EnqueueOnTpuDeviceSendRecvLocal(
250               input_buffer, output_buffer));
251     }
252 
253     // If the on-device shape is a tuple, write new tuple index buffers.
254     if (xla_output->shaped_buffer().on_device_shape().IsTuple()) {
255       TF_RETURN_IF_ERROR(
256           dst_xla_context->transfer_manager()->WriteTupleIndexTablesAsync(
257               dst_xla_context->host_to_device_stream(),
258               xla_output->shaped_buffer()));
259 
260       // We need a single definition event for an XlaTensor, so make the
261       // device to device stream wait for the stream that wrote the tuple index
262       // tables on the destination device. Should this prove to be a problem,
263       // we can always extend XlaTensor to take a pair of definition events that
264       // must all be satisfied, or add an Event::Merge() API that allows us to
265       // build an event that is triggered when all of its dependencies are
266       // triggered.
267       dst_device_to_device_stream->ThenWaitFor(
268           dst_xla_context->host_to_device_stream());
269     }
270 
271     auto definition_event =
272         std::make_shared<se::Event>(dst_xla_context->stream()->parent());
273     TF_RET_CHECK(definition_event->Init()) << "Event failed to initialize!";
274     dst_device_to_device_stream->ThenRecordEvent(definition_event.get());
275     xla_output->ResetDefinitionEvent(std::move(definition_event),
276                                      dst_device_to_device_stream);
277 
278     // The input must remain alive until the transfer completes, so we keep a
279     // reference. We also wait until the transfer completes before calling
280     // done().
281     // The latter may be too conservative, but given the host is involved in
282     // waiting for the transfer to complete anyway there is probably little
283     // downside. If we were to add the ability for computations to wait directly
284     // on transfers, then we might want to rethink this property.
285     // Also ideally this host callback should be on source stream rather than
286     // destination stream, but when this function returns, the send requests
287     // might not be enqueued to the stream yet, we put it on destination stream.
288     TensorReference input_reference(*input);
289     std::move(return_substream).release();
290     dst_device_to_device_stream->ThenDoHostCallback(
291         [input_reference, done = std::move(done),
292          device_to_device_master_stream, dst_device_to_device_stream] {
293           if (device_to_device_master_stream) {
294             device_to_device_master_stream->ReturnSubStream(
295                 dst_device_to_device_stream);
296           }
297           input_reference.Unref();
298           done(OkStatus());
299         });
300 
301     return OkStatus();
302   };
303   Status status = impl();
304   if (!status.ok()) {
305     done(status);
306   }
307 }
308 
309 class TpuNodeDeviceFactory : public DeviceFactory {
310  public:
311   Status ListPhysicalDevices(std::vector<string>* devices) override;
312   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
313                        std::vector<std::unique_ptr<Device>>* devices) override;
314 };
315 
ListPhysicalDevices(std::vector<string> * devices)316 Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
317   tpu::TpuPlatformInterface* platform =
318       tpu::TpuPlatformInterface::GetRegisteredPlatform();
319   if (platform == nullptr) {
320     // If we don't have a platform registered, then we have no devices.
321     return OkStatus();
322   }
323 
324   int device_count = platform->VisibleDeviceCount();
325 
326   for (int i = 0; i < device_count; ++i) {
327     const string device_name = absl::StrCat("/physical_device:TPU:", i);
328     devices->push_back(device_name);
329   }
330 
331   return OkStatus();
332 }
333 
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)334 Status TpuNodeDeviceFactory::CreateDevices(
335     const SessionOptions& session_options, const string& name_prefix,
336     std::vector<std::unique_ptr<Device>>* devices) {
337   tpu::TpuPlatformInterface* platform =
338       tpu::TpuPlatformInterface::GetRegisteredPlatform();
339   if (platform == nullptr) {
340     // If we don't have a platform registered, then we should not create any.
341     return OkStatus();
342   }
343 
344   if (platform != nullptr && platform->ShouldRegisterTpuDeviceToDeviceCopy()) {
345     RegisterTpuDeviceToDeviceCopy();
346   }
347 
348   XlaOpRegistry::DeviceRegistration registration;
349   registration.compilation_device_name = DEVICE_TPU_XLA_JIT;
350   registration.autoclustering_policy =
351       tpu_autoclustering_flag
352           ? XlaOpRegistry::AutoclusteringPolicy::kAlways
353           : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
354 
355   registration.cluster_resource_variable_ops_unsafely = true;
356   registration.cluster_stack_ops = false;
357   registration.cluster_tensor_array_ops = true;
358   registration.cluster_stateful_rng_ops = true;
359   registration.cluster_control_trigger = true;
360   registration.elide_assert_and_checknumerics = true;
361   registration.cluster_variant_ops = true;
362   registration.cluster_slow_ops = true;
363   registration.cluster_inaccurate_ops = true;
364   XlaOpRegistry::RegisterCompilationDevice(DEVICE_TPU_NODE, registration);
365 
366   static XlaDeviceOpRegistrations* registrations =
367       RegisterXlaDeviceKernels(DEVICE_TPU_NODE, DEVICE_TPU_XLA_JIT);
368   (void)registrations;
369 
370   int device_count = platform->VisibleDeviceCount();
371   VLOG(1) << "Creating " << device_count << " TPU devices";
372   for (int i = 0; i < device_count; ++i) {
373     TF_RETURN_IF_ERROR(tpu::TpuNodeContext::Initialize(i));
374 
375     XlaDevice::Options options;
376     options.platform = platform;
377     options.device_name_prefix = name_prefix;
378     options.device_name = DEVICE_TPU_NODE;
379     options.device_ordinal = i;
380     options.compilation_device_name = DEVICE_TPU_XLA_JIT;
381     options.use_multiple_streams = true;
382     // We set `use_global_compute_stream` to true for TPUs as TPUs can only
383     // have one program running on each core at the same time.
384     options.use_global_compute_stream = true;
385     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{
386         UseNoPreferenceLayoutFn(), &TpuShapeRepresentation};
387     options.shape_determination_fns = {shape_determination_fns};
388     options.padded_shape_fn = &TpuPaddedShapeFn;
389     auto device = std::make_unique<XlaDevice>(session_options, options);
390 
391     // The AcceleratorDeviceInfo actually provides information not only for GPU
392     // devices but also for TPU. The name is a legacy from the pre-TPU
393     // dark ages.
394     Status status = device->UseAcceleratorDeviceInfo();
395     if (!status.ok()) {
396       errors::AppendToMessage(&status, "while setting up ", DEVICE_TPU_XLA_JIT,
397                               " device number ", i);
398       return status;
399     }
400     device->SetAllowsSyncOnCompletion(false);
401     if (tpu_xla_device_failure_closes_chips_flag) {
402       device->SetHandleDeviceErrorCallback(&tpu::TpuNodeContext::CloseTpuHost);
403     }
404 
405     devices->push_back(std::move(device));
406   }
407 
408   return OkStatus();
409 }
410 
411 class TpuSystemDeviceFactory : public DeviceFactory {
412  public:
413   Status ListPhysicalDevices(std::vector<string>* devices) override;
414   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
415                        std::vector<std::unique_ptr<Device>>* devices) override;
416 };
417 
ListPhysicalDevices(std::vector<string> * devices)418 Status TpuSystemDeviceFactory::ListPhysicalDevices(
419     std::vector<string>* devices) {
420   int device_count = 0;
421   TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
422   if (device_count == 0) {
423     VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
424     return OkStatus();
425   }
426 
427   devices->push_back("/physical_device:TPU_SYSTEM:0");
428 
429   return OkStatus();
430 }
431 
CreateDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)432 Status TpuSystemDeviceFactory::CreateDevices(
433     const SessionOptions& options, const string& name_prefix,
434     std::vector<std::unique_ptr<Device>>* devices) {
435   int device_count = 0;
436   TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count));
437   if (device_count == 0) {
438     VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device";
439     return OkStatus();
440   }
441 
442   int64_t memory_limit;
443   TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpuMemoryLimit(&memory_limit));
444 
445   // Creates a device that represents a TPU distributed system.
446   const DeviceAttributes attrs = Device::BuildDeviceAttributes(
447       absl::StrCat(name_prefix, "/device:", DEVICE_TPU_SYSTEM, ":", 0),
448       DeviceType(DEVICE_TPU_SYSTEM), Bytes(memory_limit), DeviceLocality(),
449       absl::StrCat("device: ", DEVICE_TPU_SYSTEM, " device"));
450   devices->push_back(std::make_unique<VirtualDevice>(options.env, attrs));
451   VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count
452           << " TPUs";
453 
454   return OkStatus();
455 }
456 
457 }  // namespace
458 
RegisterTpuDeviceToDeviceCopy()459 void RegisterTpuDeviceToDeviceCopy() {
460   static auto* const register_tpu_tpu_copy = new CopyTensor::Registration(
461       DEVICE_TPU_NODE, DEVICE_TPU_NODE, TpuDeviceToDeviceCopy);
462   (void)register_tpu_tpu_copy;
463 }
464 
RegisterTpuNodeDevice(bool tpu_autoclustering,bool tpu_xla_device_failure_closes_chips,bool tpu_use_substreams_for_cross_tpu_device_transfers)465 void RegisterTpuNodeDevice(
466     bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips,
467     bool tpu_use_substreams_for_cross_tpu_device_transfers) {
468   tpu_autoclustering_flag = tpu_autoclustering;
469   tpu_xla_device_failure_closes_chips_flag =
470       tpu_xla_device_failure_closes_chips;
471   tpu_use_substreams_for_cross_tpu_device_transfers_flag =
472       tpu_use_substreams_for_cross_tpu_device_transfers;
473 
474   REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
475   REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
476   REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
477   REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
478   REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
479 }
480 
RegisterTpuSystemDevice()481 void RegisterTpuSystemDevice() {
482   REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
483 }
484 
485 #if !defined(PLATFORM_GOOGLE)
486 
487 // We automatically register this if we are building for open source. For
488 // Google platforms, we initialize these devices in other places.
489 
490 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes);
491 REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes);
492 REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes);
493 REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes);
494 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory);
495 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
496 
497 #endif  // PLATFORM_GOOGLE
498 
499 }  // namespace tensorflow
500