1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/service/computation_placer.h"
22 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xrt/xrt.pb.h"
27 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
28 #include "tensorflow/compiler/xrt/xrt_device.h"
29 #include "tensorflow/compiler/xrt/xrt_state.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/resource_mgr.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/lib/core/refcount.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/stream_executor/stream_executor.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39
40 namespace tensorflow {
41
42 namespace {
43
InitialRandomSeed()44 uint32 InitialRandomSeed() {
45 // Support plumbing the TF seed through to XLA is being worked on.
46 // If a user wants deterministic behavior, their best option
47 // is to start with a known checkpoint. This also handles issues when
48 // multiple random calls can be invoked in any order by TF executor.
49 // Another option is to use stateless random ops. They have much cleaner
50 // semantics.
51 // If a user really wants to set a deterministic seed for XLA-based
52 // devices, this is the place to do it.
53 std::random_device rd;
54 // Make the starting value odd.
55 return rd() | 1;
56 }
57
GetXLARandomSeed()58 uint32 GetXLARandomSeed() {
59 // We initialize counter with an odd number and increment it by two
60 // everytime. This ensures that it will never be zero, even
61 // after an overflow. When seeded with zero, some XLA backends
62 // can return all zeros instead of random numbers.
63 static std::atomic<uint32> counter(InitialRandomSeed());
64 return counter.fetch_add(2);
65 }
66
67 // Populates `inputs` with the input tensors to the computation.
GetComputationInputs(OpKernelContext * context,ResourceMgr * rm,bool release_inputs,std::vector<XRTTupleAllocation * > * input_tuples,std::vector<xla::ShapedBuffer> * input_allocations,std::vector<xla::ShapedBuffer * > * input_pointers)68 Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
69 bool release_inputs,
70 std::vector<XRTTupleAllocation*>* input_tuples,
71 std::vector<xla::ShapedBuffer>* input_allocations,
72 std::vector<xla::ShapedBuffer*>* input_pointers) {
73 std::vector<int64> input_uids;
74 OpInputList arg_list;
75 TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
76
77 // Concatenate all input uids from list of scalars-or-vectors carrying them.
78 for (int i = 0; i < arg_list.size(); ++i) {
79 const Tensor& arg = arg_list[i];
80 if (TensorShapeUtils::IsScalar(arg.shape())) {
81 input_uids.push_back(arg.scalar<int64>()());
82 } else {
83 TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
84 auto arg_vec = arg.vec<int64>();
85 const int64 num_elts = arg.shape().dim_size(0);
86 for (int i = 0; i < num_elts; ++i) {
87 input_uids.push_back(arg_vec(i));
88 }
89 }
90 }
91
92 // Retrieve allocations for the uids.
93 input_tuples->resize(input_uids.size());
94 input_pointers->resize(input_uids.size());
95 for (int i = 0; i < input_uids.size(); ++i) {
96 const int64 input_uid = input_uids[i];
97 TF_RETURN_IF_ERROR(
98 XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
99 if (release_inputs) {
100 // We are holding a reference to the tuple, so we can safely delete it
101 // from the resource manager here.
102 TF_RETURN_IF_ERROR(
103 XRTTupleAllocation::DeleteFromResourceManager(rm, input_uid));
104 VLOG(2) << "Released allocation handle " << input_uid;
105 }
106 XRTTupleAllocation* tuple = (*input_tuples)[i];
107 input_allocations->emplace_back(tuple->ToShapedBuffer());
108 }
109 for (int i = 0; i < input_uids.size(); ++i) {
110 (*input_pointers)[i] = &(*input_allocations)[i];
111 }
112 return Status::OK();
113 }
114
115 // XRTExecuteOp
116
117 class XRTExecuteOp : public AsyncOpKernel {
118 public:
119 explicit XRTExecuteOp(OpKernelConstruction* context);
120 ~XRTExecuteOp() override;
121
122 void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
123
124 private:
125 Status DoWork(OpKernelContext* context);
126 };
127
XRTExecuteOp(OpKernelConstruction * context)128 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
129 : AsyncOpKernel(context) {}
130
ComputeAsync(OpKernelContext * context,DoneCallback done)131 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
132 // Schedule onto the default queue, for unbounded concurrency. See b/73520706
133 Env::Default()->SchedClosure([this, context, done]() {
134 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
135 done();
136 });
137 }
138
DoWork(OpKernelContext * context)139 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
140 VLOG(1) << "XRTExecuteOp::Compute";
141 ResourceMgr* rm;
142 TF_RETURN_IF_ERROR(
143 XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
144
145 const Tensor& execution_input = context->input(0);
146 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
147 int64 compilation_handle = execution_input.scalar<int64>()();
148
149 const Tensor& execution_config = context->input(1);
150 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
151 xrt::XRTExecutionConfig config_proto;
152 TF_RET_CHECK(
153 config_proto.ParseFromString(execution_config.scalar<string>()()));
154
155 int core_index_in_replica = config_proto.core_index_in_replica();
156 TF_RET_CHECK(core_index_in_replica == 0);
157 bool release_inputs = config_proto.release_input_handles();
158 bool release_compilation = config_proto.release_compilation_handle();
159
160 XRTCompilationCache* cache;
161 TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
162 rm->default_container(), kXRTCompilationCacheResourceName, &cache));
163 core::ScopedUnref cache_unref(cache);
164
165 std::unique_ptr<XRTCompilationCacheEntryRef> entry;
166 TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
167
168 if (release_compilation) {
169 // Process-wide cache of XLA executables.
170 TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
171 VLOG(2) << "Released compilation handle " << compilation_handle;
172 }
173
174 std::vector<XRTTupleAllocation*> input_tuples;
175 // Make a cleanup method so that we can safely return in error conditions
176 // without leaking references to allocations.
177 auto buffer_releaser = gtl::MakeCleanup([&input_tuples]() {
178 for (auto tuple : input_tuples) {
179 if (tuple != nullptr) {
180 tuple->Unref();
181 }
182 }
183 });
184 std::vector<xla::ShapedBuffer> input_allocations;
185 std::vector<xla::ShapedBuffer*> input_pointers;
186 TF_RETURN_IF_ERROR(GetComputationInputs(context, rm, release_inputs,
187 &input_tuples, &input_allocations,
188 &input_pointers));
189
190 // We are guaranteed that the underlying device object won't be deleted out
191 // from under us, while the ScopedRef is live.
192 class XRTGenericDeviceAccessor::ScopedRef device_ref;
193 TF_RETURN_IF_ERROR(
194 XRTGenericDeviceAccessor::InitScopedRef(context, 0, &device_ref));
195
196 int rng_seed = config_proto.rng_seed();
197 if (rng_seed == 0) {
198 rng_seed = GetXLARandomSeed();
199 }
200
201 se::Stream* stream = context->op_device_context()
202 ? context->op_device_context()->stream()
203 : nullptr;
204
205 // Execute the computation.
206 VLOG(2) << "Executing computation.";
207 xla::ExecutableRunOptions run_options;
208 run_options.set_stream(stream);
209 run_options.set_allocator(device_ref.backend()->memory_allocator());
210 run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
211 run_options.set_rng_seed(rng_seed);
212
213 Env* env = Env::Default();
214 auto start_time = env->NowMicros();
215
216 xla::LocalExecutable* executable = entry->get().get_executable();
217 auto run_result = executable->Run(input_pointers, run_options);
218 if (!run_result.ok()) {
219 return run_result.status();
220 }
221
222 auto elapsed = env->NowMicros() - start_time;
223 VLOG(2) << "Elapsed time: " << elapsed << "us";
224
225 auto scoped_buffer = run_result.ConsumeValueOrDie();
226 auto shaped_buffer = scoped_buffer.release();
227 XRTTupleAllocation* output_tuple;
228 TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
229 shaped_buffer, device_ref.backend(), device_ref.device_ordinal(),
230 &output_tuple));
231
232 // The ScopedShapedBuffer returned by the executable Run() API, in case of
233 // input/output buffer aliasing, might have holes in it, which need to be
234 // filled using the proper input tuples buffers which are the source of
235 // aliasing.
236 const xla::HloInputOutputAliasConfig& input_output_alias =
237 executable->executable()->module().input_output_alias_config();
238 auto alias_function =
239 [&](const xla::ShapeIndex& output_index,
240 const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
241 TF_RET_CHECK(alias.parameter_number < input_tuples.size());
242 return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
243 ? output_tuple->AliasBufferFrom(
244 *input_tuples[alias.parameter_number],
245 alias.parameter_index, output_index)
246 : Status::OK();
247 };
248 TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function));
249
250 if (config_proto.return_exploded_tuple() &&
251 output_tuple->on_device_shape().IsTuple()) {
252 int64 tuple_element_count =
253 xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
254 Tensor* output_tensor;
255 TF_RETURN_IF_ERROR(context->allocate_output(
256 0, TensorShape({tuple_element_count}), &output_tensor));
257
258 for (int64 i = 0; i < tuple_element_count; ++i) {
259 xla::ShapeIndex shape_index;
260 shape_index.push_back(i);
261
262 XRTTupleAllocation* suballocation;
263 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
264 output_tuple, shape_index, &suballocation,
265 /*alias_parent_allocation=*/false));
266 int64 key;
267 TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key));
268 output_tensor->vec<int64>()(i) = key;
269 }
270 output_tuple->Unref();
271 } else {
272 Tensor* output_tensor;
273 TF_RETURN_IF_ERROR(
274 context->allocate_output(0, TensorShape({}), &output_tensor));
275 int64 key;
276 TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
277 output_tensor->scalar<int64>()() = key;
278 }
279 return Status::OK();
280 }
281
282 XRTExecuteOp::~XRTExecuteOp() = default;
283
284 } // namespace
285
286 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
287 .Device(DEVICE_XLA_CPU)
288 .HostMemory("computation_handle")
289 .HostMemory("execution_config")
290 .HostMemory("input_handles")
291 .HostMemory("output_handle"),
292 XRTExecuteOp);
293
294 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
295 .Device(DEVICE_XLA_GPU)
296 .HostMemory("computation_handle")
297 .HostMemory("execution_config")
298 .HostMemory("input_handles")
299 .HostMemory("output_handle"),
300 XRTExecuteOp);
301
302 } // namespace tensorflow
303