• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_runner.h"
30 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_layout.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
40 #include "tensorflow/core/platform/test.h"
41 
42 namespace xla {
43 
44 // A base class for tests which build and/or run HLO code. The class includes
45 // support for running an HLO module on two platforms and compare the results.
46 // This is a lower level of abstraction than using the client interface and
47 // enables, for one, explicitly building a graph of HLO instructions to run.
48 //
49 // This can also be used to write text/file-based test cases. Note that the test
50 // target is responsible for linking the needed backends. A convenient way to do
51 // this is to make it an xla_test: it will generate test targets linking with
52 // the respective backends, which will be used as the test backend; the
53 // interpreter backend is already linked with hlo_test_base so it will be the
54 // default reference backend. For example, if you want to compare both cpu vs.
55 // interpreter, and gpu vs. interpreter, you can:
56 //
57 //  xla_test (
58 //    name = "sample_text_test",
59 //    srcs = ["sample_text_test.cc"],
60 //    backends = [
61 //      "cpu",
62 //      "gpu",
63 //    ],
64 //    deps = [
65 //      "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
66 //      ...
67 //    ],
68 //  )
69 //
70 // For a more detailed example, see "../tests/sample_text_test.cc".
71 class HloTestBase : public ManifestCheckingTest {
72  public:
73   // Creates a new HLO module for a test. The module created will have
74   // TestName() for its name; it will also automatically populate its debug
75   // options from command-line flags. If you want a fresh HloModule object and
76   // then add HloComputations to it, it's recommended to use this method in your
77   // tests.
78   //
79   // This returns a vanilla HloModule that doesn't run the HLO verifier on
80   // destruction.
81   ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.")
82   std::unique_ptr<HloModule> CreateNewUnverifiedModule(
83       const string& name = TestName());
84 
85   // Like CreateNewUnverifiedModule, except the HloModule returned here runs the
86   // HLO verifier on destruction.
87   std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
88       const string& name = TestName(), int64_t replica_count = 1);
89 
90   // Parses the given string and returns module as a VerifiedHloModule.
91   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
92       absl::string_view hlo_text, int64_t replica_count = 1,
93       int64_t num_partitions = 1);
94   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
95       absl::string_view hlo_text, const HloModuleConfig& config);
96 
97   // Runs the hlo_pass with the provided module and returns the result. This
98   // function also verifies that the module remains unchanged when hlo_pass
99   // returns false as the StatusOr value.
100   static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
101                                    HloModule* module);
102 
103   static PrecisionConfig DefaultPrecisionConfig(int operands);
104 
105   // Sets most fath math options to be enabled to model the fast math flags
106   // generally used for CPU:AOT compilation.
107   static void SetAotFastMathDebugOptions(DebugOptions* options);
108 
109  protected:
110   // This uses the interpreter backend as the reference backend and
111   // automatically finds another supported backend as the test backend. If the
112   // interpreter is the only supported backend, it will be both the test backend
113   // and the reference backend.
114   HloTestBase(bool verifier_layout_sensitive = false,
115               bool allow_mixed_precision_in_hlo_verifier = true,
116               std::function<bool(const HloInstruction*)>
117                   instruction_can_change_layout_func = {});
118 
119   // If your test doesn't use interpreter as the reference backend, you can use
120   // this constructor. Note that your test target is responsible for linking in
121   // both needed backends.
122   HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
123               bool verifier_layout_sensitive = false,
124               bool allow_mixed_precision_in_hlo_verifier = true,
125               std::function<bool(const HloInstruction*)>
126                   instruction_can_change_layout_func = {});
127 
~HloTestBase()128   ~HloTestBase() override {}
129 
130   // Populates debug options from command-line flags and adjusts the options for
131   // testing. It is recommended to use this when you need to pass in
132   // DebugOptions, e.g. when creating a module from a string or a file.
133   //
134   // This function is virtual so tests can specify an alternative set of debug
135   // options (e.g. disabling additional passes).
136   virtual DebugOptions GetDebugOptionsForTest();
137 
138   // Gets an HloModuleConfig with options appropriate for tests.
139   HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1,
140                                          int64_t num_partitions = 1) {
141     HloModuleConfig config;
142     config.set_debug_options(GetDebugOptionsForTest());
143     config.set_replica_count(replica_count);
144     config.set_num_partitions(num_partitions);
145     return config;
146   }
147 
148   // Executes the given module and return the result as a Literal.
149   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
150                             absl::Span<Literal* const> arguments);
151 
152   // Same as above, except the module will be executed without running any HLO
153   // passes on it.
154   Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
155                              absl::Span<Literal* const> arguments);
156 
157   Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
158                              absl::Span<Literal* const> arguments);
159 
160   // Executes the given module on multiple replicas.
161   //
162   // use_threads indicates whether this replicated computation will be executed
163   // with a thread-per-replica, vs using an implicitly async call such as
164   // Executable::ExecuteOnStreams.
165   StatusOr<std::vector<Literal>> ExecuteReplicated(
166       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
167       int64_t num_replicas, bool use_threads, bool run_hlo_passes = false);
168 
169   // Same as above, but uses specified device assignment.
170   StatusOr<std::vector<Literal>> ExecuteReplicated(
171       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
172       int64_t num_replicas, DeviceAssignment* device_assignment,
173       bool run_hlo_passes, bool use_threads);
174 
175   // Same as above, but allows passing different programs for replicas.
176   StatusOr<std::vector<Literal>> ExecuteReplicated(
177       std::function<Executable*(int64_t)> executable_provider,
178       std::function<int64(int64_t)> argument_count_provider,
179       std::function<const Literal*(int64_t, int64_t)> argument_provider,
180       int64_t num_replicas, bool run_hlo_passes);
181 
182   // Executes the given hlo module on two backends and compares results.
183   //
184   // 'arguments': the input of the hlo module.
185   //
186   // 'error': if has value, expects the results to be near (within the error
187   // bound). Otherwise, expects the results to be equal.
188   //
189   // 'reference_preprocessor': the module should be ready to run on the test
190   // backend, but it might need to be tailored so that it is able to run on the
191   // reference backend. Note that the program shape of the module must not be
192   // modified.
193   ::testing::AssertionResult RunAndCompare(
194       std::unique_ptr<HloModule> module,
195       const absl::Span<Literal* const> arguments,
196       const absl::optional<ErrorSpec>& error,
197       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
198       TF_MUST_USE_RESULT;
199 
200   // Same as above, except that the module will be executed without Hlo
201   // optimization.
202   ::testing::AssertionResult RunAndCompareNoHloPasses(
203       std::unique_ptr<HloModule> module,
204       const absl::Span<Literal* const> arguments,
205       const absl::optional<ErrorSpec>& error,
206       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
207       TF_MUST_USE_RESULT;
208 
209   // Executes an hlo module with fake inputs and compares the results.
210   ::testing::AssertionResult RunAndCompare(
211       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
212       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
213       TF_MUST_USE_RESULT;
214 
215   // Same as above, except that the module will be executed without Hlo
216   // optimization.
217   ::testing::AssertionResult RunAndCompareNoHloPasses(
218       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
219       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
220       TF_MUST_USE_RESULT;
221 
222   // Executes an hlo module with fake inputs and checks that the execution is
223   // successful.
224   ::testing::AssertionResult Run(std::unique_ptr<HloModule> module,
225                                  bool run_hlo_passes) TF_MUST_USE_RESULT;
226 
227   // Convenient wrappers for executing and comparing an hlo module with fake
228   // input. Module can be passed in directly, or parsed from an hlo_string,
229   // or loaded from a file.
230   ::testing::AssertionResult RunAndCompare(
231       const absl::string_view hlo_string,
232       const absl::optional<ErrorSpec>& error,
233       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
234       TF_MUST_USE_RESULT;
235   ::testing::AssertionResult Run(const absl::string_view hlo_string,
236                                  bool run_hlo_passes = true,
237                                  ExecutionProfile* profile = nullptr,
238                                  string backend_config = "") TF_MUST_USE_RESULT;
239 
240   // Executes an hlo module with fake inputs on multiple replicas.
241   ::testing::AssertionResult RunReplicated(
242       const absl::string_view hlo_string, bool run_hlo_passes = true,
243       int64_t num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT;
244 
245   // If assert_determinism is true, the assertion will fail unless all runs
246   // produce exactly the same output.
247   ::testing::AssertionResult RunMultipleTimes(
248       const absl::string_view hlo_string, bool run_hlo_passes,
249       std::vector<ExecutionProfile>* profiles, string backend_config = "",
250       bool assert_determinism = false) TF_MUST_USE_RESULT;
251   ::testing::AssertionResult RunAndCompareFromFile(
252       const string& filename, const absl::optional<ErrorSpec>& error,
253       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
254       TF_MUST_USE_RESULT;
255   ::testing::AssertionResult RunAndCompareNoHloPasses(
256       const absl::string_view hlo_string,
257       const absl::optional<ErrorSpec>& error,
258       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
259       TF_MUST_USE_RESULT;
260   ::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
261       const string& filename, const absl::optional<ErrorSpec>& error,
262       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
263       TF_MUST_USE_RESULT;
264 
265   // Convenience method to force the layout of a given parameter in a module.
266   // The layout of parameter number 'param_no' in the 'module' is set to
267   // 'layout'.
ForceParameterLayout(HloModule * module,int64_t param_no,const Layout & layout)268   void ForceParameterLayout(HloModule* module, int64_t param_no,
269                             const Layout& layout) {
270     ASSERT_LT(param_no,
271               module->mutable_entry_computation_layout()->parameter_count());
272     module->mutable_entry_computation_layout()
273         ->mutable_parameter_layout(param_no)
274         ->ResetLayout(layout);
275   }
276 
277   // Convenience method to force the layout of the computation result in a
278   // module. The result layout of 'module' is set to 'layout'.
ForceResultLayout(HloModule * module,const Layout & layout)279   void ForceResultLayout(HloModule* module, const Layout& layout) {
280     module->mutable_entry_computation_layout()
281         ->mutable_result_layout()
282         ->ResetLayout(layout);
283   }
284 
ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)285   void ForceResultLayout(HloModule* module, const Layout& layout,
286                          ShapeIndexView shape_index) {
287     module->mutable_entry_computation_layout()
288         ->mutable_result_layout()
289         ->ResetLayout(layout, shape_index);
290   }
291 
292   // Convenience method to clear the layout of the computation result in
293   // 'module'.
ForceClearResultLayout(HloModule * module)294   void ForceClearResultLayout(HloModule* module) {
295     module->mutable_entry_computation_layout()
296         ->mutable_result_layout()
297         ->Clear();
298   }
299 
300   // Gets the computation/instruction from the given module with the given name.
301   //
302   // This is useful for tests which create HLOs from a string and then want to
303   // inspect a particular computation or instruction.
304   HloComputation* FindComputation(HloModule* module, absl::string_view name);
305   HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
306   // Gets the instruction from the given module with the given opcode.
307   HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
308 
309   // Return an HLO verifier constructed for the test backend.
verifier()310   HloVerifier& verifier() const { return *hlo_verifier_; }
311 
312   static string TestName();
313 
314   // Returns the backend owned by the test runner.
315   Backend& backend();
316 
317   HloRunner test_runner_;
318   HloRunner reference_runner_;
319 
320   bool verifier_layout_sensitive_;
321   bool allow_mixed_precision_in_hlo_verifier_;
322   std::unique_ptr<HloVerifier> hlo_verifier_;
323 
324   ErrorSpec error_spec_{0.0001};
325 
326  protected:
327   // Helper functions to get test and reference platforms.
328   static se::Platform* GetReferencePlatform();
329   static se::Platform* GetTestPlatform();
330 
331  private:
332   // Given the test module, makes a reference module that is ready to run on the
333   // reference platform. This assumes that the given module is ready to run on
334   // the test platform.
335   StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
336       const HloModule& test_module,
337       const std::function<void(HloModule*)>& reference_preprocessor);
338 
339   // Runs the module on two platforms with or without running hlo passes and
340   // compares the results. Returns whether the results are near or equal. If any
341   // error happens before the results are computed, returns the error status.
342   StatusOr<::testing::AssertionResult> RunAndCompareInternal(
343       std::unique_ptr<HloModule> module,
344       const absl::Span<Literal* const> arguments,
345       const absl::optional<ErrorSpec>& error, bool run_hlo_passes,
346       const std::function<void(HloModule*)>& reference_preprocessor);
347 };
348 
349 }  // namespace xla
350 
351 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
352