• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_util.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 mutex nccl_factory_mutex(LINKER_INITIALIZED);
30 std::shared_ptr<NcclUniqueIdFactory>* nccl_factory;
31 
32 // The ScopedHandles data structure is used in the ExecuteChained() API and its
33 // task is to track tuple allocation registrations. It is used both the track
34 // intermediate results of a chained computation, or its final results. Anything
35 // which is marked to be released, will be released using the XRTMemoryManager
36 // once the object is destroyed (unless an explicit call to Drop() or Release()
37 // is made).
38 class ScopedHandles {
39  public:
ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)40   explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
41       : memory_manager_(std::move(memory_manager)) {}
42 
~ScopedHandles()43   ~ScopedHandles() {
44     for (size_t i = 0; i < handles_.size(); ++i) {
45       if (handles_release_[i]) {
46         memory_manager_->Release(handles_[i]).IgnoreError();
47       }
48     }
49   }
50 
operator [](size_t index) const51   int64_t operator[](size_t index) const { return handles_.at(index); }
52 
size() const53   size_t size() const { return handles_.size(); }
54 
55   // Adds the given handle at the index position, by marking it releasable
56   // according to the release argument. If an existing, and to-be-released
57   // handle already exists at the same index, it will be released.
Add(size_t index,int64_t handle,bool release)58   Status Add(size_t index, int64_t handle, bool release) {
59     if (index >= handles_.size()) {
60       handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
61       handles_release_.resize(index + 1, false);
62     }
63     if (handles_release_[index]) {
64       Status status = memory_manager_->Release(handles_[index]);
65       if (!status.ok()) {
66         if (release) {
67           memory_manager_->Release(handle).IgnoreError();
68         }
69         return status;
70       }
71     }
72     handles_[index] = handle;
73     handles_release_[index] = release;
74     return OkStatus();
75   }
76 
77   // Adds a to-be-released tuple allocation at the given index.
Add(size_t index,RefPtr<XRTTupleAllocation> tuple)78   Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
79     return Add(index, memory_manager_->Register(std::move(tuple)),
80                /*release=*/true);
81   }
82 
83   // Drops the handle at the given index, and releases it using the
84   // XRTMemoryManager::Release() if marked as to-be-released.
Drop(size_t index)85   Status Drop(size_t index) {
86     if (handles_release_.at(index)) {
87       TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
88     }
89     Release(index);
90     return OkStatus();
91   }
92 
93   // Releases the handle at the given index. The destructor will not use that
94   // XRTMemoryManager::Release() API on such handle.
Release(size_t index)95   int64_t Release(size_t index) {
96     int64_t handle = handles_.at(index);
97     handles_[index] = XRTMemoryManager::InvalidKey();
98     handles_release_[index] = false;
99     return handle;
100   }
101 
102   // Looks up the handle stored at the given index, and returns the matching
103   // tuple allocation.
Lookup(size_t index) const104   xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
105     return memory_manager_->Lookup(handles_.at(index));
106   }
107 
108  private:
109   RefPtr<XRTMemoryManager> memory_manager_;
110   std::vector<int64_t> handles_;
111   std::vector<bool> handles_release_;
112 };
113 
DebugOptionsPassThroughEnabled()114 bool DebugOptionsPassThroughEnabled() {
115   const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
116   bool enabled =
117       env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
118   if (enabled) {
119     LOG(WARNING) << "Passing through XLA debug options!";
120   } else {
121     LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
122                     "will be retained";
123   }
124   return enabled;
125 }
126 
SafeDebugPath(const string & path)127 string SafeDebugPath(const string& path) {
128   if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
129       path.compare(0, 11, "bigstore://") == 0) {
130     return path;
131   }
132   LOG(WARNING) << "Invalid config path (will be dropped): " << path;
133   return string();
134 }
135 
MakeOutput(const RefPtr<XRTTupleAllocation> & output,int64_t index,RefPtr<XRTTupleAllocation> * result)136 Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64_t index,
137                   RefPtr<XRTTupleAllocation>* result) {
138   if (index == 0) {
139     *result = output;
140   } else {
141     XRTTupleAllocation* tuple;
142     TF_RETURN_IF_ERROR(
143         XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple,
144                                           /*alias_parent_allocation=*/true));
145     result->reset(tuple);
146   }
147   return OkStatus();
148 }
149 
PopulateOpWorkingSet(xla::Backend * backend,const xrt::XRTChainedExecuteOp & op,int current_index,const ScopedHandles & outputs,XRTMemoryManager::WorkingSet * working_set,se::DeviceMemoryAllocator * allocator)150 Status PopulateOpWorkingSet(xla::Backend* backend,
151                             const xrt::XRTChainedExecuteOp& op,
152                             int current_index, const ScopedHandles& outputs,
153                             XRTMemoryManager::WorkingSet* working_set,
154                             se::DeviceMemoryAllocator* allocator) {
155   for (int i = 0; i < op.inputs_size(); ++i) {
156     auto& input = op.inputs(i);
157     if (input.op_index() >= current_index) {
158       return errors::InvalidArgument(
159           "Input index ", input.op_index(),
160           " is above the current position: ", current_index);
161     }
162     TF_RETURN_IF_ERROR(working_set->LookupAndPin(
163         backend, outputs[input.op_index()], allocator));
164   }
165   return OkStatus();
166 }
167 
168 }  // namespace
169 
SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory)170 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory) {
171   mutex_lock lock(nccl_factory_mutex);
172   if (nccl_factory == nullptr) {
173     nccl_factory = new std::shared_ptr<NcclUniqueIdFactory>();
174   }
175   *nccl_factory = std::move(factory);
176 }
177 
GetNcclUniqueIdFactory()178 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory() {
179   mutex_lock lock(nccl_factory_mutex);
180   return nccl_factory != nullptr ? *nccl_factory : nullptr;
181 }
182 
BuildXlaDebugOptions(const xla::DebugOptions & ref_options)183 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
184   static const bool options_passthrough = DebugOptionsPassThroughEnabled();
185   if (options_passthrough) {
186     return ref_options;
187   }
188   xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
189   options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to()));
190   options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto());
191   options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text());
192   options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots());
193   options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re());
194   options.set_xla_dump_include_timestamp(
195       ref_options.xla_dump_include_timestamp());
196   options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules());
197   for (auto& pass : ref_options.xla_disable_hlo_passes()) {
198     options.add_xla_disable_hlo_passes(pass);
199   }
200   return options;
201 }
202 
GetComputationInputs(OpKernelContext * context,const char * input_name)203 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
204     OpKernelContext* context, const char* input_name) {
205   OpInputList arg_list;
206   TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
207   // Concatenate all input uids from list of scalars-or-vectors carrying them.
208   std::vector<InputCoords> input_coords;
209   for (int i = 0; i < arg_list.size(); ++i) {
210     const Tensor& arg = arg_list[i];
211     if (TensorShapeUtils::IsScalar(arg.shape())) {
212       input_coords.emplace_back(arg.scalar<int64_t>()());
213     } else {
214       TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
215       auto arg_vec = arg.vec<int64_t>();
216       const int64_t num_elts = arg.shape().dim_size(0);
217       for (int i = 0; i < num_elts; ++i) {
218         input_coords.emplace_back(arg_vec(i));
219       }
220     }
221   }
222   return std::move(input_coords);
223 }
224 
InputShapeMatches(const xla::Shape & parameter_shape,const xla::Shape & input_shape)225 bool InputShapeMatches(const xla::Shape& parameter_shape,
226                        const xla::Shape& input_shape) {
227   auto shape_checker = [&](const xla::Shape& pshape,
228                            const xla::ShapeIndex& index) {
229     if (pshape.IsArray()) {
230       TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
231                           xla::ShapeUtil::TryGetSubshape(input_shape, index));
232       if (pshape.rank() != ishape->rank() ||
233           pshape.element_type() != ishape->element_type()) {
234         return errors::InvalidArgument("Mismatching shapes");
235       }
236       if (pshape.is_static() && !xla::Layout::Equal().IgnoreTiles()(
237                                     pshape.layout(), ishape->layout())) {
238         return errors::InvalidArgument("Mismatching layouts");
239       }
240       for (int64_t dim = 0; dim < pshape.rank(); ++dim) {
241         if (pshape.is_dynamic_dimension(dim)) {
242           if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
243             return errors::InvalidArgument("Mismatching shapes");
244           }
245         } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
246           return errors::InvalidArgument("Mismatching shapes");
247         }
248       }
249     }
250     return OkStatus();
251   };
252   return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
253                                                    shape_checker)
254       .ok();
255 }
256 
GetInputTupleAllocations(const std::vector<InputCoords> & input_coords,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,int64_t num_input_shapes,const std::function<xla::Shape (int64_t)> & shape_getter,bool release_inputs,se::DeviceMemoryAllocator * allocator)257 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
258     const std::vector<InputCoords>& input_coords,
259     XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
260     int64_t num_input_shapes,
261     const std::function<xla::Shape(int64_t)>& shape_getter, bool release_inputs,
262     se::DeviceMemoryAllocator* allocator) {
263   if (input_coords.size() != num_input_shapes) {
264     return errors::InvalidArgument(
265         "Number of inputs does not match executable proto input shapes: ",
266         input_coords.size(), " vs. ", num_input_shapes);
267   }
268   std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
269   input_tuples.reserve(input_coords.size());
270   for (size_t i = 0; i < input_coords.size(); ++i) {
271     TF_RETURN_IF_ERROR(
272         working_set->LookupAndPin(backend, input_coords[i].handle, allocator));
273     auto tuple = working_set->PinnedTuples().back();
274     if (release_inputs) {
275       // We are holding a reference to the tuple, so we can safely delete it
276       // from the resource manager here.
277       TF_RETURN_IF_ERROR(
278           working_set->MemoryManager()->Release(input_coords[i].handle));
279       VLOG(2) << "Released allocation handle " << input_coords[i].handle;
280     }
281     xla::Shape input_shape = shape_getter(i);
282     if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
283       return errors::InvalidArgument(
284           "Run-time shape mismatch for XRTExecute argument[", i, "] (",
285           input_coords[i].handle, "). Expected ", input_shape.DebugString(),
286           "; got ", tuple->on_host_shape().DebugString());
287     }
288     if (input_coords[i].index.empty()) {
289       input_tuples.emplace_back(std::move(tuple));
290     } else {
291       XRTTupleAllocation* sub_tuple;
292       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
293           tuple.get(), input_coords[i].index, &sub_tuple,
294           /*alias_parent_allocation=*/true));
295       input_tuples.emplace_back(sub_tuple);
296     }
297   }
298   return std::move(input_tuples);
299 }
300 
RebuildOutputAliases(const RefPtr<XRTTupleAllocation> & output_tuple,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias)301 Status RebuildOutputAliases(
302     const RefPtr<XRTTupleAllocation>& output_tuple,
303     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
304     const xla::HloInputOutputAliasConfig& input_output_alias) {
305   auto alias_function =
306       [&](const xla::ShapeIndex& output_index,
307           const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
308     TF_RET_CHECK(alias.parameter_number < input_tuples.size());
309     return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
310                                          alias.parameter_index, output_index);
311   };
312   return input_output_alias.ForEachAliasWithStatus(alias_function);
313 }
314 
GetArgumentsBuffers(const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const std::vector<bool> & input_is_dynamic,bool release_inputs)315 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
316     const xla::HloInputOutputAliasConfig& input_output_alias,
317     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
318     const std::vector<bool>& input_is_dynamic, bool release_inputs) {
319   auto is_dynamic = [&](size_t arg) {
320     return arg < input_is_dynamic.size() && input_is_dynamic[arg];
321   };
322   std::vector<xla::ExecutionInput> arguments;
323   // Don't alias dynamic input -- Due to the underlying implementation,
324   // aliased inputs have two owners: XRTAllocation and return value of
325   // this function. If an argument is dynamic and the ownership is
326   // released to output of this function, TPUExecute will free it and
327   // reallocate a new one, which creates a double freeing issue where
328   // XRTAllocation also attempts to release the buffer.
329   bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
330                        input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
331   arguments.reserve(input_tuples.size());
332   for (int64_t i = 0; i < input_tuples.size(); ++i) {
333     auto alias_checker =
334         [&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
335       if (input_output_alias.ParameterHasAlias(i, index)) {
336         TF_RET_CHECK(!is_dynamic(i));
337         return true;
338       }
339       return alias_outputs;
340     };
341     TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
342                         input_tuples[i]->ToExecutionInput(alias_checker));
343     arguments.emplace_back(std::move(exec_input));
344   }
345   return std::move(arguments);
346 }
347 
CreateExecuteOutput(OpKernelContext * context,XRTMemoryManager * memory_manager,RefPtr<XRTTupleAllocation> output_tuple,bool return_exploded_tuple)348 Status CreateExecuteOutput(OpKernelContext* context,
349                            XRTMemoryManager* memory_manager,
350                            RefPtr<XRTTupleAllocation> output_tuple,
351                            bool return_exploded_tuple) {
352   if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
353     int64_t tuple_element_count =
354         xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
355     Tensor* output_tensor;
356     TF_RETURN_IF_ERROR(context->allocate_output(
357         0, TensorShape({tuple_element_count}), &output_tensor));
358 
359     for (int64_t i = 0; i < tuple_element_count; ++i) {
360       XRTTupleAllocation* suballocation;
361       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
362           output_tuple.get(), {i}, &suballocation,
363           /*alias_parent_allocation=*/false));
364       output_tensor->vec<int64_t>()(i) =
365           memory_manager->Register(suballocation);
366     }
367   } else {
368     Tensor* output_tensor;
369     TF_RETURN_IF_ERROR(
370         context->allocate_output(0, TensorShape({}), &output_tensor));
371     output_tensor->scalar<int64_t>()() =
372         memory_manager->Register(std::move(output_tuple));
373   }
374   return OkStatus();
375 }
376 
ExecuteChained(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,xla::Backend * backend,int device_ordinal,const xrt::XRTChainedExecutePlan & plan,const xrt::XRTChainedExecuteConfig & config,const ChainedExecuteFn & execute_op,se::DeviceMemoryAllocator * allocator)377 Status ExecuteChained(OpKernelContext* context,
378                       const RefPtr<XRTMemoryManager>& memory_manager,
379                       xla::Backend* backend, int device_ordinal,
380                       const xrt::XRTChainedExecutePlan& plan,
381                       const xrt::XRTChainedExecuteConfig& config,
382                       const ChainedExecuteFn& execute_op,
383                       se::DeviceMemoryAllocator* allocator) {
384   // Create the vector which tracks the uses of the intermediate chained
385   // operations outputs.
386   std::vector<int64_t> uses(plan.ops_size(), 0);
387   for (auto& op : plan.ops()) {
388     for (auto& input : op.inputs()) {
389       uses[input.op_index()] += 1;
390     }
391   }
392 
393   ScopedHandles outputs(memory_manager);
394   ScopedHandles results(memory_manager);
395   for (int i = 0; i < plan.ops_size(); ++i) {
396     auto& op = plan.ops(i);
397     if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
398       // This operation is a device data load. Set the handle as output and
399       // leave the release flag off, since this is not an intermediate output.
400       TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
401     } else if (op.op_oneof_case() ==
402                xrt::XRTChainedExecuteOp::kComputationHandle) {
403       // This is an XRT execute operation, forward to the device specific
404       // handler. Populating the working set makes sure the input allocations
405       // for this execute operations are pinned to device memory.
406       XRTMemoryManager::WorkingSet working_set(memory_manager);
407       TF_RETURN_IF_ERROR(PopulateOpWorkingSet(backend, op, i, outputs,
408                                               &working_set, allocator));
409       TF_ASSIGN_OR_RETURN(auto tuple,
410                           execute_op(op, working_set.PinnedTuples()));
411       TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
412     } else {
413       return errors::InvalidArgument(
414           "Undefined operation kind at post-order position ", i);
415     }
416     // If the result of this chained operation is an output result, feed the
417     // results at the desired position.
418     for (auto& output : op.outputs()) {
419       TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
420       RefPtr<XRTTupleAllocation> result;
421       TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
422       TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
423     }
424     // Drop intermediate results which have no more users.
425     for (auto& input : op.inputs()) {
426       uses[input.op_index()] -= 1;
427       if (uses[input.op_index()] == 0) {
428         TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
429       }
430     }
431   }
432 
433   Tensor* output_tensor;
434   TF_RETURN_IF_ERROR(context->allocate_output(
435       0, TensorShape({static_cast<int64_t>(results.size())}), &output_tensor));
436   for (size_t i = 0; i < results.size(); ++i) {
437     output_tensor->vec<int64_t>()(i) = results.Release(i);
438   }
439   return OkStatus();
440 }
441 
442 }  // namespace tensorflow
443