• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/types.h"
23 #include "tensorflow/stream_executor/device_memory.h"
24 #include "tensorflow/stream_executor/device_options.h"
25 #include "tensorflow/stream_executor/event.h"
26 #include "tensorflow/stream_executor/lib/status.h"
27 #include "tensorflow/stream_executor/lib/statusor.h"
28 #include "tensorflow/stream_executor/stream.h"
29 #include "tensorflow/stream_executor/stream_executor.h"
30 #include "tensorflow/stream_executor/stream_executor_internal.h"
31 #include "tensorflow/stream_executor/temporary_device_memory.h"
32 #include "tensorflow/stream_executor/timer.h"
33 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
34 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
35 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
36 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
37 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
38 
39 namespace tensorflow {
40 namespace tpu {
41 
42 class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
43  public:
44   using Status = ::stream_executor::port::Status;
45   template <typename T>
46   using StatusOr = ::stream_executor::port::StatusOr<T>;
47   using StatusCallback = std::function<void(const Status&)>;
48   using Stream = ::stream_executor::Stream;
49   using Event = ::stream_executor::Event;
50   using Timer = ::stream_executor::Timer;
51   using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
52   using StreamInterface = ::stream_executor::internal::StreamInterface;
53   using StreamExecutorInterface =
54       ::stream_executor::internal::StreamExecutorInterface;
55 
56   using TimerMap =
57       absl::flat_hash_map<stream_executor::internal::TimerInterface*,
58                           SE_Timer*>;
59 
TpuExecutor(::tensorflow::tpu::TpuPlatformInterface * platform,SE_StreamExecutor * executor)60   explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
61                        SE_StreamExecutor* executor)
62       : platform_(platform), executor_(executor) {}
63 
64   ~TpuExecutor() override;
65 
66   Status Init(int device_ordinal,
67               ::stream_executor::DeviceOptions device_options) override;
68 
69   DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override;
70 
71   Status AllocateEvent(Event* event) override;
72 
73   bool AllocateStream(Stream* stream) override;
74 
75   bool AllocateTimer(Timer* timer) override;
76 
77   Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
78 
79   Status BlockUntilDoneOrFailed();
80 
81   StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
82   CreateDeviceDescription() const override;
83 
84   bool CreateStreamDependency(Stream* dependent, Stream* other) override;
85 
86   void DeallocateStream(Stream* stream) override;
87 
88   void Deallocate(const DeviceMemoryBase& memory);
89 
90   void Deallocate(DeviceMemoryBase* memory) override;
91 
92   Status DeallocateEvent(Event* event) override;
93 
94   void DeallocateTimer(Timer* timer) override;
95 
96   bool DeviceMemoryUsage(int64* free, int64* total) const override;
97 
98   void DequeueOutfeed(int32_t outfeed_queue_index, absl::Span<uint8> bytes,
99                       StatusCallback done);
100 
101   Status EnqueueInfeed(int32_t infeed_queue_index,
102                        absl::Span<const uint8> bytes);
103 
104   absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
105 
106   tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override;
107 
108   Status GetStatus(Stream* stream) override;
109 
110   std::unique_ptr<::stream_executor::internal::StreamInterface>
111   GetStreamImplementation() override;
112 
113   std::unique_ptr<::stream_executor::internal::TimerInterface>
114   GetTimerImplementation() override;
115 
116   std::unique_ptr<::stream_executor::internal::EventInterface>
117   CreateEventImplementation() override;
118 
119   bool HostCallback(Stream* stream, std::function<Status()> callback) override;
120 
121   bool Memcpy(Stream* stream, void* host_dst,
122               const ::stream_executor::DeviceMemoryBase& device_src,
123               uint64 size) override;
124 
125   bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
126               const void* host_src, uint64 size) override;
127 
128   bool MemcpyDeviceToDevice(Stream* stream,
129                             ::stream_executor::DeviceMemoryBase* gpu_dst,
130                             const ::stream_executor::DeviceMemoryBase& host_src,
131                             uint64 size) override;
132 
133   void SyncAndForgetFailedStreams();
134   bool SynchronizeAllActivity() override;
135 
136   Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
137                            const void* host_src, uint64 size) override;
138   Status SynchronousMemcpy(
139       void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
140       uint64 size) override;
141   Status SynchronousMemcpyDeviceToDevice(
142       ::stream_executor::DeviceMemoryBase* device_dst,
143       const ::stream_executor::DeviceMemoryBase& device_src,
144       uint64 size) override;
145 
146   int PlatformDeviceCount() override;
147 
148   Event::Status PollForEventStatus(Event* event) override;
149   Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
150   Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
151 
152   bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
153   bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
154 
155   Status WaitForInfeedReady(int32_t infeed_queue_index);
156 
157   Status WaitForOutfeedReady(int32_t outfeed_queue_index);
158 
159   Status UnloadAllPrograms() override;
160 
161   Status EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) override;
162 
platform()163   const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
164     return *platform_;
165   }
166 
platform()167   ::tensorflow::tpu::TpuPlatformInterface& platform() override {
168     return *platform_;
169   }
170 
171   // TODO(henrytan): convert this to override once the base interface is changed
172   // to TpuExecutorInterface.
173   StatusOr<std::unique_ptr<
174       tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
CreateTemporaryDeviceMemory(int64_t memory_space,int64_t byte_offset,int64_t size)175   CreateTemporaryDeviceMemory(int64_t memory_space, int64_t byte_offset,
176                               int64_t size) override {
177     LOG(FATAL) << "Unimplemented.";
178   }
179 
180   // -- Unimplemented (stubbed out) methods.
181   std::unique_ptr<stream_executor::internal::KernelInterface>
CreateKernelImplementation()182   CreateKernelImplementation() override {
183     LOG(FATAL) << "Not yet implemented";
184   }
185 
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)186   void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
187                      uint64 size) override {
188     LOG(FATAL) << "not yet implemented";
189   }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)190   Status MemZero(Stream* stream, DeviceMemoryBase* location,
191                  uint64 size) override {
192     LOG(FATAL) << "not yet implemented";
193   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)194   Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
195                   uint64 size) override {
196     LOG(FATAL) << "not yet implemented";
197   }
EnablePeerAccessTo(StreamExecutorInterface * other)198   Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
199     LOG(FATAL) << "not yet implemented";
200   }
CanEnablePeerAccessTo(StreamExecutorInterface * other)201   bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
202     LOG(FATAL) << "not yet implemented";
203   }
204 
HostMemoryAllocate(uint64 size)205   void* HostMemoryAllocate(uint64 size) override {
206     LOG(FATAL) << "not yet implemented";
207   }
HostMemoryDeallocate(void * mem)208   void HostMemoryDeallocate(void* mem) override {
209     LOG(FATAL) << "not yet implemented";
210   }
HostMemoryRegister(void * mem,uint64 size)211   bool HostMemoryRegister(void* mem, uint64 size) override {
212     LOG(FATAL) << "not yet implemented";
213   }
HostMemoryUnregister(void * mem)214   bool HostMemoryUnregister(void* mem) override {
215     LOG(FATAL) << "not yet implemented";
216   }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)217   Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override {
218     LOG(FATAL) << "not yet implemented";
219   }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)220   Status SynchronousMemSet(DeviceMemoryBase* location, int value,
221                            uint64 size) override {
222     LOG(FATAL) << "not yet implemented";
223   }
224 
se_executor()225   SE_StreamExecutor* se_executor() { return executor_; }
226 
227  private:
tpu_platform()228   TpuPlatform& tpu_platform() {
229     return *(tensorflow::down_cast<TpuPlatform*>(platform_));
230   }
231 
stream_map()232   TpuPlatform::StreamMap& stream_map() {
233     return *(tpu_platform().stream_map());
234   }
235 
get_stream(StreamInterface * ptr)236   SE_Stream* get_stream(StreamInterface* ptr) {
237     tensorflow::mutex_lock m(tpu_platform().mutex());
238     return stream_map()[ptr];
239   }
240 
241   TimerMap timer_map_;
242   tensorflow::tpu::TpuPlatformInterface* platform_;
243   SE_StreamExecutor* executor_;
244 };
245 
246 }  // namespace tpu
247 }  // namespace tensorflow
248 
249 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
250