• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_runner.h"
30 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_layout.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
40 #include "tensorflow/core/platform/test.h"
41 
42 namespace xla {
43 
44 // A base class for tests which build and/or run HLO code. The class includes
45 // support for running an HLO module on two platforms and compare the results.
46 // This is a lower level of abstraction than using the client interface and
47 // enables, for one, explicitly building a graph of HLO instructions to run.
48 //
49 // This can also be used to write text/file-based test cases. Note that the test
50 // target is responsible for linking the needed backends. A convenient way to do
51 // this is to make it an xla_test: it will generate test targets linking with
52 // the respective backends, which will be used as the test backend; the
53 // interpreter backend is already linked with hlo_test_base so it will be the
54 // default reference backend. For example, if you want to compare both cpu vs.
55 // interpreter, and gpu vs. interpreter, you can:
56 //
57 //  xla_test (
58 //    name = "sample_text_test",
59 //    srcs = ["sample_text_test.cc"],
60 //    backends = [
61 //      "cpu",
62 //      "gpu",
63 //    ],
64 //    deps = [
65 //      "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
66 //      ...
67 //    ],
68 //  )
69 //
70 // For a more detailed example, see "../tests/sample_text_test.cc".
71 class HloTestBase : public ManifestCheckingTest {
72  public:
73   // Creates a new HLO module for a test. The module created will have
74   // TestName() for its name; it will also automatically populate its debug
75   // options from command-line flags. If you want a fresh HloModule object and
76   // then add HloComputations to it, it's recommended to use this method in your
77   // tests.
78   //
79   // This returns a vanilla HloModule that doesn't run the HLO verifier on
80   // destruction.
81   ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.")
82   std::unique_ptr<HloModule> CreateNewUnverifiedModule(
83       const string& name = TestName());
84 
85   // Like CreateNewUnverifiedModule, except the HloModule returned here runs the
86   // HLO verifier on destruction.
87   std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
88       const string& name = TestName(), int64 replica_count = 1);
89 
90   // Parses the given string and returns module as a VerifiedHloModule.
91   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
92       absl::string_view hlo_text, int64 replica_count = 1);
93   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
94       absl::string_view hlo_text, const HloModuleConfig& config);
95 
96   // Runs the hlo_pass with the provided module and returns the result. This
97   // function also verifies that the module remains unchanged when hlo_pass
98   // returns false as the StatusOr value.
99   static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
100                                    HloModule* module);
101 
102   static PrecisionConfig DefaultPrecisionConfig(int operands);
103 
104   // Sets most fath math options to be enabled to model the fast math flags
105   // generally used for CPU:AOT compilation.
106   static void SetAotFastMathDebugOptions(DebugOptions* options);
107 
108  protected:
109   // This uses the interpreter backend as the reference backend and
110   // automatically finds another supported backend as the test backend. If the
111   // interpreter is the only supported backend, it will be both the test backend
112   // and the reference backend.
113   HloTestBase(bool verifier_layout_sensitive = false,
114               bool allow_mixed_precision_in_hlo_verifier = true,
115               std::function<bool(const HloInstruction*)>
116                   instruction_can_change_layout_func = {});
117 
118   // If your test doesn't use interpreter as the reference backend, you can use
119   // this constructor. Note that your test target is responsible for linking in
120   // both needed backends.
121   HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
122               bool verifier_layout_sensitive = false,
123               bool allow_mixed_precision_in_hlo_verifier = true,
124               std::function<bool(const HloInstruction*)>
125                   instruction_can_change_layout_func = {});
126 
~HloTestBase()127   ~HloTestBase() override {}
128 
129   // Populates debug options from command-line flags and adjusts the options for
130   // testing. It is recommended to use this when you need to pass in
131   // DebugOptions, e.g. when creating a module from a string or a file.
132   //
133   // This function is virtual so tests can specify an alternative set of debug
134   // options (e.g. disabling additional passes).
135   virtual DebugOptions GetDebugOptionsForTest();
136 
137   // Gets an HloModuleConfig with options appropriate for tests.
138   HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) {
139     HloModuleConfig config;
140     config.set_debug_options(GetDebugOptionsForTest());
141     config.set_replica_count(replica_count);
142     return config;
143   }
144 
145   // Executes the given module and return the result as a Literal.
146   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
147                             absl::Span<Literal* const> arguments);
148 
149   // Same as above, except the module will be executed without running any HLO
150   // passes on it.
151   Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
152                              absl::Span<Literal* const> arguments);
153 
154   Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
155                              absl::Span<Literal* const> arguments);
156 
157   // Executes the given module on multiple replicas.
158   //
159   // use_threads indicates whether this replicated computation will be executed
160   // with a thread-per-replica, vs using an implicitly async call such as
161   // Executable::ExecuteOnStreams.
162   StatusOr<std::vector<Literal>> ExecuteReplicated(
163       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
164       int64 num_replicas, bool use_threads, bool run_hlo_passes = false);
165 
166   // Same as above, but uses specified device assignment.
167   StatusOr<std::vector<Literal>> ExecuteReplicated(
168       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
169       int64 num_replicas, DeviceAssignment* device_assignment,
170       bool run_hlo_passes, bool use_threads);
171 
172   // Same as above, but allows passing different programs for replicas.
173   StatusOr<std::vector<Literal>> ExecuteReplicated(
174       std::function<Executable*(int64)> executable_provider,
175       std::function<int64(int64)> argument_count_provider,
176       std::function<const Literal*(int64, int64)> argument_provider,
177       int64 num_replicas, bool run_hlo_passes);
178 
179   // Executes the given hlo module on two backends and compares results.
180   //
181   // 'arguments': the input of the hlo module.
182   //
183   // 'error': if has value, expects the results to be near (within the error
184   // bound). Otherwise, expects the results to be equal.
185   //
186   // 'reference_preprocessor': the module should be ready to run on the test
187   // backend, but it might need to be tailored so that it is able to run on the
188   // reference backend. Note that the program shape of the module must not be
189   // modified.
190   ::testing::AssertionResult RunAndCompare(
191       std::unique_ptr<HloModule> module,
192       const absl::Span<Literal* const> arguments,
193       const absl::optional<ErrorSpec>& error,
194       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
195       TF_MUST_USE_RESULT;
196 
197   // Same as above, except that the module will be executed without Hlo
198   // optimization.
199   ::testing::AssertionResult RunAndCompareNoHloPasses(
200       std::unique_ptr<HloModule> module,
201       const absl::Span<Literal* const> arguments,
202       const absl::optional<ErrorSpec>& error,
203       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
204       TF_MUST_USE_RESULT;
205 
206   // Executes an hlo module with fake inputs and compares the results.
207   ::testing::AssertionResult RunAndCompare(
208       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
209       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
210       TF_MUST_USE_RESULT;
211 
212   // Same as above, except that the module will be executed without Hlo
213   // optimization.
214   ::testing::AssertionResult RunAndCompareNoHloPasses(
215       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
216       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
217       TF_MUST_USE_RESULT;
218 
219   // Executes an hlo module with fake inputs and checks that the execution is
220   // successful.
221   ::testing::AssertionResult Run(std::unique_ptr<HloModule> module,
222                                  bool run_hlo_passes) TF_MUST_USE_RESULT;
223 
224   // Convenient wrappers for executing and comparing an hlo module with fake
225   // input. Module can be passed in directly, or parsed from an hlo_string,
226   // or loaded from a file.
227   ::testing::AssertionResult RunAndCompare(
228       const absl::string_view hlo_string,
229       const absl::optional<ErrorSpec>& error,
230       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
231       TF_MUST_USE_RESULT;
232   ::testing::AssertionResult Run(const absl::string_view hlo_string,
233                                  bool run_hlo_passes = true,
234                                  ExecutionProfile* profile = nullptr,
235                                  string backend_config = "") TF_MUST_USE_RESULT;
236 
237   // Executes an hlo module with fake inputs on multiple replicas.
238   ::testing::AssertionResult RunReplicated(
239       const absl::string_view hlo_string, bool run_hlo_passes = true,
240       int64 num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT;
241 
242   // If assert_determinism is true, the assertion will fail unless all runs
243   // produce exactly the same output.
244   ::testing::AssertionResult RunMultipleTimes(
245       const absl::string_view hlo_string, bool run_hlo_passes,
246       std::vector<ExecutionProfile>* profiles, string backend_config = "",
247       bool assert_determinism = false) TF_MUST_USE_RESULT;
248   ::testing::AssertionResult RunAndCompareFromFile(
249       const string& filename, const absl::optional<ErrorSpec>& error,
250       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
251       TF_MUST_USE_RESULT;
252   ::testing::AssertionResult RunAndCompareNoHloPasses(
253       const absl::string_view hlo_string,
254       const absl::optional<ErrorSpec>& error,
255       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
256       TF_MUST_USE_RESULT;
257   ::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
258       const string& filename, const absl::optional<ErrorSpec>& error,
259       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
260       TF_MUST_USE_RESULT;
261 
262   // Convenience method to force the layout of a given parameter in a module.
263   // The layout of parameter number 'param_no' in the 'module' is set to
264   // 'layout'.
ForceParameterLayout(HloModule * module,int64 param_no,const Layout & layout)265   void ForceParameterLayout(HloModule* module, int64 param_no,
266                             const Layout& layout) {
267     ASSERT_LT(param_no,
268               module->mutable_entry_computation_layout()->parameter_count());
269     module->mutable_entry_computation_layout()
270         ->mutable_parameter_layout(param_no)
271         ->ResetLayout(layout);
272   }
273 
274   // Convenience method to force the layout of the computation result in a
275   // module. The result layout of 'module' is set to 'layout'.
ForceResultLayout(HloModule * module,const Layout & layout)276   void ForceResultLayout(HloModule* module, const Layout& layout) {
277     module->mutable_entry_computation_layout()
278         ->mutable_result_layout()
279         ->ResetLayout(layout);
280   }
281 
ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)282   void ForceResultLayout(HloModule* module, const Layout& layout,
283                          ShapeIndexView shape_index) {
284     module->mutable_entry_computation_layout()
285         ->mutable_result_layout()
286         ->ResetLayout(layout, shape_index);
287   }
288 
289   // Convenience method to clear the layout of the computation result in
290   // 'module'.
ForceClearResultLayout(HloModule * module)291   void ForceClearResultLayout(HloModule* module) {
292     module->mutable_entry_computation_layout()
293         ->mutable_result_layout()
294         ->Clear();
295   }
296 
297   // Gets the computation/instruction from the given module with the given name.
298   //
299   // This is useful for tests which create HLOs from a string and then want to
300   // inspect a particular computation or instruction.
301   HloComputation* FindComputation(HloModule* module, absl::string_view name);
302   HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
303   // Gets the instruction from the given module with the given opcode.
304   HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
305 
306   // Return an HLO verifier constructed for the test backend.
verifier()307   HloVerifier& verifier() const { return *hlo_verifier_; }
308 
309   static string TestName();
310 
311   // Returns the backend owned by the test runner.
312   Backend& backend();
313 
314   HloRunner test_runner_;
315   HloRunner reference_runner_;
316 
317   bool verifier_layout_sensitive_;
318   bool allow_mixed_precision_in_hlo_verifier_;
319   std::unique_ptr<HloVerifier> hlo_verifier_;
320 
321   ErrorSpec error_spec_{0.0001};
322 
323  private:
324   // Given the test module, makes a reference module that is ready to run on the
325   // reference platform. This assumes that the given module is ready to run on
326   // the test platform.
327   StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
328       const HloModule& test_module,
329       const std::function<void(HloModule*)>& reference_preprocessor);
330 
331   // Runs the module on two platforms with or without running hlo passes and
332   // compares the results. Returns whether the results are near or equal. If any
333   // error happens before the results are computed, returns the error status.
334   StatusOr<::testing::AssertionResult> RunAndCompareInternal(
335       std::unique_ptr<HloModule> module,
336       const absl::Span<Literal* const> arguments,
337       const absl::optional<ErrorSpec>& error, bool run_hlo_passes,
338       const std::function<void(HloModule*)>& reference_preprocessor);
339 };
340 
341 }  // namespace xla
342 
343 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
344