• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/xla/service/computation_placer.h"
23 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/compiler/xrt/xrt.pb.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30 #include "tensorflow/compiler/xrt/xrt_metrics.h"
31 #include "tensorflow/compiler/xrt/xrt_state.h"
32 #include "tensorflow/compiler/xrt/xrt_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/monitoring/timed.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
43 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
44 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
45 #include "tensorflow/core/tpu/tpu_configuration.h"
46 #include "tensorflow/core/tpu/tpu_defs.h"
47 #include "tensorflow/core/tpu/tpu_execute.h"
48 #include "tensorflow/stream_executor/stream_executor.h"
49 #include "tensorflow/stream_executor/stream_executor_internal.h"
50 
51 namespace tensorflow {
52 namespace {
53 
54 using tensorflow::tpu::CompilationCacheEntryRef;
55 using tensorflow::tpu::TpuCompilationCacheEntry;
56 using tensorflow::tpu::TpuCompilationCacheLookup;
57 using GetBufferFunction =
58     std::function<xla::StatusOr<std::vector<xla::ExecutionInput>>()>;
59 
60 // Looks up the input `key` in the compilation cache.
GetComputationCacheEntry(ResourceMgr * rm,int64_t key,int core_index_in_replica,std::unique_ptr<CompilationCacheEntryRef> * entry)61 Status GetComputationCacheEntry(
62     ResourceMgr* rm, int64_t key, int core_index_in_replica,
63     std::unique_ptr<CompilationCacheEntryRef>* entry) {
64   profiler::TraceMe trace_me("XRTExecuteOp::LookupProto", /*level=*/2);
65   TpuCompilationCacheLookup* proto_lookup;
66   TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
67                                 tpu::kCompiledProtoCacheResourceName,
68                                 &proto_lookup));
69   core::ScopedUnref lookup_unref(proto_lookup);
70   TF_RETURN_IF_ERROR(proto_lookup->Lookup(key, core_index_in_replica, entry));
71   return OkStatus();
72 }
73 
GetDynamicInputInfo(const TPUExecutableInfoProto & executable_proto)74 std::vector<bool> GetDynamicInputInfo(
75     const TPUExecutableInfoProto& executable_proto) {
76   std::vector<bool> input_is_dynamic;
77   input_is_dynamic.reserve(executable_proto.input_shapes().size());
78   for (int64_t i = 0; i < executable_proto.input_shapes().size(); ++i) {
79     input_is_dynamic.push_back(
80         !xla::Shape(executable_proto.input_shapes(i)).is_static());
81   }
82   return input_is_dynamic;
83 }
84 
GetChainedOpInputs(const xrt::XRTChainedExecuteOp & op,absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,const TPUExecutableInfoProto & executable_proto)85 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputs(
86     const xrt::XRTChainedExecuteOp& op,
87     absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,
88     const TPUExecutableInfoProto& executable_proto) {
89   if (op.inputs_size() != executable_proto.input_shapes_size()) {
90     return errors::InvalidArgument(
91         "Number of inputs does not match executable proto input shapes: ",
92         op.inputs_size(), " vs. ", executable_proto.input_shapes_size());
93   }
94 
95   std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
96   input_tuples.reserve(op.inputs_size());
97   for (int i = 0; i < op.inputs_size(); ++i) {
98     auto& input = op.inputs(i);
99     const RefPtr<XRTTupleAllocation>& tuple = op_inputs[i];
100     // Thanks to the greatness of proto3, there is no way to query for
101     // explicitly set fields, so the default for output_index (zero) means no
102     // sub-index. As consequence, the real index is output_index - 1.
103     if (input.output_index() == 0) {
104       input_tuples.push_back(tuple);
105     } else {
106       XRTTupleAllocation* sub_tuple;
107       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
108           tuple.get(), {input.output_index() - 1}, &sub_tuple,
109           /*alias_parent_allocation=*/true));
110       input_tuples.emplace_back(sub_tuple);
111     }
112     if (!InputShapeMatches(xla::Shape(executable_proto.input_shapes(i)),
113                            input_tuples.back()->on_host_shape())) {
114       return errors::InvalidArgument(
115           "Run-time shape mismatch for XRTExecute argument[", i, "] (",
116           op.computation_handle(), "). Expected ",
117           executable_proto.input_shapes(i).DebugString(), "; got ",
118           tuple->on_host_shape().DebugString());
119     }
120   }
121   return std::move(input_tuples);
122 }
123 
GetExecutableAliasConfig(const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,int core_index)124 xla::StatusOr<xla::HloInputOutputAliasConfig> GetExecutableAliasConfig(
125     const tpu::TpuProgramGroup* tpu_program_group, xla::Backend* const backend,
126     int core_index) {
127   const TPUExecutableInfoProto& executable =
128       tpu_program_group->executable_info(core_index);
129   return xla::HloInputOutputAliasConfig::CreateFromProto(
130       backend->transfer_manager()->HostShapeToDeviceShape(
131           xla::Shape(executable.output_shape())),
132       tpu_program_group->hlo_metadata(core_index)
133           ->hlo_module()
134           .input_output_alias());
135 }
136 
AllocateOutputTuple(tpu::TpuNodeContext * node_context,se::Stream * stream,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)137 xla::StatusOr<RefPtr<XRTTupleAllocation>> AllocateOutputTuple(
138     tpu::TpuNodeContext* node_context, se::Stream* stream,
139     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
140     const xla::HloInputOutputAliasConfig& input_output_alias,
141     xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
142   auto output_shaped_buffer = output_scoped_buffer.release();
143 
144   xla::Shape output_device_shape = output_shaped_buffer.on_device_shape();
145   if (!output_device_shape.is_static()) {
146     TF_RETURN_IF_ERROR(
147         node_context->backend()->transfer_manager()->ReadDynamicShapes(
148             stream, &output_shaped_buffer, &output_device_shape));
149   }
150 
151   XRTTupleAllocation* output_tuple;
152   xla::Shape output_host_shape =
153       xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape);
154 
155   TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
156       output_shaped_buffer, output_host_shape, output_device_shape,
157       node_context->backend(), device_ordinal, &output_tuple,
158       node_context->backend()->memory_allocator()));
159   RefPtr<XRTTupleAllocation> output_tuple_ptr(output_tuple);
160 
161   // If the input tuples had to release some buffers in order to provide the
162   // proper temporary ownership transfer, we patch the holes here by alising the
163   // buffers from the result tuple. The device address we patch back here, will
164   // essentially be the same one we carved out in the DoWork() function.
165   TF_RETURN_IF_ERROR(
166       RebuildOutputAliases(output_tuple_ptr, input_tuples, input_output_alias));
167 
168   return std::move(output_tuple_ptr);
169 }
170 
AllocateOutputTensors(OpKernelContext * context,XRTMemoryManager * memory_manager,tpu::TpuNodeContext * node_context,se::Stream * stream,const xrt::XRTExecutionConfig & config_proto,const TPUExecutableInfoProto & executable_proto,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)171 Status AllocateOutputTensors(
172     OpKernelContext* context, XRTMemoryManager* memory_manager,
173     tpu::TpuNodeContext* node_context, se::Stream* stream,
174     const xrt::XRTExecutionConfig& config_proto,
175     const TPUExecutableInfoProto& executable_proto,
176     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
177     const xla::HloInputOutputAliasConfig& input_output_alias,
178     xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
179   TF_ASSIGN_OR_RETURN(
180       RefPtr<XRTTupleAllocation> output_tuple,
181       AllocateOutputTuple(node_context, stream, input_tuples,
182                           input_output_alias, std::move(output_scoped_buffer),
183                           device_ordinal));
184   return CreateExecuteOutput(context, memory_manager, std::move(output_tuple),
185                              config_proto.return_exploded_tuple());
186 }
187 
RunExecutable(OpKernelContext * context,tpu::TpuNodeContext * node_context,const TPUExecutableInfoProto & executable,std::vector<xla::ExecutionInput> arguments,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)188 xla::StatusOr<xla::ExecutionOutput> RunExecutable(
189     OpKernelContext* context, tpu::TpuNodeContext* node_context,
190     const TPUExecutableInfoProto& executable,
191     std::vector<xla::ExecutionInput> arguments, const string& execution_id,
192     const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
193     xla::Backend* const backend, se::Stream* stream, int core_index,
194     int device_ordinal, string rendezvous_key_base) {
195   profiler::TraceMe trace_me("RunExecutable", /*level=*/2);
196 
197   // se::StreamExecutor* executor = node->stream_executor();
198 
199   std::unique_ptr<xla::DeviceAssignment> device_assignment;
200   if (executable.has_device_assignment()) {
201     TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
202                                                executable.device_assignment()));
203   }
204   // Ideally this should be the host-to-device stream from XlaDeviceContext.
205   // The particular anti-dependency this is avoiding (why we need a separate
206   // transfer stream) is between the executable writing tuple tables and
207   // TPUExecute()'s deregister_stream; if they come from the same stream pool
208   // antidependencies will occur. XlaBackend has a different pool of streams
209   // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
210   // will never refer to the same stream.
211   TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
212                       backend->BorrowStream(device_ordinal));
213   const TPUHostTransferInfoProto& host_transfer_info =
214       tpu_program_group->host_transfer_info(core_index);
215   TF_ASSIGN_OR_RETURN(
216       xla::ExecutionOutput output,
217       TPUExecute(executable, host_transfer_info,
218                  *tpu_program_group->hlo_metadata(core_index),
219                  std::move(arguments), rendezvous_key_base, rng_seed,
220                  node_context, device_assignment.get(),
221                  context->cancellation_manager(), context, stream,
222                  transfer_stream_ptr.get(),
223                  tpu_program_group->tpu_program(core_index)));
224 
225   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
226 
227   return output;
228 }
229 
ExecuteTPUProgram(OpKernelContext * context,tpu::TpuNodeContext * node_context,XRTMemoryManager * memory_manager,const TPUExecutableInfoProto & executable,const GetBufferFunction & get_buffers_fn,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)230 xla::StatusOr<xla::ExecutionOutput> ExecuteTPUProgram(
231     OpKernelContext* context, tpu::TpuNodeContext* node_context,
232     XRTMemoryManager* memory_manager, const TPUExecutableInfoProto& executable,
233     const GetBufferFunction& get_buffers_fn, const string& execution_id,
234     const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
235     xla::Backend* const backend, se::Stream* stream, int core_index,
236     int device_ordinal, string rendezvous_key_base) {
237   auto runfn = [&]() -> xla::StatusOr<xla::ExecutionOutput> {
238     TF_ASSIGN_OR_RETURN(auto arguments, get_buffers_fn());
239     return RunExecutable(context, node_context, executable,
240                          std::move(arguments), execution_id, rng_seed,
241                          tpu_program_group, backend, stream, core_index,
242                          device_ordinal, rendezvous_key_base);
243   };
244   return memory_manager->Run<xla::ExecutionOutput>(
245       runfn, backend, device_ordinal, /*requested_free_size=*/0,
246       backend->memory_allocator());
247 }
248 
249 // XRTExecuteOp
250 
251 class XRTExecuteOp : public AsyncOpKernel {
252  public:
253   explicit XRTExecuteOp(OpKernelConstruction* context);
254 
255   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
256 
257  private:
258   Status DoWork(OpKernelContext* context);
259 };
260 
XRTExecuteOp(OpKernelConstruction * context)261 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
262     : AsyncOpKernel(context, /* is_deferred = */ true) {}
263 
ComputeAsync(OpKernelContext * context,DoneCallback done)264 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
265   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
266   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
267   done();
268 }
269 
DoWork(OpKernelContext * context)270 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
271   VLOG(1) << "XRTExecuteOp::Compute";
272 
273   const XlaDevice::Metadata* metadata;
274   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
275   const int device_ordinal = metadata->device_ordinal();
276   // We are guaranteed that the object underlying TpuNodeContext won't be
277   // deleted out from under us, while node_context is alive.
278   TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
279                       tpu::TpuNodeContext::Create(device_ordinal));
280   xla::Backend* const backend = node_context->backend();
281   se::Stream* stream = context->op_device_context()->stream();
282 
283   auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell());
284   profiler::TraceMe trace_me(
285       [context] {
286         return profiler::TraceMeEncode("TpuExecuteOp",
287                                        {{"step_id", context->step_id()}});
288       },
289       /*level=*/2);
290   profiler::TraceMe trace_me_init("XRTExecuteOp::Init", /*level=*/2);
291 
292   auto* rm = GetTPUConfigResourceMgr();
293   TF_RET_CHECK(rm != nullptr);
294 
295   const Tensor& execution_input = context->input(0);
296   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
297   int64_t compilation_handle = execution_input.scalar<int64_t>()();
298 
299   const Tensor& execution_config = context->input(1);
300   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
301   xrt::XRTExecutionConfig config_proto;
302   TF_RET_CHECK(
303       config_proto.ParseFromString(execution_config.scalar<tstring>()()));
304 
305   int core_index_in_replica = config_proto.core_index_in_replica();
306   bool release_inputs = config_proto.release_input_handles();
307   bool release_compilation = config_proto.release_compilation_handle();
308 
309   string rendezvous_key_base = std::to_string(compilation_handle);
310   std::unique_ptr<CompilationCacheEntryRef> entry;
311   TF_RETURN_IF_ERROR(GetComputationCacheEntry(rm, compilation_handle,
312                                               core_index_in_replica, &entry));
313 
314   TpuCompilationCacheEntry centry = entry->get();
315   const tpu::TpuProgramGroup* tpu_program_group =
316       tensorflow::down_cast<const tpu::TpuProgramGroup*>(
317           centry.tpu_program_group());
318   CHECK_NE(tpu_program_group, nullptr);
319 
320   if (release_compilation) {
321     // Process-wide cache of Tpu executables.
322     tpu::TpuCompilationCacheInterface* cache;
323     TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuCompilationCacheInterface>(
324         rm->default_container(), tpu::kCompilationCacheResourceName, &cache));
325     core::ScopedUnref cache_unref(cache);
326     TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
327     VLOG(2) << "Released compilation handle " << compilation_handle;
328   }
329 
330   const int core_index = centry.core_index();
331   const TPUExecutableInfoProto& executable =
332       tpu_program_group->executable_info(core_index);
333 
334   std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
335 
336   TF_ASSIGN_OR_RETURN(
337       xla::HloInputOutputAliasConfig input_output_alias,
338       GetExecutableAliasConfig(tpu_program_group, backend, core_index));
339   TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
340                       GetComputationInputs(context, "input_handles"));
341 
342   RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
343   XRTMemoryManager::WorkingSet working_set(memory_manager);
344   TF_ASSIGN_OR_RETURN(
345       std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
346       GetInputTupleAllocations(
347           input_coords, &working_set, backend, executable.input_shapes_size(),
348           [&](int64_t i) { return xla::Shape(executable.input_shapes(i)); },
349           release_inputs, backend->memory_allocator()));
350   auto get_buffers_fn = [&]() {
351     return GetArgumentsBuffers(input_output_alias, input_tuples,
352                                input_is_dynamic, release_inputs);
353   };
354   trace_me_init.Stop();
355 
356   TF_ASSIGN_OR_RETURN(
357       xla::ExecutionOutput output,
358       ExecuteTPUProgram(
359           context, node_context.get(), memory_manager.get(), executable,
360           get_buffers_fn, config_proto.execution_instance_key(),
361           config_proto.rng_seed(), tpu_program_group, backend, stream,
362           core_index, device_ordinal, rendezvous_key_base));
363 
364   // AllocateComputationOutput writes the output tuple handle to the output
365   // tensor return value from the Op.
366   TF_RETURN_IF_ERROR(AllocateOutputTensors(
367       context, memory_manager.get(), node_context.get(), stream, config_proto,
368       executable, input_tuples, input_output_alias, output.ConsumeResult(),
369       device_ordinal));
370   return OkStatus();
371 }
372 
373 class XRTExecuteChainedOp : public AsyncOpKernel {
374  public:
375   explicit XRTExecuteChainedOp(OpKernelConstruction* context);
376 
377   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
378 
379  private:
380   Status DoWork(OpKernelContext* context);
381 };
382 
XRTExecuteChainedOp(OpKernelConstruction * context)383 XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context)
384     : AsyncOpKernel(context, /* is_deferred = */ true) {}
385 
ComputeAsync(OpKernelContext * context,DoneCallback done)386 void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context,
387                                        DoneCallback done) {
388   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
389   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
390   done();
391 }
392 
DoWork(OpKernelContext * context)393 Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
394   VLOG(1) << "XRTExecuteChainedOp::Compute";
395   const XlaDevice::Metadata* metadata;
396   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
397   const int device_ordinal = metadata->device_ordinal();
398   // We are guaranteed that the object underlying TpuNodeContext won't be
399   // deleted out from under us, while node_context is alive.
400   TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
401                       tpu::TpuNodeContext::Create(device_ordinal));
402   xla::Backend* const backend = node_context->backend();
403   se::Stream* stream = context->op_device_context()->stream();
404   auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell());
405   profiler::TraceMe trace_me(
406       [context] {
407         return profiler::TraceMeEncode("TpuExecuteChainedOp",
408                                        {{"step_id", context->step_id()}});
409       },
410       /*level=*/2);
411   ResourceMgr* rm = GetTPUConfigResourceMgr();
412   TF_RET_CHECK(rm != nullptr);
413 
414   const Tensor& execution_plan = context->input(0);
415   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
416   xrt::XRTChainedExecutePlan plan;
417   TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar<tstring>()()));
418 
419   const Tensor& execution_config = context->input(1);
420   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
421   xrt::XRTChainedExecuteConfig config;
422   TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
423 
424   TpuCompilationCacheLookup* proto_lookup;
425   TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
426                                 tpu::kCompiledProtoCacheResourceName,
427                                 &proto_lookup));
428   core::ScopedUnref lookup_unref(proto_lookup);
429   RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
430   auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
431                         absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
432       -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
433     std::unique_ptr<CompilationCacheEntryRef> entry;
434     TF_RETURN_IF_ERROR(proto_lookup->Lookup(
435         op.computation_handle(), config.core_index_in_replica(), &entry));
436     string rendezvous_key_base = std::to_string(op.computation_handle());
437     TpuCompilationCacheEntry centry = entry->get();
438     const tpu::TpuProgramGroup* tpu_program_group =
439         tensorflow::down_cast<const tpu::TpuProgramGroup*>(
440             centry.tpu_program_group());
441     CHECK_NE(tpu_program_group, nullptr);
442     const int core_index = centry.core_index();
443     const TPUExecutableInfoProto& executable =
444         tpu_program_group->executable_info(core_index);
445     std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
446 
447     TF_ASSIGN_OR_RETURN(
448         xla::HloInputOutputAliasConfig input_output_alias,
449         GetExecutableAliasConfig(tpu_program_group, backend, core_index));
450     TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
451                         GetChainedOpInputs(op, op_inputs, executable));
452     auto get_buffers_fn = [&]() {
453       return GetArgumentsBuffers(input_output_alias, input_tuples,
454                                  input_is_dynamic,
455                                  /*release_inputs=*/false);
456     };
457     TF_ASSIGN_OR_RETURN(
458         xla::ExecutionOutput output,
459         ExecuteTPUProgram(context, node_context.get(), memory_manager.get(),
460                           executable, get_buffers_fn,
461                           config.execution_instance_key(), config.rng_seed(),
462                           tpu_program_group, backend, stream, core_index,
463                           device_ordinal, rendezvous_key_base));
464     return AllocateOutputTuple(node_context.get(), stream, input_tuples,
465                                input_output_alias, output.ConsumeResult(),
466                                device_ordinal);
467   };
468 
469   return ExecuteChained(context, memory_manager, backend, device_ordinal, plan,
470                         config, execute_op, backend->memory_allocator());
471 }
472 
473 }  // namespace
474 
475 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
476                             .Device(DEVICE_TPU_NODE)
477                             .HostMemory("computation_handle")
478                             .HostMemory("execution_config")
479                             .HostMemory("input_handles")
480                             .HostMemory("output_handle"),
481                         XRTExecuteOp);
482 
483 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
484                             .Device(DEVICE_TPU_NODE)
485                             .HostMemory("execution_plan")
486                             .HostMemory("execution_config")
487                             .HostMemory("output_handle"),
488                         XRTExecuteChainedOp);
489 
490 }  // namespace tensorflow
491