1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_util.h"
17
18 #include <stdlib.h>
19 #include <string.h>
20
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 namespace tensorflow {
27 namespace {
28
29 mutex nccl_factory_mutex(LINKER_INITIALIZED);
30 std::shared_ptr<NcclUniqueIdFactory>* nccl_factory;
31
32 // The ScopedHandles data structure is used in the ExecuteChained() API and its
33 // task is to track tuple allocation registrations. It is used both the track
34 // intermediate results of a chained computation, or its final results. Anything
35 // which is marked to be released, will be released using the XRTMemoryManager
36 // once the object is destroyed (unless an explicit call to Drop() or Release()
37 // is made).
38 class ScopedHandles {
39 public:
ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)40 explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
41 : memory_manager_(std::move(memory_manager)) {}
42
~ScopedHandles()43 ~ScopedHandles() {
44 for (size_t i = 0; i < handles_.size(); ++i) {
45 if (handles_release_[i]) {
46 memory_manager_->Release(handles_[i]).IgnoreError();
47 }
48 }
49 }
50
operator [](size_t index) const51 int64 operator[](size_t index) const { return handles_.at(index); }
52
size() const53 size_t size() const { return handles_.size(); }
54
55 // Adds the given handle at the index position, by marking it releasable
56 // according to the release argument. If an existing, and to-be-released
57 // handle already exists at the same index, it will be released.
Add(size_t index,int64 handle,bool release)58 Status Add(size_t index, int64 handle, bool release) {
59 if (index >= handles_.size()) {
60 handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
61 handles_release_.resize(index + 1, false);
62 }
63 if (handles_release_[index]) {
64 Status status = memory_manager_->Release(handles_[index]);
65 if (!status.ok()) {
66 if (release) {
67 memory_manager_->Release(handle).IgnoreError();
68 }
69 return status;
70 }
71 }
72 handles_[index] = handle;
73 handles_release_[index] = release;
74 return Status::OK();
75 }
76
77 // Adds a to-be-released tuple allocation at the given index.
Add(size_t index,RefPtr<XRTTupleAllocation> tuple)78 Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
79 return Add(index, memory_manager_->Register(std::move(tuple)),
80 /*release=*/true);
81 }
82
83 // Drops the handle at the given index, and releases it using the
84 // XRTMemoryManager::Release() if marked as to-be-released.
Drop(size_t index)85 Status Drop(size_t index) {
86 if (handles_release_.at(index)) {
87 TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
88 }
89 Release(index);
90 return Status::OK();
91 }
92
93 // Releases the handle at the given index. The destructor will not use that
94 // XRTMemoryManager::Release() API on such handle.
Release(size_t index)95 int64 Release(size_t index) {
96 int64 handle = handles_.at(index);
97 handles_[index] = XRTMemoryManager::InvalidKey();
98 handles_release_[index] = false;
99 return handle;
100 }
101
102 // Looks up the handle stored at the given index, and returns the matching
103 // tuple allocation.
Lookup(size_t index) const104 xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
105 return memory_manager_->Lookup(handles_.at(index));
106 }
107
108 private:
109 RefPtr<XRTMemoryManager> memory_manager_;
110 std::vector<int64> handles_;
111 std::vector<bool> handles_release_;
112 };
113
DebugOptionsPassThroughEnabled()114 bool DebugOptionsPassThroughEnabled() {
115 const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
116 bool enabled =
117 env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
118 if (enabled) {
119 LOG(WARNING) << "Passing through XLA debug options!";
120 } else {
121 LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
122 "will be retained";
123 }
124 return enabled;
125 }
126
SafeDebugPath(const string & path)127 string SafeDebugPath(const string& path) {
128 if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
129 path.compare(0, 11, "bigstore://") == 0) {
130 return path;
131 }
132 LOG(WARNING) << "Invalid config path (will be dropped): " << path;
133 return string();
134 }
135
MakeOutput(const RefPtr<XRTTupleAllocation> & output,int64 index,RefPtr<XRTTupleAllocation> * result)136 Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64 index,
137 RefPtr<XRTTupleAllocation>* result) {
138 if (index == 0) {
139 *result = output;
140 } else {
141 XRTTupleAllocation* tuple;
142 TF_RETURN_IF_ERROR(
143 XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple,
144 /*alias_parent_allocation=*/true));
145 result->reset(tuple);
146 }
147 return Status::OK();
148 }
149
PopulateOpWorkingSet(xla::Backend * backend,const xrt::XRTChainedExecuteOp & op,int current_index,const ScopedHandles & outputs,XRTMemoryManager::WorkingSet * working_set)150 Status PopulateOpWorkingSet(xla::Backend* backend,
151 const xrt::XRTChainedExecuteOp& op,
152 int current_index, const ScopedHandles& outputs,
153 XRTMemoryManager::WorkingSet* working_set) {
154 for (int i = 0; i < op.inputs_size(); ++i) {
155 auto& input = op.inputs(i);
156 if (input.op_index() >= current_index) {
157 return errors::InvalidArgument(
158 "Input index ", input.op_index(),
159 " is above the current position: ", current_index);
160 }
161 TF_RETURN_IF_ERROR(
162 working_set->LookupAndPin(backend, outputs[input.op_index()]));
163 }
164 return Status::OK();
165 }
166
167 } // namespace
168
SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory)169 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory) {
170 mutex_lock lock(nccl_factory_mutex);
171 if (nccl_factory == nullptr) {
172 nccl_factory = new std::shared_ptr<NcclUniqueIdFactory>();
173 }
174 *nccl_factory = std::move(factory);
175 }
176
GetNcclUniqueIdFactory()177 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory() {
178 mutex_lock lock(nccl_factory_mutex);
179 return nccl_factory != nullptr ? *nccl_factory : nullptr;
180 }
181
BuildXlaDebugOptions(const xla::DebugOptions & ref_options)182 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
183 static const bool options_passthrough = DebugOptionsPassThroughEnabled();
184 if (options_passthrough) {
185 return ref_options;
186 }
187 xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
188 options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to()));
189 options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto());
190 options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text());
191 options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots());
192 options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re());
193 options.set_xla_dump_include_timestamp(
194 ref_options.xla_dump_include_timestamp());
195 options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules());
196 for (auto& pass : ref_options.xla_disable_hlo_passes()) {
197 options.add_xla_disable_hlo_passes(pass);
198 }
199 return options;
200 }
201
GetComputationInputs(OpKernelContext * context,const char * input_name)202 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
203 OpKernelContext* context, const char* input_name) {
204 OpInputList arg_list;
205 TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
206 // Concatenate all input uids from list of scalars-or-vectors carrying them.
207 std::vector<InputCoords> input_coords;
208 for (int i = 0; i < arg_list.size(); ++i) {
209 const Tensor& arg = arg_list[i];
210 if (TensorShapeUtils::IsScalar(arg.shape())) {
211 input_coords.emplace_back(arg.scalar<int64>()());
212 } else {
213 TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
214 auto arg_vec = arg.vec<int64>();
215 const int64 num_elts = arg.shape().dim_size(0);
216 for (int i = 0; i < num_elts; ++i) {
217 input_coords.emplace_back(arg_vec(i));
218 }
219 }
220 }
221 return std::move(input_coords);
222 }
223
InputShapeMatches(const xla::Shape & parameter_shape,const xla::Shape & input_shape)224 bool InputShapeMatches(const xla::Shape& parameter_shape,
225 const xla::Shape& input_shape) {
226 auto shape_checker = [&](const xla::Shape& pshape,
227 const xla::ShapeIndex& index) {
228 if (pshape.IsArray()) {
229 TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
230 xla::ShapeUtil::TryGetSubshape(input_shape, index));
231 if (pshape.rank() != ishape->rank() ||
232 pshape.element_type() != ishape->element_type()) {
233 return errors::InvalidArgument("Mismatching shapes");
234 }
235 if (pshape.is_static() && pshape.layout() != ishape->layout()) {
236 return errors::InvalidArgument("Mismatching layouts");
237 }
238 for (int64 dim = 0; dim < pshape.rank(); ++dim) {
239 if (pshape.is_dynamic_dimension(dim)) {
240 if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
241 return errors::InvalidArgument("Mismatching shapes");
242 }
243 } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
244 return errors::InvalidArgument("Mismatching shapes");
245 }
246 }
247 }
248 return Status::OK();
249 };
250 return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
251 shape_checker)
252 .ok();
253 }
254
GetInputTupleAllocations(const std::vector<InputCoords> & input_coords,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,int64 num_input_shapes,const std::function<xla::Shape (int64)> & shape_getter,bool release_inputs)255 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
256 const std::vector<InputCoords>& input_coords,
257 XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
258 int64 num_input_shapes,
259 const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs) {
260 if (input_coords.size() != num_input_shapes) {
261 return errors::InvalidArgument(
262 "Number of inputs does not match executable proto input shapes: ",
263 input_coords.size(), " vs. ", num_input_shapes);
264 }
265 std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
266 input_tuples.reserve(input_coords.size());
267 for (size_t i = 0; i < input_coords.size(); ++i) {
268 TF_RETURN_IF_ERROR(
269 working_set->LookupAndPin(backend, input_coords[i].handle));
270 auto tuple = working_set->PinnedTuples().back();
271 if (release_inputs) {
272 // We are holding a reference to the tuple, so we can safely delete it
273 // from the resource manager here.
274 TF_RETURN_IF_ERROR(
275 working_set->MemoryManager()->Release(input_coords[i].handle));
276 VLOG(2) << "Released allocation handle " << input_coords[i].handle;
277 }
278 xla::Shape input_shape = shape_getter(i);
279 if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
280 return errors::InvalidArgument(
281 "Run-time shape mismatch for XRTExecute argument[", i, "] (",
282 input_coords[i].handle, "). Expected ", input_shape.DebugString(),
283 "; got ", tuple->on_host_shape().DebugString());
284 }
285 if (input_coords[i].index.empty()) {
286 input_tuples.emplace_back(std::move(tuple));
287 } else {
288 XRTTupleAllocation* sub_tuple;
289 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
290 tuple.get(), input_coords[i].index, &sub_tuple,
291 /*alias_parent_allocation=*/true));
292 input_tuples.emplace_back(sub_tuple);
293 }
294 }
295 return std::move(input_tuples);
296 }
297
RebuildOutputAliases(const RefPtr<XRTTupleAllocation> & output_tuple,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias)298 Status RebuildOutputAliases(
299 const RefPtr<XRTTupleAllocation>& output_tuple,
300 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
301 const xla::HloInputOutputAliasConfig& input_output_alias) {
302 auto alias_function =
303 [&](const xla::ShapeIndex& output_index,
304 const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
305 TF_RET_CHECK(alias.parameter_number < input_tuples.size());
306 return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
307 alias.parameter_index, output_index);
308 };
309 return input_output_alias.ForEachAliasWithStatus(alias_function);
310 }
311
GetArgumentsBuffers(const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const std::vector<bool> & input_is_dynamic,bool release_inputs)312 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
313 const xla::HloInputOutputAliasConfig& input_output_alias,
314 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
315 const std::vector<bool>& input_is_dynamic, bool release_inputs) {
316 auto is_dynamic = [&](size_t arg) {
317 return arg < input_is_dynamic.size() && input_is_dynamic[arg];
318 };
319 std::vector<xla::ExecutionInput> arguments;
320 // Don't alias dynamic input -- Due to the underlying implementation,
321 // aliased inputs have two owners: XRTAllocation and return value of
322 // this function. If an argument is dynamic and the ownership is
323 // released to output of this function, TPUExecute will free it and
324 // reallocate a new one, which creates a double freeing issue where
325 // XRTAllocation also attempts to release the buffer.
326 bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
327 input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
328 arguments.reserve(input_tuples.size());
329 for (int64 i = 0; i < input_tuples.size(); ++i) {
330 auto alias_checker =
331 [&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
332 if (input_output_alias.ParameterHasAlias(i, index)) {
333 TF_RET_CHECK(!is_dynamic(i));
334 return true;
335 }
336 return alias_outputs;
337 };
338 TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
339 input_tuples[i]->ToExecutionInput(alias_checker));
340 arguments.emplace_back(std::move(exec_input));
341 }
342 return std::move(arguments);
343 }
344
CreateExecuteOutput(OpKernelContext * context,XRTMemoryManager * memory_manager,RefPtr<XRTTupleAllocation> output_tuple,bool return_exploded_tuple)345 Status CreateExecuteOutput(OpKernelContext* context,
346 XRTMemoryManager* memory_manager,
347 RefPtr<XRTTupleAllocation> output_tuple,
348 bool return_exploded_tuple) {
349 if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
350 int64 tuple_element_count =
351 xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
352 Tensor* output_tensor;
353 TF_RETURN_IF_ERROR(context->allocate_output(
354 0, TensorShape({tuple_element_count}), &output_tensor));
355
356 for (int64 i = 0; i < tuple_element_count; ++i) {
357 XRTTupleAllocation* suballocation;
358 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
359 output_tuple.get(), {i}, &suballocation,
360 /*alias_parent_allocation=*/false));
361 output_tensor->vec<int64>()(i) = memory_manager->Register(suballocation);
362 }
363 } else {
364 Tensor* output_tensor;
365 TF_RETURN_IF_ERROR(
366 context->allocate_output(0, TensorShape({}), &output_tensor));
367 output_tensor->scalar<int64>()() =
368 memory_manager->Register(std::move(output_tuple));
369 }
370 return Status::OK();
371 }
372
ExecuteChained(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,xla::Backend * backend,int device_ordinal,const xrt::XRTChainedExecutePlan & plan,const xrt::XRTChainedExecuteConfig & config,const ChainedExecuteFn & execute_op)373 Status ExecuteChained(OpKernelContext* context,
374 const RefPtr<XRTMemoryManager>& memory_manager,
375 xla::Backend* backend, int device_ordinal,
376 const xrt::XRTChainedExecutePlan& plan,
377 const xrt::XRTChainedExecuteConfig& config,
378 const ChainedExecuteFn& execute_op) {
379 // Create the vector which tracks the uses of the intermediate chained
380 // operations outputs.
381 std::vector<int64> uses(plan.ops_size(), 0);
382 for (auto& op : plan.ops()) {
383 for (auto& input : op.inputs()) {
384 uses[input.op_index()] += 1;
385 }
386 }
387
388 ScopedHandles outputs(memory_manager);
389 ScopedHandles results(memory_manager);
390 for (int i = 0; i < plan.ops_size(); ++i) {
391 auto& op = plan.ops(i);
392 if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
393 // This operation is a device data load. Set the handle as output and
394 // leave the release flag off, since this is not an intermediate output.
395 TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
396 } else if (op.op_oneof_case() ==
397 xrt::XRTChainedExecuteOp::kComputationHandle) {
398 // This is an XRT execute operation, forward to the device specific
399 // handler. Populating the working set makes sure the input allocations
400 // for this execute operations are pinned to device memory.
401 XRTMemoryManager::WorkingSet working_set(memory_manager);
402 TF_RETURN_IF_ERROR(
403 PopulateOpWorkingSet(backend, op, i, outputs, &working_set));
404 TF_ASSIGN_OR_RETURN(auto tuple,
405 execute_op(op, working_set.PinnedTuples()));
406 TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
407 } else {
408 return errors::InvalidArgument(
409 "Undefined operation kind at post-order position ", i);
410 }
411 // If the result of this chained operation is an output result, feed the
412 // results at the desired position.
413 for (auto& output : op.outputs()) {
414 TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
415 RefPtr<XRTTupleAllocation> result;
416 TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
417 TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
418 }
419 // Drop intermediate results which have no more users.
420 for (auto& input : op.inputs()) {
421 uses[input.op_index()] -= 1;
422 if (uses[input.op_index()] == 0) {
423 TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
424 }
425 }
426 }
427
428 Tensor* output_tensor;
429 TF_RETURN_IF_ERROR(context->allocate_output(
430 0, TensorShape({static_cast<int64>(results.size())}), &output_tensor));
431 for (size_t i = 0; i < results.size(); ++i) {
432 output_tensor->vec<int64>()(i) = results.Release(i);
433 }
434 return Status::OK();
435 }
436
437 } // namespace tensorflow
438