1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/core/platform/casts.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/types.h" 23 #include "tensorflow/stream_executor/device_memory.h" 24 #include "tensorflow/stream_executor/device_options.h" 25 #include "tensorflow/stream_executor/event.h" 26 #include "tensorflow/stream_executor/lib/status.h" 27 #include "tensorflow/stream_executor/lib/statusor.h" 28 #include "tensorflow/stream_executor/stream.h" 29 #include "tensorflow/stream_executor/stream_executor.h" 30 #include "tensorflow/stream_executor/stream_executor_internal.h" 31 #include "tensorflow/stream_executor/temporary_device_memory.h" 32 #include "tensorflow/stream_executor/timer.h" 33 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" 34 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" 35 #include "tensorflow/stream_executor/tpu/tpu_platform.h" 36 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" 37 #include "tensorflow/stream_executor/tpu/tpu_stream.h" 38 39 namespace tensorflow { 40 namespace tpu { 41 42 class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { 43 public: 44 using Status = ::stream_executor::port::Status; 45 template <typename T> 46 using StatusOr = ::stream_executor::port::StatusOr<T>; 47 using StatusCallback = std::function<void(const Status&)>; 48 using Stream = ::stream_executor::Stream; 49 using Event = ::stream_executor::Event; 50 using Timer = ::stream_executor::Timer; 51 using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase; 52 using StreamInterface = ::stream_executor::internal::StreamInterface; 53 using StreamExecutorInterface = 54 ::stream_executor::internal::StreamExecutorInterface; 55 56 using TimerMap = 57 absl::flat_hash_map<stream_executor::internal::TimerInterface*, 58 SE_Timer*>; 59 TpuExecutor(::tensorflow::tpu::TpuPlatformInterface * platform,SE_StreamExecutor * executor)60 explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform, 61 SE_StreamExecutor* executor) 62 : platform_(platform), executor_(executor) {} 63 64 ~TpuExecutor() override; 65 66 Status Init(int device_ordinal, 67 ::stream_executor::DeviceOptions device_options) override; 68 69 DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override; 70 71 Status AllocateEvent(Event* event) override; 72 73 bool AllocateStream(Stream* stream) override; 74 75 bool AllocateTimer(Timer* timer) override; 76 77 Status BlockHostUntilDone(::stream_executor::Stream* stream) override; 78 79 Status BlockUntilDoneOrFailed(); 80 81 StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>> 82 CreateDeviceDescription() const override; 83 84 bool CreateStreamDependency(Stream* dependent, Stream* other) override; 85 86 void DeallocateStream(Stream* stream) override; 87 88 void Deallocate(const DeviceMemoryBase& memory); 89 90 void Deallocate(DeviceMemoryBase* memory) override; 91 92 Status DeallocateEvent(Event* event) override; 93 94 void DeallocateTimer(Timer* timer) override; 95 96 bool DeviceMemoryUsage(int64* free, int64* total) const override; 97 98 void DequeueOutfeed(int32 outfeed_queue_index, absl::Span<uint8> bytes, 99 StatusCallback done); 100 101 Status EnqueueInfeed(int32 infeed_queue_index, absl::Span<const uint8> bytes); 102 103 absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override; 104 105 tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override; 106 107 Status GetStatus(Stream* stream) override; 108 109 std::unique_ptr<::stream_executor::internal::StreamInterface> 110 GetStreamImplementation() override; 111 112 std::unique_ptr<::stream_executor::internal::TimerInterface> 113 GetTimerImplementation() override; 114 115 std::unique_ptr<::stream_executor::internal::EventInterface> 116 CreateEventImplementation() override; 117 118 bool HostCallback(Stream* stream, std::function<Status()> callback) override; 119 120 bool Memcpy(Stream* stream, void* host_dst, 121 const ::stream_executor::DeviceMemoryBase& device_src, 122 uint64 size) override; 123 124 bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, 125 const void* host_src, uint64 size) override; 126 127 bool MemcpyDeviceToDevice(Stream* stream, 128 ::stream_executor::DeviceMemoryBase* gpu_dst, 129 const ::stream_executor::DeviceMemoryBase& host_src, 130 uint64 size) override; 131 132 void SyncAndForgetFailedStreams(); 133 bool SynchronizeAllActivity() override; 134 135 Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst, 136 const void* host_src, uint64 size) override; 137 Status SynchronousMemcpy( 138 void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src, 139 uint64 size) override; 140 Status SynchronousMemcpyDeviceToDevice( 141 ::stream_executor::DeviceMemoryBase* device_dst, 142 const ::stream_executor::DeviceMemoryBase& device_src, 143 uint64 size) override; 144 145 int PlatformDeviceCount() override; 146 147 Event::Status PollForEventStatus(Event* event) override; 148 Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override; 149 Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override; 150 151 bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override; 152 bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override; 153 154 Status WaitForInfeedReady(int32 infeed_queue_index); 155 156 Status WaitForOutfeedReady(int32 outfeed_queue_index); 157 platform()158 const ::tensorflow::tpu::TpuPlatformInterface& platform() const override { 159 return *platform_; 160 } 161 platform()162 ::tensorflow::tpu::TpuPlatformInterface& platform() override { 163 return *platform_; 164 } 165 166 // TODO(henrytan): convert this to override once the base interface is changed 167 // to TpuExecutorInterface. 168 StatusOr<std::unique_ptr< 169 tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>> CreateTemporaryDeviceMemory(int64 memory_space,int64 byte_offset,int64 size)170 CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset, 171 int64 size) override { 172 LOG(FATAL) << "Unimplemented."; 173 } 174 175 // -- Unimplemented (stubbed out) methods. 176 std::unique_ptr<stream_executor::internal::KernelInterface> CreateKernelImplementation()177 CreateKernelImplementation() override { 178 LOG(FATAL) << "Not yet implemented"; 179 } 180 GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)181 void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, 182 uint64 size) override { 183 LOG(FATAL) << "not yet implemented"; 184 } MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)185 Status MemZero(Stream* stream, DeviceMemoryBase* location, 186 uint64 size) override { 187 LOG(FATAL) << "not yet implemented"; 188 } Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)189 Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern, 190 uint64 size) override { 191 LOG(FATAL) << "not yet implemented"; 192 } EnablePeerAccessTo(StreamExecutorInterface * other)193 Status EnablePeerAccessTo(StreamExecutorInterface* other) override { 194 LOG(FATAL) << "not yet implemented"; 195 } CanEnablePeerAccessTo(StreamExecutorInterface * other)196 bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { 197 LOG(FATAL) << "not yet implemented"; 198 } 199 HostMemoryAllocate(uint64 size)200 void* HostMemoryAllocate(uint64 size) override { 201 LOG(FATAL) << "not yet implemented"; 202 } HostMemoryDeallocate(void * mem)203 void HostMemoryDeallocate(void* mem) override { 204 LOG(FATAL) << "not yet implemented"; 205 } HostMemoryRegister(void * mem,uint64 size)206 bool HostMemoryRegister(void* mem, uint64 size) override { 207 LOG(FATAL) << "not yet implemented"; 208 } HostMemoryUnregister(void * mem)209 bool HostMemoryUnregister(void* mem) override { 210 LOG(FATAL) << "not yet implemented"; 211 } SynchronousMemZero(DeviceMemoryBase * location,uint64 size)212 Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override { 213 LOG(FATAL) << "not yet implemented"; 214 } SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)215 Status SynchronousMemSet(DeviceMemoryBase* location, int value, 216 uint64 size) override { 217 LOG(FATAL) << "not yet implemented"; 218 } 219 se_executor()220 SE_StreamExecutor* se_executor() { return executor_; } 221 222 private: tpu_platform()223 TpuPlatform& tpu_platform() { 224 return *(tensorflow::down_cast<TpuPlatform*>(platform_)); 225 } 226 stream_map()227 TpuPlatform::StreamMap& stream_map() { 228 return *(tpu_platform().stream_map()); 229 } 230 get_stream(StreamInterface * ptr)231 SE_Stream* get_stream(StreamInterface* ptr) { 232 tensorflow::mutex_lock m(tpu_platform().mutex()); 233 return stream_map()[ptr]; 234 } 235 236 TimerMap timer_map_; 237 tensorflow::tpu::TpuPlatformInterface* platform_; 238 SE_StreamExecutor* executor_; 239 }; 240 241 } // namespace tpu 242 } // namespace tensorflow 243 244 #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 245