• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
29 #include "tensorflow/compiler/xla/service/local_service.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
32 #include "tensorflow/compiler/xla/service/transfer_manager.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41 #include "tensorflow/core/platform/types.h"
42 #include "tensorflow/stream_executor/device_memory_allocator.h"
43 
44 namespace xla {
45 
46 class TestAllocator : public se::StreamExecutorMemoryAllocator {
47  public:
TestAllocator(se::Platform * platform)48   explicit TestAllocator(se::Platform* platform)
49       : se::StreamExecutorMemoryAllocator(
50             platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
51   }
52 
53   StatusOr<se::OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
54                                             bool retry_on_failure,
55                                             int64 memory_space) override;
56   Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
57 
58   // Return the number of allocations that have been performed.
59   int64 allocation_count() const;
60   int64 allocation_count(int device_ordinal) const;
61 
62   // Return the number of deallocations that have been performed.
63   int64 deallocation_count() const;
64   int64 deallocation_count(int device_ordinal) const;
65 
66  private:
67   mutable tensorflow::mutex count_mutex_;
68 
69   // Global counts of allocations and deallocations.
70   int64 allocation_count_ TF_GUARDED_BY(count_mutex_) = 0;
71   int64 deallocation_count_ TF_GUARDED_BY(count_mutex_) = 0;
72 
73   // Per-device counts of allocations and deallocations.
74   std::map<int, int64> device_allocation_count_ TF_GUARDED_BY(count_mutex_);
75   std::map<int, int64> device_deallocation_count_ TF_GUARDED_BY(count_mutex_);
76 };
77 
78 // A base class for tests which exercise the LocalClient interface.
79 class LocalClientTestBase : public ManifestCheckingTest {
80  protected:
81   struct EigenThreadPoolWrapper;
82   explicit LocalClientTestBase(se::Platform* platform = nullptr);
83   virtual ~LocalClientTestBase();
84 
85   static TestAllocator* GetOrCreateAllocator(se::Platform* platform);
86 
87   // Copy the given literal onto the default device and return a
88   // ScopedShapedBuffer. Convenience wrapper around
89   // LocalClient::LiteralToShapedBuffer.
90   ScopedShapedBuffer LiteralToShapedBuffer(const Literal& literal);
91 
92   // Construct and return a literal containing the array represented by
93   // shaped_buffer.
94   Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
95 
96   // Execute the given computation on the local client. With and without
97   // options.
98   StatusOr<ScopedShapedBuffer> ExecuteLocally(
99       const XlaComputation& computation,
100       absl::Span<const ShapedBuffer* const> arguments);
101   StatusOr<ScopedShapedBuffer> ExecuteLocally(
102       const XlaComputation& computation,
103       absl::Span<const ShapedBuffer* const> arguments,
104       const ExecutableBuildOptions& build_options,
105       const ExecutableRunOptions& run_options);
106 
107   ScopedShapedBuffer ExecuteLocallyOrDie(
108       const XlaComputation& computation,
109       absl::Span<const ShapedBuffer* const> arguments);
110   ScopedShapedBuffer ExecuteLocallyOrDie(
111       const XlaComputation& computation,
112       absl::Span<const ShapedBuffer* const> arguments,
113       const ExecutableBuildOptions& build_options,
114       const ExecutableRunOptions& run_options);
115 
116   // Parses the given string and returns module as a VerifiedHloModule.
117   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
118       absl::string_view hlo_text);
119   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
120       absl::string_view hlo_text, const HloModuleConfig& config);
121 
122   // Returns a default set of execute options.
123   ExecutableBuildOptions DefaultExecutableBuildOptions() const;
124 
125   // Returns a default set of execute options, configured to use allocator_
126   // as the allocator.
127   ExecutableRunOptions DefaultExecutableRunOptions() const;
128 
TestName()129   string TestName() const {
130     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
131   }
132 
133   // The allocator must live as long as the service, which lives until the end
134   // of the process. So make the allocator static.
135   static TestAllocator* allocator_;
136 
137   se::StreamExecutor* stream_executor_;
138   TransferManager* transfer_manager_;
139 
140   LocalClient* local_client_;
141 
142   std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
143 };
144 
145 }  // namespace xla
146 
147 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
148