1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/platform_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace xla {
42
43 namespace {
44
45 using absl::optional;
46 using absl::string_view;
47
48 constexpr char kInterpreter[] = "interpreter";
49
50 // Helper functions to get test and reference platforms.
GetReferencePlatform()51 se::Platform* GetReferencePlatform() {
52 auto result = PlatformUtil::GetPlatform(kInterpreter);
53 TF_CHECK_OK(result.status()) << "could not get interpreter platform";
54 return result.ValueOrDie();
55 }
56
GetTestPlatform()57 se::Platform* GetTestPlatform() {
58 auto result = PlatformUtil::GetDefaultPlatform();
59 TF_CHECK_OK(result.status()) << "could not get test platform";
60 return result.ValueOrDie();
61 }
62
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
64 if (lhs.parameters_size() != rhs.parameters_size()) {
65 return false;
66 }
67 for (int i = 0; i < lhs.parameters_size(); i++) {
68 if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
69 return false;
70 }
71 }
72 return ShapeUtil::Equal(lhs.result(), rhs.result());
73 }
74
GetProgramShapeWithLayout(const HloModule & module)75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
76 ProgramShape program_shape;
77 const auto* entry = module.entry_computation();
78 for (const auto* param : entry->parameter_instructions()) {
79 *program_shape.add_parameters() = param->shape();
80 *program_shape.add_parameter_names() = param->name();
81 }
82 *program_shape.mutable_result() = entry->root_instruction()->shape();
83 return program_shape;
84 }
85
86 } // namespace
87
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)88 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
89 bool allow_mixed_precision_in_hlo_verifier,
90 std::function<bool(const HloInstruction*)>
91 instruction_can_change_layout_func)
92 : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
93 verifier_layout_sensitive,
94 allow_mixed_precision_in_hlo_verifier,
95 instruction_can_change_layout_func) {}
96
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)97 HloTestBase::HloTestBase(se::Platform* test_platform,
98 se::Platform* reference_platform,
99 bool verifier_layout_sensitive,
100 bool allow_mixed_precision_in_hlo_verifier,
101 std::function<bool(const HloInstruction*)>
102 instruction_can_change_layout_func)
103 : test_runner_(test_platform),
104 reference_runner_(reference_platform),
105 verifier_layout_sensitive_(verifier_layout_sensitive),
106 allow_mixed_precision_in_hlo_verifier_(
107 allow_mixed_precision_in_hlo_verifier) {
108 hlo_verifier_ = absl::make_unique<HloVerifier>(
109 /*layout_sensitive=*/verifier_layout_sensitive,
110 /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
111 instruction_can_change_layout_func);
112 }
113
CreateNewUnverifiedModule(const string & name)114 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
115 const string& name) {
116 return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
117 }
118
CreateNewVerifiedModule(const string & name,int64 replica_count)119 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
120 const string& name, int64 replica_count) {
121 return absl::make_unique<VerifiedHloModule>(
122 name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
123 allow_mixed_precision_in_hlo_verifier_,
124 backend().compiler()->ShapeSizeBytesFunction());
125 }
126
127 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,int64 replica_count)128 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
129 int64 replica_count) {
130 return ParseAndReturnVerifiedModule(hlo_text,
131 GetModuleConfigForTest(replica_count));
132 }
133
134 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)135 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
136 const HloModuleConfig& config) {
137 auto module = absl::make_unique<VerifiedHloModule>(
138 TestName(), config, verifier_layout_sensitive_,
139 allow_mixed_precision_in_hlo_verifier_,
140 backend().compiler()->ShapeSizeBytesFunction());
141 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
142 return std::move(module);
143 }
144
145 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)146 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
147 HloModule* module) {
148 const string module_str_before_run = module->ToProto().ShortDebugString();
149 const auto status_or = hlo_pass->Run(module);
150 if (status_or.status().ok()) {
151 const string module_str_after_run = module->ToProto().ShortDebugString();
152 if (!status_or.ValueOrDie()) {
153 // Check that the proto remains same.
154 EXPECT_EQ(module_str_after_run, module_str_before_run);
155 }
156 }
157 return status_or;
158 }
159
160 /* static */
DefaultPrecisionConfig(int operands)161 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
162 PrecisionConfig precision_config;
163 precision_config.mutable_operand_precision()->Resize(
164 operands, PrecisionConfig::DEFAULT);
165 return precision_config;
166 }
167
SetAotFastMathDebugOptions(DebugOptions * options)168 void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) {
169 options->set_xla_cpu_enable_fast_math(true);
170 options->set_xla_gpu_enable_fast_min_max(true);
171 options->set_xla_cpu_enable_fast_min_max(true);
172 options->set_xla_cpu_fast_math_honor_nans(false);
173 options->set_xla_cpu_fast_math_honor_infs(false);
174 options->set_xla_cpu_fast_math_honor_functions(false);
175 options->set_xla_cpu_fast_math_honor_division(false);
176 }
177
GetDebugOptionsForTest()178 DebugOptions HloTestBase::GetDebugOptionsForTest() {
179 auto debug_options = GetDebugOptionsFromFlags();
180 // TODO(b/38354253): Change tests to use Parameters instead of Constants.
181 debug_options.add_xla_disable_hlo_passes("constant_folding");
182 debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
183 debug_options.set_xla_hlo_evaluator_use_fast_path(true);
184 return debug_options;
185 }
186
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)187 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
188 absl::Span<Literal* const> arguments) {
189 return test_runner_.Execute(std::move(module), arguments);
190 }
191
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)192 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
193 absl::Span<Literal* const> arguments) {
194 return test_runner_
195 .Execute(std::move(module), arguments,
196 /*run_hlo_passes=*/false)
197 .ValueOrDie();
198 }
199
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)200 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
201 absl::Span<Literal* const> arguments) {
202 return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
203 }
204
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,bool use_threads,bool run_hlo_passes)205 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
206 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
207 int64 num_replicas, bool use_threads, bool run_hlo_passes) {
208 HloRunner::ReplicatedExecuteOptions options;
209 options.num_replicas = num_replicas;
210 options.run_hlo_passes = run_hlo_passes;
211 options.use_threads = use_threads;
212 for (auto argument : arguments) {
213 options.arguments.push_back(argument);
214 }
215 return test_runner_.ExecuteReplicated(std::move(module), options);
216 }
217
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64 num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)218 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
219 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
220 int64 num_replicas, DeviceAssignment* device_assignment,
221 bool run_hlo_passes, bool use_threads) {
222 HloRunner::ReplicatedExecuteOptions options;
223 options.num_replicas = num_replicas;
224 options.run_hlo_passes = run_hlo_passes;
225 options.use_threads = use_threads;
226 for (auto argument : arguments) {
227 options.arguments.push_back(argument);
228 }
229 return test_runner_.ExecuteReplicated(std::move(module), options,
230 device_assignment);
231 }
232
ExecuteReplicated(std::function<Executable * (int64)> executable_provider,std::function<int64 (int64)> argument_count_provider,std::function<const Literal * (int64,int64)> argument_provider,int64 num_replicas,bool run_hlo_passes)233 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
234 std::function<Executable*(int64)> executable_provider,
235 std::function<int64(int64)> argument_count_provider,
236 std::function<const Literal*(int64, int64)> argument_provider,
237 int64 num_replicas, bool run_hlo_passes) {
238 HloRunner::ReplicatedExecuteOptions options;
239 options.num_replicas = num_replicas;
240 options.run_hlo_passes = run_hlo_passes;
241 options.use_threads = true;
242 return test_runner_.ExecuteReplicated(
243 executable_provider, argument_count_provider, argument_provider, options);
244 }
245
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)246 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
247 const HloModule& test_module,
248 const std::function<void(HloModule*)>& reference_preprocessor) {
249 std::unique_ptr<HloModule> reference_module = test_module.Clone();
250 const auto& program_shape = GetProgramShapeWithLayout(test_module);
251
252 if (reference_preprocessor != nullptr) {
253 reference_preprocessor(reference_module.get());
254 if (!ProgramShapesEqual(program_shape,
255 GetProgramShapeWithLayout(*reference_module))) {
256 return InvalidArgument(
257 "reference preprocessor must not modify the program shape");
258 }
259 }
260 TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
261 return std::move(reference_module);
262 }
263
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)264 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
265 std::unique_ptr<HloModule> module,
266 const absl::Span<Literal* const> arguments,
267 const optional<ErrorSpec>& error, bool run_hlo_passes,
268 const std::function<void(HloModule*)>& reference_preprocessor) {
269 TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
270 TF_ASSIGN_OR_RETURN(auto reference_module,
271 MakeReferenceModule(*module, reference_preprocessor));
272
273 // Execute on two backends.
274 TF_ASSIGN_OR_RETURN(
275 auto test,
276 test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
277 TF_ASSIGN_OR_RETURN(auto reference,
278 reference_runner_.Execute(std::move(reference_module),
279 arguments, run_hlo_passes));
280 return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
281 error);
282 }
283
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)284 ::testing::AssertionResult HloTestBase::RunAndCompare(
285 std::unique_ptr<HloModule> module,
286 const absl::Span<Literal* const> arguments,
287 const optional<ErrorSpec>& error,
288 const std::function<void(HloModule*)>& reference_preprocessor) {
289 auto result =
290 RunAndCompareInternal(std::move(module), arguments, error,
291 /*run_hlo_passes=*/true, reference_preprocessor);
292 if (!result.ok()) {
293 return ::testing::AssertionFailure() << result.status();
294 }
295 return result.ValueOrDie();
296 }
297
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)298 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
299 std::unique_ptr<HloModule> module,
300 const absl::Span<Literal* const> arguments,
301 const optional<ErrorSpec>& error,
302 const std::function<void(HloModule*)>& reference_preprocessor) {
303 auto result =
304 RunAndCompareInternal(std::move(module), arguments, error,
305 /*run_hlo_passes=*/false, reference_preprocessor);
306 if (!result.ok()) {
307 return ::testing::AssertionFailure() << result.status();
308 }
309 return result.ValueOrDie();
310 }
311
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)312 ::testing::AssertionResult HloTestBase::RunAndCompare(
313 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
314 const std::function<void(HloModule*)>& reference_preprocessor) {
315 auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
316
317 std::vector<Literal*> fake_argument_ptrs;
318 absl::c_transform(
319 fake_arguments, std::back_inserter(fake_argument_ptrs),
320 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
321
322 return RunAndCompare(std::move(module), fake_argument_ptrs, error,
323 reference_preprocessor);
324 }
325
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)326 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
327 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
328 const std::function<void(HloModule*)>& reference_preprocessor) {
329 const auto& fake_arguments =
330 MakeFakeArguments(module.get()).ConsumeValueOrDie();
331 std::vector<Literal*> fake_argument_ptrs;
332 absl::c_transform(
333 fake_arguments, std::back_inserter(fake_argument_ptrs),
334 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
335
336 return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
337 reference_preprocessor);
338 }
339
Run(std::unique_ptr<HloModule> module,bool run_hlo_passes)340 ::testing::AssertionResult HloTestBase::Run(std::unique_ptr<HloModule> module,
341 bool run_hlo_passes) {
342 const auto fake_arguments =
343 MakeFakeArguments(module.get()).ConsumeValueOrDie();
344 const auto change = hlo_verifier_->Run(module.get());
345 if (!change.ok()) {
346 return ::testing::AssertionFailure() << change.status();
347 }
348
349 const auto output =
350 test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes);
351 return output.ok()
352 ? ::testing::AssertionSuccess()
353 : ::testing::AssertionFailure() << output.status().error_message();
354 }
355
RunAndCompare(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)356 ::testing::AssertionResult HloTestBase::RunAndCompare(
357 string_view hlo_string, const absl::optional<ErrorSpec>& error,
358 const std::function<void(HloModule*)>& reference_preprocessor) {
359 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
360 if (!module_or_status.ok()) {
361 return ::testing::AssertionFailure()
362 << "Error while parsing HLO text format: "
363 << module_or_status.status().ToString();
364 }
365 return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
366 reference_preprocessor);
367 }
368
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,string backend_config)369 ::testing::AssertionResult HloTestBase::Run(string_view hlo_string,
370 bool run_hlo_passes,
371 ExecutionProfile* profile,
372 string backend_config) {
373 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
374 if (!module_or_status.ok()) {
375 return ::testing::AssertionFailure()
376 << "Error while parsing HLO text format: "
377 << module_or_status.status().ToString();
378 }
379
380 std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
381 const auto& fake_arguments =
382 MakeFakeArguments(module.get()).ConsumeValueOrDie();
383 std::vector<Literal*> fake_argument_ptrs;
384 absl::c_transform(
385 fake_arguments, std::back_inserter(fake_argument_ptrs),
386 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
387
388 if (profile != nullptr) {
389 // We have to enable HLO profiling since otherwise currently the
390 // ExecutionProfile is not correct.
391 //
392 // TODO(b/119432044): Fix collection of the ExecutionProfile
393 // so that this is not necessary.
394 HloModuleConfig config = module->config();
395 DebugOptions debug_options = config.debug_options();
396 debug_options.set_xla_hlo_profile(true);
397 config.set_debug_options(debug_options);
398 module->set_config(config);
399 }
400
401 if (!backend_config.empty()) {
402 // Set backend configuration if it is given.
403 HloInstruction* instruction =
404 module->entry_computation()->root_instruction();
405 instruction->set_raw_backend_config_string(backend_config);
406 }
407
408 auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
409 /*run_hlo_passes=*/run_hlo_passes,
410 /*profile=*/profile);
411
412 return output.ok()
413 ? ::testing::AssertionSuccess()
414 : ::testing::AssertionFailure() << output.status().error_message();
415 }
416
RunReplicated(string_view hlo_string,bool run_hlo_passes,int64 num_replicas,string backend_config)417 ::testing::AssertionResult HloTestBase::RunReplicated(string_view hlo_string,
418 bool run_hlo_passes,
419 int64 num_replicas,
420 string backend_config) {
421 auto module_or_status =
422 ParseAndReturnVerifiedModule(hlo_string, num_replicas);
423 if (!module_or_status.ok()) {
424 return ::testing::AssertionFailure()
425 << "Error while parsing HLO text format: "
426 << module_or_status.status().ToString();
427 }
428
429 std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
430 const auto& fake_arguments =
431 MakeFakeArguments(module.get()).ConsumeValueOrDie();
432 std::vector<Literal*> fake_argument_ptrs;
433 absl::c_transform(
434 fake_arguments, std::back_inserter(fake_argument_ptrs),
435 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
436
437 if (!backend_config.empty()) {
438 // Set backend configuration if it is given.
439 HloInstruction* instruction =
440 module->entry_computation()->root_instruction();
441 instruction->set_raw_backend_config_string(backend_config);
442 }
443
444 HloRunner::ReplicatedExecuteOptions options;
445 options.num_replicas = num_replicas;
446 options.run_hlo_passes = run_hlo_passes;
447 options.use_threads = true;
448 for (auto argument : fake_argument_ptrs) {
449 options.arguments.push_back(argument);
450 }
451 auto output = test_runner_.ExecuteReplicated(std::move(module), options);
452
453 return output.ok()
454 ? ::testing::AssertionSuccess()
455 : ::testing::AssertionFailure() << output.status().error_message();
456 }
457
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,string backend_config,bool assert_determinism)458 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
459 string_view hlo_string, bool run_hlo_passes,
460 std::vector<ExecutionProfile>* profiles, string backend_config,
461 bool assert_determinism) {
462 int n = profiles->size();
463 std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
464 std::vector<std::vector<Literal>> fake_arguments(n);
465 std::vector<std::unique_ptr<Executable>> executables(n);
466
467 for (int i = 0; i < n; ++i) {
468 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
469 if (!module_or_status.ok()) {
470 return ::testing::AssertionFailure()
471 << "Error while parsing HLO text format: "
472 << module_or_status.status().ToString();
473 }
474 std::unique_ptr<HloModule> module =
475 std::move(module_or_status.ValueOrDie());
476
477 fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
478
479 if (profiles != nullptr) {
480 // We have to enable HLO profiling since otherwise currently the
481 // ExecutionProfile is not correct.
482 //
483 // TODO(b/119432044): Fix collection of the ExecutionProfile
484 // so that this is not necessary.
485 HloModuleConfig config = module->config();
486 DebugOptions debug_options = config.debug_options();
487 debug_options.set_xla_hlo_profile(true);
488 config.set_debug_options(debug_options);
489 module->set_config(config);
490 }
491
492 if (!backend_config.empty()) {
493 // Set backend configuration if it is given.
494 HloInstruction* instruction =
495 module->entry_computation()->root_instruction();
496 instruction->set_raw_backend_config_string(backend_config);
497 }
498
499 auto executable =
500 test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
501 if (!executable.ok()) {
502 return ::testing::AssertionFailure()
503 << executable.status().error_message();
504 }
505 executables[i] = std::move(executable.ValueOrDie());
506 }
507
508 absl::optional<Literal> canonical_output;
509 for (int i = 0; i < n; ++i) {
510 StatusOr<Literal> output = test_runner_.ExecuteWithExecutable(
511 std::move(executables[i]), fake_arguments[i],
512 /*profile=*/&((*profiles)[i]));
513 if (!output.ok()) {
514 return ::testing::AssertionFailure() << output.status().error_message();
515 }
516
517 if (assert_determinism) {
518 if (!canonical_output.has_value()) {
519 canonical_output = output.ConsumeValueOrDie();
520 } else {
521 if (*canonical_output != output.ValueOrDie()) {
522 return ::testing::AssertionFailure()
523 << "Successive runs have returned different results: "
524 << *canonical_output << " vs. " << output.ValueOrDie();
525 }
526 }
527 }
528 }
529
530 return ::testing::AssertionSuccess();
531 }
532
RunAndCompareFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)533 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
534 const string& filename, const absl::optional<ErrorSpec>& error,
535 const std::function<void(HloModule*)>& reference_preprocessor) {
536 auto module_or_status =
537 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
538 if (!module_or_status.ok()) {
539 return ::testing::AssertionFailure()
540 << "failed reading hlo module from file";
541 }
542 return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
543 reference_preprocessor);
544 }
545
RunAndCompareNoHloPasses(string_view hlo_string,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)546 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
547 string_view hlo_string, const absl::optional<ErrorSpec>& error,
548 const std::function<void(HloModule*)>& reference_preprocessor) {
549 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
550 if (!module_or_status.ok()) {
551 return ::testing::AssertionFailure()
552 << "Error while parsing HLO text format: "
553 << module_or_status.status().ToString();
554 }
555 return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
556 reference_preprocessor);
557 }
558
RunAndCompareNoHloPassesFromFile(const string & filename,const absl::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)559 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
560 const string& filename, const absl::optional<ErrorSpec>& error,
561 const std::function<void(HloModule*)>& reference_preprocessor) {
562 auto module_or_status =
563 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
564 if (!module_or_status.ok()) {
565 return ::testing::AssertionFailure()
566 << "failed reading hlo module from file";
567 }
568 return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
569 reference_preprocessor);
570 }
571
FindComputation(HloModule * module,absl::string_view name)572 HloComputation* HloTestBase::FindComputation(HloModule* module,
573 absl::string_view name) {
574 auto computations = module->computations();
575 auto it = absl::c_find_if(
576 computations, [&](HloComputation* c) { return c->name() == name; });
577 if (it == computations.end()) {
578 return nullptr;
579 }
580 return *it;
581 }
582
FindInstruction(HloModule * module,absl::string_view name)583 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
584 absl::string_view name) {
585 for (const HloComputation* c : module->computations()) {
586 auto instructions = c->instructions();
587 auto it = absl::c_find_if(
588 instructions, [&](HloInstruction* i) { return i->name() == name; });
589 if (it != instructions.end()) {
590 return *it;
591 }
592 }
593 return nullptr;
594 }
595
FindInstruction(HloModule * module,HloOpcode opcode)596 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
597 HloOpcode opcode) {
598 for (const HloComputation* c : module->computations()) {
599 auto instructions = c->instructions();
600 auto it = absl::c_find_if(
601 instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
602 if (it != instructions.end()) {
603 return *it;
604 }
605 }
606 return nullptr;
607 }
608
backend()609 Backend& HloTestBase::backend() { return test_runner_.backend(); }
610
611 /* static */
TestName()612 string HloTestBase::TestName() {
613 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
614 }
615
616 } // namespace xla
617