1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
18
19 #include <vector>
20
21 #include "absl/memory/memory.h"
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/byte_order.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 namespace xla {
36
37 /* static */ TestAllocator* LocalClientTestBase::allocator_;
38
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)39 StatusOr<OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
40 uint64 size,
41 bool retry_on_failure) {
42 VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
43 {
44 tensorflow::mutex_lock lock(count_mutex_);
45 allocation_count_++;
46 device_allocation_count_[device_ordinal]++;
47 }
48 return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size,
49 retry_on_failure);
50 }
51
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)52 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
53 VLOG(2) << "Deallocate(" << device_ordinal << ")";
54 {
55 tensorflow::mutex_lock lock(count_mutex_);
56 deallocation_count_++;
57 device_deallocation_count_[device_ordinal]++;
58 }
59 return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
60 }
61
allocation_count() const62 int64 TestAllocator::allocation_count() const {
63 tensorflow::mutex_lock lock(count_mutex_);
64 return allocation_count_;
65 }
66
allocation_count(int device_ordinal) const67 int64 TestAllocator::allocation_count(int device_ordinal) const {
68 tensorflow::mutex_lock lock(count_mutex_);
69 auto it = device_allocation_count_.find(device_ordinal);
70 if (it == device_allocation_count_.end()) {
71 return 0;
72 } else {
73 return it->second;
74 }
75 }
76
deallocation_count() const77 int64 TestAllocator::deallocation_count() const {
78 tensorflow::mutex_lock lock(count_mutex_);
79 return deallocation_count_;
80 }
81
deallocation_count(int device_ordinal) const82 int64 TestAllocator::deallocation_count(int device_ordinal) const {
83 tensorflow::mutex_lock lock(count_mutex_);
84 auto it = device_deallocation_count_.find(device_ordinal);
85 if (it == device_deallocation_count_.end()) {
86 return 0;
87 } else {
88 return it->second;
89 }
90 }
91
GetOrCreateAllocator(se::Platform * platform)92 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
93 se::Platform* platform) {
94 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
95 tensorflow::mutex_lock lock(mu);
96
97 if (allocator_ == nullptr) {
98 allocator_ = new TestAllocator(
99 platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
100 : platform);
101 }
102 return allocator_;
103 }
104
105 // Define this in .cc file to avoid having to include eigen or forward declare
106 // these types in the header.
107 struct LocalClientTestBase::EigenThreadPoolWrapper {
EigenThreadPoolWrapperxla::LocalClientTestBase::EigenThreadPoolWrapper108 explicit EigenThreadPoolWrapper()
109 : pool(new tensorflow::thread::ThreadPool(
110 tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
111 wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
112 device(new Eigen::ThreadPoolDevice(wrapper.get(),
113 wrapper->NumThreads())) {}
114
115 std::unique_ptr<tensorflow::thread::ThreadPool> pool;
116 std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
117 std::unique_ptr<Eigen::ThreadPoolDevice> device;
118 };
119
LocalClientTestBase(se::Platform * platform)120 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
121 : local_client_(
122 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
123 thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
124 stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
125 .ValueOrDie()[local_client_->default_device_ordinal()];
126 transfer_manager_ =
127 TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
128 }
129
~LocalClientTestBase()130 LocalClientTestBase::~LocalClientTestBase() {}
131
LiteralToShapedBuffer(const Literal & literal)132 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
133 const Literal& literal) {
134 return local_client_
135 ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
136 .ConsumeValueOrDie();
137 }
138
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)139 Literal LocalClientTestBase::ShapedBufferToLiteral(
140 const ShapedBuffer& shaped_buffer) {
141 return local_client_->ShapedBufferToLiteral(shaped_buffer)
142 .ConsumeValueOrDie();
143 }
144
DefaultExecutableBuildOptions() const145 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
146 const {
147 return ExecutableBuildOptions();
148 }
149
DefaultExecutableRunOptions() const150 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
151 ExecutableRunOptions run_options;
152 run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
153 run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
154 return run_options;
155 }
156
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)157 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
158 const XlaComputation& computation,
159 absl::Span<const ShapedBuffer* const> arguments) {
160 return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
161 DefaultExecutableRunOptions())
162 .ConsumeValueOrDie();
163 }
164
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)165 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
166 const XlaComputation& computation,
167 absl::Span<const ShapedBuffer* const> arguments,
168 const ExecutableBuildOptions& build_options,
169 const ExecutableRunOptions& run_options) {
170 return ExecuteLocally(computation, arguments, build_options, run_options)
171 .ConsumeValueOrDie();
172 }
173
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)174 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
175 const XlaComputation& computation,
176 absl::Span<const ShapedBuffer* const> arguments) {
177 return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
178 DefaultExecutableRunOptions());
179 }
180
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)181 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
182 const XlaComputation& computation,
183 absl::Span<const ShapedBuffer* const> arguments,
184 const ExecutableBuildOptions& build_options,
185 const ExecutableRunOptions& run_options) {
186 std::vector<const Shape*> argument_layouts(arguments.size());
187 for (int i = 0; i < arguments.size(); ++i) {
188 argument_layouts[i] = &arguments[i]->on_host_shape();
189 }
190 TF_ASSIGN_OR_RETURN(
191 std::unique_ptr<LocalExecutable> executable,
192 local_client_->Compile(computation, argument_layouts, build_options));
193 TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
194
195 auto device_ordinal =
196 build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
197 auto* stream = run_options.stream();
198 if (!stream) {
199 stream = local_client_->mutable_backend()
200 ->BorrowStream(device_ordinal)
201 .ValueOrDie()
202 .get();
203 }
204 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
205 return std::move(ret);
206 }
207
208 } // namespace xla
209