• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace xla {
42 
43 namespace {
44 
45 using absl::optional;
46 using absl::string_view;
47 
48 constexpr char kInterpreter[] = "interpreter";
49 
50 // Helper functions to get test and reference platforms.
GetReferencePlatform()51 se::Platform* GetReferencePlatform() {
52   auto result = PlatformUtil::GetPlatform(kInterpreter);
53   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
54   return result.ValueOrDie();
55 }
56 
GetTestPlatform()57 se::Platform* GetTestPlatform() {
58   auto result = PlatformUtil::GetDefaultPlatform();
59   TF_CHECK_OK(result.status()) << "could not get test platform";
60   return result.ValueOrDie();
61 }
62 
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
64   if (lhs.parameters_size() != rhs.parameters_size()) {
65     return false;
66   }
67   for (int i = 0; i < lhs.parameters_size(); i++) {
68     if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
69       return false;
70     }
71   }
72   return ShapeUtil::Equal(lhs.result(), rhs.result());
73 }
74 
GetProgramShapeWithLayout(const HloModule & module)75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
76   ProgramShape program_shape;
77   const auto* entry = module.entry_computation();
78   for (const auto* param : entry->parameter_instructions()) {
79     *program_shape.add_parameters() = param->shape();
80     *program_shape.add_parameter_names() = param->name();
81   }
82   *program_shape.mutable_result() = entry->root_instruction()->shape();
83   return program_shape;
84 }
85 
86 }  // namespace
87 
Verify()88 Status VerifiedHloModule::Verify() {
89   if (computation_count() == 0) {
90     // The computation was never built. Nothing to verify.
91     return Status::OK();
92   }
93   return verifier_.Run(this).status();
94 }
95 
VerifyOrAddFailure(const string & message)96 void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
97   Status status = Verify();
98   if (!status.ok()) {
99     ADD_FAILURE() << "HloVerifier failed on module " << name()
100                   << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
101                   << ": " << status;
102     LOG(ERROR) << "Contents of bad module:";
103     XLA_LOG_LINES(tensorflow::ERROR, ToString());
104   }
105 }
106 
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)107 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
108                          bool allow_mixed_precision_in_hlo_verifier,
109                          std::function<bool(const HloInstruction*)>
110                              instruction_can_change_layout_func)
111     : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
112                   verifier_layout_sensitive,
113                   allow_mixed_precision_in_hlo_verifier,
114                   instruction_can_change_layout_func) {}
115 
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)116 HloTestBase::HloTestBase(se::Platform* test_platform,
117                          se::Platform* reference_platform,
118                          bool verifier_layout_sensitive,
119                          bool allow_mixed_precision_in_hlo_verifier,
120                          std::function<bool(const HloInstruction*)>
121                              instruction_can_change_layout_func)
122     : test_runner_(test_platform),
123       reference_runner_(reference_platform),
124       verifier_layout_sensitive_(verifier_layout_sensitive),
125       allow_mixed_precision_in_hlo_verifier_(
126           allow_mixed_precision_in_hlo_verifier) {
127   hlo_verifier_ = absl::make_unique<HloVerifier>(
128       /*layout_sensitive=*/verifier_layout_sensitive,
129       /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
130       instruction_can_change_layout_func);
131 }
132 
CreateNewUnverifiedModule(const string & name)133 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
134     const string& name) {
135   return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
136 }
137 
CreateNewVerifiedModule(const string & name)138 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
139     const string& name) {
140   return absl::make_unique<VerifiedHloModule>(
141       name, GetModuleConfigForTest(), verifier_layout_sensitive_,
142       allow_mixed_precision_in_hlo_verifier_,
143       backend().compiler()->ShapeSizeBytesFunction());
144 }
145 
146 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)147 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
148                                           const HloModuleConfig& config) {
149   auto module = absl::make_unique<VerifiedHloModule>(
150       TestName(), config, verifier_layout_sensitive_,
151       allow_mixed_precision_in_hlo_verifier_,
152       backend().compiler()->ShapeSizeBytesFunction());
153   TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
154   TF_RETURN_IF_ERROR(module->Verify());
155   return std::move(module);
156 }
157 
158 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)159 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
160                                        HloModule* module) {
161   const string module_str_before_run = module->ToProto().ShortDebugString();
162   const auto status_or = hlo_pass->Run(module);
163   if (status_or.status().ok()) {
164     const string module_str_after_run = module->ToProto().ShortDebugString();
165     if (!status_or.ValueOrDie()) {
166       // Check that the proto remains same.
167       EXPECT_EQ(module_str_after_run, module_str_before_run);
168     }
169   }
170   return status_or;
171 }
172 
173 /* static */
DefaultPrecisionConfig(int operands)174 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
175   PrecisionConfig precision_config;
176   precision_config.mutable_operand_precision()->Resize(
177       operands, PrecisionConfig::DEFAULT);
178   return precision_config;
179 }
180 
GetDebugOptionsForTest()181 DebugOptions HloTestBase::GetDebugOptionsForTest() {
182   auto debug_options = GetDebugOptionsFromFlags();
183   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
184   debug_options.add_xla_disable_hlo_passes("constant_folding");
185   debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
186   debug_options.set_xla_hlo_evaluator_use_fast_path(true);
187   return debug_options;
188 }
189 
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)190 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
191                                        absl::Span<Literal* const> arguments) {
192   return test_runner_.Execute(std::move(module), arguments);
193 }
194 
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)195 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
196                                         absl::Span<Literal* const> arguments) {
197   return test_runner_
198       .Execute(std::move(module), arguments,
199                /*run_hlo_passes=*/false)
200       .ValueOrDie();
201 }
202 
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)203 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
204                                         absl::Span<Literal* const> arguments) {
205   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
206 }
207 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,bool use_threads)208 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
209     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
210     int64 num_replicas, bool use_threads) {
211   HloRunner::ReplicatedExecuteOptions options;
212   options.num_replicas = num_replicas;
213   for (auto argument : arguments) {
214     options.arguments.push_back(argument);
215   }
216   return test_runner_.ExecuteReplicated(std::move(module), options,
217                                         use_threads);
218 }
219 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)220 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
221     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
222     int64 num_replicas, DeviceAssignment* device_assignment,
223     bool run_hlo_passes, bool use_threads) {
224   HloRunner::ReplicatedExecuteOptions options;
225   options.num_replicas = num_replicas;
226   options.run_hlo_passes = run_hlo_passes;
227   for (auto argument : arguments) {
228     options.arguments.push_back(argument);
229   }
230   return test_runner_.ExecuteReplicated(std::move(module), options,
231                                         device_assignment, use_threads);
232 }
233 
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)234 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
235     const HloModule& test_module,
236     const std::function<void(HloModule*)>& reference_preprocessor) {
237   std::unique_ptr<HloModule> reference_module = test_module.Clone();
238   const auto& program_shape = GetProgramShapeWithLayout(test_module);
239 
240   if (reference_preprocessor != nullptr) {
241     reference_preprocessor(reference_module.get());
242     if (!ProgramShapesEqual(program_shape,
243                             GetProgramShapeWithLayout(*reference_module))) {
244       return InvalidArgument(
245           "reference preprocessor must not modify the program shape");
246     }
247   }
248   TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
249   return std::move(reference_module);
250 }
251 
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)252 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
253     std::unique_ptr<HloModule> module,
254     const absl::Span<Literal* const> arguments,
255     const optional<ErrorSpec>& error, bool run_hlo_passes,
256     const std::function<void(HloModule*)>& reference_preprocessor) {
257   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
258   TF_ASSIGN_OR_RETURN(auto reference_module,
259                       MakeReferenceModule(*module, reference_preprocessor));
260 
261   // Execute on two backends.
262   TF_ASSIGN_OR_RETURN(
263       auto test,
264       test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
265   TF_ASSIGN_OR_RETURN(auto reference,
266                       reference_runner_.Execute(std::move(reference_module),
267                                                 arguments, run_hlo_passes));
268   return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
269                                       error);
270 }
271 
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)272 ::testing::AssertionResult HloTestBase::RunAndCompare(
273     std::unique_ptr<HloModule> module,
274     const absl::Span<Literal* const> arguments,
275     const optional<ErrorSpec>& error,
276     const std::function<void(HloModule*)>& reference_preprocessor) {
277   auto result =
278       RunAndCompareInternal(std::move(module), arguments, error,
279                             /*run_hlo_passes=*/true, reference_preprocessor);
280   if (!result.ok()) {
281     return ::testing::AssertionFailure() << result.status();
282   }
283   return result.ValueOrDie();
284 }
285 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)286 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
287     std::unique_ptr<HloModule> module,
288     const absl::Span<Literal* const> arguments,
289     const optional<ErrorSpec>& error,
290     const std::function<void(HloModule*)>& reference_preprocessor) {
291   auto result =
292       RunAndCompareInternal(std::move(module), arguments, error,
293                             /*run_hlo_passes=*/false, reference_preprocessor);
294   if (!result.ok()) {
295     return ::testing::AssertionFailure() << result.status();
296   }
297   return result.ValueOrDie();
298 }
299 
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)300 ::testing::AssertionResult HloTestBase::RunAndCompare(
301     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
302     const std::function<void(HloModule*)>& reference_preprocessor) {
303   auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
304 
305   std::vector<Literal*> fake_argument_ptrs;
306   absl::c_transform(
307       fake_arguments, std::back_inserter(fake_argument_ptrs),
308       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
309 
310   return RunAndCompare(std::move(module), fake_argument_ptrs, error,
311                        reference_preprocessor);
312 }
313 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)314 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
315     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
316     const std::function<void(HloModule*)>& reference_preprocessor) {
317   const auto& fake_arguments =
318       MakeFakeArguments(module.get()).ConsumeValueOrDie();
319   std::vector<Literal*> fake_argument_ptrs;
320   absl::c_transform(
321       fake_arguments, std::back_inserter(fake_argument_ptrs),
322       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
323 
324   return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
325                                   reference_preprocessor);
326 }
327 
RunAndCompare(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)328 ::testing::AssertionResult HloTestBase::RunAndCompare(
329     string_view hlo_string, const absl::optional<ErrorSpec>& error,
330     const std::function<void(HloModule*)>& reference_preprocessor) {
331   auto module_or_status =
332       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
333   if (!module_or_status.ok()) {
334     return ::testing::AssertionFailure()
335            << "Error while parsing HLO text format: "
336            << module_or_status.status().ToString();
337   }
338   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
339                        reference_preprocessor);
340 }
341 
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,string backend_config)342 ::testing::AssertionResult HloTestBase::Run(string_view hlo_string,
343                                             bool run_hlo_passes,
344                                             ExecutionProfile* profile,
345                                             string backend_config) {
346   auto module_or_status =
347       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
348   if (!module_or_status.ok()) {
349     return ::testing::AssertionFailure()
350            << "Error while parsing HLO text format: "
351            << module_or_status.status().ToString();
352   }
353 
354   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
355   const auto& fake_arguments =
356       MakeFakeArguments(module.get()).ConsumeValueOrDie();
357   std::vector<Literal*> fake_argument_ptrs;
358   absl::c_transform(
359       fake_arguments, std::back_inserter(fake_argument_ptrs),
360       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
361 
362   if (profile != nullptr) {
363     // We have to enable HLO profiling since otherwise currently the
364     // ExecutionProfile is not correct.
365     //
366     // TODO(b/119432044): Fix collection of the ExecutionProfile
367     // so that this is not necessary.
368     HloModuleConfig config = module->config();
369     DebugOptions debug_options = config.debug_options();
370     debug_options.set_xla_hlo_profile(true);
371     config.set_debug_options(debug_options);
372     module->set_config(config);
373   }
374 
375   if (!backend_config.empty()) {
376     // Set backend configuration if it is given.
377     HloInstruction* instruction =
378         module->entry_computation()->root_instruction();
379     instruction->set_raw_backend_config_string(backend_config);
380   }
381 
382   // return ::testing::AssertionSuccess();
383   auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
384                                      /*run_hlo_passes=*/run_hlo_passes,
385                                      /*profile=*/profile);
386 
387   return output.ok()
388              ? ::testing::AssertionSuccess()
389              : ::testing::AssertionFailure() << output.status().error_message();
390 }
391 
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,string backend_config)392 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
393     string_view hlo_string, bool run_hlo_passes,
394     std::vector<ExecutionProfile>* profiles, string backend_config) {
395   int n = profiles->size();
396   std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
397   std::vector<std::vector<Literal>> fake_arguments(n);
398   std::vector<std::unique_ptr<Executable>> executables(n);
399 
400   for (int i = 0; i < n; ++i) {
401     auto module_or_status =
402         HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
403     if (!module_or_status.ok()) {
404       return ::testing::AssertionFailure()
405              << "Error while parsing HLO text format: "
406              << module_or_status.status().ToString();
407     }
408     std::unique_ptr<HloModule> module =
409         std::move(module_or_status.ValueOrDie());
410 
411     fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
412     absl::c_transform(
413         fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
414         [](const Literal& literal) { return const_cast<Literal*>(&literal); });
415 
416     if (profiles != nullptr) {
417       // We have to enable HLO profiling since otherwise currently the
418       // ExecutionProfile is not correct.
419       //
420       // TODO(b/119432044): Fix collection of the ExecutionProfile
421       // so that this is not necessary.
422       HloModuleConfig config = module->config();
423       DebugOptions debug_options = config.debug_options();
424       debug_options.set_xla_hlo_profile(true);
425       config.set_debug_options(debug_options);
426       module->set_config(config);
427     }
428 
429     if (!backend_config.empty()) {
430       // Set backend configuration if it is given.
431       HloInstruction* instruction =
432           module->entry_computation()->root_instruction();
433       instruction->set_raw_backend_config_string(backend_config);
434     }
435 
436     auto executable =
437         test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
438     if (!executable.ok()) {
439       return ::testing::AssertionFailure()
440              << executable.status().error_message();
441     }
442     executables[i] = std::move(executable.ValueOrDie());
443   }
444 
445   for (int i = 0; i < n; ++i) {
446     auto output =
447         test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
448                              /*profile=*/&((*profiles)[i]));
449     if (!output.ok()) {
450       return ::testing::AssertionFailure() << output.status().error_message();
451     }
452   }
453 
454   return ::testing::AssertionSuccess();
455 }
456 
RunAndCompareFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)457 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
458     const string& filename, const absl::optional<ErrorSpec>& error,
459     const std::function<void(HloModule*)>& reference_preprocessor) {
460   auto module_or_status =
461       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
462   if (!module_or_status.ok()) {
463     return ::testing::AssertionFailure()
464            << "failed reading hlo module from file";
465   }
466   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
467                        reference_preprocessor);
468 }
469 
RunAndCompareNoHloPasses(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)470 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
471     string_view hlo_string, const absl::optional<ErrorSpec>& error,
472     const std::function<void(HloModule*)>& reference_preprocessor) {
473   auto module_or_status =
474       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
475   if (!module_or_status.ok()) {
476     return ::testing::AssertionFailure()
477            << "Error while parsing HLO text format: "
478            << module_or_status.status().ToString();
479   }
480   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
481                                   reference_preprocessor);
482 }
483 
RunAndCompareNoHloPassesFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)484 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
485     const string& filename, const absl::optional<ErrorSpec>& error,
486     const std::function<void(HloModule*)>& reference_preprocessor) {
487   auto module_or_status =
488       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
489   if (!module_or_status.ok()) {
490     return ::testing::AssertionFailure()
491            << "failed reading hlo module from file";
492   }
493   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
494                                   reference_preprocessor);
495 }
496 
FindComputation(HloModule * module,absl::string_view name)497 HloComputation* HloTestBase::FindComputation(HloModule* module,
498                                              absl::string_view name) {
499   auto computations = module->computations();
500   auto it = absl::c_find_if(
501       computations, [&](HloComputation* c) { return c->name() == name; });
502   if (it == computations.end()) {
503     return nullptr;
504   }
505   return *it;
506 }
507 
FindInstruction(HloModule * module,absl::string_view name)508 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
509                                              absl::string_view name) {
510   for (const HloComputation* c : module->computations()) {
511     auto instructions = c->instructions();
512     auto it = absl::c_find_if(
513         instructions, [&](HloInstruction* i) { return i->name() == name; });
514     if (it != instructions.end()) {
515       return *it;
516     }
517   }
518   return nullptr;
519 }
520 
backend()521 Backend& HloTestBase::backend() { return test_runner_.backend(); }
522 
523 /* static */
TestName()524 string HloTestBase::TestName() {
525   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
526 }
527 
528 }  // namespace xla
529