• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/compiler/xla/service/hlo_runner.h"
18 
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/transfer_manager.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
35 /*static*/ StatusOr<std::unique_ptr<HloModule>>
CreateModuleFromString(const absl::string_view hlo_string,const DebugOptions & debug_options)36 HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
37                                   const DebugOptions& debug_options) {
38   HloModuleConfig config;
39   config.set_debug_options(debug_options);
40   return ParseAndReturnUnverifiedModule(hlo_string, config);
41 }
42 
43 namespace {
44 
45 // Creates an HloModule from the given proto.
HloProtoToModule(const HloProto & proto,const DebugOptions & debug_options)46 StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
47     const HloProto& proto, const DebugOptions& debug_options) {
48   TF_ASSIGN_OR_RETURN(HloModuleConfig config,
49                       HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
50                                                              debug_options));
51   TF_ASSIGN_OR_RETURN(auto module,
52                       HloModule::CreateFromProto(proto.hlo_module(), config));
53   return std::move(module);
54 }
55 
56 }  // namespace
57 
58 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromBinaryProtoFile(const std::string & filename,const DebugOptions & debug_options)59 HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
60                                          const DebugOptions& debug_options) {
61   HloProto proto;
62   TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
63                                                  filename, &proto));
64   return HloProtoToModule(proto, debug_options);
65 }
66 
67 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromTextProtoFile(const std::string & filename,const DebugOptions & debug_options)68 HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
69                                        const DebugOptions& debug_options) {
70   HloProto proto;
71   TF_RETURN_IF_ERROR(
72       tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
73   return HloProtoToModule(proto, debug_options);
74 }
75 
76 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromHloTextFile(const std::string & filename,const DebugOptions & debug_options)77 HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
78                                      const DebugOptions& debug_options) {
79   string hlo_string;
80   TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
81                                                   filename, &hlo_string));
82   HloModuleConfig config;
83   config.set_debug_options(debug_options);
84   return ParseAndReturnUnverifiedModule(hlo_string, config);
85 }
86 
HloRunner(se::Platform * platform,int intra_op_parallelism_threads)87 HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
88   BackendOptions backend_options;
89   backend_options.set_platform(platform);
90   backend_options.set_intra_op_parallelism_threads(
91       intra_op_parallelism_threads);
92   backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
93   VLOG(1) << "Created HloRunner for platform: " << platform->Name();
94 }
95 
~HloRunner()96 HloRunner::~HloRunner() {}
97 
TransferLiteralToDevice(const Literal & literal)98 StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
99     const Literal& literal) {
100   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
101                       backend().transfer_manager()->AllocateScopedShapedBuffer(
102                           literal.shape(), backend().memory_allocator(),
103                           backend().default_device_ordinal()));
104   TF_ASSIGN_OR_RETURN(
105       auto stream, backend().BorrowStream(backend().default_stream_executor()));
106   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
107       stream.get(), literal, buffer));
108   return std::move(buffer);
109 }
110 
TransferLiteralsToDevice(absl::Span<const Literal * const> literals)111 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
112     absl::Span<const Literal* const> literals) {
113   std::vector<ScopedShapedBuffer> buffers;
114   for (const Literal* literal : literals) {
115     CHECK(literal != nullptr);
116     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
117                         TransferLiteralToDevice(*literal));
118     buffers.push_back(std::move(buffer));
119   }
120   return std::move(buffers);
121 }
122 
TransferLiteralsToDevice(absl::Span<const Literal> literals)123 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
124     absl::Span<const Literal> literals) {
125   std::vector<const Literal*> literal_pointers;
126   literal_pointers.reserve(literals.size());
127   for (const auto& literal : literals) {
128     literal_pointers.push_back(&literal);
129   }
130   return TransferLiteralsToDevice(literal_pointers);
131 }
132 
TransferLiteralFromDevice(const ShapedBuffer & buffer)133 StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
134     const ShapedBuffer& buffer) {
135   TF_ASSIGN_OR_RETURN(
136       auto stream, backend().BorrowStream(backend().default_stream_executor()));
137   return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
138                                                                  buffer);
139 }
140 
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal * const> arguments,bool run_hlo_passes,ExecutionProfile * profile)141 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
142                                      absl::Span<const Literal* const> arguments,
143                                      bool run_hlo_passes,
144                                      ExecutionProfile* profile) {
145   TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
146                       TransferLiteralsToDevice(arguments));
147   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
148                       ExecuteWithDeviceBuffers(
149                           /*module=*/std::move(module),
150                           /*arguments=*/argument_buffers,
151                           /*run_hlo_passes=*/run_hlo_passes,
152                           /*profile=*/profile));
153   return TransferLiteralFromDevice(result);
154 }
155 
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal> arguments,bool run_hlo_passes,ExecutionProfile * profile)156 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
157                                      absl::Span<const Literal> arguments,
158                                      bool run_hlo_passes,
159                                      ExecutionProfile* profile) {
160   // Construct a vector of plain pointers for the arguments.
161   std::vector<const Literal*> argument_pointers;
162   argument_pointers.reserve(arguments.size());
163   for (const auto& argument : arguments) {
164     argument_pointers.push_back(&argument);
165   }
166   return Execute(
167       /*module=*/std::move(module),
168       /*arguments=*/argument_pointers,
169       /*run_hlo_passes=*/run_hlo_passes,
170       /*profile=*/profile);
171 }
172 
Execute(std::unique_ptr<Executable> executable,absl::Span<const Literal * const> arguments,ExecutionProfile * profile)173 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
174                                      absl::Span<const Literal* const> arguments,
175                                      ExecutionProfile* profile) {
176   TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
177                       TransferLiteralsToDevice(arguments));
178   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
179                       ExecuteWithDeviceBuffers(
180                           /*executable=*/executable.get(),
181                           /*arguments=*/argument_buffers,
182                           /*profile=*/profile));
183   return TransferLiteralFromDevice(result);
184 }
185 
Execute(std::unique_ptr<Executable> executable,absl::Span<const Literal> arguments,ExecutionProfile * profile)186 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
187                                      absl::Span<const Literal> arguments,
188                                      ExecutionProfile* profile) {
189   // Construct a vector of plain pointers for the arguments.
190   std::vector<const Literal*> argument_pointers;
191   argument_pointers.reserve(arguments.size());
192   for (const auto& argument : arguments) {
193     argument_pointers.push_back(&argument);
194   }
195   return Execute(
196       /*module=*/std::move(executable),
197       /*arguments=*/argument_pointers,
198       /*profile=*/profile);
199 }
200 
ExecuteWithDeviceBuffers(std::unique_ptr<HloModule> module,absl::Span<const ShapedBuffer * const> arguments,bool run_hlo_passes,ExecutionProfile * profile)201 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
202     std::unique_ptr<HloModule> module,
203     absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
204     ExecutionProfile* profile) {
205   // Get service run options.
206   se::Stream stream(backend().default_stream_executor());
207   stream.Init();
208   ServiceExecutableRunOptions service_run_options =
209       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
210                                     nullptr, RunId());
211   service_run_options.mutable_run_options()->set_execution_profile(profile);
212 
213   TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
214                       CreateExecutable(std::move(module), run_hlo_passes));
215   TF_ASSIGN_OR_RETURN(
216       ScopedShapedBuffer retval,
217       executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
218   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
219   return std::move(retval);
220 }
221 
ExecuteWithDeviceBuffers(std::unique_ptr<HloModule> module,absl::Span<const ScopedShapedBuffer> arguments,bool run_hlo_passes,ExecutionProfile * profile)222 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
223     std::unique_ptr<HloModule> module,
224     absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
225     ExecutionProfile* profile) {
226   std::vector<const ShapedBuffer*> argument_pointers;
227   argument_pointers.reserve(arguments.size());
228   for (const auto& argument : arguments) {
229     argument_pointers.push_back(&argument);
230   }
231   return ExecuteWithDeviceBuffers(
232       /*module=*/std::move(module),
233       /*arguments=*/argument_pointers,
234       /*run_hlo_passes=*/run_hlo_passes,
235       /*profile=*/profile);
236 }
237 
ExecuteWithDeviceBuffers(Executable * executable,absl::Span<const ShapedBuffer * const> arguments,ExecutionProfile * profile)238 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
239     Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
240     ExecutionProfile* profile) {
241   // Get service run options.
242   se::Stream stream(backend().default_stream_executor());
243   stream.Init();
244   ServiceExecutableRunOptions service_run_options =
245       GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
246                                     nullptr, RunId());
247   service_run_options.mutable_run_options()->set_execution_profile(profile);
248 
249   TF_ASSIGN_OR_RETURN(
250       ScopedShapedBuffer retval,
251       executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
252   TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
253   return std::move(retval);
254 }
255 
ExecuteWithDeviceBuffers(Executable * executable,absl::Span<const ScopedShapedBuffer> arguments,ExecutionProfile * profile)256 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
257     Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
258     ExecutionProfile* profile) {
259   std::vector<const ShapedBuffer*> argument_pointers;
260   argument_pointers.reserve(arguments.size());
261   for (const auto& argument : arguments) {
262     argument_pointers.push_back(&argument);
263   }
264   return ExecuteWithDeviceBuffers(
265       /*executable=*/std::move(executable),
266       /*arguments=*/argument_pointers,
267       /*profile=*/profile);
268 }
269 
ExecuteReplicated(std::unique_ptr<HloModule> module,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment)270 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
271     std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
272     DeviceAssignment* device_assignment) {
273   TF_ASSIGN_OR_RETURN(
274       std::unique_ptr<Executable> executable,
275       CreateExecutable(std::move(module), options.run_hlo_passes));
276   return ExecuteReplicated(executable.get(), options, device_assignment);
277 }
278 
ExecuteReplicated(Executable * executable,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment,ExecutionProfile * profile)279 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
280     Executable* executable, const ReplicatedExecuteOptions& options,
281     DeviceAssignment* device_assignment, ExecutionProfile* profile) {
282   std::vector<std::unique_ptr<se::Stream>> streams;
283   std::vector<ServiceExecutableRunOptions> service_run_options;
284 
285   std::vector<ScopedShapedBuffer> argument_buffers;
286   // This reserve() call is necessary for correctness, because
287   // argument_buffer_ptrs contains pointers into the elements of
288   // argument_buffers.
289   argument_buffers.reserve(options.num_replicas * options.arguments.size());
290 
291   // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are
292   // no arguments.
293   std::vector<const ShapedBuffer*> argument_buffer_ptrs(
294       options.num_replicas * options.arguments.size() + 1);
295   std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
296   int64 index = 0;
297   RunId run_id;
298   for (int64 i = 0; i < options.num_replicas; ++i) {
299     int64 device = (*device_assignment)(i, 0);
300     TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
301                         backend().stream_executor(device));
302     streams.push_back(absl::make_unique<se::Stream>(executor));
303     streams.back()->Init();
304     service_run_options.emplace_back(GetServiceRunOptionsForDevice(
305         device, streams.back().get(), device_assignment, run_id));
306 
307     // Copy arguments to device.
308     for (const Literal* argument : options.arguments) {
309       TF_ASSIGN_OR_RETURN(
310           ScopedShapedBuffer argument_buffer,
311           backend().transfer_manager()->AllocateScopedShapedBuffer(
312               argument->shape(), backend().memory_allocator(), device));
313       TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
314           streams.back().get(), *argument, argument_buffer));
315       argument_buffers.push_back(std::move(argument_buffer));
316       argument_buffer_ptrs[index++] = &argument_buffers.back();
317     }
318     argument_buffer_slices.emplace_back(
319         &argument_buffer_ptrs[index - options.arguments.size()],
320         options.arguments.size());
321   }
322 
323   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
324   int64 num_threads = (options.infeed != nullptr) ? options.num_replicas : 0;
325   if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
326     num_threads += options.num_replicas;
327   }
328   if (num_threads > 0) {
329     pool = absl::make_unique<tensorflow::thread::ThreadPool>(
330         tensorflow::Env::Default(), "infeed_outfeed",
331         /*num_threads=*/num_threads);
332   }
333   if (options.infeed != nullptr) {
334     for (int64 i = 0; i < options.num_replicas; ++i) {
335       int64 device = (*device_assignment)(i, 0);
336       pool->Schedule([this, device, &options]() {
337         se::StreamExecutor* executor =
338             backend().stream_executor(device).ValueOrDie();
339         VLOG(1) << "Starting infeed on device " << device;
340         for (int64 step = 1;
341              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
342           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed(
343               executor, *options.infeed));
344           if (step % 100 == 0) {
345             VLOG(1) << "Infeed step " << step;
346           }
347         }
348       });
349     }
350   }
351   if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
352     for (int64 i = 0; i < options.num_replicas; ++i) {
353       int64 device = (*device_assignment)(i, 0);
354       pool->Schedule([this, device, &options]() {
355         se::StreamExecutor* executor =
356             backend().stream_executor(device).ValueOrDie();
357         VLOG(1) << "Starting outfeed on device " << device;
358         for (int64 step = 1;
359              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
360           Literal literal;
361           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
362               executor, options.outfeed_shape, &literal));
363           if (options.outfeed_values != nullptr) {
364             options.outfeed_values->push_back(std::move(literal));
365           }
366           if (step % 100 == 0) {
367             VLOG(1) << "Outfeed step " << step;
368           }
369         }
370       });
371     }
372   }
373 
374   LOG(INFO) << "Replicated execution started";
375   std::vector<ScopedShapedBuffer> results;
376   if (!options.use_threads) {
377     TF_ASSIGN_OR_RETURN(results,
378                         executable->ExecuteOnStreams(service_run_options,
379                                                      argument_buffer_slices));
380   } else {
381     tensorflow::mutex mutex;
382     std::vector<StatusOr<ScopedShapedBuffer>> thread_results(
383         options.num_replicas);
384     {
385       LOG(INFO) << "Creating thread pool for " << options.num_replicas
386                 << " replicas";
387       tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
388                                           "replicas", options.num_replicas);
389       for (int64 i = 0; i < options.num_replicas; ++i) {
390         pool.Schedule([&, i] {
391           auto result = executable->ExecuteOnStream(
392               &service_run_options[i], argument_buffer_slices[i], nullptr);
393           tensorflow::mutex_lock lock(mutex);
394           thread_results[i] = std::move(result);
395         });
396       }
397 
398       // Note: the thread pool destructor guarantees it completes all work
399       // before we leave this scope.
400     }
401     for (auto& thread_result : thread_results) {
402       if (!thread_result.ok()) {
403         return thread_result.status();
404       }
405       results.push_back(std::move(thread_result).ValueOrDie());
406     }
407   }
408   LOG(INFO) << "Replicated execution terminated";
409 
410   std::vector<Literal> exec_results;
411   for (int64 i = 0; i < options.num_replicas; ++i) {
412     TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
413     TF_ASSIGN_OR_RETURN(Literal literal,
414                         backend().transfer_manager()->TransferLiteralFromDevice(
415                             streams[i].get(), results[i]));
416     exec_results.push_back(std::move(literal));
417   }
418   return std::move(exec_results);
419 }
420 
ExecuteReplicated(std::unique_ptr<HloModule> module,const ReplicatedExecuteOptions & options)421 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
422     std::unique_ptr<HloModule> module,
423     const ReplicatedExecuteOptions& options) {
424   TF_ASSIGN_OR_RETURN(
425       DeviceAssignment device_assignment,
426       backend().computation_placer()->AssignDevices(options.num_replicas, 1));
427   return ExecuteReplicated(std::move(module), options, &device_assignment);
428 }
429 
CreateExecutable(std::unique_ptr<HloModule> module,bool run_hlo_passes)430 StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
431     std::unique_ptr<HloModule> module, bool run_hlo_passes) {
432   if (run_hlo_passes) {
433     auto module_group = absl::make_unique<HloModuleGroup>(std::move(module));
434     TF_ASSIGN_OR_RETURN(
435         auto executables,
436         backend().compiler()->Compile(std::move(module_group),
437                                       {{backend().default_stream_executor()}},
438                                       backend().memory_allocator()));
439     return std::move(executables[0]);
440   }
441   return backend().compiler()->RunBackend(std::move(module),
442                                           backend().default_stream_executor(),
443                                           backend().memory_allocator());
444 }
445 
GetServiceRunOptionsForDevice(int64 device,se::Stream * stream,DeviceAssignment * device_assignment,RunId run_id)446 ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
447     int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
448     RunId run_id) {
449   ExecutableRunOptions run_options;
450   run_options.set_device_ordinal(device);
451   run_options.set_stream(stream);
452   run_options.set_allocator(backend().memory_allocator());
453   run_options.set_intra_op_thread_pool(
454       backend().eigen_intra_op_thread_pool_device());
455   if (device_assignment != nullptr) {
456     run_options.set_device_assignment(device_assignment);
457   }
458   run_options.set_run_id(run_id);
459   return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
460 }
461 
backend()462 Backend& HloRunner::backend() {
463   if (!backend_) {
464     backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
465     VLOG(1) << "Executing on platform " << backend().platform()->Name();
466   }
467   return *backend_;
468 }
469 
backend() const470 const Backend& HloRunner::backend() const {
471   return const_cast<HloRunner*>(this)->backend();
472 }
473 
474 }  // namespace xla
475