1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_util.h"
17
18 #include <stdlib.h>
19 #include <string.h>
20
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 namespace tensorflow {
27 namespace {
28
29 mutex nccl_factory_mutex(LINKER_INITIALIZED);
30 std::shared_ptr<NcclUniqueIdFactory>* nccl_factory;
31
32 // The ScopedHandles data structure is used in the ExecuteChained() API and its
33 // task is to track tuple allocation registrations. It is used both the track
34 // intermediate results of a chained computation, or its final results. Anything
35 // which is marked to be released, will be released using the XRTMemoryManager
36 // once the object is destroyed (unless an explicit call to Drop() or Release()
37 // is made).
38 class ScopedHandles {
39 public:
ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)40 explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
41 : memory_manager_(std::move(memory_manager)) {}
42
~ScopedHandles()43 ~ScopedHandles() {
44 for (size_t i = 0; i < handles_.size(); ++i) {
45 if (handles_release_[i]) {
46 memory_manager_->Release(handles_[i]).IgnoreError();
47 }
48 }
49 }
50
operator [](size_t index) const51 int64 operator[](size_t index) const { return handles_.at(index); }
52
size() const53 size_t size() const { return handles_.size(); }
54
55 // Adds the given handle at the index position, by marking it releasable
56 // according to the release argument. If an existing, and to-be-released
57 // handle already exists at the same index, it will be released.
Add(size_t index,int64_t handle,bool release)58 Status Add(size_t index, int64_t handle, bool release) {
59 if (index >= handles_.size()) {
60 handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
61 handles_release_.resize(index + 1, false);
62 }
63 if (handles_release_[index]) {
64 Status status = memory_manager_->Release(handles_[index]);
65 if (!status.ok()) {
66 if (release) {
67 memory_manager_->Release(handle).IgnoreError();
68 }
69 return status;
70 }
71 }
72 handles_[index] = handle;
73 handles_release_[index] = release;
74 return Status::OK();
75 }
76
77 // Adds a to-be-released tuple allocation at the given index.
Add(size_t index,RefPtr<XRTTupleAllocation> tuple)78 Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
79 return Add(index, memory_manager_->Register(std::move(tuple)),
80 /*release=*/true);
81 }
82
83 // Drops the handle at the given index, and releases it using the
84 // XRTMemoryManager::Release() if marked as to-be-released.
Drop(size_t index)85 Status Drop(size_t index) {
86 if (handles_release_.at(index)) {
87 TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
88 }
89 Release(index);
90 return Status::OK();
91 }
92
93 // Releases the handle at the given index. The destructor will not use that
94 // XRTMemoryManager::Release() API on such handle.
Release(size_t index)95 int64 Release(size_t index) {
96 int64_t handle = handles_.at(index);
97 handles_[index] = XRTMemoryManager::InvalidKey();
98 handles_release_[index] = false;
99 return handle;
100 }
101
102 // Looks up the handle stored at the given index, and returns the matching
103 // tuple allocation.
Lookup(size_t index) const104 xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
105 return memory_manager_->Lookup(handles_.at(index));
106 }
107
108 private:
109 RefPtr<XRTMemoryManager> memory_manager_;
110 std::vector<int64> handles_;
111 std::vector<bool> handles_release_;
112 };
113
DebugOptionsPassThroughEnabled()114 bool DebugOptionsPassThroughEnabled() {
115 const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
116 bool enabled =
117 env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
118 if (enabled) {
119 LOG(WARNING) << "Passing through XLA debug options!";
120 } else {
121 LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
122 "will be retained";
123 }
124 return enabled;
125 }
126
SafeDebugPath(const string & path)127 string SafeDebugPath(const string& path) {
128 if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
129 path.compare(0, 11, "bigstore://") == 0) {
130 return path;
131 }
132 LOG(WARNING) << "Invalid config path (will be dropped): " << path;
133 return string();
134 }
135
MakeOutput(const RefPtr<XRTTupleAllocation> & output,int64_t index,RefPtr<XRTTupleAllocation> * result)136 Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64_t index,
137 RefPtr<XRTTupleAllocation>* result) {
138 if (index == 0) {
139 *result = output;
140 } else {
141 XRTTupleAllocation* tuple;
142 TF_RETURN_IF_ERROR(
143 XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple,
144 /*alias_parent_allocation=*/true));
145 result->reset(tuple);
146 }
147 return Status::OK();
148 }
149
PopulateOpWorkingSet(xla::Backend * backend,const xrt::XRTChainedExecuteOp & op,int current_index,const ScopedHandles & outputs,XRTMemoryManager::WorkingSet * working_set,se::DeviceMemoryAllocator * allocator)150 Status PopulateOpWorkingSet(xla::Backend* backend,
151 const xrt::XRTChainedExecuteOp& op,
152 int current_index, const ScopedHandles& outputs,
153 XRTMemoryManager::WorkingSet* working_set,
154 se::DeviceMemoryAllocator* allocator) {
155 for (int i = 0; i < op.inputs_size(); ++i) {
156 auto& input = op.inputs(i);
157 if (input.op_index() >= current_index) {
158 return errors::InvalidArgument(
159 "Input index ", input.op_index(),
160 " is above the current position: ", current_index);
161 }
162 TF_RETURN_IF_ERROR(working_set->LookupAndPin(
163 backend, outputs[input.op_index()], allocator));
164 }
165 return Status::OK();
166 }
167
168 } // namespace
169
SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory)170 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory) {
171 mutex_lock lock(nccl_factory_mutex);
172 if (nccl_factory == nullptr) {
173 nccl_factory = new std::shared_ptr<NcclUniqueIdFactory>();
174 }
175 *nccl_factory = std::move(factory);
176 }
177
GetNcclUniqueIdFactory()178 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory() {
179 mutex_lock lock(nccl_factory_mutex);
180 return nccl_factory != nullptr ? *nccl_factory : nullptr;
181 }
182
BuildXlaDebugOptions(const xla::DebugOptions & ref_options)183 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
184 static const bool options_passthrough = DebugOptionsPassThroughEnabled();
185 if (options_passthrough) {
186 return ref_options;
187 }
188 xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
189 options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to()));
190 options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto());
191 options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text());
192 options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots());
193 options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re());
194 options.set_xla_dump_include_timestamp(
195 ref_options.xla_dump_include_timestamp());
196 options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules());
197 for (auto& pass : ref_options.xla_disable_hlo_passes()) {
198 options.add_xla_disable_hlo_passes(pass);
199 }
200 return options;
201 }
202
GetComputationInputs(OpKernelContext * context,const char * input_name)203 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
204 OpKernelContext* context, const char* input_name) {
205 OpInputList arg_list;
206 TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
207 // Concatenate all input uids from list of scalars-or-vectors carrying them.
208 std::vector<InputCoords> input_coords;
209 for (int i = 0; i < arg_list.size(); ++i) {
210 const Tensor& arg = arg_list[i];
211 if (TensorShapeUtils::IsScalar(arg.shape())) {
212 input_coords.emplace_back(arg.scalar<int64>()());
213 } else {
214 TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
215 auto arg_vec = arg.vec<int64>();
216 const int64_t num_elts = arg.shape().dim_size(0);
217 for (int i = 0; i < num_elts; ++i) {
218 input_coords.emplace_back(arg_vec(i));
219 }
220 }
221 }
222 return std::move(input_coords);
223 }
224
InputShapeMatches(const xla::Shape & parameter_shape,const xla::Shape & input_shape)225 bool InputShapeMatches(const xla::Shape& parameter_shape,
226 const xla::Shape& input_shape) {
227 auto shape_checker = [&](const xla::Shape& pshape,
228 const xla::ShapeIndex& index) {
229 if (pshape.IsArray()) {
230 TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
231 xla::ShapeUtil::TryGetSubshape(input_shape, index));
232 if (pshape.rank() != ishape->rank() ||
233 pshape.element_type() != ishape->element_type()) {
234 return errors::InvalidArgument("Mismatching shapes");
235 }
236 if (pshape.is_static() && pshape.layout() != ishape->layout()) {
237 return errors::InvalidArgument("Mismatching layouts");
238 }
239 for (int64_t dim = 0; dim < pshape.rank(); ++dim) {
240 if (pshape.is_dynamic_dimension(dim)) {
241 if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
242 return errors::InvalidArgument("Mismatching shapes");
243 }
244 } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
245 return errors::InvalidArgument("Mismatching shapes");
246 }
247 }
248 }
249 return Status::OK();
250 };
251 return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
252 shape_checker)
253 .ok();
254 }
255
GetInputTupleAllocations(const std::vector<InputCoords> & input_coords,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,int64_t num_input_shapes,const std::function<xla::Shape (int64_t)> & shape_getter,bool release_inputs,se::DeviceMemoryAllocator * allocator)256 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
257 const std::vector<InputCoords>& input_coords,
258 XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
259 int64_t num_input_shapes,
260 const std::function<xla::Shape(int64_t)>& shape_getter, bool release_inputs,
261 se::DeviceMemoryAllocator* allocator) {
262 if (input_coords.size() != num_input_shapes) {
263 return errors::InvalidArgument(
264 "Number of inputs does not match executable proto input shapes: ",
265 input_coords.size(), " vs. ", num_input_shapes);
266 }
267 std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
268 input_tuples.reserve(input_coords.size());
269 for (size_t i = 0; i < input_coords.size(); ++i) {
270 TF_RETURN_IF_ERROR(
271 working_set->LookupAndPin(backend, input_coords[i].handle, allocator));
272 auto tuple = working_set->PinnedTuples().back();
273 if (release_inputs) {
274 // We are holding a reference to the tuple, so we can safely delete it
275 // from the resource manager here.
276 TF_RETURN_IF_ERROR(
277 working_set->MemoryManager()->Release(input_coords[i].handle));
278 VLOG(2) << "Released allocation handle " << input_coords[i].handle;
279 }
280 xla::Shape input_shape = shape_getter(i);
281 if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
282 return errors::InvalidArgument(
283 "Run-time shape mismatch for XRTExecute argument[", i, "] (",
284 input_coords[i].handle, "). Expected ", input_shape.DebugString(),
285 "; got ", tuple->on_host_shape().DebugString());
286 }
287 if (input_coords[i].index.empty()) {
288 input_tuples.emplace_back(std::move(tuple));
289 } else {
290 XRTTupleAllocation* sub_tuple;
291 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
292 tuple.get(), input_coords[i].index, &sub_tuple,
293 /*alias_parent_allocation=*/true));
294 input_tuples.emplace_back(sub_tuple);
295 }
296 }
297 return std::move(input_tuples);
298 }
299
RebuildOutputAliases(const RefPtr<XRTTupleAllocation> & output_tuple,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias)300 Status RebuildOutputAliases(
301 const RefPtr<XRTTupleAllocation>& output_tuple,
302 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
303 const xla::HloInputOutputAliasConfig& input_output_alias) {
304 auto alias_function =
305 [&](const xla::ShapeIndex& output_index,
306 const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
307 TF_RET_CHECK(alias.parameter_number < input_tuples.size());
308 return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
309 alias.parameter_index, output_index);
310 };
311 return input_output_alias.ForEachAliasWithStatus(alias_function);
312 }
313
GetArgumentsBuffers(const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const std::vector<bool> & input_is_dynamic,bool release_inputs)314 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
315 const xla::HloInputOutputAliasConfig& input_output_alias,
316 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
317 const std::vector<bool>& input_is_dynamic, bool release_inputs) {
318 auto is_dynamic = [&](size_t arg) {
319 return arg < input_is_dynamic.size() && input_is_dynamic[arg];
320 };
321 std::vector<xla::ExecutionInput> arguments;
322 // Don't alias dynamic input -- Due to the underlying implementation,
323 // aliased inputs have two owners: XRTAllocation and return value of
324 // this function. If an argument is dynamic and the ownership is
325 // released to output of this function, TPUExecute will free it and
326 // reallocate a new one, which creates a double freeing issue where
327 // XRTAllocation also attempts to release the buffer.
328 bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
329 input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
330 arguments.reserve(input_tuples.size());
331 for (int64_t i = 0; i < input_tuples.size(); ++i) {
332 auto alias_checker =
333 [&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
334 if (input_output_alias.ParameterHasAlias(i, index)) {
335 TF_RET_CHECK(!is_dynamic(i));
336 return true;
337 }
338 return alias_outputs;
339 };
340 TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
341 input_tuples[i]->ToExecutionInput(alias_checker));
342 arguments.emplace_back(std::move(exec_input));
343 }
344 return std::move(arguments);
345 }
346
CreateExecuteOutput(OpKernelContext * context,XRTMemoryManager * memory_manager,RefPtr<XRTTupleAllocation> output_tuple,bool return_exploded_tuple)347 Status CreateExecuteOutput(OpKernelContext* context,
348 XRTMemoryManager* memory_manager,
349 RefPtr<XRTTupleAllocation> output_tuple,
350 bool return_exploded_tuple) {
351 if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
352 int64_t tuple_element_count =
353 xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
354 Tensor* output_tensor;
355 TF_RETURN_IF_ERROR(context->allocate_output(
356 0, TensorShape({tuple_element_count}), &output_tensor));
357
358 for (int64_t i = 0; i < tuple_element_count; ++i) {
359 XRTTupleAllocation* suballocation;
360 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
361 output_tuple.get(), {i}, &suballocation,
362 /*alias_parent_allocation=*/false));
363 output_tensor->vec<int64>()(i) = memory_manager->Register(suballocation);
364 }
365 } else {
366 Tensor* output_tensor;
367 TF_RETURN_IF_ERROR(
368 context->allocate_output(0, TensorShape({}), &output_tensor));
369 output_tensor->scalar<int64>()() =
370 memory_manager->Register(std::move(output_tuple));
371 }
372 return Status::OK();
373 }
374
ExecuteChained(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,xla::Backend * backend,int device_ordinal,const xrt::XRTChainedExecutePlan & plan,const xrt::XRTChainedExecuteConfig & config,const ChainedExecuteFn & execute_op,se::DeviceMemoryAllocator * allocator)375 Status ExecuteChained(OpKernelContext* context,
376 const RefPtr<XRTMemoryManager>& memory_manager,
377 xla::Backend* backend, int device_ordinal,
378 const xrt::XRTChainedExecutePlan& plan,
379 const xrt::XRTChainedExecuteConfig& config,
380 const ChainedExecuteFn& execute_op,
381 se::DeviceMemoryAllocator* allocator) {
382 // Create the vector which tracks the uses of the intermediate chained
383 // operations outputs.
384 std::vector<int64> uses(plan.ops_size(), 0);
385 for (auto& op : plan.ops()) {
386 for (auto& input : op.inputs()) {
387 uses[input.op_index()] += 1;
388 }
389 }
390
391 ScopedHandles outputs(memory_manager);
392 ScopedHandles results(memory_manager);
393 for (int i = 0; i < plan.ops_size(); ++i) {
394 auto& op = plan.ops(i);
395 if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
396 // This operation is a device data load. Set the handle as output and
397 // leave the release flag off, since this is not an intermediate output.
398 TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
399 } else if (op.op_oneof_case() ==
400 xrt::XRTChainedExecuteOp::kComputationHandle) {
401 // This is an XRT execute operation, forward to the device specific
402 // handler. Populating the working set makes sure the input allocations
403 // for this execute operations are pinned to device memory.
404 XRTMemoryManager::WorkingSet working_set(memory_manager);
405 TF_RETURN_IF_ERROR(PopulateOpWorkingSet(backend, op, i, outputs,
406 &working_set, allocator));
407 TF_ASSIGN_OR_RETURN(auto tuple,
408 execute_op(op, working_set.PinnedTuples()));
409 TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
410 } else {
411 return errors::InvalidArgument(
412 "Undefined operation kind at post-order position ", i);
413 }
414 // If the result of this chained operation is an output result, feed the
415 // results at the desired position.
416 for (auto& output : op.outputs()) {
417 TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
418 RefPtr<XRTTupleAllocation> result;
419 TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
420 TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
421 }
422 // Drop intermediate results which have no more users.
423 for (auto& input : op.inputs()) {
424 uses[input.op_index()] -= 1;
425 if (uses[input.op_index()] == 0) {
426 TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
427 }
428 }
429 }
430
431 Tensor* output_tensor;
432 TF_RETURN_IF_ERROR(context->allocate_output(
433 0, TensorShape({static_cast<int64>(results.size())}), &output_tensor));
434 for (size_t i = 0; i < results.size(); ++i) {
435 output_tensor->vec<int64>()(i) = results.Release(i);
436 }
437 return Status::OK();
438 }
439
440 } // namespace tensorflow
441