1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_util.h"
17
18 #include <stdlib.h>
19 #include <string.h>
20
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 namespace tensorflow {
27 namespace {
28
29 mutex nccl_factory_mutex(LINKER_INITIALIZED);
30 std::shared_ptr<NcclUniqueIdFactory>* nccl_factory;
31
32 // The ScopedHandles data structure is used in the ExecuteChained() API and its
33 // task is to track tuple allocation registrations. It is used both the track
34 // intermediate results of a chained computation, or its final results. Anything
35 // which is marked to be released, will be released using the XRTMemoryManager
36 // once the object is destroyed (unless an explicit call to Drop() or Release()
37 // is made).
38 class ScopedHandles {
39 public:
ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)40 explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
41 : memory_manager_(std::move(memory_manager)) {}
42
~ScopedHandles()43 ~ScopedHandles() {
44 for (size_t i = 0; i < handles_.size(); ++i) {
45 if (handles_release_[i]) {
46 memory_manager_->Release(handles_[i]).IgnoreError();
47 }
48 }
49 }
50
operator [](size_t index) const51 int64_t operator[](size_t index) const { return handles_.at(index); }
52
size() const53 size_t size() const { return handles_.size(); }
54
55 // Adds the given handle at the index position, by marking it releasable
56 // according to the release argument. If an existing, and to-be-released
57 // handle already exists at the same index, it will be released.
Add(size_t index,int64_t handle,bool release)58 Status Add(size_t index, int64_t handle, bool release) {
59 if (index >= handles_.size()) {
60 handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
61 handles_release_.resize(index + 1, false);
62 }
63 if (handles_release_[index]) {
64 Status status = memory_manager_->Release(handles_[index]);
65 if (!status.ok()) {
66 if (release) {
67 memory_manager_->Release(handle).IgnoreError();
68 }
69 return status;
70 }
71 }
72 handles_[index] = handle;
73 handles_release_[index] = release;
74 return OkStatus();
75 }
76
77 // Adds a to-be-released tuple allocation at the given index.
Add(size_t index,RefPtr<XRTTupleAllocation> tuple)78 Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
79 return Add(index, memory_manager_->Register(std::move(tuple)),
80 /*release=*/true);
81 }
82
83 // Drops the handle at the given index, and releases it using the
84 // XRTMemoryManager::Release() if marked as to-be-released.
Drop(size_t index)85 Status Drop(size_t index) {
86 if (handles_release_.at(index)) {
87 TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
88 }
89 Release(index);
90 return OkStatus();
91 }
92
93 // Releases the handle at the given index. The destructor will not use that
94 // XRTMemoryManager::Release() API on such handle.
Release(size_t index)95 int64_t Release(size_t index) {
96 int64_t handle = handles_.at(index);
97 handles_[index] = XRTMemoryManager::InvalidKey();
98 handles_release_[index] = false;
99 return handle;
100 }
101
102 // Looks up the handle stored at the given index, and returns the matching
103 // tuple allocation.
Lookup(size_t index) const104 xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
105 return memory_manager_->Lookup(handles_.at(index));
106 }
107
108 private:
109 RefPtr<XRTMemoryManager> memory_manager_;
110 std::vector<int64_t> handles_;
111 std::vector<bool> handles_release_;
112 };
113
DebugOptionsPassThroughEnabled()114 bool DebugOptionsPassThroughEnabled() {
115 const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
116 bool enabled =
117 env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
118 if (enabled) {
119 LOG(WARNING) << "Passing through XLA debug options!";
120 } else {
121 LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
122 "will be retained";
123 }
124 return enabled;
125 }
126
SafeDebugPath(const string & path)127 string SafeDebugPath(const string& path) {
128 if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
129 path.compare(0, 11, "bigstore://") == 0) {
130 return path;
131 }
132 LOG(WARNING) << "Invalid config path (will be dropped): " << path;
133 return string();
134 }
135
MakeOutput(const RefPtr<XRTTupleAllocation> & output,int64_t index,RefPtr<XRTTupleAllocation> * result)136 Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64_t index,
137 RefPtr<XRTTupleAllocation>* result) {
138 if (index == 0) {
139 *result = output;
140 } else {
141 XRTTupleAllocation* tuple;
142 TF_RETURN_IF_ERROR(
143 XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple,
144 /*alias_parent_allocation=*/true));
145 result->reset(tuple);
146 }
147 return OkStatus();
148 }
149
PopulateOpWorkingSet(xla::Backend * backend,const xrt::XRTChainedExecuteOp & op,int current_index,const ScopedHandles & outputs,XRTMemoryManager::WorkingSet * working_set,se::DeviceMemoryAllocator * allocator)150 Status PopulateOpWorkingSet(xla::Backend* backend,
151 const xrt::XRTChainedExecuteOp& op,
152 int current_index, const ScopedHandles& outputs,
153 XRTMemoryManager::WorkingSet* working_set,
154 se::DeviceMemoryAllocator* allocator) {
155 for (int i = 0; i < op.inputs_size(); ++i) {
156 auto& input = op.inputs(i);
157 if (input.op_index() >= current_index) {
158 return errors::InvalidArgument(
159 "Input index ", input.op_index(),
160 " is above the current position: ", current_index);
161 }
162 TF_RETURN_IF_ERROR(working_set->LookupAndPin(
163 backend, outputs[input.op_index()], allocator));
164 }
165 return OkStatus();
166 }
167
168 } // namespace
169
SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory)170 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory) {
171 mutex_lock lock(nccl_factory_mutex);
172 if (nccl_factory == nullptr) {
173 nccl_factory = new std::shared_ptr<NcclUniqueIdFactory>();
174 }
175 *nccl_factory = std::move(factory);
176 }
177
GetNcclUniqueIdFactory()178 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory() {
179 mutex_lock lock(nccl_factory_mutex);
180 return nccl_factory != nullptr ? *nccl_factory : nullptr;
181 }
182
BuildXlaDebugOptions(const xla::DebugOptions & ref_options)183 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
184 static const bool options_passthrough = DebugOptionsPassThroughEnabled();
185 if (options_passthrough) {
186 return ref_options;
187 }
188 xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
189 options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to()));
190 options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto());
191 options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text());
192 options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots());
193 options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re());
194 options.set_xla_dump_include_timestamp(
195 ref_options.xla_dump_include_timestamp());
196 options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules());
197 for (auto& pass : ref_options.xla_disable_hlo_passes()) {
198 options.add_xla_disable_hlo_passes(pass);
199 }
200 return options;
201 }
202
GetComputationInputs(OpKernelContext * context,const char * input_name)203 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
204 OpKernelContext* context, const char* input_name) {
205 OpInputList arg_list;
206 TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
207 // Concatenate all input uids from list of scalars-or-vectors carrying them.
208 std::vector<InputCoords> input_coords;
209 for (int i = 0; i < arg_list.size(); ++i) {
210 const Tensor& arg = arg_list[i];
211 if (TensorShapeUtils::IsScalar(arg.shape())) {
212 input_coords.emplace_back(arg.scalar<int64_t>()());
213 } else {
214 TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
215 auto arg_vec = arg.vec<int64_t>();
216 const int64_t num_elts = arg.shape().dim_size(0);
217 for (int i = 0; i < num_elts; ++i) {
218 input_coords.emplace_back(arg_vec(i));
219 }
220 }
221 }
222 return std::move(input_coords);
223 }
224
InputShapeMatches(const xla::Shape & parameter_shape,const xla::Shape & input_shape)225 bool InputShapeMatches(const xla::Shape& parameter_shape,
226 const xla::Shape& input_shape) {
227 auto shape_checker = [&](const xla::Shape& pshape,
228 const xla::ShapeIndex& index) {
229 if (pshape.IsArray()) {
230 TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
231 xla::ShapeUtil::TryGetSubshape(input_shape, index));
232 if (pshape.rank() != ishape->rank() ||
233 pshape.element_type() != ishape->element_type()) {
234 return errors::InvalidArgument("Mismatching shapes");
235 }
236 if (pshape.is_static() && !xla::Layout::Equal().IgnoreTiles()(
237 pshape.layout(), ishape->layout())) {
238 return errors::InvalidArgument("Mismatching layouts");
239 }
240 for (int64_t dim = 0; dim < pshape.rank(); ++dim) {
241 if (pshape.is_dynamic_dimension(dim)) {
242 if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
243 return errors::InvalidArgument("Mismatching shapes");
244 }
245 } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
246 return errors::InvalidArgument("Mismatching shapes");
247 }
248 }
249 }
250 return OkStatus();
251 };
252 return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
253 shape_checker)
254 .ok();
255 }
256
GetInputTupleAllocations(const std::vector<InputCoords> & input_coords,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,int64_t num_input_shapes,const std::function<xla::Shape (int64_t)> & shape_getter,bool release_inputs,se::DeviceMemoryAllocator * allocator)257 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
258 const std::vector<InputCoords>& input_coords,
259 XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
260 int64_t num_input_shapes,
261 const std::function<xla::Shape(int64_t)>& shape_getter, bool release_inputs,
262 se::DeviceMemoryAllocator* allocator) {
263 if (input_coords.size() != num_input_shapes) {
264 return errors::InvalidArgument(
265 "Number of inputs does not match executable proto input shapes: ",
266 input_coords.size(), " vs. ", num_input_shapes);
267 }
268 std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
269 input_tuples.reserve(input_coords.size());
270 for (size_t i = 0; i < input_coords.size(); ++i) {
271 TF_RETURN_IF_ERROR(
272 working_set->LookupAndPin(backend, input_coords[i].handle, allocator));
273 auto tuple = working_set->PinnedTuples().back();
274 if (release_inputs) {
275 // We are holding a reference to the tuple, so we can safely delete it
276 // from the resource manager here.
277 TF_RETURN_IF_ERROR(
278 working_set->MemoryManager()->Release(input_coords[i].handle));
279 VLOG(2) << "Released allocation handle " << input_coords[i].handle;
280 }
281 xla::Shape input_shape = shape_getter(i);
282 if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
283 return errors::InvalidArgument(
284 "Run-time shape mismatch for XRTExecute argument[", i, "] (",
285 input_coords[i].handle, "). Expected ", input_shape.DebugString(),
286 "; got ", tuple->on_host_shape().DebugString());
287 }
288 if (input_coords[i].index.empty()) {
289 input_tuples.emplace_back(std::move(tuple));
290 } else {
291 XRTTupleAllocation* sub_tuple;
292 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
293 tuple.get(), input_coords[i].index, &sub_tuple,
294 /*alias_parent_allocation=*/true));
295 input_tuples.emplace_back(sub_tuple);
296 }
297 }
298 return std::move(input_tuples);
299 }
300
RebuildOutputAliases(const RefPtr<XRTTupleAllocation> & output_tuple,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias)301 Status RebuildOutputAliases(
302 const RefPtr<XRTTupleAllocation>& output_tuple,
303 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
304 const xla::HloInputOutputAliasConfig& input_output_alias) {
305 auto alias_function =
306 [&](const xla::ShapeIndex& output_index,
307 const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
308 TF_RET_CHECK(alias.parameter_number < input_tuples.size());
309 return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
310 alias.parameter_index, output_index);
311 };
312 return input_output_alias.ForEachAliasWithStatus(alias_function);
313 }
314
GetArgumentsBuffers(const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const std::vector<bool> & input_is_dynamic,bool release_inputs)315 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
316 const xla::HloInputOutputAliasConfig& input_output_alias,
317 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
318 const std::vector<bool>& input_is_dynamic, bool release_inputs) {
319 auto is_dynamic = [&](size_t arg) {
320 return arg < input_is_dynamic.size() && input_is_dynamic[arg];
321 };
322 std::vector<xla::ExecutionInput> arguments;
323 // Don't alias dynamic input -- Due to the underlying implementation,
324 // aliased inputs have two owners: XRTAllocation and return value of
325 // this function. If an argument is dynamic and the ownership is
326 // released to output of this function, TPUExecute will free it and
327 // reallocate a new one, which creates a double freeing issue where
328 // XRTAllocation also attempts to release the buffer.
329 bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
330 input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
331 arguments.reserve(input_tuples.size());
332 for (int64_t i = 0; i < input_tuples.size(); ++i) {
333 auto alias_checker =
334 [&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
335 if (input_output_alias.ParameterHasAlias(i, index)) {
336 TF_RET_CHECK(!is_dynamic(i));
337 return true;
338 }
339 return alias_outputs;
340 };
341 TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
342 input_tuples[i]->ToExecutionInput(alias_checker));
343 arguments.emplace_back(std::move(exec_input));
344 }
345 return std::move(arguments);
346 }
347
CreateExecuteOutput(OpKernelContext * context,XRTMemoryManager * memory_manager,RefPtr<XRTTupleAllocation> output_tuple,bool return_exploded_tuple)348 Status CreateExecuteOutput(OpKernelContext* context,
349 XRTMemoryManager* memory_manager,
350 RefPtr<XRTTupleAllocation> output_tuple,
351 bool return_exploded_tuple) {
352 if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
353 int64_t tuple_element_count =
354 xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
355 Tensor* output_tensor;
356 TF_RETURN_IF_ERROR(context->allocate_output(
357 0, TensorShape({tuple_element_count}), &output_tensor));
358
359 for (int64_t i = 0; i < tuple_element_count; ++i) {
360 XRTTupleAllocation* suballocation;
361 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
362 output_tuple.get(), {i}, &suballocation,
363 /*alias_parent_allocation=*/false));
364 output_tensor->vec<int64_t>()(i) =
365 memory_manager->Register(suballocation);
366 }
367 } else {
368 Tensor* output_tensor;
369 TF_RETURN_IF_ERROR(
370 context->allocate_output(0, TensorShape({}), &output_tensor));
371 output_tensor->scalar<int64_t>()() =
372 memory_manager->Register(std::move(output_tuple));
373 }
374 return OkStatus();
375 }
376
ExecuteChained(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,xla::Backend * backend,int device_ordinal,const xrt::XRTChainedExecutePlan & plan,const xrt::XRTChainedExecuteConfig & config,const ChainedExecuteFn & execute_op,se::DeviceMemoryAllocator * allocator)377 Status ExecuteChained(OpKernelContext* context,
378 const RefPtr<XRTMemoryManager>& memory_manager,
379 xla::Backend* backend, int device_ordinal,
380 const xrt::XRTChainedExecutePlan& plan,
381 const xrt::XRTChainedExecuteConfig& config,
382 const ChainedExecuteFn& execute_op,
383 se::DeviceMemoryAllocator* allocator) {
384 // Create the vector which tracks the uses of the intermediate chained
385 // operations outputs.
386 std::vector<int64_t> uses(plan.ops_size(), 0);
387 for (auto& op : plan.ops()) {
388 for (auto& input : op.inputs()) {
389 uses[input.op_index()] += 1;
390 }
391 }
392
393 ScopedHandles outputs(memory_manager);
394 ScopedHandles results(memory_manager);
395 for (int i = 0; i < plan.ops_size(); ++i) {
396 auto& op = plan.ops(i);
397 if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
398 // This operation is a device data load. Set the handle as output and
399 // leave the release flag off, since this is not an intermediate output.
400 TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
401 } else if (op.op_oneof_case() ==
402 xrt::XRTChainedExecuteOp::kComputationHandle) {
403 // This is an XRT execute operation, forward to the device specific
404 // handler. Populating the working set makes sure the input allocations
405 // for this execute operations are pinned to device memory.
406 XRTMemoryManager::WorkingSet working_set(memory_manager);
407 TF_RETURN_IF_ERROR(PopulateOpWorkingSet(backend, op, i, outputs,
408 &working_set, allocator));
409 TF_ASSIGN_OR_RETURN(auto tuple,
410 execute_op(op, working_set.PinnedTuples()));
411 TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
412 } else {
413 return errors::InvalidArgument(
414 "Undefined operation kind at post-order position ", i);
415 }
416 // If the result of this chained operation is an output result, feed the
417 // results at the desired position.
418 for (auto& output : op.outputs()) {
419 TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
420 RefPtr<XRTTupleAllocation> result;
421 TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
422 TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
423 }
424 // Drop intermediate results which have no more users.
425 for (auto& input : op.inputs()) {
426 uses[input.op_index()] -= 1;
427 if (uses[input.op_index()] == 0) {
428 TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
429 }
430 }
431 }
432
433 Tensor* output_tensor;
434 TF_RETURN_IF_ERROR(context->allocate_output(
435 0, TensorShape({static_cast<int64_t>(results.size())}), &output_tensor));
436 for (size_t i = 0; i < results.size(); ++i) {
437 output_tensor->vec<int64_t>()(i) = results.Release(i);
438 }
439 return OkStatus();
440 }
441
442 } // namespace tensorflow
443