1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/threadpool.h"
36 #include "tensorflow/core/platform/byte_order.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/platform/logging.h"
39
40 namespace xla {
41
42 /* static */ TestAllocator* LocalClientTestBase::allocator_;
43
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64 memory_space)44 StatusOr<se::OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
45 uint64 size,
46 bool retry_on_failure,
47 int64 memory_space) {
48 VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
49 {
50 tensorflow::mutex_lock lock(count_mutex_);
51 allocation_count_++;
52 device_allocation_count_[device_ordinal]++;
53 }
54 return se::StreamExecutorMemoryAllocator::Allocate(
55 device_ordinal, size, retry_on_failure, memory_space);
56 }
57
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)58 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
59 VLOG(2) << "Deallocate(" << device_ordinal << ")";
60 {
61 tensorflow::mutex_lock lock(count_mutex_);
62 deallocation_count_++;
63 device_deallocation_count_[device_ordinal]++;
64 }
65 return se::StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
66 }
67
allocation_count() const68 int64 TestAllocator::allocation_count() const {
69 tensorflow::mutex_lock lock(count_mutex_);
70 return allocation_count_;
71 }
72
allocation_count(int device_ordinal) const73 int64 TestAllocator::allocation_count(int device_ordinal) const {
74 tensorflow::mutex_lock lock(count_mutex_);
75 auto it = device_allocation_count_.find(device_ordinal);
76 if (it == device_allocation_count_.end()) {
77 return 0;
78 } else {
79 return it->second;
80 }
81 }
82
deallocation_count() const83 int64 TestAllocator::deallocation_count() const {
84 tensorflow::mutex_lock lock(count_mutex_);
85 return deallocation_count_;
86 }
87
deallocation_count(int device_ordinal) const88 int64 TestAllocator::deallocation_count(int device_ordinal) const {
89 tensorflow::mutex_lock lock(count_mutex_);
90 auto it = device_deallocation_count_.find(device_ordinal);
91 if (it == device_deallocation_count_.end()) {
92 return 0;
93 } else {
94 return it->second;
95 }
96 }
97
GetOrCreateAllocator(se::Platform * platform)98 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
99 se::Platform* platform) {
100 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
101 tensorflow::mutex_lock lock(mu);
102
103 if (allocator_ == nullptr) {
104 allocator_ = new TestAllocator(
105 platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
106 : platform);
107 }
108 return allocator_;
109 }
110
111 // Define this in .cc file to avoid having to include eigen or forward declare
112 // these types in the header.
113 struct LocalClientTestBase::EigenThreadPoolWrapper {
EigenThreadPoolWrapperxla::LocalClientTestBase::EigenThreadPoolWrapper114 explicit EigenThreadPoolWrapper()
115 : pool(new tensorflow::thread::ThreadPool(
116 tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
117 device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
118 pool->NumThreads())) {}
119
120 std::unique_ptr<tensorflow::thread::ThreadPool> pool;
121 std::unique_ptr<Eigen::ThreadPoolDevice> device;
122 };
123
LocalClientTestBase(se::Platform * platform)124 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
125 : local_client_(
126 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
127 thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
128 // Take the first executor, since it's the default one.
129 stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
130 .ValueOrDie()
131 .front();
132 transfer_manager_ =
133 TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
134 }
135
~LocalClientTestBase()136 LocalClientTestBase::~LocalClientTestBase() {}
137
LiteralToShapedBuffer(const Literal & literal)138 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
139 const Literal& literal) {
140 return local_client_
141 ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
142 .ConsumeValueOrDie();
143 }
144
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)145 Literal LocalClientTestBase::ShapedBufferToLiteral(
146 const ShapedBuffer& shaped_buffer) {
147 return local_client_->ShapedBufferToLiteral(shaped_buffer)
148 .ConsumeValueOrDie();
149 }
150
DefaultExecutableBuildOptions() const151 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
152 const {
153 return ExecutableBuildOptions();
154 }
155
DefaultExecutableRunOptions() const156 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
157 ExecutableRunOptions run_options;
158 run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
159 run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
160 return run_options;
161 }
162
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)163 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
164 const XlaComputation& computation,
165 absl::Span<const ShapedBuffer* const> arguments) {
166 return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
167 DefaultExecutableRunOptions())
168 .ConsumeValueOrDie();
169 }
170
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)171 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
172 const XlaComputation& computation,
173 absl::Span<const ShapedBuffer* const> arguments,
174 const ExecutableBuildOptions& build_options,
175 const ExecutableRunOptions& run_options) {
176 return ExecuteLocally(computation, arguments, build_options, run_options)
177 .ConsumeValueOrDie();
178 }
179
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)180 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
181 const XlaComputation& computation,
182 absl::Span<const ShapedBuffer* const> arguments) {
183 return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
184 DefaultExecutableRunOptions());
185 }
186
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)187 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
188 const XlaComputation& computation,
189 absl::Span<const ShapedBuffer* const> arguments,
190 const ExecutableBuildOptions& build_options,
191 const ExecutableRunOptions& run_options) {
192 std::vector<const Shape*> argument_layouts(arguments.size());
193 for (int i = 0; i < arguments.size(); ++i) {
194 argument_layouts[i] = &arguments[i]->on_host_shape();
195 }
196 TF_ASSIGN_OR_RETURN(
197 auto executables,
198 local_client_->Compile(computation, argument_layouts, build_options));
199 TF_RET_CHECK(executables.size() == 1);
200 TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
201
202 auto device_ordinal =
203 build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
204 auto* stream = run_options.stream();
205 if (!stream) {
206 stream = local_client_->mutable_backend()
207 ->BorrowStream(device_ordinal)
208 .ValueOrDie()
209 .get();
210 }
211 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
212 return std::move(ret);
213 }
214
215 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text)216 LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
217 return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
218 }
219
220 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)221 LocalClientTestBase::ParseAndReturnVerifiedModule(
222 absl::string_view hlo_text, const HloModuleConfig& config) {
223 auto module = absl::make_unique<VerifiedHloModule>(
224 TestName(), config, /*verifier_layout_sensitive=*/false,
225 /*allow_mixed_precision_in_hlo_verifier=*/true,
226 local_client_->backend().compiler()->ShapeSizeBytesFunction());
227 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
228 return std::move(module);
229 }
230
231 } // namespace xla
232