• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executable.h"
17 
18 #include "absl/cleanup/cleanup.h"
19 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h"
20 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
21 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_stream.h"
22 #include "tensorflow/core/tpu/tpu_executor_api.h"
23 
24 namespace ApiConverter {
ToC(const xla::ServiceExecutableRunOptions & options)25 static SE_ExecutableRunOptions ToC(
26     const xla::ServiceExecutableRunOptions& options) {
27   SE_ExecutableRunOptions se_options;
28   se_options.allocator = ApiConverter::ToC(options.run_options().allocator());
29   se_options.device_ordinal = options.run_options().device_ordinal();
30   if (options.run_options().host_to_device_stream() != nullptr) {
31     se_options.host_to_device_stream =
32         static_cast<tensorflow::tpu::TpuStream*>(
33             options.run_options().host_to_device_stream()->implementation())
34             ->se_stream();
35   } else {
36     se_options.host_to_device_stream = nullptr;
37   }
38 
39   if (options.run_options().device_assignment() != nullptr) {
40     xla::DeviceAssignmentProto dev_assign_proto;
41     options.run_options()
42         .device_assignment()
43         ->Serialize(&dev_assign_proto)
44         .IgnoreError();
45     se_options.device_assignment =
46         stream_executor::tpu::SerializeProto(dev_assign_proto);
47   } else {
48     se_options.device_assignment.bytes = nullptr;
49     se_options.device_assignment.size = 0;
50   }
51 
52   se_options.rng_seed = options.run_options().rng_seed();
53   se_options.run_id = options.run_options().run_id().ToInt();
54   se_options.launch_id = options.run_options().launch_id();
55 
56   CHECK_EQ(options.run_options().then_execute_function(), nullptr)
57       << "ThenExecuteFunction not supported by this platform.";
58 
59   auto impl =
60       const_cast<stream_executor::Stream*>(options.stream())->implementation();
61   se_options.stream =
62       static_cast<tensorflow::tpu::TpuStream*>(impl)->se_stream();
63   return se_options;
64 }
65 }  // namespace ApiConverter
66 
67 namespace xla {
68 
69 using ::tensorflow::tpu::ExecutorApiFn;
70 
~TpuExecutable()71 TpuExecutable::~TpuExecutable() {
72   ExecutorApiFn()->TpuExecutable_FreeFn(se_executable_);
73 }
74 
ExecuteAsyncOnStream(const ServiceExecutableRunOptions * run_options,std::vector<ExecutionInput> arguments,HloExecutionProfile * hlo_execution_profile)75 StatusOr<ExecutionOutput> TpuExecutable::ExecuteAsyncOnStream(
76     const ServiceExecutableRunOptions* run_options,
77     std::vector<ExecutionInput> arguments,
78     HloExecutionProfile* hlo_execution_profile) {
79   SE_ExecutableRunOptions se_run_options = ApiConverter::ToC(*run_options);
80   SE_ExecutionInput** se_args = new SE_ExecutionInput*[arguments.size()];
81   for (int i = 0; i < arguments.size(); ++i) {
82     auto& arg = arguments[i];
83     se_args[i] = new SE_ExecutionInput;
84 
85     ApiConverter::ToC(arg.shape(), &se_args[i]->shape_tree.shape);
86     auto* arg_buffers = arg.MutableBuffers();
87     absl::InlinedVector<SE_MaybeOwningDeviceMemory, 2> se_buffers;
88     for (auto& pair : *arg_buffers) {
89       bool aliased = arg.unowned_indices().count(pair.first) > 0;
90       se_buffers.push_back(ApiConverter::ToC(pair.second, aliased));
91     }
92     se_args[i]->shape_tree.buffers =
93         new SE_MaybeOwningDeviceMemory[se_buffers.size()];
94     for (int j = 0; j < se_buffers.size(); ++j) {
95       se_args[i]->shape_tree.buffers[j] = se_buffers[j];
96     }
97 
98     ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape);
99     const auto& unowned_indices = arg.unowned_indices();
100     se_args[i]->unowned_indices_size = unowned_indices.size();
101     se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()];
102     int j = 0;
103     for (auto& idx : unowned_indices) {
104       se_args[i]->unowned_indices[j] = ApiConverter::ToC(idx);
105       ++j;
106     }
107   }
108   SE_ExecutionOutput se_execution_output;
109   StatusHelper status;
110   ExecutorApiFn()->TpuExecutable_ExecuteAsyncOnStreamFn(
111       se_executable_, &se_run_options, se_args, arguments.size(), nullptr,
112       &se_execution_output, status.c_status);
113 
114   if (se_run_options.device_assignment.bytes != nullptr) {
115     stream_executor::tpu::SerializedProto_Free(
116         se_run_options.device_assignment);
117   }
118   for (int i = 0; i < arguments.size(); ++i) {
119     ApiConverter::Destroy(&se_args[i]->shape_tree.shape);
120     ApiConverter::Destroy(&se_args[i]->dynamic_shape);
121     delete[] se_args[i]->unowned_indices;
122     delete[] se_args[i]->shape_tree.buffers;
123     delete se_args[i];
124   }
125   delete[] se_args;
126 
127   if (!status.ok()) {
128     return status.status();
129   }
130 
131   xla::ScopedShapedBuffer result(
132       ApiConverter::FromC(&se_execution_output.result),
133       run_options->stream()->parent()->GetAllocator());
134   ApiConverter::Destroy(&se_execution_output.result);
135 
136   ExecutionOutput output(std::move(result));
137   for (int i = 0; i < se_execution_output.aliased_indices_size; ++i) {
138     output.AddAliasedIndex(
139         ApiConverter::FromC(&se_execution_output.aliased_indices[i]));
140   }
141   ExecutorApiFn()->TpuExecutable_FreeXlaShapeIndexArrayFn(
142       se_execution_output.aliased_indices);
143 
144   for (int i = 0; i < se_execution_output.to_be_released_size; ++i) {
145     output.AddToBeReleased(
146         ApiConverter::FromC(&se_execution_output.to_be_released[i],
147                             run_options->stream()->parent()->GetAllocator())
148             .Release()
149             .value());
150   }
151   ExecutorApiFn()->TpuExecutable_FreeMaybeOwningDeviceMemoryArrayFn(
152       se_execution_output.to_be_released);
153 
154   return output;
155 }
156 
fingerprint() const157 absl::string_view TpuExecutable::fingerprint() const {
158   const char* data;
159   size_t size;
160   ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size);
161   return absl::string_view(data, size);
162 }
163 
Serialize() const164 StatusOr<std::string> TpuExecutable::Serialize() const {
165   SE_ExecutableSerializationHandle* handle = nullptr;
166   absl::Cleanup cleanup = [&handle]() {
167     ExecutorApiFn()->TpuExecutableSerialize_FreeHandleFn(handle);
168   };
169   StatusHelper status;
170   ExecutorApiFn()->TpuExecutable_SerializeFn(se_executable_, &handle,
171                                              status.c_status);
172   if (!status.ok()) {
173     return status.status();
174   }
175   size_t size = ExecutorApiFn()->TpuExecutableSerialize_GetByteSizeFn(handle);
176   CHECK_GT(size, 0);
177   std::string serialized;
178   // NOTE(skyewm): this initializes serialized. If this ever becomes a
179   // bottleneck, we could change the return type to std::vector<uint8_t> or
180   // similar.
181   serialized.resize(size);
182   ExecutorApiFn()->TpuExecutableSerialize_WriteToArrayFn(
183       handle, size,
184       const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(serialized.data())),
185       status.c_status);
186   if (!status.ok()) {
187     return status.status();
188   }
189   return serialized;
190 }
191 
Deserialize(absl::string_view serialized)192 StatusOr<std::unique_ptr<TpuExecutable>> TpuExecutable::Deserialize(
193     absl::string_view serialized) {
194   SE_Executable* se_executable;
195   StatusHelper status;
196   ExecutorApiFn()->TpuExecutable_DeserializeFn(
197       serialized.size(), reinterpret_cast<const uint8_t*>(serialized.data()),
198       &se_executable, status.c_status);
199   if (!status.ok()) {
200     return status.status();
201   }
202   XLA_HloModule c_module =
203       ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executable);
204   absl::Cleanup cleanup_c_module = [&c_module]() {
205     ApiConverter::Destroy(&c_module);
206   };
207   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
208                       ApiConverter::FromC(c_module));
209   return std::make_unique<TpuExecutable>(se_executable, std::move(hlo_module));
210 }
211 
212 }  // namespace xla
213