• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
18 
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/byte_order.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace xla {
36 
37 /* static */ TestAllocator* LocalClientTestBase::allocator_;
38 
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)39 StatusOr<OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
40                                                      uint64 size,
41                                                      bool retry_on_failure) {
42   VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
43   {
44     tensorflow::mutex_lock lock(count_mutex_);
45     allocation_count_++;
46     device_allocation_count_[device_ordinal]++;
47   }
48   return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size,
49                                                  retry_on_failure);
50 }
51 
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)52 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
53   VLOG(2) << "Deallocate(" << device_ordinal << ")";
54   {
55     tensorflow::mutex_lock lock(count_mutex_);
56     deallocation_count_++;
57     device_deallocation_count_[device_ordinal]++;
58   }
59   return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
60 }
61 
allocation_count() const62 int64 TestAllocator::allocation_count() const {
63   tensorflow::mutex_lock lock(count_mutex_);
64   return allocation_count_;
65 }
66 
allocation_count(int device_ordinal) const67 int64 TestAllocator::allocation_count(int device_ordinal) const {
68   tensorflow::mutex_lock lock(count_mutex_);
69   auto it = device_allocation_count_.find(device_ordinal);
70   if (it == device_allocation_count_.end()) {
71     return 0;
72   } else {
73     return it->second;
74   }
75 }
76 
deallocation_count() const77 int64 TestAllocator::deallocation_count() const {
78   tensorflow::mutex_lock lock(count_mutex_);
79   return deallocation_count_;
80 }
81 
deallocation_count(int device_ordinal) const82 int64 TestAllocator::deallocation_count(int device_ordinal) const {
83   tensorflow::mutex_lock lock(count_mutex_);
84   auto it = device_deallocation_count_.find(device_ordinal);
85   if (it == device_deallocation_count_.end()) {
86     return 0;
87   } else {
88     return it->second;
89   }
90 }
91 
GetOrCreateAllocator(se::Platform * platform)92 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
93     se::Platform* platform) {
94   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
95   tensorflow::mutex_lock lock(mu);
96 
97   if (allocator_ == nullptr) {
98     allocator_ = new TestAllocator(
99         platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
100                             : platform);
101   }
102   return allocator_;
103 }
104 
105 // Define this in .cc file to avoid having to include eigen or forward declare
106 // these types in the header.
107 struct LocalClientTestBase::EigenThreadPoolWrapper {
EigenThreadPoolWrapperxla::LocalClientTestBase::EigenThreadPoolWrapper108   explicit EigenThreadPoolWrapper()
109       : pool(new tensorflow::thread::ThreadPool(
110             tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
111         wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
112         device(new Eigen::ThreadPoolDevice(wrapper.get(),
113                                            wrapper->NumThreads())) {}
114 
115   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
116   std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
117   std::unique_ptr<Eigen::ThreadPoolDevice> device;
118 };
119 
LocalClientTestBase(se::Platform * platform)120 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
121     : local_client_(
122           ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
123       thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
124   stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
125                          .ValueOrDie()[local_client_->default_device_ordinal()];
126   transfer_manager_ =
127       TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
128 }
129 
~LocalClientTestBase()130 LocalClientTestBase::~LocalClientTestBase() {}
131 
LiteralToShapedBuffer(const Literal & literal)132 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
133     const Literal& literal) {
134   return local_client_
135       ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
136       .ConsumeValueOrDie();
137 }
138 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)139 Literal LocalClientTestBase::ShapedBufferToLiteral(
140     const ShapedBuffer& shaped_buffer) {
141   return local_client_->ShapedBufferToLiteral(shaped_buffer)
142       .ConsumeValueOrDie();
143 }
144 
DefaultExecutableBuildOptions() const145 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
146     const {
147   return ExecutableBuildOptions();
148 }
149 
DefaultExecutableRunOptions() const150 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
151   ExecutableRunOptions run_options;
152   run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
153   run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
154   return run_options;
155 }
156 
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)157 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
158     const XlaComputation& computation,
159     absl::Span<const ShapedBuffer* const> arguments) {
160   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
161                         DefaultExecutableRunOptions())
162       .ConsumeValueOrDie();
163 }
164 
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)165 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
166     const XlaComputation& computation,
167     absl::Span<const ShapedBuffer* const> arguments,
168     const ExecutableBuildOptions& build_options,
169     const ExecutableRunOptions& run_options) {
170   return ExecuteLocally(computation, arguments, build_options, run_options)
171       .ConsumeValueOrDie();
172 }
173 
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)174 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
175     const XlaComputation& computation,
176     absl::Span<const ShapedBuffer* const> arguments) {
177   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
178                         DefaultExecutableRunOptions());
179 }
180 
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)181 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
182     const XlaComputation& computation,
183     absl::Span<const ShapedBuffer* const> arguments,
184     const ExecutableBuildOptions& build_options,
185     const ExecutableRunOptions& run_options) {
186   std::vector<const Shape*> argument_layouts(arguments.size());
187   for (int i = 0; i < arguments.size(); ++i) {
188     argument_layouts[i] = &arguments[i]->on_host_shape();
189   }
190   TF_ASSIGN_OR_RETURN(
191       std::unique_ptr<LocalExecutable> executable,
192       local_client_->Compile(computation, argument_layouts, build_options));
193   TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
194 
195   auto device_ordinal =
196       build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
197   auto* stream = run_options.stream();
198   if (!stream) {
199     stream = local_client_->mutable_backend()
200                  ->BorrowStream(device_ordinal)
201                  .ValueOrDie()
202                  .get();
203   }
204   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
205   return std::move(ret);
206 }
207 
208 }  // namespace xla
209