1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executable.h"
17
18 #include "absl/cleanup/cleanup.h"
19 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h"
20 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
21 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_stream.h"
22 #include "tensorflow/core/tpu/tpu_executor_api.h"
23
24 namespace ApiConverter {
ToC(const xla::ServiceExecutableRunOptions & options)25 static SE_ExecutableRunOptions ToC(
26 const xla::ServiceExecutableRunOptions& options) {
27 SE_ExecutableRunOptions se_options;
28 se_options.allocator = ApiConverter::ToC(options.run_options().allocator());
29 se_options.device_ordinal = options.run_options().device_ordinal();
30 if (options.run_options().host_to_device_stream() != nullptr) {
31 se_options.host_to_device_stream =
32 static_cast<tensorflow::tpu::TpuStream*>(
33 options.run_options().host_to_device_stream()->implementation())
34 ->se_stream();
35 } else {
36 se_options.host_to_device_stream = nullptr;
37 }
38
39 if (options.run_options().device_assignment() != nullptr) {
40 xla::DeviceAssignmentProto dev_assign_proto;
41 options.run_options()
42 .device_assignment()
43 ->Serialize(&dev_assign_proto)
44 .IgnoreError();
45 se_options.device_assignment =
46 stream_executor::tpu::SerializeProto(dev_assign_proto);
47 } else {
48 se_options.device_assignment.bytes = nullptr;
49 se_options.device_assignment.size = 0;
50 }
51
52 se_options.rng_seed = options.run_options().rng_seed();
53 se_options.run_id = options.run_options().run_id().ToInt();
54 se_options.launch_id = options.run_options().launch_id();
55
56 CHECK_EQ(options.run_options().then_execute_function(), nullptr)
57 << "ThenExecuteFunction not supported by this platform.";
58
59 auto impl =
60 const_cast<stream_executor::Stream*>(options.stream())->implementation();
61 se_options.stream =
62 static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
63 return se_options;
64 }
65 } // namespace ApiConverter
66
67 namespace xla {
68
69 using ::tensorflow::tpu::ExecutorApiFn;
70
~TpuExecutable()71 TpuExecutable::~TpuExecutable() {
72 ExecutorApiFn()->TpuExecutable_FreeFn(se_executable_);
73 }
74
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)75 StatusOr<ExecutionOutput> TpuExecutable::ExecuteAsyncOnStream(
76 const ServiceExecutableRunOptions* run_options,
77 std::vector<ExecutionInput> arguments,
78 HloExecutionProfile* hlo_execution_profile) {
79 SE_ExecutableRunOptions se_run_options = ApiConverter::ToC(*run_options);
80 SE_ExecutionInput** se_args = new SE_ExecutionInput*[arguments.size()];
81 for (int i = 0; i < arguments.size(); ++i) {
82 auto& arg = arguments[i];
83 se_args[i] = new SE_ExecutionInput;
84
85 ApiConverter::ToC(arg.shape(), &se_args[i]->shape_tree.shape);
86 auto* arg_buffers = arg.MutableBuffers();
87 absl::InlinedVector<SE_MaybeOwningDeviceMemory, 2> se_buffers;
88 for (auto& pair : *arg_buffers) {
89 bool aliased = arg.unowned_indices().count(pair.first) > 0;
90 se_buffers.push_back(ApiConverter::ToC(pair.second, aliased));
91 }
92 se_args[i]->shape_tree.buffers =
93 new SE_MaybeOwningDeviceMemory[se_buffers.size()];
94 for (int j = 0; j < se_buffers.size(); ++j) {
95 se_args[i]->shape_tree.buffers[j] = se_buffers[j];
96 }
97
98 ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape);
99 const auto& unowned_indices = arg.unowned_indices();
100 se_args[i]->unowned_indices_size = unowned_indices.size();
101 se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()];
102 int j = 0;
103 for (auto& idx : unowned_indices) {
104 se_args[i]->unowned_indices[j] = ApiConverter::ToC(idx);
105 ++j;
106 }
107 }
108 SE_ExecutionOutput se_execution_output;
109 StatusHelper status;
110 ExecutorApiFn()->TpuExecutable_ExecuteAsyncOnStreamFn(
111 se_executable_, &se_run_options, se_args, arguments.size(), nullptr,
112 &se_execution_output, status.c_status);
113
114 if (se_run_options.device_assignment.bytes != nullptr) {
115 stream_executor::tpu::SerializedProto_Free(
116 se_run_options.device_assignment);
117 }
118 for (int i = 0; i < arguments.size(); ++i) {
119 ApiConverter::Destroy(&se_args[i]->shape_tree.shape);
120 ApiConverter::Destroy(&se_args[i]->dynamic_shape);
121 delete[] se_args[i]->unowned_indices;
122 delete[] se_args[i]->shape_tree.buffers;
123 delete se_args[i];
124 }
125 delete[] se_args;
126
127 if (!status.ok()) {
128 return status.status();
129 }
130
131 xla::ScopedShapedBuffer result(
132 ApiConverter::FromC(&se_execution_output.result),
133 run_options->stream()->parent()->GetAllocator());
134 ApiConverter::Destroy(&se_execution_output.result);
135
136 ExecutionOutput output(std::move(result));
137 for (int i = 0; i < se_execution_output.aliased_indices_size; ++i) {
138 output.AddAliasedIndex(
139 ApiConverter::FromC(&se_execution_output.aliased_indices[i]));
140 }
141 ExecutorApiFn()->TpuExecutable_FreeXlaShapeIndexArrayFn(
142 se_execution_output.aliased_indices);
143
144 for (int i = 0; i < se_execution_output.to_be_released_size; ++i) {
145 output.AddToBeReleased(
146 ApiConverter::FromC(&se_execution_output.to_be_released[i],
147 run_options->stream()->parent()->GetAllocator())
148 .Release()
149 .value());
150 }
151 ExecutorApiFn()->TpuExecutable_FreeMaybeOwningDeviceMemoryArrayFn(
152 se_execution_output.to_be_released);
153
154 return output;
155 }
156
fingerprint() const157 absl::string_view TpuExecutable::fingerprint() const {
158 const char* data;
159 size_t size;
160 ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size);
161 return absl::string_view(data, size);
162 }
163
Serialize() const164 StatusOr<std::string> TpuExecutable::Serialize() const {
165 SE_ExecutableSerializationHandle* handle = nullptr;
166 absl::Cleanup cleanup = [&handle]() {
167 ExecutorApiFn()->TpuExecutableSerialize_FreeHandleFn(handle);
168 };
169 StatusHelper status;
170 ExecutorApiFn()->TpuExecutable_SerializeFn(se_executable_, &handle,
171 status.c_status);
172 if (!status.ok()) {
173 return status.status();
174 }
175 size_t size = ExecutorApiFn()->TpuExecutableSerialize_GetByteSizeFn(handle);
176 CHECK_GT(size, 0);
177 std::string serialized;
178 // NOTE(skyewm): this initializes serialized. If this ever becomes a
179 // bottleneck, we could change the return type to std::vector<uint8_t> or
180 // similar.
181 serialized.resize(size);
182 ExecutorApiFn()->TpuExecutableSerialize_WriteToArrayFn(
183 handle, size,
184 const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(serialized.data())),
185 status.c_status);
186 if (!status.ok()) {
187 return status.status();
188 }
189 return serialized;
190 }
191
Deserialize(absl::string_view serialized)192 StatusOr<std::unique_ptr<TpuExecutable>> TpuExecutable::Deserialize(
193 absl::string_view serialized) {
194 SE_Executable* se_executable;
195 StatusHelper status;
196 ExecutorApiFn()->TpuExecutable_DeserializeFn(
197 serialized.size(), reinterpret_cast<const uint8_t*>(serialized.data()),
198 &se_executable, status.c_status);
199 if (!status.ok()) {
200 return status.status();
201 }
202 XLA_HloModule c_module =
203 ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executable);
204 absl::Cleanup cleanup_c_module = [&c_module]() {
205 ApiConverter::Destroy(&c_module);
206 };
207 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
208 ApiConverter::FromC(c_module));
209 return std::make_unique<TpuExecutable>(se_executable, std::move(hlo_module));
210 }
211
212 } // namespace xla
213