• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/computation_placer.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
24 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xrt/xrt.pb.h"
29 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
30 #include "tensorflow/compiler/xrt/xrt_device.h"
31 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
32 #include "tensorflow/compiler/xrt/xrt_metrics.h"
33 #include "tensorflow/compiler/xrt/xrt_state.h"
34 #include "tensorflow/compiler/xrt/xrt_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/resource_mgr.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/core/refcount.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/cleanup.h"
41 #include "tensorflow/core/lib/monitoring/timed.h"
42 #include "tensorflow/core/platform/errors.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/stream_executor/device_memory.h"
45 #include "tensorflow/stream_executor/device_memory_allocator.h"
46 #include "tensorflow/stream_executor/platform.h"
47 #include "tensorflow/stream_executor/stream_executor.h"
48 #include "tensorflow/stream_executor/stream_executor_internal.h"
49 
50 namespace tensorflow {
51 
52 namespace {
53 
InitialRandomSeed()54 uint32 InitialRandomSeed() {
55   // Support plumbing the TF seed through to XLA is being worked on.
56   // If a user wants deterministic behavior, their best option
57   // is to start with a known checkpoint. This also handles issues when
58   // multiple random calls can be invoked in any order by TF executor.
59   // Another option is to use stateless random ops. They have much cleaner
60   // semantics.
61   // If a user really wants to set a deterministic seed for XLA-based
62   // devices, this is the place to do it.
63   std::random_device rd;
64   // Make the starting value odd.
65   return rd() | 1;
66 }
67 
GetXLARandomSeed()68 uint32 GetXLARandomSeed() {
69   // We initialize counter with an odd number and increment it by two
70   // everytime. This ensures that it will never be zero, even
71   // after an overflow. When seeded with zero, some XLA backends
72   // can return all zeros instead of random numbers.
73   static std::atomic<uint32> counter(InitialRandomSeed());
74   return counter.fetch_add(2);
75 }
76 
GetDynamicInputInfo(const xla::ComputationLayout & computation_layout)77 std::vector<bool> GetDynamicInputInfo(
78     const xla::ComputationLayout& computation_layout) {
79   std::vector<bool> input_is_dynamic;
80   input_is_dynamic.reserve(computation_layout.parameter_count());
81   for (int64_t i = 0; i < computation_layout.parameter_count(); ++i) {
82     input_is_dynamic.push_back(
83         !computation_layout.parameter_shape(i).is_static());
84   }
85   return input_is_dynamic;
86 }
87 
GetInputTuples(xla::LocalExecutable * executable,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,const std::vector<InputCoords> & input_coords,bool release_inputs,se::DeviceMemoryAllocator * allocator)88 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTuples(
89     xla::LocalExecutable* executable, XRTMemoryManager::WorkingSet* working_set,
90     xla::Backend* backend, const std::vector<InputCoords>& input_coords,
91     bool release_inputs, se::DeviceMemoryAllocator* allocator) {
92   const xla::ComputationLayout& computation_layout =
93       executable->executable()->module_config().entry_computation_layout();
94 
95   return GetInputTupleAllocations(
96       input_coords, working_set, backend, computation_layout.parameter_count(),
97       [&](int64_t i) { return computation_layout.parameter_shape(i); },
98       release_inputs, allocator);
99 }
100 
GetChainedOpInputTuples(const xrt::XRTChainedExecuteOp & op,absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)101 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputTuples(
102     const xrt::XRTChainedExecuteOp& op,
103     absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) {
104   std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
105   input_tuples.reserve(op.inputs_size());
106   for (int i = 0; i < op.inputs_size(); ++i) {
107     auto& input = op.inputs(i);
108     // Thanks to the greatness of proto3, there is no way to query for
109     // explicitly set fields, so the default for output_index (zero) means no
110     // sub-index. As consequence, the real index is output_index - 1.
111     if (input.output_index() == 0) {
112       input_tuples.emplace_back(op_inputs[i]);
113     } else {
114       XRTTupleAllocation* sub_tuple;
115       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
116           op_inputs[i].get(), {input.output_index() - 1}, &sub_tuple,
117           /*alias_parent_allocation=*/true));
118       input_tuples.emplace_back(sub_tuple);
119     }
120   }
121   return input_tuples;
122 }
123 
124 // Given a shape, returns a byte array representing the shape metadata of the
125 // shape. The shape metadata contains dimensions sizes stored as contiguous S32.
PrepareMetadata(const xla::Shape & shape)126 std::vector<int32> PrepareMetadata(const xla::Shape& shape) {
127   DCHECK(shape.is_static());
128   DCHECK(shape.IsArray());
129   // Each dimension size is stored as a S32.
130   std::vector<int32> result(shape.dimensions_size());
131   for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
132     result[i] = shape.dimensions(i);
133   }
134   return result;
135 }
136 
137 // Given a buffer with dynamic shape, update buffer metadata at the correct
138 // offset starting from that buffer.
139 //
140 // +-----------+
141 // |Payload    |
142 // +-----------+
143 // | Padding   |
144 // +-----------+
145 // |dim_size_0 |  (each dim_size is a S32):
146 // +-----------+
147 // |dim_size_1 |
148 // +-----------+
149 //  ..........
150 // +-----------+
151 //
152 // Size of payload = ByteSizeOf(runtime_shape)
153 // Size of payload + padding = ByteSizeOf(compile_time_shape_static)
154 // Size of payload + padding + metadata = ByteSizeOf(compile_time_shape)
UpdateMetadata(se::Stream * stream,se::DeviceMemory<uint8> * buffer,const xla::Shape & compile_time_shape,const xla::Shape & runtime_shape)155 Status UpdateMetadata(se::Stream* stream, se::DeviceMemory<uint8>* buffer,
156                       const xla::Shape& compile_time_shape,
157                       const xla::Shape& runtime_shape) {
158   TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
159                                          stream->parent()->platform()));
160   TF_ASSIGN_OR_RETURN(
161       auto transfer_manager,
162       xla::TransferManager::GetForPlatform(stream->parent()->platform()));
163   auto shape_size_fn = compiler->ShapeSizeBytesFunction();
164   xla::Shape compile_time_shape_static =
165       xla::ShapeUtil::MakeStaticShape(compile_time_shape);
166   uint64 offset = shape_size_fn(compile_time_shape_static);
167   uint64 metadata_size = shape_size_fn(compile_time_shape) - offset;
168   auto metadata_buffer =
169       stream->parent()->GetSubBuffer(buffer, offset, metadata_size);
170 
171   auto metadata_literal = std::make_shared<xla::Literal>(
172       xla::LiteralUtil::CreateR1<int32>(PrepareMetadata(runtime_shape)));
173   TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync(
174       stream, *metadata_literal, metadata_buffer));
175   // Retain the literal until the end of the transfer.
176   stream->ThenDoHostCallback([metadata_literal]() { return OkStatus(); });
177   return OkStatus();
178 }
179 
180 // Given a static input buffer, convert it to dynamic form by expanding it to
181 // the bounded size and attaching a metadata filled with dimension sizes.
182 //
183 // From:
184 // +--------+
185 // |Payload |
186 // +--------+
187 //
188 // To:
189 //
190 // +--------+
191 // |Payload |
192 // +--------+
193 // | Padding|
194 // +--------+
195 // |Metadata|
196 // +--------+
197 //
198 // As we can't expand the size of an existing memory allocation, a reallocation
199 // is required. A list of new allocations are returned after this function. The
200 // caller is reponsible for maintaining those allocations.
UpdateDynamicInputs(se::Stream * stream,se::DeviceMemoryAllocator * allocator,std::vector<xla::ExecutionInput> * execution_inputs,const std::vector<xla::ShapeLayout> & compile_time_shapes)201 Status UpdateDynamicInputs(
202     se::Stream* stream, se::DeviceMemoryAllocator* allocator,
203     std::vector<xla::ExecutionInput>* execution_inputs,
204     const std::vector<xla::ShapeLayout>& compile_time_shapes) {
205   TF_RET_CHECK(execution_inputs->size() == compile_time_shapes.size());
206   TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
207                                          stream->parent()->platform()));
208   auto shape_size_fn = compiler->ShapeSizeBytesFunction();
209   for (int64_t i = 0; i < compile_time_shapes.size(); i++) {
210     const xla::Shape& compile_time_shape = compile_time_shapes[i].shape();
211     if (compile_time_shape.is_static()) {
212       continue;
213     }
214     xla::ExecutionInput* execution_input = &(*execution_inputs)[i];
215     bool element_modified = false;
216     TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
217         compile_time_shape,
218         [&](const xla::Shape& sub_shape,
219             const xla::ShapeIndex& index) -> Status {
220           if (sub_shape.IsTuple() || sub_shape.is_static()) {
221             return OkStatus();
222           }
223           TF_ASSIGN_OR_RETURN(
224               const xla::Shape* runtime_shape,
225               xla::ShapeUtil::TryGetSubshape(execution_input->shape(), index));
226           TF_RET_CHECK(!runtime_shape->IsTuple());
227           TF_RET_CHECK(xla::ShapeUtil::DynamicArrayShapeIsCompatible(
228               *runtime_shape, sub_shape));
229           TF_ASSIGN_OR_RETURN(
230               se::OwningDeviceMemory dynamic_input,
231               allocator->Allocate(stream->parent()->device_ordinal(),
232                                   shape_size_fn(sub_shape)));
233 
234           se::DeviceMemoryBase static_input =
235               execution_input->Buffer(index).AsDeviceMemoryBase();
236           se::DeviceMemory<uint8>* dynamic_input_base = dynamic_input.ptr();
237           // Send the original data to the new location.
238           stream->ThenMemcpyD2D(dynamic_input_base, static_input,
239                                 static_input.size());
240           TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
241                                             sub_shape, *runtime_shape));
242           // Modify the memory location in the input shape tree to point to the
243           // new input.
244           execution_input->SetBuffer(
245               index, xla::MaybeOwningDeviceMemory(std::move(dynamic_input)));
246           execution_input->ClearUnownedIndex(index);
247           element_modified = true;
248           return OkStatus();
249         }));
250     if (element_modified) {
251       TF_RETURN_IF_ERROR(execution_input->SetDynamicShape(compile_time_shape));
252       TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
253                           execution_input->ToShapedBuffer(
254                               allocator, stream->parent()->device_ordinal()));
255       // The input location has been modified, need to fix tuple table to
256       // point to the correct address.
257       TF_ASSIGN_OR_RETURN(
258           auto transfer_manager,
259           xla::TransferManager::GetForPlatform(stream->parent()->platform()));
260       TF_RETURN_IF_ERROR(
261           transfer_manager->WriteTupleIndexTablesAsync(stream, shaped_buffer));
262     }
263   }
264   return OkStatus();
265 }
266 
CreateOutputTuple(se::Stream * stream,xla::ExecutionOutput run_result,xla::Backend * backend,int device_ordinal,se::DeviceMemoryAllocator * allocator)267 xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
268     se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
269     int device_ordinal, se::DeviceMemoryAllocator* allocator) {
270   XRTTupleAllocation* output_tuple;
271   xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
272   if (shaped_buffer->on_device_shape().is_dynamic()) {
273     // Update dynamic shapes from output buffer, and create a XRT tensor with
274     // dimension sizes read from metadata.
275     xla::Shape output_device_shape = shaped_buffer->on_device_shape();
276     TF_ASSIGN_OR_RETURN(
277         auto transfer_manager,
278         xla::TransferManager::GetForPlatform(stream->parent()->platform()));
279     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
280         stream, shaped_buffer, &output_device_shape));
281     TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
282         *shaped_buffer,
283         xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape),
284         output_device_shape, backend, device_ordinal, &output_tuple,
285         allocator));
286   } else {
287     // Fast-path: Don't copy shapes of output buffer.
288     TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
289         *shaped_buffer, backend, device_ordinal, &output_tuple, allocator));
290   }
291   // After the output tuple is created, we can release the output result
292   // buffers, to make sure they won't be cleared by its destructor.
293   (void)run_result.ConsumeResult().release();
294   return RefPtr<XRTTupleAllocation>(output_tuple);
295 }
296 
RunExecutable(OpKernelContext * context,XRTGenericDeviceAccessor::ScopedRef * device_ref,xla::LocalExecutable * executable,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,bool release_inputs,se::Stream * stream,int rng_seed,const xrt::CommonExecutionConfig & config)297 xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
298     OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
299     xla::LocalExecutable* executable,
300     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
301     bool release_inputs, se::Stream* stream, int rng_seed,
302     const xrt::CommonExecutionConfig& config) {
303   const xla::ComputationLayout& computation_layout =
304       executable->executable()->module_config().entry_computation_layout();
305   std::vector<bool> input_is_dynamic = GetDynamicInputInfo(computation_layout);
306   TF_ASSIGN_OR_RETURN(
307       std::vector<xla::ExecutionInput> execution_inputs,
308       GetArgumentsBuffers(
309           executable->executable()->module().input_output_alias_config(),
310           input_tuples, input_is_dynamic, release_inputs));
311 
312   se::DeviceMemoryAllocator* allocator = device_ref->allocator();
313   xla::ExecutableRunOptions run_options;
314   run_options.set_stream(stream);
315   run_options.set_allocator(allocator);
316   run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
317   run_options.set_rng_seed(rng_seed);
318   if (config.run_id() != 0) {
319     run_options.set_run_id(xla::RunId(config.run_id()));
320   }
321   if (executable->executable()
322           ->module_config()
323           .has_static_device_assignment()) {
324     run_options.set_device_assignment(
325         &executable->executable()->module_config().static_device_assignment());
326   }
327   xla::gpu::GpuExecutableRunOptions gpu_options;
328   std::vector<xla::GlobalDeviceId> gpu_global_ids;
329   if (config.local_replica_mapping_size() > 0) {
330     gpu_global_ids.reserve(config.local_replica_mapping_size());
331     for (auto& gid : config.local_replica_mapping()) {
332       gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid));
333     }
334     gpu_options.set_gpu_global_device_ids(gpu_global_ids);
335   }
336   std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
337   if (nccl_factory != nullptr) {
338     auto uid_callback =
339         [&](const xla::gpu::NcclCliqueKey& key) -> xla::StatusOr<std::string> {
340       std::vector<int64_t> replicas;
341       const auto key_devices = key.devices();
342       replicas.reserve(key_devices.size());
343       for (auto& device : key_devices) {
344         replicas.push_back(device.value());
345       }
346       return nccl_factory->GetUniqueId(replicas);
347     };
348     gpu_options.set_nccl_unique_id_callback(uid_callback);
349   }
350   run_options.set_gpu_executable_run_options(&gpu_options);
351 
352   const std::vector<xla::ShapeLayout>& shape_layouts =
353       executable->executable()
354           ->module_config()
355           .entry_computation_layout()
356           .parameter_layouts();
357   TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, run_options.allocator(),
358                                          &execution_inputs, shape_layouts));
359   TF_ASSIGN_OR_RETURN(
360       xla::ExecutionOutput run_result,
361       executable->Run(std::move(execution_inputs), run_options));
362 
363   TF_ASSIGN_OR_RETURN(
364       RefPtr<XRTTupleAllocation> output_tuple_ptr,
365       CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
366                         device_ref->device_ordinal(), allocator));
367   // The ScopedShapedBuffer returned by the executable Run() API, in case of
368   // input/output buffer aliasing, might have holes in it, which need to be
369   // filled using the proper input tuples buffers which are the source of
370   // aliasing.
371   TF_RETURN_IF_ERROR(RebuildOutputAliases(
372       output_tuple_ptr, input_tuples,
373       executable->executable()->module().input_output_alias_config()));
374 
375   return std::move(output_tuple_ptr);
376 }
377 
ExecuteComputation(OpKernelContext * context,XRTMemoryManager * memory_manager,XRTGenericDeviceAccessor::ScopedRef * device_ref,xla::LocalExecutable * executable,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,bool release_inputs,se::Stream * stream,int rng_seed,const xrt::CommonExecutionConfig & config)378 xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
379     OpKernelContext* context, XRTMemoryManager* memory_manager,
380     XRTGenericDeviceAccessor::ScopedRef* device_ref,
381     xla::LocalExecutable* executable,
382     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
383     bool release_inputs, se::Stream* stream, int rng_seed,
384     const xrt::CommonExecutionConfig& config) {
385   auto runfn = [&]() {
386     return RunExecutable(context, device_ref, executable, input_tuples,
387                          release_inputs, stream, rng_seed, config);
388   };
389 
390   // We pass zero as requested_free_size as there is no simple way to get the
391   // peak heap size. Upon zero, the Run() API will try to free chunks of device
392   // memory, until either the runfn can run, or we run out of freeable memory.
393   return memory_manager->Run<RefPtr<XRTTupleAllocation>>(
394       runfn, device_ref->backend(), device_ref->device_ordinal(),
395       /*requested_free_size=*/0, device_ref->allocator());
396 }
397 
ExecuteComputation(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,XRTGenericDeviceAccessor::ScopedRef * device_ref,xla::LocalExecutable * executable,const std::vector<InputCoords> & input_coords,bool release_inputs,se::Stream * stream,int rng_seed,const xrt::CommonExecutionConfig & config)398 xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
399     OpKernelContext* context, const RefPtr<XRTMemoryManager>& memory_manager,
400     XRTGenericDeviceAccessor::ScopedRef* device_ref,
401     xla::LocalExecutable* executable,
402     const std::vector<InputCoords>& input_coords, bool release_inputs,
403     se::Stream* stream, int rng_seed,
404     const xrt::CommonExecutionConfig& config) {
405   XRTMemoryManager::WorkingSet working_set(memory_manager);
406   TF_ASSIGN_OR_RETURN(
407       std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
408       GetInputTuples(executable, &working_set, device_ref->backend(),
409                      input_coords, release_inputs, device_ref->allocator()));
410   return ExecuteComputation(context, memory_manager.get(), device_ref,
411                             executable, input_tuples, release_inputs, stream,
412                             rng_seed, config);
413 }
414 
415 // XRTExecuteOp
416 
417 class XRTExecuteOp : public AsyncOpKernel {
418  public:
419   explicit XRTExecuteOp(OpKernelConstruction* context);
420   ~XRTExecuteOp() override;
421 
422   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
423 
424  private:
425   Status DoWork(OpKernelContext* context);
426 };
427 
XRTExecuteOp(OpKernelConstruction * context)428 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
429     : AsyncOpKernel(context) {}
430 
ComputeAsync(OpKernelContext * context,DoneCallback done)431 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
432   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
433   Env::Default()->SchedClosure([this, context, done]() {
434     OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
435     done();
436   });
437 }
438 
DoWork(OpKernelContext * context)439 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
440   VLOG(1) << "XRTExecuteOp::Compute";
441   auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell());
442   ResourceMgr* rm;
443   TF_RETURN_IF_ERROR(
444       XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
445 
446   const Tensor& execution_input = context->input(0);
447   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
448   int64_t compilation_handle = execution_input.scalar<int64_t>()();
449 
450   const Tensor& execution_config = context->input(1);
451   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
452   xrt::XRTExecutionConfig config_proto;
453   TF_RET_CHECK(
454       ParseFromTString(execution_config.scalar<tstring>()(), &config_proto));
455 
456   int core_index_in_replica = config_proto.core_index_in_replica();
457   TF_RET_CHECK(core_index_in_replica == 0);
458   bool release_inputs = config_proto.release_input_handles();
459   bool release_compilation = config_proto.release_compilation_handle();
460 
461   TF_ASSIGN_OR_RETURN(auto cache,
462                       XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
463                           context, /*max_number_of_entries=*/0));
464   // We are guaranteed that the underlying device object won't be deleted out
465   // from under us, while the ScopedRef is live.
466   class XRTGenericDeviceAccessor::ScopedRef device_ref;
467   TF_RETURN_IF_ERROR(
468       XRTGenericDeviceAccessor::InitScopedRef(context, &device_ref));
469 
470   int rng_seed = config_proto.rng_seed();
471   if (rng_seed == 0) {
472     rng_seed = GetXLARandomSeed();
473   }
474 
475   se::Stream* stream = context->op_device_context()
476                            ? context->op_device_context()->stream()
477                            : nullptr;
478   RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
479   TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
480                       GetComputationInputs(context, "input_handles"));
481 
482   std::unique_ptr<XRTCompilationCacheEntryRef> entry;
483   TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
484   xla::LocalExecutable* executable = entry->get().get_executable();
485   if (release_compilation) {
486     // Process-wide cache of XLA executables.
487     TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
488     VLOG(2) << "Released compilation handle " << compilation_handle;
489   }
490 
491   TF_ASSIGN_OR_RETURN(
492       RefPtr<XRTTupleAllocation> output_tuple,
493       ExecuteComputation(context, memory_manager, &device_ref, executable,
494                          input_coords, release_inputs, stream, rng_seed,
495                          config_proto.common_config()));
496 
497   return CreateExecuteOutput(context, memory_manager.get(),
498                              std::move(output_tuple),
499                              config_proto.return_exploded_tuple());
500 }
501 
502 XRTExecuteOp::~XRTExecuteOp() = default;
503 
504 class XRTExecuteChainedOp : public AsyncOpKernel {
505  public:
506   explicit XRTExecuteChainedOp(OpKernelConstruction* context);
507   ~XRTExecuteChainedOp() override;
508 
509   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
510 
511  private:
512   Status DoWork(OpKernelContext* context);
513 };
514 
XRTExecuteChainedOp(OpKernelConstruction * context)515 XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context)
516     : AsyncOpKernel(context) {}
517 
ComputeAsync(OpKernelContext * context,DoneCallback done)518 void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context,
519                                        DoneCallback done) {
520   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
521   Env::Default()->SchedClosure([this, context, done]() {
522     OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
523     done();
524   });
525 }
526 
DoWork(OpKernelContext * context)527 Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
528   VLOG(1) << "XRTExecuteChainedOp::Compute";
529   auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell());
530   ResourceMgr* rm;
531   TF_RETURN_IF_ERROR(
532       XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
533 
534   const Tensor& execution_plan = context->input(0);
535   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
536   xrt::XRTChainedExecutePlan plan;
537   TF_RET_CHECK(ParseFromTString(execution_plan.scalar<tstring>()(), &plan));
538 
539   const Tensor& execution_config = context->input(1);
540   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
541   xrt::XRTChainedExecuteConfig config;
542   TF_RET_CHECK(ParseFromTString(execution_config.scalar<tstring>()(), &config));
543 
544   TF_ASSIGN_OR_RETURN(auto cache,
545                       XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
546                           context, /*max_number_of_entries=*/0));
547   // We are guaranteed that the underlying device object won't be deleted out
548   // from under us, while the ScopedRef is live.
549   class XRTGenericDeviceAccessor::ScopedRef device_ref;
550   TF_RETURN_IF_ERROR(
551       XRTGenericDeviceAccessor::InitScopedRef(context, &device_ref));
552 
553   int rng_seed = config.rng_seed();
554   if (rng_seed == 0) {
555     rng_seed = GetXLARandomSeed();
556   }
557 
558   se::Stream* stream = context->op_device_context()
559                            ? context->op_device_context()->stream()
560                            : nullptr;
561   RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
562   auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
563                         absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
564       -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
565     std::unique_ptr<XRTCompilationCacheEntryRef> entry;
566     TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry));
567     xla::LocalExecutable* executable = entry->get().get_executable();
568 
569     TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
570                         GetChainedOpInputTuples(op, op_inputs));
571 
572     return ExecuteComputation(
573         context, memory_manager.get(), &device_ref, executable, input_tuples,
574         /*release_inputs=*/false, stream, rng_seed, config.common_config());
575   };
576 
577   return ExecuteChained(context, memory_manager, device_ref.backend(),
578                         device_ref.device_ordinal(), plan, config, execute_op,
579                         device_ref.allocator());
580 }
581 
582 XRTExecuteChainedOp::~XRTExecuteChainedOp() = default;
583 
584 }  // namespace
585 
586 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
587                             .Device(DEVICE_XLA_CPU)
588                             .HostMemory("computation_handle")
589                             .HostMemory("execution_config")
590                             .HostMemory("input_handles")
591                             .HostMemory("output_handle"),
592                         XRTExecuteOp);
593 
594 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
595                             .Device(DEVICE_XLA_GPU)
596                             .HostMemory("computation_handle")
597                             .HostMemory("execution_config")
598                             .HostMemory("input_handles")
599                             .HostMemory("output_handle"),
600                         XRTExecuteOp);
601 
602 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
603                             .Device(DEVICE_XLA_CPU)
604                             .HostMemory("execution_plan")
605                             .HostMemory("execution_config")
606                             .HostMemory("output_handle"),
607                         XRTExecuteChainedOp);
608 
609 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
610                             .Device(DEVICE_XLA_GPU)
611                             .HostMemory("execution_plan")
612                             .HostMemory("execution_config")
613                             .HostMemory("output_handle"),
614                         XRTExecuteChainedOp);
615 
616 }  // namespace tensorflow
617