• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/threadpool.h"
36 #include "tensorflow/core/platform/byte_order.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/platform/logging.h"
39 
40 namespace xla {
41 
42 /* static */ TestAllocator* LocalClientTestBase::allocator_;
43 
Allocate(int device_ordinal,uint64 size,bool retry_on_failure,int64 memory_space)44 StatusOr<se::OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
45                                                          uint64 size,
46                                                          bool retry_on_failure,
47                                                          int64 memory_space) {
48   VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
49   {
50     tensorflow::mutex_lock lock(count_mutex_);
51     allocation_count_++;
52     device_allocation_count_[device_ordinal]++;
53   }
54   return se::StreamExecutorMemoryAllocator::Allocate(
55       device_ordinal, size, retry_on_failure, memory_space);
56 }
57 
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)58 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
59   VLOG(2) << "Deallocate(" << device_ordinal << ")";
60   {
61     tensorflow::mutex_lock lock(count_mutex_);
62     deallocation_count_++;
63     device_deallocation_count_[device_ordinal]++;
64   }
65   return se::StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
66 }
67 
allocation_count() const68 int64 TestAllocator::allocation_count() const {
69   tensorflow::mutex_lock lock(count_mutex_);
70   return allocation_count_;
71 }
72 
allocation_count(int device_ordinal) const73 int64 TestAllocator::allocation_count(int device_ordinal) const {
74   tensorflow::mutex_lock lock(count_mutex_);
75   auto it = device_allocation_count_.find(device_ordinal);
76   if (it == device_allocation_count_.end()) {
77     return 0;
78   } else {
79     return it->second;
80   }
81 }
82 
deallocation_count() const83 int64 TestAllocator::deallocation_count() const {
84   tensorflow::mutex_lock lock(count_mutex_);
85   return deallocation_count_;
86 }
87 
deallocation_count(int device_ordinal) const88 int64 TestAllocator::deallocation_count(int device_ordinal) const {
89   tensorflow::mutex_lock lock(count_mutex_);
90   auto it = device_deallocation_count_.find(device_ordinal);
91   if (it == device_deallocation_count_.end()) {
92     return 0;
93   } else {
94     return it->second;
95   }
96 }
97 
GetOrCreateAllocator(se::Platform * platform)98 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
99     se::Platform* platform) {
100   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
101   tensorflow::mutex_lock lock(mu);
102 
103   if (allocator_ == nullptr) {
104     allocator_ = new TestAllocator(
105         platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
106                             : platform);
107   }
108   return allocator_;
109 }
110 
111 // Define this in .cc file to avoid having to include eigen or forward declare
112 // these types in the header.
113 struct LocalClientTestBase::EigenThreadPoolWrapper {
EigenThreadPoolWrapperxla::LocalClientTestBase::EigenThreadPoolWrapper114   explicit EigenThreadPoolWrapper()
115       : pool(new tensorflow::thread::ThreadPool(
116             tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
117         device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
118                                            pool->NumThreads())) {}
119 
120   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
121   std::unique_ptr<Eigen::ThreadPoolDevice> device;
122 };
123 
LocalClientTestBase(se::Platform * platform)124 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
125     : local_client_(
126           ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
127       thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
128   // Take the first executor, since it's the default one.
129   stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
130                          .ValueOrDie()
131                          .front();
132   transfer_manager_ =
133       TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
134 }
135 
~LocalClientTestBase()136 LocalClientTestBase::~LocalClientTestBase() {}
137 
LiteralToShapedBuffer(const Literal & literal)138 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
139     const Literal& literal) {
140   return local_client_
141       ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
142       .ConsumeValueOrDie();
143 }
144 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)145 Literal LocalClientTestBase::ShapedBufferToLiteral(
146     const ShapedBuffer& shaped_buffer) {
147   return local_client_->ShapedBufferToLiteral(shaped_buffer)
148       .ConsumeValueOrDie();
149 }
150 
DefaultExecutableBuildOptions() const151 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
152     const {
153   return ExecutableBuildOptions();
154 }
155 
DefaultExecutableRunOptions() const156 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
157   ExecutableRunOptions run_options;
158   run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
159   run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
160   return run_options;
161 }
162 
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)163 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
164     const XlaComputation& computation,
165     absl::Span<const ShapedBuffer* const> arguments) {
166   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
167                         DefaultExecutableRunOptions())
168       .ConsumeValueOrDie();
169 }
170 
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)171 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
172     const XlaComputation& computation,
173     absl::Span<const ShapedBuffer* const> arguments,
174     const ExecutableBuildOptions& build_options,
175     const ExecutableRunOptions& run_options) {
176   return ExecuteLocally(computation, arguments, build_options, run_options)
177       .ConsumeValueOrDie();
178 }
179 
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)180 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
181     const XlaComputation& computation,
182     absl::Span<const ShapedBuffer* const> arguments) {
183   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
184                         DefaultExecutableRunOptions());
185 }
186 
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)187 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
188     const XlaComputation& computation,
189     absl::Span<const ShapedBuffer* const> arguments,
190     const ExecutableBuildOptions& build_options,
191     const ExecutableRunOptions& run_options) {
192   std::vector<const Shape*> argument_layouts(arguments.size());
193   for (int i = 0; i < arguments.size(); ++i) {
194     argument_layouts[i] = &arguments[i]->on_host_shape();
195   }
196   TF_ASSIGN_OR_RETURN(
197       auto executables,
198       local_client_->Compile(computation, argument_layouts, build_options));
199   TF_RET_CHECK(executables.size() == 1);
200   TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
201 
202   auto device_ordinal =
203       build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
204   auto* stream = run_options.stream();
205   if (!stream) {
206     stream = local_client_->mutable_backend()
207                  ->BorrowStream(device_ordinal)
208                  .ValueOrDie()
209                  .get();
210   }
211   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
212   return std::move(ret);
213 }
214 
215 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text)216 LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
217   return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
218 }
219 
220 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)221 LocalClientTestBase::ParseAndReturnVerifiedModule(
222     absl::string_view hlo_text, const HloModuleConfig& config) {
223   auto module = absl::make_unique<VerifiedHloModule>(
224       TestName(), config, /*verifier_layout_sensitive=*/false,
225       /*allow_mixed_precision_in_hlo_verifier=*/true,
226       local_client_->backend().compiler()->ShapeSizeBytesFunction());
227   TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
228   return std::move(module);
229 }
230 
231 }  // namespace xla
232