• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_util.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 mutex nccl_factory_mutex(LINKER_INITIALIZED);
30 std::shared_ptr<NcclUniqueIdFactory>* nccl_factory;
31 
32 // The ScopedHandles data structure is used in the ExecuteChained() API and its
33 // task is to track tuple allocation registrations. It is used both the track
34 // intermediate results of a chained computation, or its final results. Anything
35 // which is marked to be released, will be released using the XRTMemoryManager
36 // once the object is destroyed (unless an explicit call to Drop() or Release()
37 // is made).
38 class ScopedHandles {
39  public:
ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)40   explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
41       : memory_manager_(std::move(memory_manager)) {}
42 
~ScopedHandles()43   ~ScopedHandles() {
44     for (size_t i = 0; i < handles_.size(); ++i) {
45       if (handles_release_[i]) {
46         memory_manager_->Release(handles_[i]).IgnoreError();
47       }
48     }
49   }
50 
operator [](size_t index) const51   int64 operator[](size_t index) const { return handles_.at(index); }
52 
size() const53   size_t size() const { return handles_.size(); }
54 
55   // Adds the given handle at the index position, by marking it releasable
56   // according to the release argument. If an existing, and to-be-released
57   // handle already exists at the same index, it will be released.
Add(size_t index,int64_t handle,bool release)58   Status Add(size_t index, int64_t handle, bool release) {
59     if (index >= handles_.size()) {
60       handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
61       handles_release_.resize(index + 1, false);
62     }
63     if (handles_release_[index]) {
64       Status status = memory_manager_->Release(handles_[index]);
65       if (!status.ok()) {
66         if (release) {
67           memory_manager_->Release(handle).IgnoreError();
68         }
69         return status;
70       }
71     }
72     handles_[index] = handle;
73     handles_release_[index] = release;
74     return Status::OK();
75   }
76 
77   // Adds a to-be-released tuple allocation at the given index.
Add(size_t index,RefPtr<XRTTupleAllocation> tuple)78   Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
79     return Add(index, memory_manager_->Register(std::move(tuple)),
80                /*release=*/true);
81   }
82 
83   // Drops the handle at the given index, and releases it using the
84   // XRTMemoryManager::Release() if marked as to-be-released.
Drop(size_t index)85   Status Drop(size_t index) {
86     if (handles_release_.at(index)) {
87       TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
88     }
89     Release(index);
90     return Status::OK();
91   }
92 
93   // Releases the handle at the given index. The destructor will not use that
94   // XRTMemoryManager::Release() API on such handle.
Release(size_t index)95   int64 Release(size_t index) {
96     int64_t handle = handles_.at(index);
97     handles_[index] = XRTMemoryManager::InvalidKey();
98     handles_release_[index] = false;
99     return handle;
100   }
101 
102   // Looks up the handle stored at the given index, and returns the matching
103   // tuple allocation.
Lookup(size_t index) const104   xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
105     return memory_manager_->Lookup(handles_.at(index));
106   }
107 
108  private:
109   RefPtr<XRTMemoryManager> memory_manager_;
110   std::vector<int64> handles_;
111   std::vector<bool> handles_release_;
112 };
113 
DebugOptionsPassThroughEnabled()114 bool DebugOptionsPassThroughEnabled() {
115   const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
116   bool enabled =
117       env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
118   if (enabled) {
119     LOG(WARNING) << "Passing through XLA debug options!";
120   } else {
121     LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
122                     "will be retained";
123   }
124   return enabled;
125 }
126 
SafeDebugPath(const string & path)127 string SafeDebugPath(const string& path) {
128   if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
129       path.compare(0, 11, "bigstore://") == 0) {
130     return path;
131   }
132   LOG(WARNING) << "Invalid config path (will be dropped): " << path;
133   return string();
134 }
135 
MakeOutput(const RefPtr<XRTTupleAllocation> & output,int64_t index,RefPtr<XRTTupleAllocation> * result)136 Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64_t index,
137                   RefPtr<XRTTupleAllocation>* result) {
138   if (index == 0) {
139     *result = output;
140   } else {
141     XRTTupleAllocation* tuple;
142     TF_RETURN_IF_ERROR(
143         XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple,
144                                           /*alias_parent_allocation=*/true));
145     result->reset(tuple);
146   }
147   return Status::OK();
148 }
149 
PopulateOpWorkingSet(xla::Backend * backend,const xrt::XRTChainedExecuteOp & op,int current_index,const ScopedHandles & outputs,XRTMemoryManager::WorkingSet * working_set,se::DeviceMemoryAllocator * allocator)150 Status PopulateOpWorkingSet(xla::Backend* backend,
151                             const xrt::XRTChainedExecuteOp& op,
152                             int current_index, const ScopedHandles& outputs,
153                             XRTMemoryManager::WorkingSet* working_set,
154                             se::DeviceMemoryAllocator* allocator) {
155   for (int i = 0; i < op.inputs_size(); ++i) {
156     auto& input = op.inputs(i);
157     if (input.op_index() >= current_index) {
158       return errors::InvalidArgument(
159           "Input index ", input.op_index(),
160           " is above the current position: ", current_index);
161     }
162     TF_RETURN_IF_ERROR(working_set->LookupAndPin(
163         backend, outputs[input.op_index()], allocator));
164   }
165   return Status::OK();
166 }
167 
168 }  // namespace
169 
SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory)170 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory) {
171   mutex_lock lock(nccl_factory_mutex);
172   if (nccl_factory == nullptr) {
173     nccl_factory = new std::shared_ptr<NcclUniqueIdFactory>();
174   }
175   *nccl_factory = std::move(factory);
176 }
177 
GetNcclUniqueIdFactory()178 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory() {
179   mutex_lock lock(nccl_factory_mutex);
180   return nccl_factory != nullptr ? *nccl_factory : nullptr;
181 }
182 
BuildXlaDebugOptions(const xla::DebugOptions & ref_options)183 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
184   static const bool options_passthrough = DebugOptionsPassThroughEnabled();
185   if (options_passthrough) {
186     return ref_options;
187   }
188   xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
189   options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to()));
190   options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto());
191   options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text());
192   options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots());
193   options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re());
194   options.set_xla_dump_include_timestamp(
195       ref_options.xla_dump_include_timestamp());
196   options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules());
197   for (auto& pass : ref_options.xla_disable_hlo_passes()) {
198     options.add_xla_disable_hlo_passes(pass);
199   }
200   return options;
201 }
202 
GetComputationInputs(OpKernelContext * context,const char * input_name)203 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
204     OpKernelContext* context, const char* input_name) {
205   OpInputList arg_list;
206   TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
207   // Concatenate all input uids from list of scalars-or-vectors carrying them.
208   std::vector<InputCoords> input_coords;
209   for (int i = 0; i < arg_list.size(); ++i) {
210     const Tensor& arg = arg_list[i];
211     if (TensorShapeUtils::IsScalar(arg.shape())) {
212       input_coords.emplace_back(arg.scalar<int64>()());
213     } else {
214       TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
215       auto arg_vec = arg.vec<int64>();
216       const int64_t num_elts = arg.shape().dim_size(0);
217       for (int i = 0; i < num_elts; ++i) {
218         input_coords.emplace_back(arg_vec(i));
219       }
220     }
221   }
222   return std::move(input_coords);
223 }
224 
InputShapeMatches(const xla::Shape & parameter_shape,const xla::Shape & input_shape)225 bool InputShapeMatches(const xla::Shape& parameter_shape,
226                        const xla::Shape& input_shape) {
227   auto shape_checker = [&](const xla::Shape& pshape,
228                            const xla::ShapeIndex& index) {
229     if (pshape.IsArray()) {
230       TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
231                           xla::ShapeUtil::TryGetSubshape(input_shape, index));
232       if (pshape.rank() != ishape->rank() ||
233           pshape.element_type() != ishape->element_type()) {
234         return errors::InvalidArgument("Mismatching shapes");
235       }
236       if (pshape.is_static() && pshape.layout() != ishape->layout()) {
237         return errors::InvalidArgument("Mismatching layouts");
238       }
239       for (int64_t dim = 0; dim < pshape.rank(); ++dim) {
240         if (pshape.is_dynamic_dimension(dim)) {
241           if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
242             return errors::InvalidArgument("Mismatching shapes");
243           }
244         } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
245           return errors::InvalidArgument("Mismatching shapes");
246         }
247       }
248     }
249     return Status::OK();
250   };
251   return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
252                                                    shape_checker)
253       .ok();
254 }
255 
GetInputTupleAllocations(const std::vector<InputCoords> & input_coords,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,int64_t num_input_shapes,const std::function<xla::Shape (int64_t)> & shape_getter,bool release_inputs,se::DeviceMemoryAllocator * allocator)256 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
257     const std::vector<InputCoords>& input_coords,
258     XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
259     int64_t num_input_shapes,
260     const std::function<xla::Shape(int64_t)>& shape_getter, bool release_inputs,
261     se::DeviceMemoryAllocator* allocator) {
262   if (input_coords.size() != num_input_shapes) {
263     return errors::InvalidArgument(
264         "Number of inputs does not match executable proto input shapes: ",
265         input_coords.size(), " vs. ", num_input_shapes);
266   }
267   std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
268   input_tuples.reserve(input_coords.size());
269   for (size_t i = 0; i < input_coords.size(); ++i) {
270     TF_RETURN_IF_ERROR(
271         working_set->LookupAndPin(backend, input_coords[i].handle, allocator));
272     auto tuple = working_set->PinnedTuples().back();
273     if (release_inputs) {
274       // We are holding a reference to the tuple, so we can safely delete it
275       // from the resource manager here.
276       TF_RETURN_IF_ERROR(
277           working_set->MemoryManager()->Release(input_coords[i].handle));
278       VLOG(2) << "Released allocation handle " << input_coords[i].handle;
279     }
280     xla::Shape input_shape = shape_getter(i);
281     if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
282       return errors::InvalidArgument(
283           "Run-time shape mismatch for XRTExecute argument[", i, "] (",
284           input_coords[i].handle, "). Expected ", input_shape.DebugString(),
285           "; got ", tuple->on_host_shape().DebugString());
286     }
287     if (input_coords[i].index.empty()) {
288       input_tuples.emplace_back(std::move(tuple));
289     } else {
290       XRTTupleAllocation* sub_tuple;
291       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
292           tuple.get(), input_coords[i].index, &sub_tuple,
293           /*alias_parent_allocation=*/true));
294       input_tuples.emplace_back(sub_tuple);
295     }
296   }
297   return std::move(input_tuples);
298 }
299 
RebuildOutputAliases(const RefPtr<XRTTupleAllocation> & output_tuple,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias)300 Status RebuildOutputAliases(
301     const RefPtr<XRTTupleAllocation>& output_tuple,
302     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
303     const xla::HloInputOutputAliasConfig& input_output_alias) {
304   auto alias_function =
305       [&](const xla::ShapeIndex& output_index,
306           const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
307     TF_RET_CHECK(alias.parameter_number < input_tuples.size());
308     return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
309                                          alias.parameter_index, output_index);
310   };
311   return input_output_alias.ForEachAliasWithStatus(alias_function);
312 }
313 
GetArgumentsBuffers(const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const std::vector<bool> & input_is_dynamic,bool release_inputs)314 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
315     const xla::HloInputOutputAliasConfig& input_output_alias,
316     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
317     const std::vector<bool>& input_is_dynamic, bool release_inputs) {
318   auto is_dynamic = [&](size_t arg) {
319     return arg < input_is_dynamic.size() && input_is_dynamic[arg];
320   };
321   std::vector<xla::ExecutionInput> arguments;
322   // Don't alias dynamic input -- Due to the underlying implementation,
323   // aliased inputs have two owners: XRTAllocation and return value of
324   // this function. If an argument is dynamic and the ownership is
325   // released to output of this function, TPUExecute will free it and
326   // reallocate a new one, which creates a double freeing issue where
327   // XRTAllocation also attempts to release the buffer.
328   bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
329                        input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
330   arguments.reserve(input_tuples.size());
331   for (int64_t i = 0; i < input_tuples.size(); ++i) {
332     auto alias_checker =
333         [&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
334       if (input_output_alias.ParameterHasAlias(i, index)) {
335         TF_RET_CHECK(!is_dynamic(i));
336         return true;
337       }
338       return alias_outputs;
339     };
340     TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
341                         input_tuples[i]->ToExecutionInput(alias_checker));
342     arguments.emplace_back(std::move(exec_input));
343   }
344   return std::move(arguments);
345 }
346 
CreateExecuteOutput(OpKernelContext * context,XRTMemoryManager * memory_manager,RefPtr<XRTTupleAllocation> output_tuple,bool return_exploded_tuple)347 Status CreateExecuteOutput(OpKernelContext* context,
348                            XRTMemoryManager* memory_manager,
349                            RefPtr<XRTTupleAllocation> output_tuple,
350                            bool return_exploded_tuple) {
351   if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
352     int64_t tuple_element_count =
353         xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
354     Tensor* output_tensor;
355     TF_RETURN_IF_ERROR(context->allocate_output(
356         0, TensorShape({tuple_element_count}), &output_tensor));
357 
358     for (int64_t i = 0; i < tuple_element_count; ++i) {
359       XRTTupleAllocation* suballocation;
360       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
361           output_tuple.get(), {i}, &suballocation,
362           /*alias_parent_allocation=*/false));
363       output_tensor->vec<int64>()(i) = memory_manager->Register(suballocation);
364     }
365   } else {
366     Tensor* output_tensor;
367     TF_RETURN_IF_ERROR(
368         context->allocate_output(0, TensorShape({}), &output_tensor));
369     output_tensor->scalar<int64>()() =
370         memory_manager->Register(std::move(output_tuple));
371   }
372   return Status::OK();
373 }
374 
ExecuteChained(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,xla::Backend * backend,int device_ordinal,const xrt::XRTChainedExecutePlan & plan,const xrt::XRTChainedExecuteConfig & config,const ChainedExecuteFn & execute_op,se::DeviceMemoryAllocator * allocator)375 Status ExecuteChained(OpKernelContext* context,
376                       const RefPtr<XRTMemoryManager>& memory_manager,
377                       xla::Backend* backend, int device_ordinal,
378                       const xrt::XRTChainedExecutePlan& plan,
379                       const xrt::XRTChainedExecuteConfig& config,
380                       const ChainedExecuteFn& execute_op,
381                       se::DeviceMemoryAllocator* allocator) {
382   // Create the vector which tracks the uses of the intermediate chained
383   // operations outputs.
384   std::vector<int64> uses(plan.ops_size(), 0);
385   for (auto& op : plan.ops()) {
386     for (auto& input : op.inputs()) {
387       uses[input.op_index()] += 1;
388     }
389   }
390 
391   ScopedHandles outputs(memory_manager);
392   ScopedHandles results(memory_manager);
393   for (int i = 0; i < plan.ops_size(); ++i) {
394     auto& op = plan.ops(i);
395     if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
396       // This operation is a device data load. Set the handle as output and
397       // leave the release flag off, since this is not an intermediate output.
398       TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
399     } else if (op.op_oneof_case() ==
400                xrt::XRTChainedExecuteOp::kComputationHandle) {
401       // This is an XRT execute operation, forward to the device specific
402       // handler. Populating the working set makes sure the input allocations
403       // for this execute operations are pinned to device memory.
404       XRTMemoryManager::WorkingSet working_set(memory_manager);
405       TF_RETURN_IF_ERROR(PopulateOpWorkingSet(backend, op, i, outputs,
406                                               &working_set, allocator));
407       TF_ASSIGN_OR_RETURN(auto tuple,
408                           execute_op(op, working_set.PinnedTuples()));
409       TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
410     } else {
411       return errors::InvalidArgument(
412           "Undefined operation kind at post-order position ", i);
413     }
414     // If the result of this chained operation is an output result, feed the
415     // results at the desired position.
416     for (auto& output : op.outputs()) {
417       TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
418       RefPtr<XRTTupleAllocation> result;
419       TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
420       TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
421     }
422     // Drop intermediate results which have no more users.
423     for (auto& input : op.inputs()) {
424       uses[input.op_index()] -= 1;
425       if (uses[input.op_index()] == 0) {
426         TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
427       }
428     }
429   }
430 
431   Tensor* output_tensor;
432   TF_RETURN_IF_ERROR(context->allocate_output(
433       0, TensorShape({static_cast<int64>(results.size())}), &output_tensor));
434   for (size_t i = 0; i < results.size(); ++i) {
435     output_tensor->vec<int64>()(i) = results.Release(i);
436   }
437   return Status::OK();
438 }
439 
440 }  // namespace tensorflow
441