1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "tensorflow/compiler/xla/service/hlo_runner.h"
18
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/transfer_manager.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace xla {
34
35 /*static*/ StatusOr<std::unique_ptr<HloModule>>
CreateModuleFromString(const absl::string_view hlo_string,const DebugOptions & debug_options)36 HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
37 const DebugOptions& debug_options) {
38 HloModuleConfig config;
39 config.set_debug_options(debug_options);
40 return ParseAndReturnUnverifiedModule(hlo_string, config);
41 }
42
43 namespace {
44
45 // Creates an HloModule from the given proto.
HloProtoToModule(const HloProto & proto,const DebugOptions & debug_options)46 StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
47 const HloProto& proto, const DebugOptions& debug_options) {
48 TF_ASSIGN_OR_RETURN(HloModuleConfig config,
49 HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
50 debug_options));
51 TF_ASSIGN_OR_RETURN(auto module,
52 HloModule::CreateFromProto(proto.hlo_module(), config));
53 return std::move(module);
54 }
55
56 } // namespace
57
58 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromBinaryProtoFile(const std::string & filename,const DebugOptions & debug_options)59 HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
60 const DebugOptions& debug_options) {
61 HloProto proto;
62 TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
63 filename, &proto));
64 return HloProtoToModule(proto, debug_options);
65 }
66
67 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromTextProtoFile(const std::string & filename,const DebugOptions & debug_options)68 HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
69 const DebugOptions& debug_options) {
70 HloProto proto;
71 TF_RETURN_IF_ERROR(
72 tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
73 return HloProtoToModule(proto, debug_options);
74 }
75
76 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromHloTextFile(const std::string & filename,const DebugOptions & debug_options)77 HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
78 const DebugOptions& debug_options) {
79 string hlo_string;
80 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
81 filename, &hlo_string));
82 HloModuleConfig config;
83 config.set_debug_options(debug_options);
84 return ParseAndReturnUnverifiedModule(hlo_string, config);
85 }
86
HloRunner(se::Platform * platform,int intra_op_parallelism_threads)87 HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
88 BackendOptions backend_options;
89 backend_options.set_platform(platform);
90 backend_options.set_intra_op_parallelism_threads(
91 intra_op_parallelism_threads);
92 backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
93 VLOG(1) << "Created HloRunner for platform: " << platform->Name();
94 }
95
~HloRunner()96 HloRunner::~HloRunner() {}
97
TransferLiteralToDevice(const Literal & literal)98 StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
99 const Literal& literal) {
100 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
101 backend().transfer_manager()->AllocateScopedShapedBuffer(
102 literal.shape(), backend().memory_allocator(),
103 backend().default_device_ordinal()));
104 TF_ASSIGN_OR_RETURN(
105 auto stream, backend().BorrowStream(backend().default_stream_executor()));
106 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
107 stream.get(), literal, buffer));
108 return std::move(buffer);
109 }
110
TransferLiteralsToDevice(absl::Span<const Literal * const> literals)111 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
112 absl::Span<const Literal* const> literals) {
113 std::vector<ScopedShapedBuffer> buffers;
114 for (const Literal* literal : literals) {
115 CHECK(literal != nullptr);
116 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
117 TransferLiteralToDevice(*literal));
118 buffers.push_back(std::move(buffer));
119 }
120 return std::move(buffers);
121 }
122
TransferLiteralsToDevice(absl::Span<const Literal> literals)123 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
124 absl::Span<const Literal> literals) {
125 std::vector<const Literal*> literal_pointers;
126 literal_pointers.reserve(literals.size());
127 for (const auto& literal : literals) {
128 literal_pointers.push_back(&literal);
129 }
130 return TransferLiteralsToDevice(literal_pointers);
131 }
132
TransferLiteralFromDevice(const ShapedBuffer & buffer)133 StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
134 const ShapedBuffer& buffer) {
135 TF_ASSIGN_OR_RETURN(
136 auto stream, backend().BorrowStream(backend().default_stream_executor()));
137 return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
138 buffer);
139 }
140
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal * const> arguments,bool run_hlo_passes,ExecutionProfile * profile)141 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
142 absl::Span<const Literal* const> arguments,
143 bool run_hlo_passes,
144 ExecutionProfile* profile) {
145 TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
146 TransferLiteralsToDevice(arguments));
147 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
148 ExecuteWithDeviceBuffers(
149 /*module=*/std::move(module),
150 /*arguments=*/argument_buffers,
151 /*run_hlo_passes=*/run_hlo_passes,
152 /*profile=*/profile));
153 return TransferLiteralFromDevice(result);
154 }
155
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal> arguments,bool run_hlo_passes,ExecutionProfile * profile)156 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
157 absl::Span<const Literal> arguments,
158 bool run_hlo_passes,
159 ExecutionProfile* profile) {
160 // Construct a vector of plain pointers for the arguments.
161 std::vector<const Literal*> argument_pointers;
162 argument_pointers.reserve(arguments.size());
163 for (const auto& argument : arguments) {
164 argument_pointers.push_back(&argument);
165 }
166 return Execute(
167 /*module=*/std::move(module),
168 /*arguments=*/argument_pointers,
169 /*run_hlo_passes=*/run_hlo_passes,
170 /*profile=*/profile);
171 }
172
Execute(std::unique_ptr<Executable> executable,absl::Span<const Literal * const> arguments,ExecutionProfile * profile)173 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
174 absl::Span<const Literal* const> arguments,
175 ExecutionProfile* profile) {
176 TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
177 TransferLiteralsToDevice(arguments));
178 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
179 ExecuteWithDeviceBuffers(
180 /*executable=*/executable.get(),
181 /*arguments=*/argument_buffers,
182 /*profile=*/profile));
183 return TransferLiteralFromDevice(result);
184 }
185
Execute(std::unique_ptr<Executable> executable,absl::Span<const Literal> arguments,ExecutionProfile * profile)186 StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
187 absl::Span<const Literal> arguments,
188 ExecutionProfile* profile) {
189 // Construct a vector of plain pointers for the arguments.
190 std::vector<const Literal*> argument_pointers;
191 argument_pointers.reserve(arguments.size());
192 for (const auto& argument : arguments) {
193 argument_pointers.push_back(&argument);
194 }
195 return Execute(
196 /*module=*/std::move(executable),
197 /*arguments=*/argument_pointers,
198 /*profile=*/profile);
199 }
200
ExecuteWithDeviceBuffers(std::unique_ptr<HloModule> module,absl::Span<const ShapedBuffer * const> arguments,bool run_hlo_passes,ExecutionProfile * profile)201 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
202 std::unique_ptr<HloModule> module,
203 absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
204 ExecutionProfile* profile) {
205 // Get service run options.
206 se::Stream stream(backend().default_stream_executor());
207 stream.Init();
208 ServiceExecutableRunOptions service_run_options =
209 GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
210 nullptr, RunId());
211 service_run_options.mutable_run_options()->set_execution_profile(profile);
212
213 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
214 CreateExecutable(std::move(module), run_hlo_passes));
215 TF_ASSIGN_OR_RETURN(
216 ScopedShapedBuffer retval,
217 executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
218 TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
219 return std::move(retval);
220 }
221
ExecuteWithDeviceBuffers(std::unique_ptr<HloModule> module,absl::Span<const ScopedShapedBuffer> arguments,bool run_hlo_passes,ExecutionProfile * profile)222 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
223 std::unique_ptr<HloModule> module,
224 absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
225 ExecutionProfile* profile) {
226 std::vector<const ShapedBuffer*> argument_pointers;
227 argument_pointers.reserve(arguments.size());
228 for (const auto& argument : arguments) {
229 argument_pointers.push_back(&argument);
230 }
231 return ExecuteWithDeviceBuffers(
232 /*module=*/std::move(module),
233 /*arguments=*/argument_pointers,
234 /*run_hlo_passes=*/run_hlo_passes,
235 /*profile=*/profile);
236 }
237
ExecuteWithDeviceBuffers(Executable * executable,absl::Span<const ShapedBuffer * const> arguments,ExecutionProfile * profile)238 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
239 Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
240 ExecutionProfile* profile) {
241 // Get service run options.
242 se::Stream stream(backend().default_stream_executor());
243 stream.Init();
244 ServiceExecutableRunOptions service_run_options =
245 GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
246 nullptr, RunId());
247 service_run_options.mutable_run_options()->set_execution_profile(profile);
248
249 TF_ASSIGN_OR_RETURN(
250 ScopedShapedBuffer retval,
251 executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
252 TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
253 return std::move(retval);
254 }
255
ExecuteWithDeviceBuffers(Executable * executable,absl::Span<const ScopedShapedBuffer> arguments,ExecutionProfile * profile)256 StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
257 Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
258 ExecutionProfile* profile) {
259 std::vector<const ShapedBuffer*> argument_pointers;
260 argument_pointers.reserve(arguments.size());
261 for (const auto& argument : arguments) {
262 argument_pointers.push_back(&argument);
263 }
264 return ExecuteWithDeviceBuffers(
265 /*executable=*/std::move(executable),
266 /*arguments=*/argument_pointers,
267 /*profile=*/profile);
268 }
269
ExecuteReplicated(std::unique_ptr<HloModule> module,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment)270 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
271 std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
272 DeviceAssignment* device_assignment) {
273 TF_ASSIGN_OR_RETURN(
274 std::unique_ptr<Executable> executable,
275 CreateExecutable(std::move(module), options.run_hlo_passes));
276 return ExecuteReplicated(executable.get(), options, device_assignment);
277 }
278
ExecuteReplicated(Executable * executable,const ReplicatedExecuteOptions & options,DeviceAssignment * device_assignment,ExecutionProfile * profile)279 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
280 Executable* executable, const ReplicatedExecuteOptions& options,
281 DeviceAssignment* device_assignment, ExecutionProfile* profile) {
282 std::vector<std::unique_ptr<se::Stream>> streams;
283 std::vector<ServiceExecutableRunOptions> service_run_options;
284
285 std::vector<ScopedShapedBuffer> argument_buffers;
286 // This reserve() call is necessary for correctness, because
287 // argument_buffer_ptrs contains pointers into the elements of
288 // argument_buffers.
289 argument_buffers.reserve(options.num_replicas * options.arguments.size());
290
291 // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are
292 // no arguments.
293 std::vector<const ShapedBuffer*> argument_buffer_ptrs(
294 options.num_replicas * options.arguments.size() + 1);
295 std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
296 int64 index = 0;
297 RunId run_id;
298 for (int64 i = 0; i < options.num_replicas; ++i) {
299 int64 device = (*device_assignment)(i, 0);
300 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
301 backend().stream_executor(device));
302 streams.push_back(absl::make_unique<se::Stream>(executor));
303 streams.back()->Init();
304 service_run_options.emplace_back(GetServiceRunOptionsForDevice(
305 device, streams.back().get(), device_assignment, run_id));
306
307 // Copy arguments to device.
308 for (const Literal* argument : options.arguments) {
309 TF_ASSIGN_OR_RETURN(
310 ScopedShapedBuffer argument_buffer,
311 backend().transfer_manager()->AllocateScopedShapedBuffer(
312 argument->shape(), backend().memory_allocator(), device));
313 TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
314 streams.back().get(), *argument, argument_buffer));
315 argument_buffers.push_back(std::move(argument_buffer));
316 argument_buffer_ptrs[index++] = &argument_buffers.back();
317 }
318 argument_buffer_slices.emplace_back(
319 &argument_buffer_ptrs[index - options.arguments.size()],
320 options.arguments.size());
321 }
322
323 std::unique_ptr<tensorflow::thread::ThreadPool> pool;
324 int64 num_threads = (options.infeed != nullptr) ? options.num_replicas : 0;
325 if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
326 num_threads += options.num_replicas;
327 }
328 if (num_threads > 0) {
329 pool = absl::make_unique<tensorflow::thread::ThreadPool>(
330 tensorflow::Env::Default(), "infeed_outfeed",
331 /*num_threads=*/num_threads);
332 }
333 if (options.infeed != nullptr) {
334 for (int64 i = 0; i < options.num_replicas; ++i) {
335 int64 device = (*device_assignment)(i, 0);
336 pool->Schedule([this, device, &options]() {
337 se::StreamExecutor* executor =
338 backend().stream_executor(device).ValueOrDie();
339 VLOG(1) << "Starting infeed on device " << device;
340 for (int64 step = 1;
341 options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
342 TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed(
343 executor, *options.infeed));
344 if (step % 100 == 0) {
345 VLOG(1) << "Infeed step " << step;
346 }
347 }
348 });
349 }
350 }
351 if (ShapeUtil::IsInitialized(options.outfeed_shape)) {
352 for (int64 i = 0; i < options.num_replicas; ++i) {
353 int64 device = (*device_assignment)(i, 0);
354 pool->Schedule([this, device, &options]() {
355 se::StreamExecutor* executor =
356 backend().stream_executor(device).ValueOrDie();
357 VLOG(1) << "Starting outfeed on device " << device;
358 for (int64 step = 1;
359 options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
360 Literal literal;
361 TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
362 executor, options.outfeed_shape, &literal));
363 if (options.outfeed_values != nullptr) {
364 options.outfeed_values->push_back(std::move(literal));
365 }
366 if (step % 100 == 0) {
367 VLOG(1) << "Outfeed step " << step;
368 }
369 }
370 });
371 }
372 }
373
374 LOG(INFO) << "Replicated execution started";
375 std::vector<ScopedShapedBuffer> results;
376 if (!options.use_threads) {
377 TF_ASSIGN_OR_RETURN(results,
378 executable->ExecuteOnStreams(service_run_options,
379 argument_buffer_slices));
380 } else {
381 tensorflow::mutex mutex;
382 std::vector<StatusOr<ScopedShapedBuffer>> thread_results(
383 options.num_replicas);
384 {
385 LOG(INFO) << "Creating thread pool for " << options.num_replicas
386 << " replicas";
387 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
388 "replicas", options.num_replicas);
389 for (int64 i = 0; i < options.num_replicas; ++i) {
390 pool.Schedule([&, i] {
391 auto result = executable->ExecuteOnStream(
392 &service_run_options[i], argument_buffer_slices[i], nullptr);
393 tensorflow::mutex_lock lock(mutex);
394 thread_results[i] = std::move(result);
395 });
396 }
397
398 // Note: the thread pool destructor guarantees it completes all work
399 // before we leave this scope.
400 }
401 for (auto& thread_result : thread_results) {
402 if (!thread_result.ok()) {
403 return thread_result.status();
404 }
405 results.push_back(std::move(thread_result).ValueOrDie());
406 }
407 }
408 LOG(INFO) << "Replicated execution terminated";
409
410 std::vector<Literal> exec_results;
411 for (int64 i = 0; i < options.num_replicas; ++i) {
412 TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
413 TF_ASSIGN_OR_RETURN(Literal literal,
414 backend().transfer_manager()->TransferLiteralFromDevice(
415 streams[i].get(), results[i]));
416 exec_results.push_back(std::move(literal));
417 }
418 return std::move(exec_results);
419 }
420
ExecuteReplicated(std::unique_ptr<HloModule> module,const ReplicatedExecuteOptions & options)421 StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
422 std::unique_ptr<HloModule> module,
423 const ReplicatedExecuteOptions& options) {
424 TF_ASSIGN_OR_RETURN(
425 DeviceAssignment device_assignment,
426 backend().computation_placer()->AssignDevices(options.num_replicas, 1));
427 return ExecuteReplicated(std::move(module), options, &device_assignment);
428 }
429
CreateExecutable(std::unique_ptr<HloModule> module,bool run_hlo_passes)430 StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
431 std::unique_ptr<HloModule> module, bool run_hlo_passes) {
432 if (run_hlo_passes) {
433 auto module_group = absl::make_unique<HloModuleGroup>(std::move(module));
434 TF_ASSIGN_OR_RETURN(
435 auto executables,
436 backend().compiler()->Compile(std::move(module_group),
437 {{backend().default_stream_executor()}},
438 backend().memory_allocator()));
439 return std::move(executables[0]);
440 }
441 return backend().compiler()->RunBackend(std::move(module),
442 backend().default_stream_executor(),
443 backend().memory_allocator());
444 }
445
GetServiceRunOptionsForDevice(int64 device,se::Stream * stream,DeviceAssignment * device_assignment,RunId run_id)446 ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
447 int64 device, se::Stream* stream, DeviceAssignment* device_assignment,
448 RunId run_id) {
449 ExecutableRunOptions run_options;
450 run_options.set_device_ordinal(device);
451 run_options.set_stream(stream);
452 run_options.set_allocator(backend().memory_allocator());
453 run_options.set_intra_op_thread_pool(
454 backend().eigen_intra_op_thread_pool_device());
455 if (device_assignment != nullptr) {
456 run_options.set_device_assignment(device_assignment);
457 }
458 run_options.set_run_id(run_id);
459 return ServiceExecutableRunOptions(run_options, backend().StreamBorrower());
460 }
461
backend()462 Backend& HloRunner::backend() {
463 if (!backend_) {
464 backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
465 VLOG(1) << "Executing on platform " << backend().platform()->Name();
466 }
467 return *backend_;
468 }
469
backend() const470 const Backend& HloRunner::backend() const {
471 return const_cast<HloRunner*>(this)->backend();
472 }
473
474 } // namespace xla
475