1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/platform_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 // Name of the interpreter backend.
40 constexpr char kInterpreter[] = "interpreter";
41
42 // Wrapper function that creates a nicer error message (than a bare
43 // ValueOrDie()) if the platform we intend to test is not available.
GetOrCreateLocalClientOrDie(const LocalClientOptions & client_options)44 LocalClient* GetOrCreateLocalClientOrDie(
45 const LocalClientOptions& client_options) {
46 StatusOr<LocalClient*> result =
47 ClientLibrary::GetOrCreateLocalClient(client_options);
48 TF_CHECK_OK(result.status()) << " could not create local client for testing";
49 return result.ValueOrDie();
50 }
51
52 // Helper functions to get the reference platform.
GetReferencePlatform()53 se::Platform* GetReferencePlatform() {
54 auto result = PlatformUtil::GetPlatform(kInterpreter);
55 TF_CHECK_OK(result.status()) << "could not get interpreter platform";
56 return result.ValueOrDie();
57 }
58
59 } // namespace
60
ClientLibraryTestBase(se::Platform * platform,const LocalClientOptions & client_options)61 ClientLibraryTestBase::ClientLibraryTestBase(
62 se::Platform* platform, const LocalClientOptions& client_options)
63 : client_(GetOrCreateLocalClientOrDie(client_options)),
64 execution_options_(CreateDefaultExecutionOptions()) {
65 CHECK_EQ(platform, client_options.platform());
66
67 LocalClientOptions ref_options;
68 ref_options.set_platform(GetReferencePlatform());
69 ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
70
71 // Disabling constant_folding so that tests (usually written using Constants)
72 // will exercise the intended code paths, instead of being constant folded.
73 //
74 // TODO(b/38354253): Constant folding is currently disabled. Change tests to
75 // use Parameters instead of Constants, and re-enable constant folding by
76 // default.
77 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
78 "constant_folding");
79
80 execution_options_.mutable_debug_options()
81 ->set_xla_hlo_evaluator_use_fast_path(true);
82 }
83
ClientLibraryTestBase(se::Platform * platform)84 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
85 : execution_options_(CreateDefaultExecutionOptions()) {
86 LocalClientOptions default_options;
87 default_options.set_platform(platform);
88 client_ = GetOrCreateLocalClientOrDie(default_options);
89
90 LocalClientOptions ref_options;
91 ref_options.set_platform(GetReferencePlatform());
92 ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
93
94 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
95 "constant_folding");
96
97 execution_options_.mutable_debug_options()
98 ->set_xla_hlo_evaluator_use_fast_path(true);
99 }
100
TestName() const101 string ClientLibraryTestBase::TestName() const {
102 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
103 }
104
Execute(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)105 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
106 XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
107 // Build the computation, as a convenience.
108 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
109 return client_->Execute(computation, arguments, &execution_options_);
110 }
111
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)112 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
113 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
114 const Shape* shape_with_output_layout) {
115 ExecutionOptions execution_options = execution_options_;
116 if (shape_with_output_layout != nullptr) {
117 *execution_options.mutable_shape_with_output_layout() =
118 shape_with_output_layout->ToProto();
119 }
120 return client_->ExecuteAndTransfer(computation, arguments,
121 &execution_options);
122 }
123
ExecuteAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)124 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
125 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
126 const Shape* shape_with_output_layout) {
127 // Build the computation, as a convenience.
128 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
129 return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
130 }
131
ExecuteAndTransferReference(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)132 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
133 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
134 const Shape* shape_with_output_layout) {
135 ExecutionOptions execution_options = execution_options_;
136 if (shape_with_output_layout != nullptr) {
137 *execution_options.mutable_shape_with_output_layout() =
138 shape_with_output_layout->ToProto();
139 }
140 execution_options.clear_device_handles();
141 return ref_client_->ExecuteAndTransfer(computation, arguments,
142 &execution_options);
143 }
144
ExecuteToString(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)145 string ClientLibraryTestBase::ExecuteToString(
146 XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
147 auto computation_status = builder->Build();
148 if (!computation_status.ok()) {
149 return computation_status.status().ToString();
150 }
151 auto computation = computation_status.ConsumeValueOrDie();
152
153 auto result =
154 client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
155 if (!result.ok()) {
156 return result.status().ToString();
157 } else {
158 return result.ValueOrDie().ToString();
159 }
160 }
161
ComputeAndCompareR1(XlaBuilder * builder,const tensorflow::core::Bitmap & expected,absl::Span<GlobalData * const> arguments)162 void ClientLibraryTestBase::ComputeAndCompareR1(
163 XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
164 absl::Span<GlobalData* const> arguments) {
165 Literal expected_literal = LiteralUtil::CreateR1(expected);
166 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
167 arguments);
168 }
169
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,const Shape * shape_with_layout)170 void ClientLibraryTestBase::ComputeAndCompareLiteral(
171 XlaBuilder* builder, const Literal& expected,
172 absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) {
173 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
174 shape_with_layout));
175 }
176
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error,const Shape * shape_with_layout)177 void ClientLibraryTestBase::ComputeAndCompareLiteral(
178 XlaBuilder* builder, const Literal& expected,
179 absl::Span<GlobalData* const> arguments, ErrorSpec error,
180 const Shape* shape_with_layout) {
181 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
182 error, shape_with_layout));
183 }
184
ComputeAndCompareLiteralWithAllOutputLayouts(const xla::XlaComputation & computation,const Literal & expected,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const string & error_message)> & verify_output)185 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
186 const xla::XlaComputation& computation, const Literal& expected,
187 absl::Span<GlobalData* const> arguments,
188 const std::function<void(const Literal& actual,
189 const string& error_message)>& verify_output) {
190 // Try with no layout requirement.
191 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
192 verify_output(actual, "");
193
194 // Try with all output layouts.
195 std::vector<int64> minor_to_major(expected.shape().rank());
196 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
197 do {
198 auto layout = ShapeUtil::MakeShapeWithLayout(
199 expected.shape().element_type(),
200 AsInt64Slice(expected.shape().dimensions()), minor_to_major);
201 TF_ASSIGN_OR_RETURN(auto actual,
202 ExecuteAndTransfer(computation, arguments, &layout));
203 verify_output(actual,
204 absl::StrCat("Test with output layout: ",
205 ShapeUtil::HumanStringWithLayout(layout)));
206 } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
207 return Status::OK();
208 }
209
ComputeAndCompareLiteralWithAllInputLayouts(const xla::XlaComputation & computation,const Literal &,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const string & error_message)> & verify_output,const Shape * output_with_layout)210 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
211 const xla::XlaComputation& computation, const Literal& /*expected*/,
212 absl::Span<GlobalData* const> arguments,
213 const std::function<void(const Literal& actual,
214 const string& error_message)>& verify_output,
215 const Shape* output_with_layout) {
216 std::vector<GlobalData*> arguments_with_layout;
217 std::vector<string> layout_strings;
218 // This is a recursive function. It's an std::function instead of a lambda
219 // because it needs to capture itself. The index is the index of the argument
220 // to try all layouts for.
221 std::function<Status(int64)> choose;
222 choose = [&, this](int64 index) -> Status {
223 if (index < arguments.size()) {
224 // Try out all layouts for the operand.
225 TF_ASSIGN_OR_RETURN(auto literal,
226 client_->Transfer(*arguments[index], nullptr));
227 // Skip tuples because they don't have a rank.
228 if (literal.shape().IsTuple()) {
229 layout_strings.push_back(
230 ShapeUtil::HumanStringWithLayout(literal.shape()));
231 arguments_with_layout.push_back(arguments[index]);
232 TF_RETURN_IF_ERROR(choose(index + 1));
233 arguments_with_layout.pop_back();
234 layout_strings.pop_back();
235 return Status::OK();
236 }
237
238 std::vector<int64> minor_to_major(literal.shape().rank());
239 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
240 do {
241 auto literal_relayout =
242 literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
243 layout_strings.push_back(
244 ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
245 TF_ASSIGN_OR_RETURN(auto data,
246 client_->TransferToServer(literal_relayout));
247 arguments_with_layout.push_back(data.get());
248 TF_RETURN_IF_ERROR(choose(index + 1));
249 arguments_with_layout.pop_back();
250 layout_strings.pop_back();
251 } while (
252 std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
253 return Status::OK();
254 }
255
256 // Every argument has an assigned layout.
257 TF_ASSIGN_OR_RETURN(
258 auto actual,
259 ExecuteAndTransfer(computation,
260 absl::Span<GlobalData* const>(arguments_with_layout),
261 output_with_layout));
262 string error_message = "Test with input layouts: ";
263 for (const auto& str : layout_strings) {
264 absl::StrAppend(&error_message, str, " ");
265 }
266 verify_output(actual, error_message);
267 return Status::OK();
268 };
269
270 return choose(0);
271 }
272
ComputeAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)273 StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer(
274 XlaBuilder* builder, absl::Span<GlobalData* const> arguments_passed_in,
275 const Shape* shape_with_layout) {
276 std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
277 arguments_passed_in.end());
278
279 // Transfer and use elements of arguments_, if the AddParam() API was used.
280 std::vector<std::unique_ptr<GlobalData>> owning_arguments;
281 if (!arguments_.empty()) {
282 CHECK(arguments.empty());
283 for (const auto& argument : arguments_) {
284 TF_ASSIGN_OR_RETURN(
285 std::unique_ptr<GlobalData> owned_argument,
286 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
287 owning_arguments.push_back(std::move(owned_argument));
288 arguments.push_back(owning_arguments.back().get());
289 }
290 }
291
292 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
293 return ExecuteAndTransfer(computation, arguments, shape_with_layout);
294 }
295
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)296 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
297 XlaBuilder* builder, const Literal& expected,
298 absl::Span<GlobalData* const> arguments_passed_in,
299 const Shape* shape_with_layout) {
300 std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
301 arguments_passed_in.end());
302
303 // Transfer and use elements of arguments_, if the AddParam() API was used.
304 std::vector<std::unique_ptr<GlobalData>> owning_arguments;
305 if (!arguments_.empty()) {
306 CHECK(arguments.empty());
307 for (const auto& argument : arguments_) {
308 TF_ASSIGN_OR_RETURN(
309 std::unique_ptr<GlobalData> owned_argument,
310 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
311 owning_arguments.push_back(std::move(owned_argument));
312 arguments.push_back(owning_arguments.back().get());
313 }
314 }
315
316 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
317 if (ShapeUtil::ElementIsFloating(expected.shape()) ||
318 ShapeUtil::ElementIsComplex(expected.shape())) {
319 LOG(WARNING) << "performing exact comparison of floating point numbers";
320 }
321 // We allow using a float expected literal for a bfloat16 output. In this
322 // case, we need to convert the expected literal to bfloat16.
323 const Literal* expected_ptr = &expected;
324 Literal converted_expected;
325 Shape layout_shape;
326 if (use_bfloat16_) {
327 converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
328 expected_ptr = &converted_expected;
329 if (shape_with_layout != nullptr) {
330 layout_shape = *shape_with_layout;
331 ShapeUtil::ForEachMutableSubshape(
332 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
333 if (subshape->element_type() == F32) {
334 subshape->set_element_type(BF16);
335 }
336 });
337 shape_with_layout = &layout_shape;
338 }
339 }
340 auto expect_equal = [&](const Literal& actual, const string& error_message) {
341 EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
342 };
343 if (execution_options_.debug_options().xla_test_all_output_layouts()) {
344 return ComputeAndCompareLiteralWithAllOutputLayouts(
345 computation, *expected_ptr, arguments, expect_equal);
346 }
347 if (execution_options_.debug_options().xla_test_all_input_layouts()) {
348 return ComputeAndCompareLiteralWithAllInputLayouts(
349 computation, *expected_ptr, arguments, expect_equal, shape_with_layout);
350 }
351 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
352 shape_with_layout));
353 EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
354 return Status::OK();
355 }
356
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,ErrorSpec error,const Shape * shape_with_layout)357 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
358 XlaBuilder* builder, const Literal& expected,
359 absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error,
360 const Shape* shape_with_layout) {
361 std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
362 arguments_passed_in.end());
363
364 // Transfer and use elements of arguments_, if the AddParam() API was used.
365 std::vector<std::unique_ptr<GlobalData>> owning_arguments;
366 if (!arguments_.empty()) {
367 CHECK(arguments.empty());
368 for (const auto& argument : arguments_) {
369 TF_ASSIGN_OR_RETURN(
370 std::unique_ptr<GlobalData> owned_argument,
371 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
372 owning_arguments.push_back(std::move(owned_argument));
373 arguments.push_back(owning_arguments.back().get());
374 }
375 }
376
377 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
378 // We allow using a float expected literal for a bfloat16 output. In this
379 // case, we need to convert the expected literal to bfloat16.
380 const Literal* expected_ptr = &expected;
381 Literal converted_expected;
382 Shape layout_shape;
383 if (use_bfloat16_) {
384 converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
385 expected_ptr = &converted_expected;
386 if (shape_with_layout != nullptr) {
387 layout_shape = *shape_with_layout;
388 ShapeUtil::ForEachMutableSubshape(
389 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
390 if (subshape->element_type() == F32) {
391 subshape->set_element_type(BF16);
392 }
393 });
394 shape_with_layout = &layout_shape;
395 }
396 }
397 auto expect_near = [&](const Literal& actual, const string& error_message) {
398 EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
399 << error_message;
400 };
401 if (execution_options_.debug_options().xla_test_all_output_layouts()) {
402 return ComputeAndCompareLiteralWithAllOutputLayouts(
403 computation, *expected_ptr, arguments, expect_near);
404 }
405 if (execution_options_.debug_options().xla_test_all_input_layouts()) {
406 return ComputeAndCompareLiteralWithAllInputLayouts(
407 computation, *expected_ptr, arguments, expect_near, shape_with_layout);
408 }
409 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
410 shape_with_layout));
411 EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
412 return Status::OK();
413 }
414
ComputeAndCompareR1U8(XlaBuilder * builder,absl::string_view expected,absl::Span<GlobalData * const> arguments)415 void ClientLibraryTestBase::ComputeAndCompareR1U8(
416 XlaBuilder* builder, absl::string_view expected,
417 absl::Span<GlobalData* const> arguments) {
418 auto actual_status = ExecuteAndTransfer(builder, arguments);
419 EXPECT_IS_OK(actual_status.status());
420 if (!actual_status.ok()) {
421 return;
422 }
423 auto actual = actual_status.ConsumeValueOrDie();
424
425 // Turn the expected value into a literal.
426 Literal expected_literal = LiteralUtil::CreateR1U8(expected);
427
428 VLOG(1) << "expected: " << expected_literal.ToString();
429 VLOG(1) << "actual: " << actual.ToString();
430
431 EXPECT_EQ(expected, actual.GetR1U8AsString());
432 }
433
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments)434 void ClientLibraryTestBase::ComputeAndCompareTuple(
435 XlaBuilder* builder, const Literal& expected,
436 absl::Span<GlobalData* const> arguments) {
437 auto actual_status = ExecuteAndTransfer(builder, arguments);
438 EXPECT_IS_OK(actual_status.status());
439 if (!actual_status.ok()) {
440 return;
441 }
442 auto actual = actual_status.ConsumeValueOrDie();
443 EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
444 }
445
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)446 void ClientLibraryTestBase::ComputeAndCompareTuple(
447 XlaBuilder* builder, const Literal& expected,
448 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
449 auto actual_status = ExecuteAndTransfer(builder, arguments);
450 EXPECT_IS_OK(actual_status.status());
451 if (!actual_status.ok()) {
452 return;
453 }
454 auto actual = actual_status.ConsumeValueOrDie();
455 EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
456 }
457
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments)458 void ClientLibraryTestBase::ComputeAndCompare(
459 XlaBuilder* builder, absl::Span<const Literal> arguments) {
460 auto status_or_data = ComputeValueAndReference(builder, arguments);
461 EXPECT_IS_OK(status_or_data);
462 if (!status_or_data.ok()) {
463 return;
464 }
465 Literal reference, result;
466 std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
467 EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
468 }
469
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments,ErrorSpec error)470 void ClientLibraryTestBase::ComputeAndCompare(
471 XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) {
472 auto status_or_data = ComputeValueAndReference(builder, arguments);
473 EXPECT_IS_OK(status_or_data);
474 if (!status_or_data.ok()) {
475 return;
476 }
477 Literal reference, result;
478 std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
479 EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
480 }
481
482 StatusOr<std::pair<Literal, Literal>>
ComputeValueAndReference(XlaBuilder * builder,absl::Span<const Literal> arguments)483 ClientLibraryTestBase::ComputeValueAndReference(
484 XlaBuilder* builder, absl::Span<const Literal> arguments) {
485 // Transfer the arguments to the executor service. We put the unique_ptr's
486 // into a vector to keep the data alive on the service until the end of this
487 // function.
488 std::vector<std::unique_ptr<GlobalData>> argument_data;
489 std::vector<std::unique_ptr<GlobalData>> ref_argument_data;
490
491 // Use `arguments_` if the AddParam() API was used. Otherwise, use
492 // plain `arguments`.
493 if (!arguments_.empty()) {
494 CHECK_EQ(arguments.size(), 0);
495 arguments = arguments_;
496 }
497
498 for (const auto& arg : arguments) {
499 TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone()));
500 TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg));
501 argument_data.push_back(std::move(data));
502 ref_argument_data.push_back(std::move(ref_data));
503 }
504
505 // Create raw pointers to the GlobalData for the rest of the call stack.
506 std::vector<GlobalData*> argument_data_ptr;
507 std::transform(
508 argument_data.begin(), argument_data.end(),
509 std::back_inserter(argument_data_ptr),
510 [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
511 std::vector<GlobalData*> ref_argument_data_ptr;
512 std::transform(
513 ref_argument_data.begin(), ref_argument_data.end(),
514 std::back_inserter(ref_argument_data_ptr),
515 [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
516
517 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
518
519 TF_ASSIGN_OR_RETURN(auto result,
520 ExecuteAndTransfer(computation, argument_data_ptr));
521
522 TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference(
523 computation, ref_argument_data_ptr));
524
525 return std::make_pair(std::move(reference), std::move(result));
526 }
527
CreateScalarRelu()528 XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
529 XlaBuilder builder("relu");
530 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
531 auto z_value = Parameter(&builder, 0, shape, "z_value");
532 auto zero = use_bfloat16_
533 ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
534 : ConstantR0<float>(&builder, 0.0f);
535 Max(z_value, zero);
536 auto computation_status = builder.Build();
537 TF_CHECK_OK(computation_status.status());
538 return computation_status.ConsumeValueOrDie();
539 }
540
CreateScalarMax()541 XlaComputation ClientLibraryTestBase::CreateScalarMax() {
542 XlaBuilder builder("max");
543 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
544 auto x = Parameter(&builder, 0, shape, "x");
545 auto y = Parameter(&builder, 1, shape, "y");
546 Max(x, y);
547 auto computation_status = builder.Build();
548 TF_CHECK_OK(computation_status.status());
549 return computation_status.ConsumeValueOrDie();
550 }
551
CreateScalarReluSensitivity()552 XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
553 XlaBuilder builder("relu_sensitivity");
554 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
555 auto activation = Parameter(&builder, 0, shape, "activation");
556 auto backprop = Parameter(&builder, 1, shape, "backprop");
557 auto zero = use_bfloat16_
558 ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
559 : ConstantR0<float>(&builder, 0.0f);
560 auto activation_gtz = Gt(activation, zero);
561 Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
562
563 auto computation_status = builder.Build();
564 TF_CHECK_OK(computation_status.status());
565 return computation_status.ConsumeValueOrDie();
566 }
567
CreatePatternedMatrix(int rows,int cols,float offset)568 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
569 int rows, int cols, float offset) {
570 auto array = absl::make_unique<Array2D<float>>(rows, cols);
571 for (int64 row = 0; row < rows; ++row) {
572 for (int64 col = 0; col < cols; ++col) {
573 (*array)(row, col) = col + (row * 1000.0f) + offset;
574 }
575 }
576 return array;
577 }
578
579 std::unique_ptr<Array2D<float>>
CreatePatternedMatrixWithZeroPadding(int rows,int cols,int rows_padded,int cols_padded)580 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
581 int rows_padded,
582 int cols_padded) {
583 CHECK_GE(rows_padded, rows);
584 CHECK_GE(cols_padded, cols);
585 auto array = absl::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0);
586 for (int64 row = 0; row < rows; ++row) {
587 for (int64 col = 0; col < cols; ++col) {
588 (*array)(row, col) = col + (row * 1000.0f);
589 }
590 }
591 return array;
592 }
593
AddParam(const Literal & argument,XlaBuilder * builder)594 XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
595 XlaBuilder* builder) {
596 arguments_.push_back(argument.Clone());
597 return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
598 MaybeConvertShapeToBfloat16(argument.shape()), "");
599 }
600
CreateConstantFromLiteral(const Literal & literal,XlaBuilder * builder)601 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
602 XlaBuilder* builder) {
603 return ConstantLiteral(builder, use_bfloat16_
604 ? LiteralUtil::ConvertF32ToBF16(literal)
605 : LiteralSlice(literal));
606 }
607
608 std::unique_ptr<GlobalData>
CreateParameterAndTransferLiteral(int64 parameter_number,const Literal & literal,const string & name,XlaBuilder * builder,XlaOp * data_handle)609 ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
610 const Literal& literal,
611 const string& name,
612 XlaBuilder* builder,
613 XlaOp* data_handle) {
614 return CreateParameterAndTransferLiteral(parameter_number, literal, name,
615 nullptr, builder, data_handle);
616 }
617
MaybeConvertShapeToBfloat16(const Shape & shape)618 Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
619 if (!use_bfloat16_) {
620 return shape;
621 }
622 Shape new_shape = shape;
623 ShapeUtil::ForEachMutableSubshape(&new_shape,
624 [](Shape* subshape, const ShapeIndex&) {
625 if (subshape->element_type() == F32) {
626 subshape->set_element_type(BF16);
627 }
628 });
629 return new_shape;
630 }
631
MaybeConvertLiteralToBfloat16(const Literal & literal)632 Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
633 const Literal& literal) {
634 if (use_bfloat16_) {
635 return LiteralUtil::ConvertF32ToBF16(literal);
636 }
637 return literal.Clone();
638 }
639
640 std::unique_ptr<GlobalData>
CreateParameterAndTransferLiteral(int64 parameter_number,const Literal & literal,const string & name,const DeviceHandle * device_handle,XlaBuilder * builder,XlaOp * data_handle)641 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
642 int64 parameter_number, const Literal& literal, const string& name,
643 const DeviceHandle* device_handle, XlaBuilder* builder,
644 XlaOp* data_handle) {
645 Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
646 std::unique_ptr<GlobalData> data =
647 client_->TransferToServer(param_literal, device_handle)
648 .ConsumeValueOrDie();
649 *data_handle =
650 Parameter(builder, parameter_number, param_literal.shape(), name);
651 return data;
652 }
653
654 } // namespace xla
655