1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/interpreter/executor.h"
17
18 #include <cstring>
19
20 #include "tensorflow/compiler/xla/status_macros.h"
21
22 namespace stream_executor {
23 namespace interpreter {
24
AsExecutorStream(Stream * stream)25 host::HostStream *AsExecutorStream(Stream *stream) {
26 DCHECK(stream != nullptr);
27 return dynamic_cast<host::HostStream *>(stream->implementation());
28 }
29
XlaInterpreterExecutor(const PluginConfig & plugin_config)30 XlaInterpreterExecutor::XlaInterpreterExecutor(
31 const PluginConfig &plugin_config)
32 : plugin_config_(plugin_config) {}
33
~XlaInterpreterExecutor()34 XlaInterpreterExecutor::~XlaInterpreterExecutor() {}
35
Allocate(uint64_t size,int64_t memory_space)36 DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64_t size,
37 int64_t memory_space) {
38 return DeviceMemoryBase(new char[size], size);
39 }
40
GetSubBuffer(DeviceMemoryBase * parent,uint64_t offset_bytes,uint64_t)41 void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent,
42 uint64_t offset_bytes,
43 uint64_t /*size_bytes*/) {
44 return parent + offset_bytes;
45 }
46
Deallocate(DeviceMemoryBase * mem)47 void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) {
48 delete[] static_cast<char *>(mem->opaque());
49 }
50
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & dev_src,uint64_t size)51 bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst,
52 const DeviceMemoryBase &dev_src,
53 uint64_t size) {
54 AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
55 // Ignore errors.
56 port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
57 });
58 port::Status status = AsExecutorStream(stream)->BlockUntilDone();
59 if (status.ok()) {
60 return true;
61 }
62
63 // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to
64 // throw away error information here.
65 LOG(WARNING) << "Memcpy: error on stream: " << status;
66 return false;
67 }
68
Memcpy(Stream * stream,DeviceMemoryBase * dev_dst,const void * host_src,uint64_t size)69 bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
70 const void *host_src, uint64_t size) {
71 AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
72 // Ignore errors.
73 port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
74 });
75 port::Status status = AsExecutorStream(stream)->BlockUntilDone();
76 if (status.ok()) {
77 return true;
78 }
79
80 // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to
81 // throw away error information here.
82 LOG(WARNING) << "Memcpy: error on stream: " << status;
83 return false;
84 }
85
SynchronousMemcpy(DeviceMemoryBase * dev_dst,const void * host_src,uint64_t size)86 port::Status XlaInterpreterExecutor::SynchronousMemcpy(
87 DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) {
88 memcpy(dev_dst->opaque(), host_src, size);
89 return ::tensorflow::OkStatus();
90 }
91
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & dev_src,uint64_t size)92 port::Status XlaInterpreterExecutor::SynchronousMemcpy(
93 void *host_dst, const DeviceMemoryBase &dev_src, uint64_t size) {
94 memcpy(host_dst, dev_src.opaque(), size);
95 return ::tensorflow::OkStatus();
96 }
97
HostCallback(Stream * stream,std::function<port::Status ()> callback)98 bool XlaInterpreterExecutor::HostCallback(
99 Stream *stream, std::function<port::Status()> callback) {
100 AsExecutorStream(stream)->EnqueueTaskWithStatus(callback);
101 return true;
102 }
103
CreateStreamDependency(Stream * dependent,Stream * other)104 bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent,
105 Stream *other) {
106 AsExecutorStream(dependent)->EnqueueTask(
107 [other]() { return other->BlockHostUntilDone(); });
108 port::Status status = AsExecutorStream(dependent)->BlockUntilDone();
109 if (status.ok()) {
110 return true;
111 }
112
113 // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to
114 // throw away error information here.
115 LOG(WARNING) << "CreateStreamDependency: error on stream: " << status;
116 return false;
117 }
118
StartTimer(Stream * stream,Timer * timer)119 bool XlaInterpreterExecutor::StartTimer(Stream *stream, Timer *timer) {
120 dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream);
121 return true;
122 }
123
StopTimer(Stream * stream,Timer * timer)124 bool XlaInterpreterExecutor::StopTimer(Stream *stream, Timer *timer) {
125 dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream);
126 return true;
127 }
128
BlockHostUntilDone(Stream * stream)129 port::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) {
130 return AsExecutorStream(stream)->BlockUntilDone();
131 }
132
133 port::StatusOr<std::unique_ptr<DeviceDescription>>
CreateDeviceDescription(int device_ordinal)134 XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) {
135 internal::DeviceDescriptionBuilder builder;
136
137 builder.set_device_address_bits(64);
138
139 builder.set_name("Interpreter");
140 builder.set_device_memory_size(static_cast<uint64_t>(4) * 1024 * 1024 * 1024);
141 builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9);
142
143 return builder.Build();
144 }
145
146 } // namespace interpreter
147 } // namespace stream_executor
148