• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/compiler/xla/service/hlo_runner.h"
18 
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/service/executable.h"
26 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 
HloRunner(se::Platform * platform,int intra_op_parallelism_threads)37 HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
38   BackendOptions backend_options;
39   backend_options.set_platform(platform);
40   backend_options.set_intra_op_parallelism_threads(
41       intra_op_parallelism_threads);
42   backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
43   device_shape_representation_fn_ = [this](const Shape& shape) {
44     return backend_->compiler()->DeviceShapeRepresentation(shape);
45   };
46   VLOG(1) << "Created HloRunner for platform: " << platform->Name();
47 }
48 
~HloRunner()49 HloRunner::~HloRunner() {}
50 
TransferLiteralToDevice(const Literal & literal)51 StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
52     const Literal& literal) {
53   TF_ASSIGN_OR_RETURN(
54       ScopedShapedBuffer buffer,
55       backend().transfer_manager()->AllocateScopedShapedBuffer(
56           literal.shape(), backend().memory_allocator(),
57           backend().default_device_ordinal(), device_shape_representation_fn_));
58   TF_ASSIGN_OR_RETURN(
59       auto stream, backend().BorrowStream(backend().default_stream_executor()));
60   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
61       stream.get(), literal, buffer));
62   return std::move(buffer);
63 }
64 
TransferLiteralsToDevice(absl::Span<const Literal * const> literals)65 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
66     absl::Span<const Literal* const> literals) {
67   std::vector<ScopedShapedBuffer> buffers;
68   for (const Literal* literal : literals) {
69     CHECK(literal != nullptr);
70     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
71                         TransferLiteralToDevice(*literal));
72     buffers.push_back(std::move(buffer));
73   }
74   return std::move(buffers);
75 }
76 
TransferLiteralsToDevice(absl::Span<const Literal> literals)77 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
78     absl::Span<const Literal> literals) {
79   std::vector<const Literal*> literal_pointers;
80   literal_pointers.reserve(literals.size());
81   for (const auto& literal : literals) {
82     literal_pointers.push_back(&literal);
83   }
84   return TransferLiteralsToDevice(literal_pointers);
85 }
86 
TransferLiteralFromDevice(const ShapedBuffer & buffer)87 StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
88     const ShapedBuffer& buffer) {
89   TF_ASSIGN_OR_RETURN(
90       auto stream, backend().BorrowStream(backend().default_stream_executor()));
91   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
92                                                                  buffer);
93 }
94 
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal * const> arguments,bool run_hlo_passes,ExecutionProfile * profile)95 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
96                                      absl::Span<const Literal* const> arguments,
97                                      bool run_hlo_passes,
98                                      ExecutionProfile* profile) {
99   UpdateEntryComputationLayout(module.get(), device_shape_representation_fn_);
100 
101   TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
102                       TransferLiteralsToDevice(arguments));
103   TF_ASSIGN_OR_RETURN(ExecutionOutput result,
104                       ExecuteWithDeviceBuffers(
105                           /*module=*/std::move(module),
106                           /*arguments=*/argument_buffers,
107                           /*run_hlo_passes=*/run_hlo_passes,
108                           /*profile=*/profile));
109   return TransferLiteralFromDevice(result.Result());
110 }
111 
ExecuteWithExecutable(Executable * executable,absl::Span<const Literal * const> arguments,ExecutionProfile * profile)112 StatusOr<Literal> HloRunner::ExecuteWithExecutable(
113     Executable* executable, absl::Span<const Literal* const> arguments,
114     ExecutionProfile* profile) {
115   TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
116                       TransferLiteralsToDevice(arguments));
117   TF_ASSIGN_OR_RETURN(ExecutionOutput result,
118                       ExecuteWithDeviceBuffers(
119                           /*executable=*/executable,
120                           /*arguments=*/argument_buffers,
121                           /*profile=*/profile));
122   return TransferLiteralFromDevice(result.Result());
123 }
124 
125 // Convert the owning buffer of inputs into a (partially) owning vector of
126 // ExecutionInputs, and an owning vector of `OwningDeviceMemory`'s.
ExecutionInputsFromScopedShapedBuffers(absl::Span<ScopedShapedBuffer const> inputs,HloInputOutputAliasConfig alias_config,int device_ordinal,se::DeviceMemoryAllocator * allocator)127 static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
128     absl::Span<ScopedShapedBuffer const> inputs,
129     HloInputOutputAliasConfig alias_config, int device_ordinal,
130     se::DeviceMemoryAllocator* allocator) {
131   std::vector<ExecutionInput> execution_inputs;
132   std::vector<se::OwningDeviceMemory> owned_args;
133 
134   for (int param_num = 0; param_num < inputs.size(); param_num++) {
135     const ScopedShapedBuffer& input_buffer = inputs[param_num];
136     ShapeTree<MaybeOwningDeviceMemory> buffer_tree(
137         input_buffer.on_device_shape());
138 
139     input_buffer.buffers().ForEachElement(
140         [&](const ShapeIndex& index,
141             const se::DeviceMemoryBase& execution_input_buffer) {
142           if (alias_config.ParameterHasAlias(param_num, index)) {
143             // Store owned.
144             *buffer_tree.mutable_element(index) = se::OwningDeviceMemory{
145                 execution_input_buffer, device_ordinal, allocator};
146           } else {
147             // Store unowned.
148             *buffer_tree.mutable_element(index) = execution_input_buffer;
149           }
150         });
151     execution_inputs.emplace_back(std::move(buffer_tree));
152   }
153   return execution_inputs;
154 }
155 
ExecuteWithDeviceBuffers(std::unique_ptr<HloModule> module,absl::Span<ScopedShapedBuffer const> arguments,bool run_hlo_passes,ExecutionProfile * profile)156 StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
157     std::unique_ptr<HloModule> module,
158     absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
159     ExecutionProfile* profile) {
160   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
161                       CreateExecutable(std::move(module), run_hlo_passes));
162   return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
163 }
164 
ExecuteWithDeviceBuffers(Executable * executable,absl::Span<ScopedShapedBuffer const> arguments,ExecutionProfile * profile)165 StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
166     Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
167     ExecutionProfile* profile) {
168   // Get service run options.
169   se::Stream stream(backend().default_stream_executor());
170   stream.Init();
171   ServiceExecutableRunOptions service_run_options =
172       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
173                                     nullptr, RunId());
174   service_run_options.mutable_run_options()->set_execution_profile(profile);
175 
176   std::vector<ExecutionInput> execution_arguments =
177       ExecutionInputsFromScopedShapedBuffers(
178           arguments, executable->module().input_output_alias_config(),
179           stream.parent()->device_ordinal(), stream.parent()->GetAllocator());
180 
181   TF_ASSIGN_OR_RETURN(
182       ExecutionOutput retval,
183       executable->ExecuteOnStreamWrapper(&service_run_options,
184                                          std::move(execution_arguments)));
185   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
186   return std::move(retval);
187 }
188 
ExecuteReplicated(std::unique_ptr<HloModule> module,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment)189 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
190     std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
191     DeviceAssignment* device_assignment) {
192   TF_ASSIGN_OR_RETURN(
193       std::unique_ptr<Executable> executable,
194       CreateExecutable(std::move(module), options.run_hlo_passes));
195   return ExecuteReplicated(executable.get(), options, device_assignment);
196 }
197 
ExecuteReplicatedImpl(std::function<StatusOr<std::vector<ScopedShapedBuffer>> (const std::vector<ServiceExecutableRunOptions> &,const std::vector<absl::Span<const ShapedBuffer * const>> &)> execution_helper,std::function<int64 (int64_t)> argument_count_provider,std::function<const Literal * (int64_t,int64_t)> argument_provider,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment)198 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicatedImpl(
199     std::function<StatusOr<std::vector<ScopedShapedBuffer>>(
200         const std::vector<ServiceExecutableRunOptions>&,
201         const std::vector<absl::Span<const ShapedBuffer* const>>&)>
202         execution_helper,
203     std::function<int64(int64_t)> argument_count_provider,
204     std::function<const Literal*(int64_t, int64_t)> argument_provider,
205     const ReplicatedExecuteOptions& options,
206     DeviceAssignment* device_assignment) {
207   std::vector<std::unique_ptr<se::Stream>> streams;
208   std::vector<ServiceExecutableRunOptions> service_run_options;
209   int64_t num_partitions = device_assignment->computation_count();
210 
211   std::vector<ScopedShapedBuffer> argument_buffers;
212   // This reserve() call is necessary for correctness, because
213   // argument_buffer_ptrs contains pointers into the elements of
214   // argument_buffers.
215   const int64_t total_argument_count = [&]() {
216     int64_t total = 0;
217     for (int64_t i = 0; i < options.num_replicas; ++i) {
218       total += argument_count_provider(i);
219     }
220     return total;
221   }();
222   argument_buffers.reserve(total_argument_count);
223 
224   // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are
225   // no arguments.
226   std::vector<const ShapedBuffer*> argument_buffer_ptrs(total_argument_count +
227                                                         1);
228   std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
229   int64_t index = 0;
230   RunId run_id;
231   for (int64_t i = 0; i < options.num_replicas; ++i) {
232     int64_t device =
233         (*device_assignment)(i / num_partitions, i % num_partitions);
234     TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
235                         backend().stream_executor(device));
236     streams.push_back(absl::make_unique<se::Stream>(executor));
237     streams.back()->Init();
238     service_run_options.emplace_back(GetServiceRunOptionsForDevice(
239         device, streams.back().get(), device_assignment, run_id));
240 
241     // Copy arguments to device.
242     const int64_t argument_count = argument_count_provider(i);
243     for (int64_t arg_index = 0; arg_index < argument_count; arg_index++) {
244       const Literal* const argument = argument_provider(i, arg_index);
245       TF_RET_CHECK(argument != nullptr);
246       TF_ASSIGN_OR_RETURN(
247           ScopedShapedBuffer argument_buffer,
248           backend().transfer_manager()->AllocateScopedShapedBuffer(
249               argument->shape(), backend().memory_allocator(), device,
250               device_shape_representation_fn_));
251       TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
252           streams.back().get(), *argument, argument_buffer));
253       argument_buffers.push_back(std::move(argument_buffer));
254       argument_buffer_ptrs[index++] = &argument_buffers.back();
255     }
256     argument_buffer_slices.emplace_back(
257         &argument_buffer_ptrs[index - argument_count], argument_count);
258   }
259 
260   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
261   TF_RET_CHECK(options.infeed_values.empty() ||
262                options.infeed_values.size() == options.num_replicas);
263   int64_t num_threads = options.infeed_values.size();
264   if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
265     num_threads += options.num_replicas;
266   }
267   if (num_threads > 0) {
268     pool = absl::make_unique<tensorflow::thread::ThreadPool>(
269         tensorflow::Env::Default(), "infeed_outfeed",
270         /*num_threads=*/num_threads);
271   }
272   if (!options.infeed_values.empty()) {
273     for (int64_t i = 0; i < options.num_replicas; ++i) {
274       int64_t device =
275           (*device_assignment)(i / num_partitions, i % num_partitions);
276       pool->Schedule([this, device, &options, i]() {
277         se::StreamExecutor* executor =
278             backend().stream_executor(device).ValueOrDie();
279         VLOG(1) << "Starting infeed on device " << device;
280         for (int64_t step = 1;
281              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
282           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed(
283               executor, *options.infeed_values[i]));
284           if (step % 100 == 0) {
285             VLOG(1) << "Infeed step " << step;
286           }
287         }
288       });
289     }
290   }
291   if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
292     if (options.outfeed_values) {
293       options.outfeed_values->resize(options.num_replicas);
294     }
295     for (int64_t i = 0; i < options.num_replicas; ++i) {
296       int64_t device =
297           (*device_assignment)(i / num_partitions, i % num_partitions);
298       pool->Schedule([this, device, &options, i]() {
299         se::StreamExecutor* executor =
300             backend().stream_executor(device).ValueOrDie();
301         VLOG(1) << "Starting outfeed on device " << device;
302         for (int64_t step = 1;
303              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
304           Literal literal(options.outfeed_shape);
305           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
306               executor, &literal));
307           if (options.outfeed_values) {
308             options.outfeed_values->at(i) = std::move(literal);
309           }
310           if (step % 100 == 0) {
311             VLOG(1) << "Outfeed step " << step;
312           }
313         }
314       });
315     }
316   }
317 
318   LOG(INFO) << "Replicated execution started";
319   TF_ASSIGN_OR_RETURN(
320       std::vector<ScopedShapedBuffer> results,
321       execution_helper(service_run_options, argument_buffer_slices));
322   LOG(INFO) << "Replicated execution terminated";
323 
324   std::vector<Literal> exec_results;
325   for (int64_t i = 0; i < options.num_replicas; ++i) {
326     TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
327     TF_ASSIGN_OR_RETURN(Literal literal,
328                         backend().transfer_manager()->TransferLiteralFromDevice(
329                             streams[i].get(), results[i]));
330     exec_results.push_back(std::move(literal));
331   }
332   return std::move(exec_results);
333 }
334 
ExecuteReplicated(Executable * executable,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment,ExecutionProfile * profile)335 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
336     Executable* executable, const ReplicatedExecuteOptions& options,
337     DeviceAssignment* device_assignment, ExecutionProfile* profile) {
338   return ExecuteReplicatedImpl(
339       [&](const std::vector<ServiceExecutableRunOptions>& service_run_options,
340           const std::vector<absl::Span<const ShapedBuffer* const>>&
341               argument_buffer_slices)
342           -> StatusOr<std::vector<ScopedShapedBuffer>> {
343         std::vector<ScopedShapedBuffer> results;
344         if (!options.use_threads) {
345           TF_ASSIGN_OR_RETURN(
346               results, executable->ExecuteOnStreams(service_run_options,
347                                                     argument_buffer_slices));
348         } else {
349           tensorflow::mutex mutex;
350           std::vector<StatusOr<ScopedShapedBuffer>> thread_results(
351               options.num_replicas);
352           {
353             LOG(INFO) << "Creating thread pool for " << options.num_replicas
354                       << " replicas";
355             tensorflow::thread::ThreadPool pool(
356                 tensorflow::Env::Default(), "replicas", options.num_replicas);
357             for (int64_t i = 0; i < options.num_replicas; ++i) {
358               pool.Schedule([&, i] {
359                 auto result = executable->ExecuteOnStream(
360                     &service_run_options[i], argument_buffer_slices[i],
361                     nullptr);
362                 tensorflow::mutex_lock lock(mutex);
363                 thread_results[i] = std::move(result);
364               });
365             }
366 
367             // Note: the thread pool destructor guarantees it completes all work
368             // before we leave this scope.
369           }
370           for (auto& thread_result : thread_results) {
371             if (!thread_result.ok()) {
372               return thread_result.status();
373             }
374             results.push_back(std::move(thread_result).ValueOrDie());
375           }
376         }
377         return results;
378       },
379       [&](int64_t replica) { return options.arguments.size(); },
380       [&](int64_t replica, int64_t index) { return options.arguments[index]; },
381       options, device_assignment);
382 }
383 
ExecuteReplicated(std::function<Executable * (int64_t)> executable_provider,std::function<int64 (int64_t)> argument_count_provider,std::function<const Literal * (int64_t,int64_t)> argument_provider,const ReplicatedExecuteOptions & options)384 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
385     std::function<Executable*(int64_t)> executable_provider,
386     std::function<int64(int64_t)> argument_count_provider,
387     std::function<const Literal*(int64_t, int64_t)> argument_provider,
388     const ReplicatedExecuteOptions& options) {
389   TF_ASSIGN_OR_RETURN(
390       DeviceAssignment device_assignment,
391       backend().computation_placer()->AssignDevices(options.num_replicas, 1));
392   return ExecuteReplicatedImpl(
393       [&](const std::vector<ServiceExecutableRunOptions>& service_run_options,
394           const std::vector<absl::Span<const ShapedBuffer* const>>&
395               argument_buffer_slices)
396           -> StatusOr<std::vector<ScopedShapedBuffer>> {
397         TF_RET_CHECK(options.use_threads);
398         std::vector<ScopedShapedBuffer> results;
399         tensorflow::mutex mutex;
400         std::vector<StatusOr<ScopedShapedBuffer>> thread_results(
401             options.num_replicas);
402         {
403           LOG(INFO) << "Creating thread pool for " << options.num_replicas
404                     << " replicas";
405           tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
406                                               "replicas", options.num_replicas);
407           for (int64_t i = 0; i < options.num_replicas; ++i) {
408             for (const auto& arg : argument_buffer_slices[i]) {
409               TF_RET_CHECK(arg != nullptr);
410             }
411             pool.Schedule([&, i] {
412               auto result = executable_provider(i)->ExecuteOnStream(
413                   &service_run_options[i], argument_buffer_slices[i], nullptr);
414               tensorflow::mutex_lock lock(mutex);
415               thread_results[i] = std::move(result);
416             });
417           }
418 
419           // Note: the thread pool destructor guarantees it completes all work
420           // before we leave this scope.
421         }
422         for (auto& thread_result : thread_results) {
423           if (!thread_result.ok()) {
424             return thread_result.status();
425           }
426           results.push_back(std::move(thread_result).ValueOrDie());
427         }
428         return results;
429       },
430       argument_count_provider, argument_provider, options, &device_assignment);
431 }
432 
ExecuteReplicated(std::unique_ptr<HloModule> module,const ReplicatedExecuteOptions & options)433 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
434     std::unique_ptr<HloModule> module,
435     const ReplicatedExecuteOptions& options) {
436   TF_ASSIGN_OR_RETURN(
437       DeviceAssignment device_assignment,
438       backend().computation_placer()->AssignDevices(options.num_replicas, 1));
439   return ExecuteReplicated(std::move(module), options, &device_assignment);
440 }
441 
CreateExecutable(std::unique_ptr<HloModule> module,bool run_hlo_passes)442 StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
443     std::unique_ptr<HloModule> module, bool run_hlo_passes) {
444   if (run_hlo_passes) {
445     auto module_group = absl::make_unique<HloModuleGroup>(std::move(module));
446     TF_ASSIGN_OR_RETURN(
447         auto executables,
448         backend().compiler()->Compile(std::move(module_group),
449                                       {{backend().default_stream_executor()}},
450                                       backend().memory_allocator()));
451     return std::move(executables[0]);
452   }
453   return backend().compiler()->RunBackend(std::move(module),
454                                           backend().default_stream_executor(),
455                                           backend().memory_allocator());
456 }
457 
GetServiceRunOptionsForDevice(int64_t device,se::Stream * stream,DeviceAssignment * device_assignment,RunId run_id)458 ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
459     int64_t device, se::Stream* stream, DeviceAssignment* device_assignment,
460     RunId run_id) {
461   ExecutableRunOptions run_options;
462   run_options.set_device_ordinal(device);
463   run_options.set_stream(stream);
464   run_options.set_allocator(backend().memory_allocator());
465   run_options.set_intra_op_thread_pool(
466       backend().eigen_intra_op_thread_pool_device());
467   if (device_assignment != nullptr) {
468     run_options.set_device_assignment(device_assignment);
469   }
470   run_options.set_run_id(run_id);
471   return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
472 }
473 
backend()474 Backend& HloRunner::backend() {
475   if (!backend_) {
476     backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
477     VLOG(1) << "Executing on platform " << backend().platform()->Name();
478   }
479   return *backend_;
480 }
481 
backend() const482 const Backend& HloRunner::backend() const {
483   return const_cast<HloRunner*>(this)->backend();
484 }
485 
486 }  // namespace xla
487