1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/computation_placer.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
24 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xrt/xrt.pb.h"
29 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
30 #include "tensorflow/compiler/xrt/xrt_device.h"
31 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
32 #include "tensorflow/compiler/xrt/xrt_metrics.h"
33 #include "tensorflow/compiler/xrt/xrt_state.h"
34 #include "tensorflow/compiler/xrt/xrt_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/resource_mgr.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/core/refcount.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/cleanup.h"
41 #include "tensorflow/core/lib/monitoring/timed.h"
42 #include "tensorflow/core/platform/errors.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/stream_executor/device_memory.h"
45 #include "tensorflow/stream_executor/device_memory_allocator.h"
46 #include "tensorflow/stream_executor/platform.h"
47 #include "tensorflow/stream_executor/stream_executor.h"
48 #include "tensorflow/stream_executor/stream_executor_internal.h"
49
50 namespace tensorflow {
51
52 namespace {
53
InitialRandomSeed()54 uint32 InitialRandomSeed() {
55 // Support plumbing the TF seed through to XLA is being worked on.
56 // If a user wants deterministic behavior, their best option
57 // is to start with a known checkpoint. This also handles issues when
58 // multiple random calls can be invoked in any order by TF executor.
59 // Another option is to use stateless random ops. They have much cleaner
60 // semantics.
61 // If a user really wants to set a deterministic seed for XLA-based
62 // devices, this is the place to do it.
63 std::random_device rd;
64 // Make the starting value odd.
65 return rd() | 1;
66 }
67
GetXLARandomSeed()68 uint32 GetXLARandomSeed() {
69 // We initialize counter with an odd number and increment it by two
70 // everytime. This ensures that it will never be zero, even
71 // after an overflow. When seeded with zero, some XLA backends
72 // can return all zeros instead of random numbers.
73 static std::atomic<uint32> counter(InitialRandomSeed());
74 return counter.fetch_add(2);
75 }
76
GetDynamicInputInfo(const xla::ComputationLayout & computation_layout)77 std::vector<bool> GetDynamicInputInfo(
78 const xla::ComputationLayout& computation_layout) {
79 std::vector<bool> input_is_dynamic;
80 input_is_dynamic.reserve(computation_layout.parameter_count());
81 for (int64_t i = 0; i < computation_layout.parameter_count(); ++i) {
82 input_is_dynamic.push_back(
83 !computation_layout.parameter_shape(i).is_static());
84 }
85 return input_is_dynamic;
86 }
87
GetInputTuples(xla::LocalExecutable * executable,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,const std::vector<InputCoords> & input_coords,bool release_inputs,se::DeviceMemoryAllocator * allocator)88 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTuples(
89 xla::LocalExecutable* executable, XRTMemoryManager::WorkingSet* working_set,
90 xla::Backend* backend, const std::vector<InputCoords>& input_coords,
91 bool release_inputs, se::DeviceMemoryAllocator* allocator) {
92 const xla::ComputationLayout& computation_layout =
93 executable->executable()->module_config().entry_computation_layout();
94
95 return GetInputTupleAllocations(
96 input_coords, working_set, backend, computation_layout.parameter_count(),
97 [&](int64_t i) { return computation_layout.parameter_shape(i); },
98 release_inputs, allocator);
99 }
100
GetChainedOpInputTuples(const xrt::XRTChainedExecuteOp & op,absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)101 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetChainedOpInputTuples(
102 const xrt::XRTChainedExecuteOp& op,
103 absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) {
104 std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
105 input_tuples.reserve(op.inputs_size());
106 for (int i = 0; i < op.inputs_size(); ++i) {
107 auto& input = op.inputs(i);
108 // Thanks to the greatness of proto3, there is no way to query for
109 // explicitly set fields, so the default for output_index (zero) means no
110 // sub-index. As consequence, the real index is output_index - 1.
111 if (input.output_index() == 0) {
112 input_tuples.emplace_back(op_inputs[i]);
113 } else {
114 XRTTupleAllocation* sub_tuple;
115 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
116 op_inputs[i].get(), {input.output_index() - 1}, &sub_tuple,
117 /*alias_parent_allocation=*/true));
118 input_tuples.emplace_back(sub_tuple);
119 }
120 }
121 return input_tuples;
122 }
123
124 // Given a shape, returns a byte array representing the shape metadata of the
125 // shape. The shape metadata contains dimensions sizes stored as contiguous S32.
PrepareMetadata(const xla::Shape & shape)126 std::vector<int32> PrepareMetadata(const xla::Shape& shape) {
127 DCHECK(shape.is_static());
128 DCHECK(shape.IsArray());
129 // Each dimension size is stored as a S32.
130 std::vector<int32> result(shape.dimensions_size());
131 for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
132 result[i] = shape.dimensions(i);
133 }
134 return result;
135 }
136
137 // Given a buffer with dynamic shape, update buffer metadata at the correct
138 // offset starting from that buffer.
139 //
140 // +-----------+
141 // |Payload |
142 // +-----------+
143 // | Padding |
144 // +-----------+
145 // |dim_size_0 | (each dim_size is a S32):
146 // +-----------+
147 // |dim_size_1 |
148 // +-----------+
149 // ..........
150 // +-----------+
151 //
152 // Size of payload = ByteSizeOf(runtime_shape)
153 // Size of payload + padding = ByteSizeOf(compile_time_shape_static)
154 // Size of payload + padding + metadata = ByteSizeOf(compile_time_shape)
UpdateMetadata(se::Stream * stream,se::DeviceMemory<uint8> * buffer,const xla::Shape & compile_time_shape,const xla::Shape & runtime_shape)155 Status UpdateMetadata(se::Stream* stream, se::DeviceMemory<uint8>* buffer,
156 const xla::Shape& compile_time_shape,
157 const xla::Shape& runtime_shape) {
158 TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
159 stream->parent()->platform()));
160 TF_ASSIGN_OR_RETURN(
161 auto transfer_manager,
162 xla::TransferManager::GetForPlatform(stream->parent()->platform()));
163 auto shape_size_fn = compiler->ShapeSizeBytesFunction();
164 xla::Shape compile_time_shape_static =
165 xla::ShapeUtil::MakeStaticShape(compile_time_shape);
166 uint64 offset = shape_size_fn(compile_time_shape_static);
167 uint64 metadata_size = shape_size_fn(compile_time_shape) - offset;
168 auto metadata_buffer =
169 stream->parent()->GetSubBuffer(buffer, offset, metadata_size);
170
171 auto metadata_literal = std::make_shared<xla::Literal>(
172 xla::LiteralUtil::CreateR1<int32>(PrepareMetadata(runtime_shape)));
173 TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync(
174 stream, *metadata_literal, metadata_buffer));
175 // Retain the literal until the end of the transfer.
176 stream->ThenDoHostCallback([metadata_literal]() { return OkStatus(); });
177 return OkStatus();
178 }
179
180 // Given a static input buffer, convert it to dynamic form by expanding it to
181 // the bounded size and attaching a metadata filled with dimension sizes.
182 //
183 // From:
184 // +--------+
185 // |Payload |
186 // +--------+
187 //
188 // To:
189 //
190 // +--------+
191 // |Payload |
192 // +--------+
193 // | Padding|
194 // +--------+
195 // |Metadata|
196 // +--------+
197 //
198 // As we can't expand the size of an existing memory allocation, a reallocation
199 // is required. A list of new allocations are returned after this function. The
200 // caller is reponsible for maintaining those allocations.
UpdateDynamicInputs(se::Stream * stream,se::DeviceMemoryAllocator * allocator,std::vector<xla::ExecutionInput> * execution_inputs,const std::vector<xla::ShapeLayout> & compile_time_shapes)201 Status UpdateDynamicInputs(
202 se::Stream* stream, se::DeviceMemoryAllocator* allocator,
203 std::vector<xla::ExecutionInput>* execution_inputs,
204 const std::vector<xla::ShapeLayout>& compile_time_shapes) {
205 TF_RET_CHECK(execution_inputs->size() == compile_time_shapes.size());
206 TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
207 stream->parent()->platform()));
208 auto shape_size_fn = compiler->ShapeSizeBytesFunction();
209 for (int64_t i = 0; i < compile_time_shapes.size(); i++) {
210 const xla::Shape& compile_time_shape = compile_time_shapes[i].shape();
211 if (compile_time_shape.is_static()) {
212 continue;
213 }
214 xla::ExecutionInput* execution_input = &(*execution_inputs)[i];
215 bool element_modified = false;
216 TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
217 compile_time_shape,
218 [&](const xla::Shape& sub_shape,
219 const xla::ShapeIndex& index) -> Status {
220 if (sub_shape.IsTuple() || sub_shape.is_static()) {
221 return OkStatus();
222 }
223 TF_ASSIGN_OR_RETURN(
224 const xla::Shape* runtime_shape,
225 xla::ShapeUtil::TryGetSubshape(execution_input->shape(), index));
226 TF_RET_CHECK(!runtime_shape->IsTuple());
227 TF_RET_CHECK(xla::ShapeUtil::DynamicArrayShapeIsCompatible(
228 *runtime_shape, sub_shape));
229 TF_ASSIGN_OR_RETURN(
230 se::OwningDeviceMemory dynamic_input,
231 allocator->Allocate(stream->parent()->device_ordinal(),
232 shape_size_fn(sub_shape)));
233
234 se::DeviceMemoryBase static_input =
235 execution_input->Buffer(index).AsDeviceMemoryBase();
236 se::DeviceMemory<uint8>* dynamic_input_base = dynamic_input.ptr();
237 // Send the original data to the new location.
238 stream->ThenMemcpyD2D(dynamic_input_base, static_input,
239 static_input.size());
240 TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
241 sub_shape, *runtime_shape));
242 // Modify the memory location in the input shape tree to point to the
243 // new input.
244 execution_input->SetBuffer(
245 index, xla::MaybeOwningDeviceMemory(std::move(dynamic_input)));
246 execution_input->ClearUnownedIndex(index);
247 element_modified = true;
248 return OkStatus();
249 }));
250 if (element_modified) {
251 TF_RETURN_IF_ERROR(execution_input->SetDynamicShape(compile_time_shape));
252 TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
253 execution_input->ToShapedBuffer(
254 allocator, stream->parent()->device_ordinal()));
255 // The input location has been modified, need to fix tuple table to
256 // point to the correct address.
257 TF_ASSIGN_OR_RETURN(
258 auto transfer_manager,
259 xla::TransferManager::GetForPlatform(stream->parent()->platform()));
260 TF_RETURN_IF_ERROR(
261 transfer_manager->WriteTupleIndexTablesAsync(stream, shaped_buffer));
262 }
263 }
264 return OkStatus();
265 }
266
CreateOutputTuple(se::Stream * stream,xla::ExecutionOutput run_result,xla::Backend * backend,int device_ordinal,se::DeviceMemoryAllocator * allocator)267 xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
268 se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
269 int device_ordinal, se::DeviceMemoryAllocator* allocator) {
270 XRTTupleAllocation* output_tuple;
271 xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
272 if (shaped_buffer->on_device_shape().is_dynamic()) {
273 // Update dynamic shapes from output buffer, and create a XRT tensor with
274 // dimension sizes read from metadata.
275 xla::Shape output_device_shape = shaped_buffer->on_device_shape();
276 TF_ASSIGN_OR_RETURN(
277 auto transfer_manager,
278 xla::TransferManager::GetForPlatform(stream->parent()->platform()));
279 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
280 stream, shaped_buffer, &output_device_shape));
281 TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
282 *shaped_buffer,
283 xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape),
284 output_device_shape, backend, device_ordinal, &output_tuple,
285 allocator));
286 } else {
287 // Fast-path: Don't copy shapes of output buffer.
288 TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
289 *shaped_buffer, backend, device_ordinal, &output_tuple, allocator));
290 }
291 // After the output tuple is created, we can release the output result
292 // buffers, to make sure they won't be cleared by its destructor.
293 (void)run_result.ConsumeResult().release();
294 return RefPtr<XRTTupleAllocation>(output_tuple);
295 }
296
RunExecutable(OpKernelContext * context,XRTGenericDeviceAccessor::ScopedRef * device_ref,xla::LocalExecutable * executable,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,bool release_inputs,se::Stream * stream,int rng_seed,const xrt::CommonExecutionConfig & config)297 xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
298 OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
299 xla::LocalExecutable* executable,
300 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
301 bool release_inputs, se::Stream* stream, int rng_seed,
302 const xrt::CommonExecutionConfig& config) {
303 const xla::ComputationLayout& computation_layout =
304 executable->executable()->module_config().entry_computation_layout();
305 std::vector<bool> input_is_dynamic = GetDynamicInputInfo(computation_layout);
306 TF_ASSIGN_OR_RETURN(
307 std::vector<xla::ExecutionInput> execution_inputs,
308 GetArgumentsBuffers(
309 executable->executable()->module().input_output_alias_config(),
310 input_tuples, input_is_dynamic, release_inputs));
311
312 se::DeviceMemoryAllocator* allocator = device_ref->allocator();
313 xla::ExecutableRunOptions run_options;
314 run_options.set_stream(stream);
315 run_options.set_allocator(allocator);
316 run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
317 run_options.set_rng_seed(rng_seed);
318 if (config.run_id() != 0) {
319 run_options.set_run_id(xla::RunId(config.run_id()));
320 }
321 if (executable->executable()
322 ->module_config()
323 .has_static_device_assignment()) {
324 run_options.set_device_assignment(
325 &executable->executable()->module_config().static_device_assignment());
326 }
327 xla::gpu::GpuExecutableRunOptions gpu_options;
328 std::vector<xla::GlobalDeviceId> gpu_global_ids;
329 if (config.local_replica_mapping_size() > 0) {
330 gpu_global_ids.reserve(config.local_replica_mapping_size());
331 for (auto& gid : config.local_replica_mapping()) {
332 gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid));
333 }
334 gpu_options.set_gpu_global_device_ids(gpu_global_ids);
335 }
336 std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
337 if (nccl_factory != nullptr) {
338 auto uid_callback =
339 [&](const xla::gpu::NcclCliqueKey& key) -> xla::StatusOr<std::string> {
340 std::vector<int64_t> replicas;
341 const auto key_devices = key.devices();
342 replicas.reserve(key_devices.size());
343 for (auto& device : key_devices) {
344 replicas.push_back(device.value());
345 }
346 return nccl_factory->GetUniqueId(replicas);
347 };
348 gpu_options.set_nccl_unique_id_callback(uid_callback);
349 }
350 run_options.set_gpu_executable_run_options(&gpu_options);
351
352 const std::vector<xla::ShapeLayout>& shape_layouts =
353 executable->executable()
354 ->module_config()
355 .entry_computation_layout()
356 .parameter_layouts();
357 TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, run_options.allocator(),
358 &execution_inputs, shape_layouts));
359 TF_ASSIGN_OR_RETURN(
360 xla::ExecutionOutput run_result,
361 executable->Run(std::move(execution_inputs), run_options));
362
363 TF_ASSIGN_OR_RETURN(
364 RefPtr<XRTTupleAllocation> output_tuple_ptr,
365 CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
366 device_ref->device_ordinal(), allocator));
367 // The ScopedShapedBuffer returned by the executable Run() API, in case of
368 // input/output buffer aliasing, might have holes in it, which need to be
369 // filled using the proper input tuples buffers which are the source of
370 // aliasing.
371 TF_RETURN_IF_ERROR(RebuildOutputAliases(
372 output_tuple_ptr, input_tuples,
373 executable->executable()->module().input_output_alias_config()));
374
375 return std::move(output_tuple_ptr);
376 }
377
ExecuteComputation(OpKernelContext * context,XRTMemoryManager * memory_manager,XRTGenericDeviceAccessor::ScopedRef * device_ref,xla::LocalExecutable * executable,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,bool release_inputs,se::Stream * stream,int rng_seed,const xrt::CommonExecutionConfig & config)378 xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
379 OpKernelContext* context, XRTMemoryManager* memory_manager,
380 XRTGenericDeviceAccessor::ScopedRef* device_ref,
381 xla::LocalExecutable* executable,
382 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
383 bool release_inputs, se::Stream* stream, int rng_seed,
384 const xrt::CommonExecutionConfig& config) {
385 auto runfn = [&]() {
386 return RunExecutable(context, device_ref, executable, input_tuples,
387 release_inputs, stream, rng_seed, config);
388 };
389
390 // We pass zero as requested_free_size as there is no simple way to get the
391 // peak heap size. Upon zero, the Run() API will try to free chunks of device
392 // memory, until either the runfn can run, or we run out of freeable memory.
393 return memory_manager->Run<RefPtr<XRTTupleAllocation>>(
394 runfn, device_ref->backend(), device_ref->device_ordinal(),
395 /*requested_free_size=*/0, device_ref->allocator());
396 }
397
ExecuteComputation(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,XRTGenericDeviceAccessor::ScopedRef * device_ref,xla::LocalExecutable * executable,const std::vector<InputCoords> & input_coords,bool release_inputs,se::Stream * stream,int rng_seed,const xrt::CommonExecutionConfig & config)398 xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
399 OpKernelContext* context, const RefPtr<XRTMemoryManager>& memory_manager,
400 XRTGenericDeviceAccessor::ScopedRef* device_ref,
401 xla::LocalExecutable* executable,
402 const std::vector<InputCoords>& input_coords, bool release_inputs,
403 se::Stream* stream, int rng_seed,
404 const xrt::CommonExecutionConfig& config) {
405 XRTMemoryManager::WorkingSet working_set(memory_manager);
406 TF_ASSIGN_OR_RETURN(
407 std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
408 GetInputTuples(executable, &working_set, device_ref->backend(),
409 input_coords, release_inputs, device_ref->allocator()));
410 return ExecuteComputation(context, memory_manager.get(), device_ref,
411 executable, input_tuples, release_inputs, stream,
412 rng_seed, config);
413 }
414
415 // XRTExecuteOp
416
417 class XRTExecuteOp : public AsyncOpKernel {
418 public:
419 explicit XRTExecuteOp(OpKernelConstruction* context);
420 ~XRTExecuteOp() override;
421
422 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
423
424 private:
425 Status DoWork(OpKernelContext* context);
426 };
427
XRTExecuteOp(OpKernelConstruction * context)428 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
429 : AsyncOpKernel(context) {}
430
ComputeAsync(OpKernelContext * context,DoneCallback done)431 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
432 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
433 Env::Default()->SchedClosure([this, context, done]() {
434 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
435 done();
436 });
437 }
438
DoWork(OpKernelContext * context)439 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
440 VLOG(1) << "XRTExecuteOp::Compute";
441 auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell());
442 ResourceMgr* rm;
443 TF_RETURN_IF_ERROR(
444 XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
445
446 const Tensor& execution_input = context->input(0);
447 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
448 int64_t compilation_handle = execution_input.scalar<int64_t>()();
449
450 const Tensor& execution_config = context->input(1);
451 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
452 xrt::XRTExecutionConfig config_proto;
453 TF_RET_CHECK(
454 ParseFromTString(execution_config.scalar<tstring>()(), &config_proto));
455
456 int core_index_in_replica = config_proto.core_index_in_replica();
457 TF_RET_CHECK(core_index_in_replica == 0);
458 bool release_inputs = config_proto.release_input_handles();
459 bool release_compilation = config_proto.release_compilation_handle();
460
461 TF_ASSIGN_OR_RETURN(auto cache,
462 XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
463 context, /*max_number_of_entries=*/0));
464 // We are guaranteed that the underlying device object won't be deleted out
465 // from under us, while the ScopedRef is live.
466 class XRTGenericDeviceAccessor::ScopedRef device_ref;
467 TF_RETURN_IF_ERROR(
468 XRTGenericDeviceAccessor::InitScopedRef(context, &device_ref));
469
470 int rng_seed = config_proto.rng_seed();
471 if (rng_seed == 0) {
472 rng_seed = GetXLARandomSeed();
473 }
474
475 se::Stream* stream = context->op_device_context()
476 ? context->op_device_context()->stream()
477 : nullptr;
478 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
479 TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
480 GetComputationInputs(context, "input_handles"));
481
482 std::unique_ptr<XRTCompilationCacheEntryRef> entry;
483 TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
484 xla::LocalExecutable* executable = entry->get().get_executable();
485 if (release_compilation) {
486 // Process-wide cache of XLA executables.
487 TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
488 VLOG(2) << "Released compilation handle " << compilation_handle;
489 }
490
491 TF_ASSIGN_OR_RETURN(
492 RefPtr<XRTTupleAllocation> output_tuple,
493 ExecuteComputation(context, memory_manager, &device_ref, executable,
494 input_coords, release_inputs, stream, rng_seed,
495 config_proto.common_config()));
496
497 return CreateExecuteOutput(context, memory_manager.get(),
498 std::move(output_tuple),
499 config_proto.return_exploded_tuple());
500 }
501
502 XRTExecuteOp::~XRTExecuteOp() = default;
503
504 class XRTExecuteChainedOp : public AsyncOpKernel {
505 public:
506 explicit XRTExecuteChainedOp(OpKernelConstruction* context);
507 ~XRTExecuteChainedOp() override;
508
509 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
510
511 private:
512 Status DoWork(OpKernelContext* context);
513 };
514
XRTExecuteChainedOp(OpKernelConstruction * context)515 XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context)
516 : AsyncOpKernel(context) {}
517
ComputeAsync(OpKernelContext * context,DoneCallback done)518 void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context,
519 DoneCallback done) {
520 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
521 Env::Default()->SchedClosure([this, context, done]() {
522 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
523 done();
524 });
525 }
526
DoWork(OpKernelContext * context)527 Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
528 VLOG(1) << "XRTExecuteChainedOp::Compute";
529 auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell());
530 ResourceMgr* rm;
531 TF_RETURN_IF_ERROR(
532 XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
533
534 const Tensor& execution_plan = context->input(0);
535 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape()));
536 xrt::XRTChainedExecutePlan plan;
537 TF_RET_CHECK(ParseFromTString(execution_plan.scalar<tstring>()(), &plan));
538
539 const Tensor& execution_config = context->input(1);
540 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
541 xrt::XRTChainedExecuteConfig config;
542 TF_RET_CHECK(ParseFromTString(execution_config.scalar<tstring>()(), &config));
543
544 TF_ASSIGN_OR_RETURN(auto cache,
545 XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
546 context, /*max_number_of_entries=*/0));
547 // We are guaranteed that the underlying device object won't be deleted out
548 // from under us, while the ScopedRef is live.
549 class XRTGenericDeviceAccessor::ScopedRef device_ref;
550 TF_RETURN_IF_ERROR(
551 XRTGenericDeviceAccessor::InitScopedRef(context, &device_ref));
552
553 int rng_seed = config.rng_seed();
554 if (rng_seed == 0) {
555 rng_seed = GetXLARandomSeed();
556 }
557
558 se::Stream* stream = context->op_device_context()
559 ? context->op_device_context()->stream()
560 : nullptr;
561 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
562 auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
563 absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
564 -> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
565 std::unique_ptr<XRTCompilationCacheEntryRef> entry;
566 TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry));
567 xla::LocalExecutable* executable = entry->get().get_executable();
568
569 TF_ASSIGN_OR_RETURN(std::vector<RefPtr<XRTTupleAllocation>> input_tuples,
570 GetChainedOpInputTuples(op, op_inputs));
571
572 return ExecuteComputation(
573 context, memory_manager.get(), &device_ref, executable, input_tuples,
574 /*release_inputs=*/false, stream, rng_seed, config.common_config());
575 };
576
577 return ExecuteChained(context, memory_manager, device_ref.backend(),
578 device_ref.device_ordinal(), plan, config, execute_op,
579 device_ref.allocator());
580 }
581
582 XRTExecuteChainedOp::~XRTExecuteChainedOp() = default;
583
584 } // namespace
585
586 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
587 .Device(DEVICE_XLA_CPU)
588 .HostMemory("computation_handle")
589 .HostMemory("execution_config")
590 .HostMemory("input_handles")
591 .HostMemory("output_handle"),
592 XRTExecuteOp);
593
594 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
595 .Device(DEVICE_XLA_GPU)
596 .HostMemory("computation_handle")
597 .HostMemory("execution_config")
598 .HostMemory("input_handles")
599 .HostMemory("output_handle"),
600 XRTExecuteOp);
601
602 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
603 .Device(DEVICE_XLA_CPU)
604 .HostMemory("execution_plan")
605 .HostMemory("execution_config")
606 .HostMemory("output_handle"),
607 XRTExecuteChainedOp);
608
609 REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained")
610 .Device(DEVICE_XLA_GPU)
611 .HostMemory("execution_plan")
612 .HostMemory("execution_config")
613 .HostMemory("output_handle"),
614 XRTExecuteChainedOp);
615
616 } // namespace tensorflow
617