1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/executable.h"
17
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/service/dump.h"
22 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
23 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
24 #include "tensorflow/compiler/xla/status.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/hash/hash.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/lib/strings/proto_serialization.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/stream_executor/device_description.h"
33
34 namespace xla {
35
~ExecutionInput()36 ExecutionInput::~ExecutionInput() {
37 for (auto& index : unowned_indices_) {
38 auto buffer = buffers_.mutable_element(index)->Release();
39 if (buffer) {
40 buffer->Release();
41 }
42 }
43 }
44
SetDynamicShape(Shape dynamic_shape)45 Status ExecutionInput::SetDynamicShape(Shape dynamic_shape) {
46 const Shape& input_shape = shape();
47 if (!ShapeUtil::DynamicShapeIsCompatible(input_shape, dynamic_shape)) {
48 return tensorflow::errors::InvalidArgument(
49 "Cannot set dynamic shape: ", input_shape.DebugString(), " vs. ",
50 dynamic_shape.DebugString());
51 }
52 dynamic_shape_ = absl::make_unique<Shape>(std::move(dynamic_shape));
53 return Status::OK();
54 }
55
SetUnownedBuffer(const ShapeIndex & index,MaybeOwningDeviceMemory buffer)56 void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
57 MaybeOwningDeviceMemory buffer) {
58 *buffers_.mutable_element(index) = std::move(buffer);
59 unowned_indices_.insert(index);
60 }
61
ToShapedBuffer(se::DeviceMemoryAllocator * allocator,int device_ordinal) const62 StatusOr<ShapedBuffer> ExecutionInput::ToShapedBuffer(
63 se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
64 const Shape& input_shape = shape();
65 ShapedBuffer shaped_buffer(input_shape, device_ordinal);
66 for (const auto& index_buffer : Buffers()) {
67 const tensorflow::se::OwningDeviceMemory* mem =
68 index_buffer.second.AsOwningDeviceMemory();
69 if (mem != nullptr && (mem->allocator() != allocator ||
70 mem->device_ordinal() != device_ordinal)) {
71 return tensorflow::errors::InvalidArgument(
72 "Device buffer at index ", index_buffer.first.ToString(),
73 " has mismatching allocator/device");
74 }
75 shaped_buffer.set_buffer(index_buffer.second.AsDeviceMemoryBase(),
76 index_buffer.first);
77 }
78 return std::move(shaped_buffer);
79 }
80
ExecuteOnStream(const ServiceExecutableRunOptions * run_options,absl::Span<const ShapedBuffer * const> arguments,HloExecutionProfile * hlo_execution_profile)81 StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
82 const ServiceExecutableRunOptions* run_options,
83 absl::Span<const ShapedBuffer* const> arguments,
84 HloExecutionProfile* hlo_execution_profile) {
85 StatusOr<ScopedShapedBuffer> result =
86 ExecuteAsyncOnStream(run_options, arguments, hlo_execution_profile);
87 Status blocking_status = run_options->stream()->BlockHostUntilDone();
88 TF_RETURN_IF_ERROR(result.status());
89 TF_RETURN_IF_ERROR(blocking_status);
90 return result;
91 }
92
MakeMaybeOwningDeviceMemoryTree(const ShapedBuffer & shaped_buffer)93 static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
94 const ShapedBuffer& shaped_buffer) {
95 ExecutionInput result(shaped_buffer.on_device_shape());
96 shaped_buffer.buffers().ForEachElement(
97 [&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
98 result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
99 });
100 return result;
101 }
102
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,absl::Span<const ShapedBuffer * const> arguments,HloExecutionProfile * hlo_execution_profile)103 StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
104 const ServiceExecutableRunOptions* run_options,
105 absl::Span<const ShapedBuffer* const> arguments,
106 HloExecutionProfile* hlo_execution_profile) {
107 std::vector<ExecutionInput> args;
108 args.reserve(arguments.size());
109 for (const ShapedBuffer* arg : arguments) {
110 args.emplace_back(MakeMaybeOwningDeviceMemoryTree(*arg));
111 }
112 TF_ASSIGN_OR_RETURN(ExecutionOutput out,
113 ExecuteAsyncOnStream(run_options, std::move(args),
114 hlo_execution_profile));
115 return out.ConsumeResult();
116 }
117
ExecuteOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)118 StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
119 const ServiceExecutableRunOptions* run_options,
120 std::vector<ExecutionInput> arguments,
121 HloExecutionProfile* hlo_execution_profile) {
122 StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
123 run_options, std::move(arguments), hlo_execution_profile);
124 Status blocking_status = run_options->stream()->BlockHostUntilDone();
125 TF_RETURN_IF_ERROR(result.status());
126 TF_RETURN_IF_ERROR(blocking_status);
127 return result;
128 }
129
ExecuteOnStreams(absl::Span<const ServiceExecutableRunOptions> run_options,absl::Span<const absl::Span<const ShapedBuffer * const>> arguments)130 StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
131 absl::Span<const ServiceExecutableRunOptions> run_options,
132 absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
133 TF_RET_CHECK(run_options.size() == arguments.size());
134
135 std::vector<ScopedShapedBuffer> return_values;
136 return_values.reserve(run_options.size());
137
138 if (run_options.size() == 1) {
139 TF_ASSIGN_OR_RETURN(auto rv,
140 ExecuteOnStream(&run_options[0], arguments[0],
141 /*hlo_execution_profile=*/nullptr));
142 return_values.push_back(std::move(rv));
143 return std::move(return_values);
144 }
145
146 for (size_t i = 0; i < run_options.size(); ++i) {
147 // We cannot BlockHostUntilDone() on the already-launched executions in case
148 // of error, since if the executions communicate, the initially launched
149 // executions may never complete if not all executions are running.
150 TF_ASSIGN_OR_RETURN(
151 auto rv, ExecuteAsyncOnStream(&run_options[i], arguments[i],
152 /*hlo_execution_profile=*/nullptr));
153 return_values.push_back(std::move(rv));
154 }
155 for (const auto& options : run_options) {
156 TF_RET_CHECK(options.stream() != nullptr);
157 TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone());
158 }
159 return std::move(return_values);
160 }
161
ExecuteOnStreamWrapper(const ServiceExecutableRunOptions * run_options,absl::Span<const ShapedBuffer * const> arguments)162 StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
163 const ServiceExecutableRunOptions* run_options,
164 absl::Span<const ShapedBuffer* const> arguments) {
165 StatusOr<ScopedShapedBuffer> result =
166 ExecuteAsyncOnStreamWrapper(run_options, arguments);
167 Status block_status = run_options->stream()->BlockHostUntilDone();
168 TF_RETURN_IF_ERROR(result.status());
169 TF_RETURN_IF_ERROR(block_status);
170 return result;
171 }
172
ExecuteOnStreamWrapper(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments)173 StatusOr<ExecutionOutput> Executable::ExecuteOnStreamWrapper(
174 const ServiceExecutableRunOptions* run_options,
175 std::vector<ExecutionInput> arguments) {
176 StatusOr<ExecutionOutput> result =
177 ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments));
178 Status block_status = run_options->stream()->BlockHostUntilDone();
179 TF_RETURN_IF_ERROR(result.status());
180 TF_RETURN_IF_ERROR(block_status);
181 return result;
182 }
183
184 struct ExecuteAsyncOnStreamWrapperState {
185 ExecutionProfile* profile;
186 std::shared_ptr<se::Timer> timer;
187 std::shared_ptr<HloExecutionProfile> profile_ptr;
188 };
189
ExecuteWrapperBeforeExecution(const Executable & executable,const ServiceExecutableRunOptions * run_options)190 static ExecuteAsyncOnStreamWrapperState ExecuteWrapperBeforeExecution(
191 const Executable& executable,
192 const ServiceExecutableRunOptions* run_options) {
193 ExecuteAsyncOnStreamWrapperState state;
194 se::Stream* stream = run_options->stream();
195 state.profile = run_options->run_options().execution_profile();
196 if (state.profile != nullptr) {
197 state.timer = std::make_shared<se::Timer>(stream->parent());
198 stream->InitTimer(state.timer.get()).ThenStartTimer(state.timer.get());
199 }
200
201 VLOG(1) << "enqueueing executable on stream...";
202 // If the profiling flag isn't enabled, we pass nullptr as the profile to
203 // indicate profiling is not requested.
204 state.profile_ptr =
205 executable.module_config().debug_options().xla_hlo_profile() &&
206 executable.hlo_profiling_enabled()
207 ? std::make_shared<HloExecutionProfile>(
208 &executable.hlo_profile_printer_data(),
209 &executable.hlo_profile_index_map())
210 : nullptr;
211 return state;
212 }
213
ExecuteWrapperAfterExecution(Executable * executable,const ExecuteAsyncOnStreamWrapperState & state,Status return_status,se::Stream * stream)214 Status ExecuteWrapperAfterExecution(
215 Executable* executable, const ExecuteAsyncOnStreamWrapperState& state,
216 Status return_status, se::Stream* stream) {
217 if (!return_status.ok()) {
218 if (state.profile != nullptr) {
219 // Ensure the ThenStartTimer call has completed before we destroy timer.
220 // We already have a failure status to return, so just log this if it
221 // fails.
222 Status status = stream->BlockHostUntilDone();
223 if (!status.ok()) {
224 LOG(ERROR) << "Failed to BlockHostUntilDone: " << status;
225 }
226 }
227 return return_status;
228 }
229
230 if (state.profile != nullptr) {
231 VLOG(1) << "enqueueing 'stop timer' and profiling callback...";
232 stream->ThenStopTimer(state.timer.get());
233
234 // We block instead of using an async callback because reading the timer
235 // value may call back into the driver on GPU, which is not allowed.
236 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
237
238 const int64 executable_size_in_bytes =
239 executable->SizeOfGeneratedCodeInBytes();
240 // Merge in run-time profile information from execution_profile.
241
242 // Overall execution time (in nanoseconds) from the executor timer.
243 state.profile->set_compute_and_transfer_time_ns(state.timer->Nanoseconds());
244
245 // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
246 // the compute time without the transfer time, so this way we get the
247 // correct compute time. We should instead have the correct value for
248 // compute_and_transfer_time and set compute_time to the compute time.
249 if (state.profile->compute_time_ns() == 0) {
250 state.profile->set_compute_time_ns(
251 state.profile->compute_and_transfer_time_ns());
252 }
253
254 if (executable_size_in_bytes != 0) {
255 state.profile->set_executable_size_in_bytes(executable_size_in_bytes);
256 }
257 }
258
259 if (executable->module_config().debug_options().xla_hlo_profile() &&
260 state.profile_ptr != nullptr) {
261 DumpToFileInDir(executable->module(), /*file_prefix=*/"",
262 /*file_suffix=*/"hlo_execution_profile_data",
263 state.profile_ptr->ToProto().SerializeAsString());
264 }
265
266 if (state.profile_ptr != nullptr) {
267 const se::DeviceDescription* device_description =
268 &stream->parent()->GetDeviceDescription();
269 std::shared_ptr<HloExecutionProfile> profile = state.profile_ptr;
270 stream->ThenDoHostCallback([profile, device_description]() {
271 XLA_LOG_LINES(tensorflow::INFO,
272 profile->ToString(device_description->clock_rate_ghz()));
273 });
274 }
275
276 return return_status;
277 }
278
ExecuteAsyncOnStreamWrapper(const ServiceExecutableRunOptions * run_options,absl::Span<const ShapedBuffer * const> arguments)279 StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStreamWrapper(
280 const ServiceExecutableRunOptions* run_options,
281 absl::Span<const ShapedBuffer* const> arguments) {
282 auto state = ExecuteWrapperBeforeExecution(*this, run_options);
283 StatusOr<ScopedShapedBuffer> return_value =
284 ExecuteAsyncOnStream(run_options, arguments, state.profile_ptr.get());
285 TF_RETURN_IF_ERROR(ExecuteWrapperAfterExecution(
286 this, state, return_value.status(), run_options->stream()));
287 return return_value;
288 }
289
ExecuteAsyncOnStreamWrapper(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments)290 StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStreamWrapper(
291 const ServiceExecutableRunOptions* run_options,
292 std::vector<ExecutionInput> arguments) {
293 auto state = ExecuteWrapperBeforeExecution(*this, run_options);
294 StatusOr<ExecutionOutput> return_value = ExecuteAsyncOnStream(
295 run_options, std::move(arguments), state.profile_ptr.get());
296 TF_RETURN_IF_ERROR(ExecuteWrapperAfterExecution(
297 this, state, return_value.status(), run_options->stream()));
298 return return_value;
299 }
300
SizeOfGeneratedCodeInBytes() const301 int64 Executable::SizeOfGeneratedCodeInBytes() const { return -1; }
302
MarkToBeReleasedArguments(absl::Span<ExecutionInput> arguments,ExecutionOutput & result)303 void Executable::MarkToBeReleasedArguments(absl::Span<ExecutionInput> arguments,
304 ExecutionOutput& result) {
305 for (ExecutionInput& argument : arguments) {
306 for (auto& index_buffer : *argument.MutableBuffers()) {
307 if (absl::optional<se::OwningDeviceMemory> maybe_owning_buffer =
308 index_buffer.second.Release()) {
309 result.AddToBeReleased(std::move(*maybe_owning_buffer));
310 }
311 }
312 }
313 }
314
315 } // namespace xla
316