• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_runner.h"
30 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_layout.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
39 #include "tensorflow/core/platform/test.h"
40 
41 namespace xla {
42 
43 // A base class for tests which build and/or run HLO code. The class includes
44 // support for running an HLO module on two platforms and compare the results.
45 // This is a lower level of abstraction than using the client interface and
46 // enables, for one, explicitly building a graph of HLO instructions to run.
47 //
48 // This can also be used to write text/file-based test cases. Note that the test
49 // target is responsible for linking the needed backends. A convenient way to do
50 // this is to make it an xla_test: it will generate test targets linking with
51 // the respective backends, which will be used as the test backend; the
52 // interpreter backend is already linked with hlo_test_base so it will be the
53 // default reference backend. For example, if you want to compare both cpu vs.
54 // interpreter, and gpu vs. interpreter, you can:
55 //
56 //  xla_test (
57 //    name = "sample_text_test",
58 //    srcs = ["sample_text_test.cc"],
59 //    backends = [
60 //      "cpu",
61 //      "gpu",
62 //    ],
63 //    deps = [
64 //      "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
65 //      ...
66 //    ],
67 //  )
68 //
69 // For a more detailed example, see "../tests/sample_text_test.cc".
70 class HloTestBase : public ::testing::Test {
71  public:
72   // Creates a new HLO module for a test. The module created will have
73   // TestName() for its name; it will also automatically populate its debug
74   // options from command-line flags. If you want a fresh HloModule object and
75   // then add HloComputations to it, it's recommended to use this method in your
76   // tests.
77   //
78   // This returns a vanilla HloModule that doesn't run the HLO verifier on
79   // destruction.
80   ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.")
81   std::unique_ptr<HloModule> CreateNewUnverifiedModule(
82       const string& name = TestName());
83 
84   // Like CreateNewUnverifiedModule, except the HloModule returned here runs the
85   // HLO verifier on destruction.
86   std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
87       const string& name = TestName());
88 
89   // Parses the given string and returns module as a VerifiedHloModule.
90   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
91       absl::string_view hlo_text);
92   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
93       absl::string_view hlo_text, const HloModuleConfig& config);
94 
95   // Runs the hlo_pass with the provided module and returns the result. This
96   // function also verifies that the module remains unchanged when hlo_pass
97   // returns false as the StatusOr value.
98   static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
99                                    HloModule* module);
100 
101   static PrecisionConfig DefaultPrecisionConfig(int operands);
102 
103  protected:
104   // This uses the interpreter backend as the reference backend and
105   // automatically finds another supported backend as the test backend. If the
106   // interpreter is the only supported backend, it will be both the test backend
107   // and the reference backend.
108   HloTestBase(bool verifier_layout_sensitive = false,
109               bool allow_mixed_precision_in_hlo_verifier = true,
110               std::function<bool(const HloInstruction*)>
111                   instruction_can_change_layout_func = {});
112 
113   // If your test doesn't use interpreter as the reference backend, you can use
114   // this constructor. Note that your test target is responsible for linking in
115   // both needed backends.
116   HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
117               bool verifier_layout_sensitive = false,
118               bool allow_mixed_precision_in_hlo_verifier = true,
119               std::function<bool(const HloInstruction*)>
120                   instruction_can_change_layout_func = {});
121 
~HloTestBase()122   ~HloTestBase() override {}
123 
124   // Populates debug options from command-line flags and adjusts the options for
125   // testing. It is recommended to use this when you need to pass in
126   // DebugOptions, e.g. when creating a module from a string or a file.
127   //
128   // This function is virtual so tests can specify an alternative set of debug
129   // options (e.g. disabling additional passes).
130   virtual DebugOptions GetDebugOptionsForTest();
131 
132   // Gets an HloModuleConfig with options appropriate for tests.
GetModuleConfigForTest()133   HloModuleConfig GetModuleConfigForTest() {
134     HloModuleConfig config;
135     config.set_debug_options(GetDebugOptionsForTest());
136     return config;
137   }
138 
139   // Executes the given module and return the result as a Literal.
140   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
141                             absl::Span<Literal* const> arguments);
142 
143   // Same as above, except the module will be executed without running any HLO
144   // passes on it.
145   Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
146                              absl::Span<Literal* const> arguments);
147 
148   Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
149                              absl::Span<Literal* const> arguments);
150 
151   // Executes the given module on multiple replicas.
152   //
153   // use_threads indicates whether this replicated computation will be executed
154   // with a thread-per-replica, vs using an implicitly async call such as
155   // Executable::ExecuteOnStreams.
156   StatusOr<std::vector<Literal>> ExecuteReplicated(
157       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
158       int64 num_replicas, bool use_threads, bool run_hlo_passes = false);
159 
160   // Same as above, but uses specified device assignment.
161   StatusOr<std::vector<Literal>> ExecuteReplicated(
162       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
163       int64 num_replicas, DeviceAssignment* device_assignment,
164       bool run_hlo_passes, bool use_threads);
165 
166   // Executes the given hlo module on two backends and compares results.
167   //
168   // 'arguments': the input of the hlo module.
169   //
170   // 'error': if has value, expects the results to be near (within the error
171   // bound). Otherwise, expects the results to be equal.
172   //
173   // 'reference_preprocessor': the module should be ready to run on the test
174   // backend, but it might need to be tailored so that it is able to run on the
175   // reference backend. Note that the program shape of the module must not be
176   // modified.
177   ::testing::AssertionResult RunAndCompare(
178       std::unique_ptr<HloModule> module,
179       const absl::Span<Literal* const> arguments,
180       const absl::optional<ErrorSpec>& error,
181       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
182       TF_MUST_USE_RESULT;
183 
184   // Same as above, except that the module will be executed without Hlo
185   // optimization.
186   ::testing::AssertionResult RunAndCompareNoHloPasses(
187       std::unique_ptr<HloModule> module,
188       const absl::Span<Literal* const> arguments,
189       const absl::optional<ErrorSpec>& error,
190       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
191       TF_MUST_USE_RESULT;
192 
193   // Executes an hlo module with fake inputs and compares the results.
194   ::testing::AssertionResult RunAndCompare(
195       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
196       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
197       TF_MUST_USE_RESULT;
198 
199   // Same as above, except that the module will be executed without Hlo
200   // optimization.
201   ::testing::AssertionResult RunAndCompareNoHloPasses(
202       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
203       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
204       TF_MUST_USE_RESULT;
205 
206   // Convenient wrappers for executing and comparing an hlo module with fake
207   // input. Module can be passed in directly, or parsed from an hlo_string,
208   // or loaded from a file.
209   ::testing::AssertionResult RunAndCompare(
210       const absl::string_view hlo_string,
211       const absl::optional<ErrorSpec>& error,
212       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
213       TF_MUST_USE_RESULT;
214   ::testing::AssertionResult Run(const absl::string_view hlo_string,
215                                  bool run_hlo_passes = true,
216                                  ExecutionProfile* profile = nullptr,
217                                  string backend_config = "") TF_MUST_USE_RESULT;
218 
219   // If assert_determinism is true, the assertion will fail unless all runs
220   // produce exactly the same output.
221   ::testing::AssertionResult RunMultipleTimes(
222       const absl::string_view hlo_string, bool run_hlo_passes,
223       std::vector<ExecutionProfile>* profiles, string backend_config = "",
224       bool assert_determinism = false) TF_MUST_USE_RESULT;
225   ::testing::AssertionResult RunAndCompareFromFile(
226       const string& filename, const absl::optional<ErrorSpec>& error,
227       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
228       TF_MUST_USE_RESULT;
229   ::testing::AssertionResult RunAndCompareNoHloPasses(
230       const absl::string_view hlo_string,
231       const absl::optional<ErrorSpec>& error,
232       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
233       TF_MUST_USE_RESULT;
234   ::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
235       const string& filename, const absl::optional<ErrorSpec>& error,
236       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
237       TF_MUST_USE_RESULT;
238 
239   // Convenience method to force the layout of a given parameter in a module.
240   // The layout of parameter number 'param_no' in the 'module' is set to
241   // 'layout'.
ForceParameterLayout(HloModule * module,int64 param_no,const Layout & layout)242   void ForceParameterLayout(HloModule* module, int64 param_no,
243                             const Layout& layout) {
244     ASSERT_LT(param_no,
245               module->mutable_entry_computation_layout()->parameter_count());
246     module->mutable_entry_computation_layout()
247         ->mutable_parameter_layout(param_no)
248         ->ResetLayout(layout);
249   }
250 
251   // Convenience method to force the layout of the computation result in a
252   // module. The result layout of 'module' is set to 'layout'.
ForceResultLayout(HloModule * module,const Layout & layout)253   void ForceResultLayout(HloModule* module, const Layout& layout) {
254     module->mutable_entry_computation_layout()
255         ->mutable_result_layout()
256         ->ResetLayout(layout);
257   }
258 
ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)259   void ForceResultLayout(HloModule* module, const Layout& layout,
260                          ShapeIndexView shape_index) {
261     module->mutable_entry_computation_layout()
262         ->mutable_result_layout()
263         ->ResetLayout(layout, shape_index);
264   }
265 
266   // Convenience method to clear the layout of the computation result in
267   // 'module'.
ForceClearResultLayout(HloModule * module)268   void ForceClearResultLayout(HloModule* module) {
269     module->mutable_entry_computation_layout()
270         ->mutable_result_layout()
271         ->Clear();
272   }
273 
274   // Gets the computation/instruction from the given module with the given name.
275   //
276   // This is useful for tests which create HLOs from a string and then want to
277   // inspect a particular computation or instruction.
278   HloComputation* FindComputation(HloModule* module, absl::string_view name);
279   HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
280   // Gets the instruction from the given module with the given opcode.
281   HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
282 
283   // Return an HLO verifier constructed for the test backend.
verifier()284   HloVerifier& verifier() const { return *hlo_verifier_; }
285 
286   static string TestName();
287 
288   // Returns the backend owned by the test runner.
289   Backend& backend();
290 
291   HloRunner test_runner_;
292   HloRunner reference_runner_;
293 
294   bool verifier_layout_sensitive_;
295   bool allow_mixed_precision_in_hlo_verifier_;
296   std::unique_ptr<HloVerifier> hlo_verifier_;
297 
298   ErrorSpec error_spec_{0.0001};
299 
300  private:
301   // Given the test module, makes a reference module that is ready to run on the
302   // reference platform. This assumes that the given module is ready to run on
303   // the test platform.
304   StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
305       const HloModule& test_module,
306       const std::function<void(HloModule*)>& reference_preprocessor);
307 
308   // Runs the module on two platforms with or without running hlo passes and
309   // compares the results. Returns whether the results are near or equal. If any
310   // error happens before the results are computed, returns the error status.
311   StatusOr<::testing::AssertionResult> RunAndCompareInternal(
312       std::unique_ptr<HloModule> module,
313       const absl::Span<Literal* const> arguments,
314       const absl::optional<ErrorSpec>& error, bool run_hlo_passes,
315       const std::function<void(HloModule*)>& reference_preprocessor);
316 };
317 
318 }  // namespace xla
319 
320 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
321