• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_util.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 mutex nccl_factory_mutex(LINKER_INITIALIZED);
30 std::shared_ptr<NcclUniqueIdFactory>* nccl_factory;
31 
32 // The ScopedHandles data structure is used in the ExecuteChained() API and its
33 // task is to track tuple allocation registrations. It is used both the track
34 // intermediate results of a chained computation, or its final results. Anything
35 // which is marked to be released, will be released using the XRTMemoryManager
36 // once the object is destroyed (unless an explicit call to Drop() or Release()
37 // is made).
38 class ScopedHandles {
39  public:
ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)40   explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
41       : memory_manager_(std::move(memory_manager)) {}
42 
~ScopedHandles()43   ~ScopedHandles() {
44     for (size_t i = 0; i < handles_.size(); ++i) {
45       if (handles_release_[i]) {
46         memory_manager_->Release(handles_[i]).IgnoreError();
47       }
48     }
49   }
50 
operator [](size_t index) const51   int64 operator[](size_t index) const { return handles_.at(index); }
52 
size() const53   size_t size() const { return handles_.size(); }
54 
55   // Adds the given handle at the index position, by marking it releasable
56   // according to the release argument. If an existing, and to-be-released
57   // handle already exists at the same index, it will be released.
Add(size_t index,int64 handle,bool release)58   Status Add(size_t index, int64 handle, bool release) {
59     if (index >= handles_.size()) {
60       handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
61       handles_release_.resize(index + 1, false);
62     }
63     if (handles_release_[index]) {
64       Status status = memory_manager_->Release(handles_[index]);
65       if (!status.ok()) {
66         if (release) {
67           memory_manager_->Release(handle).IgnoreError();
68         }
69         return status;
70       }
71     }
72     handles_[index] = handle;
73     handles_release_[index] = release;
74     return Status::OK();
75   }
76 
77   // Adds a to-be-released tuple allocation at the given index.
Add(size_t index,RefPtr<XRTTupleAllocation> tuple)78   Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
79     return Add(index, memory_manager_->Register(std::move(tuple)),
80                /*release=*/true);
81   }
82 
83   // Drops the handle at the given index, and releases it using the
84   // XRTMemoryManager::Release() if marked as to-be-released.
Drop(size_t index)85   Status Drop(size_t index) {
86     if (handles_release_.at(index)) {
87       TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
88     }
89     Release(index);
90     return Status::OK();
91   }
92 
93   // Releases the handle at the given index. The destructor will not use that
94   // XRTMemoryManager::Release() API on such handle.
Release(size_t index)95   int64 Release(size_t index) {
96     int64 handle = handles_.at(index);
97     handles_[index] = XRTMemoryManager::InvalidKey();
98     handles_release_[index] = false;
99     return handle;
100   }
101 
102   // Looks up the handle stored at the given index, and returns the matching
103   // tuple allocation.
Lookup(size_t index) const104   xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
105     return memory_manager_->Lookup(handles_.at(index));
106   }
107 
108  private:
109   RefPtr<XRTMemoryManager> memory_manager_;
110   std::vector<int64> handles_;
111   std::vector<bool> handles_release_;
112 };
113 
DebugOptionsPassThroughEnabled()114 bool DebugOptionsPassThroughEnabled() {
115   const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
116   bool enabled =
117       env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
118   if (enabled) {
119     LOG(WARNING) << "Passing through XLA debug options!";
120   } else {
121     LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
122                     "will be retained";
123   }
124   return enabled;
125 }
126 
SafeDebugPath(const string & path)127 string SafeDebugPath(const string& path) {
128   if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
129       path.compare(0, 11, "bigstore://") == 0) {
130     return path;
131   }
132   LOG(WARNING) << "Invalid config path (will be dropped): " << path;
133   return string();
134 }
135 
MakeOutput(const RefPtr<XRTTupleAllocation> & output,int64 index,RefPtr<XRTTupleAllocation> * result)136 Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64 index,
137                   RefPtr<XRTTupleAllocation>* result) {
138   if (index == 0) {
139     *result = output;
140   } else {
141     XRTTupleAllocation* tuple;
142     TF_RETURN_IF_ERROR(
143         XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple,
144                                           /*alias_parent_allocation=*/true));
145     result->reset(tuple);
146   }
147   return Status::OK();
148 }
149 
PopulateOpWorkingSet(xla::Backend * backend,const xrt::XRTChainedExecuteOp & op,int current_index,const ScopedHandles & outputs,XRTMemoryManager::WorkingSet * working_set)150 Status PopulateOpWorkingSet(xla::Backend* backend,
151                             const xrt::XRTChainedExecuteOp& op,
152                             int current_index, const ScopedHandles& outputs,
153                             XRTMemoryManager::WorkingSet* working_set) {
154   for (int i = 0; i < op.inputs_size(); ++i) {
155     auto& input = op.inputs(i);
156     if (input.op_index() >= current_index) {
157       return errors::InvalidArgument(
158           "Input index ", input.op_index(),
159           " is above the current position: ", current_index);
160     }
161     TF_RETURN_IF_ERROR(
162         working_set->LookupAndPin(backend, outputs[input.op_index()]));
163   }
164   return Status::OK();
165 }
166 
167 }  // namespace
168 
SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory)169 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory) {
170   mutex_lock lock(nccl_factory_mutex);
171   if (nccl_factory == nullptr) {
172     nccl_factory = new std::shared_ptr<NcclUniqueIdFactory>();
173   }
174   *nccl_factory = std::move(factory);
175 }
176 
GetNcclUniqueIdFactory()177 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory() {
178   mutex_lock lock(nccl_factory_mutex);
179   return nccl_factory != nullptr ? *nccl_factory : nullptr;
180 }
181 
BuildXlaDebugOptions(const xla::DebugOptions & ref_options)182 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
183   static const bool options_passthrough = DebugOptionsPassThroughEnabled();
184   if (options_passthrough) {
185     return ref_options;
186   }
187   xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
188   options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to()));
189   options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto());
190   options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text());
191   options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots());
192   options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re());
193   options.set_xla_dump_include_timestamp(
194       ref_options.xla_dump_include_timestamp());
195   options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules());
196   for (auto& pass : ref_options.xla_disable_hlo_passes()) {
197     options.add_xla_disable_hlo_passes(pass);
198   }
199   return options;
200 }
201 
GetComputationInputs(OpKernelContext * context,const char * input_name)202 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
203     OpKernelContext* context, const char* input_name) {
204   OpInputList arg_list;
205   TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
206   // Concatenate all input uids from list of scalars-or-vectors carrying them.
207   std::vector<InputCoords> input_coords;
208   for (int i = 0; i < arg_list.size(); ++i) {
209     const Tensor& arg = arg_list[i];
210     if (TensorShapeUtils::IsScalar(arg.shape())) {
211       input_coords.emplace_back(arg.scalar<int64>()());
212     } else {
213       TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
214       auto arg_vec = arg.vec<int64>();
215       const int64 num_elts = arg.shape().dim_size(0);
216       for (int i = 0; i < num_elts; ++i) {
217         input_coords.emplace_back(arg_vec(i));
218       }
219     }
220   }
221   return std::move(input_coords);
222 }
223 
InputShapeMatches(const xla::Shape & parameter_shape,const xla::Shape & input_shape)224 bool InputShapeMatches(const xla::Shape& parameter_shape,
225                        const xla::Shape& input_shape) {
226   auto shape_checker = [&](const xla::Shape& pshape,
227                            const xla::ShapeIndex& index) {
228     if (pshape.IsArray()) {
229       TF_ASSIGN_OR_RETURN(const xla::Shape* ishape,
230                           xla::ShapeUtil::TryGetSubshape(input_shape, index));
231       if (pshape.rank() != ishape->rank() ||
232           pshape.element_type() != ishape->element_type()) {
233         return errors::InvalidArgument("Mismatching shapes");
234       }
235       if (pshape.is_static() && pshape.layout() != ishape->layout()) {
236         return errors::InvalidArgument("Mismatching layouts");
237       }
238       for (int64 dim = 0; dim < pshape.rank(); ++dim) {
239         if (pshape.is_dynamic_dimension(dim)) {
240           if (pshape.dimensions(dim) < ishape->dimensions(dim)) {
241             return errors::InvalidArgument("Mismatching shapes");
242           }
243         } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) {
244           return errors::InvalidArgument("Mismatching shapes");
245         }
246       }
247     }
248     return Status::OK();
249   };
250   return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape,
251                                                    shape_checker)
252       .ok();
253 }
254 
GetInputTupleAllocations(const std::vector<InputCoords> & input_coords,XRTMemoryManager::WorkingSet * working_set,xla::Backend * backend,int64 num_input_shapes,const std::function<xla::Shape (int64)> & shape_getter,bool release_inputs)255 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
256     const std::vector<InputCoords>& input_coords,
257     XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
258     int64 num_input_shapes,
259     const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs) {
260   if (input_coords.size() != num_input_shapes) {
261     return errors::InvalidArgument(
262         "Number of inputs does not match executable proto input shapes: ",
263         input_coords.size(), " vs. ", num_input_shapes);
264   }
265   std::vector<RefPtr<XRTTupleAllocation>> input_tuples;
266   input_tuples.reserve(input_coords.size());
267   for (size_t i = 0; i < input_coords.size(); ++i) {
268     TF_RETURN_IF_ERROR(
269         working_set->LookupAndPin(backend, input_coords[i].handle));
270     auto tuple = working_set->PinnedTuples().back();
271     if (release_inputs) {
272       // We are holding a reference to the tuple, so we can safely delete it
273       // from the resource manager here.
274       TF_RETURN_IF_ERROR(
275           working_set->MemoryManager()->Release(input_coords[i].handle));
276       VLOG(2) << "Released allocation handle " << input_coords[i].handle;
277     }
278     xla::Shape input_shape = shape_getter(i);
279     if (!InputShapeMatches(input_shape, tuple->on_host_shape())) {
280       return errors::InvalidArgument(
281           "Run-time shape mismatch for XRTExecute argument[", i, "] (",
282           input_coords[i].handle, "). Expected ", input_shape.DebugString(),
283           "; got ", tuple->on_host_shape().DebugString());
284     }
285     if (input_coords[i].index.empty()) {
286       input_tuples.emplace_back(std::move(tuple));
287     } else {
288       XRTTupleAllocation* sub_tuple;
289       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
290           tuple.get(), input_coords[i].index, &sub_tuple,
291           /*alias_parent_allocation=*/true));
292       input_tuples.emplace_back(sub_tuple);
293     }
294   }
295   return std::move(input_tuples);
296 }
297 
RebuildOutputAliases(const RefPtr<XRTTupleAllocation> & output_tuple,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const xla::HloInputOutputAliasConfig & input_output_alias)298 Status RebuildOutputAliases(
299     const RefPtr<XRTTupleAllocation>& output_tuple,
300     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
301     const xla::HloInputOutputAliasConfig& input_output_alias) {
302   auto alias_function =
303       [&](const xla::ShapeIndex& output_index,
304           const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
305     TF_RET_CHECK(alias.parameter_number < input_tuples.size());
306     return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
307                                          alias.parameter_index, output_index);
308   };
309   return input_output_alias.ForEachAliasWithStatus(alias_function);
310 }
311 
GetArgumentsBuffers(const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,const std::vector<bool> & input_is_dynamic,bool release_inputs)312 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
313     const xla::HloInputOutputAliasConfig& input_output_alias,
314     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
315     const std::vector<bool>& input_is_dynamic, bool release_inputs) {
316   auto is_dynamic = [&](size_t arg) {
317     return arg < input_is_dynamic.size() && input_is_dynamic[arg];
318   };
319   std::vector<xla::ExecutionInput> arguments;
320   // Don't alias dynamic input -- Due to the underlying implementation,
321   // aliased inputs have two owners: XRTAllocation and return value of
322   // this function. If an argument is dynamic and the ownership is
323   // released to output of this function, TPUExecute will free it and
324   // reallocate a new one, which creates a double freeing issue where
325   // XRTAllocation also attempts to release the buffer.
326   bool alias_outputs = release_inputs && input_tuples.size() == 1 &&
327                        input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0);
328   arguments.reserve(input_tuples.size());
329   for (int64 i = 0; i < input_tuples.size(); ++i) {
330     auto alias_checker =
331         [&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
332       if (input_output_alias.ParameterHasAlias(i, index)) {
333         TF_RET_CHECK(!is_dynamic(i));
334         return true;
335       }
336       return alias_outputs;
337     };
338     TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input,
339                         input_tuples[i]->ToExecutionInput(alias_checker));
340     arguments.emplace_back(std::move(exec_input));
341   }
342   return std::move(arguments);
343 }
344 
CreateExecuteOutput(OpKernelContext * context,XRTMemoryManager * memory_manager,RefPtr<XRTTupleAllocation> output_tuple,bool return_exploded_tuple)345 Status CreateExecuteOutput(OpKernelContext* context,
346                            XRTMemoryManager* memory_manager,
347                            RefPtr<XRTTupleAllocation> output_tuple,
348                            bool return_exploded_tuple) {
349   if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
350     int64 tuple_element_count =
351         xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape());
352     Tensor* output_tensor;
353     TF_RETURN_IF_ERROR(context->allocate_output(
354         0, TensorShape({tuple_element_count}), &output_tensor));
355 
356     for (int64 i = 0; i < tuple_element_count; ++i) {
357       XRTTupleAllocation* suballocation;
358       TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
359           output_tuple.get(), {i}, &suballocation,
360           /*alias_parent_allocation=*/false));
361       output_tensor->vec<int64>()(i) = memory_manager->Register(suballocation);
362     }
363   } else {
364     Tensor* output_tensor;
365     TF_RETURN_IF_ERROR(
366         context->allocate_output(0, TensorShape({}), &output_tensor));
367     output_tensor->scalar<int64>()() =
368         memory_manager->Register(std::move(output_tuple));
369   }
370   return Status::OK();
371 }
372 
ExecuteChained(OpKernelContext * context,const RefPtr<XRTMemoryManager> & memory_manager,xla::Backend * backend,int device_ordinal,const xrt::XRTChainedExecutePlan & plan,const xrt::XRTChainedExecuteConfig & config,const ChainedExecuteFn & execute_op)373 Status ExecuteChained(OpKernelContext* context,
374                       const RefPtr<XRTMemoryManager>& memory_manager,
375                       xla::Backend* backend, int device_ordinal,
376                       const xrt::XRTChainedExecutePlan& plan,
377                       const xrt::XRTChainedExecuteConfig& config,
378                       const ChainedExecuteFn& execute_op) {
379   // Create the vector which tracks the uses of the intermediate chained
380   // operations outputs.
381   std::vector<int64> uses(plan.ops_size(), 0);
382   for (auto& op : plan.ops()) {
383     for (auto& input : op.inputs()) {
384       uses[input.op_index()] += 1;
385     }
386   }
387 
388   ScopedHandles outputs(memory_manager);
389   ScopedHandles results(memory_manager);
390   for (int i = 0; i < plan.ops_size(); ++i) {
391     auto& op = plan.ops(i);
392     if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
393       // This operation is a device data load. Set the handle as output and
394       // leave the release flag off, since this is not an intermediate output.
395       TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
396     } else if (op.op_oneof_case() ==
397                xrt::XRTChainedExecuteOp::kComputationHandle) {
398       // This is an XRT execute operation, forward to the device specific
399       // handler. Populating the working set makes sure the input allocations
400       // for this execute operations are pinned to device memory.
401       XRTMemoryManager::WorkingSet working_set(memory_manager);
402       TF_RETURN_IF_ERROR(
403           PopulateOpWorkingSet(backend, op, i, outputs, &working_set));
404       TF_ASSIGN_OR_RETURN(auto tuple,
405                           execute_op(op, working_set.PinnedTuples()));
406       TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
407     } else {
408       return errors::InvalidArgument(
409           "Undefined operation kind at post-order position ", i);
410     }
411     // If the result of this chained operation is an output result, feed the
412     // results at the desired position.
413     for (auto& output : op.outputs()) {
414       TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
415       RefPtr<XRTTupleAllocation> result;
416       TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
417       TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
418     }
419     // Drop intermediate results which have no more users.
420     for (auto& input : op.inputs()) {
421       uses[input.op_index()] -= 1;
422       if (uses[input.op_index()] == 0) {
423         TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
424       }
425     }
426   }
427 
428   Tensor* output_tensor;
429   TF_RETURN_IF_ERROR(context->allocate_output(
430       0, TensorShape({static_cast<int64>(results.size())}), &output_tensor));
431   for (size_t i = 0; i < results.size(); ++i) {
432     output_tensor->vec<int64>()(i) = results.Release(i);
433   }
434   return Status::OK();
435 }
436 
437 }  // namespace tensorflow
438