• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 
43 namespace {
44 
45 using absl::optional;
46 using absl::string_view;
47 
48 constexpr char kInterpreter[] = "interpreter";
49 
50 // Helper functions to get test and reference platforms.
GetReferencePlatform()51 se::Platform* GetReferencePlatform() {
52   auto result = PlatformUtil::GetPlatform(kInterpreter);
53   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
54   return result.ValueOrDie();
55 }
56 
GetTestPlatform()57 se::Platform* GetTestPlatform() {
58   auto result = PlatformUtil::GetDefaultPlatform();
59   TF_CHECK_OK(result.status()) << "could not get test platform";
60   return result.ValueOrDie();
61 }
62 
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
64   if (lhs.parameters_size() != rhs.parameters_size()) {
65     return false;
66   }
67   for (int i = 0; i < lhs.parameters_size(); i++) {
68     if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
69       return false;
70     }
71   }
72   return ShapeUtil::Equal(lhs.result(), rhs.result());
73 }
74 
GetProgramShapeWithLayout(const HloModule & module)75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
76   ProgramShape program_shape;
77   const auto* entry = module.entry_computation();
78   for (const auto* param : entry->parameter_instructions()) {
79     *program_shape.add_parameters() = param->shape();
80     *program_shape.add_parameter_names() = param->name();
81   }
82   *program_shape.mutable_result() = entry->root_instruction()->shape();
83   return program_shape;
84 }
85 
86 }  // namespace
87 
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)88 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
89                          bool allow_mixed_precision_in_hlo_verifier,
90                          std::function<bool(const HloInstruction*)>
91                              instruction_can_change_layout_func)
92     : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
93                   verifier_layout_sensitive,
94                   allow_mixed_precision_in_hlo_verifier,
95                   instruction_can_change_layout_func) {}
96 
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)97 HloTestBase::HloTestBase(se::Platform* test_platform,
98                          se::Platform* reference_platform,
99                          bool verifier_layout_sensitive,
100                          bool allow_mixed_precision_in_hlo_verifier,
101                          std::function<bool(const HloInstruction*)>
102                              instruction_can_change_layout_func)
103     : test_runner_(test_platform),
104       reference_runner_(reference_platform),
105       verifier_layout_sensitive_(verifier_layout_sensitive),
106       allow_mixed_precision_in_hlo_verifier_(
107           allow_mixed_precision_in_hlo_verifier) {
108   hlo_verifier_ = absl::make_unique<HloVerifier>(
109       /*layout_sensitive=*/verifier_layout_sensitive,
110       /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
111       instruction_can_change_layout_func);
112 }
113 
CreateNewUnverifiedModule(const string & name)114 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
115     const string& name) {
116   return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
117 }
118 
CreateNewVerifiedModule(const string & name)119 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
120     const string& name) {
121   return absl::make_unique<VerifiedHloModule>(
122       name, GetModuleConfigForTest(), verifier_layout_sensitive_,
123       allow_mixed_precision_in_hlo_verifier_,
124       backend().compiler()->ShapeSizeBytesFunction());
125 }
126 
127 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text)128 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
129   return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest());
130 }
131 
132 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)133 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
134                                           const HloModuleConfig& config) {
135   auto module = absl::make_unique<VerifiedHloModule>(
136       TestName(), config, verifier_layout_sensitive_,
137       allow_mixed_precision_in_hlo_verifier_,
138       backend().compiler()->ShapeSizeBytesFunction());
139   TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
140   return std::move(module);
141 }
142 
143 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)144 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
145                                        HloModule* module) {
146   const string module_str_before_run = module->ToProto().ShortDebugString();
147   const auto status_or = hlo_pass->Run(module);
148   if (status_or.status().ok()) {
149     const string module_str_after_run = module->ToProto().ShortDebugString();
150     if (!status_or.ValueOrDie()) {
151       // Check that the proto remains same.
152       EXPECT_EQ(module_str_after_run, module_str_before_run);
153     }
154   }
155   return status_or;
156 }
157 
158 /* static */
DefaultPrecisionConfig(int operands)159 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
160   PrecisionConfig precision_config;
161   precision_config.mutable_operand_precision()->Resize(
162       operands, PrecisionConfig::DEFAULT);
163   return precision_config;
164 }
165 
GetDebugOptionsForTest()166 DebugOptions HloTestBase::GetDebugOptionsForTest() {
167   auto debug_options = GetDebugOptionsFromFlags();
168   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
169   debug_options.add_xla_disable_hlo_passes("constant_folding");
170   debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
171   debug_options.set_xla_hlo_evaluator_use_fast_path(true);
172   return debug_options;
173 }
174 
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)175 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
176                                        absl::Span<Literal* const> arguments) {
177   return test_runner_.Execute(std::move(module), arguments);
178 }
179 
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)180 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
181                                         absl::Span<Literal* const> arguments) {
182   return test_runner_
183       .Execute(std::move(module), arguments,
184                /*run_hlo_passes=*/false)
185       .ValueOrDie();
186 }
187 
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)188 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
189                                         absl::Span<Literal* const> arguments) {
190   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
191 }
192 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,bool use_threads,bool run_hlo_passes)193 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
194     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
195     int64 num_replicas, bool use_threads, bool run_hlo_passes) {
196   HloRunner::ReplicatedExecuteOptions options;
197   options.num_replicas = num_replicas;
198   options.run_hlo_passes = run_hlo_passes;
199   options.use_threads = use_threads;
200   for (auto argument : arguments) {
201     options.arguments.push_back(argument);
202   }
203   return test_runner_.ExecuteReplicated(std::move(module), options);
204 }
205 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)206 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
207     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
208     int64 num_replicas, DeviceAssignment* device_assignment,
209     bool run_hlo_passes, bool use_threads) {
210   HloRunner::ReplicatedExecuteOptions options;
211   options.num_replicas = num_replicas;
212   options.run_hlo_passes = run_hlo_passes;
213   options.use_threads = use_threads;
214   for (auto argument : arguments) {
215     options.arguments.push_back(argument);
216   }
217   return test_runner_.ExecuteReplicated(std::move(module), options,
218                                         device_assignment);
219 }
220 
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)221 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
222     const HloModule& test_module,
223     const std::function<void(HloModule*)>& reference_preprocessor) {
224   std::unique_ptr<HloModule> reference_module = test_module.Clone();
225   const auto& program_shape = GetProgramShapeWithLayout(test_module);
226 
227   if (reference_preprocessor != nullptr) {
228     reference_preprocessor(reference_module.get());
229     if (!ProgramShapesEqual(program_shape,
230                             GetProgramShapeWithLayout(*reference_module))) {
231       return InvalidArgument(
232           "reference preprocessor must not modify the program shape");
233     }
234   }
235   TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
236   return std::move(reference_module);
237 }
238 
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)239 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
240     std::unique_ptr<HloModule> module,
241     const absl::Span<Literal* const> arguments,
242     const optional<ErrorSpec>& error, bool run_hlo_passes,
243     const std::function<void(HloModule*)>& reference_preprocessor) {
244   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
245   TF_ASSIGN_OR_RETURN(auto reference_module,
246                       MakeReferenceModule(*module, reference_preprocessor));
247 
248   // Execute on two backends.
249   TF_ASSIGN_OR_RETURN(
250       auto test,
251       test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
252   TF_ASSIGN_OR_RETURN(auto reference,
253                       reference_runner_.Execute(std::move(reference_module),
254                                                 arguments, run_hlo_passes));
255   return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
256                                       error);
257 }
258 
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)259 ::testing::AssertionResult HloTestBase::RunAndCompare(
260     std::unique_ptr<HloModule> module,
261     const absl::Span<Literal* const> arguments,
262     const optional<ErrorSpec>& error,
263     const std::function<void(HloModule*)>& reference_preprocessor) {
264   auto result =
265       RunAndCompareInternal(std::move(module), arguments, error,
266                             /*run_hlo_passes=*/true, reference_preprocessor);
267   if (!result.ok()) {
268     return ::testing::AssertionFailure() << result.status();
269   }
270   return result.ValueOrDie();
271 }
272 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)273 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
274     std::unique_ptr<HloModule> module,
275     const absl::Span<Literal* const> arguments,
276     const optional<ErrorSpec>& error,
277     const std::function<void(HloModule*)>& reference_preprocessor) {
278   auto result =
279       RunAndCompareInternal(std::move(module), arguments, error,
280                             /*run_hlo_passes=*/false, reference_preprocessor);
281   if (!result.ok()) {
282     return ::testing::AssertionFailure() << result.status();
283   }
284   return result.ValueOrDie();
285 }
286 
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)287 ::testing::AssertionResult HloTestBase::RunAndCompare(
288     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
289     const std::function<void(HloModule*)>& reference_preprocessor) {
290   auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
291 
292   std::vector<Literal*> fake_argument_ptrs;
293   absl::c_transform(
294       fake_arguments, std::back_inserter(fake_argument_ptrs),
295       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
296 
297   return RunAndCompare(std::move(module), fake_argument_ptrs, error,
298                        reference_preprocessor);
299 }
300 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)301 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
302     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
303     const std::function<void(HloModule*)>& reference_preprocessor) {
304   const auto& fake_arguments =
305       MakeFakeArguments(module.get()).ConsumeValueOrDie();
306   std::vector<Literal*> fake_argument_ptrs;
307   absl::c_transform(
308       fake_arguments, std::back_inserter(fake_argument_ptrs),
309       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
310 
311   return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
312                                   reference_preprocessor);
313 }
314 
RunAndCompare(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)315 ::testing::AssertionResult HloTestBase::RunAndCompare(
316     string_view hlo_string, const absl::optional<ErrorSpec>& error,
317     const std::function<void(HloModule*)>& reference_preprocessor) {
318   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
319   if (!module_or_status.ok()) {
320     return ::testing::AssertionFailure()
321            << "Error while parsing HLO text format: "
322            << module_or_status.status().ToString();
323   }
324   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
325                        reference_preprocessor);
326 }
327 
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,string backend_config)328 ::testing::AssertionResult HloTestBase::Run(string_view hlo_string,
329                                             bool run_hlo_passes,
330                                             ExecutionProfile* profile,
331                                             string backend_config) {
332   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
333   if (!module_or_status.ok()) {
334     return ::testing::AssertionFailure()
335            << "Error while parsing HLO text format: "
336            << module_or_status.status().ToString();
337   }
338 
339   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
340   const auto& fake_arguments =
341       MakeFakeArguments(module.get()).ConsumeValueOrDie();
342   std::vector<Literal*> fake_argument_ptrs;
343   absl::c_transform(
344       fake_arguments, std::back_inserter(fake_argument_ptrs),
345       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
346 
347   if (profile != nullptr) {
348     // We have to enable HLO profiling since otherwise currently the
349     // ExecutionProfile is not correct.
350     //
351     // TODO(b/119432044): Fix collection of the ExecutionProfile
352     // so that this is not necessary.
353     HloModuleConfig config = module->config();
354     DebugOptions debug_options = config.debug_options();
355     debug_options.set_xla_hlo_profile(true);
356     config.set_debug_options(debug_options);
357     module->set_config(config);
358   }
359 
360   if (!backend_config.empty()) {
361     // Set backend configuration if it is given.
362     HloInstruction* instruction =
363         module->entry_computation()->root_instruction();
364     instruction->set_raw_backend_config_string(backend_config);
365   }
366 
367   auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
368                                      /*run_hlo_passes=*/run_hlo_passes,
369                                      /*profile=*/profile);
370 
371   return output.ok()
372              ? ::testing::AssertionSuccess()
373              : ::testing::AssertionFailure() << output.status().error_message();
374 }
375 
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,string backend_config,bool assert_determinism)376 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
377     string_view hlo_string, bool run_hlo_passes,
378     std::vector<ExecutionProfile>* profiles, string backend_config,
379     bool assert_determinism) {
380   int n = profiles->size();
381   std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
382   std::vector<std::vector<Literal>> fake_arguments(n);
383   std::vector<std::unique_ptr<Executable>> executables(n);
384 
385   for (int i = 0; i < n; ++i) {
386     auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
387     if (!module_or_status.ok()) {
388       return ::testing::AssertionFailure()
389              << "Error while parsing HLO text format: "
390              << module_or_status.status().ToString();
391     }
392     std::unique_ptr<HloModule> module =
393         std::move(module_or_status.ValueOrDie());
394 
395     fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
396     absl::c_transform(
397         fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
398         [](const Literal& literal) { return const_cast<Literal*>(&literal); });
399 
400     if (profiles != nullptr) {
401       // We have to enable HLO profiling since otherwise currently the
402       // ExecutionProfile is not correct.
403       //
404       // TODO(b/119432044): Fix collection of the ExecutionProfile
405       // so that this is not necessary.
406       HloModuleConfig config = module->config();
407       DebugOptions debug_options = config.debug_options();
408       debug_options.set_xla_hlo_profile(true);
409       config.set_debug_options(debug_options);
410       module->set_config(config);
411     }
412 
413     if (!backend_config.empty()) {
414       // Set backend configuration if it is given.
415       HloInstruction* instruction =
416           module->entry_computation()->root_instruction();
417       instruction->set_raw_backend_config_string(backend_config);
418     }
419 
420     auto executable =
421         test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
422     if (!executable.ok()) {
423       return ::testing::AssertionFailure()
424              << executable.status().error_message();
425     }
426     executables[i] = std::move(executable.ValueOrDie());
427   }
428 
429   absl::optional<Literal> canonical_output;
430   for (int i = 0; i < n; ++i) {
431     StatusOr<Literal> output =
432         test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
433                              /*profile=*/&((*profiles)[i]));
434     if (!output.ok()) {
435       return ::testing::AssertionFailure() << output.status().error_message();
436     }
437 
438     if (assert_determinism) {
439       if (!canonical_output.has_value()) {
440         canonical_output = output.ConsumeValueOrDie();
441       } else {
442         if (*canonical_output != output.ValueOrDie()) {
443           return ::testing::AssertionFailure()
444                  << "Successive runs have returned different results: "
445                  << *canonical_output << " vs. " << output.ValueOrDie();
446         }
447       }
448     }
449   }
450 
451   return ::testing::AssertionSuccess();
452 }
453 
RunAndCompareFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)454 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
455     const string& filename, const absl::optional<ErrorSpec>& error,
456     const std::function<void(HloModule*)>& reference_preprocessor) {
457   auto module_or_status =
458       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
459   if (!module_or_status.ok()) {
460     return ::testing::AssertionFailure()
461            << "failed reading hlo module from file";
462   }
463   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
464                        reference_preprocessor);
465 }
466 
RunAndCompareNoHloPasses(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)467 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
468     string_view hlo_string, const absl::optional<ErrorSpec>& error,
469     const std::function<void(HloModule*)>& reference_preprocessor) {
470   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
471   if (!module_or_status.ok()) {
472     return ::testing::AssertionFailure()
473            << "Error while parsing HLO text format: "
474            << module_or_status.status().ToString();
475   }
476   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
477                                   reference_preprocessor);
478 }
479 
RunAndCompareNoHloPassesFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)480 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
481     const string& filename, const absl::optional<ErrorSpec>& error,
482     const std::function<void(HloModule*)>& reference_preprocessor) {
483   auto module_or_status =
484       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
485   if (!module_or_status.ok()) {
486     return ::testing::AssertionFailure()
487            << "failed reading hlo module from file";
488   }
489   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
490                                   reference_preprocessor);
491 }
492 
FindComputation(HloModule * module,absl::string_view name)493 HloComputation* HloTestBase::FindComputation(HloModule* module,
494                                              absl::string_view name) {
495   auto computations = module->computations();
496   auto it = absl::c_find_if(
497       computations, [&](HloComputation* c) { return c->name() == name; });
498   if (it == computations.end()) {
499     return nullptr;
500   }
501   return *it;
502 }
503 
FindInstruction(HloModule * module,absl::string_view name)504 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
505                                              absl::string_view name) {
506   for (const HloComputation* c : module->computations()) {
507     auto instructions = c->instructions();
508     auto it = absl::c_find_if(
509         instructions, [&](HloInstruction* i) { return i->name() == name; });
510     if (it != instructions.end()) {
511       return *it;
512     }
513   }
514   return nullptr;
515 }
516 
FindInstruction(HloModule * module,HloOpcode opcode)517 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
518                                              HloOpcode opcode) {
519   for (const HloComputation* c : module->computations()) {
520     auto instructions = c->instructions();
521     auto it = absl::c_find_if(
522         instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
523     if (it != instructions.end()) {
524       return *it;
525     }
526   }
527   return nullptr;
528 }
529 
backend()530 Backend& HloTestBase::backend() { return test_runner_.backend(); }
531 
532 /* static */
TestName()533 string HloTestBase::TestName() {
534   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
535 }
536 
537 }  // namespace xla
538