• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_runner.h"
30 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_layout.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38 #include "tensorflow/core/platform/test.h"
39 
40 namespace xla {
41 
42 // An HLO module derived class which verifies itself on destruction. This class
43 // is intended to be used in unit tests. Any verification errors are raised via
44 // ADD_FAILURE.
45 class VerifiedHloModule : public HloModule {
46  public:
VerifiedHloModule(const string & name,const HloModuleConfig & config,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<int64 (const Shape &)> shape_size_function)47   VerifiedHloModule(const string& name, const HloModuleConfig& config,
48                     bool verifier_layout_sensitive,
49                     bool allow_mixed_precision_in_hlo_verifier,
50                     std::function<int64(const Shape&)> shape_size_function)
51       : HloModule(name, config),
52         verifier_(
53             verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
54             /*instruction_can_change_layout_func=*/{}, shape_size_function) {}
55 
~VerifiedHloModule()56   ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
57 
58   // Verifies the module using HloVerifier and returns the status.
59   Status Verify();
60 
61   // Verifies the module and flags any error with ADD_FAILURE. 'message' is
62   // included in the failure message.
63   void VerifyOrAddFailure(const string& message);
64 
65  private:
66   HloVerifier verifier_;
67 };
68 
69 // A base class for tests which build and/or run HLO code. The class includes
70 // support for running an HLO module on two platforms and compare the results.
71 // This is a lower level of abstraction than using the client interface and
72 // enables, for one, explicitly building a graph of HLO instructions to run.
73 //
74 // This can also be used to write text/file-based test cases. Note that the test
75 // target is responsible for linking the needed backends. A convenient way to do
76 // this is to make it an xla_test: it will generate test targets linking with
77 // the respective backends, which will be used as the test backend; the
78 // interpreter backend is already linked with hlo_test_base so it will be the
79 // default reference backend. For example, if you want to compare both cpu vs.
80 // interpreter, and gpu vs. interpreter, you can:
81 //
82 //  xla_test (
83 //    name = "sample_text_test",
84 //    srcs = ["sample_text_test.cc"],
85 //    backends = [
86 //      "cpu",
87 //      "gpu",
88 //    ],
89 //    deps = [
90 //      "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
91 //      ...
92 //    ],
93 //  )
94 //
95 // For a more detailed example, see "../tests/sample_text_test.cc".
96 class HloTestBase : public ::testing::Test {
97  public:
98   // Creates a new HLO module for a test. The module created will have
99   // TestName() for its name; it will also automatically populate its debug
100   // options from command-line flags. If you want a fresh HloModule object and
101   // then add HloComputations to it, it's recommended to use this method in your
102   // tests.
103   //
104   // This returns a vanilla HloModule that doesn't run the HLO verifier on
105   // destruction.
106   ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.")
107   std::unique_ptr<HloModule> CreateNewUnverifiedModule(
108       const string& name = TestName());
109 
110   // Like CreateNewUnverifiedModule, except the HloModule returned here runs the
111   // HLO verifier on destruction.
112   std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
113       const string& name = TestName());
114 
115   // Parses the given string and returns module as a VerifiedHloModule.
116   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
117       absl::string_view hlo_text,
118       const HloModuleConfig& config = HloModuleConfig());
119 
120   // Runs the hlo_pass with the provided module and returns the result. This
121   // function also verifies that the module remains unchanged when hlo_pass
122   // returns false as the StatusOr value.
123   static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
124                                    HloModule* module);
125 
126   static PrecisionConfig DefaultPrecisionConfig(int operands);
127 
128  protected:
129   // This uses the interpreter backend as the reference backend and
130   // automatically finds another supported backend as the test backend. If the
131   // interpreter is the only supported backend, it will be both the test backend
132   // and the reference backend.
133   HloTestBase(bool verifier_layout_sensitive = false,
134               bool allow_mixed_precision_in_hlo_verifier = true,
135               std::function<bool(const HloInstruction*)>
136                   instruction_can_change_layout_func = {});
137 
138   // If your test doesn't use interpreter as the reference backend, you can use
139   // this constructor. Note that your test target is responsible for linking in
140   // both needed backends.
141   HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
142               bool verifier_layout_sensitive = false,
143               bool allow_mixed_precision_in_hlo_verifier = true,
144               std::function<bool(const HloInstruction*)>
145                   instruction_can_change_layout_func = {});
146 
~HloTestBase()147   ~HloTestBase() override {}
148 
149   // Populates debug options from command-line flags and adjusts the options for
150   // testing. It is recommended to use this when you need to pass in
151   // DebugOptions, e.g. when creating a module from a string or a file.
152   //
153   // This function is virtual so tests can specify an alternative set of debug
154   // options (e.g. disabling additional passes).
155   virtual DebugOptions GetDebugOptionsForTest();
156 
157   // Gets an HloModuleConfig with options appropriate for tests.
GetModuleConfigForTest()158   HloModuleConfig GetModuleConfigForTest() {
159     HloModuleConfig config;
160     config.set_debug_options(GetDebugOptionsForTest());
161     return config;
162   }
163 
164   // Executes the given module and return the result as a Literal.
165   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
166                             absl::Span<Literal* const> arguments);
167 
168   // Same as above, except the module will be executed without running any HLO
169   // passes on it.
170   Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
171                              absl::Span<Literal* const> arguments);
172 
173   Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
174                              absl::Span<Literal* const> arguments);
175 
176   // Executes the given module on multiple replicas.
177   //
178   // use_threads indicates whether this replicated computation will be executed
179   // with a thread-per-replica, vs using an implicitly async call such as
180   // Executable::ExecuteOnStreams.
181   StatusOr<std::vector<Literal>> ExecuteReplicated(
182       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
183       int64 num_replicas, bool use_threads);
184 
185   // Same as above, but uses specified device assignment.
186   StatusOr<std::vector<Literal>> ExecuteReplicated(
187       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
188       int64 num_replicas, DeviceAssignment* device_assignment,
189       bool run_hlo_passes, bool use_threads);
190 
191   // Executes the given hlo module on two backends and compares results.
192   //
193   // 'arguments': the input of the hlo module.
194   //
195   // 'error': if has value, expects the results to be near (within the error
196   // bound). Otherwise, expects the results to be equal.
197   //
198   // 'reference_preprocessor': the module should be ready to run on the test
199   // backend, but it might need to be tailored so that it is able to run on the
200   // reference backend. Note that the program shape of the module must not be
201   // modified.
202   ::testing::AssertionResult RunAndCompare(
203       std::unique_ptr<HloModule> module,
204       const absl::Span<Literal* const> arguments,
205       const absl::optional<ErrorSpec>& error,
206       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
207       TF_MUST_USE_RESULT;
208 
209   // Same as above, except that the module will be executed without Hlo
210   // optimization.
211   ::testing::AssertionResult RunAndCompareNoHloPasses(
212       std::unique_ptr<HloModule> module,
213       const absl::Span<Literal* const> arguments,
214       const absl::optional<ErrorSpec>& error,
215       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
216       TF_MUST_USE_RESULT;
217 
218   // Executes an hlo module with fake inputs and compares the results.
219   ::testing::AssertionResult RunAndCompare(
220       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
221       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
222       TF_MUST_USE_RESULT;
223 
224   // Same as above, except that the module will be executed without Hlo
225   // optimization.
226   ::testing::AssertionResult RunAndCompareNoHloPasses(
227       std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
228       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
229       TF_MUST_USE_RESULT;
230 
231   // Convenient wrappers for executing and comparing an hlo module with fake
232   // input. Module can be passed in directly, or parsed from an hlo_string,
233   // or loaded from a file.
234   ::testing::AssertionResult RunAndCompare(
235       const absl::string_view hlo_string,
236       const absl::optional<ErrorSpec>& error,
237       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
238       TF_MUST_USE_RESULT;
239   ::testing::AssertionResult Run(const absl::string_view hlo_string,
240                                  bool run_hlo_passes = true,
241                                  ExecutionProfile* profile = nullptr,
242                                  string backend_config = "") TF_MUST_USE_RESULT;
243   ::testing::AssertionResult RunMultipleTimes(
244       const absl::string_view hlo_string, bool run_hlo_passes,
245       std::vector<ExecutionProfile>* profiles,
246       string backend_config = "") TF_MUST_USE_RESULT;
247   ::testing::AssertionResult RunAndCompareFromFile(
248       const string& filename, const absl::optional<ErrorSpec>& error,
249       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
250       TF_MUST_USE_RESULT;
251   ::testing::AssertionResult RunAndCompareNoHloPasses(
252       const absl::string_view hlo_string,
253       const absl::optional<ErrorSpec>& error,
254       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
255       TF_MUST_USE_RESULT;
256   ::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
257       const string& filename, const absl::optional<ErrorSpec>& error,
258       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
259       TF_MUST_USE_RESULT;
260 
261   // Convenience method to force the layout of a given parameter in a module.
262   // The layout of parameter number 'param_no' in the 'module' is set to
263   // 'layout'.
ForceParameterLayout(HloModule * module,int64 param_no,const Layout & layout)264   void ForceParameterLayout(HloModule* module, int64 param_no,
265                             const Layout& layout) {
266     ASSERT_LT(param_no,
267               module->mutable_entry_computation_layout()->parameter_count());
268     module->mutable_entry_computation_layout()
269         ->mutable_parameter_layout(param_no)
270         ->ResetLayout(layout);
271   }
272 
273   // Convenience method to force the layout of the computation result in a
274   // module. The result layout of 'module' is set to 'layout'.
ForceResultLayout(HloModule * module,const Layout & layout)275   void ForceResultLayout(HloModule* module, const Layout& layout) {
276     module->mutable_entry_computation_layout()
277         ->mutable_result_layout()
278         ->ResetLayout(layout);
279   }
280 
ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)281   void ForceResultLayout(HloModule* module, const Layout& layout,
282                          ShapeIndexView shape_index) {
283     module->mutable_entry_computation_layout()
284         ->mutable_result_layout()
285         ->ResetLayout(layout, shape_index);
286   }
287 
288   // Convenience method to clear the layout of the computation result in
289   // 'module'.
ForceClearResultLayout(HloModule * module)290   void ForceClearResultLayout(HloModule* module) {
291     module->mutable_entry_computation_layout()
292         ->mutable_result_layout()
293         ->Clear();
294   }
295 
296   // Gets the computation/instruction from the given module with the given name.
297   //
298   // This is useful for tests which create HLOs from a string and then want to
299   // inspect a particular computation or instruction.
300   HloComputation* FindComputation(HloModule* module, absl::string_view name);
301   HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
302 
303   // Return an HLO verifier constructed for the test backend.
verifier()304   HloVerifier& verifier() const { return *hlo_verifier_; }
305 
306   static string TestName();
307 
308   // Returns the backend owned by the test runner.
309   Backend& backend();
310 
311   HloRunner test_runner_;
312   HloRunner reference_runner_;
313 
314   bool verifier_layout_sensitive_;
315   bool allow_mixed_precision_in_hlo_verifier_;
316   std::unique_ptr<HloVerifier> hlo_verifier_;
317 
318   ErrorSpec error_spec_{0.0001};
319 
320  private:
321   // Given the test module, makes a reference module that is ready to run on the
322   // reference platform. This assumes that the given module is ready to run on
323   // the test platform.
324   StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
325       const HloModule& test_module,
326       const std::function<void(HloModule*)>& reference_preprocessor);
327 
328   // Runs the module on two platforms with or without running hlo passes and
329   // compares the results. Returns whether the results are near or equal. If any
330   // error happens before the results are computed, returns the error status.
331   StatusOr<::testing::AssertionResult> RunAndCompareInternal(
332       std::unique_ptr<HloModule> module,
333       const absl::Span<Literal* const> arguments,
334       const absl::optional<ErrorSpec>& error, bool run_hlo_passes,
335       const std::function<void(HloModule*)>& reference_preprocessor);
336 };
337 
338 }  // namespace xla
339 
340 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
341