1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
17
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/tpu/tpu_api.h"
21 #include "tensorflow/stream_executor/tpu/status_helper.h"
22 #include "tensorflow/stream_executor/tpu/tpu_event.h"
23 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
24 #include "tensorflow/stream_executor/tpu/tpu_timer.h"
25
26 using stream_executor::DeviceMemoryBase;
27
28 namespace tensorflow {
29 namespace tpu {
30
31 namespace {
32 using ::stream_executor::port::Status;
33 } // namespace
34
~TpuExecutor()35 TpuExecutor::~TpuExecutor() {
36 tpu::ExecutorApiFn()->TpuExecutor_FreeFn(executor_);
37 }
38
Init(int device_ordinal,::stream_executor::DeviceOptions device_options)39 Status TpuExecutor::Init(int device_ordinal,
40 ::stream_executor::DeviceOptions device_options) {
41 StatusHelper status;
42 SE_DeviceOptions* options =
43 tpu::ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(
44 device_options.flags());
45 tpu::ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options,
46 status.c_status);
47 tpu::ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options);
48 return status.status();
49 }
50
PlatformDeviceCount()51 int TpuExecutor::PlatformDeviceCount() {
52 return tpu::ExecutorApiFn()->TpuExecutor_PlatformDeviceCountFn(executor_);
53 }
54
SyncAndForgetFailedStreams()55 void TpuExecutor::SyncAndForgetFailedStreams() {
56 tpu::ExecutorApiFn()->TpuExecutor_SyncAndForgetFailedStreamsFn(executor_);
57 }
58
SynchronizeAllActivity()59 bool TpuExecutor::SynchronizeAllActivity() {
60 return tpu::ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_);
61 }
62
BlockHostUntilDone(Stream * stream)63 Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
64 StatusHelper status;
65 tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn(
66 executor_, get_stream(stream->implementation()), status.c_status);
67 return status.status();
68 }
69
BlockUntilDoneOrFailed()70 Status TpuExecutor::BlockUntilDoneOrFailed() {
71 StatusHelper status;
72 tpu::ExecutorApiFn()->TpuExecutor_BlockUntilDoneOrFailedFn(executor_,
73 status.c_status);
74 return status.status();
75 }
76
GetStatus(Stream * stream)77 Status TpuExecutor::GetStatus(Stream* stream) {
78 StatusHelper status;
79 tpu::ExecutorApiFn()->TpuExecutor_GetStatusFn(
80 executor_, get_stream(stream->implementation()), status.c_status);
81 return status.status();
82 }
83
GetCoreLocationExternal() const84 tpu::TpuCoreLocationExternal TpuExecutor::GetCoreLocationExternal() const {
85 return tpu::TpuCoreLocationExternal(
86 tpu::ExecutorApiFn()->TpuExecutor_GetCoreLocationFn(executor_));
87 }
88
AllocateStream(Stream * stream)89 bool TpuExecutor::AllocateStream(Stream* stream) {
90 return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
91 executor_, get_stream(stream->implementation()));
92 }
93
DeallocateStream(Stream * stream)94 void TpuExecutor::DeallocateStream(Stream* stream) {
95 tpu::ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(
96 executor_, get_stream(stream->implementation()));
97 tpu_platform().mutex().lock();
98 stream_map().erase(stream->implementation());
99 tpu_platform().mutex().unlock();
100 }
101
CreateStreamDependency(Stream * dependent,Stream * other)102 bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
103 return tpu::ExecutorApiFn()->TpuExecutor_CreateStreamDependencyFn(
104 executor_, get_stream(dependent->implementation()),
105 get_stream(other->implementation()));
106 }
107
AllocateEvent(Event * event)108 Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); }
109
DeallocateEvent(Event * event)110 Status TpuExecutor::DeallocateEvent(Event* event) {
111 tpu_platform().EraseEvent(event->implementation());
112 return Status::OK();
113 }
114
115 // AllocateTimer/DeallocateTimer have no specialization.
AllocateTimer(Timer * timer)116 bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
117
DeallocateTimer(Timer * timer)118 void TpuExecutor::DeallocateTimer(Timer* timer) {}
119
StartTimer(Stream * stream,::stream_executor::Timer * timer)120 bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
121 return tpu::ExecutorApiFn()->TpuExecutor_StartTimerFn(
122 executor_, get_stream(stream->implementation()),
123 timer_map_.at(timer->implementation()));
124 }
125
StopTimer(Stream * stream,::stream_executor::Timer * timer)126 bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
127 return tpu::ExecutorApiFn()->TpuExecutor_StopTimerFn(
128 executor_, get_stream(stream->implementation()),
129 timer_map_.at(timer->implementation()));
130 }
131
PollForEventStatus(stream_executor::Event * event)132 stream_executor::Event::Status TpuExecutor::PollForEventStatus(
133 stream_executor::Event* event) {
134 auto se_event = tpu_platform().LookupEvent(event->implementation());
135 return stream_executor::Event::Status(
136 tpu::ExecutorApiFn()->TpuExecutor_PollForEventStatusFn(executor_,
137 se_event));
138 }
139
RecordEvent(Stream * stream,::stream_executor::Event * event)140 Status TpuExecutor::RecordEvent(Stream* stream,
141 ::stream_executor::Event* event) {
142 StatusHelper status;
143 auto se_event = tpu_platform().LookupEvent(event->implementation());
144 tpu::ExecutorApiFn()->TpuExecutor_RecordEventFn(
145 executor_, get_stream(stream->implementation()), se_event,
146 status.c_status);
147 return status.status();
148 }
149
WaitForEvent(Stream * stream,::stream_executor::Event * event)150 Status TpuExecutor::WaitForEvent(Stream* stream,
151 ::stream_executor::Event* event) {
152 StatusHelper status;
153 auto se_event = tpu_platform().LookupEvent(event->implementation());
154 tpu::ExecutorApiFn()->TpuExecutor_WaitForEventFn(
155 executor_, get_stream(stream->implementation()), se_event,
156 status.c_status);
157 return status.status();
158 }
159
160 // Implementations for Timer, Stream, Event
161 // We need to map these implementations to internal equivalents -- thus we
162 // allocate the internal Timer, Stream and Event operations here, and map
163 // the implementations to the internal values. The "wrapper" interfaces are
164 // responsible for deallocating the internal value when they are destroyed.
165
166 // Called by Timer::Timer
167 std::unique_ptr<::stream_executor::internal::TimerInterface>
GetTimerImplementation()168 TpuExecutor::GetTimerImplementation() {
169 SE_Timer* tpu_timer = tpu::ExecutorApiFn()->TpuTimer_NewFn(executor_);
170 auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
171 timer_map_[ptr.get()] = tpu_timer;
172 return ptr;
173 }
174
175 // Called by Stream::Stream
176 std::unique_ptr<::stream_executor::internal::StreamInterface>
GetStreamImplementation()177 TpuExecutor::GetStreamImplementation() {
178 SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
179 auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
180 tpu_platform().mutex().lock();
181 stream_map()[ptr.get()] = tpu_stream;
182 tpu_platform().mutex().unlock();
183 return ptr;
184 }
185
186 // Called by Event::Event
187 std::unique_ptr<::stream_executor::internal::EventInterface>
CreateEventImplementation()188 TpuExecutor::CreateEventImplementation() {
189 SE_Event* tpu_event = tpu::ExecutorApiFn()->TpuEvent_NewFn(executor_);
190 auto ptr = absl::make_unique<TpuEvent>(tpu_event);
191 tpu_platform().InsertEvent(ptr.get(), tpu_event);
192 return ptr;
193 }
194
Allocate(uint64 size,int64 memory_space)195 DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
196 SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
197 executor_, size, memory_space);
198 return ApiConverter::FromC(se_base);
199 }
200
Deallocate(const DeviceMemoryBase & memory)201 void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
202 SE_DeviceMemoryBase se_base = ApiConverter::ToC(memory);
203 tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
204 }
205
Deallocate(DeviceMemoryBase * memory)206 void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
207 SE_DeviceMemoryBase se_base = ApiConverter::ToC(*memory);
208 tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
209 }
210
DeviceMemoryUsage(int64 * free,int64 * total) const211 bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
212 int64_t _free;
213 int64_t _total;
214 if (tpu::ExecutorApiFn()->TpuExecutor_DeviceMemoryUsageFn(executor_, &_free,
215 &_total)) {
216 *free = _free;
217 *total = _total;
218 return true;
219 }
220 return false;
221 }
222
223 absl::optional<stream_executor::AllocatorStats>
GetAllocatorStats()224 TpuExecutor::GetAllocatorStats() {
225 SE_AllocatorStats c_stats;
226 if (tpu::ExecutorApiFn()->TpuExecutor_GetAllocatorStatsFn(executor_,
227 &c_stats)) {
228 ::stream_executor::AllocatorStats stats;
229 stats.num_allocs = c_stats.num_allocs;
230 stats.bytes_in_use = c_stats.bytes_in_use;
231 stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
232 stats.largest_alloc_size = c_stats.largest_alloc_size;
233 if (c_stats.has_bytes_limit) {
234 stats.bytes_limit = c_stats.bytes_limit;
235 }
236 stats.bytes_reserved = c_stats.bytes_reserved;
237 stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
238 if (c_stats.has_bytes_reservable_limit) {
239 stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
240 }
241 stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
242 return stats;
243 }
244 return {};
245 }
246
WaitForInfeedReady(int32 infeed_queue_index)247 Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
248 StatusHelper status;
249 tpu::ExecutorApiFn()->TpuExecutor_WaitForInfeedReadyFn(
250 executor_, infeed_queue_index, status.c_status);
251 return status.status();
252 }
253
WaitForOutfeedReady(int32 outfeed_queue_index)254 Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
255 StatusHelper status;
256 tpu::ExecutorApiFn()->TpuExecutor_WaitForOutfeedReadyFn(
257 executor_, outfeed_queue_index, status.c_status);
258 return status.status();
259 }
260
DequeueOutfeed(int32 outfeed_queue_index,absl::Span<uint8> bytes,StatusCallback done)261 void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
262 absl::Span<uint8> bytes, StatusCallback done) {
263 StatusHelper status;
264 tpu::ExecutorApiFn()->TpuExecutor_DequeueOutfeedFn(
265 executor_, outfeed_queue_index, bytes.data(), bytes.size(),
266 status.c_status);
267 done(status.status());
268 }
269
EnqueueInfeed(int32 infeed_queue_index,absl::Span<const uint8> bytes)270 Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
271 absl::Span<const uint8> bytes) {
272 StatusHelper status;
273 tpu::ExecutorApiFn()->TpuExecutor_EnqueueInfeedFn(
274 executor_, infeed_queue_index, bytes.data(), bytes.size(),
275 status.c_status);
276 return status.status();
277 }
278
Memcpy(Stream * stream,void * host_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64 size)279 bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
280 const ::stream_executor::DeviceMemoryBase& device_src,
281 uint64 size) {
282 SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src);
283 return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
284 executor_, get_stream(stream->implementation()), host_dst, &se_base,
285 size);
286 }
287
Memcpy(Stream * stream,::stream_executor::DeviceMemoryBase * device_dst,const void * host_src,uint64 size)288 bool TpuExecutor::Memcpy(Stream* stream,
289 ::stream_executor::DeviceMemoryBase* device_dst,
290 const void* host_src, uint64 size) {
291 SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst);
292 return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
293 executor_, get_stream(stream->implementation()), &se_base, host_src,
294 size);
295 }
296
SynchronousMemcpy(::stream_executor::DeviceMemoryBase * device_dst,const void * host_src,uint64 size)297 Status TpuExecutor::SynchronousMemcpy(
298 ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
299 uint64 size) {
300 StatusHelper status;
301 SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst);
302 tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
303 executor_, &se_base, host_src, size, status.c_status);
304 return status.status();
305 }
306
SynchronousMemcpy(void * host_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64 size)307 Status TpuExecutor::SynchronousMemcpy(
308 void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
309 uint64 size) {
310 StatusHelper status;
311 SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src);
312 tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
313 executor_, host_dst, &se_base, size, status.c_status);
314 return status.status();
315 }
316
SynchronousMemcpyDeviceToDevice(::stream_executor::DeviceMemoryBase * device_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64 size)317 Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
318 ::stream_executor::DeviceMemoryBase* device_dst,
319 const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) {
320 return ::stream_executor::port::UnimplementedError(
321 "This operation not supported on TPU");
322 }
323
MemcpyDeviceToDevice(Stream * stream,::stream_executor::DeviceMemoryBase * gpu_dst,const::stream_executor::DeviceMemoryBase & host_src,uint64 size)324 bool TpuExecutor::MemcpyDeviceToDevice(
325 Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
326 const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) {
327 LOG(FATAL) << __func__ << " not supported on TpuExecutor";
328 }
329
330 struct HostCallbackContext {
331 std::function<Status()> callback;
332 };
333
HostCallbackTrampoline(void * ctx)334 TF_Status* HostCallbackTrampoline(void* ctx) {
335 HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
336 Status status = host_ctx->callback();
337 TF_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
338 status.code(), status.error_message().c_str());
339 delete host_ctx;
340 return c_status;
341 }
342
HostCallback(Stream * stream,std::function<Status ()> callback)343 bool TpuExecutor::HostCallback(Stream* stream,
344 std::function<Status()> callback) {
345 HostCallbackContext* ctx = new HostCallbackContext{callback};
346 return tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn(
347 executor_, get_stream(stream->implementation()), &HostCallbackTrampoline,
348 ctx);
349 }
350
351 TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
CreateDeviceDescription() const352 TpuExecutor::CreateDeviceDescription() const {
353 StatusHelper status;
354 SE_DeviceDescription* description =
355 tpu::ExecutorApiFn()->TpuDeviceDescription_NewFn();
356 auto cleanup = tensorflow::gtl::MakeCleanup([description]() {
357 tpu::ExecutorApiFn()->TpuDeviceDescription_FreeFn(description);
358 });
359 tpu::ExecutorApiFn()->TpuExecutor_CreateDeviceDescriptionFn(
360 executor_, description, status.c_status);
361 if (status.status().ok()) {
362 stream_executor::internal::DeviceDescriptionBuilder builder;
363 CHECK_NE(description->device_vendor, nullptr);
364 builder.set_device_vendor(description->device_vendor);
365 builder.set_name(description->name);
366 builder.set_clock_rate_ghz(description->clock_rate_ghz);
367 builder.set_core_count(description->core_count);
368 builder.set_ecc_enabled(description->ecc_enabled);
369 builder.set_device_memory_size(description->device_memory_size);
370 builder.set_platform_version(description->platform_version);
371 return builder.Build();
372 }
373 return status.status();
374 }
375
376 } // namespace tpu
377 } // namespace tensorflow
378