1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/core/platform/casts.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/types.h" 23 #include "tensorflow/stream_executor/device_memory.h" 24 #include "tensorflow/stream_executor/device_options.h" 25 #include "tensorflow/stream_executor/event.h" 26 #include "tensorflow/stream_executor/lib/status.h" 27 #include "tensorflow/stream_executor/lib/statusor.h" 28 #include "tensorflow/stream_executor/stream.h" 29 #include "tensorflow/stream_executor/stream_executor.h" 30 #include "tensorflow/stream_executor/stream_executor_internal.h" 31 #include "tensorflow/stream_executor/temporary_device_memory.h" 32 #include "tensorflow/stream_executor/timer.h" 33 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" 34 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" 35 #include "tensorflow/stream_executor/tpu/tpu_platform.h" 36 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" 37 #include "tensorflow/stream_executor/tpu/tpu_stream.h" 38 39 namespace tensorflow { 40 namespace tpu { 41 42 class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { 43 public: 44 using Status = ::stream_executor::port::Status; 45 template <typename T> 46 using StatusOr = ::stream_executor::port::StatusOr<T>; 47 using StatusCallback = std::function<void(const Status&)>; 48 using Stream = ::stream_executor::Stream; 49 using Event = ::stream_executor::Event; 50 using Timer = ::stream_executor::Timer; 51 using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase; 52 using StreamInterface = ::stream_executor::internal::StreamInterface; 53 using StreamExecutorInterface = 54 ::stream_executor::internal::StreamExecutorInterface; 55 56 using TimerMap = 57 absl::flat_hash_map<stream_executor::internal::TimerInterface*, 58 SE_Timer*>; 59 TpuExecutor(::tensorflow::tpu::TpuPlatformInterface * platform,SE_StreamExecutor * executor)60 explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform, 61 SE_StreamExecutor* executor) 62 : platform_(platform), executor_(executor) {} 63 64 ~TpuExecutor() override; 65 66 Status Init(int device_ordinal, 67 ::stream_executor::DeviceOptions device_options) override; 68 69 DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override; 70 71 Status AllocateEvent(Event* event) override; 72 73 bool AllocateStream(Stream* stream) override; 74 75 bool AllocateTimer(Timer* timer) override; 76 77 Status BlockHostUntilDone(::stream_executor::Stream* stream) override; 78 79 Status BlockUntilDoneOrFailed(); 80 81 StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>> 82 CreateDeviceDescription() const override; 83 84 bool CreateStreamDependency(Stream* dependent, Stream* other) override; 85 86 void DeallocateStream(Stream* stream) override; 87 88 void Deallocate(const DeviceMemoryBase& memory); 89 90 void Deallocate(DeviceMemoryBase* memory) override; 91 92 Status DeallocateEvent(Event* event) override; 93 94 void DeallocateTimer(Timer* timer) override; 95 96 bool DeviceMemoryUsage(int64* free, int64* total) const override; 97 98 void DequeueOutfeed(int32_t outfeed_queue_index, absl::Span<uint8> bytes, 99 StatusCallback done); 100 101 Status EnqueueInfeed(int32_t infeed_queue_index, 102 absl::Span<const uint8> bytes); 103 104 absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override; 105 106 tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override; 107 108 Status GetStatus(Stream* stream) override; 109 110 std::unique_ptr<::stream_executor::internal::StreamInterface> 111 GetStreamImplementation() override; 112 113 std::unique_ptr<::stream_executor::internal::TimerInterface> 114 GetTimerImplementation() override; 115 116 std::unique_ptr<::stream_executor::internal::EventInterface> 117 CreateEventImplementation() override; 118 119 bool HostCallback(Stream* stream, std::function<Status()> callback) override; 120 121 bool Memcpy(Stream* stream, void* host_dst, 122 const ::stream_executor::DeviceMemoryBase& device_src, 123 uint64 size) override; 124 125 bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, 126 const void* host_src, uint64 size) override; 127 128 bool MemcpyDeviceToDevice(Stream* stream, 129 ::stream_executor::DeviceMemoryBase* gpu_dst, 130 const ::stream_executor::DeviceMemoryBase& host_src, 131 uint64 size) override; 132 133 void SyncAndForgetFailedStreams(); 134 bool SynchronizeAllActivity() override; 135 136 Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst, 137 const void* host_src, uint64 size) override; 138 Status SynchronousMemcpy( 139 void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src, 140 uint64 size) override; 141 Status SynchronousMemcpyDeviceToDevice( 142 ::stream_executor::DeviceMemoryBase* device_dst, 143 const ::stream_executor::DeviceMemoryBase& device_src, 144 uint64 size) override; 145 146 int PlatformDeviceCount() override; 147 148 Event::Status PollForEventStatus(Event* event) override; 149 Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override; 150 Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override; 151 152 bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override; 153 bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override; 154 155 Status WaitForInfeedReady(int32_t infeed_queue_index); 156 157 Status WaitForOutfeedReady(int32_t outfeed_queue_index); 158 159 Status UnloadAllPrograms() override; 160 161 Status EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) override; 162 platform()163 const ::tensorflow::tpu::TpuPlatformInterface& platform() const override { 164 return *platform_; 165 } 166 platform()167 ::tensorflow::tpu::TpuPlatformInterface& platform() override { 168 return *platform_; 169 } 170 171 // TODO(henrytan): convert this to override once the base interface is changed 172 // to TpuExecutorInterface. 173 StatusOr<std::unique_ptr< 174 tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>> CreateTemporaryDeviceMemory(int64_t memory_space,int64_t byte_offset,int64_t size)175 CreateTemporaryDeviceMemory(int64_t memory_space, int64_t byte_offset, 176 int64_t size) override { 177 LOG(FATAL) << "Unimplemented."; 178 } 179 180 // -- Unimplemented (stubbed out) methods. 181 std::unique_ptr<stream_executor::internal::KernelInterface> CreateKernelImplementation()182 CreateKernelImplementation() override { 183 LOG(FATAL) << "Not yet implemented"; 184 } 185 GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)186 void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, 187 uint64 size) override { 188 LOG(FATAL) << "not yet implemented"; 189 } MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)190 Status MemZero(Stream* stream, DeviceMemoryBase* location, 191 uint64 size) override { 192 LOG(FATAL) << "not yet implemented"; 193 } Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)194 Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern, 195 uint64 size) override { 196 LOG(FATAL) << "not yet implemented"; 197 } EnablePeerAccessTo(StreamExecutorInterface * other)198 Status EnablePeerAccessTo(StreamExecutorInterface* other) override { 199 LOG(FATAL) << "not yet implemented"; 200 } CanEnablePeerAccessTo(StreamExecutorInterface * other)201 bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { 202 LOG(FATAL) << "not yet implemented"; 203 } 204 HostMemoryAllocate(uint64 size)205 void* HostMemoryAllocate(uint64 size) override { 206 LOG(FATAL) << "not yet implemented"; 207 } HostMemoryDeallocate(void * mem)208 void HostMemoryDeallocate(void* mem) override { 209 LOG(FATAL) << "not yet implemented"; 210 } HostMemoryRegister(void * mem,uint64 size)211 bool HostMemoryRegister(void* mem, uint64 size) override { 212 LOG(FATAL) << "not yet implemented"; 213 } HostMemoryUnregister(void * mem)214 bool HostMemoryUnregister(void* mem) override { 215 LOG(FATAL) << "not yet implemented"; 216 } SynchronousMemZero(DeviceMemoryBase * location,uint64 size)217 Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override { 218 LOG(FATAL) << "not yet implemented"; 219 } SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)220 Status SynchronousMemSet(DeviceMemoryBase* location, int value, 221 uint64 size) override { 222 LOG(FATAL) << "not yet implemented"; 223 } 224 se_executor()225 SE_StreamExecutor* se_executor() { return executor_; } 226 227 private: tpu_platform()228 TpuPlatform& tpu_platform() { 229 return *(tensorflow::down_cast<TpuPlatform*>(platform_)); 230 } 231 stream_map()232 TpuPlatform::StreamMap& stream_map() { 233 return *(tpu_platform().stream_map()); 234 } 235 get_stream(StreamInterface * ptr)236 SE_Stream* get_stream(StreamInterface* ptr) { 237 tensorflow::mutex_lock m(tpu_platform().mutex()); 238 return stream_map()[ptr]; 239 } 240 241 TimerMap timer_map_; 242 tensorflow::tpu::TpuPlatformInterface* platform_; 243 SE_StreamExecutor* executor_; 244 }; 245 246 } // namespace tpu 247 } // namespace tensorflow 248 249 #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 250