• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/types.h"
23 #include "tensorflow/stream_executor/device_memory.h"
24 #include "tensorflow/stream_executor/device_options.h"
25 #include "tensorflow/stream_executor/event.h"
26 #include "tensorflow/stream_executor/lib/status.h"
27 #include "tensorflow/stream_executor/lib/statusor.h"
28 #include "tensorflow/stream_executor/stream.h"
29 #include "tensorflow/stream_executor/stream_executor.h"
30 #include "tensorflow/stream_executor/stream_executor_internal.h"
31 #include "tensorflow/stream_executor/temporary_device_memory.h"
32 #include "tensorflow/stream_executor/timer.h"
33 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
34 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
35 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
36 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
37 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
38 
39 namespace tensorflow {
40 namespace tpu {
41 
42 class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
43  public:
44   using Status = ::stream_executor::port::Status;
45   template <typename T>
46   using StatusOr = ::stream_executor::port::StatusOr<T>;
47   using StatusCallback = std::function<void(const Status&)>;
48   using Stream = ::stream_executor::Stream;
49   using Event = ::stream_executor::Event;
50   using Timer = ::stream_executor::Timer;
51   using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
52   using StreamInterface = ::stream_executor::internal::StreamInterface;
53   using StreamExecutorInterface =
54       ::stream_executor::internal::StreamExecutorInterface;
55 
56   using TimerMap =
57       absl::flat_hash_map<stream_executor::internal::TimerInterface*,
58                           SE_Timer*>;
59 
TpuExecutor(::tensorflow::tpu::TpuPlatformInterface * platform,SE_StreamExecutor * executor)60   explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
61                        SE_StreamExecutor* executor)
62       : platform_(platform), executor_(executor) {}
63 
64   ~TpuExecutor() override;
65 
66   Status Init(int device_ordinal,
67               ::stream_executor::DeviceOptions device_options) override;
68 
69   DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override;
70 
71   Status AllocateEvent(Event* event) override;
72 
73   bool AllocateStream(Stream* stream) override;
74 
75   bool AllocateTimer(Timer* timer) override;
76 
77   Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
78 
79   Status BlockUntilDoneOrFailed();
80 
81   StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
82   CreateDeviceDescription() const override;
83 
84   bool CreateStreamDependency(Stream* dependent, Stream* other) override;
85 
86   void DeallocateStream(Stream* stream) override;
87 
88   void Deallocate(const DeviceMemoryBase& memory);
89 
90   void Deallocate(DeviceMemoryBase* memory) override;
91 
92   Status DeallocateEvent(Event* event) override;
93 
94   void DeallocateTimer(Timer* timer) override;
95 
96   bool DeviceMemoryUsage(int64* free, int64* total) const override;
97 
98   void DequeueOutfeed(int32 outfeed_queue_index, absl::Span<uint8> bytes,
99                       StatusCallback done);
100 
101   Status EnqueueInfeed(int32 infeed_queue_index, absl::Span<const uint8> bytes);
102 
103   absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
104 
105   tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override;
106 
107   Status GetStatus(Stream* stream) override;
108 
109   std::unique_ptr<::stream_executor::internal::StreamInterface>
110   GetStreamImplementation() override;
111 
112   std::unique_ptr<::stream_executor::internal::TimerInterface>
113   GetTimerImplementation() override;
114 
115   std::unique_ptr<::stream_executor::internal::EventInterface>
116   CreateEventImplementation() override;
117 
118   bool HostCallback(Stream* stream, std::function<Status()> callback) override;
119 
120   bool Memcpy(Stream* stream, void* host_dst,
121               const ::stream_executor::DeviceMemoryBase& device_src,
122               uint64 size) override;
123 
124   bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
125               const void* host_src, uint64 size) override;
126 
127   bool MemcpyDeviceToDevice(Stream* stream,
128                             ::stream_executor::DeviceMemoryBase* gpu_dst,
129                             const ::stream_executor::DeviceMemoryBase& host_src,
130                             uint64 size) override;
131 
132   void SyncAndForgetFailedStreams();
133   bool SynchronizeAllActivity() override;
134 
135   Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
136                            const void* host_src, uint64 size) override;
137   Status SynchronousMemcpy(
138       void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
139       uint64 size) override;
140   Status SynchronousMemcpyDeviceToDevice(
141       ::stream_executor::DeviceMemoryBase* device_dst,
142       const ::stream_executor::DeviceMemoryBase& device_src,
143       uint64 size) override;
144 
145   int PlatformDeviceCount() override;
146 
147   Event::Status PollForEventStatus(Event* event) override;
148   Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
149   Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
150 
151   bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
152   bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
153 
154   Status WaitForInfeedReady(int32 infeed_queue_index);
155 
156   Status WaitForOutfeedReady(int32 outfeed_queue_index);
157 
platform()158   const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
159     return *platform_;
160   }
161 
platform()162   ::tensorflow::tpu::TpuPlatformInterface& platform() override {
163     return *platform_;
164   }
165 
166   // TODO(henrytan): convert this to override once the base interface is changed
167   // to TpuExecutorInterface.
168   StatusOr<std::unique_ptr<
169       tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
CreateTemporaryDeviceMemory(int64 memory_space,int64 byte_offset,int64 size)170   CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
171                               int64 size) override {
172     LOG(FATAL) << "Unimplemented.";
173   }
174 
175   // -- Unimplemented (stubbed out) methods.
176   std::unique_ptr<stream_executor::internal::KernelInterface>
CreateKernelImplementation()177   CreateKernelImplementation() override {
178     LOG(FATAL) << "Not yet implemented";
179   }
180 
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)181   void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
182                      uint64 size) override {
183     LOG(FATAL) << "not yet implemented";
184   }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)185   Status MemZero(Stream* stream, DeviceMemoryBase* location,
186                  uint64 size) override {
187     LOG(FATAL) << "not yet implemented";
188   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)189   Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
190                   uint64 size) override {
191     LOG(FATAL) << "not yet implemented";
192   }
EnablePeerAccessTo(StreamExecutorInterface * other)193   Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
194     LOG(FATAL) << "not yet implemented";
195   }
CanEnablePeerAccessTo(StreamExecutorInterface * other)196   bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
197     LOG(FATAL) << "not yet implemented";
198   }
199 
HostMemoryAllocate(uint64 size)200   void* HostMemoryAllocate(uint64 size) override {
201     LOG(FATAL) << "not yet implemented";
202   }
HostMemoryDeallocate(void * mem)203   void HostMemoryDeallocate(void* mem) override {
204     LOG(FATAL) << "not yet implemented";
205   }
HostMemoryRegister(void * mem,uint64 size)206   bool HostMemoryRegister(void* mem, uint64 size) override {
207     LOG(FATAL) << "not yet implemented";
208   }
HostMemoryUnregister(void * mem)209   bool HostMemoryUnregister(void* mem) override {
210     LOG(FATAL) << "not yet implemented";
211   }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)212   Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override {
213     LOG(FATAL) << "not yet implemented";
214   }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)215   Status SynchronousMemSet(DeviceMemoryBase* location, int value,
216                            uint64 size) override {
217     LOG(FATAL) << "not yet implemented";
218   }
219 
se_executor()220   SE_StreamExecutor* se_executor() { return executor_; }
221 
222  private:
tpu_platform()223   TpuPlatform& tpu_platform() {
224     return *(tensorflow::down_cast<TpuPlatform*>(platform_));
225   }
226 
stream_map()227   TpuPlatform::StreamMap& stream_map() {
228     return *(tpu_platform().stream_map());
229   }
230 
get_stream(StreamInterface * ptr)231   SE_Stream* get_stream(StreamInterface* ptr) {
232     tensorflow::mutex_lock m(tpu_platform().mutex());
233     return stream_map()[ptr];
234   }
235 
236   TimerMap timer_map_;
237   tensorflow::tpu::TpuPlatformInterface* platform_;
238   SE_StreamExecutor* executor_;
239 };
240 
241 }  // namespace tpu
242 }  // namespace tensorflow
243 
244 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
245