1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace xla {
42
43 namespace {
44
45 using absl::optional;
46 using absl::string_view;
47
48 constexpr char kInterpreter[] = "interpreter";
49
50 // Helper functions to get test and reference platforms.
GetReferencePlatform()51 se::Platform* GetReferencePlatform() {
52 auto result = PlatformUtil::GetPlatform(kInterpreter);
53 TF_CHECK_OK(result.status()) << "could not get interpreter platform";
54 return result.ValueOrDie();
55 }
56
GetTestPlatform()57 se::Platform* GetTestPlatform() {
58 auto result = PlatformUtil::GetDefaultPlatform();
59 TF_CHECK_OK(result.status()) << "could not get test platform";
60 return result.ValueOrDie();
61 }
62
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
64 if (lhs.parameters_size() != rhs.parameters_size()) {
65 return false;
66 }
67 for (int i = 0; i < lhs.parameters_size(); i++) {
68 if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
69 return false;
70 }
71 }
72 return ShapeUtil::Equal(lhs.result(), rhs.result());
73 }
74
GetProgramShapeWithLayout(const HloModule & module)75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
76 ProgramShape program_shape;
77 const auto* entry = module.entry_computation();
78 for (const auto* param : entry->parameter_instructions()) {
79 *program_shape.add_parameters() = param->shape();
80 *program_shape.add_parameter_names() = param->name();
81 }
82 *program_shape.mutable_result() = entry->root_instruction()->shape();
83 return program_shape;
84 }
85
86 } // namespace
87
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)88 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
89 bool allow_mixed_precision_in_hlo_verifier,
90 std::function<bool(const HloInstruction*)>
91 instruction_can_change_layout_func)
92 : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
93 verifier_layout_sensitive,
94 allow_mixed_precision_in_hlo_verifier,
95 instruction_can_change_layout_func) {}
96
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)97 HloTestBase::HloTestBase(se::Platform* test_platform,
98 se::Platform* reference_platform,
99 bool verifier_layout_sensitive,
100 bool allow_mixed_precision_in_hlo_verifier,
101 std::function<bool(const HloInstruction*)>
102 instruction_can_change_layout_func)
103 : test_runner_(test_platform),
104 reference_runner_(reference_platform),
105 verifier_layout_sensitive_(verifier_layout_sensitive),
106 allow_mixed_precision_in_hlo_verifier_(
107 allow_mixed_precision_in_hlo_verifier) {
108 hlo_verifier_ = absl::make_unique<HloVerifier>(
109 /*layout_sensitive=*/verifier_layout_sensitive,
110 /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
111 instruction_can_change_layout_func);
112 }
113
CreateNewUnverifiedModule(const string & name)114 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
115 const string& name) {
116 return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
117 }
118
CreateNewVerifiedModule(const string & name)119 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
120 const string& name) {
121 return absl::make_unique<VerifiedHloModule>(
122 name, GetModuleConfigForTest(), verifier_layout_sensitive_,
123 allow_mixed_precision_in_hlo_verifier_,
124 backend().compiler()->ShapeSizeBytesFunction());
125 }
126
127 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text)128 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
129 return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest());
130 }
131
132 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)133 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
134 const HloModuleConfig& config) {
135 auto module = absl::make_unique<VerifiedHloModule>(
136 TestName(), config, verifier_layout_sensitive_,
137 allow_mixed_precision_in_hlo_verifier_,
138 backend().compiler()->ShapeSizeBytesFunction());
139 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
140 return std::move(module);
141 }
142
143 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)144 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
145 HloModule* module) {
146 const string module_str_before_run = module->ToProto().ShortDebugString();
147 const auto status_or = hlo_pass->Run(module);
148 if (status_or.status().ok()) {
149 const string module_str_after_run = module->ToProto().ShortDebugString();
150 if (!status_or.ValueOrDie()) {
151 // Check that the proto remains same.
152 EXPECT_EQ(module_str_after_run, module_str_before_run);
153 }
154 }
155 return status_or;
156 }
157
158 /* static */
DefaultPrecisionConfig(int operands)159 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
160 PrecisionConfig precision_config;
161 precision_config.mutable_operand_precision()->Resize(
162 operands, PrecisionConfig::DEFAULT);
163 return precision_config;
164 }
165
GetDebugOptionsForTest()166 DebugOptions HloTestBase::GetDebugOptionsForTest() {
167 auto debug_options = GetDebugOptionsFromFlags();
168 // TODO(b/38354253): Change tests to use Parameters instead of Constants.
169 debug_options.add_xla_disable_hlo_passes("constant_folding");
170 debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
171 debug_options.set_xla_hlo_evaluator_use_fast_path(true);
172 return debug_options;
173 }
174
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)175 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
176 absl::Span<Literal* const> arguments) {
177 return test_runner_.Execute(std::move(module), arguments);
178 }
179
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)180 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
181 absl::Span<Literal* const> arguments) {
182 return test_runner_
183 .Execute(std::move(module), arguments,
184 /*run_hlo_passes=*/false)
185 .ValueOrDie();
186 }
187
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)188 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
189 absl::Span<Literal* const> arguments) {
190 return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
191 }
192
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,bool use_threads,bool run_hlo_passes)193 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
194 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
195 int64 num_replicas, bool use_threads, bool run_hlo_passes) {
196 HloRunner::ReplicatedExecuteOptions options;
197 options.num_replicas = num_replicas;
198 options.run_hlo_passes = run_hlo_passes;
199 options.use_threads = use_threads;
200 for (auto argument : arguments) {
201 options.arguments.push_back(argument);
202 }
203 return test_runner_.ExecuteReplicated(std::move(module), options);
204 }
205
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)206 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
207 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
208 int64 num_replicas, DeviceAssignment* device_assignment,
209 bool run_hlo_passes, bool use_threads) {
210 HloRunner::ReplicatedExecuteOptions options;
211 options.num_replicas = num_replicas;
212 options.run_hlo_passes = run_hlo_passes;
213 options.use_threads = use_threads;
214 for (auto argument : arguments) {
215 options.arguments.push_back(argument);
216 }
217 return test_runner_.ExecuteReplicated(std::move(module), options,
218 device_assignment);
219 }
220
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)221 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
222 const HloModule& test_module,
223 const std::function<void(HloModule*)>& reference_preprocessor) {
224 std::unique_ptr<HloModule> reference_module = test_module.Clone();
225 const auto& program_shape = GetProgramShapeWithLayout(test_module);
226
227 if (reference_preprocessor != nullptr) {
228 reference_preprocessor(reference_module.get());
229 if (!ProgramShapesEqual(program_shape,
230 GetProgramShapeWithLayout(*reference_module))) {
231 return InvalidArgument(
232 "reference preprocessor must not modify the program shape");
233 }
234 }
235 TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
236 return std::move(reference_module);
237 }
238
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)239 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
240 std::unique_ptr<HloModule> module,
241 const absl::Span<Literal* const> arguments,
242 const optional<ErrorSpec>& error, bool run_hlo_passes,
243 const std::function<void(HloModule*)>& reference_preprocessor) {
244 TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
245 TF_ASSIGN_OR_RETURN(auto reference_module,
246 MakeReferenceModule(*module, reference_preprocessor));
247
248 // Execute on two backends.
249 TF_ASSIGN_OR_RETURN(
250 auto test,
251 test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
252 TF_ASSIGN_OR_RETURN(auto reference,
253 reference_runner_.Execute(std::move(reference_module),
254 arguments, run_hlo_passes));
255 return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
256 error);
257 }
258
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)259 ::testing::AssertionResult HloTestBase::RunAndCompare(
260 std::unique_ptr<HloModule> module,
261 const absl::Span<Literal* const> arguments,
262 const optional<ErrorSpec>& error,
263 const std::function<void(HloModule*)>& reference_preprocessor) {
264 auto result =
265 RunAndCompareInternal(std::move(module), arguments, error,
266 /*run_hlo_passes=*/true, reference_preprocessor);
267 if (!result.ok()) {
268 return ::testing::AssertionFailure() << result.status();
269 }
270 return result.ValueOrDie();
271 }
272
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)273 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
274 std::unique_ptr<HloModule> module,
275 const absl::Span<Literal* const> arguments,
276 const optional<ErrorSpec>& error,
277 const std::function<void(HloModule*)>& reference_preprocessor) {
278 auto result =
279 RunAndCompareInternal(std::move(module), arguments, error,
280 /*run_hlo_passes=*/false, reference_preprocessor);
281 if (!result.ok()) {
282 return ::testing::AssertionFailure() << result.status();
283 }
284 return result.ValueOrDie();
285 }
286
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)287 ::testing::AssertionResult HloTestBase::RunAndCompare(
288 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
289 const std::function<void(HloModule*)>& reference_preprocessor) {
290 auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
291
292 std::vector<Literal*> fake_argument_ptrs;
293 absl::c_transform(
294 fake_arguments, std::back_inserter(fake_argument_ptrs),
295 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
296
297 return RunAndCompare(std::move(module), fake_argument_ptrs, error,
298 reference_preprocessor);
299 }
300
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)301 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
302 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
303 const std::function<void(HloModule*)>& reference_preprocessor) {
304 const auto& fake_arguments =
305 MakeFakeArguments(module.get()).ConsumeValueOrDie();
306 std::vector<Literal*> fake_argument_ptrs;
307 absl::c_transform(
308 fake_arguments, std::back_inserter(fake_argument_ptrs),
309 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
310
311 return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
312 reference_preprocessor);
313 }
314
RunAndCompare(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)315 ::testing::AssertionResult HloTestBase::RunAndCompare(
316 string_view hlo_string, const absl::optional<ErrorSpec>& error,
317 const std::function<void(HloModule*)>& reference_preprocessor) {
318 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
319 if (!module_or_status.ok()) {
320 return ::testing::AssertionFailure()
321 << "Error while parsing HLO text format: "
322 << module_or_status.status().ToString();
323 }
324 return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
325 reference_preprocessor);
326 }
327
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,string backend_config)328 ::testing::AssertionResult HloTestBase::Run(string_view hlo_string,
329 bool run_hlo_passes,
330 ExecutionProfile* profile,
331 string backend_config) {
332 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
333 if (!module_or_status.ok()) {
334 return ::testing::AssertionFailure()
335 << "Error while parsing HLO text format: "
336 << module_or_status.status().ToString();
337 }
338
339 std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
340 const auto& fake_arguments =
341 MakeFakeArguments(module.get()).ConsumeValueOrDie();
342 std::vector<Literal*> fake_argument_ptrs;
343 absl::c_transform(
344 fake_arguments, std::back_inserter(fake_argument_ptrs),
345 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
346
347 if (profile != nullptr) {
348 // We have to enable HLO profiling since otherwise currently the
349 // ExecutionProfile is not correct.
350 //
351 // TODO(b/119432044): Fix collection of the ExecutionProfile
352 // so that this is not necessary.
353 HloModuleConfig config = module->config();
354 DebugOptions debug_options = config.debug_options();
355 debug_options.set_xla_hlo_profile(true);
356 config.set_debug_options(debug_options);
357 module->set_config(config);
358 }
359
360 if (!backend_config.empty()) {
361 // Set backend configuration if it is given.
362 HloInstruction* instruction =
363 module->entry_computation()->root_instruction();
364 instruction->set_raw_backend_config_string(backend_config);
365 }
366
367 auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
368 /*run_hlo_passes=*/run_hlo_passes,
369 /*profile=*/profile);
370
371 return output.ok()
372 ? ::testing::AssertionSuccess()
373 : ::testing::AssertionFailure() << output.status().error_message();
374 }
375
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,string backend_config,bool assert_determinism)376 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
377 string_view hlo_string, bool run_hlo_passes,
378 std::vector<ExecutionProfile>* profiles, string backend_config,
379 bool assert_determinism) {
380 int n = profiles->size();
381 std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
382 std::vector<std::vector<Literal>> fake_arguments(n);
383 std::vector<std::unique_ptr<Executable>> executables(n);
384
385 for (int i = 0; i < n; ++i) {
386 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
387 if (!module_or_status.ok()) {
388 return ::testing::AssertionFailure()
389 << "Error while parsing HLO text format: "
390 << module_or_status.status().ToString();
391 }
392 std::unique_ptr<HloModule> module =
393 std::move(module_or_status.ValueOrDie());
394
395 fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
396 absl::c_transform(
397 fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
398 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
399
400 if (profiles != nullptr) {
401 // We have to enable HLO profiling since otherwise currently the
402 // ExecutionProfile is not correct.
403 //
404 // TODO(b/119432044): Fix collection of the ExecutionProfile
405 // so that this is not necessary.
406 HloModuleConfig config = module->config();
407 DebugOptions debug_options = config.debug_options();
408 debug_options.set_xla_hlo_profile(true);
409 config.set_debug_options(debug_options);
410 module->set_config(config);
411 }
412
413 if (!backend_config.empty()) {
414 // Set backend configuration if it is given.
415 HloInstruction* instruction =
416 module->entry_computation()->root_instruction();
417 instruction->set_raw_backend_config_string(backend_config);
418 }
419
420 auto executable =
421 test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
422 if (!executable.ok()) {
423 return ::testing::AssertionFailure()
424 << executable.status().error_message();
425 }
426 executables[i] = std::move(executable.ValueOrDie());
427 }
428
429 absl::optional<Literal> canonical_output;
430 for (int i = 0; i < n; ++i) {
431 StatusOr<Literal> output =
432 test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
433 /*profile=*/&((*profiles)[i]));
434 if (!output.ok()) {
435 return ::testing::AssertionFailure() << output.status().error_message();
436 }
437
438 if (assert_determinism) {
439 if (!canonical_output.has_value()) {
440 canonical_output = output.ConsumeValueOrDie();
441 } else {
442 if (*canonical_output != output.ValueOrDie()) {
443 return ::testing::AssertionFailure()
444 << "Successive runs have returned different results: "
445 << *canonical_output << " vs. " << output.ValueOrDie();
446 }
447 }
448 }
449 }
450
451 return ::testing::AssertionSuccess();
452 }
453
RunAndCompareFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)454 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
455 const string& filename, const absl::optional<ErrorSpec>& error,
456 const std::function<void(HloModule*)>& reference_preprocessor) {
457 auto module_or_status =
458 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
459 if (!module_or_status.ok()) {
460 return ::testing::AssertionFailure()
461 << "failed reading hlo module from file";
462 }
463 return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
464 reference_preprocessor);
465 }
466
RunAndCompareNoHloPasses(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)467 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
468 string_view hlo_string, const absl::optional<ErrorSpec>& error,
469 const std::function<void(HloModule*)>& reference_preprocessor) {
470 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
471 if (!module_or_status.ok()) {
472 return ::testing::AssertionFailure()
473 << "Error while parsing HLO text format: "
474 << module_or_status.status().ToString();
475 }
476 return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
477 reference_preprocessor);
478 }
479
RunAndCompareNoHloPassesFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)480 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
481 const string& filename, const absl::optional<ErrorSpec>& error,
482 const std::function<void(HloModule*)>& reference_preprocessor) {
483 auto module_or_status =
484 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
485 if (!module_or_status.ok()) {
486 return ::testing::AssertionFailure()
487 << "failed reading hlo module from file";
488 }
489 return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
490 reference_preprocessor);
491 }
492
FindComputation(HloModule * module,absl::string_view name)493 HloComputation* HloTestBase::FindComputation(HloModule* module,
494 absl::string_view name) {
495 auto computations = module->computations();
496 auto it = absl::c_find_if(
497 computations, [&](HloComputation* c) { return c->name() == name; });
498 if (it == computations.end()) {
499 return nullptr;
500 }
501 return *it;
502 }
503
FindInstruction(HloModule * module,absl::string_view name)504 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
505 absl::string_view name) {
506 for (const HloComputation* c : module->computations()) {
507 auto instructions = c->instructions();
508 auto it = absl::c_find_if(
509 instructions, [&](HloInstruction* i) { return i->name() == name; });
510 if (it != instructions.end()) {
511 return *it;
512 }
513 }
514 return nullptr;
515 }
516
FindInstruction(HloModule * module,HloOpcode opcode)517 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
518 HloOpcode opcode) {
519 for (const HloComputation* c : module->computations()) {
520 auto instructions = c->instructions();
521 auto it = absl::c_find_if(
522 instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
523 if (it != instructions.end()) {
524 return *it;
525 }
526 }
527 return nullptr;
528 }
529
backend()530 Backend& HloTestBase::backend() { return test_runner_.backend(); }
531
532 /* static */
TestName()533 string HloTestBase::TestName() {
534 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
535 }
536
537 } // namespace xla
538