1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/xla/service/computation_placer.h"
23 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/compiler/xrt/xrt.pb.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30 #include "tensorflow/compiler/xrt/xrt_metrics.h"
31 #include "tensorflow/compiler/xrt/xrt_state.h"
32 #include "tensorflow/compiler/xrt/xrt_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/monitoring/timed.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
43 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
44 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
45 #include "tensorflow/core/tpu/tpu_configuration.h"
46 #include "tensorflow/core/tpu/tpu_defs.h"
47 #include "tensorflow/core/tpu/tpu_execute.h"
48 #include "tensorflow/stream_executor/stream_executor.h"
49 #include "tensorflow/stream_executor/stream_executor_internal.h"
50
51 namespace tensorflow {
52 namespace {
53
54 using tensorflow::tpu::CompilationCacheEntryRef;
55 using tensorflow::tpu::TpuCompilationCacheEntry;
56 using tensorflow::tpu::TpuCompilationCacheLookup;
57 using GetBufferFunction =
58 std::function<xla::StatusOr<std::vector<xla::ExecutionInput>>()>;
59
60 // Looks up the input `key` in the compilation cache.
GetComputationCacheEntry(ResourceMgr * rm,int64_t key,int core_index_in_replica,std::unique_ptr<CompilationCacheEntryRef> * entry)61 Status GetComputationCacheEntry(
62 ResourceMgr* rm, int64_t key, int core_index_in_replica,
63 std::unique_ptr<CompilationCacheEntryRef>* entry) {
64 profiler::TraceMe trace_me("XRTExecuteOp::LookupProto", /*level=*/2);
65 TpuCompilationCacheLookup* proto_lookup;
66 TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
67 tpu::kCompiledProtoCacheResourceName,
68 &proto_lookup));
69 core::ScopedUnref lookup_unref(proto_lookup);
70 TF_RETURN_IF_ERROR(proto_lookup->Lookup(key, core_index_in_replica, entry));
71 return OkStatus();
72 }
73
GetDynamicInputInfo(const TPUExecutableInfoProto & executable_proto)74 std::vector<bool> GetDynamicInputInfo(
75 const TPUExecutableInfoProto& executable_proto) {
76 std::vector<bool> input_is_dynamic;
77 input_is_dynamic.reserve(executable_proto.input_shapes().size());
78 for (int64_t i = 0; i < executable_proto.input_shapes().size(); ++i) {
79 input_is_dynamic.push_back(
80 !xla::Shape(executable_proto.input_shapes(i)).is_static());
81 }
82 return input_is_dynamic;
83 }
84
GetChainedOpInputs(const xrt::XRTChainedExecuteOp & op,absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,const TPUExecutableInfoProto & executable_proto)85 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputs(
86 const xrt::XRTChainedExecuteOp& op,
87 absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs,
88 const TPUExecutableInfoProto& executable_proto) {
89 if (op.inputs_size() != executable_proto.input_shapes_size()) {
90 return errors::InvalidArgument(
91 "Number of inputs does not match executable proto input shapes: ",
92 op.inputs_size(), " vs. ", executable_proto.input_shapes_size());
93 }
94
95 std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
96 input_tuples.reserve(op.inputs_size());
97 for (int i = 0; i < op.inputs_size(); ++i) {
98 auto& input = op.inputs(i);
99 const RefPtr<XRTTupleAllocation>& tuple = op_inputs[i];
100 // Thanks to the greatness of proto3, there is no way to query for
101 // explicitly set fields, so the default for output_index (zero) means no
102 // sub-index. As consequence, the real index is output_index - 1.
103 if (input.output_index() == 0) {
104 input_tuples.push_back(tuple);
105 } else {
106 XRTTupleAllocation* sub_tuple;
107 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
108 tuple.get(), {input.output_index() - 1}, &sub_tuple,
109 /*alias_parent_allocation=*/true));
110 input_tuples.emplace_back(sub_tuple);
111 }
112 if (!InputShapeMatches(xla::Shape(executable_proto.input_shapes(i)),
113 input_tuples.back()->on_host_shape())) {
114 return errors::InvalidArgument(
115 "Run-time shape mismatch for XRTExecute argument[", i, "] (",
116 op.computation_handle(), "). Expected ",
117 executable_proto.input_shapes(i).DebugString(), "; got ",
118 tuple->on_host_shape().DebugString());
119 }
120 }
121 return std::move(input_tuples);
122 }
123
GetExecutableAliasConfig(const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,int core_index)124 xla::StatusOr<xla::HloInputOutputAliasConfig> GetExecutableAliasConfig(
125 const tpu::TpuProgramGroup* tpu_program_group, xla::Backend* const backend,
126 int core_index) {
127 const TPUExecutableInfoProto& executable =
128 tpu_program_group->executable_info(core_index);
129 return xla::HloInputOutputAliasConfig::CreateFromProto(
130 backend->transfer_manager()->HostShapeToDeviceShape(
131 xla::Shape(executable.output_shape())),
132 tpu_program_group->hlo_metadata(core_index)
133 ->hlo_module()
134 .input_output_alias());
135 }
136
AllocateOutputTuple(tpu::TpuNodeContext * node_context,se::Stream * stream,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)137 xla::StatusOr<RefPtr<XRTTupleAllocation>> AllocateOutputTuple(
138 tpu::TpuNodeContext* node_context, se::Stream* stream,
139 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
140 const xla::HloInputOutputAliasConfig& input_output_alias,
141 xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
142 auto output_shaped_buffer = output_scoped_buffer.release();
143
144 xla::Shape output_device_shape = output_shaped_buffer.on_device_shape();
145 if (!output_device_shape.is_static()) {
146 TF_RETURN_IF_ERROR(
147 node_context->backend()->transfer_manager()->ReadDynamicShapes(
148 stream, &output_shaped_buffer, &output_device_shape));
149 }
150
151 XRTTupleAllocation* output_tuple;
152 xla::Shape output_host_shape =
153 xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape);
154
155 TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
156 output_shaped_buffer, output_host_shape, output_device_shape,
157 node_context->backend(), device_ordinal, &output_tuple,
158 node_context->backend()->memory_allocator()));
159 RefPtr<XRTTupleAllocation> output_tuple_ptr(output_tuple);
160
161 // If the input tuples had to release some buffers in order to provide the
162 // proper temporary ownership transfer, we patch the holes here by alising the
163 // buffers from the result tuple. The device address we patch back here, will
164 // essentially be the same one we carved out in the DoWork() function.
165 TF_RETURN_IF_ERROR(
166 RebuildOutputAliases(output_tuple_ptr, input_tuples, input_output_alias));
167
168 return std::move(output_tuple_ptr);
169 }
170
AllocateOutputTensors(OpKernelContext * context,XRTMemoryManager * memory_manager,tpu::TpuNodeContext * node_context,se::Stream * stream,const xrt::XRTExecutionConfig & config_proto,const TPUExecutableInfoProto & executable_proto,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias,xla::ScopedShapedBuffer output_scoped_buffer,int device_ordinal)171 Status AllocateOutputTensors(
172 OpKernelContext* context, XRTMemoryManager* memory_manager,
173 tpu::TpuNodeContext* node_context, se::Stream* stream,
174 const xrt::XRTExecutionConfig& config_proto,
175 const TPUExecutableInfoProto& executable_proto,
176 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
177 const xla::HloInputOutputAliasConfig& input_output_alias,
178 xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) {
179 TF_ASSIGN_OR_RETURN(
180 RefPtr<XRTTupleAllocation> output_tuple,
181 AllocateOutputTuple(node_context, stream, input_tuples,
182 input_output_alias, std::move(output_scoped_buffer),
183 device_ordinal));
184 return CreateExecuteOutput(context, memory_manager, std::move(output_tuple),
185 config_proto.return_exploded_tuple());
186 }
187
RunExecutable(OpKernelContext * context,tpu::TpuNodeContext * node_context,const TPUExecutableInfoProto & executable,std::vector<xla::ExecutionInput> arguments,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)188 xla::StatusOr<xla::ExecutionOutput> RunExecutable(
189 OpKernelContext* context, tpu::TpuNodeContext* node_context,
190 const TPUExecutableInfoProto& executable,
191 std::vector<xla::ExecutionInput> arguments, const string& execution_id,
192 const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
193 xla::Backend* const backend, se::Stream* stream, int core_index,
194 int device_ordinal, string rendezvous_key_base) {
195 profiler::TraceMe trace_me("RunExecutable", /*level=*/2);
196
197 // se::StreamExecutor* executor = node->stream_executor();
198
199 std::unique_ptr<xla::DeviceAssignment> device_assignment;
200 if (executable.has_device_assignment()) {
201 TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
202 executable.device_assignment()));
203 }
204 // Ideally this should be the host-to-device stream from XlaDeviceContext.
205 // The particular anti-dependency this is avoiding (why we need a separate
206 // transfer stream) is between the executable writing tuple tables and
207 // TPUExecute()'s deregister_stream; if they come from the same stream pool
208 // antidependencies will occur. XlaBackend has a different pool of streams
209 // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
210 // will never refer to the same stream.
211 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
212 backend->BorrowStream(device_ordinal));
213 const TPUHostTransferInfoProto& host_transfer_info =
214 tpu_program_group->host_transfer_info(core_index);
215 TF_ASSIGN_OR_RETURN(
216 xla::ExecutionOutput output,
217 TPUExecute(executable, host_transfer_info,
218 *tpu_program_group->hlo_metadata(core_index),
219 std::move(arguments), rendezvous_key_base, rng_seed,
220 node_context, device_assignment.get(),
221 context->cancellation_manager(), context, stream,
222 transfer_stream_ptr.get(),
223 tpu_program_group->tpu_program(core_index)));
224
225 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
226
227 return output;
228 }
229
ExecuteTPUProgram(OpKernelContext * context,tpu::TpuNodeContext * node_context,XRTMemoryManager * memory_manager,const TPUExecutableInfoProto & executable,const GetBufferFunction & get_buffers_fn,const string & execution_id,const uint32 rng_seed,const tpu::TpuProgramGroup * tpu_program_group,xla::Backend * const backend,se::Stream * stream,int core_index,int device_ordinal,string rendezvous_key_base)230 xla::StatusOr<xla::ExecutionOutput> ExecuteTPUProgram(
231 OpKernelContext* context, tpu::TpuNodeContext* node_context,
232 XRTMemoryManager* memory_manager, const TPUExecutableInfoProto& executable,
233 const GetBufferFunction& get_buffers_fn, const string& execution_id,
234 const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group,
235 xla::Backend* const backend, se::Stream* stream, int core_index,
236 int device_ordinal, string rendezvous_key_base) {
237 auto runfn = [&]() -> xla::StatusOr<xla::ExecutionOutput> {
238 TF_ASSIGN_OR_RETURN(auto arguments, get_buffers_fn());
239 return RunExecutable(context, node_context, executable,
240 std::move(arguments), execution_id, rng_seed,
241 tpu_program_group, backend, stream, core_index,
242 device_ordinal, rendezvous_key_base);
243 };
244 return memory_manager->Run<xla::ExecutionOutput>(
245 runfn, backend, device_ordinal, /*requested_free_size=*/0,
246 backend->memory_allocator());
247 }
248
249 // XRTExecuteOp
250
251 class XRTExecuteOp : public AsyncOpKernel {
252 public:
253 explicit XRTExecuteOp(OpKernelConstruction* context);
254
255 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
256
257 private:
258 Status DoWork(OpKernelContext* context);
259 };
260
XRTExecuteOp(OpKernelConstruction * context)261 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
262 : AsyncOpKernel(context, /* is_deferred = */ true) {}
263
ComputeAsync(OpKernelContext * context,DoneCallback done)264 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
265 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
266 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
267 done();
268 }
269
DoWork(OpKernelContext * context)270 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
271 VLOG(1) << "XRTExecuteOp::Compute";
272
273 const XlaDevice::Metadata* metadata;
274 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
275 const int device_ordinal = metadata->device_ordinal();
276 // We are guaranteed that the object underlying TpuNodeContext won't be
277 // deleted out from under us, while node_context is alive.
278 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
279 tpu::TpuNodeContext::Create(device_ordinal));
280 xla::Backend* const backend = node_context->backend();
281 se::Stream* stream = context->op_device_context()->stream();
282
283 auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell());
284 profiler::TraceMe trace_me(
285 [context] {
286 return profiler::TraceMeEncode("TpuExecuteOp",
287 {{"step_id", context->step_id()}});
288 },
289 /*level=*/2);
290 profiler::TraceMe trace_me_init("XRTExecuteOp::Init", /*level=*/2);
291
292 auto* rm = GetTPUConfigResourceMgr();
293 TF_RET_CHECK(rm != nullptr);
294
295 const Tensor& execution_input = context->input(0);
296 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
297 int64_t compilation_handle = execution_input.scalar<int64_t>()();
298
299 const Tensor& execution_config = context->input(1);
300 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
301 xrt::XRTExecutionConfig config_proto;
302 TF_RET_CHECK(
303 config_proto.ParseFromString(execution_config.scalar<tstring>()()));
304
305 int core_index_in_replica = config_proto.core_index_in_replica();
306 bool release_inputs = config_proto.release_input_handles();
307 bool release_compilation = config_proto.release_compilation_handle();
308
309 string rendezvous_key_base = std::to_string(compilation_handle);
310 std::unique_ptr<CompilationCacheEntryRef> entry;
311 TF_RETURN_IF_ERROR(GetComputationCacheEntry(rm, compilation_handle,
312 core_index_in_replica, &entry));
313
314 TpuCompilationCacheEntry centry = entry->get();
315 const tpu::TpuProgramGroup* tpu_program_group =
316 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
317 centry.tpu_program_group());
318 CHECK_NE(tpu_program_group, nullptr);
319
320 if (release_compilation) {
321 // Process-wide cache of Tpu executables.
322 tpu::TpuCompilationCacheInterface* cache;
323 TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuCompilationCacheInterface>(
324 rm->default_container(), tpu::kCompilationCacheResourceName, &cache));
325 core::ScopedUnref cache_unref(cache);
326 TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
327 VLOG(2) << "Released compilation handle " << compilation_handle;
328 }
329
330 const int core_index = centry.core_index();
331 const TPUExecutableInfoProto& executable =
332 tpu_program_group->executable_info(core_index);
333
334 std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
335
336 TF_ASSIGN_OR_RETURN(
337 xla::HloInputOutputAliasConfig input_output_alias,
338 GetExecutableAliasConfig(tpu_program_group, backend, core_index));
339 TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
340 GetComputationInputs(context, "input_handles"));
341
342 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
343 XRTMemoryManager::WorkingSet working_set(memory_manager);
344 TF_ASSIGN_OR_RETURN(
345 std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
346 GetInputTupleAllocations(
347 input_coords, &working_set, backend, executable.input_shapes_size(),
348 [&](int64_t i) { return xla::Shape(executable.input_shapes(i)); },
349 release_inputs, backend->memory_allocator()));
350 auto get_buffers_fn = [&]() {
351 return GetArgumentsBuffers(input_output_alias, input_tuples,
352 input_is_dynamic, release_inputs);
353 };
354 trace_me_init.Stop();
355
356 TF_ASSIGN_OR_RETURN(
357 xla::ExecutionOutput output,
358 ExecuteTPUProgram(
359 context, node_context.get(), memory_manager.get(), executable,
360 get_buffers_fn, config_proto.execution_instance_key(),
361 config_proto.rng_seed(), tpu_program_group, backend, stream,
362 core_index, device_ordinal, rendezvous_key_base));
363
364 // AllocateComputationOutput writes the output tuple handle to the output
365 // tensor return value from the Op.
366 TF_RETURN_IF_ERROR(AllocateOutputTensors(
367 context, memory_manager.get(), node_context.get(), stream, config_proto,
368 executable, input_tuples, input_output_alias, output.ConsumeResult(),
369 device_ordinal));
370 return OkStatus();
371 }
372
373 class XRTExecuteChainedOp : public AsyncOpKernel {
374 public:
375 explicit XRTExecuteChainedOp(OpKernelConstruction* context);
376
377 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
378
379 private:
380 Status DoWork(OpKernelContext* context);
381 };
382
XRTExecuteChainedOp(OpKernelConstruction * context)383 XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context)
384 : AsyncOpKernel(context, /* is_deferred = */ true) {}
385
ComputeAsync(OpKernelContext * context,DoneCallback done)386 void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context,
387 DoneCallback done) {
388 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
389 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
390 done();
391 }
392
DoWork(OpKernelContext * context)393 Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
394 VLOG(1) << "XRTExecuteChainedOp::Compute";
395 const XlaDevice::Metadata* metadata;
396 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
397 const int device_ordinal = metadata->device_ordinal();
398 // We are guaranteed that the object underlying TpuNodeContext won't be
399 // deleted out from under us, while node_context is alive.
400 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_context,
401 tpu::TpuNodeContext::Create(device_ordinal));
402 xla::Backend* const backend = node_context->backend();
403 se::Stream* stream = context->op_device_context()->stream();
404 auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell());
405 profiler::TraceMe trace_me(
406 [context] {
407 return profiler::TraceMeEncode("TpuExecuteChainedOp",
408 {{"step_id", context->step_id()}});
409 },
410 /*level=*/2);
411 ResourceMgr* rm = GetTPUConfigResourceMgr();
412 TF_RET_CHECK(rm != nullptr);
413
414 const Tensor& execution_plan = context->input(0);
415 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
416 xrt::XRTChainedExecutePlan plan;
417 TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar<tstring>()()));
418
419 const Tensor& execution_config = context->input(1);
420 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
421 xrt::XRTChainedExecuteConfig config;
422 TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
423
424 TpuCompilationCacheLookup* proto_lookup;
425 TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
426 tpu::kCompiledProtoCacheResourceName,
427 &proto_lookup));
428 core::ScopedUnref lookup_unref(proto_lookup);
429 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
430 auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
431 absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
432 -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
433 std::unique_ptr<CompilationCacheEntryRef> entry;
434 TF_RETURN_IF_ERROR(proto_lookup->Lookup(
435 op.computation_handle(), config.core_index_in_replica(), &entry));
436 string rendezvous_key_base = std::to_string(op.computation_handle());
437 TpuCompilationCacheEntry centry = entry->get();
438 const tpu::TpuProgramGroup* tpu_program_group =
439 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
440 centry.tpu_program_group());
441 CHECK_NE(tpu_program_group, nullptr);
442 const int core_index = centry.core_index();
443 const TPUExecutableInfoProto& executable =
444 tpu_program_group->executable_info(core_index);
445 std::vector<bool> input_is_dynamic = GetDynamicInputInfo(executable);
446
447 TF_ASSIGN_OR_RETURN(
448 xla::HloInputOutputAliasConfig input_output_alias,
449 GetExecutableAliasConfig(tpu_program_group, backend, core_index));
450 TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
451 GetChainedOpInputs(op, op_inputs, executable));
452 auto get_buffers_fn = [&]() {
453 return GetArgumentsBuffers(input_output_alias, input_tuples,
454 input_is_dynamic,
455 /*release_inputs=*/false);
456 };
457 TF_ASSIGN_OR_RETURN(
458 xla::ExecutionOutput output,
459 ExecuteTPUProgram(context, node_context.get(), memory_manager.get(),
460 executable, get_buffers_fn,
461 config.execution_instance_key(), config.rng_seed(),
462 tpu_program_group, backend, stream, core_index,
463 device_ordinal, rendezvous_key_base));
464 return AllocateOutputTuple(node_context.get(), stream, input_tuples,
465 input_output_alias, output.ConsumeResult(),
466 device_ordinal);
467 };
468
469 return ExecuteChained(context, memory_manager, backend, device_ordinal, plan,
470 config, execute_op, backend->memory_allocator());
471 }
472
473 } // namespace
474
475 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
476 .Device(DEVICE_TPU_NODE)
477 .HostMemory("computation_handle")
478 .HostMemory("execution_config")
479 .HostMemory("input_handles")
480 .HostMemory("output_handle"),
481 XRTExecuteOp);
482
483 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
484 .Device(DEVICE_TPU_NODE)
485 .HostMemory("execution_plan")
486 .HostMemory("execution_config")
487 .HostMemory("output_handle"),
488 XRTExecuteChainedOp);
489
490 } // namespace tensorflow
491