1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Declares the XlaInterpreterExecutor class, which is a CPU-only implementation 17 // of the StreamExecutor interface. For now, this is used for testing and to 18 // examine the performance of host-based StreamExecutor code. 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ 21 22 #include <functional> 23 #include <memory> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/stream_executor/blas.h" 29 #include "tensorflow/stream_executor/device_description.h" 30 #include "tensorflow/stream_executor/device_memory.h" 31 #include "tensorflow/stream_executor/device_options.h" 32 #include "tensorflow/stream_executor/event.h" 33 #include "tensorflow/stream_executor/host/host_stream.h" 34 #include "tensorflow/stream_executor/host/host_timer.h" 35 #include "tensorflow/stream_executor/kernel.h" 36 #include "tensorflow/stream_executor/kernel_spec.h" 37 #include "tensorflow/stream_executor/launch_dim.h" 38 #include "tensorflow/stream_executor/plugin.h" 39 #include "tensorflow/stream_executor/rng.h" 40 #include "tensorflow/stream_executor/stream.h" 41 #include "tensorflow/stream_executor/stream_executor.h" 42 #include "tensorflow/stream_executor/stream_executor_internal.h" 43 #include "tensorflow/stream_executor/timer.h" 44 45 namespace stream_executor { 46 namespace interpreter { 47 48 using Args = absl::Span<const DeviceMemoryBase>; 49 50 class XlaInterpreterExecutor : public internal::StreamExecutorInterface { 51 public: 52 explicit XlaInterpreterExecutor(const PluginConfig &plugin_config); 53 ~XlaInterpreterExecutor() override; 54 Init(int device_ordinal,DeviceOptions device_options)55 port::Status Init(int device_ordinal, DeviceOptions device_options) override { 56 return ::tensorflow::OkStatus(); 57 } 58 GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)59 port::Status GetKernel(const MultiKernelLoaderSpec &spec, 60 KernelBase *kernel) override { 61 return port::UnimplementedError("Not Implemented"); 62 } Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)63 port::Status Launch(Stream *stream, const ThreadDim &thread_dims, 64 const BlockDim &block_dims, const KernelBase &kernel, 65 const KernelArgsArrayBase &args) override { 66 return port::UnimplementedError("Not Implemented"); 67 } 68 69 DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; 70 void *GetSubBuffer(DeviceMemoryBase *parent, uint64_t offset_bytes, 71 uint64_t size_bytes) override; 72 void Deallocate(DeviceMemoryBase *mem) override; 73 HostMemoryAllocate(uint64_t size)74 void *HostMemoryAllocate(uint64_t size) override { return new char[size]; } HostMemoryDeallocate(void * mem)75 void HostMemoryDeallocate(void *mem) override { 76 delete[] static_cast<char *>(mem); 77 } HostMemoryRegister(void * mem,uint64_t size)78 bool HostMemoryRegister(void *mem, uint64_t size) override { return true; } HostMemoryUnregister(void * mem)79 bool HostMemoryUnregister(void *mem) override { return true; } 80 81 bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &dev_src, 82 uint64_t size) override; 83 bool Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src, 84 uint64_t size) override; MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * pop_dst,const DeviceMemoryBase & host_src,uint64_t size)85 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, 86 const DeviceMemoryBase &host_src, 87 uint64_t size) override { 88 return false; 89 } 90 MemZero(Stream * stream,DeviceMemoryBase * location,uint64_t size)91 port::Status MemZero(Stream *stream, DeviceMemoryBase *location, 92 uint64_t size) override { 93 return port::InternalError("Interpreter can not memzero"); 94 } Memset(Stream * stream,DeviceMemoryBase * location,uint8_t pattern,uint64_t size)95 port::Status Memset(Stream *stream, DeviceMemoryBase *location, 96 uint8_t pattern, uint64_t size) override { 97 return port::InternalError("Interpreter can not memset"); 98 } Memset32(Stream * stream,DeviceMemoryBase * location,uint32_t pattern,uint64_t size)99 port::Status Memset32(Stream *stream, DeviceMemoryBase *location, 100 uint32_t pattern, uint64_t size) override { 101 return port::InternalError("Interpreter can not memset"); 102 } 103 104 // No "synchronize all activity" implemented for this platform at the moment. SynchronizeAllActivity()105 bool SynchronizeAllActivity() override { return true; } SynchronousMemZero(DeviceMemoryBase * location,uint64_t size)106 port::Status SynchronousMemZero(DeviceMemoryBase *location, 107 uint64_t size) override { 108 return port::InternalError("Interpreter can not memzero"); 109 } 110 SynchronousMemSet(DeviceMemoryBase * location,int value,uint64_t size)111 port::Status SynchronousMemSet(DeviceMemoryBase *location, int value, 112 uint64_t size) override { 113 return port::InternalError("Interpreter can not memset"); 114 } 115 116 port::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, 117 const void *host_src, uint64_t size) override; 118 port::Status SynchronousMemcpy(void *host_dst, 119 const DeviceMemoryBase &dev_src, 120 uint64_t size) override; SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * pop_dst,const DeviceMemoryBase & pop_src,uint64_t size)121 port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, 122 const DeviceMemoryBase &pop_src, 123 uint64_t size) override { 124 return port::Status{port::error::UNIMPLEMENTED, ""}; 125 } 126 127 bool HostCallback(Stream *stream, 128 std::function<port::Status()> callback) override; 129 AllocateEvent(Event * event)130 port::Status AllocateEvent(Event *event) override { 131 return ::tensorflow::OkStatus(); 132 } 133 DeallocateEvent(Event * event)134 port::Status DeallocateEvent(Event *event) override { 135 return ::tensorflow::OkStatus(); 136 } 137 RecordEvent(Stream * stream,Event * event)138 port::Status RecordEvent(Stream *stream, Event *event) override { 139 return port::Status{port::error::UNIMPLEMENTED, "RecordEvent"}; 140 } 141 WaitForEvent(Stream * stream,Event * event)142 port::Status WaitForEvent(Stream *stream, Event *event) override { 143 return port::Status{port::error::UNIMPLEMENTED, "WaitForEvent"}; 144 } 145 PollForEventStatus(Event * event)146 Event::Status PollForEventStatus(Event *event) override { 147 return Event::Status::kError; 148 } 149 AllocateStream(Stream * stream)150 bool AllocateStream(Stream *stream) override { return true; } DeallocateStream(Stream * stream)151 void DeallocateStream(Stream *stream) override {} 152 bool CreateStreamDependency(Stream *dependent, Stream *other) override; 153 AllocateTimer(Timer * timer)154 bool AllocateTimer(Timer *timer) override { return true; } DeallocateTimer(Timer * timer)155 void DeallocateTimer(Timer *timer) override {} 156 bool StartTimer(Stream *stream, Timer *timer) override; 157 bool StopTimer(Stream *stream, Timer *timer) override; 158 159 port::Status BlockHostUntilDone(Stream *stream) override; 160 PlatformDeviceCount()161 int PlatformDeviceCount() override { return 1; } 162 DeviceMemoryUsage(int64_t * free,int64_t * total)163 bool DeviceMemoryUsage(int64_t *free, int64_t *total) const override { 164 return false; 165 } 166 CreateDeviceDescription()167 port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription() 168 const override { 169 return CreateDeviceDescription(0); 170 } 171 172 static port::StatusOr<std::unique_ptr<DeviceDescription>> 173 CreateDeviceDescription(int device_ordinal); 174 EnablePeerAccessTo(StreamExecutorInterface * other)175 port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { 176 return ::tensorflow::OkStatus(); 177 } 178 CanEnablePeerAccessTo(StreamExecutorInterface * other)179 bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override { 180 return true; 181 } 182 CreateEventImplementation()183 std::unique_ptr<internal::EventInterface> CreateEventImplementation() 184 override { 185 return nullptr; 186 } 187 CreateKernelImplementation()188 std::unique_ptr<internal::KernelInterface> CreateKernelImplementation() 189 override { 190 return nullptr; 191 } 192 GetStreamImplementation()193 std::unique_ptr<internal::StreamInterface> GetStreamImplementation() 194 override { 195 return std::unique_ptr<internal::StreamInterface>( 196 new host::HostStream(/*thread_stack_size=*/0)); 197 } 198 GetTimerImplementation()199 std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override { 200 return std::unique_ptr<internal::TimerInterface>(new host::HostTimer()); 201 } 202 203 private: 204 DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); 205 206 port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer( 207 const xla::Shape &shape); 208 209 const PluginConfig plugin_config_; 210 }; 211 212 } // namespace interpreter 213 } // namespace stream_executor 214 215 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ 216