• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/interpreter/executor.h"
17 
18 #include <cstring>
19 
20 #include "tensorflow/compiler/xla/status_macros.h"
21 
22 namespace stream_executor {
23 namespace interpreter {
24 
AsExecutorStream(Stream * stream)25 host::HostStream *AsExecutorStream(Stream *stream) {
26   DCHECK(stream != nullptr);
27   return dynamic_cast<host::HostStream *>(stream->implementation());
28 }
29 
XlaInterpreterExecutor(const PluginConfig & plugin_config)30 XlaInterpreterExecutor::XlaInterpreterExecutor(
31     const PluginConfig &plugin_config)
32     : plugin_config_(plugin_config) {}
33 
~XlaInterpreterExecutor()34 XlaInterpreterExecutor::~XlaInterpreterExecutor() {}
35 
Allocate(uint64_t size,int64_t memory_space)36 DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64_t size,
37                                                   int64_t memory_space) {
38   return DeviceMemoryBase(new char[size], size);
39 }
40 
GetSubBuffer(DeviceMemoryBase * parent,uint64_t offset_bytes,uint64_t)41 void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent,
42                                            uint64_t offset_bytes,
43                                            uint64_t /*size_bytes*/) {
44   return parent + offset_bytes;
45 }
46 
Deallocate(DeviceMemoryBase * mem)47 void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) {
48   delete[] static_cast<char *>(mem->opaque());
49 }
50 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & dev_src,uint64_t size)51 bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst,
52                                     const DeviceMemoryBase &dev_src,
53                                     uint64_t size) {
54   AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
55     // Ignore errors.
56     port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
57   });
58   port::Status status = AsExecutorStream(stream)->BlockUntilDone();
59   if (status.ok()) {
60     return true;
61   }
62 
63   // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to
64   // throw away error information here.
65   LOG(WARNING) << "Memcpy: error on stream: " << status;
66   return false;
67 }
68 
Memcpy(Stream * stream,DeviceMemoryBase * dev_dst,const void * host_src,uint64_t size)69 bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
70                                     const void *host_src, uint64_t size) {
71   AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
72     // Ignore errors.
73     port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
74   });
75   port::Status status = AsExecutorStream(stream)->BlockUntilDone();
76   if (status.ok()) {
77     return true;
78   }
79 
80   // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to
81   // throw away error information here.
82   LOG(WARNING) << "Memcpy: error on stream: " << status;
83   return false;
84 }
85 
SynchronousMemcpy(DeviceMemoryBase * dev_dst,const void * host_src,uint64_t size)86 port::Status XlaInterpreterExecutor::SynchronousMemcpy(
87     DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) {
88   memcpy(dev_dst->opaque(), host_src, size);
89   return ::tensorflow::OkStatus();
90 }
91 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & dev_src,uint64_t size)92 port::Status XlaInterpreterExecutor::SynchronousMemcpy(
93     void *host_dst, const DeviceMemoryBase &dev_src, uint64_t size) {
94   memcpy(host_dst, dev_src.opaque(), size);
95   return ::tensorflow::OkStatus();
96 }
97 
HostCallback(Stream * stream,std::function<port::Status ()> callback)98 bool XlaInterpreterExecutor::HostCallback(
99     Stream *stream, std::function<port::Status()> callback) {
100   AsExecutorStream(stream)->EnqueueTaskWithStatus(callback);
101   return true;
102 }
103 
CreateStreamDependency(Stream * dependent,Stream * other)104 bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent,
105                                                     Stream *other) {
106   AsExecutorStream(dependent)->EnqueueTask(
107       [other]() { return other->BlockHostUntilDone(); });
108   port::Status status = AsExecutorStream(dependent)->BlockUntilDone();
109   if (status.ok()) {
110     return true;
111   }
112 
113   // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to
114   // throw away error information here.
115   LOG(WARNING) << "CreateStreamDependency: error on stream: " << status;
116   return false;
117 }
118 
StartTimer(Stream * stream,Timer * timer)119 bool XlaInterpreterExecutor::StartTimer(Stream *stream, Timer *timer) {
120   dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream);
121   return true;
122 }
123 
StopTimer(Stream * stream,Timer * timer)124 bool XlaInterpreterExecutor::StopTimer(Stream *stream, Timer *timer) {
125   dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream);
126   return true;
127 }
128 
BlockHostUntilDone(Stream * stream)129 port::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) {
130   return AsExecutorStream(stream)->BlockUntilDone();
131 }
132 
133 port::StatusOr<std::unique_ptr<DeviceDescription>>
CreateDeviceDescription(int device_ordinal)134 XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) {
135   internal::DeviceDescriptionBuilder builder;
136 
137   builder.set_device_address_bits(64);
138 
139   builder.set_name("Interpreter");
140   builder.set_device_memory_size(static_cast<uint64_t>(4) * 1024 * 1024 * 1024);
141   builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9);
142 
143   return builder.Build();
144 }
145 
146 }  // namespace interpreter
147 }  // namespace stream_executor
148