• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/service/computation_placer.h"
22 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xrt/xrt.pb.h"
27 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
28 #include "tensorflow/compiler/xrt/xrt_device.h"
29 #include "tensorflow/compiler/xrt/xrt_state.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/resource_mgr.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/lib/core/refcount.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/stream_executor/stream_executor.h"
38 #include "tensorflow/stream_executor/stream_executor_internal.h"
39 
40 namespace tensorflow {
41 
42 namespace {
43 
InitialRandomSeed()44 uint32 InitialRandomSeed() {
45   // Support plumbing the TF seed through to XLA is being worked on.
46   // If a user wants deterministic behavior, their best option
47   // is to start with a known checkpoint. This also handles issues when
48   // multiple random calls can be invoked in any order by TF executor.
49   // Another option is to use stateless random ops. They have much cleaner
50   // semantics.
51   // If a user really wants to set a deterministic seed for XLA-based
52   // devices, this is the place to do it.
53   std::random_device rd;
54   // Make the starting value odd.
55   return rd() | 1;
56 }
57 
GetXLARandomSeed()58 uint32 GetXLARandomSeed() {
59   // We initialize counter with an odd number and increment it by two
60   // everytime. This ensures that it will never be zero, even
61   // after an overflow. When seeded with zero, some XLA backends
62   // can return all zeros instead of random numbers.
63   static std::atomic<uint32> counter(InitialRandomSeed());
64   return counter.fetch_add(2);
65 }
66 
67 // Populates `inputs` with the input tensors to the computation.
GetComputationInputs(OpKernelContext * context,ResourceMgr * rm,bool release_inputs,std::vector<XRTTupleAllocation * > * input_tuples,std::vector<xla::ShapedBuffer> * input_allocations,std::vector<xla::ShapedBuffer * > * input_pointers)68 Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
69                             bool release_inputs,
70                             std::vector<XRTTupleAllocation*>* input_tuples,
71                             std::vector<xla::ShapedBuffer>* input_allocations,
72                             std::vector<xla::ShapedBuffer*>* input_pointers) {
73   std::vector<int64> input_uids;
74   OpInputList arg_list;
75   TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
76 
77   // Concatenate all input uids from list of scalars-or-vectors carrying them.
78   for (int i = 0; i < arg_list.size(); ++i) {
79     const Tensor& arg = arg_list[i];
80     if (TensorShapeUtils::IsScalar(arg.shape())) {
81       input_uids.push_back(arg.scalar<int64>()());
82     } else {
83       TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
84       auto arg_vec = arg.vec<int64>();
85       const int64 num_elts = arg.shape().dim_size(0);
86       for (int i = 0; i < num_elts; ++i) {
87         input_uids.push_back(arg_vec(i));
88       }
89     }
90   }
91 
92   // Retrieve allocations for the uids.
93   input_tuples->resize(input_uids.size());
94   input_pointers->resize(input_uids.size());
95   for (int i = 0; i < input_uids.size(); ++i) {
96     const int64 input_uid = input_uids[i];
97     TF_RETURN_IF_ERROR(
98         XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
99     if (release_inputs) {
100       // We are holding a reference to the tuple, so we can safely delete it
101       // from the resource manager here.
102       TF_RETURN_IF_ERROR(
103           XRTTupleAllocation::DeleteFromResourceManager(rm, input_uid));
104       VLOG(2) << "Released allocation handle " << input_uid;
105     }
106     XRTTupleAllocation* tuple = (*input_tuples)[i];
107     input_allocations->emplace_back(tuple->ToShapedBuffer());
108   }
109   for (int i = 0; i < input_uids.size(); ++i) {
110     (*input_pointers)[i] = &(*input_allocations)[i];
111   }
112   return Status::OK();
113 }
114 
115 // XRTExecuteOp
116 
117 class XRTExecuteOp : public AsyncOpKernel {
118  public:
119   explicit XRTExecuteOp(OpKernelConstruction* context);
120   ~XRTExecuteOp() override;
121 
122   void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
123 
124  private:
125   Status DoWork(OpKernelContext* context);
126 };
127 
XRTExecuteOp(OpKernelConstruction * context)128 XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
129     : AsyncOpKernel(context) {}
130 
ComputeAsync(OpKernelContext * context,DoneCallback done)131 void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
132   // Schedule onto the default queue, for unbounded concurrency. See b/73520706
133   Env::Default()->SchedClosure([this, context, done]() {
134     OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
135     done();
136   });
137 }
138 
DoWork(OpKernelContext * context)139 Status XRTExecuteOp::DoWork(OpKernelContext* context) {
140   VLOG(1) << "XRTExecuteOp::Compute";
141   ResourceMgr* rm;
142   TF_RETURN_IF_ERROR(
143       XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
144 
145   const Tensor& execution_input = context->input(0);
146   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
147   int64 compilation_handle = execution_input.scalar<int64>()();
148 
149   const Tensor& execution_config = context->input(1);
150   TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
151   xrt::XRTExecutionConfig config_proto;
152   TF_RET_CHECK(
153       config_proto.ParseFromString(execution_config.scalar<string>()()));
154 
155   int core_index_in_replica = config_proto.core_index_in_replica();
156   TF_RET_CHECK(core_index_in_replica == 0);
157   bool release_inputs = config_proto.release_input_handles();
158   bool release_compilation = config_proto.release_compilation_handle();
159 
160   XRTCompilationCache* cache;
161   TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
162       rm->default_container(), kXRTCompilationCacheResourceName, &cache));
163   core::ScopedUnref cache_unref(cache);
164 
165   std::unique_ptr<XRTCompilationCacheEntryRef> entry;
166   TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
167 
168   if (release_compilation) {
169     // Process-wide cache of XLA executables.
170     TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
171     VLOG(2) << "Released compilation handle " << compilation_handle;
172   }
173 
174   std::vector<XRTTupleAllocation*> input_tuples;
175   // Make a cleanup method so that we can safely return in error conditions
176   // without leaking references to allocations.
177   auto buffer_releaser = gtl::MakeCleanup([&input_tuples]() {
178     for (auto tuple : input_tuples) {
179       if (tuple != nullptr) {
180         tuple->Unref();
181       }
182     }
183   });
184   std::vector<xla::ShapedBuffer> input_allocations;
185   std::vector<xla::ShapedBuffer*> input_pointers;
186   TF_RETURN_IF_ERROR(GetComputationInputs(context, rm, release_inputs,
187                                           &input_tuples, &input_allocations,
188                                           &input_pointers));
189 
190   // We are guaranteed that the underlying device object won't be deleted out
191   // from under us, while the ScopedRef is live.
192   class XRTGenericDeviceAccessor::ScopedRef device_ref;
193   TF_RETURN_IF_ERROR(
194       XRTGenericDeviceAccessor::InitScopedRef(context, 0, &device_ref));
195 
196   int rng_seed = config_proto.rng_seed();
197   if (rng_seed == 0) {
198     rng_seed = GetXLARandomSeed();
199   }
200 
201   se::Stream* stream = context->op_device_context()
202                            ? context->op_device_context()->stream()
203                            : nullptr;
204 
205   // Execute the computation.
206   VLOG(2) << "Executing computation.";
207   xla::ExecutableRunOptions run_options;
208   run_options.set_stream(stream);
209   run_options.set_allocator(device_ref.backend()->memory_allocator());
210   run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
211   run_options.set_rng_seed(rng_seed);
212 
213   Env* env = Env::Default();
214   auto start_time = env->NowMicros();
215 
216   xla::LocalExecutable* executable = entry->get().get_executable();
217   auto run_result = executable->Run(input_pointers, run_options);
218   if (!run_result.ok()) {
219     return run_result.status();
220   }
221 
222   auto elapsed = env->NowMicros() - start_time;
223   VLOG(2) << "Elapsed time: " << elapsed << "us";
224 
225   auto scoped_buffer = run_result.ConsumeValueOrDie();
226   auto shaped_buffer = scoped_buffer.release();
227   XRTTupleAllocation* output_tuple;
228   TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
229       shaped_buffer, device_ref.backend(), device_ref.device_ordinal(),
230       &output_tuple));
231 
232   // The ScopedShapedBuffer returned by the executable Run() API, in case of
233   // input/output buffer aliasing, might have holes in it, which need to be
234   // filled using the proper input tuples buffers which are the source of
235   // aliasing.
236   const xla::HloInputOutputAliasConfig& input_output_alias =
237       executable->executable()->module().input_output_alias_config();
238   auto alias_function =
239       [&](const xla::ShapeIndex& output_index,
240           const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
241     TF_RET_CHECK(alias.parameter_number < input_tuples.size());
242     return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
243                ? output_tuple->AliasBufferFrom(
244                      *input_tuples[alias.parameter_number],
245                      alias.parameter_index, output_index)
246                : Status::OK();
247   };
248   TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function));
249 
250   if (config_proto.return_exploded_tuple() &&
251       output_tuple->on_device_shape().IsTuple()) {
252     int64 tuple_element_count =
253         xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
254     Tensor* output_tensor;
255     TF_RETURN_IF_ERROR(context->allocate_output(
256         0, TensorShape({tuple_element_count}), &output_tensor));
257 
258     for (int64 i = 0; i < tuple_element_count; ++i) {
259       xla::ShapeIndex shape_index;
260       shape_index.push_back(i);
261 
262       XRTTupleAllocation* suballocation;
263       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
264           output_tuple, shape_index, &suballocation,
265           /*alias_parent_allocation=*/false));
266       int64 key;
267       TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key));
268       output_tensor->vec<int64>()(i) = key;
269     }
270     output_tuple->Unref();
271   } else {
272     Tensor* output_tensor;
273     TF_RETURN_IF_ERROR(
274         context->allocate_output(0, TensorShape({}), &output_tensor));
275     int64 key;
276     TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
277     output_tensor->scalar<int64>()() = key;
278   }
279   return Status::OK();
280 }
281 
282 XRTExecuteOp::~XRTExecuteOp() = default;
283 
284 }  // namespace
285 
286 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
287                             .Device(DEVICE_XLA_CPU)
288                             .HostMemory("computation_handle")
289                             .HostMemory("execution_config")
290                             .HostMemory("input_handles")
291                             .HostMemory("output_handle"),
292                         XRTExecuteOp);
293 
294 REGISTER_KERNEL_BUILDER(Name("XRTExecute")
295                             .Device(DEVICE_XLA_GPU)
296                             .HostMemory("computation_handle")
297                             .HostMemory("execution_config")
298                             .HostMemory("input_handles")
299                             .HostMemory("output_handle"),
300                         XRTExecuteOp);
301 
302 }  // namespace tensorflow
303