• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 
43 namespace {
44 
45 using absl::optional;
46 using absl::string_view;
47 
48 constexpr char kInterpreter[] = "interpreter";
49 
50 // Helper functions to get test and reference platforms.
GetReferencePlatform()51 se::Platform* GetReferencePlatform() {
52   auto result = PlatformUtil::GetPlatform(kInterpreter);
53   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
54   return result.ValueOrDie();
55 }
56 
GetTestPlatform()57 se::Platform* GetTestPlatform() {
58   auto result = PlatformUtil::GetDefaultPlatform();
59   TF_CHECK_OK(result.status()) << "could not get test platform";
60   return result.ValueOrDie();
61 }
62 
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
64   if (lhs.parameters_size() != rhs.parameters_size()) {
65     return false;
66   }
67   for (int i = 0; i < lhs.parameters_size(); i++) {
68     if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
69       return false;
70     }
71   }
72   return ShapeUtil::Equal(lhs.result(), rhs.result());
73 }
74 
GetProgramShapeWithLayout(const HloModule & module)75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
76   ProgramShape program_shape;
77   const auto* entry = module.entry_computation();
78   for (const auto* param : entry->parameter_instructions()) {
79     *program_shape.add_parameters() = param->shape();
80     *program_shape.add_parameter_names() = param->name();
81   }
82   *program_shape.mutable_result() = entry->root_instruction()->shape();
83   return program_shape;
84 }
85 
86 }  // namespace
87 
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)88 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
89                          bool allow_mixed_precision_in_hlo_verifier,
90                          std::function<bool(const HloInstruction*)>
91                              instruction_can_change_layout_func)
92     : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
93                   verifier_layout_sensitive,
94                   allow_mixed_precision_in_hlo_verifier,
95                   instruction_can_change_layout_func) {}
96 
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)97 HloTestBase::HloTestBase(se::Platform* test_platform,
98                          se::Platform* reference_platform,
99                          bool verifier_layout_sensitive,
100                          bool allow_mixed_precision_in_hlo_verifier,
101                          std::function<bool(const HloInstruction*)>
102                              instruction_can_change_layout_func)
103     : test_runner_(test_platform),
104       reference_runner_(reference_platform),
105       verifier_layout_sensitive_(verifier_layout_sensitive),
106       allow_mixed_precision_in_hlo_verifier_(
107           allow_mixed_precision_in_hlo_verifier) {
108   hlo_verifier_ = absl::make_unique<HloVerifier>(
109       /*layout_sensitive=*/verifier_layout_sensitive,
110       /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
111       instruction_can_change_layout_func);
112 }
113 
CreateNewUnverifiedModule(const string & name)114 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
115     const string& name) {
116   return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
117 }
118 
CreateNewVerifiedModule(const string & name,int64 replica_count)119 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
120     const string& name, int64 replica_count) {
121   return absl::make_unique<VerifiedHloModule>(
122       name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
123       allow_mixed_precision_in_hlo_verifier_,
124       backend().compiler()->ShapeSizeBytesFunction());
125 }
126 
127 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,int64 replica_count)128 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
129                                           int64 replica_count) {
130   return ParseAndReturnVerifiedModule(hlo_text,
131                                       GetModuleConfigForTest(replica_count));
132 }
133 
134 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)135 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
136                                           const HloModuleConfig& config) {
137   auto module = absl::make_unique<VerifiedHloModule>(
138       TestName(), config, verifier_layout_sensitive_,
139       allow_mixed_precision_in_hlo_verifier_,
140       backend().compiler()->ShapeSizeBytesFunction());
141   TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
142   return std::move(module);
143 }
144 
145 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)146 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
147                                        HloModule* module) {
148   const string module_str_before_run = module->ToProto().ShortDebugString();
149   const auto status_or = hlo_pass->Run(module);
150   if (status_or.status().ok()) {
151     const string module_str_after_run = module->ToProto().ShortDebugString();
152     if (!status_or.ValueOrDie()) {
153       // Check that the proto remains same.
154       EXPECT_EQ(module_str_after_run, module_str_before_run);
155     }
156   }
157   return status_or;
158 }
159 
160 /* static */
DefaultPrecisionConfig(int operands)161 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
162   PrecisionConfig precision_config;
163   precision_config.mutable_operand_precision()->Resize(
164       operands, PrecisionConfig::DEFAULT);
165   return precision_config;
166 }
167 
SetAotFastMathDebugOptions(DebugOptions * options)168 void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) {
169   options->set_xla_cpu_enable_fast_math(true);
170   options->set_xla_gpu_enable_fast_min_max(true);
171   options->set_xla_cpu_enable_fast_min_max(true);
172   options->set_xla_cpu_fast_math_honor_nans(false);
173   options->set_xla_cpu_fast_math_honor_infs(false);
174   options->set_xla_cpu_fast_math_honor_functions(false);
175   options->set_xla_cpu_fast_math_honor_division(false);
176 }
177 
GetDebugOptionsForTest()178 DebugOptions HloTestBase::GetDebugOptionsForTest() {
179   auto debug_options = GetDebugOptionsFromFlags();
180   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
181   debug_options.add_xla_disable_hlo_passes("constant_folding");
182   debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
183   debug_options.set_xla_hlo_evaluator_use_fast_path(true);
184   return debug_options;
185 }
186 
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)187 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
188                                        absl::Span<Literal* const> arguments) {
189   return test_runner_.Execute(std::move(module), arguments);
190 }
191 
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)192 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
193                                         absl::Span<Literal* const> arguments) {
194   return test_runner_
195       .Execute(std::move(module), arguments,
196                /*run_hlo_passes=*/false)
197       .ValueOrDie();
198 }
199 
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)200 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
201                                         absl::Span<Literal* const> arguments) {
202   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
203 }
204 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,bool use_threads,bool run_hlo_passes)205 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
206     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
207     int64 num_replicas, bool use_threads, bool run_hlo_passes) {
208   HloRunner::ReplicatedExecuteOptions options;
209   options.num_replicas = num_replicas;
210   options.run_hlo_passes = run_hlo_passes;
211   options.use_threads = use_threads;
212   for (auto argument : arguments) {
213     options.arguments.push_back(argument);
214   }
215   return test_runner_.ExecuteReplicated(std::move(module), options);
216 }
217 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)218 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
219     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
220     int64 num_replicas, DeviceAssignment* device_assignment,
221     bool run_hlo_passes, bool use_threads) {
222   HloRunner::ReplicatedExecuteOptions options;
223   options.num_replicas = num_replicas;
224   options.run_hlo_passes = run_hlo_passes;
225   options.use_threads = use_threads;
226   for (auto argument : arguments) {
227     options.arguments.push_back(argument);
228   }
229   return test_runner_.ExecuteReplicated(std::move(module), options,
230                                         device_assignment);
231 }
232 
ExecuteReplicated(std::function<Executable * (int64)> executable_provider,std::function<int64 (int64)> argument_count_provider,std::function<const Literal * (int64,int64)> argument_provider,int64 num_replicas,bool run_hlo_passes)233 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
234     std::function<Executable*(int64)> executable_provider,
235     std::function<int64(int64)> argument_count_provider,
236     std::function<const Literal*(int64, int64)> argument_provider,
237     int64 num_replicas, bool run_hlo_passes) {
238   HloRunner::ReplicatedExecuteOptions options;
239   options.num_replicas = num_replicas;
240   options.run_hlo_passes = run_hlo_passes;
241   options.use_threads = true;
242   return test_runner_.ExecuteReplicated(
243       executable_provider, argument_count_provider, argument_provider, options);
244 }
245 
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)246 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
247     const HloModule& test_module,
248     const std::function<void(HloModule*)>& reference_preprocessor) {
249   std::unique_ptr<HloModule> reference_module = test_module.Clone();
250   const auto& program_shape = GetProgramShapeWithLayout(test_module);
251 
252   if (reference_preprocessor != nullptr) {
253     reference_preprocessor(reference_module.get());
254     if (!ProgramShapesEqual(program_shape,
255                             GetProgramShapeWithLayout(*reference_module))) {
256       return InvalidArgument(
257           "reference preprocessor must not modify the program shape");
258     }
259   }
260   TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
261   return std::move(reference_module);
262 }
263 
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)264 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
265     std::unique_ptr<HloModule> module,
266     const absl::Span<Literal* const> arguments,
267     const optional<ErrorSpec>& error, bool run_hlo_passes,
268     const std::function<void(HloModule*)>& reference_preprocessor) {
269   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
270   TF_ASSIGN_OR_RETURN(auto reference_module,
271                       MakeReferenceModule(*module, reference_preprocessor));
272 
273   // Execute on two backends.
274   TF_ASSIGN_OR_RETURN(
275       auto test,
276       test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
277   TF_ASSIGN_OR_RETURN(auto reference,
278                       reference_runner_.Execute(std::move(reference_module),
279                                                 arguments, run_hlo_passes));
280   return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
281                                       error);
282 }
283 
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)284 ::testing::AssertionResult HloTestBase::RunAndCompare(
285     std::unique_ptr<HloModule> module,
286     const absl::Span<Literal* const> arguments,
287     const optional<ErrorSpec>& error,
288     const std::function<void(HloModule*)>& reference_preprocessor) {
289   auto result =
290       RunAndCompareInternal(std::move(module), arguments, error,
291                             /*run_hlo_passes=*/true, reference_preprocessor);
292   if (!result.ok()) {
293     return ::testing::AssertionFailure() << result.status();
294   }
295   return result.ValueOrDie();
296 }
297 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)298 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
299     std::unique_ptr<HloModule> module,
300     const absl::Span<Literal* const> arguments,
301     const optional<ErrorSpec>& error,
302     const std::function<void(HloModule*)>& reference_preprocessor) {
303   auto result =
304       RunAndCompareInternal(std::move(module), arguments, error,
305                             /*run_hlo_passes=*/false, reference_preprocessor);
306   if (!result.ok()) {
307     return ::testing::AssertionFailure() << result.status();
308   }
309   return result.ValueOrDie();
310 }
311 
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)312 ::testing::AssertionResult HloTestBase::RunAndCompare(
313     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
314     const std::function<void(HloModule*)>& reference_preprocessor) {
315   auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
316 
317   std::vector<Literal*> fake_argument_ptrs;
318   absl::c_transform(
319       fake_arguments, std::back_inserter(fake_argument_ptrs),
320       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
321 
322   return RunAndCompare(std::move(module), fake_argument_ptrs, error,
323                        reference_preprocessor);
324 }
325 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)326 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
327     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
328     const std::function<void(HloModule*)>& reference_preprocessor) {
329   const auto& fake_arguments =
330       MakeFakeArguments(module.get()).ConsumeValueOrDie();
331   std::vector<Literal*> fake_argument_ptrs;
332   absl::c_transform(
333       fake_arguments, std::back_inserter(fake_argument_ptrs),
334       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
335 
336   return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
337                                   reference_preprocessor);
338 }
339 
Run(std::unique_ptr<HloModule> module,bool run_hlo_passes)340 ::testing::AssertionResult HloTestBase::Run(std::unique_ptr<HloModule> module,
341                                             bool run_hlo_passes) {
342   const auto fake_arguments =
343       MakeFakeArguments(module.get()).ConsumeValueOrDie();
344   const auto change = hlo_verifier_->Run(module.get());
345   if (!change.ok()) {
346     return ::testing::AssertionFailure() << change.status();
347   }
348 
349   const auto output =
350       test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes);
351   return output.ok()
352              ? ::testing::AssertionSuccess()
353              : ::testing::AssertionFailure() << output.status().error_message();
354 }
355 
RunAndCompare(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)356 ::testing::AssertionResult HloTestBase::RunAndCompare(
357     string_view hlo_string, const absl::optional<ErrorSpec>& error,
358     const std::function<void(HloModule*)>& reference_preprocessor) {
359   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
360   if (!module_or_status.ok()) {
361     return ::testing::AssertionFailure()
362            << "Error while parsing HLO text format: "
363            << module_or_status.status().ToString();
364   }
365   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
366                        reference_preprocessor);
367 }
368 
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,string backend_config)369 ::testing::AssertionResult HloTestBase::Run(string_view hlo_string,
370                                             bool run_hlo_passes,
371                                             ExecutionProfile* profile,
372                                             string backend_config) {
373   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
374   if (!module_or_status.ok()) {
375     return ::testing::AssertionFailure()
376            << "Error while parsing HLO text format: "
377            << module_or_status.status().ToString();
378   }
379 
380   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
381   const auto& fake_arguments =
382       MakeFakeArguments(module.get()).ConsumeValueOrDie();
383   std::vector<Literal*> fake_argument_ptrs;
384   absl::c_transform(
385       fake_arguments, std::back_inserter(fake_argument_ptrs),
386       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
387 
388   if (profile != nullptr) {
389     // We have to enable HLO profiling since otherwise currently the
390     // ExecutionProfile is not correct.
391     //
392     // TODO(b/119432044): Fix collection of the ExecutionProfile
393     // so that this is not necessary.
394     HloModuleConfig config = module->config();
395     DebugOptions debug_options = config.debug_options();
396     debug_options.set_xla_hlo_profile(true);
397     config.set_debug_options(debug_options);
398     module->set_config(config);
399   }
400 
401   if (!backend_config.empty()) {
402     // Set backend configuration if it is given.
403     HloInstruction* instruction =
404         module->entry_computation()->root_instruction();
405     instruction->set_raw_backend_config_string(backend_config);
406   }
407 
408   auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
409                                      /*run_hlo_passes=*/run_hlo_passes,
410                                      /*profile=*/profile);
411 
412   return output.ok()
413              ? ::testing::AssertionSuccess()
414              : ::testing::AssertionFailure() << output.status().error_message();
415 }
416 
RunReplicated(string_view hlo_string,bool run_hlo_passes,int64 num_replicas,string backend_config)417 ::testing::AssertionResult HloTestBase::RunReplicated(string_view hlo_string,
418                                                       bool run_hlo_passes,
419                                                       int64 num_replicas,
420                                                       string backend_config) {
421   auto module_or_status =
422       ParseAndReturnVerifiedModule(hlo_string, num_replicas);
423   if (!module_or_status.ok()) {
424     return ::testing::AssertionFailure()
425            << "Error while parsing HLO text format: "
426            << module_or_status.status().ToString();
427   }
428 
429   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
430   const auto& fake_arguments =
431       MakeFakeArguments(module.get()).ConsumeValueOrDie();
432   std::vector<Literal*> fake_argument_ptrs;
433   absl::c_transform(
434       fake_arguments, std::back_inserter(fake_argument_ptrs),
435       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
436 
437   if (!backend_config.empty()) {
438     // Set backend configuration if it is given.
439     HloInstruction* instruction =
440         module->entry_computation()->root_instruction();
441     instruction->set_raw_backend_config_string(backend_config);
442   }
443 
444   HloRunner::ReplicatedExecuteOptions options;
445   options.num_replicas = num_replicas;
446   options.run_hlo_passes = run_hlo_passes;
447   options.use_threads = true;
448   for (auto argument : fake_argument_ptrs) {
449     options.arguments.push_back(argument);
450   }
451   auto output = test_runner_.ExecuteReplicated(std::move(module), options);
452 
453   return output.ok()
454              ? ::testing::AssertionSuccess()
455              : ::testing::AssertionFailure() << output.status().error_message();
456 }
457 
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,string backend_config,bool assert_determinism)458 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
459     string_view hlo_string, bool run_hlo_passes,
460     std::vector<ExecutionProfile>* profiles, string backend_config,
461     bool assert_determinism) {
462   int n = profiles->size();
463   std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
464   std::vector<std::vector<Literal>> fake_arguments(n);
465   std::vector<std::unique_ptr<Executable>> executables(n);
466 
467   for (int i = 0; i < n; ++i) {
468     auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
469     if (!module_or_status.ok()) {
470       return ::testing::AssertionFailure()
471              << "Error while parsing HLO text format: "
472              << module_or_status.status().ToString();
473     }
474     std::unique_ptr<HloModule> module =
475         std::move(module_or_status.ValueOrDie());
476 
477     fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
478 
479     if (profiles != nullptr) {
480       // We have to enable HLO profiling since otherwise currently the
481       // ExecutionProfile is not correct.
482       //
483       // TODO(b/119432044): Fix collection of the ExecutionProfile
484       // so that this is not necessary.
485       HloModuleConfig config = module->config();
486       DebugOptions debug_options = config.debug_options();
487       debug_options.set_xla_hlo_profile(true);
488       config.set_debug_options(debug_options);
489       module->set_config(config);
490     }
491 
492     if (!backend_config.empty()) {
493       // Set backend configuration if it is given.
494       HloInstruction* instruction =
495           module->entry_computation()->root_instruction();
496       instruction->set_raw_backend_config_string(backend_config);
497     }
498 
499     auto executable =
500         test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
501     if (!executable.ok()) {
502       return ::testing::AssertionFailure()
503              << executable.status().error_message();
504     }
505     executables[i] = std::move(executable.ValueOrDie());
506   }
507 
508   absl::optional<Literal> canonical_output;
509   for (int i = 0; i < n; ++i) {
510     StatusOr<Literal> output = test_runner_.ExecuteWithExecutable(
511         std::move(executables[i]), fake_arguments[i],
512         /*profile=*/&((*profiles)[i]));
513     if (!output.ok()) {
514       return ::testing::AssertionFailure() << output.status().error_message();
515     }
516 
517     if (assert_determinism) {
518       if (!canonical_output.has_value()) {
519         canonical_output = output.ConsumeValueOrDie();
520       } else {
521         if (*canonical_output != output.ValueOrDie()) {
522           return ::testing::AssertionFailure()
523                  << "Successive runs have returned different results: "
524                  << *canonical_output << " vs. " << output.ValueOrDie();
525         }
526       }
527     }
528   }
529 
530   return ::testing::AssertionSuccess();
531 }
532 
RunAndCompareFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)533 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
534     const string& filename, const absl::optional<ErrorSpec>& error,
535     const std::function<void(HloModule*)>& reference_preprocessor) {
536   auto module_or_status =
537       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
538   if (!module_or_status.ok()) {
539     return ::testing::AssertionFailure()
540            << "failed reading hlo module from file";
541   }
542   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
543                        reference_preprocessor);
544 }
545 
RunAndCompareNoHloPasses(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)546 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
547     string_view hlo_string, const absl::optional<ErrorSpec>& error,
548     const std::function<void(HloModule*)>& reference_preprocessor) {
549   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
550   if (!module_or_status.ok()) {
551     return ::testing::AssertionFailure()
552            << "Error while parsing HLO text format: "
553            << module_or_status.status().ToString();
554   }
555   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
556                                   reference_preprocessor);
557 }
558 
RunAndCompareNoHloPassesFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)559 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
560     const string& filename, const absl::optional<ErrorSpec>& error,
561     const std::function<void(HloModule*)>& reference_preprocessor) {
562   auto module_or_status =
563       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
564   if (!module_or_status.ok()) {
565     return ::testing::AssertionFailure()
566            << "failed reading hlo module from file";
567   }
568   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
569                                   reference_preprocessor);
570 }
571 
FindComputation(HloModule * module,absl::string_view name)572 HloComputation* HloTestBase::FindComputation(HloModule* module,
573                                              absl::string_view name) {
574   auto computations = module->computations();
575   auto it = absl::c_find_if(
576       computations, [&](HloComputation* c) { return c->name() == name; });
577   if (it == computations.end()) {
578     return nullptr;
579   }
580   return *it;
581 }
582 
FindInstruction(HloModule * module,absl::string_view name)583 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
584                                              absl::string_view name) {
585   for (const HloComputation* c : module->computations()) {
586     auto instructions = c->instructions();
587     auto it = absl::c_find_if(
588         instructions, [&](HloInstruction* i) { return i->name() == name; });
589     if (it != instructions.end()) {
590       return *it;
591     }
592   }
593   return nullptr;
594 }
595 
FindInstruction(HloModule * module,HloOpcode opcode)596 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
597                                              HloOpcode opcode) {
598   for (const HloComputation* c : module->computations()) {
599     auto instructions = c->instructions();
600     auto it = absl::c_find_if(
601         instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
602     if (it != instructions.end()) {
603       return *it;
604     }
605   }
606   return nullptr;
607 }
608 
backend()609 Backend& HloTestBase::backend() { return test_runner_.backend(); }
610 
611 /* static */
TestName()612 string HloTestBase::TestName() {
613   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
614 }
615 
616 }  // namespace xla
617