• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Declares the XlaInterpreterExecutor class, which is a CPU-only implementation
17 // of the StreamExecutor interface. For now, this is used for testing and to
18 // examine the performance of host-based StreamExecutor code.
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
21 
22 #include <functional>
23 #include <memory>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_description.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/device_options.h"
32 #include "tensorflow/stream_executor/event.h"
33 #include "tensorflow/stream_executor/host/host_stream.h"
34 #include "tensorflow/stream_executor/host/host_timer.h"
35 #include "tensorflow/stream_executor/kernel.h"
36 #include "tensorflow/stream_executor/kernel_spec.h"
37 #include "tensorflow/stream_executor/launch_dim.h"
38 #include "tensorflow/stream_executor/plugin.h"
39 #include "tensorflow/stream_executor/rng.h"
40 #include "tensorflow/stream_executor/stream.h"
41 #include "tensorflow/stream_executor/stream_executor.h"
42 #include "tensorflow/stream_executor/stream_executor_internal.h"
43 #include "tensorflow/stream_executor/timer.h"
44 
45 namespace stream_executor {
46 namespace interpreter {
47 
48 using Args = absl::Span<const DeviceMemoryBase>;
49 
50 class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
51  public:
52   explicit XlaInterpreterExecutor(const PluginConfig &plugin_config);
53   ~XlaInterpreterExecutor() override;
54 
Init(int device_ordinal,DeviceOptions device_options)55   port::Status Init(int device_ordinal, DeviceOptions device_options) override {
56     return ::tensorflow::OkStatus();
57   }
58 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)59   port::Status GetKernel(const MultiKernelLoaderSpec &spec,
60                          KernelBase *kernel) override {
61     return port::UnimplementedError("Not Implemented");
62   }
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)63   port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
64                       const BlockDim &block_dims, const KernelBase &kernel,
65                       const KernelArgsArrayBase &args) override {
66     return port::UnimplementedError("Not Implemented");
67   }
68 
69   DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
70   void *GetSubBuffer(DeviceMemoryBase *parent, uint64_t offset_bytes,
71                      uint64_t size_bytes) override;
72   void Deallocate(DeviceMemoryBase *mem) override;
73 
HostMemoryAllocate(uint64_t size)74   void *HostMemoryAllocate(uint64_t size) override { return new char[size]; }
HostMemoryDeallocate(void * mem)75   void HostMemoryDeallocate(void *mem) override {
76     delete[] static_cast<char *>(mem);
77   }
HostMemoryRegister(void * mem,uint64_t size)78   bool HostMemoryRegister(void *mem, uint64_t size) override { return true; }
HostMemoryUnregister(void * mem)79   bool HostMemoryUnregister(void *mem) override { return true; }
80 
81   bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &dev_src,
82               uint64_t size) override;
83   bool Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src,
84               uint64_t size) override;
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * pop_dst,const DeviceMemoryBase & host_src,uint64_t size)85   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
86                             const DeviceMemoryBase &host_src,
87                             uint64_t size) override {
88     return false;
89   }
90 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64_t size)91   port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
92                        uint64_t size) override {
93     return port::InternalError("Interpreter can not memzero");
94   }
Memset(Stream * stream,DeviceMemoryBase * location,uint8_t pattern,uint64_t size)95   port::Status Memset(Stream *stream, DeviceMemoryBase *location,
96                       uint8_t pattern, uint64_t size) override {
97     return port::InternalError("Interpreter can not memset");
98   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32_t pattern,uint64_t size)99   port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
100                         uint32_t pattern, uint64_t size) override {
101     return port::InternalError("Interpreter can not memset");
102   }
103 
104   // No "synchronize all activity" implemented for this platform at the moment.
SynchronizeAllActivity()105   bool SynchronizeAllActivity() override { return true; }
SynchronousMemZero(DeviceMemoryBase * location,uint64_t size)106   port::Status SynchronousMemZero(DeviceMemoryBase *location,
107                                   uint64_t size) override {
108     return port::InternalError("Interpreter can not memzero");
109   }
110 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64_t size)111   port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
112                                  uint64_t size) override {
113     return port::InternalError("Interpreter can not memset");
114   }
115 
116   port::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst,
117                                  const void *host_src, uint64_t size) override;
118   port::Status SynchronousMemcpy(void *host_dst,
119                                  const DeviceMemoryBase &dev_src,
120                                  uint64_t size) override;
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * pop_dst,const DeviceMemoryBase & pop_src,uint64_t size)121   port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst,
122                                                const DeviceMemoryBase &pop_src,
123                                                uint64_t size) override {
124     return port::Status{port::error::UNIMPLEMENTED, ""};
125   }
126 
127   bool HostCallback(Stream *stream,
128                     std::function<port::Status()> callback) override;
129 
AllocateEvent(Event * event)130   port::Status AllocateEvent(Event *event) override {
131     return ::tensorflow::OkStatus();
132   }
133 
DeallocateEvent(Event * event)134   port::Status DeallocateEvent(Event *event) override {
135     return ::tensorflow::OkStatus();
136   }
137 
RecordEvent(Stream * stream,Event * event)138   port::Status RecordEvent(Stream *stream, Event *event) override {
139     return port::Status{port::error::UNIMPLEMENTED, "RecordEvent"};
140   }
141 
WaitForEvent(Stream * stream,Event * event)142   port::Status WaitForEvent(Stream *stream, Event *event) override {
143     return port::Status{port::error::UNIMPLEMENTED, "WaitForEvent"};
144   }
145 
PollForEventStatus(Event * event)146   Event::Status PollForEventStatus(Event *event) override {
147     return Event::Status::kError;
148   }
149 
AllocateStream(Stream * stream)150   bool AllocateStream(Stream *stream) override { return true; }
DeallocateStream(Stream * stream)151   void DeallocateStream(Stream *stream) override {}
152   bool CreateStreamDependency(Stream *dependent, Stream *other) override;
153 
AllocateTimer(Timer * timer)154   bool AllocateTimer(Timer *timer) override { return true; }
DeallocateTimer(Timer * timer)155   void DeallocateTimer(Timer *timer) override {}
156   bool StartTimer(Stream *stream, Timer *timer) override;
157   bool StopTimer(Stream *stream, Timer *timer) override;
158 
159   port::Status BlockHostUntilDone(Stream *stream) override;
160 
PlatformDeviceCount()161   int PlatformDeviceCount() override { return 1; }
162 
DeviceMemoryUsage(int64_t * free,int64_t * total)163   bool DeviceMemoryUsage(int64_t *free, int64_t *total) const override {
164     return false;
165   }
166 
CreateDeviceDescription()167   port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
168       const override {
169     return CreateDeviceDescription(0);
170   }
171 
172   static port::StatusOr<std::unique_ptr<DeviceDescription>>
173   CreateDeviceDescription(int device_ordinal);
174 
EnablePeerAccessTo(StreamExecutorInterface * other)175   port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
176     return ::tensorflow::OkStatus();
177   }
178 
CanEnablePeerAccessTo(StreamExecutorInterface * other)179   bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override {
180     return true;
181   }
182 
CreateEventImplementation()183   std::unique_ptr<internal::EventInterface> CreateEventImplementation()
184       override {
185     return nullptr;
186   }
187 
CreateKernelImplementation()188   std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
189       override {
190     return nullptr;
191   }
192 
GetStreamImplementation()193   std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
194       override {
195     return std::unique_ptr<internal::StreamInterface>(
196         new host::HostStream(/*thread_stack_size=*/0));
197   }
198 
GetTimerImplementation()199   std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
200     return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
201   }
202 
203  private:
204   DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
205 
206   port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
207       const xla::Shape &shape);
208 
209   const PluginConfig plugin_config_;
210 };
211 
212 }  // namespace interpreter
213 }  // namespace stream_executor
214 
215 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
216